"""Dynamic World baseline loading for inference. STEP 5: DW Baseline loader - loads and clips Dynamic World baseline COGs from MinIO. Per AGENTS.md: - Bucket: geocrop-baselines - Prefix: dw/zim/summer/ - Files: DW_Zim_HighestConf__--.tif - Efficient: Use windowed reads to avoid downloading entire tiles - CRS: Must transform AOI bbox to tile CRS before windowing """ from __future__ import annotations import time from pathlib import Path from typing import List, Optional, Tuple import numpy as np # Try to import rasterio try: import rasterio from rasterio.windows import Window, from_bounds from rasterio.warp import transform_bounds, transform HAS_RASTERIO = True except ImportError: HAS_RASTERIO = False # DW Class mapping (Dynamic World has 10 classes) DW_CLASS_NAMES = [ "water", "trees", "grass", "flooded_vegetation", "crops", "shrub_and_scrub", "built", "bare", "snow_and_ice", ] DW_CLASS_COLORS = [ "#419BDF", # water "#397D49", # trees "#88B53E", # grass "#FFAA5D", # flooded_vegetation "#DA913D", # crops "#919636", # shrub_and_scrub "#B9B9B9", # built "#D6D6D6", # bare "#FFFFFF", # snow_and_ice ] # DW bucket configuration DW_BUCKET = "geocrop-baselines" def list_dw_objects( storage, year: int, season: str = "summer", dw_type: str = "HighestConf", bucket: str = DW_BUCKET, ) -> List[str]: """List matching DW baseline objects from MinIO. Args: storage: MinIOStorage instance year: Growing season year (e.g., 2022 for 2022_2023 season) season: Season (summer/winter) dw_type: Type - "HighestConf", "Agreement", or "Mode" bucket: MinIO bucket name Returns: List of object keys matching the pattern """ prefix = f"dw/zim/{season}/" # List all objects under prefix all_objects = storage.list_objects(bucket, prefix) # Filter by year and type pattern = f"DW_Zim_{dw_type}_{year}_{year + 1}" matching = [obj for obj in all_objects if pattern in obj and obj.endswith(".tif")] return matching def get_dw_tile_window( src_path: str, aoi_bbox_wgs84: List[float], ) -> Tuple[Window, dict, np.ndarray]: """Get rasterio Window for AOI from a single tile. Args: src_path: Path or URL to tile (can be presigned URL) aoi_bbox_wgs84: AOI bounding box [min_lon, min_lat, max_lon, max_lat] in WGS84 Returns: Tuple of (window, profile, mosaic_array) - window: The window that was read - profile: rasterio profile for the window - mosaic_array: The data read (may be smaller than window if no overlap) """ if not HAS_RASTERIO: raise ImportError("rasterio is required for DW baseline loading") with rasterio.open(src_path) as src: # Transform AOI bbox from WGS84 to tile CRS src_crs = src.crs min_lon, min_lat, max_lon, max_lat = aoi_bbox_wgs84 # Transform corners to source CRS transform_coords = transform( {"init": "EPSG:4326"}, src_crs, [min_lon, max_lon], [min_lat, max_lat] ) # Get pixel coordinates (note: row/col order) col_min, row_min = src.index(transform_coords[0][0], transform_coords[1][0]) col_max, row_max = src.index(transform_coords[0][1], transform_coords[1][1]) # Ensure correct order col_min, col_max = min(col_min, col_max), max(col_min, col_max) row_min, row_max = min(row_min, row_max), max(row_min, row_max) # Clamp to bounds col_min = max(0, col_min) row_min = max(0, row_min) col_max = min(src.width, col_max) row_max = min(src.height, row_max) # Skip if no overlap if col_max <= col_min or row_max <= row_min: return None, None, None # Create window window = Window(col_min, row_min, col_max - col_min, row_max - row_min) # Read data data = src.read(1, window=window) # Build profile for this window profile = { "driver": "GTiff", "height": data.shape[0], "width": data.shape[1], "count": 1, "dtype": rasterio.int16, "nodata": 0, # DW uses 0 as nodata "crs": src_crs, "transform": src.window_transform(window), "compress": "deflate", } return window, profile, data def mosaic_windows( windows_data: List[Tuple[Window, np.ndarray, dict]], aoi_bbox_wgs84: List[float], target_crs: str, ) -> Tuple[np.ndarray, dict]: """Mosaic multiple tile windows into single array. Args: windows_data: List of (window, data, profile) tuples aoi_bbox_wgs84: Original AOI bbox in WGS84 target_crs: Target CRS for output Returns: Tuple of (mosaic_array, profile) """ if not windows_data: raise ValueError("No windows to mosaic") if len(windows_data) == 1: # Single tile - just return _, data, profile = windows_data[0] return data, profile # Multiple tiles - need to compute common bounds # Use the first tile's CRS as target _, _, first_profile = windows_data[0] target_crs = first_profile["crs"] # Compute bounds in target CRS all_bounds = [] for window, data, profile in windows_data: if data is None or data.size == 0: continue # Get bounds from profile transform t = profile["transform"] h, w = data.shape bounds = [t[2], t[5], t[2] + w * t[0], t[5] + h * t[3]] all_bounds.append(bounds) if not all_bounds: raise ValueError("No valid data in windows") # Compute union bounds min_x = min(b[0] for b in all_bounds) min_y = min(b[1] for b in all_bounds) max_x = max(b[2] for b in all_bounds) max_y = max(b[3] for b in all_bounds) # Use resolution from first tile res = abs(first_profile["transform"][0]) # Compute output shape out_width = int((max_x - min_x) / res) out_height = int((max_y - min_y) / res) # Create output array mosaic = np.zeros((out_height, out_width), dtype=np.int16) # Paste each window for window, data, profile in windows_data: if data is None or data.size == 0: continue t = profile["transform"] # Compute offset col_off = int((t[2] - min_x) / res) row_off = int((t[5] - max_y + res) / res) # Note: transform origin is top-left # Ensure valid if col_off < 0: data = data[:, -col_off:] col_off = 0 if row_off < 0: data = data[-row_off:, :] row_off = 0 # Paste h, w = data.shape end_row = min(row_off + h, out_height) end_col = min(col_off + w, out_width) if end_row > row_off and end_col > col_off: mosaic[row_off:end_row, col_off:end_col] = data[:end_row-row_off, :end_col-col_off] # Build output profile from rasterio.transform import from_origin out_transform = from_origin(min_x, max_y, res, res) profile = { "driver": "GTiff", "height": out_height, "width": out_width, "count": 1, "dtype": rasterio.int16, "nodata": 0, "crs": target_crs, "transform": out_transform, "compress": "deflate", } return mosaic, profile def load_dw_baseline_window( storage, year: int, aoi_bbox_wgs84: List[float], season: str = "summer", dw_type: str = "HighestConf", bucket: str = DW_BUCKET, max_retries: int = 3, ) -> Tuple[np.ndarray, dict]: """Load DW baseline clipped to AOI window from MinIO. Uses efficient windowed reads to avoid downloading entire tiles. Args: storage: MinIOStorage instance with presign_get method year: Growing season year (e.g., 2022 for 2022_2023 season) season: Season (summer/winter) - maps to prefix aoi_bbox_wgs84: AOI bounding box [min_lon, min_lat, max_lon, max_lat] in WGS84 dw_type: Type - "HighestConf", "Agreement", or "Mode" bucket: MinIO bucket name max_retries: Maximum retry attempts for failed reads Returns: Tuple of: - dw_arr: uint8 (or int16) baseline raster clipped to AOI window - profile: rasterio profile for writing outputs aligned to this window Raises: FileNotFoundError: If no matching DW tile found RuntimeError: If window read fails after retries """ if not HAS_RASTERIO: raise ImportError("rasterio is required for DW baseline loading") # Step 1: List matching objects matching_keys = list_dw_objects(storage, year, season, dw_type, bucket) if not matching_keys: prefix = f"dw/zim/{season}/" raise FileNotFoundError( f"No DW baseline found for year={year}, type={dw_type}, " f"season={season}. Searched prefix: {prefix}" ) # Step 2: For each tile, get presigned URL and read window windows_data = [] last_error = None for key in matching_keys: for attempt in range(max_retries): try: # Get presigned URL url = storage.presign_get(bucket, key, expires=3600) # Get window window, profile, data = get_dw_tile_window(url, aoi_bbox_wgs84) if data is not None and data.size > 0: windows_data.append((window, data, profile)) break # Success, move to next tile except Exception as e: last_error = e if attempt < max_retries - 1: wait_time = 2 ** attempt # Exponential backoff time.sleep(wait_time) continue if not windows_data: raise RuntimeError( f"Failed to read any DW tiles after {max_retries} retries. " f"Last error: {last_error}" ) # Step 3: Mosaic if needed dw_arr, profile = mosaic_windows(windows_data, aoi_bbox_wgs84, bucket) return dw_arr, profile def get_dw_class_name(class_id: int) -> str: """Get DW class name from class ID. Args: class_id: DW class ID (0-9) Returns: Class name or "unknown" """ if 0 <= class_id < len(DW_CLASS_NAMES): return DW_CLASS_NAMES[class_id] return "unknown" def get_dw_class_color(class_id: int) -> str: """Get DW class color from class ID. Args: class_id: DW class ID (0-9) Returns: Hex color code """ if 0 <= class_id < len(DW_CLASS_COLORS): return DW_CLASS_COLORS[class_id] return "#000000" # ========================================== # Self-Test # ========================================== if __name__ == "__main__": print("=== DW Baseline Loader Test ===") if not HAS_RASTERIO: print("rasterio not installed - skipping full test") print("Import test: PASS (module loads)") else: # Test object listing (without real storage) print("\n1. Testing DW object pattern...") year = 2018 season = "summer" dw_type = "HighestConf" # Simulate what list_dw_objects would return based on known files print(f" Year: {year}, Type: {dw_type}, Season: {season}") print(f" Expected pattern: DW_Zim_{dw_type}_{year}_{year+1}-*.tif") print(f" This would search prefix: dw/zim/{season}/") # Check if we can import storage try: from storage import MinIOStorage print("\n2. Testing MinIOStorage...") # Try to list objects (will fail without real MinIO) storage = MinIOStorage() objects = storage.list_objects(DW_BUCKET, f"dw/zim/{season}/") # Filter for our year pattern = f"DW_Zim_{dw_type}_{year}_{year + 1}" matching = [o for o in objects if pattern in o and o.endswith(".tif")] print(f" Found {len(matching)} matching objects") for obj in matching[:5]: print(f" {obj}") except Exception as e: print(f" MinIO not available: {e}") print(" (This is expected outside Kubernetes)") print("\n=== DW Baseline Test Complete ===")