"""GeoCrop Worker - RQ task runner for inference jobs. # CI_BUILD_TRIGGER STEP 9: Real end-to-end pipeline orchestration. This module wires together all the step modules: - contracts.py (validation, payload parsing) - storage.py (MinIO adapter) - stac_client.py (DEA STAC search) - feature_computation.py (51-feature extraction) - dw_baseline.py (windowed DW baseline) - hybrid_inference.py (CNN + CatBoost ensemble inference) - cog.py (COG export) """ from __future__ import annotations import json import os import sys import tempfile import traceback from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List, Optional, Tuple # Redis/RQ for job queue from redis import Redis from rq import Queue import numpy as np # ========================================== # Redis Configuration # ========================================== def _get_redis_conn(): """Create Redis connection, handling both simple and URL formats.""" redis_url = os.getenv("REDIS_URL") if redis_url: # Handle REDIS_URL format (e.g., redis://host:6379) # MUST NOT use decode_responses=True because RQ uses pickle (binary) return Redis.from_url(redis_url) # Handle separate REDIS_HOST and REDIS_PORT redis_host = os.getenv("REDIS_HOST", "redis.geocrop.svc.cluster.local") redis_port_str = os.getenv("REDIS_PORT", "6379") try: redis_port = int(redis_port_str) except ValueError: if "://" in redis_port_str: import urllib.parse parsed = urllib.parse.urlparse(redis_port_str) redis_port = parsed.port or 6379 else: redis_port = 6379 return Redis(host=redis_host, port=redis_port) redis_conn = _get_redis_conn() # ========================================== # Status Update Helpers # ========================================== def safe_now_iso() -> str: """Get current UTC time as ISO string.""" return datetime.now(timezone.utc).isoformat() def update_status( job_id: str, status: str, stage: str, progress: int, message: str, outputs: Optional[Dict] = None, error: Optional[Dict] = None, ) -> None: """Update job status in Redis.""" key = f"job:{job_id}:status" status_data = { "status": status, "stage": stage, "progress": progress, "message": message, "updated_at": safe_now_iso(), } if outputs: status_data["outputs"] = outputs if error: status_data["error"] = error try: redis_conn.set(key, json.dumps(status_data), ex=86400) from rq import get_current_job job = get_current_job() if job: job.meta['progress'] = progress job.meta['stage'] = stage job.meta['status_message'] = message job.save_meta() except Exception as e: print(f"Warning: Failed to update Redis status: {e}") def send_dw_baseline_if_ready(dw_future, storage, job_id, payload, update_func): """Check if DW baseline is ready and send to client.""" if dw_future is None: return None if dw_future.done(): try: dw_result = dw_future.result() if dw_result is not None: dw_arr, dw_profile = dw_result # Save to temp file import rasterio dw_temp_path = Path(tempfile.mktemp(suffix=".tif")) with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst: dst.write(dw_arr) # Upload to MinIO dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif" storage.upload_result(dw_temp_path, dw_key) # Generate presigned URL dw_url = storage.presign_get("geocrop-baselines", dw_key) print(f"[{job_id}] DW baseline URL ready: {dw_url[:80]}...") # Notify client update_func( job_id, "running", "dw_ready", 30, "Dynamic World baseline ready", outputs={"dw_baseline_url": dw_url}, ) return dw_url except Exception as e: print(f"[{job_id}] DW baseline processing failed: {e}") return None # ========================================== # Payload Validation # ========================================== def parse_and_validate_payload(payload: dict) -> tuple[dict, List[str]]: """Parse and validate job payload.""" errors = [] required = ["job_id", "lat", "lon", "radius_m", "year"] for field in required: if field not in payload: errors.append(f"Missing required field: {field}") if "lat" in payload and "lon" in payload: lat = float(payload["lat"]) lon = float(payload["lon"]) if not (-22.5 <= lat <= -15.6): errors.append(f"Latitude {lat} outside Zimbabwe bounds") if not (25.2 <= lon <= 33.1): errors.append(f"Longitude {lon} outside Zimbabwe bounds") if "radius_m" in payload: radius = int(payload["radius_m"]) if radius > 5000: errors.append(f"Radius {radius}m exceeds max 5000m") if radius < 100: errors.append(f"Radius {radius}m below min 100m") if "year" in payload: year = int(payload["year"]) current_year = datetime.now().year if year < 2015 or year > current_year: errors.append(f"Year {year} outside valid range (2015-{current_year})") if "model" in payload: from contracts import VALID_MODELS if payload["model"] not in VALID_MODELS: errors.append(f"Invalid model: {payload['model']}. Must be one of {VALID_MODELS}") if "smoothing_kernel" in payload: kernel = int(payload["smoothing_kernel"]) if kernel not in [3, 5, 7]: errors.append(f"Invalid smoothing_kernel: {kernel}. Must be 3, 5, or 7") validated = { "job_id": payload.get("job_id", "unknown"), "lat": float(payload.get("lat", 0)), "lon": float(payload.get("lon", 0)), "radius_m": int(payload.get("radius_m", 2000)), "year": int(payload.get("year", 2022)), "season": payload.get("season", "summer"), "model": payload.get("model", "Ensemble"), "smoothing_kernel": int(payload.get("smoothing_kernel", 5)), "outputs": { "refined": payload.get("outputs", {}).get("refined", True), "dw_baseline": payload.get("outputs", {}).get("dw_baseline", False), "true_color": payload.get("outputs", {}).get("true_color", False), "indices": payload.get("outputs", {}).get("indices", []), }, } return validated, errors # ========================================== # Async DW Loading Helper # ========================================== def _load_dw_async(storage, bbox, year, season) -> Optional[Tuple[np.ndarray, dict]]: """Async wrapper for DW baseline loading.""" from dw_baseline import load_dw_baseline_window try: dw_arr, dw_profile = load_dw_baseline_window( storage=storage, aoi_bbox_wgs84=bbox, year=year, season=season, ) print(f"[_dw_load] DW baseline loaded: shape={dw_arr.shape}") return dw_arr, dw_profile except Exception as e: print(f"[_dw_load] DW baseline failed: {e}") return None # ========================================== # Main Job Runner (Async) # ========================================== def run_job(payload_dict: dict) -> dict: """Main job runner with async DW baseline loading. DW baseline loads in background while hybrid inference runs. DW URL is sent to client as soon as it's ready, parallel to inference. """ from rq import get_current_job current_job = get_current_job() job_id = payload_dict.get("job_id") if not job_id and current_job: job_id = current_job.id if not job_id: job_id = "unknown" payload_dict["job_id"] = job_id if "radius_km" in payload_dict and "radius_m" not in payload_dict: payload_dict["radius_m"] = int(float(payload_dict["radius_km"]) * 1000) if "model_name" in payload_dict and "model" not in payload_dict: payload_dict["model"] = payload_dict["model_name"] # Initialize storage try: from storage import MinIOStorage storage = MinIOStorage() except Exception as e: update_status( job_id, "failed", "init", 0, f"Failed to initialize storage: {e}", error={"type": "StorageError", "message": str(e)} ) return {"status": "failed", "error": str(e)} payload, errors = parse_and_validate_payload(payload_dict) if errors: update_status( job_id, "failed", "validation", 0, f"Validation failed: {errors}", error={"type": "ValidationError", "message": "; ".join(errors)} ) return {"status": "failed", "errors": errors} update_status(job_id, "running", "init", 5, "Starting inference pipeline...") dw_baseline_url = None output_urls = {} missing_outputs = [] try: # Get config and AOI bbox from config import InferenceConfig, MinIOStorage as ConfigMinIO cfg = InferenceConfig() cfg.storage = ConfigMinIO() start_date, end_date = cfg.season_dates(payload['year'], payload['season']) lat, lon, radius = payload['lat'], payload['lon'], payload['radius_m'] radius_deg = radius / 111000 bbox = [lon - radius_deg, lat - radius_deg, lon + radius_deg, lat + radius_deg] # ========================================== # Start DW baseline loading in background # ========================================== update_status(job_id, "running", "load_dw", 10, "Loading Dynamic World baseline (async)...") print(f"[{job_id}] Starting async DW baseline load...") with ThreadPoolExecutor(max_workers=1) as dw_executor: dw_future = dw_executor.submit( _load_dw_async, storage, bbox, payload['year'], payload['season'] ) # ========================================== # Start hybrid inference immediately (in parallel) # ========================================== update_status(job_id, "running", "load_model", 20, "Loading model artifacts...") model_dir = Path(tempfile.mkdtemp()) print(f"[{job_id}] Downloading model artifacts...") # Download model artifacts for artifact in ["pipeline_meta.pkl", "Temporal_FCN.pth", "calibrated_hybrid_cb.pkl"]: try: storage.download_file(storage.bucket_models, artifact, model_dir / artifact) print(f"[{job_id}] Downloaded {artifact}") except Exception as e: try: storage.download_file(storage.bucket_models, f"models/{artifact}", model_dir / artifact) print(f"[{job_id}] Downloaded {artifact} (from models/ prefix)") except Exception as e2: raise FileNotFoundError( f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}" ) update_status(job_id, "running", "fetch_stac", 30, "Fetching spatio-temporal data...") from hybrid_inference import DEAfricaSTACWrapper, CropInferencePipeline stac_wrapper = DEAfricaSTACWrapper() lat_range = (bbox[1], bbox[3]) lon_range = (bbox[0], bbox[2]) time_range = (start_date, end_date) print(f"[{job_id}] Fetching STAC data from DEA...") unseen_pixel_df = stac_wrapper.fetch_and_format_data( lat_range=lat_range, lon_range=lon_range, time_range=time_range ) print(f"[{job_id}] STAC data fetched: {len(unseen_pixel_df)} pixels") # Check if DW is ready while processing STAC if dw_future.done(): dw_result = dw_future.result() if dw_result is not None: dw_arr, dw_profile = dw_result import rasterio dw_temp_path = Path(tempfile.mktemp(suffix=".tif")) with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst: dst.write(dw_arr) dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif" storage.upload_result(dw_temp_path, dw_key) dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key) update_status( job_id, "running", "dw_ready", 35, "Dynamic World baseline ready", outputs={"dw_baseline_url": dw_baseline_url}, ) update_status(job_id, "running", "infer", 50, "Running Hybrid Inference (CNN + CatBoost)...") print(f"[{job_id}] Running hybrid inference...") pipeline = CropInferencePipeline(model_dir=str(model_dir)) mapped_crops_df = pipeline.predict( unseen_pixel_df, apply_spatial_smoothing=True, coord_cols=['lat', 'lon'] ) print(f"[{job_id}] Inference complete, exporting results...") # ========================================== # Export and Upload Results # ========================================== update_status(job_id, "running", "export_cog", 80, "Exporting results...") output_dir = Path(tempfile.mkdtemp()) output_path = output_dir / "refined.tif" pipeline.export_to_geotiff(mapped_crops_df, output_path=str(output_path)) # Upload results for filename in ["refined.tif", "refined_confidence.tif", "refined_cloud_mask.tif", "refined_legend.json"]: local_f = output_dir / filename if local_f.exists(): result_key = f"results/{job_id}/{filename}" storage.upload_result(local_f, result_key) output_urls[filename.replace(".","_url")] = storage.presign_get("geocrop-results", result_key) # Check DW one more time (may have finished during inference) if dw_baseline_url is None and dw_future.done(): dw_result = dw_future.result() if dw_result is not None: dw_arr, dw_profile = dw_result import rasterio dw_temp_path = Path(tempfile.mktemp(suffix=".tif")) with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst: dst.write(dw_arr) dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif" storage.upload_result(dw_temp_path, dw_key) dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key) # Wait for DW if still running if dw_baseline_url is None: print(f"[{job_id}] Waiting for DW baseline to finish...") dw_result = dw_future.result(timeout=60) if dw_result is not None: dw_arr, dw_profile = dw_result import rasterio dw_temp_path = Path(tempfile.mktemp(suffix=".tif")) with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst: dst.write(dw_arr) dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif" storage.upload_result(dw_temp_path, dw_key) dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key) # ========================================== # Final Status # ========================================== final_outputs = dict(output_urls) if dw_baseline_url: final_outputs["dw_baseline_url"] = dw_baseline_url if payload['outputs'].get('indices'): missing_outputs.append("indices: not implemented") if payload['outputs'].get('true_color'): missing_outputs.append("true_color: not implemented") final_status = "partial" if missing_outputs else "done" final_message = f"Inference complete" + (f" ({', '.join(missing_outputs)})" if missing_outputs else "") update_status( job_id, final_status, "done", 100, final_message, outputs=final_outputs if final_outputs else None, ) print(f"[{job_id}] Job complete: {final_status}") return { "status": final_status, "job_id": job_id, "outputs": final_outputs, "missing": missing_outputs if missing_outputs else None, } except Exception as e: error_trace = traceback.format_exc() print(f"[{job_id}] Error: {e}") print(error_trace) update_status( job_id, "failed", "error", 0, f"Unexpected error: {e}", error={"type": type(e).__name__, "message": str(e), "trace": error_trace} ) return { "status": "failed", "error": str(e), "job_id": job_id, } run_inference = run_job # ========================================== # RQ Worker Entry Point # ========================================== def start_rq_worker(): """Start the RQ worker to listen for jobs on the geocrop_tasks queue.""" from rq import Worker import signal if '/app' not in sys.path: sys.path.insert(0, '/app') queue_name = os.getenv("RQ_QUEUE_NAME", "geocrop_tasks") print(f"=== GeoCrop RQ Worker Starting ===") print(f"Listening on queue: {queue_name}") print(f"Redis: {os.getenv('REDIS_HOST', 'redis.geocrop.svc.cluster.local')}:{os.getenv('REDIS_PORT', '6379')}") def signal_handler(signum, frame): print("\nReceived shutdown signal, exiting gracefully...") sys.exit(0) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: q = Queue(queue_name, connection=redis_conn) w = Worker([q], connection=redis_conn) w.work() except KeyboardInterrupt: print("\nWorker interrupted, shutting down...") except Exception as e: print(f"Worker error: {e}") raise if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="GeoCrop Worker") parser.add_argument("--test", action="store_true", help="Run syntax test only") parser.add_argument("--worker", action="store_true", help="Start RQ worker") args = parser.parse_args() if args.test or not args.worker: print("=== GeoCrop Worker Syntax Test ===") try: from contracts import STAGES, VALID_MODELS from storage import MinIOStorage print(f"✓ Imports OK") print(f" STAGES: {STAGES}") print(f" VALID_MODELS: {VALID_MODELS}") except ImportError as e: print(f"⚠ Some imports missing: {e}") print("\n--- Payload Parsing Test ---") test_payload = { "job_id": "test-123", "lat": -17.8, "lon": 31.0, "radius_m": 2000, "year": 2022, "model": "Hybrid_SpatioTemporal", "smoothing_kernel": 5, "outputs": {"refined": True, "dw_baseline": True}, } validated, errors = parse_and_validate_payload(test_payload) if errors: print(f"✗ Validation errors: {errors}") else: print(f"✓ Payload validation passed") print(f" job_id: {validated['job_id']}") print(f" AOI: ({validated['lat']}, {validated['lon']}) radius={validated['radius_m']}m") print(f" model: {validated['model']}") print("\n=== Worker Syntax Test Complete ===") if args.worker: start_rq_worker()