"""
att_worker.py
-------------
Headless worker process. Runs all queries using ProcessPoolExecutor.
Writes results to a JSON results file that att_monitor.py reads.
No GUI, no tkinter, no Qt — just pure processing.

Called by att_monitor.py automatically. Do not run directly.
"""

import csv
import json
import os
import sys
import logging
import multiprocessing
import concurrent.futures
from pathlib import Path
from datetime import datetime
from time import perf_counter

# Fix PROJ path
import subprocess as _sp
_r = _sp.run(["where", "python"], capture_output=True, text=True)
_py = _r.stdout.strip().split("\n")[0] if _r.stdout.strip() else ""
_cr = str(Path(_py).parent.parent) if _py else ""
for _k, _sub in [("PROJ_DATA","proj"),("PROJ_LIB","proj"),("GDAL_DATA","gdal")]:
    _p = os.path.join(_cr, "Library", "share", _sub)
    if os.path.exists(_p): os.environ[_k] = _p

import warnings

# Set HIGH priority using Windows API directly
try:
    import ctypes
    handle = ctypes.windll.kernel32.OpenProcess(0x1F0FFF, False,
                ctypes.windll.kernel32.GetCurrentProcessId())
    ctypes.windll.kernel32.SetPriorityClass(handle, 0x00000080)  # HIGH_PRIORITY_CLASS
    ctypes.windll.kernel32.CloseHandle(handle)
except Exception:
    pass
warnings.filterwarnings("ignore")

# ─────────────────────────────────────────────
# CONFIG
# ─────────────────────────────────────────────

PATH_ID_CSV    = r"C:\Users\public\pathids.csv"
PATH_ID_COLUMN = "path_id"
DATA_ROOT      = r"D:\ATT\complete"
TIF_DIR        = r"C:\Users\public\clipped_dems"
OUTPUT_ROOT    = r"D:\ATT\complete"
QUERY_DIR      = r"D:\ATT\queries"
MAX_WORKERS    = 8
STEP_M         = 80

# Results file — monitor reads this
RESULTS_FILE   = r"D:\ATT\complete\_results.jsonl"
STATUS_FILE    = r"D:\ATT\complete\_status.json"

# ─────────────────────────────────────────────
# LOGGING
# ─────────────────────────────────────────────

os.makedirs(OUTPUT_ROOT, exist_ok=True)
os.makedirs(QUERY_DIR,   exist_ok=True)

log_path = os.path.join(OUTPUT_ROOT,
                        f"run_log_{datetime.now():%Y%m%d_%H%M%S}.log")
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s  %(levelname)-8s  %(message)s",
    handlers=[
        logging.FileHandler(log_path, encoding="utf-8"),
        logging.StreamHandler(sys.stdout),
    ],
)
log = logging.getLogger(__name__)

# ─────────────────────────────────────────────
# HELPERS
# ─────────────────────────────────────────────

def _fmt(seconds):
    if seconds is None: return "ERROR"
    s = int(seconds)
    hrs, rem = divmod(s, 3600)
    mins, secs = divmod(rem, 60)
    return f"{hrs:02}:{mins:02}:{secs:02}"


def _fmt_precise(seconds):
    """Format with hundredths of a second, e.g. 00:00:00.07
    Used for individual query timings so fast queries don't
    appear as 00:00:00 (which looks like a failure)."""
    if seconds is None: return "ERROR"
    s = int(seconds)
    hrs, rem = divmod(s, 3600)
    mins, secs = divmod(rem, 60)
    frac = seconds - int(seconds)
    return f"{hrs:02}:{mins:02}:{secs:02}.{int(frac*100):02}"


def load_path_ids(csv_path, column):
    ids = []
    with open(csv_path, newline="", encoding="utf-8-sig") as f:
        if column is None:
            for row in csv.reader(f):
                if row: ids.append(row[0].strip())
        else:
            for row in csv.DictReader(f):
                ids.append(row[column].strip())
    return [i for i in ids if i]


def write_result(result):
    """Append a result dict to the JSONL results file."""
    with open(RESULTS_FILE, "a", encoding="utf-8") as f:
        f.write(json.dumps(result) + "\n")


def write_status(status):
    """Write current status to the status file."""
    with open(STATUS_FILE, "w", encoding="utf-8") as f:
        json.dump(status, f)


def find_tif(path_id):
    for f in Path(TIF_DIR).glob("*.tif"):
        if f.stem.lower() == path_id.lower():
            return str(f)
    return None


def dat_path(path_id, receiver):
    return str(Path(DATA_ROOT) / path_id /
               f"{path_id}_DAT_{receiver}_1.csv")


def ven_path(path_id):
    return str(Path(DATA_ROOT) / path_id / f"{path_id}_VEN_1.csv")


# ─────────────────────────────────────────────
# WORKER FUNCTION — fully self-contained process
# ─────────────────────────────────────────────

def worker(path_id, query_name, out_file, step_m, tif_dir, data_root):
    import csv, os, warnings
    warnings.filterwarnings("ignore")
    import subprocess as sp
    r2 = sp.run(["where","python"], capture_output=True, text=True)
    py2 = r2.stdout.strip().split("\n")[0] if r2.stdout.strip() else ""
    from pathlib import Path as P
    cr = str(P(py2).parent.parent) if py2 else ""
    for k, sub in [("PROJ_DATA","proj"),("PROJ_LIB","proj"),("GDAL_DATA","gdal")]:
        p = os.path.join(cr,"Library","share",sub)
        if os.path.exists(p): os.environ[k] = p

    # Set HIGH priority on this worker process using Windows API
    try:
        import ctypes
        handle = ctypes.windll.kernel32.OpenProcess(0x1F0FFF, False,
                    ctypes.windll.kernel32.GetCurrentProcessId())
        ctypes.windll.kernel32.SetPriorityClass(handle, 0x00000080)
        ctypes.windll.kernel32.CloseHandle(handle)
    except Exception:
        pass

    import numpy as np
    import pandas as pd
    import geopandas as gpd
    import rasterio
    from rasterio.errors import NotGeoreferencedWarning
    warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)
    import pyproj
    from time import perf_counter

    geod = pyproj.Geod(ellps="WGS84")

    def hkm(lon1,lat1,lon2,lat2):
        _,_,d = geod.inv(lon1,lat1,lon2,lat2)
        return d/1000.0

    def segmentize(lon1,lat1,lon2,lat2,sm=80):
        _,_,total = geod.inv(lon1,lat1,lon2,lat2)
        n = int(total/sm)-1
        if n < 1:
            # Building is too close to receiver for interior vertices at
            # this step size — sample the midpoint between building and
            # receiver so at least one terrain point is always recorded
            mid = geod.npts(lon1,lat1,lon2,lat2,1)
            vlon,vlat = mid[0]
            return [(vlon,vlat,hkm(vlon,vlat,lon2,lat2))]
        pts = geod.npts(lon1,lat1,lon2,lat2,n)
        return [(vlon,vlat,hkm(vlon,vlat,lon2,lat2)) for vlon,vlat in pts]

    def fv(v,d=0.0):
        try: return float(v) if v is not None and v!="" else d
        except: return d

    def find_tif_local(pid):
        for f in P(tif_dir).glob("*.tif"):
            if f.stem.lower()==pid.lower(): return str(f)
        return None

    def dat(pid,r):
        return str(P(data_root)/pid/f"{pid}_DAT_{r}_1.csv")

    def ven(pid):
        return str(P(data_root)/pid/f"{pid}_VEN_1.csv")

    t0 = perf_counter()
    rows_written = 0
    error = None

    try:
        P(out_file).parent.mkdir(parents=True, exist_ok=True)
        r = query_name[0].lower()

        if query_name in ("A_TER","B_TER"):
            rcvr_long = f"{r}_rcvr_long"
            rcvr_lat  = f"{r}_rcvr_lat"
            rcvr_elev = f"{r}_rcvr_meanelev"
            rcvr_raat = f"{r}_rcvr_raat"
            rcvr_cs   = f"{r}_rcvr_call_sign"
            dist_col  = f"bldg_{r}_rcvr_distance_km"
            other_cs  = "b_rcvr_call_sign" if r=="a" else "a_rcvr_call_sign"

            df  = pd.read_csv(dat(path_id,r), dtype=str)
            tif = find_tif_local(path_id)
            if tif is None:
                raise FileNotFoundError(f"No TIF for {path_id}")

            with rasterio.open(tif) as src, \
                 open(out_file,"w",newline="",encoding="utf-8") as f2:
                w = csv.writer(f2)
                w.writerow(["unique_path_id","bldg_number",
                    "bldg_long","bldg_lat","bldg_mean_elev_ft",
                    "bldg_maxheight_ft","bldg_address",
                    f"{r}_rcvr_call_sign",f"{r}_rcvr_long",
                    f"{r}_rcvr_lat",f"{r}_rcvr_grd_elev_me",
                    f"{r}_rcvr_raat_me","incr_long","incr_lat",
                    "incr_elev_me","incr_dist_km","path_distance_km"])
                for _, row in df.iterrows():
                    bid   = row["bldg_id"]
                    blon  = fv(row["bldg_long"])
                    blat  = fv(row["bldg_lat"])
                    belev = fv(row["bldg_meanelevation"])*3.28
                    bht   = fv(row["bldg_maxheight"])*3.28
                    badr  = row.get("bldg_address","")
                    rlon  = fv(row[rcvr_long])
                    rlat  = fv(row[rcvr_lat])
                    relev = fv(row[rcvr_elev])
                    rraat = fv(row[rcvr_raat])
                    rcs   = row.get(rcvr_cs,"")
                    ocs   = row.get(other_cs,"")
                    pkm   = fv(row[dist_col])
                    uid   = f"{bid}.{rcs}.{ocs}"
                    verts = segmentize(blon,blat,rlon,rlat,step_m)
                    if not verts: continue
                    try: elevs = list(src.sample([(v[0],v[1]) for v in verts]))
                    except: continue
                    vrows = []
                    for (vlon,vlat,dk),ea in zip(verts,elevs):
                        e = float(ea[0]) if ea[0] is not None else None
                        if e is None: continue
                        vrows.append((uid,bid,
                            round(blon,6),round(blat,6),
                            round(belev,4),round(bht,4),badr,
                            rcs,round(rlon,6),round(rlat,6),
                            round(relev,4),round(rraat,4),
                            round(vlon,6),round(vlat,6),
                            round(e,4),round(dk,6),round(pkm,6)))
                    vrows.sort(key=lambda x:x[15],reverse=True)
                    w.writerows(vrows)
                    rows_written += len(vrows)

        else:  # INT
            rcvr_long = f"{r}_rcvr_long"
            rcvr_lat  = f"{r}_rcvr_lat"
            rcvr_elev = f"{r}_rcvr_meanelev"
            rcvr_raat = f"{r}_rcvr_raat"
            rcvr_cs   = f"{r}_rcvr_call_sign"
            rcvr_desc = f"{r}_rcvr_desc"
            line_col  = f"{r}_line_4326"
            bldg_pt   = "bldg_point_4326"
            other_cs  = "b_rcvr_call_sign" if r=="a" else "a_rcvr_call_sign"

            df = pd.read_csv(dat(path_id,r), dtype=str)
            vf = ven(path_id)
            if not P(vf).exists():
                raise FileNotFoundError(f"VEN file not found: {vf}")
            vdf = pd.read_csv(vf, dtype=str)
            vdf.columns = [c.lower() for c in vdf.columns]
            vgeom = None
            for gcol in ["geometry","geom","wkb_geometry","shape"]:
                if gcol in vdf.columns:
                    try:
                        vgeom = gpd.GeoSeries.from_wkb(vdf[gcol])
                        break
                    except: pass
            if vgeom is None:
                raise ValueError(f"No geometry in VEN CSV: {vf}")
            vgdf = gpd.GeoDataFrame(vdf, geometry=vgeom, crs="EPSG:4326")

            def pg(s):
                try: return gpd.GeoSeries.from_wkb(s)
                except:
                    try: return gpd.GeoSeries.from_wkt(s)
                    except: return gpd.GeoSeries([None]*len(s))

            dl = gpd.GeoDataFrame(df, geometry=pg(df[line_col]), crs="EPSG:4326")
            db = gpd.GeoDataFrame(df, geometry=pg(df[bldg_pt]),  crs="EPSG:4326")
            if vgdf.crs is None: vgdf = vgdf.set_crs("EPSG:4326")
            else: vgdf = vgdf.to_crs("EPSG:4326")

            vs   = vgdf[["geometry","lbcs_act_1","latitude","longitude"]]
            cols = list(df.columns)+["lbcs_act_1","latitude","longitude"]

            # Process in chunks to avoid large memory allocations
            CHUNK_SIZE = 5000
            results = []
            for gdf in [dl, db]:
                try:
                    for start in range(0, len(gdf), CHUNK_SIZE):
                        chunk = gdf.iloc[start:start+CHUNK_SIZE]
                        try:
                            j = gpd.sjoin(chunk, vs, how="inner",
                                          predicate="intersects")
                            if len(j) > 0:
                                results.append(j[cols])
                        except Exception:
                            pass
                except Exception:
                    pass

            if not results:
                combined = pd.DataFrame(columns=cols)
            else:
                combined = pd.concat(results).drop_duplicates(
                    subset=["bldg_id","longitude","latitude"])

            from pyproj import Transformer
            tr = Transformer.from_crs("EPSG:4326","EPSG:2163", always_xy=True)
            def dkm(row):
                try:
                    vx,vy = tr.transform(fv(row["longitude"]),fv(row["latitude"]))
                    rx,ry = tr.transform(fv(row[rcvr_long]),fv(row[rcvr_lat]))
                    return round(((vx-rx)**2+(vy-ry)**2)**0.5/1000,6)
                except: return None

            combined["rcvr_dist_km"] = combined.apply(dkm,axis=1)
            combined["pid_col"] = (combined["bldg_id"].astype(str)+"."+
                                   combined[rcvr_cs].astype(str)+"."+
                                   combined[other_cs].astype(str))
            combined = combined.sort_values(["pid_col","rcvr_dist_km"],
                                            ascending=[True,False])

            with open(out_file,"w",newline="",encoding="utf-8") as f2:
                w = csv.writer(f2)
                w.writerow([f"{r}_rcvr_path_id","bldg_number",
                    "bldg_long","bldg_lat","bldg_maxheight_ft",
                    "bldg_meanelev_ft","bldg_desc","bldg_type",
                    f"{r}_rcvr_call_sign",f"{r}_rcvr_desc",
                    f"{r}_rcvr_long",f"{r}_rcvr_lat",
                    f"{r}_rcvr_grd_elev_me",f"{r}_rcvr_ant_raat_me",
                    f"{r}_rcvr_dist_km"])
                for _, row in combined.iterrows():
                    w.writerow([row["pid_col"],row["bldg_id"],
                        round(fv(row["bldg_long"]),6),
                        round(fv(row["bldg_lat"]),6),
                        round(fv(row["bldg_maxheight"])*3.28,4),
                        round(fv(row["bldg_meanelevation"])*3.28,4),
                        row.get("bldg_address",""),
                        row.get("lbcs_act_1",""),
                        row.get(rcvr_cs,""),
                        row.get(rcvr_desc,""),
                        round(fv(row[rcvr_long]),6),
                        round(fv(row[rcvr_lat]),6),
                        round(fv(row[rcvr_elev]),4),
                        round(fv(row[rcvr_raat]),4),
                        row["rcvr_dist_km"]])
                    rows_written += 1

    except Exception as exc:
        error = str(exc)
        try: P(out_file).unlink(missing_ok=True)
        except: pass

    # Warn if INT query produced no rows
    if error is None and rows_written == 0 and query_name in ("A_INT","B_INT"):
        error = "WARNING: zero rows — no building intersections found along this path"

    return {
        "path_id":    path_id,
        "query_name": query_name,
        "elapsed":    perf_counter() - t0,
        "rows":       rows_written,
        "error":      error,
    }


# ─────────────────────────────────────────────
# MAIN
# ─────────────────────────────────────────────

def main():
    started_dt = datetime.now()

    log.info("=" * 60)
    log.info("ATT Worker  (Postgres-free)")
    log.info(f"Workers: {MAX_WORKERS}")
    log.info("=" * 60)

    path_ids = load_path_ids(PATH_ID_CSV, PATH_ID_COLUMN)
    log.info(f"Loaded {len(path_ids)} path IDs")

    query_list = [(pid, qn)
                  for pid in path_ids
                  for qn in ["A_TER","A_INT","B_TER","B_INT"]]
    total = len(query_list)

    # Clear results file
    open(RESULTS_FILE, "w").close()

    # Write initial status
    write_status({
        "total":      total,
        "path_ids":   path_ids,
        "done":       0,
        "started_at": datetime.now().isoformat(),
        "running":    True,
    })

    run_start = perf_counter()
    done = 0
    timing = {}
    row_totals = {"A_TER":0, "A_INT":0, "B_TER":0, "B_INT":0}

    work = [
        (pid, qn,
         str(Path(OUTPUT_ROOT)/pid/f"{pid}_{qn}.csv"),
         STEP_M, TIF_DIR, DATA_ROOT)
        for pid, qn in query_list
    ]

    with concurrent.futures.ProcessPoolExecutor(
            max_workers=MAX_WORKERS) as ex:
        futs = {}
        for a in work:
            fut = ex.submit(worker, *a)
            futs[fut] = a[:2]
            # Write started message
            write_result({"started": True,
                          "path_id": a[0], "query_name": a[1]})

        for fut in concurrent.futures.as_completed(futs):
            try:
                res = fut.result()
            except Exception as exc:
                pid, qn = futs[fut]
                res = {"path_id":pid,"query_name":qn,
                       "elapsed":0,"rows":0,"error":str(exc)}

            done += 1
            timing[(res["path_id"], res["query_name"])] = \
                res["elapsed"] if not res["error"] else None

            wall = perf_counter() - run_start
            res["wall"] = wall
            res["done"] = done
            res["total"] = total
            write_result(res)
            write_status({"total":total,"path_ids":path_ids,
                          "done":done,"running":True,
                          "started_at": datetime.now().isoformat()})

            if res["error"]:
                log.error(f"  X  {res['path_id']}/{res['query_name']}  "
                          f"{_fmt_precise(res['elapsed'])}  {res['error']}")
            else:
                log.info(f"  OK {res['path_id']}/{res['query_name']}  "
                         f"{_fmt_precise(res['elapsed'])}  {res['rows']} rows")
                row_totals[res["query_name"]] += res["rows"]

    # Write timing summary
    rows = []
    for pid in path_ids:
        pt = {qn: _fmt_precise(timing.get((pid,qn)))
              for qn in ["A_TER","A_INT","B_TER","B_INT"]}
        ok = all(timing.get((pid,qn)) is not None
                 for qn in ["A_TER","A_INT","B_TER","B_INT"])
        rows.append({"path_id":pid, **pt,
                     "status":"OK" if ok else "FAILED"})
    tc = Path(OUTPUT_ROOT) / \
         f"timing_summary_{datetime.now():%Y%m%d_%H%M%S}.csv"
    with open(tc,"w",newline="",encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=rows[0].keys())
        w.writeheader(); w.writerows(rows)

    write_status({"total":total,"path_ids":path_ids,
                  "done":done,"running":False,
                  "timing_summary":str(tc)})
    write_result({"all_done": True, "timing_summary": str(tc)})

    # ── Build summary block ──
    finished_dt = datetime.now()
    elapsed_secs = perf_counter() - run_start
    pass_count = sum(1 for r in rows if r["status"] == "OK")
    fail_count = len(rows) - pass_count

    # Per-query-type stats over successfully-timed queries
    query_types = ["A_TER","A_INT","B_TER","B_INT"]
    type_vals = {}
    type_totals = {}
    for qn in query_types:
        vals = sorted(t for (pid, q), t in timing.items()
                       if q == qn and t is not None)
        type_vals[qn] = vals
        type_totals[qn] = sum(vals)

    grand_total_secs = sum(type_totals.values())

    def _median(vals):
        n = len(vals)
        if n == 0:
            return 0
        mid = n // 2
        if n % 2 == 1:
            return vals[mid]
        return (vals[mid-1] + vals[mid]) / 2

    summary_lines = []
    summary_lines.append("=" * 60)
    summary_lines.append("RUN SUMMARY")
    summary_lines.append("=" * 60)
    summary_lines.append(f"Start time:       {started_dt:%Y-%m-%d %H:%M:%S}")
    summary_lines.append(f"Completion time:  {finished_dt:%Y-%m-%d %H:%M:%S}")
    summary_lines.append(f"Total elapsed:    {_fmt(elapsed_secs)} (hhh:mm:ss)")
    summary_lines.append(f"Workers:          {MAX_WORKERS}")
    summary_lines.append(f"Total path IDs:   {len(rows)}")
    summary_lines.append(f"Pass:             {pass_count} of {len(rows)}")
    summary_lines.append(f"Fail:             {fail_count} of {len(rows)}")
    summary_lines.append("-" * 60)
    summary_lines.append("Per-query-type worker-time totals (hhh:mm:ss):")
    for qn in query_types:
        secs = type_totals[qn]
        pct = (100 * secs / grand_total_secs) if grand_total_secs else 0
        summary_lines.append(f"  {qn:<8}  {_fmt(secs):>12}  ({pct:5.1f}%)")
    summary_lines.append(f"  {'TOTAL':<8}  {_fmt(grand_total_secs):>12}")
    summary_lines.append("-" * 60)
    summary_lines.append("Per-query-type statistics (min / median / mean / max):")
    for qn in query_types:
        vals = type_vals[qn]
        if vals:
            vmin = vals[0]
            vmax = vals[-1]
            vmed = _median(vals)
            vmean = sum(vals) / len(vals)
        else:
            vmin = vmax = vmed = vmean = 0
        summary_lines.append(
            f"  {qn:<8}  min {_fmt_precise(vmin):>13}  "
            f"median {_fmt_precise(vmed):>13}  "
            f"mean {_fmt_precise(vmean):>13}  "
            f"max {_fmt_precise(vmax):>13}")
    summary_lines.append("-" * 60)
    summary_lines.append("Total rows written by query type:")
    for qn in query_types:
        summary_lines.append(f"  {qn:<8}  {row_totals[qn]:>15,}")
    summary_lines.append(f"  {'TOTAL':<8}  {sum(row_totals.values()):>15,}")
    summary_lines.append("=" * 60)
    summary_lines.append("")

    summary_text = "\n".join(summary_lines)

    log.info("\n" + summary_text)

    # Prepend summary block to the top of the log file
    try:
        with open(log_path, "r", encoding="utf-8") as f:
            existing_log = f.read()
        with open(log_path, "w", encoding="utf-8") as f:
            f.write(summary_text + "\n" + existing_log)
    except Exception as exc:
        log.warning(f"Could not prepend summary to log file: {exc}")

    log.info(f"Done. Timing: {tc}")


if __name__ == "__main__":
    multiprocessing.freeze_support()
    main()
