geocrop-platform./apps/worker/worker.py

564 lines
21 KiB
Python

"""GeoCrop Worker - RQ task runner for inference jobs.
STEP 9: Real end-to-end pipeline orchestration.
This module wires together all the step modules:
- contracts.py (validation, payload parsing)
- storage.py (MinIO adapter)
- stac_client.py (DEA STAC search)
- feature_computation.py (51-feature extraction)
- dw_baseline.py (windowed DW baseline)
- hybrid_inference.py (CNN + CatBoost ensemble inference)
- cog.py (COG export)
"""
from __future__ import annotations
import json
import os
import sys
import tempfile
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
# Redis/RQ for job queue
from redis import Redis
from rq import Queue
import numpy as np
# ==========================================
# Redis Configuration
# ==========================================
def _get_redis_conn():
"""Create Redis connection, handling both simple and URL formats."""
redis_url = os.getenv("REDIS_URL")
if redis_url:
# Handle REDIS_URL format (e.g., redis://host:6379)
# MUST NOT use decode_responses=True because RQ uses pickle (binary)
return Redis.from_url(redis_url)
# Handle separate REDIS_HOST and REDIS_PORT
redis_host = os.getenv("REDIS_HOST", "redis.geocrop.svc.cluster.local")
redis_port_str = os.getenv("REDIS_PORT", "6379")
try:
redis_port = int(redis_port_str)
except ValueError:
if "://" in redis_port_str:
import urllib.parse
parsed = urllib.parse.urlparse(redis_port_str)
redis_port = parsed.port or 6379
else:
redis_port = 6379
return Redis(host=redis_host, port=redis_port)
redis_conn = _get_redis_conn()
# ==========================================
# Status Update Helpers
# ==========================================
def safe_now_iso() -> str:
"""Get current UTC time as ISO string."""
return datetime.now(timezone.utc).isoformat()
def update_status(
job_id: str,
status: str,
stage: str,
progress: int,
message: str,
outputs: Optional[Dict] = None,
error: Optional[Dict] = None,
) -> None:
"""Update job status in Redis."""
key = f"job:{job_id}:status"
status_data = {
"status": status,
"stage": stage,
"progress": progress,
"message": message,
"updated_at": safe_now_iso(),
}
if outputs:
status_data["outputs"] = outputs
if error:
status_data["error"] = error
try:
redis_conn.set(key, json.dumps(status_data), ex=86400)
from rq import get_current_job
job = get_current_job()
if job:
job.meta['progress'] = progress
job.meta['stage'] = stage
job.meta['status_message'] = message
job.save_meta()
except Exception as e:
print(f"Warning: Failed to update Redis status: {e}")
def send_dw_baseline_if_ready(dw_future, storage, job_id, payload, update_func):
"""Check if DW baseline is ready and send to client."""
if dw_future is None:
return None
if dw_future.done():
try:
dw_result = dw_future.result()
if dw_result is not None:
dw_arr, dw_profile = dw_result
# Save to temp file
import rasterio
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
dst.write(dw_arr)
# Upload to MinIO
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
storage.upload_result(dw_temp_path, dw_key)
# Generate presigned URL
dw_url = storage.presign_get("geocrop-baselines", dw_key)
print(f"[{job_id}] DW baseline URL ready: {dw_url[:80]}...")
# Notify client
update_func(
job_id, "running", "dw_ready", 30,
"Dynamic World baseline ready",
outputs={"dw_baseline_url": dw_url},
)
return dw_url
except Exception as e:
print(f"[{job_id}] DW baseline processing failed: {e}")
return None
# ==========================================
# Payload Validation
# ==========================================
def parse_and_validate_payload(payload: dict) -> tuple[dict, List[str]]:
"""Parse and validate job payload."""
errors = []
required = ["job_id", "lat", "lon", "radius_m", "year"]
for field in required:
if field not in payload:
errors.append(f"Missing required field: {field}")
if "lat" in payload and "lon" in payload:
lat = float(payload["lat"])
lon = float(payload["lon"])
if not (-22.5 <= lat <= -15.6):
errors.append(f"Latitude {lat} outside Zimbabwe bounds")
if not (25.2 <= lon <= 33.1):
errors.append(f"Longitude {lon} outside Zimbabwe bounds")
if "radius_m" in payload:
radius = int(payload["radius_m"])
if radius > 5000:
errors.append(f"Radius {radius}m exceeds max 5000m")
if radius < 100:
errors.append(f"Radius {radius}m below min 100m")
if "year" in payload:
year = int(payload["year"])
current_year = datetime.now().year
if year < 2015 or year > current_year:
errors.append(f"Year {year} outside valid range (2015-{current_year})")
if "model" in payload:
from contracts import VALID_MODELS
if payload["model"] not in VALID_MODELS:
errors.append(f"Invalid model: {payload['model']}. Must be one of {VALID_MODELS}")
if "smoothing_kernel" in payload:
kernel = int(payload["smoothing_kernel"])
if kernel not in [3, 5, 7]:
errors.append(f"Invalid smoothing_kernel: {kernel}. Must be 3, 5, or 7")
validated = {
"job_id": payload.get("job_id", "unknown"),
"lat": float(payload.get("lat", 0)),
"lon": float(payload.get("lon", 0)),
"radius_m": int(payload.get("radius_m", 2000)),
"year": int(payload.get("year", 2022)),
"season": payload.get("season", "summer"),
"model": payload.get("model", "Ensemble"),
"smoothing_kernel": int(payload.get("smoothing_kernel", 5)),
"outputs": {
"refined": payload.get("outputs", {}).get("refined", True),
"dw_baseline": payload.get("outputs", {}).get("dw_baseline", False),
"true_color": payload.get("outputs", {}).get("true_color", False),
"indices": payload.get("outputs", {}).get("indices", []),
},
}
return validated, errors
# ==========================================
# Async DW Loading Helper
# ==========================================
def _load_dw_async(storage, bbox, year, season) -> Optional[Tuple[np.ndarray, dict]]:
"""Async wrapper for DW baseline loading."""
from dw_baseline import load_dw_baseline_window
try:
dw_arr, dw_profile = load_dw_baseline_window(
storage=storage,
aoi_bbox_wgs84=bbox,
year=year,
season=season,
)
print(f"[_dw_load] DW baseline loaded: shape={dw_arr.shape}")
return dw_arr, dw_profile
except Exception as e:
print(f"[_dw_load] DW baseline failed: {e}")
return None
# ==========================================
# Main Job Runner (Async)
# ==========================================
def run_job(payload_dict: dict) -> dict:
"""Main job runner with async DW baseline loading.
DW baseline loads in background while hybrid inference runs.
DW URL is sent to client as soon as it's ready, parallel to inference.
"""
from rq import get_current_job
current_job = get_current_job()
job_id = payload_dict.get("job_id")
if not job_id and current_job:
job_id = current_job.id
if not job_id:
job_id = "unknown"
payload_dict["job_id"] = job_id
if "radius_km" in payload_dict and "radius_m" not in payload_dict:
payload_dict["radius_m"] = int(float(payload_dict["radius_km"]) * 1000)
if "model_name" in payload_dict and "model" not in payload_dict:
payload_dict["model"] = payload_dict["model_name"]
# Initialize storage
try:
from storage import MinIOStorage
storage = MinIOStorage()
except Exception as e:
update_status(
job_id, "failed", "init", 0,
f"Failed to initialize storage: {e}",
error={"type": "StorageError", "message": str(e)}
)
return {"status": "failed", "error": str(e)}
payload, errors = parse_and_validate_payload(payload_dict)
if errors:
update_status(
job_id, "failed", "validation", 0,
f"Validation failed: {errors}",
error={"type": "ValidationError", "message": "; ".join(errors)}
)
return {"status": "failed", "errors": errors}
update_status(job_id, "running", "init", 5, "Starting inference pipeline...")
dw_baseline_url = None
output_urls = {}
missing_outputs = []
try:
# Get config and AOI bbox
from config import InferenceConfig, MinIOStorage as ConfigMinIO
cfg = InferenceConfig()
cfg.storage = ConfigMinIO()
start_date, end_date = cfg.season_dates(payload['year'], payload['season'])
lat, lon, radius = payload['lat'], payload['lon'], payload['radius_m']
radius_deg = radius / 111000
bbox = [lon - radius_deg, lat - radius_deg, lon + radius_deg, lat + radius_deg]
# ==========================================
# Start DW baseline loading in background
# ==========================================
update_status(job_id, "running", "load_dw", 10, "Loading Dynamic World baseline (async)...")
print(f"[{job_id}] Starting async DW baseline load...")
with ThreadPoolExecutor(max_workers=1) as dw_executor:
dw_future = dw_executor.submit(
_load_dw_async,
storage, bbox, payload['year'], payload['season']
)
# ==========================================
# Start hybrid inference immediately (in parallel)
# ==========================================
update_status(job_id, "running", "load_model", 20, "Loading model artifacts...")
model_dir = Path(tempfile.mkdtemp())
print(f"[{job_id}] Downloading model artifacts...")
# Download model artifacts
for artifact in ["pipeline_meta.pkl", "Temporal_FCN.pth", "calibrated_hybrid_cb.pkl"]:
try:
storage.download_file(storage.bucket_models, artifact, model_dir / artifact)
print(f"[{job_id}] Downloaded {artifact}")
except Exception as e:
try:
storage.download_file(storage.bucket_models, f"models/{artifact}", model_dir / artifact)
print(f"[{job_id}] Downloaded {artifact} (from models/ prefix)")
except Exception as e2:
raise FileNotFoundError(
f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}"
)
update_status(job_id, "running", "fetch_stac", 30, "Fetching spatio-temporal data...")
from hybrid_inference import DEAfricaSTACWrapper, CropInferencePipeline
stac_wrapper = DEAfricaSTACWrapper()
lat_range = (bbox[1], bbox[3])
lon_range = (bbox[0], bbox[2])
time_range = (start_date, end_date)
print(f"[{job_id}] Fetching STAC data from DEA...")
unseen_pixel_df = stac_wrapper.fetch_and_format_data(
lat_range=lat_range,
lon_range=lon_range,
time_range=time_range
)
print(f"[{job_id}] STAC data fetched: {len(unseen_pixel_df)} pixels")
# Check if DW is ready while processing STAC
if dw_future.done():
dw_result = dw_future.result()
if dw_result is not None:
dw_arr, dw_profile = dw_result
import rasterio
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
dst.write(dw_arr)
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
storage.upload_result(dw_temp_path, dw_key)
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
update_status(
job_id, "running", "dw_ready", 35,
"Dynamic World baseline ready",
outputs={"dw_baseline_url": dw_baseline_url},
)
update_status(job_id, "running", "infer", 50, "Running Hybrid Inference (CNN + CatBoost)...")
print(f"[{job_id}] Running hybrid inference...")
pipeline = CropInferencePipeline(model_dir=str(model_dir))
mapped_crops_df = pipeline.predict(
unseen_pixel_df,
apply_spatial_smoothing=True,
coord_cols=['lat', 'lon']
)
print(f"[{job_id}] Inference complete, exporting results...")
# ==========================================
# Export and Upload Results
# ==========================================
update_status(job_id, "running", "export_cog", 80, "Exporting results...")
output_dir = Path(tempfile.mkdtemp())
output_path = output_dir / "refined.tif"
pipeline.export_to_geotiff(mapped_crops_df, output_path=str(output_path))
# Upload results
for filename in ["refined.tif", "refined_confidence.tif", "refined_cloud_mask.tif", "refined_legend.json"]:
local_f = output_dir / filename
if local_f.exists():
result_key = f"results/{job_id}/{filename}"
storage.upload_result(local_f, result_key)
output_urls[filename.replace(".","_url")] = storage.presign_get("geocrop-results", result_key)
# Check DW one more time (may have finished during inference)
if dw_baseline_url is None and dw_future.done():
dw_result = dw_future.result()
if dw_result is not None:
dw_arr, dw_profile = dw_result
import rasterio
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
dst.write(dw_arr)
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
storage.upload_result(dw_temp_path, dw_key)
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
# Wait for DW if still running
if dw_baseline_url is None:
print(f"[{job_id}] Waiting for DW baseline to finish...")
dw_result = dw_future.result(timeout=60)
if dw_result is not None:
dw_arr, dw_profile = dw_result
import rasterio
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
dst.write(dw_arr)
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
storage.upload_result(dw_temp_path, dw_key)
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
# ==========================================
# Final Status
# ==========================================
final_outputs = dict(output_urls)
if dw_baseline_url:
final_outputs["dw_baseline_url"] = dw_baseline_url
if payload['outputs'].get('indices'):
missing_outputs.append("indices: not implemented")
if payload['outputs'].get('true_color'):
missing_outputs.append("true_color: not implemented")
final_status = "partial" if missing_outputs else "done"
final_message = f"Inference complete" + (f" ({', '.join(missing_outputs)})" if missing_outputs else "")
update_status(
job_id,
final_status,
"done",
100,
final_message,
outputs=final_outputs if final_outputs else None,
)
print(f"[{job_id}] Job complete: {final_status}")
return {
"status": final_status,
"job_id": job_id,
"outputs": final_outputs,
"missing": missing_outputs if missing_outputs else None,
}
except Exception as e:
error_trace = traceback.format_exc()
print(f"[{job_id}] Error: {e}")
print(error_trace)
update_status(
job_id, "failed", "error", 0,
f"Unexpected error: {e}",
error={"type": type(e).__name__, "message": str(e), "trace": error_trace}
)
return {
"status": "failed",
"error": str(e),
"job_id": job_id,
}
run_inference = run_job
# ==========================================
# RQ Worker Entry Point
# ==========================================
def start_rq_worker():
"""Start the RQ worker to listen for jobs on the geocrop_tasks queue."""
from rq import Worker
import signal
if '/app' not in sys.path:
sys.path.insert(0, '/app')
queue_name = os.getenv("RQ_QUEUE_NAME", "geocrop_tasks")
print(f"=== GeoCrop RQ Worker Starting ===")
print(f"Listening on queue: {queue_name}")
print(f"Redis: {os.getenv('REDIS_HOST', 'redis.geocrop.svc.cluster.local')}:{os.getenv('REDIS_PORT', '6379')}")
def signal_handler(signum, frame):
print("\nReceived shutdown signal, exiting gracefully...")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
q = Queue(queue_name, connection=redis_conn)
w = Worker([q], connection=redis_conn)
w.work()
except KeyboardInterrupt:
print("\nWorker interrupted, shutting down...")
except Exception as e:
print(f"Worker error: {e}")
raise
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="GeoCrop Worker")
parser.add_argument("--test", action="store_true", help="Run syntax test only")
parser.add_argument("--worker", action="store_true", help="Start RQ worker")
args = parser.parse_args()
if args.test or not args.worker:
print("=== GeoCrop Worker Syntax Test ===")
try:
from contracts import STAGES, VALID_MODELS
from storage import MinIOStorage
print(f"✓ Imports OK")
print(f" STAGES: {STAGES}")
print(f" VALID_MODELS: {VALID_MODELS}")
except ImportError as e:
print(f"⚠ Some imports missing: {e}")
print("\n--- Payload Parsing Test ---")
test_payload = {
"job_id": "test-123",
"lat": -17.8,
"lon": 31.0,
"radius_m": 2000,
"year": 2022,
"model": "Hybrid_SpatioTemporal",
"smoothing_kernel": 5,
"outputs": {"refined": True, "dw_baseline": True},
}
validated, errors = parse_and_validate_payload(test_payload)
if errors:
print(f"✗ Validation errors: {errors}")
else:
print(f"✓ Payload validation passed")
print(f" job_id: {validated['job_id']}")
print(f" AOI: ({validated['lat']}, {validated['lon']}) radius={validated['radius_m']}m")
print(f" model: {validated['model']}")
print("\n=== Worker Syntax Test Complete ===")
if args.worker:
start_rq_worker()