"""
batch_kizer_krunch_v2.py
------------------------
Parallel batch runner for kizer_krunch_code.py.

Each worker process loads the GPD file ONCE at startup via an initializer,
then reuses it for all assigned paths — keeping RAM stable.

Usage:
    python batch_kizer_krunch_v2.py                 # defaults to Both (I+T), 8 workers
    python batch_kizer_krunch_v2.py --mode I        # Interference only
    python batch_kizer_krunch_v2.py --mode T        # Terrain only
    python batch_kizer_krunch_v2.py --mode B        # Both (default)
    python batch_kizer_krunch_v2.py --workers 10    # use 10 parallel workers
    python batch_kizer_krunch_v2.py --start 1000    # start at path_num 1000
    python batch_kizer_krunch_v2.py --end 1100      # stop after path_num 1100
    python batch_kizer_krunch_v2.py --paths 3 7 12  # run specific path numbers only

Edit the CONFIGURATION section below to match your environment.
"""

import argparse
import time
from datetime import datetime
import traceback
import pandas as pd
from concurrent.futures import ProcessPoolExecutor, as_completed

# ── Adjust these two paths to match your environment ──────────────────────────
ROSETTA_PATH    = r'D:\ATT\ATT_Rosetta_populated.csv'
GPD_PATH        = r'D:\ATT\ATT_GPD_ALL.xlsx'
DEFAULT_WORKERS = 8
# ──────────────────────────────────────────────────────────────────────────────

from kizer_krunch_code import (
    FNC_COLUMNS,
    GENERALPATHDATA_COLUMNS,
    run_interference_calculation,
    run_terrain_calculation,
)

# ── Worker-level global (one copy per worker process, not per path) ───────────
_worker_gpd_lookup = None

def worker_init(gpd_path: str):
    """
    Called once when each worker process starts.
    Loads GPD into a module-level global so all tasks in this worker reuse it.
    """
    global _worker_gpd_lookup
    gpd_df = pd.read_excel(gpd_path, dtype=str)
    _worker_gpd_lookup = {}
    for _, row in gpd_df.iterrows():
        tx = str(row[GENERALPATHDATA_COLUMNS["tx_call_sign"]]).strip()
        rx = str(row[GENERALPATHDATA_COLUMNS["rx_call_sign"]]).strip()
        _worker_gpd_lookup[(tx, rx)] = {
            "channels":  row[GENERALPATHDATA_COLUMNS["channels"]],
            "bandwidth": row[GENERALPATHDATA_COLUMNS["bandwidth"]],
            "frequency": row[GENERALPATHDATA_COLUMNS["fr_frequency_assigned_MHz"]],
            "gain":      row[GENERALPATHDATA_COLUMNS["rxan_gain_dBi"]],
            "tilt":      row[GENERALPATHDATA_COLUMNS["RxTilt_deg"]],
            "bearing":   row[GENERALPATHDATA_COLUMNS["RxBearing_deg"]],
            "ant_model": str(row[GENERALPATHDATA_COLUMNS["AltRcvAntenna"]]).strip(),
        }


# ── Helpers ───────────────────────────────────────────────────────────────────

def build_paths_for_row(filenaming: pd.DataFrame, rownumber: int) -> dict:
    """Extract all file paths and identifiers for a single Rosetta row."""
    loc = filenaming.loc[rownumber]

    drive        = str(loc[FNC_COLUMNS["drive"]]).strip()
    folder       = str(loc[FNC_COLUMNS["folder"]]).strip().strip('\\')
    subfolder    = str(loc[FNC_COLUMNS["subfolder"]]).strip()
    base_path    = f"{drive}\\{folder}\\{subfolder}"
    IntFunction  = str(loc[FNC_COLUMNS["AIntFunction"]])
    TerrFunction = str(loc[FNC_COLUMNS["ATerrFunction"]])

    return {
        "AReceiver":      str(loc[FNC_COLUMNS["AReceiver"]]),
        "BReceiver":      str(loc[FNC_COLUMNS["BReceiver"]]),
        "drive":          drive,
        "folder":         folder,
        "subfolder":      subfolder,
        "AIntInputfile":  str(loc[FNC_COLUMNS["fileAInt"]]),
        "BIntInputfile":  str(loc[FNC_COLUMNS["fileBInt"]]),
        "ATerrInputfile": str(loc[FNC_COLUMNS["fileATerr"]]),
        "BTerrInputfile": str(loc[FNC_COLUMNS["fileBTerr"]]),
        "AIntOutputfile":  f"{base_path}\\{subfolder}_A_{IntFunction}.csv",
        "BIntOutputfile":  f"{base_path}\\{subfolder}_B_{IntFunction}.csv",
        "ATerrOutputfile": f"{base_path}\\{subfolder}_A_{TerrFunction}.csv",
        "BTerrOutputfile": f"{base_path}\\{subfolder}_B_{TerrFunction}.csv",
    }


def process_one_path(args: tuple) -> tuple:
    """
    Worker function — runs in a worker process.
    Uses the module-level _worker_gpd_lookup loaded at startup.
    Returns (path_num, elapsed_seconds, list_of_error_strings).
    """
    path_num, paths, run_interference, run_terrain = args  # no gpd_lookup passed in
    errors = []
    path_start = time.time()

    if run_interference:
        for label, inp, out in [
            ("A Interference", paths["AIntInputfile"],  paths["AIntOutputfile"]),
            ("B Interference", paths["BIntInputfile"],  paths["BIntOutputfile"]),
        ]:
            try:
                run_interference_calculation(inp, out, paths["drive"], paths["folder"], _worker_gpd_lookup)
            except Exception as e:
                errors.append(f"Path {path_num} | {label}: {e}\n{traceback.format_exc()}")

    if run_terrain:
        for label, inp, out in [
            ("A Terrain", paths["ATerrInputfile"],  paths["ATerrOutputfile"]),
            ("B Terrain", paths["BTerrInputfile"],  paths["BTerrOutputfile"]),
        ]:
            try:
                run_terrain_calculation(inp, out)
            except Exception as e:
                errors.append(f"Path {path_num} | {label}: {e}\n{traceback.format_exc()}")

    elapsed = time.time() - path_start
    return path_num, elapsed, errors


# ── Main ──────────────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(description="Parallel batch runner for kizer_krunch_code.py")
    parser.add_argument(
        "--mode", choices=["I", "T", "B", "i", "t", "b"],
        default="B",
        help="I=Interference only, T=Terrain only, B=Both (default: B)",
    )
    parser.add_argument(
        "--workers", type=int, default=DEFAULT_WORKERS,
        help=f"Number of parallel worker processes (default: {DEFAULT_WORKERS})",
    )
    parser.add_argument(
        "--start", type=int, default=None,
        help="First path_num to process (inclusive).",
    )
    parser.add_argument(
        "--end", type=int, default=None,
        help="Last path_num to process (inclusive).",
    )
    parser.add_argument(
        "--paths", type=int, nargs="+", default=None,
        help="Explicit list of path numbers to run (overrides --start/--end).",
    )
    args = parser.parse_args()

    mode = args.mode.upper()
    run_interference = mode in ("I", "B")
    run_terrain      = mode in ("T", "B")

    # ── Load Rosetta ──────────────────────────────────────────────────────────
    print(f"Loading Rosetta from: {ROSETTA_PATH}")
    filenaming = pd.read_csv(ROSETTA_PATH)
    all_path_nums = filenaming["path_num"].astype(str).str.strip().tolist()
    print(f"  Rosetta loaded: {len(all_path_nums)} paths.\n")

    # ── Determine which paths to run ─────────────────────────────────────────
    if args.paths:
        target_paths = [str(p) for p in args.paths]
        missing = [p for p in target_paths if p not in all_path_nums]
        if missing:
            print(f"WARNING: These path numbers not found in Rosetta and will be skipped: {missing}")
        target_paths = [p for p in target_paths if p in all_path_nums]
    else:
        target_paths = all_path_nums
        if args.start is not None:
            target_paths = [p for p in target_paths if int(p) >= args.start]
        if args.end is not None:
            target_paths = [p for p in target_paths if int(p) <= args.end]

    if not target_paths:
        print("No paths to process. Check --start / --end / --paths arguments.")
        return

    # ── Build work items (no gpd_lookup — workers load it themselves) ────────
    work_items = []
    skipped = []
    for path_num in target_paths:
        matching = filenaming.index[
            filenaming["path_num"].astype(str).str.strip() == path_num
        ]
        if len(matching) == 0:
            print(f"  SKIP: path_num '{path_num}' not found in Rosetta.")
            skipped.append(path_num)
            continue
        try:
            paths = build_paths_for_row(filenaming, matching[0])
            work_items.append((path_num, paths, run_interference, run_terrain))
        except Exception as e:
            print(f"  SKIP path {path_num}: could not build file paths: {e}")
            skipped.append(path_num)

    # ── Launch parallel workers ───────────────────────────────────────────────
    mode_label = {"I": "Interference only", "T": "Terrain only", "B": "Interference + Terrain"}[mode]
    print(f"{'='*60}")
    print(f"BATCH RUN — mode: {mode_label}  |  workers: {args.workers}")
    print(f"Paths to process: {len(work_items)}  ({target_paths[0]} … {target_paths[-1]})")
    print(f"Each worker loads GPD once at startup — RAM stays stable.")
    print(f"{'='*60}")

    batch_start    = time.time()
    batch_start_dt = datetime.now()
    print(f"  Start time : {batch_start_dt.strftime('%H:%M:%S')}\n")

    all_errors = []
    completed  = 0

    with ProcessPoolExecutor(
        max_workers=args.workers,
        initializer=worker_init,       # called once per worker process
        initargs=(GPD_PATH,),          # passes the file path, not the data
        max_tasks_per_child=5          # worker restarts every 5 paths, flushing RAM
    ) as executor:
        futures = {executor.submit(process_one_path, item): item[0] for item in work_items}
        for future in as_completed(futures):
            path_num = futures[future]
            try:
                pn, elapsed, errors = future.result()
                completed += 1
                status = "OK" if not errors else f"ERRORS({len(errors)})"
                print(f"  [{completed:>4}/{len(work_items)}] Path {pn:<8} {status}  ({elapsed:.1f}s)")
                for e in errors:
                    print(f"    ERROR: {e}")
                    all_errors.append((pn, e))
            except Exception as e:
                completed += 1
                msg = f"Worker crashed: {e}"
                print(f"  [{completed:>4}/{len(work_items)}] Path {path_num:<8} CRASHED: {msg}")
                all_errors.append((path_num, msg))

    # ── Summary ───────────────────────────────────────────────────────────────
    batch_end_dt  = datetime.now()
    total_elapsed = time.time() - batch_start
    hours, rem    = divmod(int(total_elapsed), 3600)
    minutes, secs = divmod(rem, 60)

    print(f"\n{'='*60}")
    print("BATCH COMPLETE")
    print(f"  Total paths attempted : {len(work_items)}")
    print(f"  Skipped               : {len(skipped)}")
    print(f"  Paths with errors     : {len(set(pn for pn, _ in all_errors))}")
    print(f"  Start time            : {batch_start_dt.strftime('%H:%M:%S')}")
    print(f"  End time              : {batch_end_dt.strftime('%H:%M:%S')}")
    print(f"  Elapsed time          : {hours:03d}:{minutes:02d}:{secs:02d}")

    if all_errors:
        print("\nERRORS SUMMARY:")
        for path_num, msg in all_errors:
            print(f"  Path {path_num}: {msg}")

    if skipped:
        print(f"\nSKIPPED PATHS: {skipped}")

    print(f"{'='*60}")


if __name__ == "__main__":
    main()
