feat: insert load_dw stage after STAC search

Co-authored-by: aider (openrouter/minimax/minimax-m2.7) <aider@aider.chat>
This commit is contained in:
fchinembiri 2026-05-04 17:53:05 +02:00
parent 44b9220369
commit c2cc58d7ce
1 changed files with 65 additions and 10 deletions

View File

@ -28,6 +28,10 @@ from typing import Any, Dict, List, Optional
from redis import Redis
from rq import Queue
import numpy as np
import rasterio
from rasterio.io import MemoryFile
# ==========================================
# Redis Configuration
# ==========================================
@ -260,6 +264,7 @@ def run_job(payload_dict: dict) -> dict:
missing_outputs = []
output_urls = {}
dw_baseline_url = None
try:
# ==========================================
@ -269,6 +274,7 @@ def run_job(payload_dict: dict) -> dict:
from stac_client import DEASTACClient
from config import InferenceConfig, MinIOStorage as ConfigMinIO
from dw_baseline import load_dw_baseline_window
cfg = InferenceConfig()
# Initialize storage adapter for inference.py
@ -303,10 +309,55 @@ def run_job(payload_dict: dict) -> dict:
print(f"[{job_id}] STAC search failed: {e}")
# Continue but note that features may be limited
# ==========================================
# Stage 2: Load DW Baseline
# ==========================================
update_status(job_id, "running", "load_dw", 10, "Loading Dynamic World baseline...")
print(f"[{job_id}] Loading Dynamic World baseline for {payload['year']} {payload['season']}...")
try:
# Load DW baseline for the AOI
dw_arr, dw_profile = load_dw_baseline_window(
storage=storage,
aoi_bbox_wgs84=bbox,
year=payload['year'],
season=payload['season'],
)
print(f"[{job_id}] DW baseline loaded: shape={dw_arr.shape}")
# Save to temporary TIF file
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
dst.write(dw_arr)
print(f"[{job_id}] DW baseline saved to temp file: {dw_temp_path}")
# Upload to MinIO
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
storage.upload_result(dw_temp_path, dw_key)
print(f"[{job_id}] DW baseline uploaded to: {dw_key}")
# Generate presigned URL
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
print(f"[{job_id}] DW baseline URL: {dw_baseline_url[:80]}...")
# Immediately update job status with DW baseline URL
update_status(
job_id, "running", "load_dw", 15,
"DW baseline loaded and uploaded",
outputs={"dw_baseline_url": dw_baseline_url},
)
except Exception as e:
print(f"[{job_id}] Failed to load DW baseline: {e}")
# Continue without DW baseline - not critical for inference
dw_arr = None
dw_profile = None
update_status(job_id, "running", "build_features", 20, "Building feature cube...")
# ==========================================
# Stage 2: Build Feature Cube
# Stage 3: Build Feature Cube
# ==========================================
print(f"[{job_id}] Building feature cube...")
@ -349,14 +400,13 @@ def run_job(payload_dict: dict) -> dict:
print(f"[{job_id}] Generating synthetic features for pipeline test...")
# Determine raster dimensions from DW baseline if loaded
if 'dw_arr' in dir() and dw_arr is not None:
if dw_arr is not None:
H, W = dw_arr.shape
else:
# Default size for testing
H, W = 100, 100
# Generate synthetic features: shape (H, W, 51)
import numpy as np
# Use year as seed for reproducible but varied features
np.random.seed(payload['year'] + int(payload.get('lon', 0) * 100) + int(payload.get('lat', 0) * 100))
@ -377,7 +427,7 @@ def run_job(payload_dict: dict) -> dict:
print(f"[{job_id}] Synthetic feature cube shape: {feature_cube.shape}")
# ==========================================
# Stage 3: Load Model Artifacts
# Stage 4: Load Model Artifacts
# ==========================================
update_status(job_id, "running", "load_model", 40, "Loading model artifacts...")
@ -402,7 +452,7 @@ def run_job(payload_dict: dict) -> dict:
raise FileNotFoundError(f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}")
# ==========================================
# Stage 4: Fetch Spatio-Temporal Data
# Stage 5: Fetch Spatio-Temporal Data
# ==========================================
update_status(job_id, "running", "fetch_stac", 50, "Fetching spatio-temporal indices...")
from hybrid_inference import DEAfricaSTACWrapper, CropInferencePipeline
@ -420,7 +470,7 @@ def run_job(payload_dict: dict) -> dict:
)
# ==========================================
# Stage 5: Hybrid Inference
# Stage 6: Hybrid Inference
# ==========================================
update_status(job_id, "running", "infer", 70, "Running Hybrid Inference (CNN + CatBoost)...")
pipeline = CropInferencePipeline(model_dir=str(model_dir))
@ -432,7 +482,7 @@ def run_job(payload_dict: dict) -> dict:
)
# ==========================================
# Stage 6: Export and Upload
# Stage 7: Export and Upload
# ==========================================
update_status(job_id, "running", "export_cog", 90, "Exporting results...")
output_dir = Path(tempfile.mkdtemp())
@ -475,20 +525,25 @@ def run_job(payload_dict: dict) -> dict:
missing_outputs.append("true_color: not implemented")
# ==========================================
# Stage 7: Final Status
# Stage 8: Final Status
# ==========================================
final_status = "partial" if missing_outputs else "done"
final_message = f"Inference complete"
if missing_outputs:
final_message += f" (partial: {', '.join(missing_outputs)})"
# Include DW baseline URL in final outputs if available
final_outputs = dict(output_urls)
if dw_baseline_url:
final_outputs["dw_baseline_url"] = dw_baseline_url
update_status(
job_id,
final_status,
"done",
100,
final_message,
outputs=output_urls,
outputs=final_outputs if final_outputs else None,
)
print(f"[{job_id}] Job complete: {final_status}")
@ -496,7 +551,7 @@ def run_job(payload_dict: dict) -> dict:
return {
"status": final_status,
"job_id": job_id,
"outputs": output_urls,
"outputs": final_outputs,
"missing": missing_outputs if missing_outputs else None,
}