import geopandas as gpd
import py3dep
import rioxarray
import rasterio
from rasterio.merge import merge
import fiona
import pyproj
import numpy as np
import os
import sys
import time
import logging
import tempfile
import math
import concurrent.futures
import requests
from pathlib import Path
from threading import Lock
from shapely.geometry import mapping, Polygon, box
from shapely.ops import transform, unary_union
import shapely

# ---------------------------------------------
# CONFIGURATION
# ---------------------------------------------
KML_DIR       = "D:\\KinderMorgan\\kml_files"
OUTPUT_DIR    = "D:\\KinderMorgan\\clipped_dems"
LOG_FILE      = "D:\\KinderMorgan\\dem_download.log"
RESOLUTION    = 10
SLEEP_BETWEEN = 0.2
MAX_WORKERS   = 24
MAX_RETRIES   = 3
RETRY_DELAY   = 5
CHUNK_MILES   = 10.0

# GLO-30 tile cache
GLO30_CACHE = "D:\\KinderMorgan\\glo30_cache"
GLO30_BASE  = "https://copernicus-dem-30m.s3.amazonaws.com"

# Border thresholds — corridors touching these use GLO-30 for ALL chunks
MEXICO_LAT  = 25.5   # southern bound below this -> near Mexico
CANADA_LAT  = 49.0   # northern bound above this -> near Canada

# ---------------------------------------------
# LOGGING SETUP
# ---------------------------------------------
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s  %(levelname)-8s  %(message)s", datefmt="%H:%M:%S")
fh = logging.FileHandler(LOG_FILE, encoding="utf-8")
fh.setFormatter(formatter)
log.addHandler(fh)
try:
    sys.stdout.reconfigure(encoding="utf-8")
except AttributeError:
    pass
sh = logging.StreamHandler(sys.stdout)
sh.setFormatter(formatter)
log.addHandler(sh)

# ---------------------------------------------
# SETUP
# ---------------------------------------------
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(GLO30_CACHE, exist_ok=True)
fiona.drvsupport.supported_drivers["KML"]    = "rw"
fiona.drvsupport.supported_drivers["LIBKML"] = "rw"

METERS_PER_MILE          = 1609.344
APPROX_METERS_PER_DEGREE = 111000.0

counter_lock = Lock()
counts = {"success": 0, "skipped": 0, "failed": 0}

def increment(key):
    with counter_lock:
        counts[key] += 1

# ---------------------------------------------
# FIND ALL KML FILES
# ---------------------------------------------
kml_dir_path = Path(KML_DIR)
log.info(f"Looking for KML files in: {kml_dir_path.resolve()}")
if not kml_dir_path.exists():
    log.error(f"Directory '{KML_DIR}' not found.")
    sys.exit(1)
kml_files = sorted(kml_dir_path.glob("*.kml"))
log.info(f"Found {len(kml_files)} KML files")
if len(kml_files) == 0:
    log.error("No KML files found.")
    sys.exit(1)

# ---------------------------------------------
# GEOMETRY HELPERS
# ---------------------------------------------
def reproject_geom(geom, from_crs, to_crs):
    project = pyproj.Transformer.from_crs(from_crs, to_crs, always_xy=True).transform
    return transform(project, geom)

def get_long_axis(polygon):
    coords = list(polygon.exterior.coords)[:-1]
    edges = []
    for i in range(4):
        p1 = np.array(coords[i])
        p2 = np.array(coords[(i + 1) % 4])
        edges.append((np.linalg.norm(p2 - p1), p1, p2))
    edges.sort(key=lambda e: e[0], reverse=True)
    long_len   = edges[0][0]
    short_len  = edges[2][0]
    mid0       = (edges[2][1] + edges[2][2]) / 2.0
    mid1       = (edges[3][1] + edges[3][2]) / 2.0
    axis_vec   = mid1 - mid0
    axis_unit  = axis_vec / np.linalg.norm(axis_vec)
    short_unit = np.array([-axis_unit[1], axis_unit[0]])
    return mid0, mid1, long_len, short_len, axis_unit, short_unit

def make_chunk_polygons(polygon_4326, chunk_size_deg):
    mid0, mid1, long_len, short_len, axis_unit, short_unit = get_long_axis(polygon_4326)
    half_width = short_len / 2.0
    n_chunks   = max(1, int(np.ceil(long_len / chunk_size_deg)))
    step       = long_len / n_chunks
    chunks = []
    for i in range(n_chunks):
        p_start = mid0 + axis_unit * (i * step)
        p_end   = mid0 + axis_unit * ((i + 1) * step)
        c1 = p_start + short_unit * half_width
        c2 = p_start - short_unit * half_width
        c3 = p_end   - short_unit * half_width
        c4 = p_end   + short_unit * half_width
        chunks.append(Polygon([c1, c2, c3, c4, c1]))
    return chunks

# ---------------------------------------------
# GLO-30 TILE HELPERS
# ---------------------------------------------
def glo30_tile_name(lat, lon):
    lat_hemi = "N" if lat >= 0 else "S"
    lon_hemi = "E" if lon >= 0 else "W"
    stem = f"Copernicus_DSM_COG_10_{lat_hemi}{abs(lat):02d}_00_{lon_hemi}{abs(lon):03d}_00_DEM"
    return stem

# Per-tile download locks — prevents multiple threads downloading same tile
tile_locks = {}
tile_locks_lock = Lock()

def get_tile_lock(tile_name):
    with tile_locks_lock:
        if tile_name not in tile_locks:
            tile_locks[tile_name] = Lock()
        return tile_locks[tile_name]

def download_glo30_tile(tile_name):
    local_path = Path(GLO30_CACHE) / f"{tile_name}.tif"
    if local_path.exists():
        return str(local_path)

    # Serialize downloads of the same tile across threads
    with get_tile_lock(tile_name):
        # Re-check after acquiring lock in case another thread downloaded it
        if local_path.exists():
            return str(local_path)
        url = f"{GLO30_BASE}/{tile_name}/{tile_name}.tif"
        try:
            resp = requests.get(url, timeout=120, stream=True)
            resp.raise_for_status()
            with open(local_path, "wb") as f:
                for chunk in resp.iter_content(chunk_size=8192):
                    f.write(chunk)
            log.info(f"Downloaded GLO-30 tile: {tile_name}")
            return str(local_path)
        except Exception as e:
            log.error(f"Failed to download GLO-30 tile {tile_name}: {e}")
            return None

def tiles_for_bounds(bounds):
    pad = 0.1
    minx, miny, maxx, maxy = bounds
    tiles = []
    for lat in range(int(math.floor(miny - pad)), int(math.ceil(maxy + pad))):
        for lon in range(int(math.floor(minx - pad)), int(math.ceil(maxx + pad))):
            tiles.append(glo30_tile_name(lat, lon))
    return tiles

# ---------------------------------------------
# FETCH ONE CHUNK FROM GLO-30
# ---------------------------------------------
def fetch_chunk_glo30(chunk_4326, kml_name, chunk_idx):
    bounds     = chunk_4326.bounds
    tile_names = tiles_for_bounds(bounds)

    tile_paths = []
    for tile_name in tile_names:
        path = download_glo30_tile(tile_name)
        if path:
            tile_paths.append(path)

    if not tile_paths:
        log.error(f"[{kml_name}] chunk {chunk_idx}: no GLO-30 tiles available")
        return None

    try:
        if len(tile_paths) > 1:
            src_files = [rasterio.open(p) for p in tile_paths]
            mosaic_data, mosaic_transform = merge(src_files)
            mosaic_crs  = src_files[0].crs
            mosaic_meta = src_files[0].meta.copy()
            for s in src_files:
                s.close()
            mosaic_meta.update({
                "driver": "GTiff", "height": mosaic_data.shape[1],
                "width": mosaic_data.shape[2], "transform": mosaic_transform,
                "crs": mosaic_crs
            })
            tmp = tempfile.NamedTemporaryFile(suffix=".tif", delete=False)
            tmp_path = tmp.name
            tmp.close()
            with rasterio.open(tmp_path, "w", **mosaic_meta) as dst:
                dst.write(mosaic_data)
            read_path = tmp_path
        else:
            read_path = tile_paths[0]
            tmp_path  = None

        da      = rioxarray.open_rasterio(read_path)
        da      = da.rio.write_crs("EPSG:4326", inplace=True)
        clipped = da.rio.clip([mapping(chunk_4326)], crs="EPSG:4326", all_touched=True, drop=True)
        result  = clipped.rio.reproject("EPSG:4326")
        clipped.close()
        da.close()

        if tmp_path:
            for _ in range(5):
                try:
                    os.unlink(tmp_path)
                    break
                except Exception:
                    time.sleep(0.3)

        log.info(f"[{kml_name}] chunk {chunk_idx}: GLO-30 OK (30m)")
        return result

    except Exception as e:
        log.error(f"[{kml_name}] chunk {chunk_idx}: GLO-30 failed: {e}")
        return None

# ---------------------------------------------
# FETCH ONE CHUNK FROM 3DEP
# ---------------------------------------------
def fetch_chunk_3dep(chunk_4326, kml_name, chunk_idx):
    for attempt in range(1, MAX_RETRIES + 1):
        try:
            time.sleep(SLEEP_BETWEEN)
            dem         = py3dep.get_dem(chunk_4326, resolution=RESOLUTION)
            chunk_crs   = dem.rio.crs
            chunk_match = reproject_geom(chunk_4326, "EPSG:4326", chunk_crs)
            clipped     = dem.rio.clip(
                [mapping(chunk_match)], crs=chunk_crs, all_touched=True, drop=True
            )
            return clipped.rio.reproject("EPSG:4326")
        except Exception as e:
            if attempt < MAX_RETRIES:
                log.warning(f"[{kml_name}] chunk {chunk_idx} attempt {attempt} failed: {e} -- retrying")
                time.sleep(RETRY_DELAY)
            else:
                log.error(f"[{kml_name}] chunk {chunk_idx}: 3DEP failed after {MAX_RETRIES} attempts")
                return None

# ---------------------------------------------
# MAIN PER-CORRIDOR FUNCTION
# ---------------------------------------------
def download_and_clip(kml_path):
    try:
        out_path = Path(OUTPUT_DIR) / kml_path.with_suffix(".tif").name

        if out_path.exists():
            log.info(f"SKIP -- already exists: {out_path.name}")
            increment("skipped")
            return

        gdf = gpd.read_file(str(kml_path), driver="KML").to_crs("EPSG:4326")
        if gdf.empty:
            log.error(f"EMPTY -- {kml_path.name}")
            increment("failed")
            return

        geom_4326 = shapely.force_2d(gdf.geometry.iloc[0])
        bounds    = geom_4326.bounds

        # Decide data source for entire corridor
        near_border = bounds[1] < MEXICO_LAT or bounds[3] > CANADA_LAT
        if near_border:
            log.info(f"[{kml_path.name}] border corridor -- using GLO-30 for all chunks")
            fetch_fn = fetch_chunk_glo30
        else:
            fetch_fn = fetch_chunk_3dep

        chunk_size_deg = (CHUNK_MILES * METERS_PER_MILE) / APPROX_METERS_PER_DEGREE
        chunks_4326    = make_chunk_polygons(geom_4326, chunk_size_deg)
        n_chunks       = len(chunks_4326)
        log.info(f"[{kml_path.name}] {n_chunks} chunk(s)")

        chunk_arrays = []
        for i, chunk_4326 in enumerate(chunks_4326):
            result = fetch_fn(chunk_4326, kml_path.name, i)
            if result is not None:
                chunk_arrays.append(result)

        if not chunk_arrays:
            log.error(f"FAILED -- all chunks failed: {kml_path.name}")
            increment("failed")
            return

        if len(chunk_arrays) < n_chunks:
            log.warning(f"[{kml_path.name}] {len(chunk_arrays)}/{n_chunks} chunks OK -- output may have gaps")

        tmpdir = tempfile.mkdtemp()
        try:
            tmp_paths = []
            for i, arr in enumerate(chunk_arrays):
                tmp_path = os.path.join(tmpdir, f"chunk_{i:04d}.tif")
                arr.rio.to_raster(tmp_path)
                tmp_paths.append(tmp_path)

            src_files = [rasterio.open(p) for p in tmp_paths]
            mosaic, mosaic_transform = merge(src_files)
            mosaic_crs  = src_files[0].crs
            mosaic_meta = src_files[0].meta.copy()
            for s in src_files:
                s.close()

            mosaic_meta.update({
                "driver": "GTiff", "height": mosaic.shape[1],
                "width": mosaic.shape[2], "transform": mosaic_transform,
                "crs": mosaic_crs, "compress": "lzw"
            })

            mosaic_path = os.path.join(tmpdir, "mosaic.tif")
            with rasterio.open(mosaic_path, "w", **mosaic_meta) as dest:
                dest.write(mosaic)

            mosaic_da    = rioxarray.open_rasterio(mosaic_path)
            mosaic_da    = mosaic_da.rio.write_crs(mosaic_crs, inplace=True)
            clipped_4326 = mosaic_da.rio.clip(
                [mapping(geom_4326)], crs="EPSG:4326", all_touched=True, drop=True
            )
            clipped_4326.rio.to_raster(str(out_path))
            clipped_4326.close()
            mosaic_da.close()

        finally:
            import shutil
            for _ in range(5):
                try:
                    shutil.rmtree(tmpdir, ignore_errors=True)
                    break
                except Exception:
                    time.sleep(0.5)

        log.info(f"OK -- {kml_path.name} -> {out_path.name}")
        increment("success")

    except Exception as e:
        log.error(f"EXCEPTION in {kml_path.name}: {e}", exc_info=True)
        increment("failed")

# ---------------------------------------------
# PARALLEL EXECUTION
# ---------------------------------------------
log.info(f"Starting downloads with {MAX_WORKERS} workers")
start_time = time.time()

with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    futures = {executor.submit(download_and_clip, kml): kml for kml in kml_files}
    for future in concurrent.futures.as_completed(futures):
        try:
            future.result()
        except Exception as e:
            kml = futures[future]
            log.error(f"UNHANDLED EXCEPTION for {kml.name}: {e}", exc_info=True)
            increment("failed")

# ---------------------------------------------
# SUMMARY
# ---------------------------------------------
elapsed = time.time() - start_time
log.info(f"\n{'-'*40}")
log.info(f"Done in {elapsed/60:.1f} minutes")
log.info(f"  Success : {counts['success']}")
log.info(f"  Skipped : {counts['skipped']}")
log.info(f"  Failed  : {counts['failed']}")
log.info(f"  Log     : {LOG_FILE}")
if counts["failed"] > 0:
    log.info(f"Search '{LOG_FILE}' for 'FAILED' to find problem rectangles")
