"""
Fills a_meanelev / b_meanelev by sampling the DEM TIF for each receiver's call sign
at that receiver's lat/lon.

TIF files expected at: c:\\users\\public\\clipped_dems\\<call_sign>.tif  (CRS = EPSG:4326)
"""

import pandas as pd
import rasterio
import os

DEM_DIR = r"c:\users\public\clipped_dems"

df = pd.read_csv("ATT_Mobility_Network_Definition_v7.csv", dtype=str)

for col in ["A_Latitude", "A_Longitude", "b_Latitude", "b_Longitude"]:
    df[col] = pd.to_numeric(df[col], errors="coerce")


import numpy as np

def sample_elev(tif_path, lon, lat, search_radius=2):
    if not os.path.exists(tif_path):
        return None
    with rasterio.open(tif_path) as src:
        val = list(src.sample([(lon, lat)]))[0][0]
        nodata = src.nodata
        is_nodata = (nodata is not None and np.isnan(nodata) and np.isnan(val)) or (val == nodata)
        if not is_nodata:
            return float(val)

        # fall back: search an expanding window around the pixel for nearest valid value
        row, col = src.index(lon, lat)
        band = src.read(1)
        for r in range(1, search_radius + 1):
            r0, r1 = max(0, row - r), min(band.shape[0], row + r + 1)
            c0, c1 = max(0, col - r), min(band.shape[1], col + r + 1)
            window = band[r0:r1, c0:c1]
            if nodata is not None and np.isnan(nodata):
                valid = window[~np.isnan(window)]
            else:
                valid = window[window != nodata]
            if valid.size > 0:
                return float(valid.mean())
        return None


df["a_meanelev"] = df.apply(lambda r: sample_elev(os.path.join(DEM_DIR, f"{r['Path_ID']}.tif"), r["A_Longitude"], r["A_Latitude"]), axis=1)
df["b_meanelev"] = df.apply(lambda r: sample_elev(os.path.join(DEM_DIR, f"{r['Path_ID']}.tif"), r["b_Longitude"], r["b_Latitude"]), axis=1)

df.to_csv("ATT_Mobility_Network_Definition_v8.csv", index=False)

missing_a = df.loc[df["a_meanelev"].isna(), "Path_ID"].unique()
missing_b = df.loc[df["b_meanelev"].isna(), "Path_ID"].unique()
print(f"Missing A elevations: {len(missing_a)} -> {list(missing_a[:10])}")
print(f"Missing B elevations: {len(missing_b)} -> {list(missing_b[:10])}")
