feat: insert load_dw stage after STAC search
Co-authored-by: aider (openrouter/minimax/minimax-m2.7) <aider@aider.chat>
This commit is contained in:
parent
44b9220369
commit
c2cc58d7ce
|
|
@ -28,6 +28,10 @@ from typing import Any, Dict, List, Optional
|
|||
from redis import Redis
|
||||
from rq import Queue
|
||||
|
||||
import numpy as np
|
||||
import rasterio
|
||||
from rasterio.io import MemoryFile
|
||||
|
||||
# ==========================================
|
||||
# Redis Configuration
|
||||
# ==========================================
|
||||
|
|
@ -260,6 +264,7 @@ def run_job(payload_dict: dict) -> dict:
|
|||
|
||||
missing_outputs = []
|
||||
output_urls = {}
|
||||
dw_baseline_url = None
|
||||
|
||||
try:
|
||||
# ==========================================
|
||||
|
|
@ -269,6 +274,7 @@ def run_job(payload_dict: dict) -> dict:
|
|||
|
||||
from stac_client import DEASTACClient
|
||||
from config import InferenceConfig, MinIOStorage as ConfigMinIO
|
||||
from dw_baseline import load_dw_baseline_window
|
||||
|
||||
cfg = InferenceConfig()
|
||||
# Initialize storage adapter for inference.py
|
||||
|
|
@ -303,10 +309,55 @@ def run_job(payload_dict: dict) -> dict:
|
|||
print(f"[{job_id}] STAC search failed: {e}")
|
||||
# Continue but note that features may be limited
|
||||
|
||||
# ==========================================
|
||||
# Stage 2: Load DW Baseline
|
||||
# ==========================================
|
||||
update_status(job_id, "running", "load_dw", 10, "Loading Dynamic World baseline...")
|
||||
|
||||
print(f"[{job_id}] Loading Dynamic World baseline for {payload['year']} {payload['season']}...")
|
||||
|
||||
try:
|
||||
# Load DW baseline for the AOI
|
||||
dw_arr, dw_profile = load_dw_baseline_window(
|
||||
storage=storage,
|
||||
aoi_bbox_wgs84=bbox,
|
||||
year=payload['year'],
|
||||
season=payload['season'],
|
||||
)
|
||||
print(f"[{job_id}] DW baseline loaded: shape={dw_arr.shape}")
|
||||
|
||||
# Save to temporary TIF file
|
||||
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
|
||||
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
|
||||
dst.write(dw_arr)
|
||||
print(f"[{job_id}] DW baseline saved to temp file: {dw_temp_path}")
|
||||
|
||||
# Upload to MinIO
|
||||
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
|
||||
storage.upload_result(dw_temp_path, dw_key)
|
||||
print(f"[{job_id}] DW baseline uploaded to: {dw_key}")
|
||||
|
||||
# Generate presigned URL
|
||||
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
|
||||
print(f"[{job_id}] DW baseline URL: {dw_baseline_url[:80]}...")
|
||||
|
||||
# Immediately update job status with DW baseline URL
|
||||
update_status(
|
||||
job_id, "running", "load_dw", 15,
|
||||
"DW baseline loaded and uploaded",
|
||||
outputs={"dw_baseline_url": dw_baseline_url},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{job_id}] Failed to load DW baseline: {e}")
|
||||
# Continue without DW baseline - not critical for inference
|
||||
dw_arr = None
|
||||
dw_profile = None
|
||||
|
||||
update_status(job_id, "running", "build_features", 20, "Building feature cube...")
|
||||
|
||||
# ==========================================
|
||||
# Stage 2: Build Feature Cube
|
||||
# Stage 3: Build Feature Cube
|
||||
# ==========================================
|
||||
print(f"[{job_id}] Building feature cube...")
|
||||
|
||||
|
|
@ -349,14 +400,13 @@ def run_job(payload_dict: dict) -> dict:
|
|||
print(f"[{job_id}] Generating synthetic features for pipeline test...")
|
||||
|
||||
# Determine raster dimensions from DW baseline if loaded
|
||||
if 'dw_arr' in dir() and dw_arr is not None:
|
||||
if dw_arr is not None:
|
||||
H, W = dw_arr.shape
|
||||
else:
|
||||
# Default size for testing
|
||||
H, W = 100, 100
|
||||
|
||||
# Generate synthetic features: shape (H, W, 51)
|
||||
import numpy as np
|
||||
|
||||
# Use year as seed for reproducible but varied features
|
||||
np.random.seed(payload['year'] + int(payload.get('lon', 0) * 100) + int(payload.get('lat', 0) * 100))
|
||||
|
|
@ -377,7 +427,7 @@ def run_job(payload_dict: dict) -> dict:
|
|||
print(f"[{job_id}] Synthetic feature cube shape: {feature_cube.shape}")
|
||||
|
||||
# ==========================================
|
||||
# Stage 3: Load Model Artifacts
|
||||
# Stage 4: Load Model Artifacts
|
||||
# ==========================================
|
||||
update_status(job_id, "running", "load_model", 40, "Loading model artifacts...")
|
||||
|
||||
|
|
@ -402,7 +452,7 @@ def run_job(payload_dict: dict) -> dict:
|
|||
raise FileNotFoundError(f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}")
|
||||
|
||||
# ==========================================
|
||||
# Stage 4: Fetch Spatio-Temporal Data
|
||||
# Stage 5: Fetch Spatio-Temporal Data
|
||||
# ==========================================
|
||||
update_status(job_id, "running", "fetch_stac", 50, "Fetching spatio-temporal indices...")
|
||||
from hybrid_inference import DEAfricaSTACWrapper, CropInferencePipeline
|
||||
|
|
@ -420,7 +470,7 @@ def run_job(payload_dict: dict) -> dict:
|
|||
)
|
||||
|
||||
# ==========================================
|
||||
# Stage 5: Hybrid Inference
|
||||
# Stage 6: Hybrid Inference
|
||||
# ==========================================
|
||||
update_status(job_id, "running", "infer", 70, "Running Hybrid Inference (CNN + CatBoost)...")
|
||||
pipeline = CropInferencePipeline(model_dir=str(model_dir))
|
||||
|
|
@ -432,7 +482,7 @@ def run_job(payload_dict: dict) -> dict:
|
|||
)
|
||||
|
||||
# ==========================================
|
||||
# Stage 6: Export and Upload
|
||||
# Stage 7: Export and Upload
|
||||
# ==========================================
|
||||
update_status(job_id, "running", "export_cog", 90, "Exporting results...")
|
||||
output_dir = Path(tempfile.mkdtemp())
|
||||
|
|
@ -475,20 +525,25 @@ def run_job(payload_dict: dict) -> dict:
|
|||
missing_outputs.append("true_color: not implemented")
|
||||
|
||||
# ==========================================
|
||||
# Stage 7: Final Status
|
||||
# Stage 8: Final Status
|
||||
# ==========================================
|
||||
final_status = "partial" if missing_outputs else "done"
|
||||
final_message = f"Inference complete"
|
||||
if missing_outputs:
|
||||
final_message += f" (partial: {', '.join(missing_outputs)})"
|
||||
|
||||
# Include DW baseline URL in final outputs if available
|
||||
final_outputs = dict(output_urls)
|
||||
if dw_baseline_url:
|
||||
final_outputs["dw_baseline_url"] = dw_baseline_url
|
||||
|
||||
update_status(
|
||||
job_id,
|
||||
final_status,
|
||||
"done",
|
||||
100,
|
||||
final_message,
|
||||
outputs=output_urls,
|
||||
outputs=final_outputs if final_outputs else None,
|
||||
)
|
||||
|
||||
print(f"[{job_id}] Job complete: {final_status}")
|
||||
|
|
@ -496,7 +551,7 @@ def run_job(payload_dict: dict) -> dict:
|
|||
return {
|
||||
"status": final_status,
|
||||
"job_id": job_id,
|
||||
"outputs": output_urls,
|
||||
"outputs": final_outputs,
|
||||
"missing": missing_outputs if missing_outputs else None,
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue