geocrop-platform./apps/worker/dw_baseline.py

420 lines
13 KiB
Python

"""Dynamic World baseline loading for inference.
STEP 5: DW Baseline loader - loads and clips Dynamic World baseline COGs from MinIO.
Per AGENTS.md:
- Bucket: geocrop-baselines
- Prefix: dw/zim/summer/
- Files: DW_Zim_HighestConf_<year>_<year+1>-<tile_row>-<tile_col>.tif
- Efficient: Use windowed reads to avoid downloading entire tiles
- CRS: Must transform AOI bbox to tile CRS before windowing
"""
from __future__ import annotations
import time
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
# Try to import rasterio
try:
import rasterio
from rasterio.windows import Window, from_bounds
from rasterio.warp import transform_bounds, transform
HAS_RASTERIO = True
except ImportError:
HAS_RASTERIO = False
# DW Class mapping (Dynamic World has 10 classes)
DW_CLASS_NAMES = [
"water",
"trees",
"grass",
"flooded_vegetation",
"crops",
"shrub_and_scrub",
"built",
"bare",
"snow_and_ice",
]
DW_CLASS_COLORS = [
"#419BDF", # water
"#397D49", # trees
"#88B53E", # grass
"#FFAA5D", # flooded_vegetation
"#DA913D", # crops
"#919636", # shrub_and_scrub
"#B9B9B9", # built
"#D6D6D6", # bare
"#FFFFFF", # snow_and_ice
]
# DW bucket configuration
DW_BUCKET = "geocrop-baselines"
def list_dw_objects(
storage,
year: int,
season: str = "summer",
dw_type: str = "HighestConf",
bucket: str = DW_BUCKET,
) -> List[str]:
"""List matching DW baseline objects from MinIO.
Args:
storage: MinIOStorage instance
year: Growing season year (e.g., 2022 for 2022_2023 season)
season: Season (summer/winter)
dw_type: Type - "HighestConf", "Agreement", or "Mode"
bucket: MinIO bucket name
Returns:
List of object keys matching the pattern
"""
prefix = f"dw/zim/{season}/"
# List all objects under prefix
all_objects = storage.list_objects(bucket, prefix)
# Filter by year and type
pattern = f"DW_Zim_{dw_type}_{year}_{year + 1}"
matching = [obj for obj in all_objects if pattern in obj and obj.endswith(".tif")]
return matching
def get_dw_tile_window(
src_path: str,
aoi_bbox_wgs84: List[float],
) -> Tuple[Window, dict, np.ndarray]:
"""Get rasterio Window for AOI from a single tile.
Args:
src_path: Path or URL to tile (can be presigned URL)
aoi_bbox_wgs84: AOI bounding box [min_lon, min_lat, max_lon, max_lat] in WGS84
Returns:
Tuple of (window, profile, mosaic_array)
- window: The window that was read
- profile: rasterio profile for the window
- mosaic_array: The data read (may be smaller than window if no overlap)
"""
if not HAS_RASTERIO:
raise ImportError("rasterio is required for DW baseline loading")
with rasterio.open(src_path) as src:
# Transform AOI bbox from WGS84 to tile CRS
src_crs = src.crs
min_lon, min_lat, max_lon, max_lat = aoi_bbox_wgs84
# Transform corners to source CRS
transform_coords = transform(
{"init": "EPSG:4326"},
src_crs,
[min_lon, max_lon],
[min_lat, max_lat]
)
# Get pixel coordinates (note: row/col order)
col_min, row_min = src.index(transform_coords[0][0], transform_coords[1][0])
col_max, row_max = src.index(transform_coords[0][1], transform_coords[1][1])
# Ensure correct order
col_min, col_max = min(col_min, col_max), max(col_min, col_max)
row_min, row_max = min(row_min, row_max), max(row_min, row_max)
# Clamp to bounds
col_min = max(0, col_min)
row_min = max(0, row_min)
col_max = min(src.width, col_max)
row_max = min(src.height, row_max)
# Skip if no overlap
if col_max <= col_min or row_max <= row_min:
return None, None, None
# Create window
window = Window(col_min, row_min, col_max - col_min, row_max - row_min)
# Read data
data = src.read(1, window=window)
# Build profile for this window
profile = {
"driver": "GTiff",
"height": data.shape[0],
"width": data.shape[1],
"count": 1,
"dtype": rasterio.int16,
"nodata": 0, # DW uses 0 as nodata
"crs": src_crs,
"transform": src.window_transform(window),
"compress": "deflate",
}
return window, profile, data
def mosaic_windows(
windows_data: List[Tuple[Window, np.ndarray, dict]],
aoi_bbox_wgs84: List[float],
target_crs: str,
) -> Tuple[np.ndarray, dict]:
"""Mosaic multiple tile windows into single array.
Args:
windows_data: List of (window, data, profile) tuples
aoi_bbox_wgs84: Original AOI bbox in WGS84
target_crs: Target CRS for output
Returns:
Tuple of (mosaic_array, profile)
"""
if not windows_data:
raise ValueError("No windows to mosaic")
if len(windows_data) == 1:
# Single tile - just return
_, data, profile = windows_data[0]
return data, profile
# Multiple tiles - need to compute common bounds
# Use the first tile's CRS as target
_, _, first_profile = windows_data[0]
target_crs = first_profile["crs"]
# Compute bounds in target CRS
all_bounds = []
for window, data, profile in windows_data:
if data is None or data.size == 0:
continue
# Get bounds from profile transform
t = profile["transform"]
h, w = data.shape
bounds = [t[2], t[5], t[2] + w * t[0], t[5] + h * t[3]]
all_bounds.append(bounds)
if not all_bounds:
raise ValueError("No valid data in windows")
# Compute union bounds
min_x = min(b[0] for b in all_bounds)
min_y = min(b[1] for b in all_bounds)
max_x = max(b[2] for b in all_bounds)
max_y = max(b[3] for b in all_bounds)
# Use resolution from first tile
res = abs(first_profile["transform"][0])
# Compute output shape
out_width = int((max_x - min_x) / res)
out_height = int((max_y - min_y) / res)
# Create output array
mosaic = np.zeros((out_height, out_width), dtype=np.int16)
# Paste each window
for window, data, profile in windows_data:
if data is None or data.size == 0:
continue
t = profile["transform"]
# Compute offset
col_off = int((t[2] - min_x) / res)
row_off = int((t[5] - max_y + res) / res) # Note: transform origin is top-left
# Ensure valid
if col_off < 0:
data = data[:, -col_off:]
col_off = 0
if row_off < 0:
data = data[-row_off:, :]
row_off = 0
# Paste
h, w = data.shape
end_row = min(row_off + h, out_height)
end_col = min(col_off + w, out_width)
if end_row > row_off and end_col > col_off:
mosaic[row_off:end_row, col_off:end_col] = data[:end_row-row_off, :end_col-col_off]
# Build output profile
from rasterio.transform import from_origin
out_transform = from_origin(min_x, max_y, res, res)
profile = {
"driver": "GTiff",
"height": out_height,
"width": out_width,
"count": 1,
"dtype": rasterio.int16,
"nodata": 0,
"crs": target_crs,
"transform": out_transform,
"compress": "deflate",
}
return mosaic, profile
def load_dw_baseline_window(
storage,
year: int,
aoi_bbox_wgs84: List[float],
season: str = "summer",
dw_type: str = "HighestConf",
bucket: str = DW_BUCKET,
max_retries: int = 3,
) -> Tuple[np.ndarray, dict]:
"""Load DW baseline clipped to AOI window from MinIO.
Uses efficient windowed reads to avoid downloading entire tiles.
Args:
storage: MinIOStorage instance with presign_get method
year: Growing season year (e.g., 2022 for 2022_2023 season)
season: Season (summer/winter) - maps to prefix
aoi_bbox_wgs84: AOI bounding box [min_lon, min_lat, max_lon, max_lat] in WGS84
dw_type: Type - "HighestConf", "Agreement", or "Mode"
bucket: MinIO bucket name
max_retries: Maximum retry attempts for failed reads
Returns:
Tuple of:
- dw_arr: uint8 (or int16) baseline raster clipped to AOI window
- profile: rasterio profile for writing outputs aligned to this window
Raises:
FileNotFoundError: If no matching DW tile found
RuntimeError: If window read fails after retries
"""
if not HAS_RASTERIO:
raise ImportError("rasterio is required for DW baseline loading")
# Step 1: List matching objects
matching_keys = list_dw_objects(storage, year, season, dw_type, bucket)
if not matching_keys:
prefix = f"dw/zim/{season}/"
raise FileNotFoundError(
f"No DW baseline found for year={year}, type={dw_type}, "
f"season={season}. Searched prefix: {prefix}"
)
# Step 2: For each tile, get presigned URL and read window
windows_data = []
last_error = None
for key in matching_keys:
for attempt in range(max_retries):
try:
# Get presigned URL
url = storage.presign_get(bucket, key, expires=3600)
# Get window
window, profile, data = get_dw_tile_window(url, aoi_bbox_wgs84)
if data is not None and data.size > 0:
windows_data.append((window, data, profile))
break # Success, move to next tile
except Exception as e:
last_error = e
if attempt < max_retries - 1:
wait_time = 2 ** attempt # Exponential backoff
time.sleep(wait_time)
continue
if not windows_data:
raise RuntimeError(
f"Failed to read any DW tiles after {max_retries} retries. "
f"Last error: {last_error}"
)
# Step 3: Mosaic if needed
dw_arr, profile = mosaic_windows(windows_data, aoi_bbox_wgs84, bucket)
return dw_arr, profile
def get_dw_class_name(class_id: int) -> str:
"""Get DW class name from class ID.
Args:
class_id: DW class ID (0-9)
Returns:
Class name or "unknown"
"""
if 0 <= class_id < len(DW_CLASS_NAMES):
return DW_CLASS_NAMES[class_id]
return "unknown"
def get_dw_class_color(class_id: int) -> str:
"""Get DW class color from class ID.
Args:
class_id: DW class ID (0-9)
Returns:
Hex color code
"""
if 0 <= class_id < len(DW_CLASS_COLORS):
return DW_CLASS_COLORS[class_id]
return "#000000"
# ==========================================
# Self-Test
# ==========================================
if __name__ == "__main__":
print("=== DW Baseline Loader Test ===")
if not HAS_RASTERIO:
print("rasterio not installed - skipping full test")
print("Import test: PASS (module loads)")
else:
# Test object listing (without real storage)
print("\n1. Testing DW object pattern...")
year = 2018
season = "summer"
dw_type = "HighestConf"
# Simulate what list_dw_objects would return based on known files
print(f" Year: {year}, Type: {dw_type}, Season: {season}")
print(f" Expected pattern: DW_Zim_{dw_type}_{year}_{year+1}-*.tif")
print(f" This would search prefix: dw/zim/{season}/")
# Check if we can import storage
try:
from storage import MinIOStorage
print("\n2. Testing MinIOStorage...")
# Try to list objects (will fail without real MinIO)
storage = MinIOStorage()
objects = storage.list_objects(DW_BUCKET, f"dw/zim/{season}/")
# Filter for our year
pattern = f"DW_Zim_{dw_type}_{year}_{year + 1}"
matching = [o for o in objects if pattern in o and o.endswith(".tif")]
print(f" Found {len(matching)} matching objects")
for obj in matching[:5]:
print(f" {obj}")
except Exception as e:
print(f" MinIO not available: {e}")
print(" (This is expected outside Kubernetes)")
print("\n=== DW Baseline Test Complete ===")