geocrop-platform./apps/worker/worker.py

689 lines
25 KiB
Python

"""GeoCrop Worker - RQ task runner for inference jobs.
# CI_BUILD_TRIGGER
STEP 9: Real end-to-end pipeline orchestration.
This module wires together all the step modules:
- contracts.py (validation, payload parsing)
- storage.py (MinIO adapter)
- stac_client.py (DEA STAC search)
- feature_computation.py (51-feature extraction)
- dw_baseline.py (windowed DW baseline)
- hybrid_inference.py (CNN + CatBoost ensemble inference)
- cog.py (COG export)
"""
from __future__ import annotations
import json
import os
import sys
import tempfile
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
# Redis/RQ for job queue
from redis import Redis
from rq import Queue
import numpy as np
# ==========================================
# Redis Configuration
# ==========================================
def _get_redis_conn():
"""Create Redis connection, handling both simple and URL formats."""
redis_url = os.getenv("REDIS_URL")
if redis_url:
# Handle REDIS_URL format (e.g., redis://host:6379)
# MUST NOT use decode_responses=True because RQ uses pickle (binary)
return Redis.from_url(redis_url)
# Handle separate REDIS_HOST and REDIS_PORT
redis_host = os.getenv("REDIS_HOST", "redis.geocrop.svc.cluster.local")
redis_port_str = os.getenv("REDIS_PORT", "6379")
try:
redis_port = int(redis_port_str)
except ValueError:
if "://" in redis_port_str:
import urllib.parse
parsed = urllib.parse.urlparse(redis_port_str)
redis_port = parsed.port or 6379
else:
redis_port = 6379
return Redis(host=redis_host, port=redis_port)
redis_conn = _get_redis_conn()
# ==========================================
# Status Update Helpers
# ==========================================
# ==========================================
# Distributed Locking (Idempotency Safeguard)
# ==========================================
LOCK_TTL = 1800 # 30 minutes - max expected job runtime
def acquire_job_lock(job_id: str, timeout: int = LOCK_TTL) -> bool:
"""Acquire distributed lock for job processing.
Prevents worker retry collisions - only one worker can process a job.
Uses Redis SETNX with TTL for safe lock acquisition.
Returns:
True if lock acquired, False if already locked by another worker
"""
lock_key = f"lock:job:{job_id}"
acquired = redis_conn.set(lock_key, "1", nx=True, ex=timeout)
return bool(acquired)
def release_job_lock(job_id: str) -> None:
"""Release distributed lock for job."""
lock_key = f"lock:job:{job_id}"
redis_conn.delete(lock_key)
def is_job_locked(job_id: str) -> bool:
"""Check if job is currently locked by another worker."""
lock_key = f"lock:job:{job_id}"
return redis_conn.exists(lock_key) > 0
# ==========================================
# Status Update Helpers (Atomic)
# ==========================================
def safe_now_iso() -> str:
"""Get current UTC time as ISO string."""
return datetime.now(timezone.utc).isoformat()
def update_status(
job_id: str,
status: str,
stage: str,
progress: int,
message: str,
outputs: Optional[Dict] = None,
error: Optional[Dict] = None,
) -> None:
"""Update job status in Redis atomically.
Uses Redis pipeline to update both the status hash and RQ job meta
in a single atomic operation.
"""
key = f"job:{job_id}:status"
status_data = {
"status": status,
"stage": stage,
"progress": progress,
"message": message,
"updated_at": safe_now_iso(),
}
if outputs:
status_data["outputs"] = outputs
if error:
status_data["error"] = error
try:
pipe = redis_conn.pipeline()
pipe.set(key, json.dumps(status_data), ex=86400)
from rq import get_current_job
job = get_current_job()
if job:
job.meta['progress'] = progress
job.meta['stage'] = stage
job.meta['status_message'] = message
job.save_meta()
pipe.execute()
except Exception as e:
print(f"Warning: Failed to update Redis status: {e}")
def check_existing_outputs(job_id: str, storage) -> Optional[Dict[str, str]]:
"""Check if job outputs already exist in MinIO (skip reprocessing).
Args:
job_id: Job ID
storage: MinIOStorage instance
Returns:
Dict of output URLs if all expected outputs exist, None otherwise.
"""
expected_files = [
"refined_url",
"refined_confidence_url",
"refined_cloud_mask_url",
"refined_legend_url",
]
outputs = {}
for filename in ["refined.tif", "refined_confidence.tif", "refined_cloud_mask.tif", "refined_legend.json"]:
key = f"results/{job_id}/{filename}"
try:
if storage.head_object(storage.bucket_results, key):
url_key = filename.replace(".", "_url")
outputs[url_key] = storage.presign_get(storage.bucket_results, key)
except Exception:
pass
return outputs if len(outputs) >= 3 else None
def send_dw_baseline_if_ready(dw_future, storage, job_id, payload, update_func):
"""Check if DW baseline is ready and send to client."""
if dw_future is None:
return None
if dw_future.done():
try:
dw_result = dw_future.result()
if dw_result is not None:
dw_arr, dw_profile = dw_result
# Save to temp file
import rasterio
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
dst.write(dw_arr)
# Upload to MinIO
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
storage.upload_result(dw_temp_path, dw_key)
# Generate presigned URL
dw_url = storage.presign_get("geocrop-baselines", dw_key)
print(f"[{job_id}] DW baseline URL ready: {dw_url[:80]}...")
# Notify client
update_func(
job_id, "running", "dw_ready", 30,
"Dynamic World baseline ready",
outputs={"dw_baseline_url": dw_url},
)
return dw_url
except Exception as e:
print(f"[{job_id}] DW baseline processing failed: {e}")
return None
# ==========================================
# Payload Validation
# ==========================================
def parse_and_validate_payload(payload: dict) -> tuple[dict, List[str]]:
"""Parse and validate job payload."""
errors = []
required = ["job_id", "lat", "lon", "radius_m", "year"]
for field in required:
if field not in payload:
errors.append(f"Missing required field: {field}")
if "lat" in payload and "lon" in payload:
lat = float(payload["lat"])
lon = float(payload["lon"])
if not (-22.5 <= lat <= -15.6):
errors.append(f"Latitude {lat} outside Zimbabwe bounds")
if not (25.2 <= lon <= 33.1):
errors.append(f"Longitude {lon} outside Zimbabwe bounds")
if "radius_m" in payload:
radius = int(payload["radius_m"])
if radius > 5000:
errors.append(f"Radius {radius}m exceeds max 5000m")
if radius < 100:
errors.append(f"Radius {radius}m below min 100m")
if "year" in payload:
year = int(payload["year"])
current_year = datetime.now().year
if year < 2015 or year > current_year:
errors.append(f"Year {year} outside valid range (2015-{current_year})")
if "model" in payload:
from contracts import VALID_MODELS
if payload["model"] not in VALID_MODELS:
errors.append(f"Invalid model: {payload['model']}. Must be one of {VALID_MODELS}")
if "smoothing_kernel" in payload:
kernel = int(payload["smoothing_kernel"])
if kernel not in [3, 5, 7]:
errors.append(f"Invalid smoothing_kernel: {kernel}. Must be 3, 5, or 7")
validated = {
"job_id": payload.get("job_id", "unknown"),
"lat": float(payload.get("lat", 0)),
"lon": float(payload.get("lon", 0)),
"radius_m": int(payload.get("radius_m", 2000)),
"year": int(payload.get("year", 2022)),
"season": payload.get("season", "summer"),
"model": payload.get("model", "Ensemble"),
"smoothing_kernel": int(payload.get("smoothing_kernel", 5)),
"outputs": {
"refined": payload.get("outputs", {}).get("refined", True),
"dw_baseline": payload.get("outputs", {}).get("dw_baseline", False),
"true_color": payload.get("outputs", {}).get("true_color", False),
"indices": payload.get("outputs", {}).get("indices", []),
},
}
return validated, errors
# ==========================================
# Async DW Loading Helper
# ==========================================
def _load_dw_async(storage, bbox, year, season) -> Optional[Tuple[np.ndarray, dict]]:
"""Async wrapper for DW baseline loading."""
from dw_baseline import load_dw_baseline_window
try:
dw_arr, dw_profile = load_dw_baseline_window(
storage=storage,
aoi_bbox_wgs84=bbox,
year=year,
season=season,
)
print(f"[_dw_load] DW baseline loaded: shape={dw_arr.shape}")
return dw_arr, dw_profile
except Exception as e:
print(f"[_dw_load] DW baseline failed: {e}")
return None
# ==========================================
# Main Job Runner (Async)
# ==========================================
def run_job(payload_dict: dict) -> dict:
"""Main job runner with idempotency safeguards and async DW baseline loading.
Safeguards:
- Distributed locking prevents worker retry collisions
- Checks existing outputs to skip reprocessing
- DW baseline loads in background while hybrid inference runs
DW baseline loads in background while hybrid inference runs.
DW URL is sent to client as soon as it's ready, parallel to inference.
"""
from rq import get_current_job
current_job = get_current_job()
job_id = payload_dict.get("job_id")
if not job_id and current_job:
job_id = current_job.id
if not job_id:
job_id = "unknown"
payload_dict["job_id"] = job_id
if "radius_km" in payload_dict and "radius_m" not in payload_dict:
payload_dict["radius_m"] = int(float(payload_dict["radius_km"]) * 1000)
if "model_name" in payload_dict and "model" not in payload_dict:
payload_dict["model"] = payload_dict["model_name"]
if "season" not in payload_dict:
payload_dict["season"] = "summer"
if not acquire_job_lock(job_id):
return {
"status": "skipped",
"job_id": job_id,
"message": "Job already being processed by another worker"
}
try:
try:
return _execute_inference_job(job_id, payload_dict)
except Exception as e:
error_trace = traceback.format_exc()
print(f"[{job_id}] Worker crashed: {e}")
print(error_trace)
update_status(
job_id, "failed", "error", 0,
f"Worker crashed: {str(e)}",
error={"type": type(e).__name__, "message": str(e), "trace": error_trace}
)
return {"status": "failed", "error": str(e), "job_id": job_id}
finally:
release_job_lock(job_id)
def _execute_inference_job(job_id: str, payload_dict: dict) -> dict:
# Initialize storage
try:
from storage import MinIOStorage
storage = MinIOStorage()
except Exception as e:
update_status(
job_id, "failed", "init", 0,
f"Failed to initialize storage: {e}",
error={"type": "StorageError", "message": str(e)}
)
return {"status": "failed", "error": str(e)}
payload, errors = parse_and_validate_payload(payload_dict)
if errors:
update_status(
job_id, "failed", "validation", 0,
f"Validation failed: {errors}",
error={"type": "ValidationError", "message": "; ".join(errors)}
)
return {"status": "failed", "errors": errors}
existing_outputs = check_existing_outputs(job_id, storage)
if existing_outputs:
update_status(
job_id, "complete", "cached", 100,
"Results already exist, skipping reprocessing",
outputs=existing_outputs
)
print(f"[{job_id}] Found existing outputs in MinIO, skipping inference")
return {
"status": "cached",
"job_id": job_id,
"outputs": existing_outputs,
"cached": True
}
update_status(job_id, "running", "init", 5, "Starting inference pipeline...")
dw_baseline_url = None
output_urls = {}
missing_outputs = []
try:
# Get config and AOI bbox
from config import InferenceConfig, MinIOStorage as ConfigMinIO
cfg = InferenceConfig()
cfg.storage = ConfigMinIO()
start_date, end_date = cfg.season_dates(payload['year'], payload['season'])
lat, lon, radius = payload['lat'], payload['lon'], payload['radius_m']
radius_deg = radius / 111000
bbox = [lon - radius_deg, lat - radius_deg, lon + radius_deg, lat + radius_deg]
# ==========================================
# Start DW baseline loading in background
# ==========================================
update_status(job_id, "running", "load_dw", 10, "Loading Dynamic World baseline (async)...")
print(f"[{job_id}] Starting async DW baseline load...")
with ThreadPoolExecutor(max_workers=1) as dw_executor:
dw_future = dw_executor.submit(
_load_dw_async,
storage, bbox, payload['year'], payload['season']
)
# ==========================================
# Start hybrid inference immediately (in parallel)
# ==========================================
update_status(job_id, "running", "load_model", 20, "Loading model artifacts...")
# Use persistent model cache
model_dir = Path("/app/model_cache")
model_dir.mkdir(parents=True, exist_ok=True)
print(f"[{job_id}] Model cache directory: {model_dir}")
# Download model artifacts if missing
for artifact in ["pipeline_meta.pkl", "Temporal_FCN.pth", "calibrated_hybrid_cb.pkl"]:
target_path = model_dir / artifact
if target_path.exists():
print(f"[{job_id}] Found {artifact} in cache, skipping download")
continue
print(f"[{job_id}] Downloading {artifact} to cache...")
try:
storage.download_file(storage.bucket_models, artifact, target_path)
print(f"[{job_id}] Downloaded {artifact}")
except Exception as e:
try:
storage.download_file(storage.bucket_models, f"models/{artifact}", target_path)
print(f"[{job_id}] Downloaded {artifact} (from models/ prefix)")
except Exception as e2:
# Clean up failed download to prevent corrupted cache
if target_path.exists():
target_path.unlink()
raise FileNotFoundError(
f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}"
)
update_status(job_id, "running", "fetch_stac", 30, "Fetching spatio-temporal data...")
from hybrid_inference import DEAfricaSTACWrapper, CropInferencePipeline
stac_wrapper = DEAfricaSTACWrapper()
lat_range = (bbox[1], bbox[3])
lon_range = (bbox[0], bbox[2])
time_range = (start_date, end_date)
print(f"[{job_id}] Fetching STAC data from DEA...")
unseen_pixel_df = stac_wrapper.fetch_and_format_data(
lat_range=lat_range,
lon_range=lon_range,
time_range=time_range
)
print(f"[{job_id}] STAC data fetched: {len(unseen_pixel_df)} pixels")
# Check if DW is ready while processing STAC
if dw_future.done():
dw_result = dw_future.result()
if dw_result is not None:
dw_arr, dw_profile = dw_result
import rasterio
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
dst.write(dw_arr)
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
storage.upload_result(dw_temp_path, dw_key)
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
update_status(
job_id, "running", "dw_ready", 35,
"Dynamic World baseline ready",
outputs={"dw_baseline_url": dw_baseline_url},
)
update_status(job_id, "running", "infer", 50, "Running Hybrid Inference (CNN + CatBoost)...")
print(f"[{job_id}] Running hybrid inference...")
pipeline = CropInferencePipeline(model_dir=str(model_dir))
mapped_crops_df = pipeline.predict(
unseen_pixel_df,
apply_spatial_smoothing=True,
coord_cols=['lat', 'lon']
)
print(f"[{job_id}] Inference complete, exporting results...")
# ==========================================
# Export and Upload Results
# ==========================================
update_status(job_id, "running", "export_cog", 80, "Exporting results...")
output_dir = Path(tempfile.mkdtemp())
output_path = output_dir / "refined.tif"
pipeline.export_to_geotiff(mapped_crops_df, output_path=str(output_path))
# Upload results
for filename in ["refined.tif", "refined_confidence.tif", "refined_cloud_mask.tif", "refined_legend.json"]:
local_f = output_dir / filename
if local_f.exists():
result_key = f"results/{job_id}/{filename}"
storage.upload_result(local_f, result_key)
output_urls[filename.replace(".","_url")] = storage.presign_get("geocrop-results", result_key)
# Check DW one more time (may have finished during inference)
if dw_baseline_url is None and dw_future.done():
dw_result = dw_future.result()
if dw_result is not None:
dw_arr, dw_profile = dw_result
import rasterio
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
dst.write(dw_arr)
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
storage.upload_result(dw_temp_path, dw_key)
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
# DW is optional - if it hasn't finished yet and inference is done, continue without it
# The dw_baseline_url is a nice-to-have, not a blocking requirement
if dw_baseline_url is None:
print(f"[{job_id}] DW baseline not ready after inference, continuing without it")
# ==========================================
# Final Status
# ==========================================
final_outputs = dict(output_urls)
if dw_baseline_url:
final_outputs["dw_baseline_url"] = dw_baseline_url
if payload['outputs'].get('indices'):
missing_outputs.append("indices: not implemented")
if payload['outputs'].get('true_color'):
missing_outputs.append("true_color: not implemented")
final_status = "partial" if missing_outputs else "done"
final_message = f"Inference complete" + (f" ({', '.join(missing_outputs)})" if missing_outputs else "")
update_status(
job_id,
final_status,
"done",
100,
final_message,
outputs=final_outputs if final_outputs else None,
)
print(f"[{job_id}] Job complete: {final_status}")
return {
"status": final_status,
"job_id": job_id,
"outputs": final_outputs,
"missing": missing_outputs if missing_outputs else None,
}
except Exception as e:
error_trace = traceback.format_exc()
print(f"[{job_id}] Error: {e}")
print(error_trace)
update_status(
job_id, "failed", "error", 0,
f"Unexpected error: {e}",
error={"type": type(e).__name__, "message": str(e), "trace": error_trace}
)
return {
"status": "failed",
"error": str(e),
"job_id": job_id,
}
run_inference = run_job
# ==========================================
# RQ Worker Entry Point
# ==========================================
def start_rq_worker():
"""Start the RQ worker to listen for jobs on the geocrop_tasks queue."""
from rq import Worker
import signal
if '/app' not in sys.path:
sys.path.insert(0, '/app')
queue_name = os.getenv("RQ_QUEUE_NAME", "geocrop_tasks")
print(f"=== GeoCrop RQ Worker Starting ===")
print(f"Listening on queue: {queue_name}")
print(f"Redis: {os.getenv('REDIS_HOST', 'redis.geocrop.svc.cluster.local')}:{os.getenv('REDIS_PORT', '6379')}")
def signal_handler(signum, frame):
print("\nReceived shutdown signal, exiting gracefully...")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
q = Queue(queue_name, connection=redis_conn)
w = Worker([q], connection=redis_conn)
w.work()
except KeyboardInterrupt:
print("\nWorker interrupted, shutting down...")
except Exception as e:
print(f"Worker error: {e}")
raise
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="GeoCrop Worker")
parser.add_argument("--test", action="store_true", help="Run syntax test only")
parser.add_argument("--worker", action="store_true", help="Start RQ worker")
args = parser.parse_args()
if args.test or not args.worker:
print("=== GeoCrop Worker Syntax Test ===")
try:
from contracts import STAGES, VALID_MODELS
from storage import MinIOStorage
print(f"✓ Imports OK")
print(f" STAGES: {STAGES}")
print(f" VALID_MODELS: {VALID_MODELS}")
except ImportError as e:
print(f"⚠ Some imports missing: {e}")
print("\n--- Payload Parsing Test ---")
test_payload = {
"job_id": "test-123",
"lat": -17.8,
"lon": 31.0,
"radius_m": 2000,
"year": 2022,
"model": "Hybrid_SpatioTemporal",
"smoothing_kernel": 5,
"outputs": {"refined": True, "dw_baseline": True},
}
validated, errors = parse_and_validate_payload(test_payload)
if errors:
print(f"✗ Validation errors: {errors}")
else:
print(f"✓ Payload validation passed")
print(f" job_id: {validated['job_id']}")
print(f" AOI: ({validated['lat']}, {validated['lon']}) radius={validated['radius_m']}m")
print(f" model: {validated['model']}")
print("\n=== Worker Syntax Test Complete ===")
if args.worker:
start_rq_worker()