#!/usr/bin/env python3
"""
path_profile_tool.py (combined, all-in-one)
============================================
For each microwave path in ATT_Network.csv:
  1. Builds (or reuses) two remote VRT mosaics -- no tile downloads:
       - glo30_remote.vrt : Copernicus GLO-30 DSM (global, 30m,
         includes terrain + tree canopy + most buildings)
       - usgs_remote.vrt  : USGS 3DEP 1m DSM (US only, patchy lidar
         coverage, includes canopy/buildings where available)
     Both point at /vsicurl/ URLs on public cloud storage; GDAL only
     fetches the small blocks needed for each sample point via HTTP
     range requests.
  2. Generates a KML line (EPSG:4326) between Site A and Site B, with
     80 m sample-point placemarks
  3. Samples the geodesic at ~80 m intervals. For each point, tries
     USGS 1m first, falls back to GLO-30 if nodata/out-of-bounds.
  4. Saves a per-path table (CSV): X (lon), Y (lat), Z (m), source
  5. Plots the surface profile vs. distance, with a line-of-sight
     reference line and obstruction points (surface above LOS)
     highlighted, color-coded by data source

Outputs (per path, named by Ref_ID_Path_ID):
  kml/<Ref_ID>_<Path_ID>.kml
  tables/<Ref_ID>_<Path_ID>.csv
  plots/<Ref_ID>_<Path_ID>.png

Usage:
  python path_profile_tool.py ATT_Network.csv --interval 80 --out output

  # Optional:
  #   --start-id / --end-id   restrict to a Ref_ID range
  #   --overwrite             reprocess paths even if outputs exist
  #   --skip-usgs             don't query USGS 3DEP at all, use GLO-30
  #                           for everything (faster, lower-res)
  #   --glo30-vrt / --usgs-vrt  override default VRT filenames/paths
  #                           (built once and reused on subsequent runs)
"""

import argparse
import csv
import json
import math
import os
import subprocess
import time
import traceback
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import datetime, timedelta

# --- GDAL/curl tuning for robust remote (/vsicurl/) COG reads ---------
# Set before importing rasterio so GDAL picks these up at driver init.
os.environ.setdefault("GDAL_DISABLE_READDIR_ON_OPEN", "EMPTY_DIR")  # don't list bucket dirs
os.environ.setdefault("CPL_VSIL_CURL_ALLOWED_EXTENSIONS", ".tif,.tiff,.vrt")
os.environ.setdefault("GDAL_HTTP_MAX_RETRY", "5")        # retry transient failures
os.environ.setdefault("GDAL_HTTP_RETRY_DELAY", "2")      # seconds between retries
os.environ.setdefault("GDAL_HTTP_TIMEOUT", "30")         # per-request timeout (s)
os.environ.setdefault("GDAL_HTTP_CONNECTTIMEOUT", "10")
os.environ.setdefault("CPL_VSIL_CURL_CACHE_SIZE", "200000000")  # 200MB block cache
os.environ.setdefault("VSI_CACHE", "TRUE")
os.environ.setdefault("VSI_CACHE_SIZE", "50000000")      # 50MB per-file cache
os.environ.setdefault("GDAL_CACHEMAX", "512")            # MB, GDAL block cache

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import numpy as np
import requests
import rasterio
import pyproj
from pyproj import Geod
import simplekml

GEOD = Geod(ellps="WGS84")  # EPSG:4326 reference ellipsoid

GLO30_BUCKET = "https://copernicus-dem-30m.s3.amazonaws.com"
TNM_API = "https://tnmaccess.nationalmap.gov/api/v1/products"
USGS_1M_DATASET = "Digital Elevation Model (DEM) 1 meter"


# ----------------------------------------------------------------------
# Remote VRT builders (no downloads -- /vsicurl/ + HTTP range requests)
# ----------------------------------------------------------------------
def glo30_tile_name(lat, lon):
    lat_floor, lon_floor = math.floor(lat), math.floor(lon)
    ns = f"N{lat_floor:02d}" if lat_floor >= 0 else f"S{abs(lat_floor):02d}"
    ew = f"E{lon_floor:03d}" if lon_floor >= 0 else f"W{abs(lon_floor):03d}"
    return f"Copernicus_DSM_COG_10_{ns}_00_{ew}_00_DEM"


def glo30_tiles_for_bbox(lat1, lon1, lat2, lon2):
    lat_lo, lat_hi = sorted((math.floor(lat1), math.floor(lat2)))
    lon_lo, lon_hi = sorted((math.floor(lon1), math.floor(lon2)))
    out = set()
    for lat in range(lat_lo, lat_hi + 1):
        for lon in range(lon_lo, lon_hi + 1):
            out.add(glo30_tile_name(lat + 0.5, lon + 0.5))
    return out


def build_glo30_vrt(rows, vrt_path, force=False):
    if os.path.exists(vrt_path) and not force:
        print(f"Reusing existing {vrt_path}")
        return
    if os.path.exists(vrt_path) and force:
        os.remove(vrt_path)
    tiles = set()
    for lat1, lon1, lat2, lon2 in rows:
        tiles |= glo30_tiles_for_bbox(lat1, lon1, lat2, lon2)

    print(f"GLO-30: {len(tiles)} candidate tiles, checking existence...")
    urls = []
    for t in sorted(tiles):
        url = f"{GLO30_BUCKET}/{t}/{t}.tif"
        try:
            r = requests.head(url, timeout=15)
            if r.status_code == 200:
                urls.append("/vsicurl/" + url)
        except Exception:
            pass
    print(f"GLO-30: {len(urls)} tiles exist, building {vrt_path}")
    if urls:
        list_file = vrt_path + ".input_list.txt"
        with open(list_file, "w") as f:
            f.write("\n".join(urls) + "\n")
        subprocess.run(["gdalbuildvrt", "-input_file_list", list_file, vrt_path], check=True)


def usgs_tile_urls_for_bbox(bbox, buffer=0.005, max_retries=3):
    west = min(bbox[0], bbox[2]) - buffer
    east = max(bbox[0], bbox[2]) + buffer
    south = min(bbox[1], bbox[3]) - buffer
    north = max(bbox[1], bbox[3]) + buffer
    params = {
        "datasets": USGS_1M_DATASET,
        "bbox": f"{west},{south},{east},{north}",
        "outputFormat": "JSON",
        "max": 50,
    }
    for attempt in range(max_retries):
        try:
            r = requests.get(TNM_API, params=params, timeout=30)
            r.raise_for_status()
            data = r.json()
            return [item["downloadURL"] for item in data.get("items", [])
                    if item.get("downloadURL", "").lower().endswith((".tif", ".tiff"))]
        except Exception:
            if attempt == max_retries - 1:
                return []


def build_usgs_tile_list(rows, list_path, force=False):
    """Builds/reuses:
      - list_path: plain-text list of all unique /vsicurl/ USGS 1m tile URLs
        (deduplicated across all paths -- kept for informational purposes)
      - list_path + '.per_path.json': dict Path_ID2 -> list of /vsicurl/ URLs
        for just that path's bounding box (typically 1-5 tiles). This is
        what's actually used for sampling -- avoids looping over thousands
        of irrelevant URLs per query point.

    `rows` is the full list of CSV row dicts (needs Path_ID2 + coordinates).
    """
    per_path_path = list_path + ".per_path.json"
    if os.path.exists(list_path) and os.path.exists(per_path_path) and not force:
        print(f"Reusing existing {list_path} and {per_path_path}")
        with open(list_path) as f:
            global_urls = [line.strip() for line in f if line.strip()]
        with open(per_path_path) as f:
            per_path = json.load(f)
        return global_urls, per_path

    urls = set()
    per_path = {}
    for i, row in enumerate(rows, 1):
        lat1, lon1 = float(row["Site_A_Latitude"]), float(row["Site_A_Longitude"])
        lat2, lon2 = float(row["Site_ B_Latitude"]), float(row["Site_B_Longitude"])
        found = usgs_tile_urls_for_bbox((lon1, lat1, lon2, lat2))
        vsi = ["/vsicurl/" + u for u in sorted(set(found))]
        per_path[row["Path_ID2"]] = vsi
        urls |= set(found)
        if i % 100 == 0:
            print(f"USGS 1m: {i}/{len(rows)} paths queried, {len(urls)} unique tiles so far")

    vsicurl_urls = ["/vsicurl/" + u for u in sorted(urls)]
    print(f"USGS 1m: {len(vsicurl_urls)} unique tiles -> {list_path}, "
          f"per-path mapping -> {per_path_path}")
    with open(list_path, "w") as f:
        for u in vsicurl_urls:
            f.write(u + "\n")
    with open(per_path_path, "w") as f:
        json.dump(per_path, f)
    return vsicurl_urls, per_path


class MultiCRSRasterSet:
    """A set of remote rasters, possibly in different CRSs (e.g. USGS 1m
    tiles in various UTM zones). For each query point (given in EPSG:4326
    lon/lat), checks each tile's bounds (transforming the point into that
    tile's CRS) and samples the first tile that contains it.

    Datasets and transformers are opened lazily and cached -- only tiles
    that are actually hit get opened."""

    def __init__(self, urls):
        self.urls = urls
        self._ds = {}        # url -> rasterio dataset (or False if failed to open)
        self._transformer = {}  # url -> pyproj Transformer EPSG:4326 -> tile CRS

    def _get(self, url):
        if url not in self._ds:
            try:
                ds = rasterio.open(url)
                self._ds[url] = ds
                self._transformer[url] = pyproj.Transformer.from_crs(
                    "EPSG:4326", ds.crs, always_xy=True)
            except Exception:
                self._ds[url] = False
        return self._ds[url]

    def sample(self, lon, lat):
        for url in self.urls:
            ds = self._get(url)
            if ds is False:
                continue
            x, y = self._transformer[url].transform(lon, lat)
            b = ds.bounds
            if not (b.left <= x <= b.right and b.bottom <= y <= b.top):
                continue
            try:
                val = float(next(ds.sample([(x, y)]))[0])
                nodata = ds.nodata
                if nodata is not None and val == nodata:
                    continue
                if math.isnan(val):
                    continue
                return val
            except Exception:
                continue
        return float("nan")

    def close(self):
        for ds in self._ds.values():
            if ds:
                ds.close()


# ----------------------------------------------------------------------
# Geometry helpers
# ----------------------------------------------------------------------
def sample_points(lat1, lon1, lat2, lon2, interval_m):
    az12, az21, dist = GEOD.inv(lon1, lat1, lon2, lat2)
    if dist == 0:
        return [(lon1, lat1, 0.0)]

    n_intervals = max(1, math.ceil(dist / interval_m))
    points = []
    for i in range(n_intervals + 1):
        d = min(i * interval_m, dist)
        lon, lat, _ = GEOD.fwd(lon1, lat1, az12, d)
        points.append((lon, lat, d))
    points[-1] = (lon2, lat2, dist)
    return points


# ----------------------------------------------------------------------
# Elevation lookup: primary raster with fallback raster
# ----------------------------------------------------------------------
def sample_raster(ds, coords):
    """Returns list of (value or NaN). NaN if nodata, out of bounds, or band error."""
    if ds is None:
        return [float("nan")] * len(coords)
    nodata = ds.nodata
    bounds = ds.bounds
    out = []
    for lon, lat in coords:
        if not (bounds.left <= lon <= bounds.right and bounds.bottom <= lat <= bounds.top):
            out.append(float("nan"))
            continue
        try:
            val = float(next(ds.sample([(lon, lat)]))[0])
            if nodata is not None and val == nodata:
                out.append(float("nan"))
            elif math.isnan(val):
                out.append(float("nan"))
            else:
                out.append(val)
        except Exception:
            out.append(float("nan"))
    return out


def get_elevations(points, primary_set, fallback_ds):
    """primary_set: MultiCRSRasterSet or None. fallback_ds: rasterio dataset (EPSG:4326).
    Returns (elevs, sources). sources[i] in {'usgs_1m','glo30','none'}."""
    coords = [(lon, lat) for lon, lat, _ in points]

    fallback_vals = sample_raster(fallback_ds, coords)

    elevs, sources = [], []
    for (lon, lat), fv in zip(coords, fallback_vals):
        pv = primary_set.sample(lon, lat) if primary_set is not None else float("nan")
        if not math.isnan(pv):
            elevs.append(pv)
            sources.append("usgs_1m")
        elif not math.isnan(fv):
            elevs.append(fv)
            sources.append("glo30")
        else:
            elevs.append(float("nan"))
            sources.append("none")
    return elevs, sources


# ----------------------------------------------------------------------
# KML
# ----------------------------------------------------------------------
def write_kml(path_id, name_a, name_b, lat1, lon1, lat2, lon2, points, elevs, sources, out_path):
    kml = simplekml.Kml()
    kml.document.name = path_id

    line = kml.newlinestring(name=path_id)
    line.coords = [(lon1, lat1, 0), (lon2, lat2, 0)]
    line.altitudemode = simplekml.AltitudeMode.clamptoground
    line.style.linestyle.width = 2
    line.style.linestyle.color = simplekml.Color.red

    kml.newpoint(name=name_a, coords=[(lon1, lat1)])
    kml.newpoint(name=name_b, coords=[(lon2, lat2)])

    folder = kml.newfolder(name="Sample Points (80 m, with surface elev.)")
    for i, ((lon, lat, d), z, src) in enumerate(zip(points, elevs, sources)):
        zlabel = "n/a" if math.isnan(z) else f"{z:.1f} m ({src})"
        folder.newpoint(name=f"P{i} d={d:.0f}m z={zlabel}", coords=[(lon, lat)])

    kml.save(out_path)


# ----------------------------------------------------------------------
# Table + plot
# ----------------------------------------------------------------------
def write_table(points, elevs, sources, out_path):
    with open(out_path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["index", "distance_m", "X_lon", "Y_lat", "Z_surface_elev_m", "source"])
        for i, ((lon, lat, d), z, src) in enumerate(zip(points, elevs, sources)):
            w.writerow([i, f"{d:.2f}", f"{lon:.8f}", f"{lat:.8f}",
                        "" if math.isnan(z) else f"{z:.2f}", src])


def plot_profile(path_id, name_a, name_b, points, elevs, sources, out_path,
                  ant_a_main_amsl=None, ant_a_div_amsl=None,
                  ant_b_main_amsl=None, ant_b_div_amsl=None,
                  main_a=None, div_a=None, main_b=None, div_b=None,
                  frequency_ghz=6.0, k_factor=4.0 / 3.0,
                  site_name_a="", site_name_b=""):
    dists = np.array([p[2] for p in points]) / 1609.344   # meters -> miles
    elevs = np.array(elevs) * 3.28084                      # meters -> feet
    sources = np.array(sources)

    fig = plt.figure(figsize=(11, 6.5))
    ax = fig.add_axes((0.08, 0.34, 0.88, 0.58))  # leave room below for call signs + legend
    valid = ~np.isnan(elevs)
    if valid.any():
        ax.fill_between(dists, np.nanmin(elevs) - 5, elevs, color="tab:green", alpha=0.3)
    ax.plot(dists, elevs, "-", color="tab:brown", lw=1, label="Surface (terrain+obstacles)")

    m1 = sources == "usgs_1m"
    m2 = sources == "glo30"
    if m1.any():
        ax.scatter(dists[m1], elevs[m1], s=8, color="saddlebrown", label="USGS 3DEP 1m")
    if m2.any():
        ax.scatter(dists[m2], elevs[m2], s=8, color="darkgoldenrod", label="Copernicus GLO-30")

    # Earth-curvature bulge (k-adjusted) and Fresnel zone radii.
    # Convert antenna AMSL values (arrive in meters) to feet
    def to_ft(v):
        return v * 3.28084 if v is not None else None
    ant_a_main_amsl = to_ft(ant_a_main_amsl)
    ant_a_div_amsl  = to_ft(ant_a_div_amsl)
    ant_b_main_amsl = to_ft(ant_b_main_amsl)
    ant_b_div_amsl  = to_ft(ant_b_div_amsl)

    # Fresnel / earth-curvature: distances in km, results in meters then
    # converted to feet for the plot (which is now in feet).
    # d1, d2 in km; bulge and F1 in meters -> convert to feet.
    D_km = dists[-1] * 1.60934   # miles -> km
    d1_km = dists * 1.60934
    d2_km = D_km - d1_km
    if D_km > 0 and k_factor > 0:
        bulge_ft = (d1_km * d2_km) / (12.74 * k_factor) * 3.28084
    else:
        bulge_ft = np.zeros_like(dists)
    if D_km > 0 and frequency_ghz > 0:
        f1_ft = 17.32 * np.sqrt(np.clip(d1_km * d2_km, 0, None) / (frequency_ghz * D_km)) * 3.28084
    else:
        f1_ft = np.zeros_like(dists)

    def plot_los_with_fresnel(y0, y1, los_color, f1_color, f06_color,
                                los_label, f1_label, f06_label,
                                obstruction_label, f1_violation_label):
        los_straight = np.interp(dists, [dists[0], dists[-1]], [y0, y1])
        los_curve = los_straight - bulge_ft
        los_f1 = los_curve - f1_ft
        los_06f1 = los_curve - 0.6 * f1_ft

        ax.plot(dists, los_curve, "--", color=los_color, label=los_label)
        ax.plot(dists, los_f1, "-.", color=f1_color, lw=1, label=f1_label)
        ax.plot(dists, los_06f1, ":", color=f06_color, lw=1.3, label=f06_label)

        obstructed = valid & (elevs > los_curve)
        if obstructed.any():
            ax.scatter(dists[obstructed], elevs[obstructed], color="red", zorder=5, marker="x",
                       label=obstruction_label)

        f1_violation = valid & (elevs > los_f1) & ~obstructed
        if f1_violation.any():
            ax.scatter(dists[f1_violation], elevs[f1_violation], color="orangered", zorder=5,
                       marker="x", facecolors="none", label=f1_violation_label)

    # LOS (main antennas): use FCC antenna AMSL elevations if available,
    # else fall back to ground elevation at the endpoints.
    los_y0 = ant_a_main_amsl if ant_a_main_amsl is not None else (elevs[0] if valid[0] else None)
    los_y1 = ant_b_main_amsl if ant_b_main_amsl is not None else (elevs[-1] if valid[-1] else None)

    if los_y0 is not None and los_y1 is not None:
        los_label = "Line of sight (main antennas)" \
            if (ant_a_main_amsl is not None and ant_b_main_amsl is not None) \
            else "Line of sight (endpoints)"
        plot_los_with_fresnel(
            los_y0, los_y1, "tab:blue", "tab:green", "tab:orange",
            los_label, f"F1 (k={k_factor:.2f}, {frequency_ghz:g} GHz)", "0.6 F1",
            "Obstruction (surface > LOS)", "F1 zone violation")

    # Diversity LOS: diversity antennas are receive-only, so a diversity
    # line runs from the MAIN antenna at one end to the DIVERSITY antenna
    # at the other end (never diversity-to-diversity).
    if ant_a_main_amsl is not None and ant_b_div_amsl is not None:
        plot_los_with_fresnel(
            ant_a_main_amsl, ant_b_div_amsl, "tab:purple", "mediumvioletred", "plum",
            f"Line of sight ({name_a} main \u2192 {name_b} diversity)",
            f"F1 -- {name_b} diversity (k={k_factor:.2f}, {frequency_ghz:g} GHz)",
            f"0.6 F1 -- {name_b} diversity",
            f"Obstruction ({name_b} diversity)", f"F1 zone violation ({name_b} diversity)")

    if ant_a_div_amsl is not None and ant_b_main_amsl is not None:
        plot_los_with_fresnel(
            ant_a_div_amsl, ant_b_main_amsl, "indigo", "darkmagenta", "orchid",
            f"Line of sight ({name_a} diversity \u2192 {name_b} main)",
            f"F1 -- {name_a} diversity (k={k_factor:.2f}, {frequency_ghz:g} GHz)",
            f"0.6 F1 -- {name_a} diversity",
            f"Obstruction ({name_a} diversity)", f"F1 zone violation ({name_a} diversity)")

    # Draw antenna masts + markers at each endpoint
    def draw_site(d, ground_z, main_amsl, div_amsl, main_ant, div_ant, name):
        if math.isnan(ground_z):
            return
        top = max([h for h in (main_amsl, div_amsl) if h is not None], default=None)
        if top is not None:
            ax.plot([d, d], [ground_z, top], color="gray", lw=1.5, zorder=4)
        if main_amsl is not None:
            ax.scatter([d], [main_amsl], color="black", marker="^", s=40, zorder=6,
                       label=f"{name} main ({main_ant['antenna_height_agl_m'] * 3.28084:.0f} ft AGL)")
        if div_amsl is not None:
            ax.scatter([d], [div_amsl], facecolors="none", edgecolors="black", marker="^", s=40,
                       zorder=6, label=f"{name} div ({div_ant['antenna_height_agl_m'] * 3.28084:.0f} ft AGL)")

    if valid[0]:
        draw_site(dists[0], elevs[0], ant_a_main_amsl, ant_a_div_amsl, main_a, div_a, name_a)
    if valid[-1]:
        draw_site(dists[-1], elevs[-1], ant_b_main_amsl, ant_b_div_amsl, main_b, div_b, name_b)

    ax.text(dists[0], -0.22, name_a, transform=ax.get_xaxis_transform(),
            ha="center", va="top", fontsize=9, fontweight="bold")
    if site_name_a:
        ax.text(dists[0], -0.32, site_name_a, transform=ax.get_xaxis_transform(),
                ha="center", va="top", fontsize=8, color="dimgray")
    ax.text(dists[-1], -0.22, name_b, transform=ax.get_xaxis_transform(),
            ha="center", va="top", fontsize=9, fontweight="bold")
    if site_name_b:
        ax.text(dists[-1], -0.32, site_name_b, transform=ax.get_xaxis_transform(),
                ha="center", va="top", fontsize=8, color="dimgray")

    ax.set_xlabel("Distance along path (miles)")
    ax.set_ylabel("Elevation (ft AMSL)")
    ax.set_title(f"Path Profile: {path_id}\n{name_a} \u2192 {name_b}")
    ax.grid(True, alpha=0.3)
    fig.legend(fontsize=7, loc="lower center", bbox_to_anchor=(0.5, 0.0),
               ncol=3, frameon=True)
    fig.savefig(out_path, dpi=120)
    plt.close(fig)


def load_fcc_data(antennas_path, paths_path):
    """Returns (antennas, by_location, path_records):
      antennas: dict (call_sign, location_number, antenna_number) -> {antenna_height_agl_m, polarization}
      by_location: dict (call_sign, location_number) -> list of antenna dicts
                    (same dicts as in `antennas`), for fallback lookups
                    when a PA-referenced antenna_number has no height.
      path_records: dict frozenset({call_sign_a, call_sign_b}) -> list of PA rows (dicts)
    """
    antennas = {}
    by_location = {}
    if antennas_path and os.path.exists(antennas_path):
        with open(antennas_path, newline="") as f:
            for row in csv.DictReader(f):
                cs = row["call_sign"].strip()
                loc = row["location_number"].strip()
                ant_num = row["antenna_number"].strip()
                h = row.get("antenna_height_agl_m", "").strip()
                if not h:
                    continue
                d = {
                    "antenna_height_agl_m": float(h),
                    "antenna_make": row.get("antenna_make", "").strip(),
                    "antenna_model": row.get("antenna_model", "").strip(),
                    "polarization": row.get("polarization", "").strip(),
                }
                antennas[(cs, loc, ant_num)] = d
                by_location.setdefault((cs, loc), []).append(d)

    path_records = {}
    if paths_path and os.path.exists(paths_path):
        with open(paths_path, newline="") as f:
            for row in csv.DictReader(f):
                a, b = row["call_sign"].strip(), row["receive_call_sign"].strip()
                if not a or not b:
                    continue
                key = frozenset((a, b))
                path_records.setdefault(key, []).append(row)
    return {"antennas": antennas, "by_location": by_location, "path_records": path_records}


def find_path_antennas(call_sign_a, call_sign_b, fcc_data, debug=False):
    """Returns (main_a, div_a, main_b, div_b), each an antenna dict
    {antenna_height_agl_m, antenna_make, antenna_model, polarization}
    or None.

    Rule (matches populate_summary_v2.py's find_main_div): for a site X
    in pair (X, Y), collect ALL of X's own TX antennas from PA.dat
    records where call_sign=X and receive_call_sign=Y. Antennas with the
    same height (within 0.5m -- dual-polarization copies of the same
    physical antenna) are collapsed to one. Of the resulting distinct
    antennas: the TALLEST is X's MAIN antenna (main is always higher on
    the tower than diversity), the next-tallest (if any) is X's
    DIVERSITY antenna.

    This relies ONLY on each site's own TX-side PA records, never on the
    RX-side references (which are frequently dangling in AN.dat)."""
    antennas = fcc_data["antennas"]
    path_records = fcc_data["path_records"]
    records = path_records.get(frozenset((call_sign_a, call_sign_b)), [])

    def own_antennas(cs, other_cs):
        found = []
        for r in records:
            if r["call_sign"].strip() == cs and r["receive_call_sign"].strip() == other_cs:
                ant = antennas.get((cs, r["tx_location_number"].strip(), r["tx_antenna_number"].strip()))
                if ant:
                    found.append(ant)
        deduped = []
        for a in found:
            if not any(abs(a["antenna_height_agl_m"] - b["antenna_height_agl_m"]) < 0.5 for b in deduped):
                deduped.append(a)
        deduped.sort(key=lambda a: a["antenna_height_agl_m"], reverse=True)
        return deduped

    a_ants = own_antennas(call_sign_a, call_sign_b)
    b_ants = own_antennas(call_sign_b, call_sign_a)

    if debug:
        print(f"    [{call_sign_a} <-> {call_sign_b}] "
              f"A heights={[a['antenna_height_agl_m'] for a in a_ants]}, "
              f"B heights={[b['antenna_height_agl_m'] for b in b_ants]}")

    main_a = a_ants[0] if a_ants else None
    div_a = a_ants[1] if len(a_ants) > 1 else None
    main_b = b_ants[0] if b_ants else None
    div_b = b_ants[1] if len(b_ants) > 1 else None
    return main_a, div_a, main_b, div_b


def load_raat(path):
    """dict Path_ID -> (a_raat_m, b_raat_m), both floats. Used as the
    authoritative MAIN antenna height AGL for every path (100% coverage,
    see ATT_Mobility_Network_Definition_v9.csv)."""
    raat = {}
    if not path or not os.path.exists(path):
        return raat
    with open(path, newline="", encoding="utf-8-sig") as f:
        for row in csv.DictReader(f):
            pid = row.get("Path_ID", "").strip()
            try:
                a = float(row["a_raat"])
                b = float(row["b_raat"])
            except (ValueError, KeyError):
                continue
            raat[pid] = (a, b)
    return raat


# ----------------------------------------------------------------------
# Main processing
# ----------------------------------------------------------------------
# Module-level dict used both for multiprocessing-worker state and (in
# sequential mode) just to track "have we printed a write-verification
# line yet" -- see _W["_verified"] in process_path below.
_W = {}


def process_path(row, interval_m, usgs_per_path, fallback_ds, out_dir,
                  fcc_data=None, raat=None, debug=False,
                  frequency_ghz=6.0, k_factor=4.0 / 3.0):
    ref_id = row["Ref_ID"]
    path_id = row["Path_ID2"]
    name_a = row["Site_A_Call_Sign"]
    name_b = row["Site_B_Call_Sign"]
    site_name_a = row.get("Site_A_Name", "").strip()
    site_name_b = row.get("Site_B_Name", "").strip()
    lat1 = float(row["Site_A_Latitude"])
    lon1 = float(row["Site_A_Longitude"])
    lat2 = float(row["Site_ B_Latitude"])
    lon2 = float(row["Site_B_Longitude"])

    t0 = time.perf_counter()
    pts = sample_points(lat1, lon1, lat2, lon2, interval_m)
    t1 = time.perf_counter()
    # Per-path raster set: only the handful of USGS tiles whose bbox
    # actually covers this path (looked up via the cached per-path mapping),
    # NOT the full global tile list -- avoids opening thousands of
    # irrelevant /vsicurl/ tiles per query point.
    path_urls = (usgs_per_path or {}).get(path_id, [])
    primary_set = MultiCRSRasterSet(path_urls) if path_urls else None
    try:
        elevs, sources = get_elevations(pts, primary_set, fallback_ds)
    finally:
        if primary_set is not None:
            primary_set.close()
    t2 = time.perf_counter()

    fcc_data = fcc_data or {"antennas": {}, "by_location": {}, "path_records": {}}

    if debug:
        print(f"  [{path_id}]")
    main_a, div_a, main_b, div_b = find_path_antennas(name_a, name_b, fcc_data, debug)
    t3 = time.perf_counter()

    # MAIN antenna height AGL: always sourced from RAAT (complete
    # coverage across all paths -- see ATT_Mobility_Network_Definition_v9.csv).
    # FCC PA.dat/AN.dat-derived main_a/main_b (used above only to get
    # make/model and to determine div_a/div_b) may have a different
    # height than RAAT for the same antenna; RAAT wins for AGL/AMSL/plot
    # purposes per user instruction. If RAAT has no entry for this path,
    # fall back to the FCC-derived height.
    raat_a, raat_b = (raat or {}).get(path_id, (None, None))
    if raat_a is not None:
        main_a = dict(main_a or {}, antenna_height_agl_m=raat_a) if main_a else {"antenna_height_agl_m": raat_a}
    if raat_b is not None:
        main_b = dict(main_b or {}, antenna_height_agl_m=raat_b) if main_b else {"antenna_height_agl_m": raat_b}

    def amsl(ground_elev, ant):
        if ant is None or math.isnan(ground_elev):
            return None
        return ground_elev + ant["antenna_height_agl_m"]

    ant_a_main_amsl = amsl(elevs[0], main_a)
    ant_a_div_amsl = amsl(elevs[0], div_a)
    ant_b_main_amsl = amsl(elevs[-1], main_b)
    ant_b_div_amsl = amsl(elevs[-1], div_b)

    base = path_id
    kml_path = os.path.join(out_dir, "kml", base + ".kml")
    table_path = os.path.join(out_dir, "tables", base + ".csv")
    png_path = os.path.join(out_dir, "plots", base + ".png")

    write_kml(path_id, name_a, name_b, lat1, lon1, lat2, lon2, pts, elevs, sources, kml_path)
    t4 = time.perf_counter()
    write_table(pts, elevs, sources, table_path)
    t5 = time.perf_counter()
    plot_profile(path_id, name_a, name_b, pts, elevs, sources, png_path,
                 ant_a_main_amsl=ant_a_main_amsl, ant_a_div_amsl=ant_a_div_amsl,
                 ant_b_main_amsl=ant_b_main_amsl, ant_b_div_amsl=ant_b_div_amsl,
                 main_a=main_a, div_a=div_a, main_b=main_b, div_b=div_b,
                 frequency_ghz=frequency_ghz, k_factor=k_factor,
                 site_name_a=site_name_a, site_name_b=site_name_b)
    t6 = time.perf_counter()

    # Verify the file was actually written where we think it was, with a
    # fresh mtime -- prints once for the first path processed by each
    # worker (and always under --debug-antennas), to settle any ambiguity
    # about output location.
    if debug or os.environ.get("PROFILE_TIMING") or os.environ.get("VERIFY_WRITE") \
            or not _W.get("_verified"):
        try:
            st = os.stat(png_path)
            print(f"  [{path_id}] wrote {png_path} "
                  f"({st.st_size} bytes, mtime={datetime.fromtimestamp(st.st_mtime):%Y-%m-%d %H:%M:%S})")
        except OSError as e:
            print(f"  [{path_id}] WARNING: could not stat {png_path}: {e}")
        _W["_verified"] = True

    if debug or os.environ.get("PROFILE_TIMING"):
        print(f"  [{path_id}] timing: sample={t1-t0:.2f}s elev={t2-t1:.2f}s "
              f"antenna={t3-t2:.2f}s kml={t4-t3:.2f}s table={t5-t4:.2f}s plot={t6-t5:.2f}s "
              f"total={t6-t0:.2f}s")

    n_usgs = sum(1 for s in sources if s == "usgs_1m")
    n_glo = sum(1 for s in sources if s == "glo30")
    n_none = sum(1 for s in sources if s == "none")
    return base, len(pts), n_usgs, n_glo, n_none


# ----------------------------------------------------------------------
# Multiprocessing worker support
# ----------------------------------------------------------------------
# Each worker process needs its own GDAL/rasterio handles -- these can't
# be shared/pickled across processes, so they're (re)opened once per
# worker via _init_worker and stashed in module-level _W (defined above).


def _init_worker(glo30_vrt, usgs_per_path, fcc_data, raat, out_dir,
                  interval, frequency_ghz, k_factor, debug):
    _W["fallback_ds"] = rasterio.open(glo30_vrt)
    _W["usgs_per_path"] = usgs_per_path
    _W["fcc_data"] = fcc_data
    _W["raat"] = raat
    _W["out_dir"] = out_dir
    _W["interval"] = interval
    _W["frequency_ghz"] = frequency_ghz
    _W["k_factor"] = k_factor
    _W["debug"] = debug


def _worker_task(row):
    try:
        result = process_path(
            row, _W["interval"], _W["usgs_per_path"], _W["fallback_ds"], _W["out_dir"],
            _W["fcc_data"], _W["raat"], _W["debug"],
            _W["frequency_ghz"], _W["k_factor"])
        return ("ok", row, result)
    except Exception as e:
        return ("err", row, (type(e).__name__, str(e), traceback.format_exc()))



def main():
    start_time = datetime.now()
    print(f"Run started: {start_time:%Y-%m-%d %H:%M:%S}")

    ap = argparse.ArgumentParser()
    ap.add_argument("csv_file")
    ap.add_argument("--interval", type=float, default=80.0, help="sample interval (m)")
    ap.add_argument("--out", default="output", help="output directory")
    ap.add_argument("--start-id", type=int, default=None, help="min Ref_ID (inclusive)")
    ap.add_argument("--end-id", type=int, default=None, help="max Ref_ID (inclusive)")
    ap.add_argument("--overwrite", action="store_true",
                     help="reprocess paths even if output files already exist "
                          "(default: skip them, for resumable runs)")
    ap.add_argument("--skip-usgs", action="store_true",
                     help="don't build/use the USGS 3DEP 1m VRT; use GLO-30 for everything")
    ap.add_argument("--glo30-vrt", default="glo30_remote.vrt")
    ap.add_argument("--usgs-tile-list", default="usgs_tiles.txt")
    ap.add_argument("--rebuild-vrt", action="store_true",
                     help="rebuild glo30-vrt and usgs-tile-list from scratch even if they "
                          "already exist (use this if you previously built them for a "
                          "smaller subset of paths, e.g. via --limit)")
    ap.add_argument("--limit", type=int, default=None,
                     help="only process the first N paths (after any --start-id/--end-id "
                          "filtering) -- handy for quick tests")
    ap.add_argument("--fcc-antennas", default="fcc_antennas.csv",
                     help="fcc_antennas.csv from fcc_antenna_heights.py")
    ap.add_argument("--fcc-paths", default="fcc_paths.csv",
                     help="fcc_paths.csv (from PA.dat) from fcc_antenna_heights.py")
    ap.add_argument("--raat-csv", default="ATT_Mobility_Network_Definition_v9.csv",
                     help="CSV with Path_ID, a_raat, b_raat columns -- used as the "
                          "authoritative Main Antenna Height AGL for every path")
    ap.add_argument("--debug-antennas", action="store_true",
                     help="print antenna-matching diagnostics for each path")
    ap.add_argument("--frequency-ghz", type=float, default=6.0,
                     help="operating frequency in GHz, used for Fresnel zone radius (default 6.0)")
    ap.add_argument("--k-factor", type=float, default=4.0 / 3.0,
                     help="effective earth radius factor k for earth-curvature bulge (default 4/3)")
    ap.add_argument("--workers", type=int, default=1,
                     help="number of parallel worker processes (default 1 = sequential)")
    args = ap.parse_args()
    args.out = os.path.abspath(args.out)
    args.glo30_vrt = os.path.abspath(args.glo30_vrt)
    args.usgs_tile_list = os.path.abspath(args.usgs_tile_list)
    if args.raat_csv:
        args.raat_csv = os.path.abspath(args.raat_csv)
    print(f"Output directory: {args.out}")

    for sub in ("kml", "tables", "plots"):
        os.makedirs(os.path.join(args.out, sub), exist_ok=True)

    error_log_path = os.path.join(args.out, "errors.csv")
    if not os.path.exists(error_log_path) or args.overwrite:
        with open(error_log_path, "w", newline="") as ef:
            csv.writer(ef).writerow(["Ref_ID", "Path_ID2", "error_type", "error_message", "traceback"])

    with open(args.csv_file, newline="") as f:
        rows = [r for r in csv.DictReader(f) if r.get("Ref_ID", "").strip().isdigit()]

    if args.start_id is not None:
        rows = [r for r in rows if int(r["Ref_ID"]) >= args.start_id]
    if args.end_id is not None:
        rows = [r for r in rows if int(r["Ref_ID"]) <= args.end_id]

    if args.limit is not None:
        rows = rows[:args.limit]

    bbox_rows = [
        (float(r["Site_A_Latitude"]), float(r["Site_A_Longitude"]),
         float(r["Site_ B_Latitude"]), float(r["Site_B_Longitude"]))
        for r in rows
    ]

    # --- Build (or reuse) GLO-30 VRT and USGS tile list ----------------
    build_glo30_vrt(bbox_rows, args.glo30_vrt, force=args.rebuild_vrt)
    usgs_urls, usgs_per_path = [], {}
    if not args.skip_usgs:
        usgs_urls, usgs_per_path = build_usgs_tile_list(rows, args.usgs_tile_list, force=args.rebuild_vrt)

    # --- Process paths -------------------------------------------------
    print(f"Processing {len(rows)} path(s) at {args.interval} m intervals.")

    if usgs_urls:
        print(f"  primary (USGS 3DEP 1m): {len(usgs_urls)} tile(s), multi-CRS, sampled individually")
    else:
        print("  primary: none (using GLO-30 for everything)")

    print(f"  fallback (Copernicus GLO-30): {args.glo30_vrt}")

    fcc_data = load_fcc_data(args.fcc_antennas, args.fcc_paths)
    if fcc_data["antennas"]:
        print(f"  FCC antenna data: {len(fcc_data['antennas'])} antennas, "
              f"{len(fcc_data['path_records'])} unique site-pairs")

    raat = load_raat(args.raat_csv)
    print(f"  RAAT data: {len(raat)} path entries from {args.raat_csv}")

    # Filter out paths already done (unless --overwrite)
    todo = []
    for row in rows:
        base = row["Path_ID2"]
        csv_path = os.path.join(args.out, "tables", base + ".csv")
        png_path = os.path.join(args.out, "plots", base + ".png")
        kml_path = os.path.join(args.out, "kml", base + ".kml")
        if not args.overwrite and all(os.path.exists(p) for p in (csv_path, png_path, kml_path)):
            print(f"{base}: already done, skipping")
            continue
        todo.append(row)

    n_total = len(rows)
    n_skip = n_total - len(todo)
    if n_skip:
        print(f"Skipped {n_skip} already-done path(s); {len(todo)} to process.")

    def log_result(i, status, row, payload):
        if status == "ok":
            base, n, n_usgs, n_glo, n_none = payload
            print(f"[{i}/{n_total}] {base}: {n} pts (1m={n_usgs}, glo30={n_glo}, missing={n_none})")
        else:
            err_type, err_msg, tb = payload
            print(f"[{i}/{n_total}] {row.get('Path_ID2')}: ERROR {err_msg}")
            with open(error_log_path, "a", newline="") as ef:
                w = csv.writer(ef)
                w.writerow([row.get("Ref_ID"), row.get("Path_ID2"), err_type, err_msg,
                             tb.replace("\n", " | ")])

    if args.workers > 1 and todo:
        print(f"Using {args.workers} parallel worker process(es).")
        with ProcessPoolExecutor(
                max_workers=args.workers, initializer=_init_worker,
                initargs=(args.glo30_vrt, usgs_per_path, fcc_data, raat, args.out,
                          args.interval, args.frequency_ghz, args.k_factor, args.debug_antennas)
        ) as ex:
            futures = {ex.submit(_worker_task, row): row for row in todo}
            for i, fut in enumerate(as_completed(futures), 1):
                status, row, payload = fut.result()
                log_result(i, status, row, payload)
    else:
        with rasterio.open(args.glo30_vrt) as fallback_ds:
            for i, row in enumerate(todo, 1):
                try:
                    result = process_path(
                        row, args.interval, usgs_per_path, fallback_ds, args.out,
                        fcc_data, raat, args.debug_antennas,
                        args.frequency_ghz, args.k_factor)
                    log_result(i, "ok", row, result)
                except Exception as e:
                    log_result(i, "err", row, (type(e).__name__, str(e), traceback.format_exc()))

    end_time = datetime.now()
    elapsed = end_time - start_time
    # format as HH:MM:SS, allowing for >24h runs
    total_seconds = int(elapsed.total_seconds())
    hh, rem = divmod(total_seconds, 3600)
    mm, ss = divmod(rem, 60)
    print(f"Run ended:   {end_time:%Y-%m-%d %H:%M:%S}")
    print(f"Elapsed:     {hh:02d}:{mm:02d}:{ss:02d}")


if __name__ == "__main__":
    main()
