"""GeoCrop Worker - RQ task runner for inference jobs. STEP 9: Real end-to-end pipeline orchestration. This module wires together all the step modules: - contracts.py (validation, payload parsing) - storage.py (MinIO adapter) - stac_client.py (DEA STAC search) - feature_computation.py (51-feature extraction) - dw_baseline.py (windowed DW baseline) - inference.py (model loading + prediction) - postprocess.py (majority filter smoothing) - cog.py (COG export) """ from __future__ import annotations import json import os import sys import tempfile import traceback from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List, Optional # Redis/RQ for job queue from redis import Redis from rq import Queue # ========================================== # Redis Configuration # ========================================== def _get_redis_conn(): """Create Redis connection, handling both simple and URL formats.""" redis_url = os.getenv("REDIS_URL") if redis_url: # Handle REDIS_URL format (e.g., redis://host:6379) # MUST NOT use decode_responses=True because RQ uses pickle (binary) return Redis.from_url(redis_url) # Handle separate REDIS_HOST and REDIS_PORT redis_host = os.getenv("REDIS_HOST", "redis.geocrop.svc.cluster.local") redis_port_str = os.getenv("REDIS_PORT", "6379") # Handle case where REDIS_PORT might be a full URL try: redis_port = int(redis_port_str) except ValueError: # If it's a URL, extract the port if "://" in redis_port_str: import urllib.parse parsed = urllib.parse.urlparse(redis_port_str) redis_port = parsed.port or 6379 else: redis_port = 6379 # MUST NOT use decode_responses=True because RQ uses pickle (binary) return Redis(host=redis_host, port=redis_port) redis_conn = _get_redis_conn() # ========================================== # Status Update Helpers # ========================================== def safe_now_iso() -> str: """Get current UTC time as ISO string.""" return datetime.now(timezone.utc).isoformat() def update_status( job_id: str, status: str, stage: str, progress: int, message: str, outputs: Optional[Dict] = None, error: Optional[Dict] = None, ) -> None: """Update job status in Redis. Args: job_id: Job identifier status: Overall status (queued, running, failed, done) stage: Current pipeline stage progress: Progress percentage (0-100) message: Human-readable message outputs: Output file URLs (when done) error: Error details (on failure) """ key = f"job:{job_id}:status" status_data = { "status": status, "stage": stage, "progress": progress, "message": message, "updated_at": safe_now_iso(), } if outputs: status_data["outputs"] = outputs if error: status_data["error"] = error try: redis_conn.set(key, json.dumps(status_data), ex=86400) # 24h expiry # Also update the job metadata in RQ if possible from rq import get_current_job job = get_current_job() if job: job.meta['progress'] = progress job.meta['stage'] = stage job.meta['status_message'] = message job.save_meta() except Exception as e: print(f"Warning: Failed to update Redis status: {e}") # ========================================== # Payload Validation # ========================================== def parse_and_validate_payload(payload: dict) -> tuple[dict, List[str]]: """Parse and validate job payload. Args: payload: Raw job payload dict Returns: Tuple of (validated_payload, list_of_errors) """ errors = [] # Required fields required = ["job_id", "lat", "lon", "radius_m", "year"] for field in required: if field not in payload: errors.append(f"Missing required field: {field}") # Validate AOI if "lat" in payload and "lon" in payload: lat = float(payload["lat"]) lon = float(payload["lon"]) # Zimbabwe bounds check if not (-22.5 <= lat <= -15.6): errors.append(f"Latitude {lat} outside Zimbabwe bounds") if not (25.2 <= lon <= 33.1): errors.append(f"Longitude {lon} outside Zimbabwe bounds") # Validate radius if "radius_m" in payload: radius = int(payload["radius_m"]) if radius > 5000: errors.append(f"Radius {radius}m exceeds max 5000m") if radius < 100: errors.append(f"Radius {radius}m below min 100m") # Validate year if "year" in payload: year = int(payload["year"]) current_year = datetime.now().year if year < 2015 or year > current_year: errors.append(f"Year {year} outside valid range (2015-{current_year})") # Validate model if "model" in payload: from contracts import VALID_MODELS if payload["model"] not in VALID_MODELS: errors.append(f"Invalid model: {payload['model']}. Must be one of {VALID_MODELS}") # Validate kernel if "smoothing_kernel" in payload: kernel = int(payload["smoothing_kernel"]) if kernel not in [3, 5, 7]: errors.append(f"Invalid smoothing_kernel: {kernel}. Must be 3, 5, or 7") # Set defaults validated = { "job_id": payload.get("job_id", "unknown"), "lat": float(payload.get("lat", 0)), "lon": float(payload.get("lon", 0)), "radius_m": int(payload.get("radius_m", 2000)), "year": int(payload.get("year", 2022)), "season": payload.get("season", "summer"), "model": payload.get("model", "Ensemble"), "smoothing_kernel": int(payload.get("smoothing_kernel", 5)), "outputs": { "refined": payload.get("outputs", {}).get("refined", True), "dw_baseline": payload.get("outputs", {}).get("dw_baseline", False), "true_color": payload.get("outputs", {}).get("true_color", False), "indices": payload.get("outputs", {}).get("indices", []), }, } return validated, errors # ========================================== # Main Job Runner # ========================================== def run_job(payload_dict: dict) -> dict: """Main job runner function. This is the RQ task function that orchestrates the full pipeline. """ from rq import get_current_job current_job = get_current_job() # Extract job_id from payload or RQ job_id = payload_dict.get("job_id") if not job_id and current_job: job_id = current_job.id if not job_id: job_id = "unknown" # Ensure job_id is in payload for validation payload_dict["job_id"] = job_id # Standardize payload from API format to worker format # API sends: radius_km, model_name # Worker expects: radius_m, model if "radius_km" in payload_dict and "radius_m" not in payload_dict: payload_dict["radius_m"] = int(float(payload_dict["radius_km"]) * 1000) if "model_name" in payload_dict and "model" not in payload_dict: payload_dict["model"] = payload_dict["model_name"] # Initialize storage try: from storage import MinIOStorage storage = MinIOStorage() except Exception as e: update_status( job_id, "failed", "init", 0, f"Failed to initialize storage: {e}", error={"type": "StorageError", "message": str(e)} ) return {"status": "failed", "error": str(e)} # Parse and validate payload payload, errors = parse_and_validate_payload(payload_dict) if errors: update_status( job_id, "failed", "validation", 0, f"Validation failed: {errors}", error={"type": "ValidationError", "message": "; ".join(errors)} ) return {"status": "failed", "errors": errors} # Update initial status update_status(job_id, "running", "fetch_stac", 5, "Fetching STAC items...") missing_outputs = [] output_urls = {} try: # ========================================== # Stage 1: Fetch STAC # ========================================== print(f"[{job_id}] Fetching STAC items for {payload['year']} {payload['season']}...") from stac_client import DEASTACClient from config import InferenceConfig, MinIOStorage as ConfigMinIO cfg = InferenceConfig() # Initialize storage adapter for inference.py cfg.storage = ConfigMinIO() # Get season dates start_date, end_date = cfg.season_dates(payload['year'], payload['season']) # Calculate AOI bbox lat, lon, radius = payload['lat'], payload['lon'], payload['radius_m'] # Rough bbox from radius (in degrees) radius_deg = radius / 111000 # ~111km per degree bbox = [ lon - radius_deg, # min_lon lat - radius_deg, # min_lat lon + radius_deg, # max_lon lat + radius_deg, # max_lat ] # Search STAC stac_client = DEASTACClient() try: items = stac_client.search_items( bbox=bbox, start_date=start_date, end_date=end_date, ) print(f"[{job_id}] Found {len(items)} STAC items") except Exception as e: print(f"[{job_id}] STAC search failed: {e}") # Continue but note that features may be limited update_status(job_id, "running", "build_features", 20, "Building feature cube...") # ========================================== # Stage 2: Build Feature Cube # ========================================== print(f"[{job_id}] Building feature cube...") from feature_computation import FEATURE_ORDER_V1 feature_order = FEATURE_ORDER_V1 expected_features = len(feature_order) # Should be 51 print(f"[{job_id}] Expected {expected_features} features (FEATURE_ORDER_V1)") # Check if we have an existing feature builder in features.py feature_cube = None use_synthetic = False try: from features import build_feature_stack_from_dea print(f"[{job_id}] Trying build_feature_stack_from_dea for feature extraction...") # Try to call it - this requires stackstac and DEA STAC access try: feature_cube = build_feature_stack_from_dea( items=items, bbox=bbox, start_date=start_date, end_date=end_date, ) print(f"[{job_id}] Feature cube built successfully: {feature_cube.shape if feature_cube is not None else 'None'}") except Exception as e: print(f"[{job_id}] Feature stack building failed: {e}") print(f"[{job_id}] Falling back to synthetic features for testing") use_synthetic = True except ImportError as e: print(f"[{job_id}] Feature builder not available: {e}") print(f"[{job_id}] Using synthetic features for testing") use_synthetic = True # Generate synthetic features for testing when real data isn't available if feature_cube is None: print(f"[{job_id}] Generating synthetic features for pipeline test...") # Determine raster dimensions from DW baseline if loaded if 'dw_arr' in dir() and dw_arr is not None: H, W = dw_arr.shape else: # Default size for testing H, W = 100, 100 # Generate synthetic features: shape (H, W, 51) import numpy as np # Use year as seed for reproducible but varied features np.random.seed(payload['year'] + int(payload.get('lon', 0) * 100) + int(payload.get('lat', 0) * 100)) # Generate realistic-looking features (normalized values) feature_cube = np.random.rand(H, W, expected_features).astype(np.float32) # Add some structure - make center pixels different from edges y, x = np.ogrid[:H, :W] center_y, center_x = H // 2, W // 2 dist = np.sqrt((y - center_y)**2 + (x - center_x)**2) max_dist = np.sqrt(center_y**2 + center_x**2) # Add a gradient based on distance from center (simulating field pattern) for i in range(min(10, expected_features)): feature_cube[:, :, i] = (1 - dist / max_dist) * 0.5 + feature_cube[:, :, i] * 0.5 print(f"[{job_id}] Synthetic feature cube shape: {feature_cube.shape}") # ========================================== # Stage 3: Load Model Artifacts # ========================================== update_status(job_id, "running", "load_model", 40, "Loading model artifacts...") is_hybrid = "hybrid" in payload['model'].lower() or "spatiotemporal" in payload['model'].lower() model_dir = Path(tempfile.mkdtemp()) if is_hybrid: print(f"[{job_id}] Model type: Hybrid Spatio-Temporal. Downloading artifacts...") # Expected files in MinIO: pipeline_meta.pkl, Temporal_FCN.pth, calibrated_hybrid_cb.pkl for artifact in ["pipeline_meta.pkl", "Temporal_FCN.pth", "calibrated_hybrid_cb.pkl"]: try: storage.download_file(storage.bucket_models, artifact, model_dir / artifact) print(f"[{job_id}] Downloaded {artifact}") except Exception as e: print(f"[{job_id}] Failed to download {artifact}: {e}") # Try with 'hybrid/' prefix if direct fails try: storage.download_file(storage.bucket_models, f"hybrid/{artifact}", model_dir / artifact) print(f"[{job_id}] Downloaded {artifact} (from hybrid/ prefix)") except Exception as e2: raise FileNotFoundError(f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}") # ========================================== # Stage 4: Fetch Spatio-Temporal Data # ========================================== update_status(job_id, "running", "fetch_stac", 50, "Fetching spatio-temporal indices...") from hybrid_inference import DEAfricaSTACWrapper, CropInferencePipeline stac_wrapper = DEAfricaSTACWrapper() # Calculate ranges for wrapper lat_range = (bbox[1], bbox[3]) lon_range = (bbox[0], bbox[2]) time_range = (start_date, end_date) unseen_pixel_df = stac_wrapper.fetch_and_format_data( lat_range=lat_range, lon_range=lon_range, time_range=time_range ) # ========================================== # Stage 5: Hybrid Inference # ========================================== update_status(job_id, "running", "infer", 70, "Running Hybrid Inference (CNN + CatBoost)...") pipeline = CropInferencePipeline(model_dir=str(model_dir)) mapped_crops_df = pipeline.predict( unseen_pixel_df, apply_spatial_smoothing=True, coord_cols=['lat', 'lon'] ) # ========================================== # Stage 6: Export and Upload # ========================================== update_status(job_id, "running", "export_cog", 90, "Exporting results...") output_dir = Path(tempfile.mkdtemp()) output_path = output_dir / "refined.tif" pipeline.export_to_geotiff(mapped_crops_df, output_path=str(output_path)) output_urls = {} for filename in ["refined.tif", "refined_confidence.tif", "refined_cloud_mask.tif", "refined_legend.json"]: local_f = output_dir / filename if local_f.exists(): result_key = f"results/{job_id}/{filename}" storage.upload_result(local_f, result_key) output_urls[filename.replace(".","_url")] = storage.presign_get("geocrop-results", result_key) else: # Fallback to Legacy/Standard logic print(f"[{job_id}] Using standard/ensemble inference logic...") from inference import run_inference_job # Create a mock job dict compatible with run_inference_job job_payload = { "job_id": job_id, "lat": payload["lat"], "lon": payload["lon"], "radius_m": payload["radius_m"], "year": payload["year"], "season": payload["season"], "model": payload["model"], "smoothing_kernel": payload["smoothing_kernel"] } inference_result = run_inference_job(cfg, job_payload) output_urls = inference_result.outputs # Note: indices and true_color not yet implemented if payload['outputs'].get('indices'): missing_outputs.append("indices: not implemented") if payload['outputs'].get('true_color'): missing_outputs.append("true_color: not implemented") # ========================================== # Stage 7: Final Status # ========================================== final_status = "partial" if missing_outputs else "done" final_message = f"Inference complete" if missing_outputs: final_message += f" (partial: {', '.join(missing_outputs)})" update_status( job_id, final_status, "done", 100, final_message, outputs=output_urls, ) print(f"[{job_id}] Job complete: {final_status}") return { "status": final_status, "job_id": job_id, "outputs": output_urls, "missing": missing_outputs if missing_outputs else None, } except Exception as e: # Catch-all for any unexpected errors error_trace = traceback.format_exc() print(f"[{job_id}] Error: {e}") print(error_trace) update_status( job_id, "failed", "error", 0, f"Unexpected error: {e}", error={"type": type(e).__name__, "message": str(e), "trace": error_trace} ) return { "status": "failed", "error": str(e), "job_id": job_id, } # Alias for API run_inference = run_job # ========================================== # RQ Worker Entry Point # ========================================== def start_rq_worker(): """Start the RQ worker to listen for jobs on the geocrop_tasks queue.""" from rq import Worker import signal # Ensure /app is in sys.path so we can import modules if '/app' not in sys.path: sys.path.insert(0, '/app') queue_name = os.getenv("RQ_QUEUE_NAME", "geocrop_tasks") print(f"=== GeoCrop RQ Worker Starting ===") print(f"Listening on queue: {queue_name}") print(f"Redis: {os.getenv('REDIS_HOST', 'redis.geocrop.svc.cluster.local')}:{os.getenv('REDIS_PORT', '6379')}") print(f"Python path: {sys.path[:3]}") # Handle graceful shutdown def signal_handler(signum, frame): print("\nReceived shutdown signal, exiting gracefully...") sys.exit(0) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: q = Queue(queue_name, connection=redis_conn) w = Worker([q], connection=redis_conn) w.work() except KeyboardInterrupt: print("\nWorker interrupted, shutting down...") except Exception as e: print(f"Worker error: {e}") raise if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="GeoCrop Worker") parser.add_argument("--test", action="store_true", help="Run syntax test only") parser.add_argument("--worker", action="store_true", help="Start RQ worker") args = parser.parse_args() if args.test or not args.worker: # Syntax-level self-test print("=== GeoCrop Worker Syntax Test ===") # Test imports try: from contracts import STAGES, VALID_MODELS from storage import MinIOStorage from feature_computation import FEATURE_ORDER_V1 print(f"✓ Imports OK") print(f" STAGES: {STAGES}") print(f" VALID_MODELS: {VALID_MODELS}") print(f" FEATURE_ORDER length: {len(FEATURE_ORDER_V1)}") except ImportError as e: print(f"⚠ Some imports missing (expected outside container): {e}") # Test payload parsing print("\n--- Payload Parsing Test ---") test_payload = { "job_id": "test-123", "lat": -17.8, "lon": 31.0, "radius_m": 2000, "year": 2022, "model": "Ensemble", "smoothing_kernel": 5, "outputs": {"refined": True, "dw_baseline": True}, } validated, errors = parse_and_validate_payload(test_payload) if errors: print(f"✗ Validation errors: {errors}") else: print(f"✓ Payload validation passed") print(f" job_id: {validated['job_id']}") print(f" AOI: ({validated['lat']}, {validated['lon']}) radius={validated['radius_m']}m") print(f" model: {validated['model']}") print(f" kernel: {validated['smoothing_kernel']}") # Show what would run print("\n--- Pipeline Overview ---") print("Pipeline stages:") for i, stage in enumerate(STAGES): print(f" {i+1}. {stage}") print("\nNote: This is a syntax-level test.") print("Full execution requires Redis, MinIO, and STAC access in the container.") print("\n=== Worker Syntax Test Complete ===") if args.worker: start_rq_worker()