geocrop-platform./apps/worker/contracts.py

442 lines
12 KiB
Python

"""Worker contracts: Job payload, output schema, and validation.
This module defines the data contracts for the inference worker pipeline.
It is designed to be tolerant of missing fields with sensible defaults.
STEP 1: Contracts module for job payloads and results.
"""
from __future__ import annotations
import sys
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
# Pipeline stage names
STAGES = [
"fetch_stac",
"build_features",
"load_dw",
"infer",
"smooth",
"export_cog",
"upload",
"done",
]
# Acceptable model names
VALID_MODELS = ["Ensemble", "RandomForest", "XGBoost", "LightGBM", "CatBoost"]
# Valid smoothing kernel sizes
VALID_KERNEL_SIZES = [3, 5, 7]
# Valid year range (Dynamic World availability)
MIN_YEAR = 2015
MAX_YEAR = datetime.now().year
# Default class names (TEMPORARY V1 - until fully dynamic)
# These match the trained model's CLASSES_V1 from training
CLASSES_V1 = [
"Avocado", "Banana", "Bare Surface", "Blueberry", "Built-Up", "Cabbage", "Chilli", "Citrus", "Cotton", "Cowpea",
"Finger Millet", "Forest", "Grassland", "Groundnut", "Macadamia", "Maize", "Pasture Legume", "Pearl Millet",
"Peas", "Potato", "Roundnut", "Sesame", "Shrubland", "Sorghum", "Soyabean", "Sugarbean", "Sugarcane", "Sunflower",
"Sunhem", "Sweet Potato", "Tea", "Tobacco", "Tomato", "Water", "Woodland"
]
DEFAULT_CLASS_NAMES = CLASSES_V1
# ==========================================
# Job Payload
# ==========================================
@dataclass
class AOI:
"""Area of Interest specification."""
lon: float
lat: float
radius_m: int
def to_tuple(self) -> tuple[float, float, int]:
"""Convert to (lon, lat, radius_m) tuple for features.py."""
return (self.lon, self.lat, self.radius_m)
@dataclass
class OutputOptions:
"""Output options for the inference job."""
refined: bool = True
dw_baseline: bool = True
true_color: bool = True
indices: List[str] = field(default_factory=lambda: ["ndvi_peak", "evi_peak", "savi_peak"])
@dataclass
class STACOptions:
"""STAC query options (optional overrides)."""
cloud_cover_lt: int = 20
max_items: int = 60
@dataclass
class JobPayload:
"""Job payload from API/queue.
This dataclass is tolerant of missing fields and fills defaults.
"""
job_id: str
user_id: Optional[str] = None
lat: float = 0.0
lon: float = 0.0
radius_m: int = 2000
year: int = 2022
season: str = "summer"
model: str = "Ensemble"
smoothing_kernel: int = 5
outputs: OutputOptions = field(default_factory=OutputOptions)
stac: Optional[STACOptions] = None
@classmethod
def from_dict(cls, data: dict) -> JobPayload:
"""Create JobPayload from dictionary, filling defaults for missing fields."""
# Extract AOI fields
if "aoi" in data:
aoi_data = data["aoi"]
lat = aoi_data.get("lat", data.get("lat", 0.0))
lon = aoi_data.get("lon", data.get("lon", 0.0))
radius_m = aoi_data.get("radius_m", data.get("radius_m", 2000))
else:
lat = data.get("lat", 0.0)
lon = data.get("lon", 0.0)
radius_m = data.get("radius_m", 2000)
# Parse outputs
outputs_data = data.get("outputs", {})
if isinstance(outputs_data, dict):
outputs = OutputOptions(
refined=outputs_data.get("refined", True),
dw_baseline=outputs_data.get("dw_baseline", True),
true_color=outputs_data.get("true_color", True),
indices=outputs_data.get("indices", ["ndvi_peak", "evi_peak", "savi_peak"]),
)
else:
outputs = OutputOptions()
# Parse STAC options
stac_data = data.get("stac")
if isinstance(stac_data, dict):
stac = STACOptions(
cloud_cover_lt=stac_data.get("cloud_cover_lt", 20),
max_items=stac_data.get("max_items", 60),
)
else:
stac = None
return cls(
job_id=data.get("job_id", ""),
user_id=data.get("user_id"),
lat=lat,
lon=lon,
radius_m=radius_m,
year=data.get("year", 2022),
season=data.get("season", "summer"),
model=data.get("model", "Ensemble"),
smoothing_kernel=data.get("smoothing_kernel", 5),
outputs=outputs,
stac=stac,
)
def get_aoi(self) -> AOI:
"""Get AOI object."""
return AOI(lon=self.lon, lat=self.lat, radius_m=self.radius_m)
# ==========================================
# Worker Result / Output Schema
# ==========================================
@dataclass
class Artifact:
"""Single artifact (file) result."""
s3_uri: str
url: str
@dataclass
class WorkerResult:
"""Result from worker pipeline."""
status: str # "success" or "error"
job_id: str
stage: str
message: str = ""
artifacts: Dict[str, Artifact] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
@classmethod
def success(cls, job_id: str, stage: str = "done", artifacts: Dict[str, Artifact] = None, metadata: Dict[str, Any] = None) -> WorkerResult:
"""Create a success result."""
return cls(
status="success",
job_id=job_id,
stage=stage,
message="",
artifacts=artifacts or {},
metadata=metadata or {},
)
@classmethod
def error(cls, job_id: str, stage: str, message: str) -> WorkerResult:
"""Create an error result."""
return cls(
status="error",
job_id=job_id,
stage=stage,
message=message,
artifacts={},
metadata={},
)
# ==========================================
# Validation Helpers
# ==========================================
def validate_radius(radius_m: int) -> int:
"""Validate radius is within bounds.
Args:
radius_m: Radius in meters
Returns:
Validated radius
Raises:
ValueError: If radius > 5000m
"""
if radius_m <= 0 or radius_m > 5000:
raise ValueError(f"radius_m must be in (0, 5000], got {radius_m}")
return radius_m
def validate_kernel(kernel: int) -> int:
"""Validate smoothing kernel is odd and in {3, 5, 7}.
Args:
kernel: Kernel size
Returns:
Validated kernel
Raises:
ValueError: If kernel not in {3, 5, 7}
"""
if kernel not in VALID_KERNEL_SIZES:
raise ValueError(f"kernel must be one of {VALID_KERNEL_SIZES}, got {kernel}")
return kernel
def validate_year(year: int) -> int:
"""Validate year is in valid range.
Args:
year: Year
Returns:
Validated year
Raises:
ValueError: If year outside 2015..current
"""
current_year = datetime.now().year
if year < MIN_YEAR or year > current_year:
raise ValueError(f"year must be in [{MIN_YEAR}, {current_year}], got {year}")
return year
def validate_model(model: str) -> str:
"""Validate model name.
Args:
model: Model name
Returns:
Validated model name (with _Raw suffix if needed)
Raises:
ValueError: If model not in VALID_MODELS
"""
# Normalize: strip whitespace, preserve case
model = model.strip()
# Check if valid (case-sensitive from VALID_MODELS)
if model not in VALID_MODELS:
raise ValueError(f"model must be one of {VALID_MODELS}, got {model}")
return model
def validate_aoi_zimbabwe_quick(aoi: AOI) -> AOI:
"""Quick bbox check for AOI in Zimbabwe.
This is a quick pre-check using rough bounds.
For strict validation, use polygon check (TODO).
Args:
aoi: AOI to validate
Returns:
Validated AOI
Raises:
ValueError: If AOI outside rough Zimbabwe bbox
"""
# Rough bbox for Zimbabwe (cheap pre-check)
# Lon: 25.2 to 33.1, Lat: -22.5 to -15.6
if not (25.2 <= aoi.lon <= 33.1 and -22.5 <= aoi.lat <= -15.6):
raise ValueError(f"AOI ({aoi.lon}, {aoi.lat}) outside Zimbabwe bounds")
return aoi
def validate_payload(payload: JobPayload) -> JobPayload:
"""Validate all payload fields.
Args:
payload: Job payload to validate
Returns:
Validated payload
Raises:
ValueError: If any validation fails
"""
# Validate radius
validate_radius(payload.radius_m)
# Validate kernel
validate_kernel(payload.smoothing_kernel)
# Validate year
validate_year(payload.year)
# Validate model
validate_model(payload.model)
# Quick AOI check (bbox only for now)
aoi = payload.get_aoi()
validate_aoi_zimbabwe_quick(aoi)
return payload
# ==========================================
# Class Resolution Helper
# ==========================================
def resolve_class_names(model_obj: Any) -> List[str]:
"""Resolve class names from model object.
TEMPORARY V1: Uses DEFAULT_CLASS_NAMES if model doesn't expose classes.
Later we will make this fully dynamic.
Args:
model_obj: Trained model object (sklearn-compatible)
Returns:
List of class names
"""
# Try to get classes from model
if hasattr(model_obj, 'classes_'):
classes = model_obj.classes_
if classes is not None:
# Handle both numpy arrays and lists
if hasattr(classes, 'tolist'):
return classes.tolist()
return list(classes)
# Try common attribute names
for attr in ['class_names', 'labels', 'classes']:
if hasattr(model_obj, attr):
val = getattr(model_obj, attr)
if val is not None:
if hasattr(val, 'tolist'):
return val.tolist()
return list(val)
# Fallback to default (TEMPORARY)
return DEFAULT_CLASS_NAMES.copy()
# ==========================================
# Test / Sanity Check
# ==========================================
if __name__ == "__main__":
# Quick sanity test
print("Running contracts sanity test...")
# Test minimal payload
minimal = {
"job_id": "test-123",
"lat": -17.8,
"lon": 31.0,
"radius_m": 2000,
"year": 2022,
}
payload = JobPayload.from_dict(minimal)
print(f" Minimal payload: job_id={payload.job_id}, model={payload.model}, season={payload.season}")
assert payload.model == "Ensemble"
assert payload.season == "summer"
assert payload.outputs.refined == True
# Test full payload
full = {
"job_id": "test-456",
"user_id": "user-789",
"aoi": {"lon": 31.0, "lat": -17.8, "radius_m": 3000},
"year": 2023,
"season": "summer",
"model": "XGBoost",
"smoothing_kernel": 7,
"outputs": {
"refined": True,
"dw_baseline": False,
"true_color": True,
"indices": ["ndvi_peak"]
}
}
payload2 = JobPayload.from_dict(full)
print(f" Full payload: model={payload2.model}, kernel={payload2.smoothing_kernel}")
assert payload2.model == "XGBoost"
assert payload2.smoothing_kernel == 7
assert payload2.outputs.indices == ["ndvi_peak"]
# Test validation
try:
validate_radius(10000)
print(" ERROR: validate_radius should have raised")
sys.exit(1)
except ValueError:
print(" validate_radius: OK (rejected >5000)")
try:
validate_kernel(4)
print(" ERROR: validate_kernel should have raised")
sys.exit(1)
except ValueError:
print(" validate_kernel: OK (rejected even)")
# Test class resolution
class MockModel:
pass
model = MockModel()
classes = resolve_class_names(model)
print(f" resolve_class_names (no attr): {len(classes)} classes")
assert classes == DEFAULT_CLASS_NAMES
model.classes_ = ["Apple", "Banana", "Cherry"]
classes2 = resolve_class_names(model)
print(f" resolve_class_names (with attr): {classes2}")
assert classes2 == ["Apple", "Banana", "Cherry"]
print("\n✅ All contracts tests passed!")