442 lines
12 KiB
Python
442 lines
12 KiB
Python
"""Worker contracts: Job payload, output schema, and validation.
|
|
|
|
This module defines the data contracts for the inference worker pipeline.
|
|
It is designed to be tolerant of missing fields with sensible defaults.
|
|
|
|
STEP 1: Contracts module for job payloads and results.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import sys
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
# Pipeline stage names
|
|
STAGES = [
|
|
"fetch_stac",
|
|
"build_features",
|
|
"load_dw",
|
|
"infer",
|
|
"smooth",
|
|
"export_cog",
|
|
"upload",
|
|
"done",
|
|
]
|
|
|
|
# Acceptable model names
|
|
VALID_MODELS = ["Ensemble", "RandomForest", "XGBoost", "LightGBM", "CatBoost"]
|
|
|
|
# Valid smoothing kernel sizes
|
|
VALID_KERNEL_SIZES = [3, 5, 7]
|
|
|
|
# Valid year range (Dynamic World availability)
|
|
MIN_YEAR = 2015
|
|
MAX_YEAR = datetime.now().year
|
|
|
|
# Default class names (TEMPORARY V1 - until fully dynamic)
|
|
# These match the trained model's CLASSES_V1 from training
|
|
CLASSES_V1 = [
|
|
"Avocado", "Banana", "Bare Surface", "Blueberry", "Built-Up", "Cabbage", "Chilli", "Citrus", "Cotton", "Cowpea",
|
|
"Finger Millet", "Forest", "Grassland", "Groundnut", "Macadamia", "Maize", "Pasture Legume", "Pearl Millet",
|
|
"Peas", "Potato", "Roundnut", "Sesame", "Shrubland", "Sorghum", "Soyabean", "Sugarbean", "Sugarcane", "Sunflower",
|
|
"Sunhem", "Sweet Potato", "Tea", "Tobacco", "Tomato", "Water", "Woodland"
|
|
]
|
|
|
|
DEFAULT_CLASS_NAMES = CLASSES_V1
|
|
|
|
|
|
# ==========================================
|
|
# Job Payload
|
|
# ==========================================
|
|
|
|
@dataclass
|
|
class AOI:
|
|
"""Area of Interest specification."""
|
|
lon: float
|
|
lat: float
|
|
radius_m: int
|
|
|
|
def to_tuple(self) -> tuple[float, float, int]:
|
|
"""Convert to (lon, lat, radius_m) tuple for features.py."""
|
|
return (self.lon, self.lat, self.radius_m)
|
|
|
|
|
|
@dataclass
|
|
class OutputOptions:
|
|
"""Output options for the inference job."""
|
|
refined: bool = True
|
|
dw_baseline: bool = True
|
|
true_color: bool = True
|
|
indices: List[str] = field(default_factory=lambda: ["ndvi_peak", "evi_peak", "savi_peak"])
|
|
|
|
|
|
@dataclass
|
|
class STACOptions:
|
|
"""STAC query options (optional overrides)."""
|
|
cloud_cover_lt: int = 20
|
|
max_items: int = 60
|
|
|
|
|
|
@dataclass
|
|
class JobPayload:
|
|
"""Job payload from API/queue.
|
|
|
|
This dataclass is tolerant of missing fields and fills defaults.
|
|
"""
|
|
job_id: str
|
|
user_id: Optional[str] = None
|
|
lat: float = 0.0
|
|
lon: float = 0.0
|
|
radius_m: int = 2000
|
|
year: int = 2022
|
|
season: str = "summer"
|
|
model: str = "Ensemble"
|
|
smoothing_kernel: int = 5
|
|
outputs: OutputOptions = field(default_factory=OutputOptions)
|
|
stac: Optional[STACOptions] = None
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict) -> JobPayload:
|
|
"""Create JobPayload from dictionary, filling defaults for missing fields."""
|
|
# Extract AOI fields
|
|
if "aoi" in data:
|
|
aoi_data = data["aoi"]
|
|
lat = aoi_data.get("lat", data.get("lat", 0.0))
|
|
lon = aoi_data.get("lon", data.get("lon", 0.0))
|
|
radius_m = aoi_data.get("radius_m", data.get("radius_m", 2000))
|
|
else:
|
|
lat = data.get("lat", 0.0)
|
|
lon = data.get("lon", 0.0)
|
|
radius_m = data.get("radius_m", 2000)
|
|
|
|
# Parse outputs
|
|
outputs_data = data.get("outputs", {})
|
|
if isinstance(outputs_data, dict):
|
|
outputs = OutputOptions(
|
|
refined=outputs_data.get("refined", True),
|
|
dw_baseline=outputs_data.get("dw_baseline", True),
|
|
true_color=outputs_data.get("true_color", True),
|
|
indices=outputs_data.get("indices", ["ndvi_peak", "evi_peak", "savi_peak"]),
|
|
)
|
|
else:
|
|
outputs = OutputOptions()
|
|
|
|
# Parse STAC options
|
|
stac_data = data.get("stac")
|
|
if isinstance(stac_data, dict):
|
|
stac = STACOptions(
|
|
cloud_cover_lt=stac_data.get("cloud_cover_lt", 20),
|
|
max_items=stac_data.get("max_items", 60),
|
|
)
|
|
else:
|
|
stac = None
|
|
|
|
return cls(
|
|
job_id=data.get("job_id", ""),
|
|
user_id=data.get("user_id"),
|
|
lat=lat,
|
|
lon=lon,
|
|
radius_m=radius_m,
|
|
year=data.get("year", 2022),
|
|
season=data.get("season", "summer"),
|
|
model=data.get("model", "Ensemble"),
|
|
smoothing_kernel=data.get("smoothing_kernel", 5),
|
|
outputs=outputs,
|
|
stac=stac,
|
|
)
|
|
|
|
def get_aoi(self) -> AOI:
|
|
"""Get AOI object."""
|
|
return AOI(lon=self.lon, lat=self.lat, radius_m=self.radius_m)
|
|
|
|
|
|
# ==========================================
|
|
# Worker Result / Output Schema
|
|
# ==========================================
|
|
|
|
@dataclass
|
|
class Artifact:
|
|
"""Single artifact (file) result."""
|
|
s3_uri: str
|
|
url: str
|
|
|
|
|
|
@dataclass
|
|
class WorkerResult:
|
|
"""Result from worker pipeline."""
|
|
status: str # "success" or "error"
|
|
job_id: str
|
|
stage: str
|
|
message: str = ""
|
|
artifacts: Dict[str, Artifact] = field(default_factory=dict)
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
@classmethod
|
|
def success(cls, job_id: str, stage: str = "done", artifacts: Dict[str, Artifact] = None, metadata: Dict[str, Any] = None) -> WorkerResult:
|
|
"""Create a success result."""
|
|
return cls(
|
|
status="success",
|
|
job_id=job_id,
|
|
stage=stage,
|
|
message="",
|
|
artifacts=artifacts or {},
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
@classmethod
|
|
def error(cls, job_id: str, stage: str, message: str) -> WorkerResult:
|
|
"""Create an error result."""
|
|
return cls(
|
|
status="error",
|
|
job_id=job_id,
|
|
stage=stage,
|
|
message=message,
|
|
artifacts={},
|
|
metadata={},
|
|
)
|
|
|
|
|
|
# ==========================================
|
|
# Validation Helpers
|
|
# ==========================================
|
|
|
|
def validate_radius(radius_m: int) -> int:
|
|
"""Validate radius is within bounds.
|
|
|
|
Args:
|
|
radius_m: Radius in meters
|
|
|
|
Returns:
|
|
Validated radius
|
|
|
|
Raises:
|
|
ValueError: If radius > 5000m
|
|
"""
|
|
if radius_m <= 0 or radius_m > 5000:
|
|
raise ValueError(f"radius_m must be in (0, 5000], got {radius_m}")
|
|
return radius_m
|
|
|
|
|
|
def validate_kernel(kernel: int) -> int:
|
|
"""Validate smoothing kernel is odd and in {3, 5, 7}.
|
|
|
|
Args:
|
|
kernel: Kernel size
|
|
|
|
Returns:
|
|
Validated kernel
|
|
|
|
Raises:
|
|
ValueError: If kernel not in {3, 5, 7}
|
|
"""
|
|
if kernel not in VALID_KERNEL_SIZES:
|
|
raise ValueError(f"kernel must be one of {VALID_KERNEL_SIZES}, got {kernel}")
|
|
return kernel
|
|
|
|
|
|
def validate_year(year: int) -> int:
|
|
"""Validate year is in valid range.
|
|
|
|
Args:
|
|
year: Year
|
|
|
|
Returns:
|
|
Validated year
|
|
|
|
Raises:
|
|
ValueError: If year outside 2015..current
|
|
"""
|
|
current_year = datetime.now().year
|
|
if year < MIN_YEAR or year > current_year:
|
|
raise ValueError(f"year must be in [{MIN_YEAR}, {current_year}], got {year}")
|
|
return year
|
|
|
|
|
|
def validate_model(model: str) -> str:
|
|
"""Validate model name.
|
|
|
|
Args:
|
|
model: Model name
|
|
|
|
Returns:
|
|
Validated model name (with _Raw suffix if needed)
|
|
|
|
Raises:
|
|
ValueError: If model not in VALID_MODELS
|
|
"""
|
|
# Normalize: strip whitespace, preserve case
|
|
model = model.strip()
|
|
|
|
# Check if valid (case-sensitive from VALID_MODELS)
|
|
if model not in VALID_MODELS:
|
|
raise ValueError(f"model must be one of {VALID_MODELS}, got {model}")
|
|
return model
|
|
|
|
|
|
def validate_aoi_zimbabwe_quick(aoi: AOI) -> AOI:
|
|
"""Quick bbox check for AOI in Zimbabwe.
|
|
|
|
This is a quick pre-check using rough bounds.
|
|
For strict validation, use polygon check (TODO).
|
|
|
|
Args:
|
|
aoi: AOI to validate
|
|
|
|
Returns:
|
|
Validated AOI
|
|
|
|
Raises:
|
|
ValueError: If AOI outside rough Zimbabwe bbox
|
|
"""
|
|
# Rough bbox for Zimbabwe (cheap pre-check)
|
|
# Lon: 25.2 to 33.1, Lat: -22.5 to -15.6
|
|
if not (25.2 <= aoi.lon <= 33.1 and -22.5 <= aoi.lat <= -15.6):
|
|
raise ValueError(f"AOI ({aoi.lon}, {aoi.lat}) outside Zimbabwe bounds")
|
|
return aoi
|
|
|
|
|
|
def validate_payload(payload: JobPayload) -> JobPayload:
|
|
"""Validate all payload fields.
|
|
|
|
Args:
|
|
payload: Job payload to validate
|
|
|
|
Returns:
|
|
Validated payload
|
|
|
|
Raises:
|
|
ValueError: If any validation fails
|
|
"""
|
|
# Validate radius
|
|
validate_radius(payload.radius_m)
|
|
|
|
# Validate kernel
|
|
validate_kernel(payload.smoothing_kernel)
|
|
|
|
# Validate year
|
|
validate_year(payload.year)
|
|
|
|
# Validate model
|
|
validate_model(payload.model)
|
|
|
|
# Quick AOI check (bbox only for now)
|
|
aoi = payload.get_aoi()
|
|
validate_aoi_zimbabwe_quick(aoi)
|
|
|
|
return payload
|
|
|
|
|
|
# ==========================================
|
|
# Class Resolution Helper
|
|
# ==========================================
|
|
|
|
def resolve_class_names(model_obj: Any) -> List[str]:
|
|
"""Resolve class names from model object.
|
|
|
|
TEMPORARY V1: Uses DEFAULT_CLASS_NAMES if model doesn't expose classes.
|
|
Later we will make this fully dynamic.
|
|
|
|
Args:
|
|
model_obj: Trained model object (sklearn-compatible)
|
|
|
|
Returns:
|
|
List of class names
|
|
"""
|
|
# Try to get classes from model
|
|
if hasattr(model_obj, 'classes_'):
|
|
classes = model_obj.classes_
|
|
if classes is not None:
|
|
# Handle both numpy arrays and lists
|
|
if hasattr(classes, 'tolist'):
|
|
return classes.tolist()
|
|
return list(classes)
|
|
|
|
# Try common attribute names
|
|
for attr in ['class_names', 'labels', 'classes']:
|
|
if hasattr(model_obj, attr):
|
|
val = getattr(model_obj, attr)
|
|
if val is not None:
|
|
if hasattr(val, 'tolist'):
|
|
return val.tolist()
|
|
return list(val)
|
|
|
|
# Fallback to default (TEMPORARY)
|
|
return DEFAULT_CLASS_NAMES.copy()
|
|
|
|
|
|
# ==========================================
|
|
# Test / Sanity Check
|
|
# ==========================================
|
|
|
|
if __name__ == "__main__":
|
|
# Quick sanity test
|
|
print("Running contracts sanity test...")
|
|
|
|
# Test minimal payload
|
|
minimal = {
|
|
"job_id": "test-123",
|
|
"lat": -17.8,
|
|
"lon": 31.0,
|
|
"radius_m": 2000,
|
|
"year": 2022,
|
|
}
|
|
payload = JobPayload.from_dict(minimal)
|
|
print(f" Minimal payload: job_id={payload.job_id}, model={payload.model}, season={payload.season}")
|
|
assert payload.model == "Ensemble"
|
|
assert payload.season == "summer"
|
|
assert payload.outputs.refined == True
|
|
|
|
# Test full payload
|
|
full = {
|
|
"job_id": "test-456",
|
|
"user_id": "user-789",
|
|
"aoi": {"lon": 31.0, "lat": -17.8, "radius_m": 3000},
|
|
"year": 2023,
|
|
"season": "summer",
|
|
"model": "XGBoost",
|
|
"smoothing_kernel": 7,
|
|
"outputs": {
|
|
"refined": True,
|
|
"dw_baseline": False,
|
|
"true_color": True,
|
|
"indices": ["ndvi_peak"]
|
|
}
|
|
}
|
|
payload2 = JobPayload.from_dict(full)
|
|
print(f" Full payload: model={payload2.model}, kernel={payload2.smoothing_kernel}")
|
|
assert payload2.model == "XGBoost"
|
|
assert payload2.smoothing_kernel == 7
|
|
assert payload2.outputs.indices == ["ndvi_peak"]
|
|
|
|
# Test validation
|
|
try:
|
|
validate_radius(10000)
|
|
print(" ERROR: validate_radius should have raised")
|
|
sys.exit(1)
|
|
except ValueError:
|
|
print(" validate_radius: OK (rejected >5000)")
|
|
|
|
try:
|
|
validate_kernel(4)
|
|
print(" ERROR: validate_kernel should have raised")
|
|
sys.exit(1)
|
|
except ValueError:
|
|
print(" validate_kernel: OK (rejected even)")
|
|
|
|
# Test class resolution
|
|
class MockModel:
|
|
pass
|
|
model = MockModel()
|
|
classes = resolve_class_names(model)
|
|
print(f" resolve_class_names (no attr): {len(classes)} classes")
|
|
assert classes == DEFAULT_CLASS_NAMES
|
|
|
|
model.classes_ = ["Apple", "Banana", "Cherry"]
|
|
classes2 = resolve_class_names(model)
|
|
print(f" resolve_class_names (with attr): {classes2}")
|
|
assert classes2 == ["Apple", "Banana", "Cherry"]
|
|
|
|
print("\n✅ All contracts tests passed!")
|