feat: insert load_dw stage after STAC search

Co-authored-by: aider (openrouter/minimax/minimax-m2.7) <aider@aider.chat>
This commit is contained in:
fchinembiri 2026-05-04 17:53:05 +02:00
parent 44b9220369
commit c2cc58d7ce
1 changed files with 65 additions and 10 deletions

View File

@ -28,6 +28,10 @@ from typing import Any, Dict, List, Optional
from redis import Redis from redis import Redis
from rq import Queue from rq import Queue
import numpy as np
import rasterio
from rasterio.io import MemoryFile
# ========================================== # ==========================================
# Redis Configuration # Redis Configuration
# ========================================== # ==========================================
@ -260,6 +264,7 @@ def run_job(payload_dict: dict) -> dict:
missing_outputs = [] missing_outputs = []
output_urls = {} output_urls = {}
dw_baseline_url = None
try: try:
# ========================================== # ==========================================
@ -269,6 +274,7 @@ def run_job(payload_dict: dict) -> dict:
from stac_client import DEASTACClient from stac_client import DEASTACClient
from config import InferenceConfig, MinIOStorage as ConfigMinIO from config import InferenceConfig, MinIOStorage as ConfigMinIO
from dw_baseline import load_dw_baseline_window
cfg = InferenceConfig() cfg = InferenceConfig()
# Initialize storage adapter for inference.py # Initialize storage adapter for inference.py
@ -303,10 +309,55 @@ def run_job(payload_dict: dict) -> dict:
print(f"[{job_id}] STAC search failed: {e}") print(f"[{job_id}] STAC search failed: {e}")
# Continue but note that features may be limited # Continue but note that features may be limited
# ==========================================
# Stage 2: Load DW Baseline
# ==========================================
update_status(job_id, "running", "load_dw", 10, "Loading Dynamic World baseline...")
print(f"[{job_id}] Loading Dynamic World baseline for {payload['year']} {payload['season']}...")
try:
# Load DW baseline for the AOI
dw_arr, dw_profile = load_dw_baseline_window(
storage=storage,
aoi_bbox_wgs84=bbox,
year=payload['year'],
season=payload['season'],
)
print(f"[{job_id}] DW baseline loaded: shape={dw_arr.shape}")
# Save to temporary TIF file
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
dst.write(dw_arr)
print(f"[{job_id}] DW baseline saved to temp file: {dw_temp_path}")
# Upload to MinIO
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
storage.upload_result(dw_temp_path, dw_key)
print(f"[{job_id}] DW baseline uploaded to: {dw_key}")
# Generate presigned URL
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
print(f"[{job_id}] DW baseline URL: {dw_baseline_url[:80]}...")
# Immediately update job status with DW baseline URL
update_status(
job_id, "running", "load_dw", 15,
"DW baseline loaded and uploaded",
outputs={"dw_baseline_url": dw_baseline_url},
)
except Exception as e:
print(f"[{job_id}] Failed to load DW baseline: {e}")
# Continue without DW baseline - not critical for inference
dw_arr = None
dw_profile = None
update_status(job_id, "running", "build_features", 20, "Building feature cube...") update_status(job_id, "running", "build_features", 20, "Building feature cube...")
# ========================================== # ==========================================
# Stage 2: Build Feature Cube # Stage 3: Build Feature Cube
# ========================================== # ==========================================
print(f"[{job_id}] Building feature cube...") print(f"[{job_id}] Building feature cube...")
@ -349,14 +400,13 @@ def run_job(payload_dict: dict) -> dict:
print(f"[{job_id}] Generating synthetic features for pipeline test...") print(f"[{job_id}] Generating synthetic features for pipeline test...")
# Determine raster dimensions from DW baseline if loaded # Determine raster dimensions from DW baseline if loaded
if 'dw_arr' in dir() and dw_arr is not None: if dw_arr is not None:
H, W = dw_arr.shape H, W = dw_arr.shape
else: else:
# Default size for testing # Default size for testing
H, W = 100, 100 H, W = 100, 100
# Generate synthetic features: shape (H, W, 51) # Generate synthetic features: shape (H, W, 51)
import numpy as np
# Use year as seed for reproducible but varied features # Use year as seed for reproducible but varied features
np.random.seed(payload['year'] + int(payload.get('lon', 0) * 100) + int(payload.get('lat', 0) * 100)) np.random.seed(payload['year'] + int(payload.get('lon', 0) * 100) + int(payload.get('lat', 0) * 100))
@ -377,7 +427,7 @@ def run_job(payload_dict: dict) -> dict:
print(f"[{job_id}] Synthetic feature cube shape: {feature_cube.shape}") print(f"[{job_id}] Synthetic feature cube shape: {feature_cube.shape}")
# ========================================== # ==========================================
# Stage 3: Load Model Artifacts # Stage 4: Load Model Artifacts
# ========================================== # ==========================================
update_status(job_id, "running", "load_model", 40, "Loading model artifacts...") update_status(job_id, "running", "load_model", 40, "Loading model artifacts...")
@ -402,7 +452,7 @@ def run_job(payload_dict: dict) -> dict:
raise FileNotFoundError(f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}") raise FileNotFoundError(f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}")
# ========================================== # ==========================================
# Stage 4: Fetch Spatio-Temporal Data # Stage 5: Fetch Spatio-Temporal Data
# ========================================== # ==========================================
update_status(job_id, "running", "fetch_stac", 50, "Fetching spatio-temporal indices...") update_status(job_id, "running", "fetch_stac", 50, "Fetching spatio-temporal indices...")
from hybrid_inference import DEAfricaSTACWrapper, CropInferencePipeline from hybrid_inference import DEAfricaSTACWrapper, CropInferencePipeline
@ -420,7 +470,7 @@ def run_job(payload_dict: dict) -> dict:
) )
# ========================================== # ==========================================
# Stage 5: Hybrid Inference # Stage 6: Hybrid Inference
# ========================================== # ==========================================
update_status(job_id, "running", "infer", 70, "Running Hybrid Inference (CNN + CatBoost)...") update_status(job_id, "running", "infer", 70, "Running Hybrid Inference (CNN + CatBoost)...")
pipeline = CropInferencePipeline(model_dir=str(model_dir)) pipeline = CropInferencePipeline(model_dir=str(model_dir))
@ -432,7 +482,7 @@ def run_job(payload_dict: dict) -> dict:
) )
# ========================================== # ==========================================
# Stage 6: Export and Upload # Stage 7: Export and Upload
# ========================================== # ==========================================
update_status(job_id, "running", "export_cog", 90, "Exporting results...") update_status(job_id, "running", "export_cog", 90, "Exporting results...")
output_dir = Path(tempfile.mkdtemp()) output_dir = Path(tempfile.mkdtemp())
@ -475,20 +525,25 @@ def run_job(payload_dict: dict) -> dict:
missing_outputs.append("true_color: not implemented") missing_outputs.append("true_color: not implemented")
# ========================================== # ==========================================
# Stage 7: Final Status # Stage 8: Final Status
# ========================================== # ==========================================
final_status = "partial" if missing_outputs else "done" final_status = "partial" if missing_outputs else "done"
final_message = f"Inference complete" final_message = f"Inference complete"
if missing_outputs: if missing_outputs:
final_message += f" (partial: {', '.join(missing_outputs)})" final_message += f" (partial: {', '.join(missing_outputs)})"
# Include DW baseline URL in final outputs if available
final_outputs = dict(output_urls)
if dw_baseline_url:
final_outputs["dw_baseline_url"] = dw_baseline_url
update_status( update_status(
job_id, job_id,
final_status, final_status,
"done", "done",
100, 100,
final_message, final_message,
outputs=output_urls, outputs=final_outputs if final_outputs else None,
) )
print(f"[{job_id}] Job complete: {final_status}") print(f"[{job_id}] Job complete: {final_status}")
@ -496,7 +551,7 @@ def run_job(payload_dict: dict) -> dict:
return { return {
"status": final_status, "status": final_status,
"job_id": job_id, "job_id": job_id,
"outputs": output_urls, "outputs": final_outputs,
"missing": missing_outputs if missing_outputs else None, "missing": missing_outputs if missing_outputs else None,
} }