"""GeoCrop Worker - RQ task runner for inference jobs. # CI_BUILD_TRIGGER STEP 9: Real end-to-end pipeline orchestration. This module wires together all the step modules: - contracts.py (validation, payload parsing) - storage.py (MinIO adapter) - stac_client.py (DEA STAC search) - feature_computation.py (51-feature extraction) - dw_baseline.py (windowed DW baseline) - hybrid_inference.py (CNN + CatBoost ensemble inference) - cog.py (COG export) """ from __future__ import annotations import json import os import sys import tempfile import traceback from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List, Optional, Tuple # Redis/RQ for job queue from redis import Redis from rq import Queue import numpy as np # ========================================== # Redis Configuration # ========================================== def _get_redis_conn(): """Create Redis connection, handling both simple and URL formats.""" redis_url = os.getenv("REDIS_URL") if redis_url: # Handle REDIS_URL format (e.g., redis://host:6379) # MUST NOT use decode_responses=True because RQ uses pickle (binary) return Redis.from_url(redis_url) # Handle separate REDIS_HOST and REDIS_PORT redis_host = os.getenv("REDIS_HOST", "redis.geocrop.svc.cluster.local") redis_port_str = os.getenv("REDIS_PORT", "6379") try: redis_port = int(redis_port_str) except ValueError: if "://" in redis_port_str: import urllib.parse parsed = urllib.parse.urlparse(redis_port_str) redis_port = parsed.port or 6379 else: redis_port = 6379 return Redis(host=redis_host, port=redis_port) redis_conn = _get_redis_conn() # ========================================== # Status Update Helpers # ========================================== # ========================================== # Distributed Locking (Idempotency Safeguard) # ========================================== LOCK_TTL = 1800 # 30 minutes - max expected job runtime def acquire_job_lock(job_id: str, timeout: int = LOCK_TTL) -> bool: """Acquire distributed lock for job processing. Prevents worker retry collisions - only one worker can process a job. Uses Redis SETNX with TTL for safe lock acquisition. Returns: True if lock acquired, False if already locked by another worker """ lock_key = f"lock:job:{job_id}" acquired = redis_conn.set(lock_key, "1", nx=True, ex=timeout) return bool(acquired) def release_job_lock(job_id: str) -> None: """Release distributed lock for job.""" lock_key = f"lock:job:{job_id}" redis_conn.delete(lock_key) def is_job_locked(job_id: str) -> bool: """Check if job is currently locked by another worker.""" lock_key = f"lock:job:{job_id}" return redis_conn.exists(lock_key) > 0 # ========================================== # Status Update Helpers (Atomic) # ========================================== def safe_now_iso() -> str: """Get current UTC time as ISO string.""" return datetime.now(timezone.utc).isoformat() def update_status( job_id: str, status: str, stage: str, progress: int, message: str, outputs: Optional[Dict] = None, error: Optional[Dict] = None, ) -> None: """Update job status in Redis atomically. Uses Redis pipeline to update both the status hash and RQ job meta in a single atomic operation. """ key = f"job:{job_id}:status" status_data = { "status": status, "stage": stage, "progress": progress, "message": message, "updated_at": safe_now_iso(), } if outputs: status_data["outputs"] = outputs if error: status_data["error"] = error try: pipe = redis_conn.pipeline() pipe.set(key, json.dumps(status_data), ex=86400) from rq import get_current_job job = get_current_job() if job: job.meta['progress'] = progress job.meta['stage'] = stage job.meta['status_message'] = message job.save_meta() pipe.execute() except Exception as e: print(f"Warning: Failed to update Redis status: {e}") def check_existing_outputs(job_id: str, storage) -> Optional[Dict[str, str]]: """Check if job outputs already exist in MinIO (skip reprocessing). Args: job_id: Job ID storage: MinIOStorage instance Returns: Dict of output URLs if all expected outputs exist, None otherwise. """ expected_files = [ "refined_url", "refined_confidence_url", "refined_cloud_mask_url", "refined_legend_url", ] outputs = {} for filename in ["refined.tif", "refined_confidence.tif", "refined_cloud_mask.tif", "refined_legend.json"]: key = f"results/{job_id}/{filename}" try: if storage.head_object(storage.bucket_results, key): url_key = filename.replace(".", "_url") outputs[url_key] = storage.presign_get(storage.bucket_results, key) except Exception: pass return outputs if len(outputs) >= 3 else None def send_dw_baseline_if_ready(dw_future, storage, job_id, payload, update_func): """Check if DW baseline is ready and send to client.""" if dw_future is None: return None if dw_future.done(): try: dw_result = dw_future.result() if dw_result is not None: dw_arr, dw_profile = dw_result # Save to temp file import rasterio dw_temp_path = Path(tempfile.mktemp(suffix=".tif")) with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst: dst.write(dw_arr) # Upload to MinIO dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif" storage.upload_result(dw_temp_path, dw_key) # Generate presigned URL dw_url = storage.presign_get("geocrop-baselines", dw_key) print(f"[{job_id}] DW baseline URL ready: {dw_url[:80]}...") # Notify client update_func( job_id, "running", "dw_ready", 30, "Dynamic World baseline ready", outputs={"dw_baseline_url": dw_url}, ) return dw_url except Exception as e: print(f"[{job_id}] DW baseline processing failed: {e}") return None # ========================================== # Payload Validation # ========================================== def parse_and_validate_payload(payload: dict) -> tuple[dict, List[str]]: """Parse and validate job payload.""" errors = [] required = ["job_id", "lat", "lon", "radius_m", "year"] for field in required: if field not in payload: errors.append(f"Missing required field: {field}") if "lat" in payload and "lon" in payload: lat = float(payload["lat"]) lon = float(payload["lon"]) if not (-22.5 <= lat <= -15.6): errors.append(f"Latitude {lat} outside Zimbabwe bounds") if not (25.2 <= lon <= 33.1): errors.append(f"Longitude {lon} outside Zimbabwe bounds") if "radius_m" in payload: radius = int(payload["radius_m"]) if radius > 5000: errors.append(f"Radius {radius}m exceeds max 5000m") if radius < 100: errors.append(f"Radius {radius}m below min 100m") if "year" in payload: year = int(payload["year"]) current_year = datetime.now().year if year < 2015 or year > current_year: errors.append(f"Year {year} outside valid range (2015-{current_year})") if "model" in payload: from contracts import VALID_MODELS if payload["model"] not in VALID_MODELS: errors.append(f"Invalid model: {payload['model']}. Must be one of {VALID_MODELS}") if "smoothing_kernel" in payload: kernel = int(payload["smoothing_kernel"]) if kernel not in [3, 5, 7]: errors.append(f"Invalid smoothing_kernel: {kernel}. Must be 3, 5, or 7") validated = { "job_id": payload.get("job_id", "unknown"), "lat": float(payload.get("lat", 0)), "lon": float(payload.get("lon", 0)), "radius_m": int(payload.get("radius_m", 2000)), "year": int(payload.get("year", 2022)), "season": payload.get("season", "summer"), "model": payload.get("model", "Ensemble"), "smoothing_kernel": int(payload.get("smoothing_kernel", 5)), "outputs": { "refined": payload.get("outputs", {}).get("refined", True), "dw_baseline": payload.get("outputs", {}).get("dw_baseline", False), "true_color": payload.get("outputs", {}).get("true_color", False), "indices": payload.get("outputs", {}).get("indices", []), }, } return validated, errors # ========================================== # Async DW Loading Helper # ========================================== def _load_dw_async(storage, bbox, year, season) -> Optional[Tuple[np.ndarray, dict]]: """Async wrapper for DW baseline loading.""" from dw_baseline import load_dw_baseline_window try: dw_arr, dw_profile = load_dw_baseline_window( storage=storage, aoi_bbox_wgs84=bbox, year=year, season=season, ) print(f"[_dw_load] DW baseline loaded: shape={dw_arr.shape}") return dw_arr, dw_profile except Exception as e: print(f"[_dw_load] DW baseline failed: {e}") return None # ========================================== # Main Job Runner (Async) # ========================================== def run_job(payload_dict: dict) -> dict: """Main job runner with idempotency safeguards and async DW baseline loading. Safeguards: - Distributed locking prevents worker retry collisions - Checks existing outputs to skip reprocessing - DW baseline loads in background while hybrid inference runs DW baseline loads in background while hybrid inference runs. DW URL is sent to client as soon as it's ready, parallel to inference. """ from rq import get_current_job current_job = get_current_job() job_id = payload_dict.get("job_id") if not job_id and current_job: job_id = current_job.id if not job_id: job_id = "unknown" payload_dict["job_id"] = job_id if "radius_km" in payload_dict and "radius_m" not in payload_dict: payload_dict["radius_m"] = int(float(payload_dict["radius_km"]) * 1000) if "model_name" in payload_dict and "model" not in payload_dict: payload_dict["model"] = payload_dict["model_name"] if "season" not in payload_dict: payload_dict["season"] = "summer" if not acquire_job_lock(job_id): return { "status": "skipped", "job_id": job_id, "message": "Job already being processed by another worker" } try: try: return _execute_inference_job(job_id, payload_dict) except Exception as e: error_trace = traceback.format_exc() print(f"[{job_id}] Worker crashed: {e}") print(error_trace) update_status( job_id, "failed", "error", 0, f"Worker crashed: {str(e)}", error={"type": type(e).__name__, "message": str(e), "trace": error_trace} ) return {"status": "failed", "error": str(e), "job_id": job_id} finally: release_job_lock(job_id) def _execute_inference_job(job_id: str, payload_dict: dict) -> dict: # Initialize storage try: from storage import MinIOStorage storage = MinIOStorage() except Exception as e: update_status( job_id, "failed", "init", 0, f"Failed to initialize storage: {e}", error={"type": "StorageError", "message": str(e)} ) return {"status": "failed", "error": str(e)} payload, errors = parse_and_validate_payload(payload_dict) if errors: update_status( job_id, "failed", "validation", 0, f"Validation failed: {errors}", error={"type": "ValidationError", "message": "; ".join(errors)} ) return {"status": "failed", "errors": errors} existing_outputs = check_existing_outputs(job_id, storage) if existing_outputs: update_status( job_id, "complete", "cached", 100, "Results already exist, skipping reprocessing", outputs=existing_outputs ) print(f"[{job_id}] Found existing outputs in MinIO, skipping inference") return { "status": "cached", "job_id": job_id, "outputs": existing_outputs, "cached": True } update_status(job_id, "running", "init", 5, "Starting inference pipeline...") dw_baseline_url = None output_urls = {} missing_outputs = [] try: # Get config and AOI bbox from config import InferenceConfig, MinIOStorage as ConfigMinIO cfg = InferenceConfig() cfg.storage = ConfigMinIO() start_date, end_date = cfg.season_dates(payload['year'], payload['season']) lat, lon, radius = payload['lat'], payload['lon'], payload['radius_m'] radius_deg = radius / 111000 bbox = [lon - radius_deg, lat - radius_deg, lon + radius_deg, lat + radius_deg] # ========================================== # Start DW baseline loading in background # ========================================== update_status(job_id, "running", "load_dw", 10, "Loading Dynamic World baseline (async)...") print(f"[{job_id}] Starting async DW baseline load...") with ThreadPoolExecutor(max_workers=1) as dw_executor: dw_future = dw_executor.submit( _load_dw_async, storage, bbox, payload['year'], payload['season'] ) # ========================================== # Start hybrid inference immediately (in parallel) # ========================================== update_status(job_id, "running", "load_model", 20, "Loading model artifacts...") # Use persistent model cache model_dir = Path("/app/model_cache") model_dir.mkdir(parents=True, exist_ok=True) print(f"[{job_id}] Model cache directory: {model_dir}") # Download model artifacts if missing for artifact in ["pipeline_meta.pkl", "Temporal_FCN.pth", "calibrated_hybrid_cb.pkl"]: target_path = model_dir / artifact if target_path.exists(): print(f"[{job_id}] Found {artifact} in cache, skipping download") continue print(f"[{job_id}] Downloading {artifact} to cache...") try: storage.download_file(storage.bucket_models, artifact, target_path) print(f"[{job_id}] Downloaded {artifact}") except Exception as e: try: storage.download_file(storage.bucket_models, f"models/{artifact}", target_path) print(f"[{job_id}] Downloaded {artifact} (from models/ prefix)") except Exception as e2: # Clean up failed download to prevent corrupted cache if target_path.exists(): target_path.unlink() raise FileNotFoundError( f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}" ) update_status(job_id, "running", "fetch_stac", 30, "Fetching spatio-temporal data...") from hybrid_inference import DEAfricaSTACWrapper, CropInferencePipeline stac_wrapper = DEAfricaSTACWrapper() lat_range = (bbox[1], bbox[3]) lon_range = (bbox[0], bbox[2]) time_range = (start_date, end_date) print(f"[{job_id}] Fetching STAC data from DEA...") unseen_pixel_df = stac_wrapper.fetch_and_format_data( lat_range=lat_range, lon_range=lon_range, time_range=time_range ) print(f"[{job_id}] STAC data fetched: {len(unseen_pixel_df)} pixels") # Check if DW is ready while processing STAC if dw_future.done(): dw_result = dw_future.result() if dw_result is not None: dw_arr, dw_profile = dw_result import rasterio dw_temp_path = Path(tempfile.mktemp(suffix=".tif")) with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst: dst.write(dw_arr) dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif" storage.upload_result(dw_temp_path, dw_key) dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key) update_status( job_id, "running", "dw_ready", 35, "Dynamic World baseline ready", outputs={"dw_baseline_url": dw_baseline_url}, ) update_status(job_id, "running", "infer", 50, "Running Hybrid Inference (CNN + CatBoost)...") print(f"[{job_id}] Running hybrid inference...") pipeline = CropInferencePipeline(model_dir=str(model_dir)) mapped_crops_df = pipeline.predict( unseen_pixel_df, apply_spatial_smoothing=True, coord_cols=['lat', 'lon'] ) print(f"[{job_id}] Inference complete, exporting results...") # ========================================== # Export and Upload Results # ========================================== update_status(job_id, "running", "export_cog", 80, "Exporting results...") output_dir = Path(tempfile.mkdtemp()) output_path = output_dir / "refined.tif" pipeline.export_to_geotiff(mapped_crops_df, output_path=str(output_path)) # Upload results for filename in ["refined.tif", "refined_confidence.tif", "refined_cloud_mask.tif", "refined_legend.json"]: local_f = output_dir / filename if local_f.exists(): result_key = f"results/{job_id}/{filename}" storage.upload_result(local_f, result_key) output_urls[filename.replace(".","_url")] = storage.presign_get("geocrop-results", result_key) # Check DW one more time (may have finished during inference) if dw_baseline_url is None and dw_future.done(): dw_result = dw_future.result() if dw_result is not None: dw_arr, dw_profile = dw_result import rasterio dw_temp_path = Path(tempfile.mktemp(suffix=".tif")) with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst: dst.write(dw_arr) dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif" storage.upload_result(dw_temp_path, dw_key) dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key) # DW is optional - if it hasn't finished yet and inference is done, continue without it # The dw_baseline_url is a nice-to-have, not a blocking requirement if dw_baseline_url is None: print(f"[{job_id}] DW baseline not ready after inference, continuing without it") # ========================================== # Final Status # ========================================== final_outputs = dict(output_urls) if dw_baseline_url: final_outputs["dw_baseline_url"] = dw_baseline_url if payload['outputs'].get('indices'): missing_outputs.append("indices: not implemented") if payload['outputs'].get('true_color'): missing_outputs.append("true_color: not implemented") final_status = "partial" if missing_outputs else "done" final_message = f"Inference complete" + (f" ({', '.join(missing_outputs)})" if missing_outputs else "") update_status( job_id, final_status, "done", 100, final_message, outputs=final_outputs if final_outputs else None, ) print(f"[{job_id}] Job complete: {final_status}") return { "status": final_status, "job_id": job_id, "outputs": final_outputs, "missing": missing_outputs if missing_outputs else None, } except Exception as e: error_trace = traceback.format_exc() print(f"[{job_id}] Error: {e}") print(error_trace) update_status( job_id, "failed", "error", 0, f"Unexpected error: {e}", error={"type": type(e).__name__, "message": str(e), "trace": error_trace} ) return { "status": "failed", "error": str(e), "job_id": job_id, } run_inference = run_job # ========================================== # RQ Worker Entry Point # ========================================== def start_rq_worker(): """Start the RQ worker to listen for jobs on the geocrop_tasks queue.""" from rq import Worker import signal if '/app' not in sys.path: sys.path.insert(0, '/app') queue_name = os.getenv("RQ_QUEUE_NAME", "geocrop_tasks") print(f"=== GeoCrop RQ Worker Starting ===") print(f"Listening on queue: {queue_name}") print(f"Redis: {os.getenv('REDIS_HOST', 'redis.geocrop.svc.cluster.local')}:{os.getenv('REDIS_PORT', '6379')}") def signal_handler(signum, frame): print("\nReceived shutdown signal, exiting gracefully...") sys.exit(0) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: q = Queue(queue_name, connection=redis_conn) w = Worker([q], connection=redis_conn) w.work() except KeyboardInterrupt: print("\nWorker interrupted, shutting down...") except Exception as e: print(f"Worker error: {e}") raise if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="GeoCrop Worker") parser.add_argument("--test", action="store_true", help="Run syntax test only") parser.add_argument("--worker", action="store_true", help="Start RQ worker") args = parser.parse_args() if args.test or not args.worker: print("=== GeoCrop Worker Syntax Test ===") try: from contracts import STAGES, VALID_MODELS from storage import MinIOStorage print(f"✓ Imports OK") print(f" STAGES: {STAGES}") print(f" VALID_MODELS: {VALID_MODELS}") except ImportError as e: print(f"⚠ Some imports missing: {e}") print("\n--- Payload Parsing Test ---") test_payload = { "job_id": "test-123", "lat": -17.8, "lon": 31.0, "radius_m": 2000, "year": 2022, "model": "Hybrid_SpatioTemporal", "smoothing_kernel": 5, "outputs": {"refined": True, "dw_baseline": True}, } validated, errors = parse_and_validate_payload(test_payload) if errors: print(f"✗ Validation errors: {errors}") else: print(f"✓ Payload validation passed") print(f" job_id: {validated['job_id']}") print(f" AOI: ({validated['lat']}, {validated['lon']}) radius={validated['radius_m']}m") print(f" model: {validated['model']}") print("\n=== Worker Syntax Test Complete ===") if args.worker: start_rq_worker()