refactor: remove sklearn inference.py, add async DW baseline loading
Build and Push Docker Images / build-and-push (push) Failing after 4m27s Details

- Deleted inference.py (sklearn path) in favor of hybrid_inference.py
- Worker now uses ThreadPoolExecutor for async DW baseline loading
- DW baseline URL sent to client as soon as ready, parallel to inference
- Removed sklearn model fallback (only Hybrid_SpatioTemporal supported)
- Updated docstring to reflect current module dependencies
This commit is contained in:
fchinembiri 2026-05-04 18:46:36 +02:00
parent 18aa966dc8
commit e2cfec586b
8 changed files with 162 additions and 925 deletions

View File

@ -1,650 +0,0 @@
"""GeoCrop inference pipeline (worker-side).
This module is designed to be called by your RQ worker.
Given a job payload (AOI, year, model choice), it:
1) Loads the correct model artifact from MinIO (or local cache).
2) Loads/clips the DW baseline COG for the requested season/year.
3) Queries Digital Earth Africa STAC for imagery and builds feature stack.
- IMPORTANT: Uses exact feature engineering from train.py:
- Savitzky-Golay smoothing (window=5, polyorder=2)
- Phenology metrics (amplitude, AUC, peak, slope)
- Harmonic features (1st/2nd order sin/cos)
- Seasonal window statistics (Early/Peak/Late)
4) Runs per-pixel inference to produce refined classes at 10m.
5) Applies neighborhood smoothing (majority filter).
6) Writes output GeoTIFF (COG recommended) to MinIO.
IMPORTANT: This implementation supports the current MinIO model format:
- Zimbabwe_Ensemble_Raw_Model.pkl (no scaler needed)
- Zimbabwe_Ensemble_Model.pkl (scaler needed)
- etc.
"""
from __future__ import annotations
import json
import os
import tempfile
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Tuple, List
# Try to import required dependencies
try:
import joblib
except ImportError:
joblib = None
try:
import numpy as np
except ImportError:
np = None
try:
import rasterio
from rasterio import windows
from rasterio.enums import Resampling
except ImportError:
rasterio = None
windows = None
Resampling = None
try:
from config import InferenceConfig
except ImportError:
InferenceConfig = None
try:
from features import (
build_feature_stack_from_dea,
clip_raster_to_aoi,
load_dw_baseline_window,
majority_filter,
validate_aoi_zimbabwe,
)
except ImportError:
pass
# ==========================================
# STEP 6: Model Loading and Raster Prediction
# ==========================================
def load_model(storage, model_name: str):
"""Load a trained model from MinIO storage.
Args:
storage: MinIOStorage instance with download_model_file method
model_name: Name of model (e.g., "RandomForest", "XGBoost", "Ensemble")
Returns:
Loaded sklearn-compatible model
Raises:
FileNotFoundError: If model file not found
ValueError: If model has incompatible number of features
"""
# Create temp directory for download
import tempfile
with tempfile.TemporaryDirectory() as tmp_dir:
dest_dir = Path(tmp_dir)
# Download model file from MinIO
# storage.download_model_file already handles mapping
model_path = storage.download_model_file(model_name, dest_dir)
# Load model with joblib
model = joblib.load(model_path)
# Validate model compatibility
if hasattr(model, 'n_features_in_'):
from feature_computation import FEATURE_ORDER_V1, FEATURE_ORDER_V2
actual_features = model.n_features_in_
if actual_features == len(FEATURE_ORDER_V1):
print(f"Detected V1 model ({actual_features} features)")
elif actual_features == len(FEATURE_ORDER_V2):
print(f"Detected V2 model ({actual_features} features)")
else:
raise ValueError(
f"Model feature mismatch: model expects {actual_features} features. "
f"Available versions: V1 ({len(FEATURE_ORDER_V1)}), V2 ({len(FEATURE_ORDER_V2)})."
)
return model
def predict_raster(
model,
feature_cube: np.ndarray,
feature_order: List[str],
) -> np.ndarray:
"""Run inference on a feature cube.
Args:
model: Trained sklearn-compatible model
feature_cube: 3D array of shape (H, W, 51) containing features
feature_order: List of 51 feature names in order
Returns:
2D array of shape (H, W) with class predictions
Raises:
ValueError: If feature_cube dimensions don't match feature_order
"""
# Validate dimensions
expected_features = len(feature_order)
actual_features = feature_cube.shape[-1]
if actual_features != expected_features:
raise ValueError(
f"Feature dimension mismatch: feature_cube has {actual_features} features "
f"but feature_order has {expected_features}. "
f"feature_cube shape: {feature_cube.shape}, feature_order length: {len(feature_order)}. "
f"Expected 51 features matching FEATURE_ORDER_V1."
)
H, W, C = feature_cube.shape
# Flatten spatial dimensions: (H, W, C) -> (H*W, C)
X = feature_cube.reshape(-1, C)
# Identify nodata pixels (all zeros)
nodata_mask = np.all(X == 0, axis=1)
num_nodata = np.sum(nodata_mask)
# Replace nodata with small non-zero values to avoid model issues
# The predictions will be overwritten for nodata pixels anyway
X_safe = X.copy()
if num_nodata > 0:
# Use epsilon to avoid division by zero in some models
X_safe[nodata_mask] = np.full(C, 1e-6)
# Run prediction
y_pred = model.predict(X_safe)
# Set nodata pixels to 0 (assuming class 0 reserved for nodata)
if num_nodata > 0:
y_pred[nodata_mask] = 0
# Reshape back to (H, W)
result = y_pred.reshape(H, W)
return result
# ==========================================
# Legacy functions (kept for backward compatibility)
# ==========================================
# Model name to MinIO filename mapping
# Format: "Zimbabwe_<ModelName>_Model.pkl" or "Zimbabwe_<ModelName>_Raw_Model.pkl"
MODEL_NAME_MAPPING = {
# Ensemble models
"Ensemble": "Zimbabwe_Ensemble_Raw_Model.pkl",
"Ensemble_Raw": "Zimbabwe_Ensemble_Raw_Model.pkl",
"Ensemble_Scaled": "Zimbabwe_Ensemble_Model.pkl",
# Individual models
"RandomForest": "Zimbabwe_RandomForest_Model.pkl",
"XGBoost": "Zimbabwe_XGBoost_Model.pkl",
"LightGBM": "Zimbabwe_LightGBM_Model.pkl",
"CatBoost": "Zimbabwe_CatBoost_Model.pkl",
# Legacy/raw variants
"RandomForest_Raw": "Zimbabwe_RandomForest_Model.pkl",
"XGBoost_Raw": "Zimbabwe_XGBoost_Model.pkl",
"LightGBM_Raw": "Zimbabwe_LightGBM_Model.pkl",
"CatBoost_Raw": "Zimbabwe_CatBoost_Model.pkl",
}
# Default class mapping if label encoder not available
# Based on typical Zimbabwe crop classification
DEFAULT_CLASSES = [
"cropland_rainfed",
"cropland_irrigated",
"tree_crop",
"grassland",
"shrubland",
"urban",
"water",
"bare",
]
@dataclass
class InferenceResult:
job_id: str
status: str
outputs: Dict[str, str]
meta: Dict
def _local_artifact_cache_dir() -> Path:
d = Path(os.getenv("GEOCROP_CACHE_DIR", "/tmp/geocrop-cache"))
d.mkdir(parents=True, exist_ok=True)
return d
def get_model_filename(model_name: str) -> str:
"""Get the MinIO filename for a given model name.
Args:
model_name: Model name from job payload (e.g., "Ensemble", "Ensemble_Scaled")
Returns:
MinIO filename (e.g., "Zimbabwe_Ensemble_Raw_Model.pkl")
"""
# Direct lookup
if model_name in MODEL_NAME_MAPPING:
return MODEL_NAME_MAPPING[model_name]
# Try case-insensitive
model_lower = model_name.lower()
for key, value in MODEL_NAME_MAPPING.items():
if key.lower() == model_lower:
return value
# Default fallback
if "_raw" in model_lower:
return f"Zimbabwe_{model_name.replace('_Raw', '').title()}_Raw_Model.pkl"
else:
return f"Zimbabwe_{model_name.title()}_Model.pkl"
def needs_scaler(model_name: str) -> bool:
"""Determine if a model needs feature scaling.
Models with "_Raw" suffix do NOT need scaling.
All other models require StandardScaler.
Args:
model_name: Model name from job payload
Returns:
True if scaler should be applied
"""
# Check for _Raw suffix
if "_raw" in model_name.lower():
return False
# Ensemble without suffix defaults to raw
if model_name.lower() == "ensemble":
return False
# Default: needs scaling
return True
def load_model_artifacts(cfg: InferenceConfig, model_name: str) -> Tuple[object, object, Optional[object], List[str]]:
"""Load model, label encoder, optional scaler, and feature list.
Supports current MinIO format:
- Zimbabwe_*_Raw_Model.pkl (no scaler)
- Zimbabwe_*_Model.pkl (needs scaler)
Args:
cfg: Inference configuration
model_name: Name of the model to load
Returns:
Tuple of (model, label_encoder, scaler, selected_features)
"""
cache = _local_artifact_cache_dir() / model_name.replace(" ", "_")
cache.mkdir(parents=True, exist_ok=True)
# Get the MinIO filename
model_filename = get_model_filename(model_name)
model_key = f"models/{model_filename}" # Prefix in bucket
model_p = cache / "model.pkl"
le_p = cache / "label_encoder.pkl"
scaler_p = cache / "scaler.pkl"
feats_p = cache / "selected_features.json"
# Check if cached
if not model_p.exists():
print(f"📥 Downloading model from MinIO: {model_key}")
cfg.storage.download_model_bundle(model_key, cache)
# Load model
model = joblib.load(model_p)
# Load or create label encoder
if le_p.exists():
label_encoder = joblib.load(le_p)
else:
# Try to get classes from model
print("⚠️ Label encoder not found, creating default")
from sklearn.preprocessing import LabelEncoder
label_encoder = LabelEncoder()
# Fit on default classes
label_encoder.fit(DEFAULT_CLASSES)
# Load scaler if needed
scaler = None
if needs_scaler(model_name):
if scaler_p.exists():
scaler = joblib.load(scaler_p)
else:
print("⚠️ Scaler not found but required for this model variant")
# Create a dummy scaler that does nothing
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
# Note: In production, this should fail - scaler must be uploaded
# Load selected features
if feats_p.exists():
selected_features = json.loads(feats_p.read_text())
else:
print("⚠️ Selected features not found, will use all computed features")
selected_features = None
return model, label_encoder, scaler, selected_features
def run_inference_job(cfg: InferenceConfig, job: Dict) -> InferenceResult:
"""Main worker entry.
job payload example:
{
"job_id": "...",
"user_id": "...",
"lat": -17.8,
"lon": 31.0,
"radius_m": 2000,
"year": 2022,
"season": "summer",
"model": "Ensemble" # or "Ensemble_Scaled", "RandomForest", etc.
}
"""
job_id = str(job.get("job_id"))
# 1) Validate AOI constraints
aoi = (float(job["lon"]), float(job["lat"]), float(job["radius_m"]))
validate_aoi_zimbabwe(aoi, max_radius_m=cfg.max_radius_m)
year = int(job["year"])
season = str(job.get("season", "summer")).lower()
# Your training window (Sep -> May)
start_date, end_date = cfg.season_dates(year=year, season=season)
model_name = str(job.get("model", "Ensemble"))
print(f"🤖 Loading model: {model_name}")
model, le, scaler, selected_features = load_model_artifacts(cfg, model_name)
# Determine if we need scaling
use_scaler = scaler is not None and needs_scaler(model_name)
print(f" Scaler required: {use_scaler}")
# 2) Load DW baseline for this year/season (already converted to COGs)
# (This gives you the "DW baseline toggle" layer too.)
dw_arr, dw_profile = load_dw_baseline_window(
cfg=cfg,
year=year,
season=season,
aoi=aoi,
)
# 3) Build EO feature stack from DEA STAC
# IMPORTANT: This now uses full feature engineering matching train.py
print("📡 Building feature stack from DEA STAC...")
feat_arr, feat_profile, feat_names, aux_layers = build_feature_stack_from_dea(
cfg=cfg,
aoi=aoi,
start_date=start_date,
end_date=end_date,
target_profile=dw_profile,
)
print(f" Computed {len(feat_names)} features")
print(f" Feature array shape: {feat_arr.shape}")
# 4) Prepare model input: (H,W,C) -> (N,C)
H, W, C = feat_arr.shape
X = feat_arr.reshape(-1, C)
# Ensure feature order matches training
if selected_features is not None:
name_to_idx = {n: i for i, n in enumerate(feat_names)}
keep_idx = [name_to_idx[n] for n in selected_features if n in name_to_idx]
if len(keep_idx) == 0:
print("⚠️ No matching features found, using all computed features")
else:
print(f" Using {len(keep_idx)} selected features")
X = X[:, keep_idx]
else:
print(" Using all computed features (no selection)")
# Apply scaler if needed
if use_scaler and scaler is not None:
print(" Applying StandardScaler")
X = scaler.transform(X)
# Handle NaNs (common with clouds/no-data)
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
# 5) Predict
print("🔮 Running prediction...")
y_pred = model.predict(X).astype(np.int32)
# Back to string labels (your refined classes)
try:
refined_labels = le.inverse_transform(y_pred)
except Exception as e:
print(f"⚠️ Label inverse_transform failed: {e}")
# Fallback: use default classes
refined_labels = np.array([DEFAULT_CLASSES[i % len(DEFAULT_CLASSES)] for i in y_pred])
refined_labels = refined_labels.reshape(H, W)
# 6) Neighborhood smoothing (majority filter)
smoothing_kernel = job.get("smoothing_kernel", cfg.smoothing_kernel)
if cfg.smoothing_enabled and smoothing_kernel > 1:
print(f"🧼 Applying majority filter (k={smoothing_kernel})")
refined_labels = majority_filter(refined_labels, k=smoothing_kernel)
# 7) Write outputs (GeoTIFF only; COG recommended for tiling)
ts = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
out_name = f"refined_{season}_{year}_{job_id}_{ts}.tif"
baseline_name = f"dw_{season}_{year}_{job_id}_{ts}.tif"
with tempfile.TemporaryDirectory() as tmp:
refined_path = Path(tmp) / out_name
dw_path = Path(tmp) / baseline_name
# DW baseline
with rasterio.open(dw_path, "w", **dw_profile) as dst:
dst.write(dw_arr, 1)
# Refined - store as uint16 with a sidecar legend in meta (recommended)
# For now store an index raster; map index->class in meta.json
classes = le.classes_.tolist() if hasattr(le, 'classes_') else DEFAULT_CLASSES
class_to_idx = {c: i for i, c in enumerate(classes)}
# Handle string labels
if refined_labels.dtype.kind in ['U', 'O', 'S']:
# String labels - create mapping
idx_raster = np.zeros((H, W), dtype=np.uint16)
for i, cls in enumerate(classes):
mask = refined_labels == cls
idx_raster[mask] = i
else:
# Numeric labels already
idx_raster = refined_labels.astype(np.uint16)
refined_profile = dw_profile.copy()
refined_profile.update({"dtype": "uint16", "count": 1})
with rasterio.open(refined_path, "w", **refined_profile) as dst:
dst.write(idx_raster, 1)
# Upload
refined_uri = cfg.storage.upload_result(local_path=refined_path, key=f"results/{out_name}")
dw_uri = cfg.storage.upload_result(local_path=dw_path, key=f"results/{baseline_name}")
# Optionally upload aux layers (true color, NDVI/EVI/SAVI)
aux_uris = {}
for layer_name, layer in aux_layers.items():
# layer: (H,W) or (H,W,3)
aux_path = Path(tmp) / f"{layer_name}_{season}_{year}_{job_id}_{ts}.tif"
# Determine count and dtype
if layer.ndim == 3 and layer.shape[2] == 3:
count = 3
dtype = layer.dtype
else:
count = 1
dtype = layer.dtype
aux_profile = dw_profile.copy()
aux_profile.update({"count": count, "dtype": str(dtype)})
with rasterio.open(aux_path, "w", **aux_profile) as dst:
if count == 1:
dst.write(layer, 1)
else:
dst.write(layer.transpose(2, 0, 1), [1, 2, 3])
aux_uris[layer_name] = cfg.storage.upload_result(
local_path=aux_path, key=f"results/{aux_path.name}"
)
meta = {
"job_id": job_id,
"year": year,
"season": season,
"start_date": start_date,
"end_date": end_date,
"model": model_name,
"scaler_used": use_scaler,
"classes": classes,
"class_index": class_to_idx,
"features_computed": feat_names,
"n_features": len(feat_names),
"smoothing": {"enabled": cfg.smoothing_enabled, "kernel": smoothing_kernel},
}
outputs = {
"refined_geotiff": refined_uri,
"dw_baseline_geotiff": dw_uri,
**aux_uris,
}
return InferenceResult(job_id=job_id, status="done", outputs=outputs, meta=meta)
# ==========================================
# Self-Test
# ==========================================
if __name__ == "__main__":
print("=== Inference Module Self-Test ===")
# Check for required dependencies
missing_deps = []
for mod in ['joblib', 'sklearn']:
try:
__import__(mod)
except ImportError:
missing_deps.append(mod)
if missing_deps:
print(f"\n⚠️ Missing dependencies: {missing_deps}")
print(" These will be available in the container environment.")
print(" Running syntax validation only...")
# Test 1: predict_raster with dummy data (only if sklearn available)
print("\n1. Testing predict_raster with dummy feature cube...")
# Create dummy feature cube (10, 10, 51)
H, W, C = 10, 10, 51
dummy_cube = np.random.rand(H, W, C).astype(np.float32)
# Create dummy feature order
from feature_computation import FEATURE_ORDER_V1
feature_order = FEATURE_ORDER_V1
print(f" Feature cube shape: {dummy_cube.shape}")
print(f" Feature order length: {len(feature_order)}")
if 'sklearn' not in missing_deps:
# Create a dummy model for testing
from sklearn.ensemble import RandomForestClassifier
# Train a small model on random data
X_train = np.random.rand(100, C)
y_train = np.random.randint(0, 8, 100)
dummy_model = RandomForestClassifier(n_estimators=10, random_state=42)
dummy_model.fit(X_train, y_train)
# Verify model compatibility check
print(f" Model n_features_in_: {dummy_model.n_features_in_}")
# Run prediction
try:
result = predict_raster(dummy_model, dummy_cube, feature_order)
print(f" Prediction result shape: {result.shape}")
print(f" Expected shape: ({H}, {W})")
if result.shape == (H, W):
print(" ✓ predict_raster test PASSED")
else:
print(" ✗ predict_raster test FAILED - wrong shape")
except Exception as e:
print(f" ✗ predict_raster test FAILED: {e}")
# Test 2: predict_raster with nodata handling
print("\n2. Testing nodata handling...")
# Create cube with nodata (all zeros)
nodata_cube = np.zeros((5, 5, C), dtype=np.float32)
nodata_cube[2, 2, :] = 1.0 # One valid pixel
result_nodata = predict_raster(dummy_model, nodata_cube, feature_order)
print(f" Nodata pixel value at [2,2]: {result_nodata[2, 2]}")
print(f" Nodata pixels (should be 0): {result_nodata[0, 0]}")
if result_nodata[0, 0] == 0 and result_nodata[0, 1] == 0:
print(" ✓ Nodata handling test PASSED")
else:
print(" ✗ Nodata handling test FAILED")
# Test 3: Feature mismatch detection
print("\n3. Testing feature mismatch detection...")
wrong_cube = np.random.rand(5, 5, 50).astype(np.float32) # 50 features, not 51
try:
predict_raster(dummy_model, wrong_cube, feature_order)
print(" ✗ Feature mismatch test FAILED - should have raised ValueError")
except ValueError as e:
if "Feature dimension mismatch" in str(e):
print(" ✓ Feature mismatch test PASSED")
else:
print(f" ✗ Wrong error: {e}")
else:
print(" (sklearn not available - skipping)")
# Test 4: Try loading model from MinIO (will fail without real storage)
print("\n4. Testing load_model from MinIO...")
try:
from storage import MinIOStorage
storage = MinIOStorage()
# This will fail without real MinIO, but we can catch the error
model = load_model(storage, "RandomForest")
print(" Model loaded successfully")
print(" ✓ load_model test PASSED")
except Exception as e:
print(f" (Expected) MinIO/storage not available: {e}")
print(" ✓ load_model test handled gracefully")
print("\n=== Inference Module Test Complete ===")

View File

@ -8,8 +8,7 @@ This module wires together all the step modules:
- stac_client.py (DEA STAC search)
- feature_computation.py (51-feature extraction)
- dw_baseline.py (windowed DW baseline)
- inference.py (model loading + prediction)
- postprocess.py (majority filter smoothing)
- hybrid_inference.py (CNN + CatBoost ensemble inference)
- cog.py (COG export)
"""
@ -20,17 +19,17 @@ import os
import sys
import tempfile
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
# Redis/RQ for job queue
from redis import Redis
from rq import Queue
import numpy as np
import rasterio
from rasterio.io import MemoryFile
# ==========================================
# Redis Configuration
@ -48,11 +47,9 @@ def _get_redis_conn():
redis_host = os.getenv("REDIS_HOST", "redis.geocrop.svc.cluster.local")
redis_port_str = os.getenv("REDIS_PORT", "6379")
# Handle case where REDIS_PORT might be a full URL
try:
redis_port = int(redis_port_str)
except ValueError:
# If it's a URL, extract the port
if "://" in redis_port_str:
import urllib.parse
parsed = urllib.parse.urlparse(redis_port_str)
@ -60,7 +57,6 @@ def _get_redis_conn():
else:
redis_port = 6379
# MUST NOT use decode_responses=True because RQ uses pickle (binary)
return Redis(host=redis_host, port=redis_port)
@ -85,17 +81,7 @@ def update_status(
outputs: Optional[Dict] = None,
error: Optional[Dict] = None,
) -> None:
"""Update job status in Redis.
Args:
job_id: Job identifier
status: Overall status (queued, running, failed, done)
stage: Current pipeline stage
progress: Progress percentage (0-100)
message: Human-readable message
outputs: Output file URLs (when done)
error: Error details (on failure)
"""
"""Update job status in Redis."""
key = f"job:{job_id}:status"
status_data = {
@ -113,8 +99,7 @@ def update_status(
status_data["error"] = error
try:
redis_conn.set(key, json.dumps(status_data), ex=86400) # 24h expiry
# Also update the job metadata in RQ if possible
redis_conn.set(key, json.dumps(status_data), ex=86400)
from rq import get_current_job
job = get_current_job()
if job:
@ -126,39 +111,65 @@ def update_status(
print(f"Warning: Failed to update Redis status: {e}")
def send_dw_baseline_if_ready(dw_future, storage, job_id, payload, update_func):
"""Check if DW baseline is ready and send to client."""
if dw_future is None:
return None
if dw_future.done():
try:
dw_result = dw_future.result()
if dw_result is not None:
dw_arr, dw_profile = dw_result
# Save to temp file
import rasterio
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
dst.write(dw_arr)
# Upload to MinIO
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
storage.upload_result(dw_temp_path, dw_key)
# Generate presigned URL
dw_url = storage.presign_get("geocrop-baselines", dw_key)
print(f"[{job_id}] DW baseline URL ready: {dw_url[:80]}...")
# Notify client
update_func(
job_id, "running", "dw_ready", 30,
"Dynamic World baseline ready",
outputs={"dw_baseline_url": dw_url},
)
return dw_url
except Exception as e:
print(f"[{job_id}] DW baseline processing failed: {e}")
return None
# ==========================================
# Payload Validation
# ==========================================
def parse_and_validate_payload(payload: dict) -> tuple[dict, List[str]]:
"""Parse and validate job payload.
Args:
payload: Raw job payload dict
Returns:
Tuple of (validated_payload, list_of_errors)
"""
"""Parse and validate job payload."""
errors = []
# Required fields
required = ["job_id", "lat", "lon", "radius_m", "year"]
for field in required:
if field not in payload:
errors.append(f"Missing required field: {field}")
# Validate AOI
if "lat" in payload and "lon" in payload:
lat = float(payload["lat"])
lon = float(payload["lon"])
# Zimbabwe bounds check
if not (-22.5 <= lat <= -15.6):
errors.append(f"Latitude {lat} outside Zimbabwe bounds")
if not (25.2 <= lon <= 33.1):
errors.append(f"Longitude {lon} outside Zimbabwe bounds")
# Validate radius
if "radius_m" in payload:
radius = int(payload["radius_m"])
if radius > 5000:
@ -166,26 +177,22 @@ def parse_and_validate_payload(payload: dict) -> tuple[dict, List[str]]:
if radius < 100:
errors.append(f"Radius {radius}m below min 100m")
# Validate year
if "year" in payload:
year = int(payload["year"])
current_year = datetime.now().year
if year < 2015 or year > current_year:
errors.append(f"Year {year} outside valid range (2015-{current_year})")
# Validate model
if "model" in payload:
from contracts import VALID_MODELS
if payload["model"] not in VALID_MODELS:
errors.append(f"Invalid model: {payload['model']}. Must be one of {VALID_MODELS}")
# Validate kernel
if "smoothing_kernel" in payload:
kernel = int(payload["smoothing_kernel"])
if kernel not in [3, 5, 7]:
errors.append(f"Invalid smoothing_kernel: {kernel}. Must be 3, 5, or 7")
# Set defaults
validated = {
"job_id": payload.get("job_id", "unknown"),
"lat": float(payload.get("lat", 0)),
@ -207,30 +214,47 @@ def parse_and_validate_payload(payload: dict) -> tuple[dict, List[str]]:
# ==========================================
# Main Job Runner
# Async DW Loading Helper
# ==========================================
def _load_dw_async(storage, bbox, year, season) -> Optional[Tuple[np.ndarray, dict]]:
"""Async wrapper for DW baseline loading."""
from dw_baseline import load_dw_baseline_window
try:
dw_arr, dw_profile = load_dw_baseline_window(
storage=storage,
aoi_bbox_wgs84=bbox,
year=year,
season=season,
)
print(f"[_dw_load] DW baseline loaded: shape={dw_arr.shape}")
return dw_arr, dw_profile
except Exception as e:
print(f"[_dw_load] DW baseline failed: {e}")
return None
# ==========================================
# Main Job Runner (Async)
# ==========================================
def run_job(payload_dict: dict) -> dict:
"""Main job runner function.
"""Main job runner with async DW baseline loading.
This is the RQ task function that orchestrates the full pipeline.
DW baseline loads in background while hybrid inference runs.
DW URL is sent to client as soon as it's ready, parallel to inference.
"""
from rq import get_current_job
current_job = get_current_job()
# Extract job_id from payload or RQ
job_id = payload_dict.get("job_id")
if not job_id and current_job:
job_id = current_job.id
if not job_id:
job_id = "unknown"
# Ensure job_id is in payload for validation
payload_dict["job_id"] = job_id
# Standardize payload from API format to worker format
# API sends: radius_km, model_name
# Worker expects: radius_m, model
if "radius_km" in payload_dict and "radius_m" not in payload_dict:
payload_dict["radius_m"] = int(float(payload_dict["radius_km"]) * 1000)
@ -249,7 +273,6 @@ def run_job(payload_dict: dict) -> dict:
)
return {"status": "failed", "error": str(e)}
# Parse and validate payload
payload, errors = parse_and_validate_payload(payload_dict)
if errors:
update_status(
@ -259,220 +282,97 @@ def run_job(payload_dict: dict) -> dict:
)
return {"status": "failed", "errors": errors}
# Update initial status
update_status(job_id, "running", "fetch_stac", 5, "Fetching STAC items...")
update_status(job_id, "running", "init", 5, "Starting inference pipeline...")
missing_outputs = []
output_urls = {}
dw_baseline_url = None
output_urls = {}
missing_outputs = []
try:
# ==========================================
# Stage 1: Fetch STAC
# ==========================================
print(f"[{job_id}] Fetching STAC items for {payload['year']} {payload['season']}...")
from stac_client import DEASTACClient
# Get config and AOI bbox
from config import InferenceConfig, MinIOStorage as ConfigMinIO
from dw_baseline import load_dw_baseline_window
cfg = InferenceConfig()
# Initialize storage adapter for inference.py
cfg.storage = ConfigMinIO()
# Get season dates
start_date, end_date = cfg.season_dates(payload['year'], payload['season'])
# Calculate AOI bbox
lat, lon, radius = payload['lat'], payload['lon'], payload['radius_m']
# Rough bbox from radius (in degrees)
radius_deg = radius / 111000 # ~111km per degree
bbox = [
lon - radius_deg, # min_lon
lat - radius_deg, # min_lat
lon + radius_deg, # max_lon
lat + radius_deg, # max_lat
]
# Search STAC
stac_client = DEASTACClient()
try:
items = stac_client.search_items(
bbox=bbox,
start_date=start_date,
end_date=end_date,
)
print(f"[{job_id}] Found {len(items)} STAC items")
except Exception as e:
print(f"[{job_id}] STAC search failed: {e}")
# Continue but note that features may be limited
radius_deg = radius / 111000
bbox = [lon - radius_deg, lat - radius_deg, lon + radius_deg, lat + radius_deg]
# ==========================================
# Stage 2: Load DW Baseline
# Start DW baseline loading in background
# ==========================================
update_status(job_id, "running", "load_dw", 10, "Loading Dynamic World baseline...")
update_status(job_id, "running", "load_dw", 10, "Loading Dynamic World baseline (async)...")
print(f"[{job_id}] Starting async DW baseline load...")
print(f"[{job_id}] Loading Dynamic World baseline for {payload['year']} {payload['season']}...")
try:
# Load DW baseline for the AOI
dw_arr, dw_profile = load_dw_baseline_window(
storage=storage,
aoi_bbox_wgs84=bbox,
year=payload['year'],
season=payload['season'],
)
print(f"[{job_id}] DW baseline loaded: shape={dw_arr.shape}")
# Save to temporary TIF file
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
dst.write(dw_arr)
print(f"[{job_id}] DW baseline saved to temp file: {dw_temp_path}")
# Upload to MinIO
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
storage.upload_result(dw_temp_path, dw_key)
print(f"[{job_id}] DW baseline uploaded to: {dw_key}")
# Generate presigned URL
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
print(f"[{job_id}] DW baseline URL: {dw_baseline_url[:80]}...")
# Immediately update job status with DW baseline URL
update_status(
job_id, "running", "load_dw", 15,
"DW baseline loaded and uploaded",
outputs={"dw_baseline_url": dw_baseline_url},
with ThreadPoolExecutor(max_workers=1) as dw_executor:
dw_future = dw_executor.submit(
_load_dw_async,
storage, bbox, payload['year'], payload['season']
)
except Exception as e:
print(f"[{job_id}] Failed to load DW baseline: {e}")
# Continue without DW baseline - not critical for inference
dw_arr = None
dw_profile = None
update_status(job_id, "running", "build_features", 20, "Building feature cube...")
# ==========================================
# Stage 3: Build Feature Cube
# ==========================================
print(f"[{job_id}] Building feature cube...")
from feature_computation import FEATURE_ORDER_V1
feature_order = FEATURE_ORDER_V1
expected_features = len(feature_order) # Should be 51
print(f"[{job_id}] Expected {expected_features} features (FEATURE_ORDER_V1)")
# Check if we have an existing feature builder in features.py
feature_cube = None
use_synthetic = False
try:
from features import build_feature_stack_from_dea
print(f"[{job_id}] Trying build_feature_stack_from_dea for feature extraction...")
# ==========================================
# Start hybrid inference immediately (in parallel)
# ==========================================
update_status(job_id, "running", "load_model", 20, "Loading model artifacts...")
# Try to call it - this requires stackstac and DEA STAC access
try:
feature_cube = build_feature_stack_from_dea(
items=items,
bbox=bbox,
start_date=start_date,
end_date=end_date,
)
print(f"[{job_id}] Feature cube built successfully: {feature_cube.shape if feature_cube is not None else 'None'}")
except Exception as e:
print(f"[{job_id}] Feature stack building failed: {e}")
print(f"[{job_id}] Falling back to synthetic features for testing")
use_synthetic = True
except ImportError as e:
print(f"[{job_id}] Feature builder not available: {e}")
print(f"[{job_id}] Using synthetic features for testing")
use_synthetic = True
# Generate synthetic features for testing when real data isn't available
if feature_cube is None:
print(f"[{job_id}] Generating synthetic features for pipeline test...")
model_dir = Path(tempfile.mkdtemp())
print(f"[{job_id}] Downloading model artifacts...")
# Determine raster dimensions from DW baseline if loaded
if dw_arr is not None:
H, W = dw_arr.shape
else:
# Default size for testing
H, W = 100, 100
# Generate synthetic features: shape (H, W, 51)
# Use year as seed for reproducible but varied features
np.random.seed(payload['year'] + int(payload.get('lon', 0) * 100) + int(payload.get('lat', 0) * 100))
# Generate realistic-looking features (normalized values)
feature_cube = np.random.rand(H, W, expected_features).astype(np.float32)
# Add some structure - make center pixels different from edges
y, x = np.ogrid[:H, :W]
center_y, center_x = H // 2, W // 2
dist = np.sqrt((y - center_y)**2 + (x - center_x)**2)
max_dist = np.sqrt(center_y**2 + center_x**2)
# Add a gradient based on distance from center (simulating field pattern)
for i in range(min(10, expected_features)):
feature_cube[:, :, i] = (1 - dist / max_dist) * 0.5 + feature_cube[:, :, i] * 0.5
print(f"[{job_id}] Synthetic feature cube shape: {feature_cube.shape}")
# ==========================================
# Stage 4: Load Model Artifacts
# ==========================================
update_status(job_id, "running", "load_model", 40, "Loading model artifacts...")
is_hybrid = "hybrid" in payload['model'].lower() or "spatiotemporal" in payload['model'].lower()
model_dir = Path(tempfile.mkdtemp())
if is_hybrid:
print(f"[{job_id}] Model type: Hybrid Spatio-Temporal. Downloading artifacts...")
# Expected files in MinIO: pipeline_meta.pkl, Temporal_FCN.pth, calibrated_hybrid_cb.pkl
# Download model artifacts
for artifact in ["pipeline_meta.pkl", "Temporal_FCN.pth", "calibrated_hybrid_cb.pkl"]:
try:
storage.download_file(storage.bucket_models, artifact, model_dir / artifact)
print(f"[{job_id}] Downloaded {artifact}")
except Exception as e:
print(f"[{job_id}] Failed to download {artifact}: {e}")
# Try with 'hybrid/' prefix if direct fails
try:
storage.download_file(storage.bucket_models, f"hybrid/{artifact}", model_dir / artifact)
print(f"[{job_id}] Downloaded {artifact} (from hybrid/ prefix)")
except Exception as e2:
raise FileNotFoundError(f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}")
raise FileNotFoundError(
f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}"
)
update_status(job_id, "running", "fetch_stac", 30, "Fetching spatio-temporal data...")
# ==========================================
# Stage 5: Fetch Spatio-Temporal Data
# ==========================================
update_status(job_id, "running", "fetch_stac", 50, "Fetching spatio-temporal indices...")
from hybrid_inference import DEAfricaSTACWrapper, CropInferencePipeline
stac_wrapper = DEAfricaSTACWrapper()
# Calculate ranges for wrapper
lat_range = (bbox[1], bbox[3])
lon_range = (bbox[0], bbox[2])
time_range = (start_date, end_date)
print(f"[{job_id}] Fetching STAC data from DEA...")
unseen_pixel_df = stac_wrapper.fetch_and_format_data(
lat_range=lat_range,
lon_range=lon_range,
time_range=time_range
)
print(f"[{job_id}] STAC data fetched: {len(unseen_pixel_df)} pixels")
# Check if DW is ready while processing STAC
if dw_future.done():
dw_result = dw_future.result()
if dw_result is not None:
dw_arr, dw_profile = dw_result
import rasterio
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
dst.write(dw_arr)
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
storage.upload_result(dw_temp_path, dw_key)
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
update_status(
job_id, "running", "dw_ready", 35,
"Dynamic World baseline ready",
outputs={"dw_baseline_url": dw_baseline_url},
)
update_status(job_id, "running", "infer", 50, "Running Hybrid Inference (CNN + CatBoost)...")
print(f"[{job_id}] Running hybrid inference...")
# ==========================================
# Stage 6: Hybrid Inference
# ==========================================
update_status(job_id, "running", "infer", 70, "Running Hybrid Inference (CNN + CatBoost)...")
pipeline = CropInferencePipeline(model_dir=str(model_dir))
mapped_crops_df = pipeline.predict(
@ -480,17 +380,19 @@ def run_job(payload_dict: dict) -> dict:
apply_spatial_smoothing=True,
coord_cols=['lat', 'lon']
)
print(f"[{job_id}] Inference complete, exporting results...")
# ==========================================
# Stage 7: Export and Upload
# Export and Upload Results
# ==========================================
update_status(job_id, "running", "export_cog", 90, "Exporting results...")
update_status(job_id, "running", "export_cog", 80, "Exporting results...")
output_dir = Path(tempfile.mkdtemp())
output_path = output_dir / "refined.tif"
pipeline.export_to_geotiff(mapped_crops_df, output_path=str(output_path))
output_urls = {}
# Upload results
for filename in ["refined.tif", "refined_confidence.tif", "refined_cloud_mask.tif", "refined_legend.json"]:
local_f = output_dir / filename
if local_f.exists():
@ -498,44 +400,47 @@ def run_job(payload_dict: dict) -> dict:
storage.upload_result(local_f, result_key)
output_urls[filename.replace(".","_url")] = storage.presign_get("geocrop-results", result_key)
else:
# Fallback to Legacy/Standard logic
print(f"[{job_id}] Using standard/ensemble inference logic...")
from inference import run_inference_job
# Check DW one more time (may have finished during inference)
if dw_baseline_url is None and dw_future.done():
dw_result = dw_future.result()
if dw_result is not None:
dw_arr, dw_profile = dw_result
import rasterio
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
dst.write(dw_arr)
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
storage.upload_result(dw_temp_path, dw_key)
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
# Create a mock job dict compatible with run_inference_job
job_payload = {
"job_id": job_id,
"lat": payload["lat"],
"lon": payload["lon"],
"radius_m": payload["radius_m"],
"year": payload["year"],
"season": payload["season"],
"model": payload["model"],
"smoothing_kernel": payload["smoothing_kernel"]
}
inference_result = run_inference_job(cfg, job_payload)
output_urls = inference_result.outputs
# Note: indices and true_color not yet implemented
# Wait for DW if still running
if dw_baseline_url is None:
print(f"[{job_id}] Waiting for DW baseline to finish...")
dw_result = dw_future.result(timeout=60)
if dw_result is not None:
dw_arr, dw_profile = dw_result
import rasterio
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
dst.write(dw_arr)
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
storage.upload_result(dw_temp_path, dw_key)
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
# ==========================================
# Final Status
# ==========================================
final_outputs = dict(output_urls)
if dw_baseline_url:
final_outputs["dw_baseline_url"] = dw_baseline_url
if payload['outputs'].get('indices'):
missing_outputs.append("indices: not implemented")
if payload['outputs'].get('true_color'):
missing_outputs.append("true_color: not implemented")
# ==========================================
# Stage 8: Final Status
# ==========================================
final_status = "partial" if missing_outputs else "done"
final_message = f"Inference complete"
if missing_outputs:
final_message += f" (partial: {', '.join(missing_outputs)})"
# Include DW baseline URL in final outputs if available
final_outputs = dict(output_urls)
if dw_baseline_url:
final_outputs["dw_baseline_url"] = dw_baseline_url
final_message = f"Inference complete" + (f" ({', '.join(missing_outputs)})" if missing_outputs else "")
update_status(
job_id,
@ -556,7 +461,6 @@ def run_job(payload_dict: dict) -> dict:
}
except Exception as e:
# Catch-all for any unexpected errors
error_trace = traceback.format_exc()
print(f"[{job_id}] Error: {e}")
print(error_trace)
@ -573,9 +477,10 @@ def run_job(payload_dict: dict) -> dict:
"job_id": job_id,
}
# Alias for API
run_inference = run_job
# ==========================================
# RQ Worker Entry Point
# ==========================================
@ -585,7 +490,6 @@ def start_rq_worker():
from rq import Worker
import signal
# Ensure /app is in sys.path so we can import modules
if '/app' not in sys.path:
sys.path.insert(0, '/app')
@ -594,9 +498,7 @@ def start_rq_worker():
print(f"=== GeoCrop RQ Worker Starting ===")
print(f"Listening on queue: {queue_name}")
print(f"Redis: {os.getenv('REDIS_HOST', 'redis.geocrop.svc.cluster.local')}:{os.getenv('REDIS_PORT', '6379')}")
print(f"Python path: {sys.path[:3]}")
# Handle graceful shutdown
def signal_handler(signum, frame):
print("\nReceived shutdown signal, exiting gracefully...")
sys.exit(0)
@ -624,22 +526,17 @@ if __name__ == "__main__":
args = parser.parse_args()
if args.test or not args.worker:
# Syntax-level self-test
print("=== GeoCrop Worker Syntax Test ===")
# Test imports
try:
from contracts import STAGES, VALID_MODELS
from storage import MinIOStorage
from feature_computation import FEATURE_ORDER_V1
print(f"✓ Imports OK")
print(f" STAGES: {STAGES}")
print(f" VALID_MODELS: {VALID_MODELS}")
print(f" FEATURE_ORDER length: {len(FEATURE_ORDER_V1)}")
except ImportError as e:
print(f"⚠ Some imports missing (expected outside container): {e}")
print(f"⚠ Some imports missing: {e}")
# Test payload parsing
print("\n--- Payload Parsing Test ---")
test_payload = {
"job_id": "test-123",
@ -647,7 +544,7 @@ if __name__ == "__main__":
"lon": 31.0,
"radius_m": 2000,
"year": 2022,
"model": "Ensemble",
"model": "Hybrid_SpatioTemporal",
"smoothing_kernel": 5,
"outputs": {"refined": True, "dw_baseline": True},
}
@ -660,18 +557,8 @@ if __name__ == "__main__":
print(f" job_id: {validated['job_id']}")
print(f" AOI: ({validated['lat']}, {validated['lon']}) radius={validated['radius_m']}m")
print(f" model: {validated['model']}")
print(f" kernel: {validated['smoothing_kernel']}")
# Show what would run
print("\n--- Pipeline Overview ---")
print("Pipeline stages:")
for i, stage in enumerate(STAGES):
print(f" {i+1}. {stage}")
print("\nNote: This is a syntax-level test.")
print("Full execution requires Redis, MinIO, and STAC access in the container.")
print("\n=== Worker Syntax Test Complete ===")
if args.worker:
start_rq_worker()
start_rq_worker()