#!/usr/bin/env python3
"""
path_profile_tool.py (combined, all-in-one)
============================================
For each microwave path in ATT_Network.csv:
  1. Builds (or reuses) two remote VRT mosaics -- no tile downloads:
       - glo30_remote.vrt : Copernicus GLO-30 DSM (global, 30m,
         includes terrain + tree canopy + most buildings)
       - usgs_remote.vrt  : USGS 3DEP 1m DSM (US only, patchy lidar
         coverage, includes canopy/buildings where available)
     Both point at /vsicurl/ URLs on public cloud storage; GDAL only
     fetches the small blocks needed for each sample point via HTTP
     range requests.
  2. Generates a KML line (EPSG:4326) between Site A and Site B, with
     80 m sample-point placemarks
  3. Samples the geodesic at ~80 m intervals. For each point, tries
     USGS 1m first, falls back to GLO-30 if nodata/out-of-bounds.
  4. Saves a per-path table (CSV): X (lon), Y (lat), Z (m), source
  5. Plots the surface profile vs. distance, with a line-of-sight
     reference line and obstruction points (surface above LOS)
     highlighted, color-coded by data source

Outputs (per path, named by Ref_ID_Path_ID):
  kml/<Ref_ID>_<Path_ID>.kml
  tables/<Ref_ID>_<Path_ID>.csv
  plots/<Ref_ID>_<Path_ID>.png

Usage:
  python path_profile_tool.py ATT_Network.csv --interval 80 --out output

  # Optional:
  #   --start-id / --end-id   restrict to a Ref_ID range
  #   --overwrite             reprocess paths even if outputs exist
  #   --skip-usgs             don't query USGS 3DEP at all, use GLO-30
  #                           for everything (faster, lower-res)
  #   --glo30-vrt / --usgs-vrt  override default VRT filenames/paths
  #                           (built once and reused on subsequent runs)
"""

import argparse
import csv
import json
import math
import os
import subprocess
import time
import traceback
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import datetime, timedelta

# --- GDAL/curl tuning for robust remote (/vsicurl/) COG reads ---------
# Set before importing rasterio so GDAL picks these up at driver init.
os.environ.setdefault("GDAL_DISABLE_READDIR_ON_OPEN", "EMPTY_DIR")  # don't list bucket dirs
os.environ.setdefault("CPL_VSIL_CURL_ALLOWED_EXTENSIONS", ".tif,.tiff,.vrt")
os.environ.setdefault("GDAL_HTTP_MAX_RETRY", "5")        # retry transient failures
os.environ.setdefault("GDAL_HTTP_RETRY_DELAY", "2")      # seconds between retries
os.environ.setdefault("GDAL_HTTP_TIMEOUT", "30")         # per-request timeout (s)
os.environ.setdefault("GDAL_HTTP_CONNECTTIMEOUT", "10")
os.environ.setdefault("CPL_VSIL_CURL_CACHE_SIZE", "200000000")  # 200MB block cache
os.environ.setdefault("VSI_CACHE", "TRUE")
os.environ.setdefault("VSI_CACHE_SIZE", "50000000")      # 50MB per-file cache
os.environ.setdefault("GDAL_CACHEMAX", "512")            # MB, GDAL block cache

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import numpy as np
import requests
import rasterio
import pyproj
from pyproj import Geod
import simplekml

GEOD = Geod(ellps="WGS84")  # EPSG:4326 reference ellipsoid

GLO30_BUCKET = "https://copernicus-dem-30m.s3.amazonaws.com"
TNM_API = "https://tnmaccess.nationalmap.gov/api/v1/products"
USGS_1M_DATASET = "Digital Elevation Model (DEM) 1 meter"


# ----------------------------------------------------------------------
# Remote VRT builders (no downloads -- /vsicurl/ + HTTP range requests)
# ----------------------------------------------------------------------
def glo30_tile_name(lat, lon):
    lat_floor, lon_floor = math.floor(lat), math.floor(lon)
    ns = f"N{lat_floor:02d}" if lat_floor >= 0 else f"S{abs(lat_floor):02d}"
    ew = f"E{lon_floor:03d}" if lon_floor >= 0 else f"W{abs(lon_floor):03d}"
    return f"Copernicus_DSM_COG_10_{ns}_00_{ew}_00_DEM"


def glo30_tiles_for_bbox(lat1, lon1, lat2, lon2):
    lat_lo, lat_hi = sorted((math.floor(lat1), math.floor(lat2)))
    lon_lo, lon_hi = sorted((math.floor(lon1), math.floor(lon2)))
    out = set()
    for lat in range(lat_lo, lat_hi + 1):
        for lon in range(lon_lo, lon_hi + 1):
            out.add(glo30_tile_name(lat + 0.5, lon + 0.5))
    return out


def build_glo30_vrt(rows, vrt_path, force=False):
    if os.path.exists(vrt_path) and not force:
        print(f"Reusing existing {vrt_path}")
        return
    if os.path.exists(vrt_path) and force:
        os.remove(vrt_path)
    tiles = set()
    for lat1, lon1, lat2, lon2 in rows:
        tiles |= glo30_tiles_for_bbox(lat1, lon1, lat2, lon2)

    print(f"GLO-30: {len(tiles)} candidate tiles, checking existence...")
    urls = []
    for t in sorted(tiles):
        url = f"{GLO30_BUCKET}/{t}/{t}.tif"
        try:
            r = requests.head(url, timeout=15)
            if r.status_code == 200:
                urls.append("/vsicurl/" + url)
        except Exception:
            pass
    print(f"GLO-30: {len(urls)} tiles exist, building {vrt_path}")
    if urls:
        list_file = vrt_path + ".input_list.txt"
        with open(list_file, "w") as f:
            f.write("\n".join(urls) + "\n")
        subprocess.run(["gdalbuildvrt", "-input_file_list", list_file, vrt_path], check=True)


def usgs_tile_urls_for_bbox(bbox, buffer=0.005, max_retries=3):
    west = min(bbox[0], bbox[2]) - buffer
    east = max(bbox[0], bbox[2]) + buffer
    south = min(bbox[1], bbox[3]) - buffer
    north = max(bbox[1], bbox[3]) + buffer
    params = {
        "datasets": USGS_1M_DATASET,
        "bbox": f"{west},{south},{east},{north}",
        "outputFormat": "JSON",
        "max": 50,
    }
    for attempt in range(max_retries):
        try:
            r = requests.get(TNM_API, params=params, timeout=30)
            r.raise_for_status()
            data = r.json()
            return [item["downloadURL"] for item in data.get("items", [])
                    if item.get("downloadURL", "").lower().endswith((".tif", ".tiff"))]
        except Exception:
            if attempt == max_retries - 1:
                return []


def build_usgs_tile_list(rows, list_path, force=False):
    """Builds/reuses:
      - list_path: plain-text list of all unique /vsicurl/ USGS 1m tile URLs
        (deduplicated across all paths -- kept for informational purposes)
      - list_path + '.per_path.json': dict Path_ID2 -> list of /vsicurl/ URLs
        for just that path's bounding box (typically 1-5 tiles). This is
        what's actually used for sampling -- avoids looping over thousands
        of irrelevant URLs per query point.

    `rows` is the full list of CSV row dicts (needs Path_ID2 + coordinates).
    """
    per_path_path = list_path + ".per_path.json"
    if os.path.exists(list_path) and os.path.exists(per_path_path) and not force:
        print(f"Reusing existing {list_path} and {per_path_path}")
        with open(list_path) as f:
            global_urls = [line.strip() for line in f if line.strip()]
        with open(per_path_path) as f:
            per_path = json.load(f)
        return global_urls, per_path

    urls = set()
    per_path = {}
    for i, row in enumerate(rows, 1):
        lat1, lon1 = float(row["Site_A_Latitude"]), float(row["Site_A_Longitude"])
        lat2, lon2 = float(row["Site_ B_Latitude"]), float(row["Site_B_Longitude"])
        found = usgs_tile_urls_for_bbox((lon1, lat1, lon2, lat2))
        vsi = ["/vsicurl/" + u for u in sorted(set(found))]
        per_path[row["Path_ID2"]] = vsi
        urls |= set(found)
        if i % 100 == 0:
            print(f"USGS 1m: {i}/{len(rows)} paths queried, {len(urls)} unique tiles so far")

    vsicurl_urls = ["/vsicurl/" + u for u in sorted(urls)]
    print(f"USGS 1m: {len(vsicurl_urls)} unique tiles -> {list_path}, "
          f"per-path mapping -> {per_path_path}")
    with open(list_path, "w") as f:
        for u in vsicurl_urls:
            f.write(u + "\n")
    with open(per_path_path, "w") as f:
        json.dump(per_path, f)
    return vsicurl_urls, per_path


class MultiCRSRasterSet:
    """A set of remote rasters, possibly in different CRSs (e.g. USGS 1m
    tiles in various UTM zones). For each query point (given in EPSG:4326
    lon/lat), checks each tile's bounds (transforming the point into that
    tile's CRS) and samples the first tile that contains it.

    Datasets and transformers are opened lazily and cached -- only tiles
    that are actually hit get opened."""

    def __init__(self, urls):
        self.urls = urls
        self._ds = {}        # url -> rasterio dataset (or False if failed to open)
        self._transformer = {}  # url -> pyproj Transformer EPSG:4326 -> tile CRS

    def _get(self, url):
        if url not in self._ds:
            try:
                ds = rasterio.open(url)
                self._ds[url] = ds
                self._transformer[url] = pyproj.Transformer.from_crs(
                    "EPSG:4326", ds.crs, always_xy=True)
            except Exception:
                self._ds[url] = False
        return self._ds[url]

    def sample(self, lon, lat):
        for url in self.urls:
            ds = self._get(url)
            if ds is False:
                continue
            x, y = self._transformer[url].transform(lon, lat)
            b = ds.bounds
            if not (b.left <= x <= b.right and b.bottom <= y <= b.top):
                continue
            try:
                val = float(next(ds.sample([(x, y)]))[0])
                nodata = ds.nodata
                if nodata is not None and val == nodata:
                    continue
                if math.isnan(val):
                    continue
                return val
            except Exception:
                continue
        return float("nan")

    def close(self):
        for ds in self._ds.values():
            if ds:
                ds.close()


# ----------------------------------------------------------------------
# Geometry helpers
# ----------------------------------------------------------------------
def sample_points(lat1, lon1, lat2, lon2, interval_m):
    az12, az21, dist = GEOD.inv(lon1, lat1, lon2, lat2)
    if dist == 0:
        return [(lon1, lat1, 0.0)]

    n_intervals = max(1, math.ceil(dist / interval_m))
    points = []
    for i in range(n_intervals + 1):
        d = min(i * interval_m, dist)
        lon, lat, _ = GEOD.fwd(lon1, lat1, az12, d)
        points.append((lon, lat, d))
    points[-1] = (lon2, lat2, dist)
    return points


# ----------------------------------------------------------------------
# Elevation lookup: primary raster with fallback raster
# ----------------------------------------------------------------------
def sample_raster(ds, coords):
    """Returns list of (value or NaN). NaN if nodata, out of bounds, or band error."""
    if ds is None:
        return [float("nan")] * len(coords)
    nodata = ds.nodata
    bounds = ds.bounds
    out = []
    for lon, lat in coords:
        if not (bounds.left <= lon <= bounds.right and bounds.bottom <= lat <= bounds.top):
            out.append(float("nan"))
            continue
        try:
            val = float(next(ds.sample([(lon, lat)]))[0])
            if nodata is not None and val == nodata:
                out.append(float("nan"))
            elif math.isnan(val):
                out.append(float("nan"))
            else:
                out.append(val)
        except Exception:
            out.append(float("nan"))
    return out


def get_elevations(points, primary_set, fallback_ds):
    """primary_set: MultiCRSRasterSet or None. fallback_ds: rasterio dataset (EPSG:4326).
    Returns (elevs, sources). sources[i] in {'usgs_1m','glo30','none'}."""
    coords = [(lon, lat) for lon, lat, _ in points]

    fallback_vals = sample_raster(fallback_ds, coords)

    elevs, sources = [], []
    for (lon, lat), fv in zip(coords, fallback_vals):
        pv = primary_set.sample(lon, lat) if primary_set is not None else float("nan")
        if not math.isnan(pv):
            elevs.append(pv)
            sources.append("usgs_1m")
        elif not math.isnan(fv):
            elevs.append(fv)
            sources.append("glo30")
        else:
            elevs.append(float("nan"))
            sources.append("none")
    return elevs, sources


# ----------------------------------------------------------------------
# KML
# ----------------------------------------------------------------------
def write_kml(path_id, name_a, name_b, lat1, lon1, lat2, lon2, points, elevs, sources, out_path):
    kml = simplekml.Kml()
    kml.document.name = path_id

    # Path line: just the two endpoints
    line = kml.newlinestring(name=path_id)
    line.coords = [(lon1, lat1, 0), (lon2, lat2, 0)]
    line.altitudemode = simplekml.AltitudeMode.clamptoground
    line.style.linestyle.width = 2
    line.style.linestyle.color = simplekml.Color.red

    # Site markers
    kml.newpoint(name=name_a, coords=[(lon1, lat1)])
    kml.newpoint(name=name_b, coords=[(lon2, lat2)])

    kml.save(out_path)


# ----------------------------------------------------------------------
# Table + plot
# ----------------------------------------------------------------------
def write_table(points, elevs, sources, out_path):
    with open(out_path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["index", "distance_m", "X_lon", "Y_lat", "Z_surface_elev_m", "source"])
        for i, ((lon, lat, d), z, src) in enumerate(zip(points, elevs, sources)):
            w.writerow([i, f"{d:.2f}", f"{lon:.8f}", f"{lat:.8f}",
                        "" if math.isnan(z) else f"{z:.2f}", src])


def plot_profile(path_id, name_a, name_b, points, elevs, sources, out_path,
                  ant_a_main_amsl=None, ant_a_div_amsl=None,
                  ant_b_main_amsl=None, ant_b_div_amsl=None,
                  main_a=None, div_a=None, main_b=None, div_b=None,
                  frequency_ghz=6.0, k_factor=4.0 / 3.0,
                  site_name_a="", site_name_b="",
                  az_a=None, az_b=None,
                  lat_a=None, lon_a=None, lat_b=None, lon_b=None):
    """MW convention: LOS is straight between antenna heights, earth is
    bulged upward per K factor. True AMSL coordinate system.
    Three terrain curves shown simultaneously (K=inf, K=4/3, K=1/2).
    One set of Fresnel zone lines (F1, 0.6F1) relative to straight LOS.
    Obstruction and F1-violation markers shown per K factor."""
    dists = np.array([p[2] for p in points]) / 1609.344   # meters -> miles
    elevs_raw = np.array(elevs) * 3.28084                  # meters -> feet AMSL
    sources = np.array(sources)

    # Convert antenna AMSL values (arrive in meters) to feet
    def to_ft(v):
        return v * 3.28084 if v is not None else None
    ant_a_main_amsl = to_ft(ant_a_main_amsl)
    ant_a_div_amsl  = to_ft(ant_a_div_amsl)
    ant_b_main_amsl = to_ft(ant_b_main_amsl)
    ant_b_div_amsl  = to_ft(ant_b_div_amsl)

    # Distances in km for bulge / Fresnel math
    D_km = dists[-1] * 1.60934
    d1_km = dists * 1.60934
    d2_km = D_km - d1_km

    def earth_bulge_ft(k):
        """Upward earth bulge in feet at each sample point for a given K factor."""
        if k is None or not np.isfinite(k) or k <= 0 or D_km <= 0:
            return np.zeros_like(dists)
        return (d1_km * d2_km) / (12.74 * k) * 3.28084

    # F1 Fresnel zone radius in feet (same for all K factors)
    if D_km > 0 and frequency_ghz > 0:
        f1_ft = 17.32 * np.sqrt(np.clip(d1_km * d2_km, 0, None) / (frequency_ghz * D_km)) * 3.28084
    else:
        f1_ft = np.zeros_like(dists)

    K_CONFIGS = [
        (np.inf,     "K=\u221e (flat earth)", "tab:green",  1.0),
        (4.0 / 3.0, "K=4/3 (standard)",       "tab:orange", 1.2),
        (0.5,        "K=1/2 (worst case)",     "tab:red",    1.2),
    ]

    # Mast offset computed early so draw_los_set can use it to clip LOS lines
    path_len_mi = dists[-1]
    mast_offset = path_len_mi * 0.008
    x_a = dists[0]  + mast_offset
    x_b = dists[-1] - mast_offset

    fig = plt.figure(figsize=(12, 9.0))
    ax = fig.add_axes((0.08, 0.30, 0.88, 0.62))
    valid = ~np.isnan(elevs_raw)

    # Terrain: original color scheme (green/orange/red per K factor)
    for k, label, color, lw in K_CONFIGS:
        bulge = earth_bulge_ft(k)
        eff_elev = elevs_raw + bulge
        if valid.any():
            ax.plot(dists[valid], eff_elev[valid], "-", color=color, lw=lw,
                    alpha=0.8, zorder=2)

    # Fill under K=inf terrain
    if valid.any():
        ax.fill_between(dists, np.nanmin(elevs_raw) - 50, elevs_raw,
                         color="tab:green", alpha=0.15, zorder=1)

    def draw_los_set(y0, y1, los_color, is_main=True, div_suffix=""):
        if y0 is None or y1 is None:
            return
        # LOS line clipped to mast positions so it terminates at each triangle
        ax.plot([x_a, x_b], [y0, y1], "-", color=los_color, lw=1.5, zorder=5)
        # Fresnel zones span the full path (zone boundaries, not endpoint lines)
        los_full = np.interp(dists, [dists[0], dists[-1]], [y0, y1])
        los_f1   = los_full - f1_ft
        los_06f1 = los_full - 0.6 * f1_ft
        ax.plot(dists, los_f1,   "-.", color="tab:blue", lw=1.0, zorder=5)
        ax.plot(dists, los_06f1, ":",  color="tab:cyan",  lw=1.2, zorder=5)

        for k, oc in [(np.inf, "tab:green"), (4.0/3.0, "tab:orange"), (0.5, "tab:red")]:
            bulge = earth_bulge_ft(k)
            eff_elev = elevs_raw + bulge
            obstructed = valid & (eff_elev > los_full)
            f1_viol = valid & (eff_elev > los_f1) & ~obstructed
            if obstructed.any():
                ax.scatter(dists[obstructed], eff_elev[obstructed],
                           color=oc, marker="x", s=40, zorder=6, linewidths=1.5)
            if f1_viol.any():
                ax.scatter(dists[f1_viol], eff_elev[f1_viol],
                           color=oc, marker="x", s=25, zorder=6,
                           alpha=0.6, linewidths=1.0)

    # Main LOS
    los_y0 = ant_a_main_amsl if ant_a_main_amsl is not None else (elevs_raw[0] if valid[0] else None)
    los_y1 = ant_b_main_amsl if ant_b_main_amsl is not None else (elevs_raw[-1] if valid[-1] else None)
    draw_los_set(los_y0, los_y1, "black", is_main=True)

    # Diversity LOS lines
    if ant_a_main_amsl is not None and ant_b_div_amsl is not None:
        draw_los_set(ant_a_main_amsl, ant_b_div_amsl, "#550077", div_suffix="B div")
    if ant_a_div_amsl is not None and ant_b_main_amsl is not None:
        draw_los_set(ant_a_div_amsl, ant_b_main_amsl, "#220044", div_suffix="A div")

    # Antenna masts
    def draw_site(d, ground_z, main_amsl, div_amsl):
        if ground_z is None or math.isnan(ground_z):
            return
        # Draw mast from ground to main antenna only -- line terminates
        # at the triangle tip, never extends beyond it.
        if main_amsl is not None:
            ax.plot([d, d], [ground_z, main_amsl], color="black", lw=1.5, zorder=7)
            ax.plot(d, main_amsl, "^", color="black", ms=7, zorder=8)
        # If diversity exists, extend mast from main up to div (div is lower
        # on the tower, so draw from main down to div, or from ground to div
        # if main is absent). Line again terminates at the div triangle.
        if div_amsl is not None:
            base = main_amsl if main_amsl is not None else ground_z
            ax.plot([d, d], [base, div_amsl], color="black", lw=1.5, zorder=7)
            ax.plot(d, div_amsl, "^", mfc="white", mec="black", ms=7, zorder=8)

    if valid[0]:
        draw_site(x_a, elevs_raw[0], ant_a_main_amsl, ant_a_div_amsl)
    if valid[-1]:
        draw_site(x_b, elevs_raw[-1], ant_b_main_amsl, ant_b_div_amsl)

    # Axis formatting
    ax.set_xlim(0, path_len_mi)
    ax.set_xlabel(f"Path length ({path_len_mi:.2f} mi)", fontsize=10)
    ax.set_ylabel("Elevation (ft AMSL)", fontsize=10)
    ax.set_title(f"Path Profile: {path_id}", fontsize=11, fontweight="bold")

    # Scale tick intervals to path length
    if path_len_mi <= 5:
        major_interval, minor_interval = 1.0, 0.25
    elif path_len_mi <= 15:
        major_interval, minor_interval = 2.0, 0.5
    elif path_len_mi <= 40:
        major_interval, minor_interval = 5.0, 1.0
    elif path_len_mi <= 100:
        major_interval, minor_interval = 10.0, 2.0
    else:
        major_interval, minor_interval = 20.0, 5.0

    ax.xaxis.set_major_locator(plt.MultipleLocator(major_interval))
    ax.xaxis.set_minor_locator(plt.MultipleLocator(minor_interval))
    ax.tick_params(axis="x", which="major", length=6, width=1)
    ax.tick_params(axis="x", which="minor", length=3, width=0.5)
    ax.grid(True, alpha=0.4, linestyle="-", linewidth=0.5)

    # --- Info boxes below plot (matching Nokia example) -------------------
    def fmt_agl(ant):
        """Format antenna AGL heights from main/div antenna dicts."""
        if ant is None:
            return "N/A"
        h = ant["antenna_height_agl_m"] * 3.28084
        return f"{h:.1f} ft AGL"

    def fmt_amsl(v_ft):
        return f"{v_ft:.0f} ft ASL" if v_ft is not None else "N/A"

    def fmt_agl_both(main, div):
        parts = []
        if main is not None:
            parts.append(f"{main['antenna_height_agl_m'] * 3.28084:.1f}")
        if div is not None:
            parts.append(f"{div['antenna_height_agl_m'] * 3.28084:.1f}")
        return (", ".join(parts) + " ft AGL") if parts else "N/A"

    a_elev_ft = elevs_raw[0] if valid[0] else None
    b_elev_ft = elevs_raw[-1] if valid[-1] else None

    # Left box: Site A
    left_lines = [
        site_name_a or name_a,
        f"Call Sign:  {name_a}",
        f"Latitude:   {lat_a:.6f}" if lat_a is not None else "",
        f"Longitude:  {lon_a:.6f}" if lon_a is not None else "",
        f"Azimuth:    {az_a:.2f}\u00b0" if az_a is not None else "",
        f"Elevation:  {fmt_amsl(a_elev_ft)}",
        f"Antenna CL: {fmt_agl_both(main_a, div_a)}",
    ]
    left_lines = [l for l in left_lines if l]

    # Center box: path parameters
    center_lines = [
        f"Frequency (MHz) = {frequency_ghz * 1000:.1f}",
        f"K = 4/3, 1/2",
        f"%F1 = 100.00, 60.00",
        f"Path length ({path_len_mi:.2f} mi)",
    ]

    # Right box: Site B
    right_lines = [
        site_name_b or name_b,
        f"Call Sign:  {name_b}",
        f"Latitude:   {lat_b:.6f}" if lat_b is not None else "",
        f"Longitude:  {lon_b:.6f}" if lon_b is not None else "",
        f"Azimuth:    {az_b:.2f}\u00b0" if az_b is not None else "",
        f"Elevation:  {fmt_amsl(b_elev_ft)}",
        f"Antenna CL: {fmt_agl_both(main_b, div_b)}",
    ]
    right_lines = [l for l in right_lines if l]

    box_props = dict(boxstyle="round,pad=0.5", facecolor="white", edgecolor="black", linewidth=0.8)

    fig.text(0.08, 0.20, "\n".join(left_lines), fontsize=8, va="top", ha="left",
             fontfamily="monospace", bbox=box_props)
    # Center box moved lower so it doesn't overlap x-axis label
    fig.text(0.5, 0.20, "\n".join(center_lines), fontsize=8, va="top", ha="center",
             fontfamily="monospace", bbox=box_props)
    fig.text(0.92, 0.20, "\n".join(right_lines), fontsize=8, va="top", ha="right",
             fontfamily="monospace", bbox=box_props)

    fig.savefig(out_path, dpi=120, bbox_inches="tight")
    plt.close(fig)


def load_fcc_data(antennas_path, paths_path):
    """Returns (antennas, by_location, path_records):
      antennas: dict (call_sign, location_number, antenna_number) -> {antenna_height_agl_m, polarization}
      by_location: dict (call_sign, location_number) -> list of antenna dicts
                    (same dicts as in `antennas`), for fallback lookups
                    when a PA-referenced antenna_number has no height.
      path_records: dict frozenset({call_sign_a, call_sign_b}) -> list of PA rows (dicts)
    """
    antennas = {}
    by_location = {}
    if antennas_path and os.path.exists(antennas_path):
        with open(antennas_path, newline="") as f:
            for row in csv.DictReader(f):
                cs = row["call_sign"].strip()
                loc = row["location_number"].strip()
                ant_num = row["antenna_number"].strip()
                h = row.get("antenna_height_agl_m", "").strip()
                if not h:
                    continue
                d = {
                    "antenna_height_agl_m": float(h),
                    "antenna_make": row.get("antenna_make", "").strip(),
                    "antenna_model": row.get("antenna_model", "").strip(),
                    "polarization": row.get("polarization", "").strip(),
                }
                antennas[(cs, loc, ant_num)] = d
                by_location.setdefault((cs, loc), []).append(d)

    path_records = {}
    if paths_path and os.path.exists(paths_path):
        with open(paths_path, newline="") as f:
            for row in csv.DictReader(f):
                a, b = row["call_sign"].strip(), row["receive_call_sign"].strip()
                if not a or not b:
                    continue
                key = frozenset((a, b))
                path_records.setdefault(key, []).append(row)
    return {"antennas": antennas, "by_location": by_location, "path_records": path_records}


def find_path_antennas(call_sign_a, call_sign_b, fcc_data, debug=False):
    """Returns (main_a, div_a, main_b, div_b), each an antenna dict
    {antenna_height_agl_m, antenna_make, antenna_model, polarization}
    or None.

    Rule (matches populate_summary_v2.py's find_main_div): for a site X
    in pair (X, Y), collect ALL of X's own TX antennas from PA.dat
    records where call_sign=X and receive_call_sign=Y. Antennas with the
    same height (within 0.5m -- dual-polarization copies of the same
    physical antenna) are collapsed to one. Of the resulting distinct
    antennas: the TALLEST is X's MAIN antenna (main is always higher on
    the tower than diversity), the next-tallest (if any) is X's
    DIVERSITY antenna.

    This relies ONLY on each site's own TX-side PA records, never on the
    RX-side references (which are frequently dangling in AN.dat)."""
    antennas = fcc_data["antennas"]
    path_records = fcc_data["path_records"]
    records = path_records.get(frozenset((call_sign_a, call_sign_b)), [])

    def own_antennas(cs, other_cs):
        found = []
        for r in records:
            if r["call_sign"].strip() == cs and r["receive_call_sign"].strip() == other_cs:
                ant = antennas.get((cs, r["tx_location_number"].strip(), r["tx_antenna_number"].strip()))
                if ant:
                    found.append(ant)
        deduped = []
        for a in found:
            if not any(abs(a["antenna_height_agl_m"] - b["antenna_height_agl_m"]) < 0.5 for b in deduped):
                deduped.append(a)
        deduped.sort(key=lambda a: a["antenna_height_agl_m"], reverse=True)
        return deduped

    a_ants = own_antennas(call_sign_a, call_sign_b)
    b_ants = own_antennas(call_sign_b, call_sign_a)

    if debug:
        print(f"    [{call_sign_a} <-> {call_sign_b}] "
              f"A heights={[a['antenna_height_agl_m'] for a in a_ants]}, "
              f"B heights={[b['antenna_height_agl_m'] for b in b_ants]}")

    main_a = a_ants[0] if a_ants else None
    div_a = a_ants[1] if len(a_ants) > 1 else None
    main_b = b_ants[0] if b_ants else None
    div_b = b_ants[1] if len(b_ants) > 1 else None
    return main_a, div_a, main_b, div_b


def load_raat(path):
    """dict Path_ID -> {a_raat, b_raat, frequency_ghz, site_name_a, site_name_b}.
    Pulls from ATT_Mobility_Network_Definition_vN.csv:
      - a_raat / b_raat: main antenna height AGL (m)
      - fr_frequency_assigned_MHz: per-path licensed frequency
      - a_desc / b_desc: site names
    """
    raat = {}
    if not path or not os.path.exists(path):
        return raat
    with open(path, newline="", encoding="utf-8-sig") as f:
        for row in csv.DictReader(f):
            pid = row.get("Path_ID", "").strip()
            try:
                a = float(row["a_raat"])
                b = float(row["b_raat"])
            except (ValueError, KeyError):
                continue
            freq_mhz = None
            try:
                freq_mhz = float(row["fr_frequency_assigned_MHz"])
            except (ValueError, KeyError):
                pass
            raat[pid] = {
                "a_raat": a,
                "b_raat": b,
                "frequency_ghz": freq_mhz / 1000.0 if freq_mhz else None,
                "site_name_a": row.get("a_desc", "").strip(),
                "site_name_b": row.get("b_desc", "").strip(),
            }
    return raat


# ----------------------------------------------------------------------
# Main processing
# ----------------------------------------------------------------------
# Module-level dict used both for multiprocessing-worker state and (in
# sequential mode) just to track "have we printed a write-verification
# line yet" -- see _W["_verified"] in process_path below.
_W = {}


def process_path(row, interval_m, usgs_per_path, fallback_ds, out_dir,
                  fcc_data=None, raat=None, debug=False,
                  frequency_ghz=6.0, k_factor=4.0 / 3.0):
    ref_id = row["Ref_ID"]
    path_id = row["Path_ID2"]
    name_a = row["Site_A_Call_Sign"]
    name_b = row["Site_B_Call_Sign"]
    lat1 = float(row["Site_A_Latitude"])
    lon1 = float(row["Site_A_Longitude"])
    lat2 = float(row["Site_ B_Latitude"])
    lon2 = float(row["Site_B_Longitude"])

    # Per-path data from RAAT/v12 file
    raat_entry = (raat or {}).get(path_id, {})
    raat_a = raat_entry.get("a_raat") if isinstance(raat_entry, dict) else None
    raat_b = raat_entry.get("b_raat") if isinstance(raat_entry, dict) else None
    # Use per-path frequency if available, else fall back to CLI default
    path_freq_ghz = raat_entry.get("frequency_ghz") if isinstance(raat_entry, dict) else None
    if path_freq_ghz:
        frequency_ghz = path_freq_ghz
    # Site names: prefer v12 a_desc/b_desc, fall back to ATT_Network.csv columns
    site_name_a = (raat_entry.get("site_name_a") if isinstance(raat_entry, dict) else None) \
        or row.get("Site_A_Name", "").strip()
    site_name_b = (raat_entry.get("site_name_b") if isinstance(raat_entry, dict) else None) \
        or row.get("Site_B_Name", "").strip()

    t0 = time.perf_counter()
    pts = sample_points(lat1, lon1, lat2, lon2, interval_m)
    t1 = time.perf_counter()
    # Per-path raster set: only the handful of USGS tiles whose bbox
    # actually covers this path (looked up via the cached per-path mapping),
    # NOT the full global tile list -- avoids opening thousands of
    # irrelevant /vsicurl/ tiles per query point.
    path_urls = (usgs_per_path or {}).get(path_id, [])
    primary_set = MultiCRSRasterSet(path_urls) if path_urls else None
    try:
        elevs, sources = get_elevations(pts, primary_set, fallback_ds)
    finally:
        if primary_set is not None:
            primary_set.close()
    t2 = time.perf_counter()

    fcc_data = fcc_data or {"antennas": {}, "by_location": {}, "path_records": {}}

    if debug:
        print(f"  [{path_id}]")
    main_a, div_a, main_b, div_b = find_path_antennas(name_a, name_b, fcc_data, debug)
    t3 = time.perf_counter()

    # MAIN antenna height AGL: always sourced from RAAT (complete
    # coverage across all paths). FCC PA.dat/AN.dat-derived main_a/main_b
    # used only for make/model and div_a/div_b; RAAT wins for height.
    if raat_a is not None:
        main_a = dict(main_a or {}, antenna_height_agl_m=raat_a) if main_a else {"antenna_height_agl_m": raat_a}
    if raat_b is not None:
        main_b = dict(main_b or {}, antenna_height_agl_m=raat_b) if main_b else {"antenna_height_agl_m": raat_b}

    def amsl(ground_elev, ant):
        if ant is None or math.isnan(ground_elev):
            return None
        return ground_elev + ant["antenna_height_agl_m"]

    ant_a_main_amsl = amsl(elevs[0], main_a)
    ant_a_div_amsl = amsl(elevs[0], div_a)
    ant_b_main_amsl = amsl(elevs[-1], main_b)
    ant_b_div_amsl = amsl(elevs[-1], div_b)

    # Compute geodesic azimuths for info boxes
    az_a_to_b, az_b_to_a, _ = GEOD.inv(lon1, lat1, lon2, lat2)
    az_a_to_b %= 360
    az_b_to_a %= 360

    base = path_id
    kml_path = os.path.join(out_dir, "kml", base + ".kml")
    table_path = os.path.join(out_dir, "tables", base + ".csv")
    png_path = os.path.join(out_dir, "plots", base + ".png")

    write_kml(path_id, name_a, name_b, lat1, lon1, lat2, lon2, pts, elevs, sources, kml_path)
    t4 = time.perf_counter()
    write_table(pts, elevs, sources, table_path)
    t5 = time.perf_counter()
    plot_profile(path_id, name_a, name_b, pts, elevs, sources, png_path,
                 ant_a_main_amsl=ant_a_main_amsl, ant_a_div_amsl=ant_a_div_amsl,
                 ant_b_main_amsl=ant_b_main_amsl, ant_b_div_amsl=ant_b_div_amsl,
                 main_a=main_a, div_a=div_a, main_b=main_b, div_b=div_b,
                 frequency_ghz=frequency_ghz, k_factor=k_factor,
                 site_name_a=site_name_a, site_name_b=site_name_b,
                 az_a=az_a_to_b, az_b=az_b_to_a,
                 lat_a=lat1, lon_a=lon1, lat_b=lat2, lon_b=lon2)
    t6 = time.perf_counter()

    # Verify the file was actually written where we think it was, with a
    # fresh mtime -- prints once for the first path processed by each
    # worker (and always under --debug-antennas), to settle any ambiguity
    # about output location.
    if debug or os.environ.get("PROFILE_TIMING") or os.environ.get("VERIFY_WRITE") \
            or not _W.get("_verified"):
        try:
            st = os.stat(png_path)
            print(f"  [{path_id}] wrote {png_path} "
                  f"({st.st_size} bytes, mtime={datetime.fromtimestamp(st.st_mtime):%Y-%m-%d %H:%M:%S})")
        except OSError as e:
            print(f"  [{path_id}] WARNING: could not stat {png_path}: {e}")
        _W["_verified"] = True

    if debug or os.environ.get("PROFILE_TIMING"):
        print(f"  [{path_id}] timing: sample={t1-t0:.2f}s elev={t2-t1:.2f}s "
              f"antenna={t3-t2:.2f}s kml={t4-t3:.2f}s table={t5-t4:.2f}s plot={t6-t5:.2f}s "
              f"total={t6-t0:.2f}s")

    n_usgs = sum(1 for s in sources if s == "usgs_1m")
    n_glo = sum(1 for s in sources if s == "glo30")
    n_none = sum(1 for s in sources if s == "none")
    return base, len(pts), n_usgs, n_glo, n_none


# ----------------------------------------------------------------------
# Multiprocessing worker support
# ----------------------------------------------------------------------
# Each worker process needs its own GDAL/rasterio handles -- these can't
# be shared/pickled across processes, so they're (re)opened once per
# worker via _init_worker and stashed in module-level _W (defined above).


def _init_worker(glo30_vrt, usgs_per_path, fcc_data, raat, out_dir,
                  interval, k_factor, default_frequency_ghz, debug):
    _W["fallback_ds"] = rasterio.open(glo30_vrt)
    _W["usgs_per_path"] = usgs_per_path
    _W["fcc_data"] = fcc_data
    _W["raat"] = raat
    _W["out_dir"] = out_dir
    _W["interval"] = interval
    _W["k_factor"] = k_factor
    _W["default_frequency_ghz"] = default_frequency_ghz
    _W["debug"] = debug


def _worker_task(row):
    try:
        result = process_path(
            row, _W["interval"], _W["usgs_per_path"], _W["fallback_ds"], _W["out_dir"],
            _W["fcc_data"], _W["raat"], _W["debug"],
            _W["default_frequency_ghz"], _W["k_factor"])
        return ("ok", row, result)
    except Exception as e:
        return ("err", row, (type(e).__name__, str(e), traceback.format_exc()))



def main():
    start_time = datetime.now()
    print(f"Run started: {start_time:%Y-%m-%d %H:%M:%S}")

    ap = argparse.ArgumentParser()
    ap.add_argument("csv_file")
    ap.add_argument("--interval", type=float, default=80.0, help="sample interval (m)")
    ap.add_argument("--out", default="output", help="output directory")
    ap.add_argument("--start-id", type=int, default=None, help="min Ref_ID (inclusive)")
    ap.add_argument("--end-id", type=int, default=None, help="max Ref_ID (inclusive)")
    ap.add_argument("--overwrite", action="store_true",
                     help="reprocess paths even if output files already exist "
                          "(default: skip them, for resumable runs)")
    ap.add_argument("--skip-usgs", action="store_true",
                     help="don't build/use the USGS 3DEP 1m VRT; use GLO-30 for everything")
    ap.add_argument("--glo30-vrt", default="glo30_remote.vrt")
    ap.add_argument("--usgs-tile-list", default="usgs_tiles.txt")
    ap.add_argument("--rebuild-vrt", action="store_true",
                     help="rebuild glo30-vrt and usgs-tile-list from scratch even if they "
                          "already exist (use this if you previously built them for a "
                          "smaller subset of paths, e.g. via --limit)")
    ap.add_argument("--limit", type=int, default=None,
                     help="only process the first N paths (after any --start-id/--end-id "
                          "filtering) -- handy for quick tests")
    ap.add_argument("--fcc-antennas", default="fcc_antennas.csv",
                     help="fcc_antennas.csv from fcc_antenna_heights.py")
    ap.add_argument("--fcc-paths", default="fcc_paths.csv",
                     help="fcc_paths.csv (from PA.dat) from fcc_antenna_heights.py")
    ap.add_argument("--raat-csv", default="ATT_Mobility_Network_Definition_v9.csv",
                     help="CSV with Path_ID, a_raat, b_raat columns -- used as the "
                          "authoritative Main Antenna Height AGL for every path")
    ap.add_argument("--debug-antennas", action="store_true",
                     help="print antenna-matching diagnostics for each path")
    ap.add_argument("--frequency-ghz", type=float, default=6.0,
                     help="operating frequency in GHz, used for Fresnel zone radius (default 6.0)")
    ap.add_argument("--k-factor", type=float, default=4.0 / 3.0,
                     help="effective earth radius factor k for earth-curvature bulge (default 4/3)")
    ap.add_argument("--workers", type=int, default=1,
                     help="number of parallel worker processes (default 1 = sequential)")
    args = ap.parse_args()
    args.out = os.path.abspath(args.out)
    args.glo30_vrt = os.path.abspath(args.glo30_vrt)
    args.usgs_tile_list = os.path.abspath(args.usgs_tile_list)
    if args.raat_csv:
        args.raat_csv = os.path.abspath(args.raat_csv)
    print(f"Output directory: {args.out}")

    for sub in ("kml", "tables", "plots"):
        os.makedirs(os.path.join(args.out, sub), exist_ok=True)

    error_log_path = os.path.join(args.out, "errors.csv")
    if not os.path.exists(error_log_path) or args.overwrite:
        with open(error_log_path, "w", newline="") as ef:
            csv.writer(ef).writerow(["Ref_ID", "Path_ID2", "error_type", "error_message", "traceback"])

    with open(args.csv_file, newline="") as f:
        rows = [r for r in csv.DictReader(f) if r.get("Ref_ID", "").strip().isdigit()]

    if args.start_id is not None:
        rows = [r for r in rows if int(r["Ref_ID"]) >= args.start_id]
    if args.end_id is not None:
        rows = [r for r in rows if int(r["Ref_ID"]) <= args.end_id]

    if args.limit is not None:
        rows = rows[:args.limit]

    bbox_rows = [
        (float(r["Site_A_Latitude"]), float(r["Site_A_Longitude"]),
         float(r["Site_ B_Latitude"]), float(r["Site_B_Longitude"]))
        for r in rows
    ]

    # --- Build (or reuse) GLO-30 VRT and USGS tile list ----------------
    build_glo30_vrt(bbox_rows, args.glo30_vrt, force=args.rebuild_vrt)
    usgs_urls, usgs_per_path = [], {}
    if not args.skip_usgs:
        usgs_urls, usgs_per_path = build_usgs_tile_list(rows, args.usgs_tile_list, force=args.rebuild_vrt)

    # --- Process paths -------------------------------------------------
    print(f"Processing {len(rows)} path(s) at {args.interval} m intervals.")

    if usgs_urls:
        print(f"  primary (USGS 3DEP 1m): {len(usgs_urls)} tile(s), multi-CRS, sampled individually")
    else:
        print("  primary: none (using GLO-30 for everything)")

    print(f"  fallback (Copernicus GLO-30): {args.glo30_vrt}")

    fcc_data = load_fcc_data(args.fcc_antennas, args.fcc_paths)
    if fcc_data["antennas"]:
        print(f"  FCC antenna data: {len(fcc_data['antennas'])} antennas, "
              f"{len(fcc_data['path_records'])} unique site-pairs")

    raat = load_raat(args.raat_csv)
    print(f"  RAAT/v12 data: {len(raat)} path entries from {args.raat_csv} "
          f"(raat, frequency, site names)")

    # Filter out paths already done (unless --overwrite)
    todo = []
    for row in rows:
        base = row["Path_ID2"]
        csv_path = os.path.join(args.out, "tables", base + ".csv")
        png_path = os.path.join(args.out, "plots", base + ".png")
        kml_path = os.path.join(args.out, "kml", base + ".kml")
        if not args.overwrite and all(os.path.exists(p) for p in (csv_path, png_path, kml_path)):
            print(f"{base}: already done, skipping")
            continue
        todo.append(row)

    n_total = len(rows)
    n_skip = n_total - len(todo)
    if n_skip:
        print(f"Skipped {n_skip} already-done path(s); {len(todo)} to process.")

    def log_result(i, status, row, payload):
        if status == "ok":
            base, n, n_usgs, n_glo, n_none = payload
            print(f"[{i}/{n_total}] {base}: {n} pts (1m={n_usgs}, glo30={n_glo}, missing={n_none})")
        else:
            err_type, err_msg, tb = payload
            print(f"[{i}/{n_total}] {row.get('Path_ID2')}: ERROR {err_msg}")
            with open(error_log_path, "a", newline="") as ef:
                w = csv.writer(ef)
                w.writerow([row.get("Ref_ID"), row.get("Path_ID2"), err_type, err_msg,
                             tb.replace("\n", " | ")])

    if args.workers > 1 and todo:
        print(f"Using {args.workers} parallel worker process(es).")
        with ProcessPoolExecutor(
                max_workers=args.workers, initializer=_init_worker,
                initargs=(args.glo30_vrt, usgs_per_path, fcc_data, raat, args.out,
                          args.interval, args.k_factor, args.frequency_ghz, args.debug_antennas)
        ) as ex:
            futures = {ex.submit(_worker_task, row): row for row in todo}
            for i, fut in enumerate(as_completed(futures), 1):
                status, row, payload = fut.result()
                log_result(i, status, row, payload)
    else:
        with rasterio.open(args.glo30_vrt) as fallback_ds:
            for i, row in enumerate(todo, 1):
                try:
                    result = process_path(
                        row, args.interval, usgs_per_path, fallback_ds, args.out,
                        fcc_data, raat, args.debug_antennas,
                        args.frequency_ghz, args.k_factor)
                    log_result(i, "ok", row, result)
                except Exception as e:
                    log_result(i, "err", row, (type(e).__name__, str(e), traceback.format_exc()))

    end_time = datetime.now()
    elapsed = end_time - start_time
    # format as HH:MM:SS, allowing for >24h runs
    total_seconds = int(elapsed.total_seconds())
    hh, rem = divmod(total_seconds, 3600)
    mm, ss = divmod(rem, 60)
    print(f"Run ended:   {end_time:%Y-%m-%d %H:%M:%S}")
    print(f"Elapsed:     {hh:02d}:{mm:02d}:{ss:02d}")


if __name__ == "__main__":
    main()
