420 lines
13 KiB
Python
420 lines
13 KiB
Python
"""Dynamic World baseline loading for inference.
|
|
|
|
STEP 5: DW Baseline loader - loads and clips Dynamic World baseline COGs from MinIO.
|
|
|
|
Per AGENTS.md:
|
|
- Bucket: geocrop-baselines
|
|
- Prefix: dw/zim/summer/
|
|
- Files: DW_Zim_HighestConf_<year>_<year+1>-<tile_row>-<tile_col>.tif
|
|
- Efficient: Use windowed reads to avoid downloading entire tiles
|
|
- CRS: Must transform AOI bbox to tile CRS before windowing
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from pathlib import Path
|
|
from typing import List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
# Try to import rasterio
|
|
try:
|
|
import rasterio
|
|
from rasterio.windows import Window, from_bounds
|
|
from rasterio.warp import transform_bounds, transform
|
|
HAS_RASTERIO = True
|
|
except ImportError:
|
|
HAS_RASTERIO = False
|
|
|
|
|
|
# DW Class mapping (Dynamic World has 10 classes)
|
|
DW_CLASS_NAMES = [
|
|
"water",
|
|
"trees",
|
|
"grass",
|
|
"flooded_vegetation",
|
|
"crops",
|
|
"shrub_and_scrub",
|
|
"built",
|
|
"bare",
|
|
"snow_and_ice",
|
|
]
|
|
|
|
DW_CLASS_COLORS = [
|
|
"#419BDF", # water
|
|
"#397D49", # trees
|
|
"#88B53E", # grass
|
|
"#FFAA5D", # flooded_vegetation
|
|
"#DA913D", # crops
|
|
"#919636", # shrub_and_scrub
|
|
"#B9B9B9", # built
|
|
"#D6D6D6", # bare
|
|
"#FFFFFF", # snow_and_ice
|
|
]
|
|
|
|
# DW bucket configuration
|
|
DW_BUCKET = "geocrop-baselines"
|
|
|
|
|
|
def list_dw_objects(
|
|
storage,
|
|
year: int,
|
|
season: str = "summer",
|
|
dw_type: str = "HighestConf",
|
|
bucket: str = DW_BUCKET,
|
|
) -> List[str]:
|
|
"""List matching DW baseline objects from MinIO.
|
|
|
|
Args:
|
|
storage: MinIOStorage instance
|
|
year: Growing season year (e.g., 2022 for 2022_2023 season)
|
|
season: Season (summer/winter)
|
|
dw_type: Type - "HighestConf", "Agreement", or "Mode"
|
|
bucket: MinIO bucket name
|
|
|
|
Returns:
|
|
List of object keys matching the pattern
|
|
"""
|
|
prefix = f"dw/zim/{season}/"
|
|
|
|
# List all objects under prefix
|
|
all_objects = storage.list_objects(bucket, prefix)
|
|
|
|
# Filter by year and type
|
|
pattern = f"DW_Zim_{dw_type}_{year}_{year + 1}"
|
|
matching = [obj for obj in all_objects if pattern in obj and obj.endswith(".tif")]
|
|
|
|
return matching
|
|
|
|
|
|
def get_dw_tile_window(
|
|
src_path: str,
|
|
aoi_bbox_wgs84: List[float],
|
|
) -> Tuple[Window, dict, np.ndarray]:
|
|
"""Get rasterio Window for AOI from a single tile.
|
|
|
|
Args:
|
|
src_path: Path or URL to tile (can be presigned URL)
|
|
aoi_bbox_wgs84: AOI bounding box [min_lon, min_lat, max_lon, max_lat] in WGS84
|
|
|
|
Returns:
|
|
Tuple of (window, profile, mosaic_array)
|
|
- window: The window that was read
|
|
- profile: rasterio profile for the window
|
|
- mosaic_array: The data read (may be smaller than window if no overlap)
|
|
"""
|
|
if not HAS_RASTERIO:
|
|
raise ImportError("rasterio is required for DW baseline loading")
|
|
|
|
with rasterio.open(src_path) as src:
|
|
# Transform AOI bbox from WGS84 to tile CRS
|
|
src_crs = src.crs
|
|
|
|
min_lon, min_lat, max_lon, max_lat = aoi_bbox_wgs84
|
|
|
|
# Transform corners to source CRS
|
|
transform_coords = transform(
|
|
{"init": "EPSG:4326"},
|
|
src_crs,
|
|
[min_lon, max_lon],
|
|
[min_lat, max_lat]
|
|
)
|
|
|
|
# Get pixel coordinates (note: row/col order)
|
|
col_min, row_min = src.index(transform_coords[0][0], transform_coords[1][0])
|
|
col_max, row_max = src.index(transform_coords[0][1], transform_coords[1][1])
|
|
|
|
# Ensure correct order
|
|
col_min, col_max = min(col_min, col_max), max(col_min, col_max)
|
|
row_min, row_max = min(row_min, row_max), max(row_min, row_max)
|
|
|
|
# Clamp to bounds
|
|
col_min = max(0, col_min)
|
|
row_min = max(0, row_min)
|
|
col_max = min(src.width, col_max)
|
|
row_max = min(src.height, row_max)
|
|
|
|
# Skip if no overlap
|
|
if col_max <= col_min or row_max <= row_min:
|
|
return None, None, None
|
|
|
|
# Create window
|
|
window = Window(col_min, row_min, col_max - col_min, row_max - row_min)
|
|
|
|
# Read data
|
|
data = src.read(1, window=window)
|
|
|
|
# Build profile for this window
|
|
profile = {
|
|
"driver": "GTiff",
|
|
"height": data.shape[0],
|
|
"width": data.shape[1],
|
|
"count": 1,
|
|
"dtype": rasterio.int16,
|
|
"nodata": 0, # DW uses 0 as nodata
|
|
"crs": src_crs,
|
|
"transform": src.window_transform(window),
|
|
"compress": "deflate",
|
|
}
|
|
|
|
return window, profile, data
|
|
|
|
|
|
def mosaic_windows(
|
|
windows_data: List[Tuple[Window, np.ndarray, dict]],
|
|
aoi_bbox_wgs84: List[float],
|
|
target_crs: str,
|
|
) -> Tuple[np.ndarray, dict]:
|
|
"""Mosaic multiple tile windows into single array.
|
|
|
|
Args:
|
|
windows_data: List of (window, data, profile) tuples
|
|
aoi_bbox_wgs84: Original AOI bbox in WGS84
|
|
target_crs: Target CRS for output
|
|
|
|
Returns:
|
|
Tuple of (mosaic_array, profile)
|
|
"""
|
|
if not windows_data:
|
|
raise ValueError("No windows to mosaic")
|
|
|
|
if len(windows_data) == 1:
|
|
# Single tile - just return
|
|
_, data, profile = windows_data[0]
|
|
return data, profile
|
|
|
|
# Multiple tiles - need to compute common bounds
|
|
# Use the first tile's CRS as target
|
|
_, _, first_profile = windows_data[0]
|
|
target_crs = first_profile["crs"]
|
|
|
|
# Compute bounds in target CRS
|
|
all_bounds = []
|
|
for window, data, profile in windows_data:
|
|
if data is None or data.size == 0:
|
|
continue
|
|
# Get bounds from profile transform
|
|
t = profile["transform"]
|
|
h, w = data.shape
|
|
bounds = [t[2], t[5], t[2] + w * t[0], t[5] + h * t[3]]
|
|
all_bounds.append(bounds)
|
|
|
|
if not all_bounds:
|
|
raise ValueError("No valid data in windows")
|
|
|
|
# Compute union bounds
|
|
min_x = min(b[0] for b in all_bounds)
|
|
min_y = min(b[1] for b in all_bounds)
|
|
max_x = max(b[2] for b in all_bounds)
|
|
max_y = max(b[3] for b in all_bounds)
|
|
|
|
# Use resolution from first tile
|
|
res = abs(first_profile["transform"][0])
|
|
|
|
# Compute output shape
|
|
out_width = int((max_x - min_x) / res)
|
|
out_height = int((max_y - min_y) / res)
|
|
|
|
# Create output array
|
|
mosaic = np.zeros((out_height, out_width), dtype=np.int16)
|
|
|
|
# Paste each window
|
|
for window, data, profile in windows_data:
|
|
if data is None or data.size == 0:
|
|
continue
|
|
|
|
t = profile["transform"]
|
|
# Compute offset
|
|
col_off = int((t[2] - min_x) / res)
|
|
row_off = int((t[5] - max_y + res) / res) # Note: transform origin is top-left
|
|
|
|
# Ensure valid
|
|
if col_off < 0:
|
|
data = data[:, -col_off:]
|
|
col_off = 0
|
|
if row_off < 0:
|
|
data = data[-row_off:, :]
|
|
row_off = 0
|
|
|
|
# Paste
|
|
h, w = data.shape
|
|
end_row = min(row_off + h, out_height)
|
|
end_col = min(col_off + w, out_width)
|
|
|
|
if end_row > row_off and end_col > col_off:
|
|
mosaic[row_off:end_row, col_off:end_col] = data[:end_row-row_off, :end_col-col_off]
|
|
|
|
# Build output profile
|
|
from rasterio.transform import from_origin
|
|
out_transform = from_origin(min_x, max_y, res, res)
|
|
|
|
profile = {
|
|
"driver": "GTiff",
|
|
"height": out_height,
|
|
"width": out_width,
|
|
"count": 1,
|
|
"dtype": rasterio.int16,
|
|
"nodata": 0,
|
|
"crs": target_crs,
|
|
"transform": out_transform,
|
|
"compress": "deflate",
|
|
}
|
|
|
|
return mosaic, profile
|
|
|
|
|
|
def load_dw_baseline_window(
|
|
storage,
|
|
year: int,
|
|
aoi_bbox_wgs84: List[float],
|
|
season: str = "summer",
|
|
dw_type: str = "HighestConf",
|
|
bucket: str = DW_BUCKET,
|
|
max_retries: int = 3,
|
|
) -> Tuple[np.ndarray, dict]:
|
|
"""Load DW baseline clipped to AOI window from MinIO.
|
|
|
|
Uses efficient windowed reads to avoid downloading entire tiles.
|
|
|
|
Args:
|
|
storage: MinIOStorage instance with presign_get method
|
|
year: Growing season year (e.g., 2022 for 2022_2023 season)
|
|
season: Season (summer/winter) - maps to prefix
|
|
aoi_bbox_wgs84: AOI bounding box [min_lon, min_lat, max_lon, max_lat] in WGS84
|
|
dw_type: Type - "HighestConf", "Agreement", or "Mode"
|
|
bucket: MinIO bucket name
|
|
max_retries: Maximum retry attempts for failed reads
|
|
|
|
Returns:
|
|
Tuple of:
|
|
- dw_arr: uint8 (or int16) baseline raster clipped to AOI window
|
|
- profile: rasterio profile for writing outputs aligned to this window
|
|
|
|
Raises:
|
|
FileNotFoundError: If no matching DW tile found
|
|
RuntimeError: If window read fails after retries
|
|
"""
|
|
if not HAS_RASTERIO:
|
|
raise ImportError("rasterio is required for DW baseline loading")
|
|
|
|
# Step 1: List matching objects
|
|
matching_keys = list_dw_objects(storage, year, season, dw_type, bucket)
|
|
|
|
if not matching_keys:
|
|
prefix = f"dw/zim/{season}/"
|
|
raise FileNotFoundError(
|
|
f"No DW baseline found for year={year}, type={dw_type}, "
|
|
f"season={season}. Searched prefix: {prefix}"
|
|
)
|
|
|
|
# Step 2: For each tile, get presigned URL and read window
|
|
windows_data = []
|
|
last_error = None
|
|
|
|
for key in matching_keys:
|
|
for attempt in range(max_retries):
|
|
try:
|
|
# Get presigned URL
|
|
url = storage.presign_get(bucket, key, expires=3600)
|
|
|
|
# Get window
|
|
window, profile, data = get_dw_tile_window(url, aoi_bbox_wgs84)
|
|
|
|
if data is not None and data.size > 0:
|
|
windows_data.append((window, data, profile))
|
|
|
|
break # Success, move to next tile
|
|
|
|
except Exception as e:
|
|
last_error = e
|
|
if attempt < max_retries - 1:
|
|
wait_time = 2 ** attempt # Exponential backoff
|
|
time.sleep(wait_time)
|
|
continue
|
|
|
|
if not windows_data:
|
|
raise RuntimeError(
|
|
f"Failed to read any DW tiles after {max_retries} retries. "
|
|
f"Last error: {last_error}"
|
|
)
|
|
|
|
# Step 3: Mosaic if needed
|
|
dw_arr, profile = mosaic_windows(windows_data, aoi_bbox_wgs84, bucket)
|
|
|
|
return dw_arr, profile
|
|
|
|
|
|
def get_dw_class_name(class_id: int) -> str:
|
|
"""Get DW class name from class ID.
|
|
|
|
Args:
|
|
class_id: DW class ID (0-9)
|
|
|
|
Returns:
|
|
Class name or "unknown"
|
|
"""
|
|
if 0 <= class_id < len(DW_CLASS_NAMES):
|
|
return DW_CLASS_NAMES[class_id]
|
|
return "unknown"
|
|
|
|
|
|
def get_dw_class_color(class_id: int) -> str:
|
|
"""Get DW class color from class ID.
|
|
|
|
Args:
|
|
class_id: DW class ID (0-9)
|
|
|
|
Returns:
|
|
Hex color code
|
|
"""
|
|
if 0 <= class_id < len(DW_CLASS_COLORS):
|
|
return DW_CLASS_COLORS[class_id]
|
|
return "#000000"
|
|
|
|
|
|
# ==========================================
|
|
# Self-Test
|
|
# ==========================================
|
|
|
|
if __name__ == "__main__":
|
|
print("=== DW Baseline Loader Test ===")
|
|
|
|
if not HAS_RASTERIO:
|
|
print("rasterio not installed - skipping full test")
|
|
print("Import test: PASS (module loads)")
|
|
else:
|
|
# Test object listing (without real storage)
|
|
print("\n1. Testing DW object pattern...")
|
|
year = 2018
|
|
season = "summer"
|
|
dw_type = "HighestConf"
|
|
|
|
# Simulate what list_dw_objects would return based on known files
|
|
print(f" Year: {year}, Type: {dw_type}, Season: {season}")
|
|
print(f" Expected pattern: DW_Zim_{dw_type}_{year}_{year+1}-*.tif")
|
|
print(f" This would search prefix: dw/zim/{season}/")
|
|
|
|
# Check if we can import storage
|
|
try:
|
|
from storage import MinIOStorage
|
|
print("\n2. Testing MinIOStorage...")
|
|
|
|
# Try to list objects (will fail without real MinIO)
|
|
storage = MinIOStorage()
|
|
objects = storage.list_objects(DW_BUCKET, f"dw/zim/{season}/")
|
|
|
|
# Filter for our year
|
|
pattern = f"DW_Zim_{dw_type}_{year}_{year + 1}"
|
|
matching = [o for o in objects if pattern in o and o.endswith(".tif")]
|
|
|
|
print(f" Found {len(matching)} matching objects")
|
|
for obj in matching[:5]:
|
|
print(f" {obj}")
|
|
|
|
except Exception as e:
|
|
print(f" MinIO not available: {e}")
|
|
print(" (This is expected outside Kubernetes)")
|
|
|
|
print("\n=== DW Baseline Test Complete ===")
|