diff --git a/apps/worker/__pycache__/contracts.cpython-310.pyc b/apps/worker/__pycache__/contracts.cpython-310.pyc index cb1fbab..99b4715 100644 Binary files a/apps/worker/__pycache__/contracts.cpython-310.pyc and b/apps/worker/__pycache__/contracts.cpython-310.pyc differ diff --git a/apps/worker/__pycache__/feature_computation.cpython-310.pyc b/apps/worker/__pycache__/feature_computation.cpython-310.pyc index a40af85..6dfb2fc 100644 Binary files a/apps/worker/__pycache__/feature_computation.cpython-310.pyc and b/apps/worker/__pycache__/feature_computation.cpython-310.pyc differ diff --git a/apps/worker/__pycache__/features.cpython-310.pyc b/apps/worker/__pycache__/features.cpython-310.pyc index d813f62..d2e2ca6 100644 Binary files a/apps/worker/__pycache__/features.cpython-310.pyc and b/apps/worker/__pycache__/features.cpython-310.pyc differ diff --git a/apps/worker/__pycache__/stac_client.cpython-310.pyc b/apps/worker/__pycache__/stac_client.cpython-310.pyc index a563253..15bd0b3 100644 Binary files a/apps/worker/__pycache__/stac_client.cpython-310.pyc and b/apps/worker/__pycache__/stac_client.cpython-310.pyc differ diff --git a/apps/worker/__pycache__/storage.cpython-310.pyc b/apps/worker/__pycache__/storage.cpython-310.pyc index b685f9f..f1e2e67 100644 Binary files a/apps/worker/__pycache__/storage.cpython-310.pyc and b/apps/worker/__pycache__/storage.cpython-310.pyc differ diff --git a/apps/worker/__pycache__/worker.cpython-310.pyc b/apps/worker/__pycache__/worker.cpython-310.pyc index 7954800..14f0aa1 100644 Binary files a/apps/worker/__pycache__/worker.cpython-310.pyc and b/apps/worker/__pycache__/worker.cpython-310.pyc differ diff --git a/apps/worker/inference.py b/apps/worker/inference.py deleted file mode 100644 index 7a5367d..0000000 --- a/apps/worker/inference.py +++ /dev/null @@ -1,650 +0,0 @@ -"""GeoCrop inference pipeline (worker-side). - -This module is designed to be called by your RQ worker. -Given a job payload (AOI, year, model choice), it: - 1) Loads the correct model artifact from MinIO (or local cache). - 2) Loads/clips the DW baseline COG for the requested season/year. - 3) Queries Digital Earth Africa STAC for imagery and builds feature stack. - - IMPORTANT: Uses exact feature engineering from train.py: - - Savitzky-Golay smoothing (window=5, polyorder=2) - - Phenology metrics (amplitude, AUC, peak, slope) - - Harmonic features (1st/2nd order sin/cos) - - Seasonal window statistics (Early/Peak/Late) - 4) Runs per-pixel inference to produce refined classes at 10m. - 5) Applies neighborhood smoothing (majority filter). - 6) Writes output GeoTIFF (COG recommended) to MinIO. - -IMPORTANT: This implementation supports the current MinIO model format: - - Zimbabwe_Ensemble_Raw_Model.pkl (no scaler needed) - - Zimbabwe_Ensemble_Model.pkl (scaler needed) - - etc. -""" - -from __future__ import annotations - -import json -import os -import tempfile -from dataclasses import dataclass -from datetime import datetime -from pathlib import Path -from typing import Dict, Optional, Tuple, List - -# Try to import required dependencies -try: - import joblib -except ImportError: - joblib = None - -try: - import numpy as np -except ImportError: - np = None - -try: - import rasterio - from rasterio import windows - from rasterio.enums import Resampling -except ImportError: - rasterio = None - windows = None - Resampling = None - -try: - from config import InferenceConfig -except ImportError: - InferenceConfig = None - -try: - from features import ( - build_feature_stack_from_dea, - clip_raster_to_aoi, - load_dw_baseline_window, - majority_filter, - validate_aoi_zimbabwe, - ) -except ImportError: - pass - - -# ========================================== -# STEP 6: Model Loading and Raster Prediction -# ========================================== - -def load_model(storage, model_name: str): - """Load a trained model from MinIO storage. - - Args: - storage: MinIOStorage instance with download_model_file method - model_name: Name of model (e.g., "RandomForest", "XGBoost", "Ensemble") - - Returns: - Loaded sklearn-compatible model - - Raises: - FileNotFoundError: If model file not found - ValueError: If model has incompatible number of features - """ - # Create temp directory for download - import tempfile - with tempfile.TemporaryDirectory() as tmp_dir: - dest_dir = Path(tmp_dir) - - # Download model file from MinIO - # storage.download_model_file already handles mapping - model_path = storage.download_model_file(model_name, dest_dir) - - # Load model with joblib - model = joblib.load(model_path) - - # Validate model compatibility - if hasattr(model, 'n_features_in_'): - from feature_computation import FEATURE_ORDER_V1, FEATURE_ORDER_V2 - actual_features = model.n_features_in_ - - if actual_features == len(FEATURE_ORDER_V1): - print(f"Detected V1 model ({actual_features} features)") - elif actual_features == len(FEATURE_ORDER_V2): - print(f"Detected V2 model ({actual_features} features)") - else: - raise ValueError( - f"Model feature mismatch: model expects {actual_features} features. " - f"Available versions: V1 ({len(FEATURE_ORDER_V1)}), V2 ({len(FEATURE_ORDER_V2)})." - ) - - return model - - -def predict_raster( - model, - feature_cube: np.ndarray, - feature_order: List[str], -) -> np.ndarray: - """Run inference on a feature cube. - - Args: - model: Trained sklearn-compatible model - feature_cube: 3D array of shape (H, W, 51) containing features - feature_order: List of 51 feature names in order - - Returns: - 2D array of shape (H, W) with class predictions - - Raises: - ValueError: If feature_cube dimensions don't match feature_order - """ - # Validate dimensions - expected_features = len(feature_order) - actual_features = feature_cube.shape[-1] - - if actual_features != expected_features: - raise ValueError( - f"Feature dimension mismatch: feature_cube has {actual_features} features " - f"but feature_order has {expected_features}. " - f"feature_cube shape: {feature_cube.shape}, feature_order length: {len(feature_order)}. " - f"Expected 51 features matching FEATURE_ORDER_V1." - ) - - H, W, C = feature_cube.shape - - # Flatten spatial dimensions: (H, W, C) -> (H*W, C) - X = feature_cube.reshape(-1, C) - - # Identify nodata pixels (all zeros) - nodata_mask = np.all(X == 0, axis=1) - num_nodata = np.sum(nodata_mask) - - # Replace nodata with small non-zero values to avoid model issues - # The predictions will be overwritten for nodata pixels anyway - X_safe = X.copy() - if num_nodata > 0: - # Use epsilon to avoid division by zero in some models - X_safe[nodata_mask] = np.full(C, 1e-6) - - # Run prediction - y_pred = model.predict(X_safe) - - # Set nodata pixels to 0 (assuming class 0 reserved for nodata) - if num_nodata > 0: - y_pred[nodata_mask] = 0 - - # Reshape back to (H, W) - result = y_pred.reshape(H, W) - - return result - - -# ========================================== -# Legacy functions (kept for backward compatibility) -# ========================================== - - -# Model name to MinIO filename mapping -# Format: "Zimbabwe__Model.pkl" or "Zimbabwe__Raw_Model.pkl" -MODEL_NAME_MAPPING = { - # Ensemble models - "Ensemble": "Zimbabwe_Ensemble_Raw_Model.pkl", - "Ensemble_Raw": "Zimbabwe_Ensemble_Raw_Model.pkl", - "Ensemble_Scaled": "Zimbabwe_Ensemble_Model.pkl", - - # Individual models - "RandomForest": "Zimbabwe_RandomForest_Model.pkl", - "XGBoost": "Zimbabwe_XGBoost_Model.pkl", - "LightGBM": "Zimbabwe_LightGBM_Model.pkl", - "CatBoost": "Zimbabwe_CatBoost_Model.pkl", - - # Legacy/raw variants - "RandomForest_Raw": "Zimbabwe_RandomForest_Model.pkl", - "XGBoost_Raw": "Zimbabwe_XGBoost_Model.pkl", - "LightGBM_Raw": "Zimbabwe_LightGBM_Model.pkl", - "CatBoost_Raw": "Zimbabwe_CatBoost_Model.pkl", -} - -# Default class mapping if label encoder not available -# Based on typical Zimbabwe crop classification -DEFAULT_CLASSES = [ - "cropland_rainfed", - "cropland_irrigated", - "tree_crop", - "grassland", - "shrubland", - "urban", - "water", - "bare", -] - - -@dataclass -class InferenceResult: - job_id: str - status: str - outputs: Dict[str, str] - meta: Dict - - -def _local_artifact_cache_dir() -> Path: - d = Path(os.getenv("GEOCROP_CACHE_DIR", "/tmp/geocrop-cache")) - d.mkdir(parents=True, exist_ok=True) - return d - - -def get_model_filename(model_name: str) -> str: - """Get the MinIO filename for a given model name. - - Args: - model_name: Model name from job payload (e.g., "Ensemble", "Ensemble_Scaled") - - Returns: - MinIO filename (e.g., "Zimbabwe_Ensemble_Raw_Model.pkl") - """ - # Direct lookup - if model_name in MODEL_NAME_MAPPING: - return MODEL_NAME_MAPPING[model_name] - - # Try case-insensitive - model_lower = model_name.lower() - for key, value in MODEL_NAME_MAPPING.items(): - if key.lower() == model_lower: - return value - - # Default fallback - if "_raw" in model_lower: - return f"Zimbabwe_{model_name.replace('_Raw', '').title()}_Raw_Model.pkl" - else: - return f"Zimbabwe_{model_name.title()}_Model.pkl" - - -def needs_scaler(model_name: str) -> bool: - """Determine if a model needs feature scaling. - - Models with "_Raw" suffix do NOT need scaling. - All other models require StandardScaler. - - Args: - model_name: Model name from job payload - - Returns: - True if scaler should be applied - """ - # Check for _Raw suffix - if "_raw" in model_name.lower(): - return False - - # Ensemble without suffix defaults to raw - if model_name.lower() == "ensemble": - return False - - # Default: needs scaling - return True - - -def load_model_artifacts(cfg: InferenceConfig, model_name: str) -> Tuple[object, object, Optional[object], List[str]]: - """Load model, label encoder, optional scaler, and feature list. - - Supports current MinIO format: - - Zimbabwe_*_Raw_Model.pkl (no scaler) - - Zimbabwe_*_Model.pkl (needs scaler) - - Args: - cfg: Inference configuration - model_name: Name of the model to load - - Returns: - Tuple of (model, label_encoder, scaler, selected_features) - """ - cache = _local_artifact_cache_dir() / model_name.replace(" ", "_") - cache.mkdir(parents=True, exist_ok=True) - - # Get the MinIO filename - model_filename = get_model_filename(model_name) - model_key = f"models/{model_filename}" # Prefix in bucket - - model_p = cache / "model.pkl" - le_p = cache / "label_encoder.pkl" - scaler_p = cache / "scaler.pkl" - feats_p = cache / "selected_features.json" - - # Check if cached - if not model_p.exists(): - print(f"๐Ÿ“ฅ Downloading model from MinIO: {model_key}") - cfg.storage.download_model_bundle(model_key, cache) - - # Load model - model = joblib.load(model_p) - - # Load or create label encoder - if le_p.exists(): - label_encoder = joblib.load(le_p) - else: - # Try to get classes from model - print("โš ๏ธ Label encoder not found, creating default") - from sklearn.preprocessing import LabelEncoder - label_encoder = LabelEncoder() - # Fit on default classes - label_encoder.fit(DEFAULT_CLASSES) - - # Load scaler if needed - scaler = None - if needs_scaler(model_name): - if scaler_p.exists(): - scaler = joblib.load(scaler_p) - else: - print("โš ๏ธ Scaler not found but required for this model variant") - # Create a dummy scaler that does nothing - from sklearn.preprocessing import StandardScaler - scaler = StandardScaler() - # Note: In production, this should fail - scaler must be uploaded - - # Load selected features - if feats_p.exists(): - selected_features = json.loads(feats_p.read_text()) - else: - print("โš ๏ธ Selected features not found, will use all computed features") - selected_features = None - - return model, label_encoder, scaler, selected_features - - -def run_inference_job(cfg: InferenceConfig, job: Dict) -> InferenceResult: - """Main worker entry. - - job payload example: - { - "job_id": "...", - "user_id": "...", - "lat": -17.8, - "lon": 31.0, - "radius_m": 2000, - "year": 2022, - "season": "summer", - "model": "Ensemble" # or "Ensemble_Scaled", "RandomForest", etc. - } - """ - - job_id = str(job.get("job_id")) - - # 1) Validate AOI constraints - aoi = (float(job["lon"]), float(job["lat"]), float(job["radius_m"])) - validate_aoi_zimbabwe(aoi, max_radius_m=cfg.max_radius_m) - - year = int(job["year"]) - season = str(job.get("season", "summer")).lower() - - # Your training window (Sep -> May) - start_date, end_date = cfg.season_dates(year=year, season=season) - - model_name = str(job.get("model", "Ensemble")) - print(f"๐Ÿค– Loading model: {model_name}") - - model, le, scaler, selected_features = load_model_artifacts(cfg, model_name) - - # Determine if we need scaling - use_scaler = scaler is not None and needs_scaler(model_name) - print(f" Scaler required: {use_scaler}") - - # 2) Load DW baseline for this year/season (already converted to COGs) - # (This gives you the "DW baseline toggle" layer too.) - dw_arr, dw_profile = load_dw_baseline_window( - cfg=cfg, - year=year, - season=season, - aoi=aoi, - ) - - # 3) Build EO feature stack from DEA STAC - # IMPORTANT: This now uses full feature engineering matching train.py - print("๐Ÿ“ก Building feature stack from DEA STAC...") - feat_arr, feat_profile, feat_names, aux_layers = build_feature_stack_from_dea( - cfg=cfg, - aoi=aoi, - start_date=start_date, - end_date=end_date, - target_profile=dw_profile, - ) - - print(f" Computed {len(feat_names)} features") - print(f" Feature array shape: {feat_arr.shape}") - - # 4) Prepare model input: (H,W,C) -> (N,C) - H, W, C = feat_arr.shape - X = feat_arr.reshape(-1, C) - - # Ensure feature order matches training - if selected_features is not None: - name_to_idx = {n: i for i, n in enumerate(feat_names)} - keep_idx = [name_to_idx[n] for n in selected_features if n in name_to_idx] - - if len(keep_idx) == 0: - print("โš ๏ธ No matching features found, using all computed features") - else: - print(f" Using {len(keep_idx)} selected features") - X = X[:, keep_idx] - else: - print(" Using all computed features (no selection)") - - # Apply scaler if needed - if use_scaler and scaler is not None: - print(" Applying StandardScaler") - X = scaler.transform(X) - - # Handle NaNs (common with clouds/no-data) - X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0) - - # 5) Predict - print("๐Ÿ”ฎ Running prediction...") - y_pred = model.predict(X).astype(np.int32) - - # Back to string labels (your refined classes) - try: - refined_labels = le.inverse_transform(y_pred) - except Exception as e: - print(f"โš ๏ธ Label inverse_transform failed: {e}") - # Fallback: use default classes - refined_labels = np.array([DEFAULT_CLASSES[i % len(DEFAULT_CLASSES)] for i in y_pred]) - - refined_labels = refined_labels.reshape(H, W) - - # 6) Neighborhood smoothing (majority filter) - smoothing_kernel = job.get("smoothing_kernel", cfg.smoothing_kernel) - if cfg.smoothing_enabled and smoothing_kernel > 1: - print(f"๐Ÿงผ Applying majority filter (k={smoothing_kernel})") - refined_labels = majority_filter(refined_labels, k=smoothing_kernel) - - # 7) Write outputs (GeoTIFF only; COG recommended for tiling) - ts = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") - out_name = f"refined_{season}_{year}_{job_id}_{ts}.tif" - baseline_name = f"dw_{season}_{year}_{job_id}_{ts}.tif" - - with tempfile.TemporaryDirectory() as tmp: - refined_path = Path(tmp) / out_name - dw_path = Path(tmp) / baseline_name - - # DW baseline - with rasterio.open(dw_path, "w", **dw_profile) as dst: - dst.write(dw_arr, 1) - - # Refined - store as uint16 with a sidecar legend in meta (recommended) - # For now store an index raster; map index->class in meta.json - classes = le.classes_.tolist() if hasattr(le, 'classes_') else DEFAULT_CLASSES - class_to_idx = {c: i for i, c in enumerate(classes)} - - # Handle string labels - if refined_labels.dtype.kind in ['U', 'O', 'S']: - # String labels - create mapping - idx_raster = np.zeros((H, W), dtype=np.uint16) - for i, cls in enumerate(classes): - mask = refined_labels == cls - idx_raster[mask] = i - else: - # Numeric labels already - idx_raster = refined_labels.astype(np.uint16) - - refined_profile = dw_profile.copy() - refined_profile.update({"dtype": "uint16", "count": 1}) - - with rasterio.open(refined_path, "w", **refined_profile) as dst: - dst.write(idx_raster, 1) - - # Upload - refined_uri = cfg.storage.upload_result(local_path=refined_path, key=f"results/{out_name}") - dw_uri = cfg.storage.upload_result(local_path=dw_path, key=f"results/{baseline_name}") - - # Optionally upload aux layers (true color, NDVI/EVI/SAVI) - aux_uris = {} - for layer_name, layer in aux_layers.items(): - # layer: (H,W) or (H,W,3) - aux_path = Path(tmp) / f"{layer_name}_{season}_{year}_{job_id}_{ts}.tif" - - # Determine count and dtype - if layer.ndim == 3 and layer.shape[2] == 3: - count = 3 - dtype = layer.dtype - else: - count = 1 - dtype = layer.dtype - - aux_profile = dw_profile.copy() - aux_profile.update({"count": count, "dtype": str(dtype)}) - - with rasterio.open(aux_path, "w", **aux_profile) as dst: - if count == 1: - dst.write(layer, 1) - else: - dst.write(layer.transpose(2, 0, 1), [1, 2, 3]) - - aux_uris[layer_name] = cfg.storage.upload_result( - local_path=aux_path, key=f"results/{aux_path.name}" - ) - - meta = { - "job_id": job_id, - "year": year, - "season": season, - "start_date": start_date, - "end_date": end_date, - "model": model_name, - "scaler_used": use_scaler, - "classes": classes, - "class_index": class_to_idx, - "features_computed": feat_names, - "n_features": len(feat_names), - "smoothing": {"enabled": cfg.smoothing_enabled, "kernel": smoothing_kernel}, - } - - outputs = { - "refined_geotiff": refined_uri, - "dw_baseline_geotiff": dw_uri, - **aux_uris, - } - - return InferenceResult(job_id=job_id, status="done", outputs=outputs, meta=meta) - - -# ========================================== -# Self-Test -# ========================================== - -if __name__ == "__main__": - print("=== Inference Module Self-Test ===") - - # Check for required dependencies - missing_deps = [] - for mod in ['joblib', 'sklearn']: - try: - __import__(mod) - except ImportError: - missing_deps.append(mod) - - if missing_deps: - print(f"\nโš ๏ธ Missing dependencies: {missing_deps}") - print(" These will be available in the container environment.") - print(" Running syntax validation only...") - - # Test 1: predict_raster with dummy data (only if sklearn available) - print("\n1. Testing predict_raster with dummy feature cube...") - - # Create dummy feature cube (10, 10, 51) - H, W, C = 10, 10, 51 - dummy_cube = np.random.rand(H, W, C).astype(np.float32) - - # Create dummy feature order - from feature_computation import FEATURE_ORDER_V1 - feature_order = FEATURE_ORDER_V1 - - print(f" Feature cube shape: {dummy_cube.shape}") - print(f" Feature order length: {len(feature_order)}") - - if 'sklearn' not in missing_deps: - # Create a dummy model for testing - from sklearn.ensemble import RandomForestClassifier - - # Train a small model on random data - X_train = np.random.rand(100, C) - y_train = np.random.randint(0, 8, 100) - dummy_model = RandomForestClassifier(n_estimators=10, random_state=42) - dummy_model.fit(X_train, y_train) - - # Verify model compatibility check - print(f" Model n_features_in_: {dummy_model.n_features_in_}") - - # Run prediction - try: - result = predict_raster(dummy_model, dummy_cube, feature_order) - print(f" Prediction result shape: {result.shape}") - print(f" Expected shape: ({H}, {W})") - - if result.shape == (H, W): - print(" โœ“ predict_raster test PASSED") - else: - print(" โœ— predict_raster test FAILED - wrong shape") - except Exception as e: - print(f" โœ— predict_raster test FAILED: {e}") - - # Test 2: predict_raster with nodata handling - print("\n2. Testing nodata handling...") - - # Create cube with nodata (all zeros) - nodata_cube = np.zeros((5, 5, C), dtype=np.float32) - nodata_cube[2, 2, :] = 1.0 # One valid pixel - - result_nodata = predict_raster(dummy_model, nodata_cube, feature_order) - print(f" Nodata pixel value at [2,2]: {result_nodata[2, 2]}") - print(f" Nodata pixels (should be 0): {result_nodata[0, 0]}") - - if result_nodata[0, 0] == 0 and result_nodata[0, 1] == 0: - print(" โœ“ Nodata handling test PASSED") - else: - print(" โœ— Nodata handling test FAILED") - - # Test 3: Feature mismatch detection - print("\n3. Testing feature mismatch detection...") - - wrong_cube = np.random.rand(5, 5, 50).astype(np.float32) # 50 features, not 51 - - try: - predict_raster(dummy_model, wrong_cube, feature_order) - print(" โœ— Feature mismatch test FAILED - should have raised ValueError") - except ValueError as e: - if "Feature dimension mismatch" in str(e): - print(" โœ“ Feature mismatch test PASSED") - else: - print(f" โœ— Wrong error: {e}") - else: - print(" (sklearn not available - skipping)") - - # Test 4: Try loading model from MinIO (will fail without real storage) - print("\n4. Testing load_model from MinIO...") - try: - from storage import MinIOStorage - storage = MinIOStorage() - - # This will fail without real MinIO, but we can catch the error - model = load_model(storage, "RandomForest") - print(" Model loaded successfully") - print(" โœ“ load_model test PASSED") - except Exception as e: - print(f" (Expected) MinIO/storage not available: {e}") - print(" โœ“ load_model test handled gracefully") - - print("\n=== Inference Module Test Complete ===") - diff --git a/apps/worker/worker.py b/apps/worker/worker.py index 0e981b4..4772ce9 100644 --- a/apps/worker/worker.py +++ b/apps/worker/worker.py @@ -8,8 +8,7 @@ This module wires together all the step modules: - stac_client.py (DEA STAC search) - feature_computation.py (51-feature extraction) - dw_baseline.py (windowed DW baseline) -- inference.py (model loading + prediction) -- postprocess.py (majority filter smoothing) +- hybrid_inference.py (CNN + CatBoost ensemble inference) - cog.py (COG export) """ @@ -20,17 +19,17 @@ import os import sys import tempfile import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timezone from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple # Redis/RQ for job queue from redis import Redis from rq import Queue import numpy as np -import rasterio -from rasterio.io import MemoryFile + # ========================================== # Redis Configuration @@ -48,11 +47,9 @@ def _get_redis_conn(): redis_host = os.getenv("REDIS_HOST", "redis.geocrop.svc.cluster.local") redis_port_str = os.getenv("REDIS_PORT", "6379") - # Handle case where REDIS_PORT might be a full URL try: redis_port = int(redis_port_str) except ValueError: - # If it's a URL, extract the port if "://" in redis_port_str: import urllib.parse parsed = urllib.parse.urlparse(redis_port_str) @@ -60,7 +57,6 @@ def _get_redis_conn(): else: redis_port = 6379 - # MUST NOT use decode_responses=True because RQ uses pickle (binary) return Redis(host=redis_host, port=redis_port) @@ -85,17 +81,7 @@ def update_status( outputs: Optional[Dict] = None, error: Optional[Dict] = None, ) -> None: - """Update job status in Redis. - - Args: - job_id: Job identifier - status: Overall status (queued, running, failed, done) - stage: Current pipeline stage - progress: Progress percentage (0-100) - message: Human-readable message - outputs: Output file URLs (when done) - error: Error details (on failure) - """ + """Update job status in Redis.""" key = f"job:{job_id}:status" status_data = { @@ -113,8 +99,7 @@ def update_status( status_data["error"] = error try: - redis_conn.set(key, json.dumps(status_data), ex=86400) # 24h expiry - # Also update the job metadata in RQ if possible + redis_conn.set(key, json.dumps(status_data), ex=86400) from rq import get_current_job job = get_current_job() if job: @@ -126,39 +111,65 @@ def update_status( print(f"Warning: Failed to update Redis status: {e}") +def send_dw_baseline_if_ready(dw_future, storage, job_id, payload, update_func): + """Check if DW baseline is ready and send to client.""" + if dw_future is None: + return None + + if dw_future.done(): + try: + dw_result = dw_future.result() + if dw_result is not None: + dw_arr, dw_profile = dw_result + + # Save to temp file + import rasterio + dw_temp_path = Path(tempfile.mktemp(suffix=".tif")) + with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst: + dst.write(dw_arr) + + # Upload to MinIO + dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif" + storage.upload_result(dw_temp_path, dw_key) + + # Generate presigned URL + dw_url = storage.presign_get("geocrop-baselines", dw_key) + print(f"[{job_id}] DW baseline URL ready: {dw_url[:80]}...") + + # Notify client + update_func( + job_id, "running", "dw_ready", 30, + "Dynamic World baseline ready", + outputs={"dw_baseline_url": dw_url}, + ) + return dw_url + except Exception as e: + print(f"[{job_id}] DW baseline processing failed: {e}") + return None + + # ========================================== # Payload Validation # ========================================== def parse_and_validate_payload(payload: dict) -> tuple[dict, List[str]]: - """Parse and validate job payload. - - Args: - payload: Raw job payload dict - - Returns: - Tuple of (validated_payload, list_of_errors) - """ + """Parse and validate job payload.""" errors = [] - # Required fields required = ["job_id", "lat", "lon", "radius_m", "year"] for field in required: if field not in payload: errors.append(f"Missing required field: {field}") - # Validate AOI if "lat" in payload and "lon" in payload: lat = float(payload["lat"]) lon = float(payload["lon"]) - # Zimbabwe bounds check if not (-22.5 <= lat <= -15.6): errors.append(f"Latitude {lat} outside Zimbabwe bounds") if not (25.2 <= lon <= 33.1): errors.append(f"Longitude {lon} outside Zimbabwe bounds") - # Validate radius if "radius_m" in payload: radius = int(payload["radius_m"]) if radius > 5000: @@ -166,26 +177,22 @@ def parse_and_validate_payload(payload: dict) -> tuple[dict, List[str]]: if radius < 100: errors.append(f"Radius {radius}m below min 100m") - # Validate year if "year" in payload: year = int(payload["year"]) current_year = datetime.now().year if year < 2015 or year > current_year: errors.append(f"Year {year} outside valid range (2015-{current_year})") - # Validate model if "model" in payload: from contracts import VALID_MODELS if payload["model"] not in VALID_MODELS: errors.append(f"Invalid model: {payload['model']}. Must be one of {VALID_MODELS}") - # Validate kernel if "smoothing_kernel" in payload: kernel = int(payload["smoothing_kernel"]) if kernel not in [3, 5, 7]: errors.append(f"Invalid smoothing_kernel: {kernel}. Must be 3, 5, or 7") - # Set defaults validated = { "job_id": payload.get("job_id", "unknown"), "lat": float(payload.get("lat", 0)), @@ -207,30 +214,47 @@ def parse_and_validate_payload(payload: dict) -> tuple[dict, List[str]]: # ========================================== -# Main Job Runner +# Async DW Loading Helper +# ========================================== + +def _load_dw_async(storage, bbox, year, season) -> Optional[Tuple[np.ndarray, dict]]: + """Async wrapper for DW baseline loading.""" + from dw_baseline import load_dw_baseline_window + try: + dw_arr, dw_profile = load_dw_baseline_window( + storage=storage, + aoi_bbox_wgs84=bbox, + year=year, + season=season, + ) + print(f"[_dw_load] DW baseline loaded: shape={dw_arr.shape}") + return dw_arr, dw_profile + except Exception as e: + print(f"[_dw_load] DW baseline failed: {e}") + return None + + +# ========================================== +# Main Job Runner (Async) # ========================================== def run_job(payload_dict: dict) -> dict: - """Main job runner function. + """Main job runner with async DW baseline loading. - This is the RQ task function that orchestrates the full pipeline. + DW baseline loads in background while hybrid inference runs. + DW URL is sent to client as soon as it's ready, parallel to inference. """ from rq import get_current_job current_job = get_current_job() - # Extract job_id from payload or RQ job_id = payload_dict.get("job_id") if not job_id and current_job: job_id = current_job.id if not job_id: job_id = "unknown" - # Ensure job_id is in payload for validation payload_dict["job_id"] = job_id - - # Standardize payload from API format to worker format - # API sends: radius_km, model_name - # Worker expects: radius_m, model + if "radius_km" in payload_dict and "radius_m" not in payload_dict: payload_dict["radius_m"] = int(float(payload_dict["radius_km"]) * 1000) @@ -249,7 +273,6 @@ def run_job(payload_dict: dict) -> dict: ) return {"status": "failed", "error": str(e)} - # Parse and validate payload payload, errors = parse_and_validate_payload(payload_dict) if errors: update_status( @@ -259,220 +282,97 @@ def run_job(payload_dict: dict) -> dict: ) return {"status": "failed", "errors": errors} - # Update initial status - update_status(job_id, "running", "fetch_stac", 5, "Fetching STAC items...") + update_status(job_id, "running", "init", 5, "Starting inference pipeline...") - missing_outputs = [] - output_urls = {} dw_baseline_url = None + output_urls = {} + missing_outputs = [] try: - # ========================================== - # Stage 1: Fetch STAC - # ========================================== - print(f"[{job_id}] Fetching STAC items for {payload['year']} {payload['season']}...") - - from stac_client import DEASTACClient + # Get config and AOI bbox from config import InferenceConfig, MinIOStorage as ConfigMinIO - from dw_baseline import load_dw_baseline_window cfg = InferenceConfig() - # Initialize storage adapter for inference.py cfg.storage = ConfigMinIO() - # Get season dates start_date, end_date = cfg.season_dates(payload['year'], payload['season']) - # Calculate AOI bbox lat, lon, radius = payload['lat'], payload['lon'], payload['radius_m'] - - # Rough bbox from radius (in degrees) - radius_deg = radius / 111000 # ~111km per degree - bbox = [ - lon - radius_deg, # min_lon - lat - radius_deg, # min_lat - lon + radius_deg, # max_lon - lat + radius_deg, # max_lat - ] - - # Search STAC - stac_client = DEASTACClient() - - try: - items = stac_client.search_items( - bbox=bbox, - start_date=start_date, - end_date=end_date, - ) - print(f"[{job_id}] Found {len(items)} STAC items") - except Exception as e: - print(f"[{job_id}] STAC search failed: {e}") - # Continue but note that features may be limited + radius_deg = radius / 111000 + bbox = [lon - radius_deg, lat - radius_deg, lon + radius_deg, lat + radius_deg] # ========================================== - # Stage 2: Load DW Baseline + # Start DW baseline loading in background # ========================================== - update_status(job_id, "running", "load_dw", 10, "Loading Dynamic World baseline...") + update_status(job_id, "running", "load_dw", 10, "Loading Dynamic World baseline (async)...") + print(f"[{job_id}] Starting async DW baseline load...") - print(f"[{job_id}] Loading Dynamic World baseline for {payload['year']} {payload['season']}...") - - try: - # Load DW baseline for the AOI - dw_arr, dw_profile = load_dw_baseline_window( - storage=storage, - aoi_bbox_wgs84=bbox, - year=payload['year'], - season=payload['season'], - ) - print(f"[{job_id}] DW baseline loaded: shape={dw_arr.shape}") - - # Save to temporary TIF file - dw_temp_path = Path(tempfile.mktemp(suffix=".tif")) - with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst: - dst.write(dw_arr) - print(f"[{job_id}] DW baseline saved to temp file: {dw_temp_path}") - - # Upload to MinIO - dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif" - storage.upload_result(dw_temp_path, dw_key) - print(f"[{job_id}] DW baseline uploaded to: {dw_key}") - - # Generate presigned URL - dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key) - print(f"[{job_id}] DW baseline URL: {dw_baseline_url[:80]}...") - - # Immediately update job status with DW baseline URL - update_status( - job_id, "running", "load_dw", 15, - "DW baseline loaded and uploaded", - outputs={"dw_baseline_url": dw_baseline_url}, + with ThreadPoolExecutor(max_workers=1) as dw_executor: + dw_future = dw_executor.submit( + _load_dw_async, + storage, bbox, payload['year'], payload['season'] ) - except Exception as e: - print(f"[{job_id}] Failed to load DW baseline: {e}") - # Continue without DW baseline - not critical for inference - dw_arr = None - dw_profile = None - - update_status(job_id, "running", "build_features", 20, "Building feature cube...") - - # ========================================== - # Stage 3: Build Feature Cube - # ========================================== - print(f"[{job_id}] Building feature cube...") - - from feature_computation import FEATURE_ORDER_V1 - - feature_order = FEATURE_ORDER_V1 - expected_features = len(feature_order) # Should be 51 - - print(f"[{job_id}] Expected {expected_features} features (FEATURE_ORDER_V1)") - - # Check if we have an existing feature builder in features.py - feature_cube = None - use_synthetic = False - - try: - from features import build_feature_stack_from_dea - print(f"[{job_id}] Trying build_feature_stack_from_dea for feature extraction...") + # ========================================== + # Start hybrid inference immediately (in parallel) + # ========================================== + update_status(job_id, "running", "load_model", 20, "Loading model artifacts...") - # Try to call it - this requires stackstac and DEA STAC access - try: - feature_cube = build_feature_stack_from_dea( - items=items, - bbox=bbox, - start_date=start_date, - end_date=end_date, - ) - print(f"[{job_id}] Feature cube built successfully: {feature_cube.shape if feature_cube is not None else 'None'}") - except Exception as e: - print(f"[{job_id}] Feature stack building failed: {e}") - print(f"[{job_id}] Falling back to synthetic features for testing") - use_synthetic = True - - except ImportError as e: - print(f"[{job_id}] Feature builder not available: {e}") - print(f"[{job_id}] Using synthetic features for testing") - use_synthetic = True - - # Generate synthetic features for testing when real data isn't available - if feature_cube is None: - print(f"[{job_id}] Generating synthetic features for pipeline test...") + model_dir = Path(tempfile.mkdtemp()) + print(f"[{job_id}] Downloading model artifacts...") - # Determine raster dimensions from DW baseline if loaded - if dw_arr is not None: - H, W = dw_arr.shape - else: - # Default size for testing - H, W = 100, 100 - - # Generate synthetic features: shape (H, W, 51) - - # Use year as seed for reproducible but varied features - np.random.seed(payload['year'] + int(payload.get('lon', 0) * 100) + int(payload.get('lat', 0) * 100)) - - # Generate realistic-looking features (normalized values) - feature_cube = np.random.rand(H, W, expected_features).astype(np.float32) - - # Add some structure - make center pixels different from edges - y, x = np.ogrid[:H, :W] - center_y, center_x = H // 2, W // 2 - dist = np.sqrt((y - center_y)**2 + (x - center_x)**2) - max_dist = np.sqrt(center_y**2 + center_x**2) - - # Add a gradient based on distance from center (simulating field pattern) - for i in range(min(10, expected_features)): - feature_cube[:, :, i] = (1 - dist / max_dist) * 0.5 + feature_cube[:, :, i] * 0.5 - - print(f"[{job_id}] Synthetic feature cube shape: {feature_cube.shape}") - - # ========================================== - # Stage 4: Load Model Artifacts - # ========================================== - update_status(job_id, "running", "load_model", 40, "Loading model artifacts...") - - is_hybrid = "hybrid" in payload['model'].lower() or "spatiotemporal" in payload['model'].lower() - - model_dir = Path(tempfile.mkdtemp()) - - if is_hybrid: - print(f"[{job_id}] Model type: Hybrid Spatio-Temporal. Downloading artifacts...") - # Expected files in MinIO: pipeline_meta.pkl, Temporal_FCN.pth, calibrated_hybrid_cb.pkl + # Download model artifacts for artifact in ["pipeline_meta.pkl", "Temporal_FCN.pth", "calibrated_hybrid_cb.pkl"]: try: storage.download_file(storage.bucket_models, artifact, model_dir / artifact) print(f"[{job_id}] Downloaded {artifact}") except Exception as e: - print(f"[{job_id}] Failed to download {artifact}: {e}") - # Try with 'hybrid/' prefix if direct fails try: storage.download_file(storage.bucket_models, f"hybrid/{artifact}", model_dir / artifact) print(f"[{job_id}] Downloaded {artifact} (from hybrid/ prefix)") except Exception as e2: - raise FileNotFoundError(f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}") + raise FileNotFoundError( + f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}" + ) + + update_status(job_id, "running", "fetch_stac", 30, "Fetching spatio-temporal data...") - # ========================================== - # Stage 5: Fetch Spatio-Temporal Data - # ========================================== - update_status(job_id, "running", "fetch_stac", 50, "Fetching spatio-temporal indices...") from hybrid_inference import DEAfricaSTACWrapper, CropInferencePipeline stac_wrapper = DEAfricaSTACWrapper() - # Calculate ranges for wrapper lat_range = (bbox[1], bbox[3]) lon_range = (bbox[0], bbox[2]) time_range = (start_date, end_date) + print(f"[{job_id}] Fetching STAC data from DEA...") unseen_pixel_df = stac_wrapper.fetch_and_format_data( lat_range=lat_range, lon_range=lon_range, time_range=time_range ) + print(f"[{job_id}] STAC data fetched: {len(unseen_pixel_df)} pixels") + + # Check if DW is ready while processing STAC + if dw_future.done(): + dw_result = dw_future.result() + if dw_result is not None: + dw_arr, dw_profile = dw_result + import rasterio + dw_temp_path = Path(tempfile.mktemp(suffix=".tif")) + with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst: + dst.write(dw_arr) + dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif" + storage.upload_result(dw_temp_path, dw_key) + dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key) + update_status( + job_id, "running", "dw_ready", 35, + "Dynamic World baseline ready", + outputs={"dw_baseline_url": dw_baseline_url}, + ) + + update_status(job_id, "running", "infer", 50, "Running Hybrid Inference (CNN + CatBoost)...") + print(f"[{job_id}] Running hybrid inference...") - # ========================================== - # Stage 6: Hybrid Inference - # ========================================== - update_status(job_id, "running", "infer", 70, "Running Hybrid Inference (CNN + CatBoost)...") pipeline = CropInferencePipeline(model_dir=str(model_dir)) mapped_crops_df = pipeline.predict( @@ -480,17 +380,19 @@ def run_job(payload_dict: dict) -> dict: apply_spatial_smoothing=True, coord_cols=['lat', 'lon'] ) + print(f"[{job_id}] Inference complete, exporting results...") # ========================================== - # Stage 7: Export and Upload + # Export and Upload Results # ========================================== - update_status(job_id, "running", "export_cog", 90, "Exporting results...") + update_status(job_id, "running", "export_cog", 80, "Exporting results...") + output_dir = Path(tempfile.mkdtemp()) output_path = output_dir / "refined.tif" pipeline.export_to_geotiff(mapped_crops_df, output_path=str(output_path)) - output_urls = {} + # Upload results for filename in ["refined.tif", "refined_confidence.tif", "refined_cloud_mask.tif", "refined_legend.json"]: local_f = output_dir / filename if local_f.exists(): @@ -498,44 +400,47 @@ def run_job(payload_dict: dict) -> dict: storage.upload_result(local_f, result_key) output_urls[filename.replace(".","_url")] = storage.presign_get("geocrop-results", result_key) - else: - # Fallback to Legacy/Standard logic - print(f"[{job_id}] Using standard/ensemble inference logic...") - from inference import run_inference_job + # Check DW one more time (may have finished during inference) + if dw_baseline_url is None and dw_future.done(): + dw_result = dw_future.result() + if dw_result is not None: + dw_arr, dw_profile = dw_result + import rasterio + dw_temp_path = Path(tempfile.mktemp(suffix=".tif")) + with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst: + dst.write(dw_arr) + dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif" + storage.upload_result(dw_temp_path, dw_key) + dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key) - # Create a mock job dict compatible with run_inference_job - job_payload = { - "job_id": job_id, - "lat": payload["lat"], - "lon": payload["lon"], - "radius_m": payload["radius_m"], - "year": payload["year"], - "season": payload["season"], - "model": payload["model"], - "smoothing_kernel": payload["smoothing_kernel"] - } - - inference_result = run_inference_job(cfg, job_payload) - output_urls = inference_result.outputs - - # Note: indices and true_color not yet implemented + # Wait for DW if still running + if dw_baseline_url is None: + print(f"[{job_id}] Waiting for DW baseline to finish...") + dw_result = dw_future.result(timeout=60) + if dw_result is not None: + dw_arr, dw_profile = dw_result + import rasterio + dw_temp_path = Path(tempfile.mktemp(suffix=".tif")) + with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst: + dst.write(dw_arr) + dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif" + storage.upload_result(dw_temp_path, dw_key) + dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key) + + # ========================================== + # Final Status + # ========================================== + final_outputs = dict(output_urls) + if dw_baseline_url: + final_outputs["dw_baseline_url"] = dw_baseline_url + if payload['outputs'].get('indices'): missing_outputs.append("indices: not implemented") if payload['outputs'].get('true_color'): missing_outputs.append("true_color: not implemented") - # ========================================== - # Stage 8: Final Status - # ========================================== final_status = "partial" if missing_outputs else "done" - final_message = f"Inference complete" - if missing_outputs: - final_message += f" (partial: {', '.join(missing_outputs)})" - - # Include DW baseline URL in final outputs if available - final_outputs = dict(output_urls) - if dw_baseline_url: - final_outputs["dw_baseline_url"] = dw_baseline_url + final_message = f"Inference complete" + (f" ({', '.join(missing_outputs)})" if missing_outputs else "") update_status( job_id, @@ -556,7 +461,6 @@ def run_job(payload_dict: dict) -> dict: } except Exception as e: - # Catch-all for any unexpected errors error_trace = traceback.format_exc() print(f"[{job_id}] Error: {e}") print(error_trace) @@ -573,9 +477,10 @@ def run_job(payload_dict: dict) -> dict: "job_id": job_id, } -# Alias for API + run_inference = run_job + # ========================================== # RQ Worker Entry Point # ========================================== @@ -585,7 +490,6 @@ def start_rq_worker(): from rq import Worker import signal - # Ensure /app is in sys.path so we can import modules if '/app' not in sys.path: sys.path.insert(0, '/app') @@ -594,9 +498,7 @@ def start_rq_worker(): print(f"=== GeoCrop RQ Worker Starting ===") print(f"Listening on queue: {queue_name}") print(f"Redis: {os.getenv('REDIS_HOST', 'redis.geocrop.svc.cluster.local')}:{os.getenv('REDIS_PORT', '6379')}") - print(f"Python path: {sys.path[:3]}") - # Handle graceful shutdown def signal_handler(signum, frame): print("\nReceived shutdown signal, exiting gracefully...") sys.exit(0) @@ -624,22 +526,17 @@ if __name__ == "__main__": args = parser.parse_args() if args.test or not args.worker: - # Syntax-level self-test print("=== GeoCrop Worker Syntax Test ===") - # Test imports try: from contracts import STAGES, VALID_MODELS from storage import MinIOStorage - from feature_computation import FEATURE_ORDER_V1 print(f"โœ“ Imports OK") print(f" STAGES: {STAGES}") print(f" VALID_MODELS: {VALID_MODELS}") - print(f" FEATURE_ORDER length: {len(FEATURE_ORDER_V1)}") except ImportError as e: - print(f"โš  Some imports missing (expected outside container): {e}") + print(f"โš  Some imports missing: {e}") - # Test payload parsing print("\n--- Payload Parsing Test ---") test_payload = { "job_id": "test-123", @@ -647,7 +544,7 @@ if __name__ == "__main__": "lon": 31.0, "radius_m": 2000, "year": 2022, - "model": "Ensemble", + "model": "Hybrid_SpatioTemporal", "smoothing_kernel": 5, "outputs": {"refined": True, "dw_baseline": True}, } @@ -660,18 +557,8 @@ if __name__ == "__main__": print(f" job_id: {validated['job_id']}") print(f" AOI: ({validated['lat']}, {validated['lon']}) radius={validated['radius_m']}m") print(f" model: {validated['model']}") - print(f" kernel: {validated['smoothing_kernel']}") - - # Show what would run - print("\n--- Pipeline Overview ---") - print("Pipeline stages:") - for i, stage in enumerate(STAGES): - print(f" {i+1}. {stage}") - - print("\nNote: This is a syntax-level test.") - print("Full execution requires Redis, MinIO, and STAC access in the container.") print("\n=== Worker Syntax Test Complete ===") if args.worker: - start_rq_worker() + start_rq_worker() \ No newline at end of file