"""
batch_kizer_krunch.py
---------------------
Batch runner for kizer_krunch_code.py.

Iterates over every path_num in the Rosetta CSV and runs Interference,
Terrain, or Both for each path — no interactive prompts required.

Usage:
    python batch_kizer_krunch.py              # defaults to Both (I+T)
    python batch_kizer_krunch.py --mode I     # Interference only
    python batch_kizer_krunch.py --mode T     # Terrain only
    python batch_kizer_krunch.py --mode B     # Both (default)
    python batch_kizer_krunch.py --start 5    # start at path_num 5
    python batch_kizer_krunch.py --end 20     # stop after path_num 20
    python batch_kizer_krunch.py --paths 3 7 12  # run specific path numbers only

Edit the CONFIGURATION section below to match your environment.
"""

import argparse
import time
from datetime import datetime
import traceback
import pandas as pd

# ── Adjust these two paths to match your environment ──────────────────────────
ROSETTA_PATH = r'D:\ATT\ATT_Rosetta_populated.csv'
GPD_PATH     = r'D:\ATT\ATT_GPD_ALL.xlsx'
# ──────────────────────────────────────────────────────────────────────────────

# Import the calculation functions and column-name maps from the main module.
# The module-level code is guarded by `if __name__ == "__main__"` so importing
# it does NOT trigger run_program().
from kizer_krunch_code import (
    FNC_COLUMNS,
    GENERALPATHDATA_COLUMNS,
    run_interference_calculation,
    run_terrain_calculation,
)


# ── Helpers ───────────────────────────────────────────────────────────────────

def load_gpd(gpd_path: str) -> dict:
    """Load GeneralPathData Excel file into a lookup dict keyed by (tx, rx)."""
    print(f"Loading General Path Data from: {gpd_path}")
    gpd_df = pd.read_excel(gpd_path, dtype=str)
    gpd_lookup = {}
    for _, row in gpd_df.iterrows():
        tx = str(row[GENERALPATHDATA_COLUMNS["tx_call_sign"]]).strip()
        rx = str(row[GENERALPATHDATA_COLUMNS["rx_call_sign"]]).strip()
        gpd_lookup[(tx, rx)] = {
            "channels":  row[GENERALPATHDATA_COLUMNS["channels"]],
            "bandwidth": row[GENERALPATHDATA_COLUMNS["bandwidth"]],
            "frequency": row[GENERALPATHDATA_COLUMNS["fr_frequency_assigned_MHz"]],
            "gain":      row[GENERALPATHDATA_COLUMNS["rxan_gain_dBi"]],
            "tilt":      row[GENERALPATHDATA_COLUMNS["RxTilt_deg"]],
            "bearing":   row[GENERALPATHDATA_COLUMNS["RxBearing_deg"]],
            "ant_model": str(row[GENERALPATHDATA_COLUMNS["AltRcvAntenna"]]).strip(),
        }
    print(f"  GPD loaded: {len(gpd_lookup)} paths.\n")
    return gpd_lookup


def build_paths_for_row(filenaming: pd.DataFrame, rownumber: int) -> dict:
    """Extract all file paths and identifiers for a single Rosetta row."""
    loc = filenaming.loc[rownumber]

    drive     = str(loc[FNC_COLUMNS["drive"]]).strip()
    folder    = str(loc[FNC_COLUMNS["folder"]]).strip().strip('\\')
    subfolder = str(loc[FNC_COLUMNS["subfolder"]]).strip()
    base_path = f"{drive}\\{folder}\\{subfolder}"

    AIntFunction  = str(loc[FNC_COLUMNS["AIntFunction"]])
    ATerrFunction = str(loc[FNC_COLUMNS["ATerrFunction"]])
    # NOTE: In the original script BIntFunction and BTerrFunction both
    # accidentally read AIntFunction / ATerrFunction — kept identical here
    # to stay consistent with existing behaviour.

    return {
        "AReceiver":      str(loc[FNC_COLUMNS["AReceiver"]]),
        "BReceiver":      str(loc[FNC_COLUMNS["BReceiver"]]),
        "drive":          drive,
        "folder":         folder,
        "subfolder":      subfolder,
        # Input files (full paths come directly from Rosetta)
        "AIntInputfile":  str(loc[FNC_COLUMNS["fileAInt"]]),
        "BIntInputfile":  str(loc[FNC_COLUMNS["fileBInt"]]),
        "ATerrInputfile": str(loc[FNC_COLUMNS["fileATerr"]]),
        "BTerrInputfile": str(loc[FNC_COLUMNS["fileBTerr"]]),
        # Output files (constructed from base_path)
        "AIntOutputfile":  f"{base_path}\\{subfolder}_A_{AIntFunction}.csv",
        "BIntOutputfile":  f"{base_path}\\{subfolder}_B_{AIntFunction}.csv",
        "ATerrOutputfile": f"{base_path}\\{subfolder}_A_{ATerrFunction}.csv",
        "BTerrOutputfile": f"{base_path}\\{subfolder}_B_{ATerrFunction}.csv",
    }


def run_path(path_num: str, paths: dict, gpd_lookup: dict,
             run_interference: bool, run_terrain: bool) -> list[str]:
    """
    Run Interference and/or Terrain for one path.
    Returns a list of error strings (empty = success).
    """
    errors = []

    if run_interference:
        for label, inp, out in [
            ("A Interference", paths["AIntInputfile"],  paths["AIntOutputfile"]),
            ("B Interference", paths["BIntInputfile"],  paths["BIntOutputfile"]),
        ]:
            print(f"  [{label}] {inp}  →  {out}")
            try:
                run_interference_calculation(inp, out, paths["drive"], paths["folder"], gpd_lookup)
            except Exception as e:
                msg = f"Path {path_num} | {label}: {e}"
                print(f"  ERROR: {msg}")
                traceback.print_exc()
                errors.append(msg)

    if run_terrain:
        for label, inp, out in [
            ("A Terrain", paths["ATerrInputfile"],  paths["ATerrOutputfile"]),
            ("B Terrain", paths["BTerrInputfile"],  paths["BTerrOutputfile"]),
        ]:
            print(f"  [{label}] {inp}  →  {out}")
            try:
                run_terrain_calculation(inp, out)
            except Exception as e:
                msg = f"Path {path_num} | {label}: {e}"
                print(f"  ERROR: {msg}")
                traceback.print_exc()
                errors.append(msg)

    return errors


# ── Main batch loop ───────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(description="Batch runner for kizer_krunch_code.py")
    parser.add_argument(
        "--mode", choices=["I", "T", "B", "i", "t", "b"],
        default="B",
        help="I=Interference only, T=Terrain only, B=Both (default: B)",
    )
    parser.add_argument(
        "--start", type=int, default=None,
        help="First path_num to process (inclusive). Skips earlier rows.",
    )
    parser.add_argument(
        "--end", type=int, default=None,
        help="Last path_num to process (inclusive). Stops after this row.",
    )
    parser.add_argument(
        "--paths", type=int, nargs="+", default=None,
        help="Explicit list of path numbers to run (overrides --start/--end).",
    )
    args = parser.parse_args()

    mode = args.mode.upper()
    run_interference = mode in ("I", "B")
    run_terrain      = mode in ("T", "B")

    # ── Load Rosetta ──────────────────────────────────────────────────────────
    print(f"Loading Rosetta from: {ROSETTA_PATH}")
    filenaming = pd.read_csv(ROSETTA_PATH)
    all_path_nums = filenaming["path_num"].astype(str).str.strip().tolist()
    print(f"  Rosetta loaded: {len(all_path_nums)} paths.\n")

    # ── Determine which paths to run ─────────────────────────────────────────
    if args.paths:
        target_paths = [str(p) for p in args.paths]
        missing = [p for p in target_paths if p not in all_path_nums]
        if missing:
            print(f"WARNING: These path numbers were not found in Rosetta and will be skipped: {missing}")
        target_paths = [p for p in target_paths if p in all_path_nums]
    else:
        target_paths = all_path_nums
        if args.start is not None:
            target_paths = [p for p in target_paths if int(p) >= args.start]
        if args.end is not None:
            target_paths = [p for p in target_paths if int(p) <= args.end]

    if not target_paths:
        print("No paths to process. Check --start / --end / --paths arguments.")
        return

    # ── Load GPD once ────────────────────────────────────────────────────────
    gpd_lookup = load_gpd(GPD_PATH)

    # ── Batch loop ────────────────────────────────────────────────────────────
    mode_label = {"I": "Interference only", "T": "Terrain only", "B": "Interference + Terrain"}[mode]
    print(f"{'='*60}")
    print(f"BATCH RUN — mode: {mode_label}")
    print(f"Paths to process: {len(target_paths)}  ({target_paths[0]} … {target_paths[-1]})")
    print(f"{'='*60}\n")

    batch_start = time.time()
    batch_start_dt = datetime.now()
    print(f"  Start time : {batch_start_dt.strftime('%H:%M:%S')}")
    all_errors  = []          # (path_num, error_message)
    skipped     = []          # path_nums we couldn't build paths for

    for idx, path_num in enumerate(target_paths, start=1):
        print(f"\n── Path {path_num}  ({idx}/{len(target_paths)}) ──────────────────────────")

        # Locate the row
        matching = filenaming.index[
            filenaming["path_num"].astype(str).str.strip() == path_num
        ]
        if len(matching) == 0:
            print(f"  SKIP: path_num '{path_num}' not found in Rosetta.")
            skipped.append(path_num)
            continue

        rownumber = matching[0]

        try:
            paths = build_paths_for_row(filenaming, rownumber)
        except Exception as e:
            msg = f"Could not build file paths: {e}"
            print(f"  SKIP: {msg}")
            skipped.append(path_num)
            all_errors.append((path_num, msg))
            continue

        path_start = time.time()
        errs = run_path(path_num, paths, gpd_lookup, run_interference, run_terrain)
        elapsed = time.time() - path_start

        for e in errs:
            all_errors.append((path_num, e))

        status = "COMPLETED" if not errs else f"COMPLETED WITH {len(errs)} ERROR(S)"
        print(f"  {status}  ({elapsed:.1f}s)")

    # ── Summary ───────────────────────────────────────────────────────────────
    batch_end_dt  = datetime.now()
    total_elapsed = time.time() - batch_start
    hours, rem    = divmod(int(total_elapsed), 3600)
    minutes, secs = divmod(rem, 60)

    print(f"\n{'='*60}")
    print("BATCH COMPLETE")
    print(f"  Total paths attempted : {len(target_paths)}")
    print(f"  Skipped               : {len(skipped)}")
    print(f"  Paths with errors     : {len(set(pn for pn, _ in all_errors))}")
    print(f"  Start time            : {batch_start_dt.strftime('%H:%M:%S')}")
    print(f"  End time              : {batch_end_dt.strftime('%H:%M:%S')}")
    print(f"  Elapsed time          : {hours:03d}:{minutes:02d}:{secs:02d}")

    if all_errors:
        print("\nERRORS SUMMARY:")
        for path_num, msg in all_errors:
            print(f"  Path {path_num}: {msg}")

    if skipped:
        print(f"\nSKIPPED PATHS: {skipped}")

    print(f"{'='*60}")


if __name__ == "__main__":
    main()
