"""
Generates DAT_A and DAT_B CSV files for each path, replacing the original
PostgreSQL/PostGIS calproject_6 pipeline.

For each row in ATT_Mobility_Network_Definition_v8.csv:
  1. Read the VEN CSV (building data)
  2. Clean fields (maxheight 0/null -> 3.048, replace spaces/commas with '_')
  3. Build calproject_6-equivalent table:
     - bldg point (lon/lat), distances to A and B receivers
     - a_line / b_line as WKT LINESTRING (lon/lat)
     - a_rcvr_path_uuid / b_rcvr_path_uuid
  4. Read the KML rectangle (4 corners)
  5. Split the rectangle perpendicular to its long axis through A receiver -> ADirection
     Split again through B receiver -> BDirection
  6. Filter buildings whose point falls in the corresponding half
  7. Write DAT_A.csv and DAT_B.csv

Requires: pandas, shapely, pyproj
    pip install pandas shapely pyproj --break-system-packages   (if needed)
"""

import math
import os
import re
import xml.etree.ElementTree as ET

import rasterio
import numpy as np
import pandas as pd
from pyproj import Transformer
from shapely.geometry import Point, Polygon, LineString
from shapely.ops import split as shapely_split, transform as shapely_transform

# ---------------------------------------------------------------------------
# CRS transformer: 4326 (lon/lat) -> 2163 (US National Atlas Equal Area, meters)
# ---------------------------------------------------------------------------
MAX_DISTANCE_KM = 50 * 1.609344  # 50 miles in km
GLO30_VRT = r"d:\att\profiles\glo30_remote.vrt"

# DEM directory — override per network if needed
ATT_DEM_DIR   = r"c:\users\public\clipped_dems"
ONCOR_DEM_DIR = r"d:\oncor\clipped_dems"
DEM_DIR = ATT_DEM_DIR  # default; set to ONCOR_DEM_DIR when running Oncor


def sample_elevations(tif_path, lons, lats, search_radius=2, fallback_vrt=None):
    """
    Sample elevation for many (lon, lat) points from a single DEM tif.
    Falls back to nearest-valid-pixel mean if the exact pixel is nodata.
    If still nodata after local search, falls back to GLO-30 VRT.
    Returns a list of floats (or None).
    """
    results = [None] * len(lons)

    if os.path.exists(tif_path):
        with rasterio.open(tif_path) as src:
            nodata = src.nodata
            band = src.read(1)
            samples = list(src.sample(zip(lons, lats)))
            for i, (lon, lat, s) in enumerate(zip(lons, lats, samples)):
                val = s[0]
                is_nodata = (nodata is not None and np.isnan(nodata) and np.isnan(val)) or (val == nodata) or (nodata is None and np.isnan(val))
                if not is_nodata:
                    results[i] = float(val)
                    continue

                row_idx, col_idx = src.index(lon, lat)
                for r in range(1, search_radius + 1):
                    r0, r1 = max(0, row_idx - r), min(band.shape[0], row_idx + r + 1)
                    c0, c1 = max(0, col_idx - r), min(band.shape[1], col_idx + r + 1)
                    window = band[r0:r1, c0:c1]
                    if nodata is not None and np.isnan(nodata):
                        valid = window[~np.isnan(window)]
                    else:
                        valid = window[window != nodata] if nodata is not None else window[~np.isnan(window)]
                    if valid.size > 0:
                        results[i] = float(valid.mean())
                        break

    # For any still-None points, fall back to GLO-30 VRT
    still_blank = [i for i, v in enumerate(results) if v is None]
    if still_blank and fallback_vrt and os.path.exists(fallback_vrt):
        fb_lons = [lons[i] for i in still_blank]
        fb_lats = [lats[i] for i in still_blank]
        try:
            with rasterio.open(fallback_vrt) as src:
                nodata = src.nodata
                samples = list(src.sample(zip(fb_lons, fb_lats)))
                for idx, s in zip(still_blank, samples):
                    val = s[0]
                    is_nodata = (nodata is not None and val == nodata) or np.isnan(val)
                    if not is_nodata:
                        results[idx] = round(float(val), 4)
        except Exception as e:
            print(f"    GLO-30 fallback error: {e}")

    return results

_to_2163 = Transformer.from_crs("EPSG:4326", "EPSG:2163", always_xy=True)


def to_2163(lon, lat):
    x, y = _to_2163.transform(lon, lat)
    return x, y


def point_2163(lon, lat):
    x, y = to_2163(lon, lat)
    return Point(x, y)


def dist_km_2163(lon1, lat1, lon2, lat2):
    p1 = point_2163(lon1, lat1)
    p2 = point_2163(lon2, lat2)
    return p1.distance(p2) / 1000.0


# ---------------------------------------------------------------------------
# KML parsing
# ---------------------------------------------------------------------------
def read_kml_rectangle(kml_path):
    """Returns list of (lon, lat) tuples for the 4 corners (closing point dropped)."""
    ns = {"kml": "http://www.opengis.net/kml/2.2"}
    tree = ET.parse(kml_path)
    root = tree.getroot()
    coords_el = root.find(".//kml:Polygon//kml:coordinates", ns)
    if coords_el is None:
        # try without namespace
        coords_el = root.find(".//Polygon//coordinates")
    text = coords_el.text.strip()
    pts = []
    for tok in re.split(r"\s+", text):
        if not tok:
            continue
        parts = tok.split(",")
        lon, lat = float(parts[0]), float(parts[1])
        pts.append((lon, lat))
    # drop closing point if it duplicates the first
    if len(pts) >= 5 and pts[0] == pts[-1]:
        pts = pts[:-1]
    return pts[:4]


# ---------------------------------------------------------------------------
# split_bounding_box equivalent
# ---------------------------------------------------------------------------
def split_bounding_box(box_points_lonlat, internal_point_lonlat):
    """
    box_points_lonlat: list of 4 (lon, lat) tuples, sequential rectangle corners
    internal_point_lonlat: (lon, lat) point inside the box

    Returns list of shapely Polygons (in lon/lat, EPSG:4326) resulting from
    splitting the box with a line through internal_point, perpendicular to
    the rectangle's longest axis.
    """
    # Work in projected meters (2163) for distance/azimuth math, like ST_Distance/ST_Azimuth did
    p_m = [point_2163(lon, lat) for lon, lat in box_points_lonlat]
    p1, p2, p3, p4 = p_m

    l1 = p1.distance(p2)
    l2 = p2.distance(p3)

    max_extent = max(l1, l2, p1.distance(p3), p2.distance(p4))
    far_dist = max_extent * 10.0

    # ST_Azimuth(a,b) = atan2(dx, dy) -- angle from north, clockwise (in radians)
    def azimuth(a, b):
        dx = b.x - a.x
        dy = b.y - a.y
        return math.atan2(dx, dy)

    if l1 >= l2:
        longest_az = azimuth(p1, p2)
    else:
        longest_az = azimuth(p2, p3)

    perp_az = longest_az + math.pi / 2.0

    # internal point in projected meters
    ip = point_2163(*internal_point_lonlat)

    # ST_Project(point, distance, azimuth): azimuth measured clockwise from north
    def project(pt, distance, az):
        dx = distance * math.sin(az)
        dy = distance * math.cos(az)
        return Point(pt.x + dx, pt.y + dy)

    sp1 = project(ip, far_dist, perp_az)
    sp2 = project(ip, far_dist, perp_az + math.pi)
    split_line_m = LineString([sp1, sp2])

    box_poly_m = Polygon([(p.x, p.y) for p in p_m] + [(p1.x, p1.y)])

    pieces_m = shapely_split(box_poly_m, split_line_m)

    # transform pieces back to lon/lat (EPSG:4326)
    _to_4326 = Transformer.from_crs("EPSG:2163", "EPSG:4326", always_xy=True)

    def proj_back(x, y, z=None):
        lon, lat = _to_4326.transform(x, y)
        return (lon, lat)

    pieces_4326 = []
    for geom in pieces_m.geoms:
        if geom.geom_type == "Polygon":
            pieces_4326.append(shapely_transform(proj_back, geom))
    return pieces_4326


# ---------------------------------------------------------------------------
# VEN cleaning
# ---------------------------------------------------------------------------
def clean_ven(df):
    df = df.copy()

    # maxheight: 0 or null -> 3.048
    df["maxheight"] = pd.to_numeric(df["maxheight"], errors="coerce")
    df.loc[df["maxheight"] == 0, "maxheight"] = 3.048
    df["maxheight"] = df["maxheight"].fillna(3.048)

    # address: replace spaces and commas with '_'
    if "address" in df.columns:
        df["address"] = df["address"].astype(str).str.replace(" ", "_", regex=False)
        df["address"] = df["address"].str.replace(",", "_", regex=False)
        df.loc[df["address"].isin(["nan", "None", ""]), "address"] = None

    # bld_uuid: replace commas with '_'
    if "bld_uuid" in df.columns:
        df["bld_uuid"] = df["bld_uuid"].astype(str).str.replace(",", "_", regex=False)

    df["latitude"] = pd.to_numeric(df["latitude"], errors="coerce")
    df["longitude"] = pd.to_numeric(df["longitude"], errors="coerce")
    df["meanelev"] = pd.to_numeric(df["meanelev"], errors="coerce")

    return df


# ---------------------------------------------------------------------------
# Build calproject_6-equivalent dataframe for one path
# ---------------------------------------------------------------------------
def build_calproject_6(ven_df, row):
    df = clean_ven(ven_df)

    out = pd.DataFrame()
    out["bldg_id"] = df["gid"]
    out["bldg_long"] = df["longitude"].round(6)
    out["bldg_lat"] = df["latitude"].round(6)
    out["bldg_maxheight"] = df["maxheight"].round(4)
    out["bldg_address"] = df["address"]

    # Sample mean elevation from the per-path DEM TIF; fall back to VEN's own
    # meanelev value if the DEM has no data at that point or the TIF is missing.
    tif_path = os.path.join(DEM_DIR, f"{row['Path_ID']}.tif")
    dem_elevs = sample_elevations(tif_path, out["bldg_long"].tolist(), out["bldg_lat"].tolist(), fallback_vrt=GLO30_VRT)
    bldg_meanelev = []
    for dem_val, ven_val in zip(dem_elevs, df["meanelev"]):
        if dem_val is not None:
            bldg_meanelev.append(round(dem_val, 4))
        elif pd.notna(ven_val):
            bldg_meanelev.append(round(float(ven_val), 4))
        else:
            bldg_meanelev.append(None)
    out["bldg_meanelevation"] = bldg_meanelev

    out["bldg_point_4326"] = [
        f"POINT({lon} {lat})" for lon, lat in zip(out["bldg_long"], out["bldg_lat"])
    ]

    bldg_pts_2163 = [to_2163(lon, lat) for lon, lat in zip(out["bldg_long"], out["bldg_lat"])]
    out["bldg_point_2163"] = [f"POINT({x} {y})" for x, y in bldg_pts_2163]

    a_long = float(row["A_Longitude"])
    a_lat = float(row["A_Latitude"])
    b_long = float(row["b_Longitude"])
    b_lat = float(row["b_Latitude"])

    a_pt_2163 = point_2163(a_long, a_lat)
    b_pt_2163 = point_2163(b_long, b_lat)

    out["a_rcvr_call_sign"] = row["A_Call_Sign"]
    out["a_rcvr_long"] = a_long
    out["a_rcvr_lat"] = a_lat
    out["a_rcvr_meanelev"] = float(row["a_meanelev"])
    out["a_rcvr_raat"] = float(row["a_raat"])
    out["a_rcvr_desc"] = row["a_desc"]

    out["bldg_a_rcvr_distance_km"] = [
        round(Point(x, y).distance(a_pt_2163) / 1000.0, 6) for x, y in bldg_pts_2163
    ]

    out["a_line_4326"] = [
        f"LINESTRING({lon} {lat},{a_long} {a_lat})"
        for lon, lat in zip(out["bldg_long"], out["bldg_lat"])
    ]

    out["a_rcvr_path_uuid"] = out["bldg_id"].astype(str) + "." + str(row["A_Call_Sign"])
    out["a_rcvr_point_4326"] = f"POINT({a_long} {a_lat})"
    ax, ay = to_2163(a_long, a_lat)
    out["a_rcvr_point_2163"] = f"POINT({ax} {ay})"

    out["b_rcvr_path_uuid"] = out["bldg_id"].astype(str) + "." + str(row["b_Call_Sign"])
    out["b_rcvr_call_sign"] = row["b_Call_Sign"]
    out["b_rcvr_long"] = b_long
    out["b_rcvr_lat"] = b_lat
    out["b_rcvr_meanelev"] = float(row["b_meanelev"])
    out["b_rcvr_raat"] = float(row["b_raat"])
    out["b_rcvr_desc"] = row["b_desc"]

    out["bldg_b_rcvr_distance_km"] = [
        round(Point(x, y).distance(b_pt_2163) / 1000.0, 6) for x, y in bldg_pts_2163
    ]

    out["b_line_4326"] = [
        f"LINESTRING({lon} {lat},{b_long} {b_lat})"
        for lon, lat in zip(out["bldg_long"], out["bldg_lat"])
    ]

    out["b_rcvr_point_4326"] = f"POINT({b_long} {b_lat})"
    bx, by = to_2163(b_long, b_lat)
    out["b_rcvr_point_2163"] = f"POINT({bx} {by})"

    # final column order (matches calproject_6 SELECT order)
    cols = [
        "bldg_id", "bldg_long", "bldg_lat", "bldg_meanelevation", "bldg_maxheight", "bldg_address",
        "bldg_point_4326", "bldg_point_2163",
        "a_rcvr_call_sign", "a_rcvr_long", "a_rcvr_lat", "a_rcvr_meanelev", "a_rcvr_raat", "a_rcvr_desc",
        "bldg_a_rcvr_distance_km", "a_line_4326",
        "a_rcvr_path_uuid", "a_rcvr_point_4326", "a_rcvr_point_2163",
        "b_rcvr_path_uuid",
        "b_rcvr_call_sign", "b_rcvr_long", "b_rcvr_lat", "b_rcvr_meanelev", "b_rcvr_raat", "b_rcvr_desc",
        "bldg_b_rcvr_distance_km", "b_line_4326",
        "b_rcvr_point_4326", "b_rcvr_point_2163",
    ]
    return out[cols], bldg_pts_2163


# ---------------------------------------------------------------------------
# Process a single path
# ---------------------------------------------------------------------------
def process_path(row, verbose=True, test_output_dir=None):
    path_id = row["Path_ID"]
    ven_file = row["ven_file"]
    kml_file = row["kml_file"]

    if test_output_dir:
        dat_a_out = os.path.join(test_output_dir, path_id, f"{path_id}_dat_a_1.csv")
        dat_b_out = os.path.join(test_output_dir, path_id, f"{path_id}_dat_b_1.csv")
    else:
        dat_a_out = row["dat_a_output"]
        dat_b_out = row["dat_b_output"]

    if verbose:
        print(f"Processing {path_id} ...")

    if not os.path.exists(ven_file):
        print(f"  SKIP: VEN file not found: {ven_file}")
        return False
    if not os.path.exists(kml_file):
        print(f"  SKIP: KML file not found: {kml_file}")
        return False

    ven_df = pd.read_csv(ven_file, dtype=str, encoding="utf-8-sig")
    cal6, bldg_pts_2163 = build_calproject_6(ven_df, row)

    box_pts = read_kml_rectangle(kml_file)  # [(lon,lat) x4]

    a_long = float(row["A_Longitude"])
    a_lat = float(row["A_Latitude"])
    b_long = float(row["b_Longitude"])
    b_lat = float(row["b_Latitude"])

    # --- Split on A receiver ---
    pieces_a = split_bounding_box(box_pts, (a_long, a_lat))
    b_pt = Point(b_long, b_lat)
    a_poly = None
    for piece in pieces_a:
        if piece.intersects(b_pt):
            a_poly = piece
            break
    if a_poly is None and pieces_a:
        a_poly = pieces_a[0]  # fallback

    # --- Split on B receiver ---
    pieces_b = split_bounding_box(box_pts, (b_long, b_lat))
    a_pt = Point(a_long, a_lat)
    b_poly = None
    for piece in pieces_b:
        if piece.intersects(a_pt):
            b_poly = piece
            break
    if b_poly is None and pieces_b:
        b_poly = pieces_b[0]  # fallback

    bldg_points_4326 = [Point(lon, lat) for lon, lat in zip(cal6["bldg_long"], cal6["bldg_lat"])]

    if a_poly is not None:
        mask_a = [a_poly.intersects(pt) for pt in bldg_points_4326]
    else:
        mask_a = [False] * len(cal6)

    if b_poly is not None:
        mask_b = [b_poly.intersects(pt) for pt in bldg_points_4326]
    else:
        mask_b = [False] * len(cal6)

    a_direction = cal6[mask_a].copy()
    b_direction = cal6[mask_b].copy()

    # Drop buildings beyond 50 miles from the respective antenna
    a_direction = a_direction[a_direction["bldg_a_rcvr_distance_km"] <= MAX_DISTANCE_KM]
    b_direction = b_direction[b_direction["bldg_b_rcvr_distance_km"] <= MAX_DISTANCE_KM]

    os.makedirs(os.path.dirname(dat_a_out), exist_ok=True)
    os.makedirs(os.path.dirname(dat_b_out), exist_ok=True)

    a_direction.to_csv(dat_a_out, index=False)
    b_direction.to_csv(dat_b_out, index=False)

    n_a = len(a_direction)
    n_b = len(b_direction)

    if verbose:
        print(f"  ADirection: {n_a} buildings -> {dat_a_out}")
        print(f"  BDirection: {n_b} buildings -> {dat_b_out}")
        if n_a == 0:
            print(f"  WARNING: DAT_A has zero buildings for {path_id}")
        if n_b == 0:
            print(f"  WARNING: DAT_B has zero buildings for {path_id}")

    return True, n_a, n_b, n_a, n_b


# ---------------------------------------------------------------------------
# Main loop
# ---------------------------------------------------------------------------
def main(network_csv="Oncor_Network_Definition_v1.csv", limit=None,
         test_output_dir=None, test_worst=False, blank_report_csv=None,
         workers=8, dem_dir=ONCOR_DEM_DIR):
    df = pd.read_csv(network_csv, dtype=str)

    # Override DEM directory if specified
    if dem_dir:
        import generate_dat_files_oncor as _self
        _self.DEM_DIR = dem_dir

    if test_worst and blank_report_csv:
        report = pd.read_csv(blank_report_csv)
        worst_ids = report.sort_values("blank_count", ascending=False)["path_id"].unique()[:10]
        df = df[df["Path_ID"].isin(worst_ids)].copy()
        print(f"Test mode: running {len(df)} worst-offender paths -> {test_output_dir}")
    elif limit:
        df = df.head(limit)

    rows = [row for _, row in df.iterrows()]
    total = len(rows)

    import concurrent.futures
    import threading
    import datetime

    log_file = f"dat_run_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
    log_lock = threading.Lock()

    def log(msg):
        ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        line = f"[{ts}] {msg}"
        print(line)
        with log_lock:
            with open(log_file, "a") as f:
                f.write(line + "\n")

    lock = threading.Lock()
    counters = {"ok": 0, "fail": 0, "done": 0, "warn": 0}

    def run_one(row):
        try:
            result = process_path(row, test_output_dir=test_output_dir, verbose=False)
            if isinstance(result, tuple):
                success, n_a, n_b = result
            else:
                success, n_a, n_b = result, None, None
        except Exception as e:
            log(f"  ERROR on {row['Path_ID']}: {e}")
            success, n_a, n_b = False, None, None

        with lock:
            counters["done"] += 1
            if success:
                counters["ok"] += 1
            else:
                counters["fail"] += 1
            if n_a == 0 or n_b == 0:
                counters["warn"] += 1
            done = counters["done"]
            ok = counters["ok"]
            fail = counters["fail"]
            warn = counters["warn"]

        status = "OK" if success else "SKIP/FAIL"
        msg = f"  [{done:4d}/{total}] {row['Path_ID']} -> {status}  (ok={ok}, fail={fail}, warn={warn})"
        if success and n_a is not None:
            msg += f"  A={n_a} B={n_b}"
        if n_a == 0:
            msg += "  *** WARNING: DAT_A zero buildings ***"
        if n_b == 0:
            msg += "  *** WARNING: DAT_B zero buildings ***"
        log(msg)
        return success

    log(f"Starting run: {total} paths, {workers} workers, network={network_csv}")

    if workers > 1:
        log(f"Running {total} paths with {workers} parallel workers...\n")
        with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
            list(executor.map(run_one, rows))
    else:
        log(f"Running {total} paths sequentially...\n")
        for row in rows:
            run_one(row)

    summary = f"Done. {counters['ok']} succeeded, {counters['fail']} failed/skipped, {counters['warn']} zero-building warnings out of {total}."
    log(summary)
    log(f"Log saved to: {log_file}")


if __name__ == "__main__":
    main()
