"""Worker contracts: Job payload, output schema, and validation. This module defines the data contracts for the inference worker pipeline. It is designed to be tolerant of missing fields with sensible defaults. STEP 1: Contracts module for job payloads and results. """ from __future__ import annotations import sys from dataclasses import dataclass, field from datetime import datetime from typing import Any, Dict, List, Optional # Pipeline stage names STAGES = [ "fetch_stac", "build_features", "load_dw", "infer", "smooth", "export_cog", "upload", "done", ] # Acceptable model names VALID_MODELS = ["Ensemble", "RandomForest", "XGBoost", "LightGBM", "CatBoost"] # Valid smoothing kernel sizes VALID_KERNEL_SIZES = [3, 5, 7] # Valid year range (Dynamic World availability) MIN_YEAR = 2015 MAX_YEAR = datetime.now().year # Default class names (TEMPORARY V1 - until fully dynamic) # These match the trained model's CLASSES_V1 from training CLASSES_V1 = [ "Avocado", "Banana", "Bare Surface", "Blueberry", "Built-Up", "Cabbage", "Chilli", "Citrus", "Cotton", "Cowpea", "Finger Millet", "Forest", "Grassland", "Groundnut", "Macadamia", "Maize", "Pasture Legume", "Pearl Millet", "Peas", "Potato", "Roundnut", "Sesame", "Shrubland", "Sorghum", "Soyabean", "Sugarbean", "Sugarcane", "Sunflower", "Sunhem", "Sweet Potato", "Tea", "Tobacco", "Tomato", "Water", "Woodland" ] DEFAULT_CLASS_NAMES = CLASSES_V1 # ========================================== # Job Payload # ========================================== @dataclass class AOI: """Area of Interest specification.""" lon: float lat: float radius_m: int def to_tuple(self) -> tuple[float, float, int]: """Convert to (lon, lat, radius_m) tuple for features.py.""" return (self.lon, self.lat, self.radius_m) @dataclass class OutputOptions: """Output options for the inference job.""" refined: bool = True dw_baseline: bool = True true_color: bool = True indices: List[str] = field(default_factory=lambda: ["ndvi_peak", "evi_peak", "savi_peak"]) @dataclass class STACOptions: """STAC query options (optional overrides).""" cloud_cover_lt: int = 20 max_items: int = 60 @dataclass class JobPayload: """Job payload from API/queue. This dataclass is tolerant of missing fields and fills defaults. """ job_id: str user_id: Optional[str] = None lat: float = 0.0 lon: float = 0.0 radius_m: int = 2000 year: int = 2022 season: str = "summer" model: str = "Ensemble" smoothing_kernel: int = 5 outputs: OutputOptions = field(default_factory=OutputOptions) stac: Optional[STACOptions] = None @classmethod def from_dict(cls, data: dict) -> JobPayload: """Create JobPayload from dictionary, filling defaults for missing fields.""" # Extract AOI fields if "aoi" in data: aoi_data = data["aoi"] lat = aoi_data.get("lat", data.get("lat", 0.0)) lon = aoi_data.get("lon", data.get("lon", 0.0)) radius_m = aoi_data.get("radius_m", data.get("radius_m", 2000)) else: lat = data.get("lat", 0.0) lon = data.get("lon", 0.0) radius_m = data.get("radius_m", 2000) # Parse outputs outputs_data = data.get("outputs", {}) if isinstance(outputs_data, dict): outputs = OutputOptions( refined=outputs_data.get("refined", True), dw_baseline=outputs_data.get("dw_baseline", True), true_color=outputs_data.get("true_color", True), indices=outputs_data.get("indices", ["ndvi_peak", "evi_peak", "savi_peak"]), ) else: outputs = OutputOptions() # Parse STAC options stac_data = data.get("stac") if isinstance(stac_data, dict): stac = STACOptions( cloud_cover_lt=stac_data.get("cloud_cover_lt", 20), max_items=stac_data.get("max_items", 60), ) else: stac = None return cls( job_id=data.get("job_id", ""), user_id=data.get("user_id"), lat=lat, lon=lon, radius_m=radius_m, year=data.get("year", 2022), season=data.get("season", "summer"), model=data.get("model", "Ensemble"), smoothing_kernel=data.get("smoothing_kernel", 5), outputs=outputs, stac=stac, ) def get_aoi(self) -> AOI: """Get AOI object.""" return AOI(lon=self.lon, lat=self.lat, radius_m=self.radius_m) # ========================================== # Worker Result / Output Schema # ========================================== @dataclass class Artifact: """Single artifact (file) result.""" s3_uri: str url: str @dataclass class WorkerResult: """Result from worker pipeline.""" status: str # "success" or "error" job_id: str stage: str message: str = "" artifacts: Dict[str, Artifact] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict) @classmethod def success(cls, job_id: str, stage: str = "done", artifacts: Dict[str, Artifact] = None, metadata: Dict[str, Any] = None) -> WorkerResult: """Create a success result.""" return cls( status="success", job_id=job_id, stage=stage, message="", artifacts=artifacts or {}, metadata=metadata or {}, ) @classmethod def error(cls, job_id: str, stage: str, message: str) -> WorkerResult: """Create an error result.""" return cls( status="error", job_id=job_id, stage=stage, message=message, artifacts={}, metadata={}, ) # ========================================== # Validation Helpers # ========================================== def validate_radius(radius_m: int) -> int: """Validate radius is within bounds. Args: radius_m: Radius in meters Returns: Validated radius Raises: ValueError: If radius > 5000m """ if radius_m <= 0 or radius_m > 5000: raise ValueError(f"radius_m must be in (0, 5000], got {radius_m}") return radius_m def validate_kernel(kernel: int) -> int: """Validate smoothing kernel is odd and in {3, 5, 7}. Args: kernel: Kernel size Returns: Validated kernel Raises: ValueError: If kernel not in {3, 5, 7} """ if kernel not in VALID_KERNEL_SIZES: raise ValueError(f"kernel must be one of {VALID_KERNEL_SIZES}, got {kernel}") return kernel def validate_year(year: int) -> int: """Validate year is in valid range. Args: year: Year Returns: Validated year Raises: ValueError: If year outside 2015..current """ current_year = datetime.now().year if year < MIN_YEAR or year > current_year: raise ValueError(f"year must be in [{MIN_YEAR}, {current_year}], got {year}") return year def validate_model(model: str) -> str: """Validate model name. Args: model: Model name Returns: Validated model name (with _Raw suffix if needed) Raises: ValueError: If model not in VALID_MODELS """ # Normalize: strip whitespace, preserve case model = model.strip() # Check if valid (case-sensitive from VALID_MODELS) if model not in VALID_MODELS: raise ValueError(f"model must be one of {VALID_MODELS}, got {model}") return model def validate_aoi_zimbabwe_quick(aoi: AOI) -> AOI: """Quick bbox check for AOI in Zimbabwe. This is a quick pre-check using rough bounds. For strict validation, use polygon check (TODO). Args: aoi: AOI to validate Returns: Validated AOI Raises: ValueError: If AOI outside rough Zimbabwe bbox """ # Rough bbox for Zimbabwe (cheap pre-check) # Lon: 25.2 to 33.1, Lat: -22.5 to -15.6 if not (25.2 <= aoi.lon <= 33.1 and -22.5 <= aoi.lat <= -15.6): raise ValueError(f"AOI ({aoi.lon}, {aoi.lat}) outside Zimbabwe bounds") return aoi def validate_payload(payload: JobPayload) -> JobPayload: """Validate all payload fields. Args: payload: Job payload to validate Returns: Validated payload Raises: ValueError: If any validation fails """ # Validate radius validate_radius(payload.radius_m) # Validate kernel validate_kernel(payload.smoothing_kernel) # Validate year validate_year(payload.year) # Validate model validate_model(payload.model) # Quick AOI check (bbox only for now) aoi = payload.get_aoi() validate_aoi_zimbabwe_quick(aoi) return payload # ========================================== # Class Resolution Helper # ========================================== def resolve_class_names(model_obj: Any) -> List[str]: """Resolve class names from model object. TEMPORARY V1: Uses DEFAULT_CLASS_NAMES if model doesn't expose classes. Later we will make this fully dynamic. Args: model_obj: Trained model object (sklearn-compatible) Returns: List of class names """ # Try to get classes from model if hasattr(model_obj, 'classes_'): classes = model_obj.classes_ if classes is not None: # Handle both numpy arrays and lists if hasattr(classes, 'tolist'): return classes.tolist() return list(classes) # Try common attribute names for attr in ['class_names', 'labels', 'classes']: if hasattr(model_obj, attr): val = getattr(model_obj, attr) if val is not None: if hasattr(val, 'tolist'): return val.tolist() return list(val) # Fallback to default (TEMPORARY) return DEFAULT_CLASS_NAMES.copy() # ========================================== # Test / Sanity Check # ========================================== if __name__ == "__main__": # Quick sanity test print("Running contracts sanity test...") # Test minimal payload minimal = { "job_id": "test-123", "lat": -17.8, "lon": 31.0, "radius_m": 2000, "year": 2022, } payload = JobPayload.from_dict(minimal) print(f" Minimal payload: job_id={payload.job_id}, model={payload.model}, season={payload.season}") assert payload.model == "Ensemble" assert payload.season == "summer" assert payload.outputs.refined == True # Test full payload full = { "job_id": "test-456", "user_id": "user-789", "aoi": {"lon": 31.0, "lat": -17.8, "radius_m": 3000}, "year": 2023, "season": "summer", "model": "XGBoost", "smoothing_kernel": 7, "outputs": { "refined": True, "dw_baseline": False, "true_color": True, "indices": ["ndvi_peak"] } } payload2 = JobPayload.from_dict(full) print(f" Full payload: model={payload2.model}, kernel={payload2.smoothing_kernel}") assert payload2.model == "XGBoost" assert payload2.smoothing_kernel == 7 assert payload2.outputs.indices == ["ndvi_peak"] # Test validation try: validate_radius(10000) print(" ERROR: validate_radius should have raised") sys.exit(1) except ValueError: print(" validate_radius: OK (rejected >5000)") try: validate_kernel(4) print(" ERROR: validate_kernel should have raised") sys.exit(1) except ValueError: print(" validate_kernel: OK (rejected even)") # Test class resolution class MockModel: pass model = MockModel() classes = resolve_class_names(model) print(f" resolve_class_names (no attr): {len(classes)} classes") assert classes == DEFAULT_CLASS_NAMES model.classes_ = ["Apple", "Banana", "Cherry"] classes2 = resolve_class_names(model) print(f" resolve_class_names (with attr): {classes2}") assert classes2 == ["Apple", "Banana", "Cherry"] print("\n✅ All contracts tests passed!")