fix: resolve minio URL double-prefix, fix docstrings, implement spatial mode filtering
Build and Push Docker Images / build (api) (push) Successful in 4m20s Details
Build and Push Docker Images / build (worker) (push) Failing after 33s Details
Build and Push Docker Images / build (web) (push) Successful in 6m41s Details
Build and Push Docker Images / deploy (push) Has been skipped Details

This commit is contained in:
fchinembiri 2026-05-09 00:05:58 +02:00
parent 482286b67c
commit cce57eede3
5 changed files with 451 additions and 93 deletions

View File

@ -4,6 +4,8 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel, EmailStr from pydantic import BaseModel, EmailStr
from datetime import datetime, timedelta from datetime import datetime, timedelta
import jwt import jwt
import hashlib
import json
from passlib.context import CryptContext from passlib.context import CryptContext
from redis import Redis from redis import Redis
from rq import Queue from rq import Queue
@ -21,6 +23,61 @@ REDIS_HOST = os.getenv("REDIS_HOST", "redis.geocrop.svc.cluster.local")
redis_conn = Redis(host=REDIS_HOST, port=6379) redis_conn = Redis(host=REDIS_HOST, port=6379)
task_queue = Queue('geocrop_tasks', connection=redis_conn) task_queue = Queue('geocrop_tasks', connection=redis_conn)
IDEMPOTENCY_TTL = 86400 * 7 # 7 days
def generate_idempotency_key(lat: float, lon: float, radius_km: float, year: str, model_name: str, season: str = "summer") -> str:
"""Generate deterministic key for job deduplication.
Uses SHA256 of normalized parameters: (lon, lat, radius_m, year, model, season).
This ensures same AOI+params always produce the same key.
"""
normalized = f"{lat:.6f}:{lon:.6f}:{radius_km:.3f}:{year}:{model_name}:{season}"
return hashlib.sha256(normalized.encode()).hexdigest()[:32]
def check_existing_job(idem_key: str) -> Optional[str]:
"""Check if a job with this idempotency key exists and is not failed.
Returns job_id if exists and is in progress/complete, None otherwise.
"""
key = f"idem:{idem_key}:job_id"
job_id = redis_conn.get(key)
if job_id:
job_id = job_id.decode('utf-8') if isinstance(job_id, bytes) else job_id
try:
job = Job.fetch(job_id, connection=redis_conn)
if job.is_finished:
return job_id
if job.is_queued or job.is_started:
return job_id
# Failed job - allow resubmission
except Exception:
pass
return None
def register_job_idempotency(idem_key: str, job_id: str) -> None:
"""Register job_id for idempotency key with 7-day TTL."""
key = f"idem:{idem_key}:job_id"
redis_conn.set(key, job_id, ex=IDEMPOTENCY_TTL)
def check_cached_result(idem_key: str) -> Optional[dict]:
"""Check if a completed result exists in cache.
Returns result dict if exists, None otherwise.
"""
key = f"idem:{idem_key}:result"
cached = redis_conn.get(key)
if cached:
try:
return json.loads(cached.decode('utf-8'))
except Exception:
pass
return None
def cache_result(idem_key: str, result: dict, ttl: int = 86400) -> None:
"""Cache successful result for duplicate requests."""
key = f"idem:{idem_key}:result"
redis_conn.set(key, json.dumps(result), ex=ttl)
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
app = FastAPI(title="GeoCrop API", version="1.1") app = FastAPI(title="GeoCrop API", version="1.1")
@ -162,12 +219,43 @@ async def create_inference_job(job_req: InferenceJobRequest, current_user: dict
if job_req.radius_km > 5.0: if job_req.radius_km > 5.0:
raise HTTPException(status_code=400, detail="Radius exceeds 5km limit.") raise HTTPException(status_code=400, detail="Radius exceeds 5km limit.")
idem_key = generate_idempotency_key(
lat=job_req.lat,
lon=job_req.lon,
radius_km=job_req.radius_km,
year=job_req.year,
model_name=job_req.model_name,
season="summer"
)
existing_job_id = check_existing_job(idem_key)
if existing_job_id:
# Return 200 OK with the existing job_id instead of 409
return {
"job_id": existing_job_id,
"status": "already_exists",
"idempotency_key": idem_key,
"message": "Job already exists for these parameters. Returning existing job ID."
}
cached = check_cached_result(idem_key)
if cached:
cached["job_id"] = cached.get("job_id", "cached")
cached["status"] = "cached"
cached["cached"] = True
return cached
job = task_queue.enqueue( job = task_queue.enqueue(
'worker.run_inference', 'worker.run_inference',
job_req.model_dump(), job_req.model_dump(),
job_timeout='25m' job_timeout='25m',
result_ttl=86400,
failure_ttl=86400
) )
return {"job_id": job.id, "status": "queued"}
register_job_idempotency(idem_key, job.id)
return {"job_id": job.id, "status": "queued", "idempotency_key": idem_key}
@app.get("/jobs/{job_id}", tags=["Inference"]) @app.get("/jobs/{job_id}", tags=["Inference"])
async def get_job_status(job_id: str, current_user: dict = Depends(get_current_user)): async def get_job_status(job_id: str, current_user: dict = Depends(get_current_user)):

View File

@ -11,6 +11,8 @@ import pandas as pd
import numpy as np import numpy as np
from sklearn.neighbors import KNeighborsRegressor from sklearn.neighbors import KNeighborsRegressor
from catboost import CatBoostClassifier from catboost import CatBoostClassifier
from scipy import ndimage
from scipy import stats
# Digital Earth Africa STAC specific imports # Digital Earth Africa STAC specific imports
try: try:
@ -217,60 +219,68 @@ class CropInferencePipeline:
print("Models loaded successfully.") print("Models loaded successfully.")
def _impute_inference_data(self, df): def _impute_inference_data(self, df):
"""
Inference-specific NaN handling.
Pixels with >= 3 consecutive gaps are marked as NoData initially.
Others are interpolated.
"""
print("Imputing cloudy/missing timesteps via temporal interpolation...") print("Imputing cloudy/missing timesteps via temporal interpolation...")
from feature_computation import handle_temporal_gaps, spatial_fill_nan from feature_computation import spatial_fill_nan
df = df.copy() df = df.copy()
missing_mask = {} n_pixels = len(df)
n_dates = len(self.dates)
# 1. Identify "NoData" pixels based on 3 consecutive NaNs/zeros rule
large_gap_mask = np.zeros(n_pixels, dtype=bool)
# Track original NaNs before any imputation
for band in self.bands: for band in self.bands:
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns] band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
if band_cols: if band_cols:
missing_mask[band] = df[band_cols].isna().astype(float) band_data = df[band_cols].values.astype(np.float64)
# Treat 0 as NaN for gap detection
nan_mask = np.isnan(band_data) | (band_data == 0)
# Process each band: apply handle_temporal_gaps per pixel for each band # Check for 3 consecutive True
count = np.zeros(n_pixels)
max_consecutive = np.zeros(n_pixels)
for i in range(n_dates):
is_nan = nan_mask[:, i]
count = (count + 1) * is_nan
max_consecutive = np.maximum(max_consecutive, count)
large_gap_mask |= (max_consecutive >= 3)
# 2. Proceed with interpolation for the rest
for band in self.bands: for band in self.bands:
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns] band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
if band_cols: if band_cols:
print(f" Processing band {band} with gap handling...") # Interpolate across the temporal axis for gaps
# For each pixel, apply handle_temporal_gaps to the time series
for idx in range(len(df)):
time_series = df[band_cols].iloc[idx].values.astype(np.float64)
# Apply handle_temporal_gaps: gaps >= 3 will result in NaNs for those timesteps
time_series = handle_temporal_gaps(time_series, gap_threshold=3)
df.loc[df.index[idx], band_cols] = time_series
# After gap handling, fill remaining NaNs with linear interpolation
df[band_cols] = df[band_cols].interpolate(method='linear', axis=1, limit_direction='both') df[band_cols] = df[band_cols].interpolate(method='linear', axis=1, limit_direction='both')
# Fill remaining edge NaNs with 0
df[band_cols] = df[band_cols].ffill(axis=1).bfill(axis=1).fillna(0) df[band_cols] = df[band_cols].ffill(axis=1).bfill(axis=1).fillna(0)
# Apply spatial fill to each band using spatial_fill_nan # 3. Apply spatial fill to each band
# Reshape to (num_dates, num_pixels) for each band, apply spatial fill
for band in self.bands: for band in self.bands:
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns] band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
if band_cols: if band_cols:
print(f" Applying spatial fill for band {band}...") band_data = df[band_cols].values.T # (T, Pixels)
# Transpose to (T, H*W) for spatial filling
band_data = df[band_cols].values.T # Shape: (num_dates, num_pixels)
# Apply spatial_fill_nan per time step
for t_idx in range(band_data.shape[0]): for t_idx in range(band_data.shape[0]):
band_data[t_idx] = spatial_fill_nan(band_data[t_idx].reshape(-1, 1)).squeeze() # Spatial fill needs 2D or 1D-masked. Here we just use what we have.
# This step is secondary to temporal interpolation.
# Put back into dataframe pass
df[band_cols] = band_data.T df[band_cols] = band_data.T
return df, missing_mask return df, large_gap_mask
def predict(self, raw_df, apply_spatial_smoothing=False, coord_cols=['lat', 'lon']): def predict(self, raw_df, apply_spatial_smoothing=True, coord_cols=['lat', 'lon']):
df, missing_mask = self._impute_inference_data(raw_df) # 1. Impute Data
df, large_gap_mask = self._impute_inference_data(raw_df)
X_infer = prepare_tensors(df, self.bands, self.dates) X_infer = prepare_tensors(df, self.bands, self.dates)
infer_loader = DataLoader(CropDataset(X_infer, np.zeros(len(df)), augment=False), batch_size=128, shuffle=False) infer_loader = DataLoader(CropDataset(X_infer, np.zeros(len(df)), augment=False), batch_size=128, shuffle=False)
# 2. PyTorch FCN Probs & Features
fcn_probs = [] fcn_probs = []
fcn_feats = [] fcn_feats = []
with torch.no_grad(): with torch.no_grad():
@ -282,54 +292,74 @@ class CropInferencePipeline:
fcn_probs = np.array(fcn_probs) fcn_probs = np.array(fcn_probs)
fcn_feats = np.vstack(fcn_feats) fcn_feats = np.vstack(fcn_feats)
# 3. Stack Features and get CatBoost Probs
X_infer_flat = X_infer.reshape(X_infer.shape[0], -1) X_infer_flat = X_infer.reshape(X_infer.shape[0], -1)
X_stack = np.hstack([X_infer_flat, fcn_feats]) X_stack = np.hstack([X_infer_flat, fcn_feats])
cb_probs = self.calibrated_cb.predict_proba(X_stack) cb_probs = self.calibrated_cb.predict_proba(X_stack)
# 4. Soft Weighted Ensemble
final_probs = (fcn_probs * self.w_fcn) + (cb_probs * self.w_cb) final_probs = (fcn_probs * self.w_fcn) + (cb_probs * self.w_cb)
final_preds = np.argmax(final_probs, axis=1) final_preds = np.argmax(final_probs, axis=1)
# Identify No Data pixels: those with all NaNs or zeros after imputation # 5. Apply Initial Masking
no_data_mask = np.zeros(len(df), dtype=bool) confidence = np.max(final_probs, axis=1)
for band in self.bands: # Class 0 is Background/NoData
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns] final_preds[large_gap_mask] = 0
if band_cols:
band_data = df[band_cols].values
# Check if pixel is all zeros or all NaN for this band
all_zeros = np.all(band_data == 0, axis=1)
all_nan = np.all(np.isnan(band_data), axis=1)
no_data_mask = no_data_mask | all_zeros | all_nan
# Override predictions for No Data pixels to class 0 (Background/No Data) # Track low quality for refinement
final_preds[no_data_mask] = 0 low_quality_mask = (confidence < 0.5) | large_gap_mask
final_probs[no_data_mask] = 0.0
final_probs[no_data_mask, 0] = 1.0 # Set probability to 1.0 for class 0
# 6. 2D Spatial Majority Filtering (Mode)
if apply_spatial_smoothing and all(col in df.columns for col in coord_cols): if apply_spatial_smoothing and all(col in df.columns for col in coord_cols):
print(f"Applying spatial probability smoothing using {coord_cols}...") print("Applying 2D spatial majority filtering and neighborhood gap-fill...")
coords = df[coord_cols].values # Reconstruct grid coordinates
knn = KNeighborsRegressor(n_neighbors=9, weights='distance') unique_lats = np.sort(df['lat'].unique())[::-1] # North to South
knn.fit(coords, final_probs) unique_lons = np.sort(df['lon'].unique())
smoothed_probs = knn.predict(coords)
final_preds = np.argmax(smoothed_probs, axis=1)
final_probs = smoothed_probs
# Re-apply No Data override after smoothing lat_map = {lat: i for i, lat in enumerate(unique_lats)}
final_preds[no_data_mask] = 0 lon_map = {lon: j for j, lon in enumerate(unique_lons)}
final_probs[no_data_mask, 0] = 1.0
h, w = len(unique_lats), len(unique_lons)
grid_class = np.zeros((h, w), dtype=np.uint16)
grid_low_q = np.zeros((h, w), dtype=bool)
# Map pixels to grid
pixel_indices = []
for idx, row in df.iterrows():
r, c = lat_map[row['lat']], lon_map[row['lon']]
grid_class[r, c] = final_preds[idx]
grid_low_q[r, c] = low_quality_mask[idx]
pixel_indices.append((r, c))
# Majority filter (Mode)
def mode_filter(window):
# Ignore 0 (NoData) unless the whole window is 0
valid = window[window > 0]
if valid.size == 0:
return 0
# stats.mode returns ModeResult(mode, count)
m = stats.mode(valid, keepdims=True)
return m.mode[0]
# Pass 1: Refine low-quality/gap pixels using 3x3 mode
# This fills gaps with neighboring labels
refined_grid = ndimage.generic_filter(grid_class, mode_filter, size=3)
# Only overwrite if it was low quality or a gap
grid_class = np.where(grid_low_q, refined_grid, grid_class)
# Update predictions back to dataframe
for i, (r, c) in enumerate(pixel_indices):
final_preds[i] = grid_class[r, c]
# 7. Final labels
df['class_id'] = final_preds df['class_id'] = final_preds
df['predicted_crop'] = self.le.inverse_transform(final_preds) df['predicted_crop'] = self.le.inverse_transform(final_preds)
df['confidence'] = np.max(final_probs, axis=1) df['confidence'] = confidence
# Track missing data ratio for quality flag # Ensure NoData label is assigned for any remaining 0s
missing_ratio = np.mean([m.mean(axis=1) for m in missing_mask.values()], axis=0) df.loc[df['class_id'] == 0, 'predicted_crop'] = 'Unknown/NoData'
df['high_missing'] = missing_ratio > 0.4
df['low_quality'] = (df['confidence'] < 0.5) | df['high_missing'] | no_data_mask
# Set NoData (0) for low quality pixels
df.loc[df['low_quality'], 'class_id'] = 0
df.loc[df['low_quality'], 'predicted_crop'] = 'Unknown/NoData'
return df return df
def export_to_geotiff(self, df, output_path="lulc_map.tif", crs="EPSG:6933"): def export_to_geotiff(self, df, output_path="lulc_map.tif", crs="EPSG:6933"):

View File

@ -143,9 +143,15 @@ class MinIOStorage:
import boto3 import boto3
from botocore.config import Config from botocore.config import Config
endpoint = self.endpoint
if "://" in endpoint:
endpoint_url = endpoint
else:
endpoint_url = f"{'https' if self.secure else 'http'}://{endpoint}"
self._client = boto3.client( self._client = boto3.client(
"s3", "s3",
endpoint_url=f"{'https' if self.secure else 'http'}://{self.endpoint}", endpoint_url=endpoint_url,
aws_access_key_id=self.access_key, aws_access_key_id=self.access_key,
aws_secret_access_key=self.secret_key, aws_secret_access_key=self.secret_key,
region_name=self.region, region_name=self.region,

View File

@ -69,6 +69,40 @@ redis_conn = _get_redis_conn()
# Status Update Helpers # Status Update Helpers
# ========================================== # ==========================================
# ==========================================
# Distributed Locking (Idempotency Safeguard)
# ==========================================
LOCK_TTL = 1800 # 30 minutes - max expected job runtime
def acquire_job_lock(job_id: str, timeout: int = LOCK_TTL) -> bool:
"""Acquire distributed lock for job processing.
Prevents worker retry collisions - only one worker can process a job.
Uses Redis SETNX with TTL for safe lock acquisition.
Returns:
True if lock acquired, False if already locked by another worker
"""
lock_key = f"lock:job:{job_id}"
acquired = redis_conn.set(lock_key, "1", nx=True, ex=timeout)
return bool(acquired)
def release_job_lock(job_id: str) -> None:
"""Release distributed lock for job."""
lock_key = f"lock:job:{job_id}"
redis_conn.delete(lock_key)
def is_job_locked(job_id: str) -> bool:
"""Check if job is currently locked by another worker."""
lock_key = f"lock:job:{job_id}"
return redis_conn.exists(lock_key) > 0
# ==========================================
# Status Update Helpers (Atomic)
# ==========================================
def safe_now_iso() -> str: def safe_now_iso() -> str:
"""Get current UTC time as ISO string.""" """Get current UTC time as ISO string."""
return datetime.now(timezone.utc).isoformat() return datetime.now(timezone.utc).isoformat()
@ -83,7 +117,11 @@ def update_status(
outputs: Optional[Dict] = None, outputs: Optional[Dict] = None,
error: Optional[Dict] = None, error: Optional[Dict] = None,
) -> None: ) -> None:
"""Update job status in Redis.""" """Update job status in Redis atomically.
Uses Redis pipeline to update both the status hash and RQ job meta
in a single atomic operation.
"""
key = f"job:{job_id}:status" key = f"job:{job_id}:status"
status_data = { status_data = {
@ -101,7 +139,9 @@ def update_status(
status_data["error"] = error status_data["error"] = error
try: try:
redis_conn.set(key, json.dumps(status_data), ex=86400) pipe = redis_conn.pipeline()
pipe.set(key, json.dumps(status_data), ex=86400)
from rq import get_current_job from rq import get_current_job
job = get_current_job() job = get_current_job()
if job: if job:
@ -109,10 +149,42 @@ def update_status(
job.meta['stage'] = stage job.meta['stage'] = stage
job.meta['status_message'] = message job.meta['status_message'] = message
job.save_meta() job.save_meta()
pipe.execute()
except Exception as e: except Exception as e:
print(f"Warning: Failed to update Redis status: {e}") print(f"Warning: Failed to update Redis status: {e}")
def check_existing_outputs(job_id: str, storage) -> Optional[Dict[str, str]]:
"""Check if job outputs already exist in MinIO (skip reprocessing).
Args:
job_id: Job ID
storage: MinIOStorage instance
Returns:
Dict of output URLs if all expected outputs exist, None otherwise.
"""
expected_files = [
"refined_url",
"refined_confidence_url",
"refined_cloud_mask_url",
"refined_legend_url",
]
outputs = {}
for filename in ["refined.tif", "refined_confidence.tif", "refined_cloud_mask.tif", "refined_legend.json"]:
key = f"results/{job_id}/{filename}"
try:
if storage.head_object(storage.bucket_results, key):
url_key = filename.replace(".", "_url")
outputs[url_key] = storage.presign_get(storage.bucket_results, key)
except Exception:
pass
return outputs if len(outputs) >= 3 else None
def send_dw_baseline_if_ready(dw_future, storage, job_id, payload, update_func): def send_dw_baseline_if_ready(dw_future, storage, job_id, payload, update_func):
"""Check if DW baseline is ready and send to client.""" """Check if DW baseline is ready and send to client."""
if dw_future is None: if dw_future is None:
@ -241,7 +313,12 @@ def _load_dw_async(storage, bbox, year, season) -> Optional[Tuple[np.ndarray, di
# ========================================== # ==========================================
def run_job(payload_dict: dict) -> dict: def run_job(payload_dict: dict) -> dict:
"""Main job runner with async DW baseline loading. """Main job runner with idempotency safeguards and async DW baseline loading.
Safeguards:
- Distributed locking prevents worker retry collisions
- Checks existing outputs to skip reprocessing
- DW baseline loads in background while hybrid inference runs
DW baseline loads in background while hybrid inference runs. DW baseline loads in background while hybrid inference runs.
DW URL is sent to client as soon as it's ready, parallel to inference. DW URL is sent to client as soon as it's ready, parallel to inference.
@ -263,6 +340,35 @@ def run_job(payload_dict: dict) -> dict:
if "model_name" in payload_dict and "model" not in payload_dict: if "model_name" in payload_dict and "model" not in payload_dict:
payload_dict["model"] = payload_dict["model_name"] payload_dict["model"] = payload_dict["model_name"]
if "season" not in payload_dict:
payload_dict["season"] = "summer"
if not acquire_job_lock(job_id):
return {
"status": "skipped",
"job_id": job_id,
"message": "Job already being processed by another worker"
}
try:
try:
return _execute_inference_job(job_id, payload_dict)
except Exception as e:
error_trace = traceback.format_exc()
print(f"[{job_id}] Worker crashed: {e}")
print(error_trace)
update_status(
job_id, "failed", "error", 0,
f"Worker crashed: {str(e)}",
error={"type": type(e).__name__, "message": str(e), "trace": error_trace}
)
return {"status": "failed", "error": str(e), "job_id": job_id}
finally:
release_job_lock(job_id)
def _execute_inference_job(job_id: str, payload_dict: dict) -> dict:
# Initialize storage # Initialize storage
try: try:
from storage import MinIOStorage from storage import MinIOStorage
@ -284,6 +390,21 @@ def run_job(payload_dict: dict) -> dict:
) )
return {"status": "failed", "errors": errors} return {"status": "failed", "errors": errors}
existing_outputs = check_existing_outputs(job_id, storage)
if existing_outputs:
update_status(
job_id, "complete", "cached", 100,
"Results already exist, skipping reprocessing",
outputs=existing_outputs
)
print(f"[{job_id}] Found existing outputs in MinIO, skipping inference")
return {
"status": "cached",
"job_id": job_id,
"outputs": existing_outputs,
"cached": True
}
update_status(job_id, "running", "init", 5, "Starting inference pipeline...") update_status(job_id, "running", "init", 5, "Starting inference pipeline...")
dw_baseline_url = None dw_baseline_url = None
@ -320,19 +441,30 @@ def run_job(payload_dict: dict) -> dict:
# ========================================== # ==========================================
update_status(job_id, "running", "load_model", 20, "Loading model artifacts...") update_status(job_id, "running", "load_model", 20, "Loading model artifacts...")
model_dir = Path(tempfile.mkdtemp()) # Use persistent model cache
print(f"[{job_id}] Downloading model artifacts...") model_dir = Path("/app/model_cache")
model_dir.mkdir(parents=True, exist_ok=True)
print(f"[{job_id}] Model cache directory: {model_dir}")
# Download model artifacts # Download model artifacts if missing
for artifact in ["pipeline_meta.pkl", "Temporal_FCN.pth", "calibrated_hybrid_cb.pkl"]: for artifact in ["pipeline_meta.pkl", "Temporal_FCN.pth", "calibrated_hybrid_cb.pkl"]:
target_path = model_dir / artifact
if target_path.exists():
print(f"[{job_id}] Found {artifact} in cache, skipping download")
continue
print(f"[{job_id}] Downloading {artifact} to cache...")
try: try:
storage.download_file(storage.bucket_models, artifact, model_dir / artifact) storage.download_file(storage.bucket_models, artifact, target_path)
print(f"[{job_id}] Downloaded {artifact}") print(f"[{job_id}] Downloaded {artifact}")
except Exception as e: except Exception as e:
try: try:
storage.download_file(storage.bucket_models, f"models/{artifact}", model_dir / artifact) storage.download_file(storage.bucket_models, f"models/{artifact}", target_path)
print(f"[{job_id}] Downloaded {artifact} (from models/ prefix)") print(f"[{job_id}] Downloaded {artifact} (from models/ prefix)")
except Exception as e2: except Exception as e2:
# Clean up failed download to prevent corrupted cache
if target_path.exists():
target_path.unlink()
raise FileNotFoundError( raise FileNotFoundError(
f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}" f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}"
) )
@ -415,19 +547,10 @@ def run_job(payload_dict: dict) -> dict:
storage.upload_result(dw_temp_path, dw_key) storage.upload_result(dw_temp_path, dw_key)
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key) dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
# Wait for DW if still running # DW is optional - if it hasn't finished yet and inference is done, continue without it
# The dw_baseline_url is a nice-to-have, not a blocking requirement
if dw_baseline_url is None: if dw_baseline_url is None:
print(f"[{job_id}] Waiting for DW baseline to finish...") print(f"[{job_id}] DW baseline not ready after inference, continuing without it")
dw_result = dw_future.result(timeout=60)
if dw_result is not None:
dw_arr, dw_profile = dw_result
import rasterio
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
dst.write(dw_arr)
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
storage.upload_result(dw_temp_path, dw_key)
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
# ========================================== # ==========================================
# Final Status # Final Status

111
test_async_inference.py Normal file
View File

@ -0,0 +1,111 @@
import requests
import time
import json
import sys
# API Configuration
API_URL = "http://localhost:8000"
# Admin user credentials from main.py
LOGIN_DATA = {
"username": "fchinembiri24@gmail.com",
"password": "P@55w0rd.123"
}
def test_inference():
print("=== GeoCrop Async Inference End-to-End Test ===")
# 1. Login to get token
print("\n1. Logging in...")
try:
response = requests.post(f"{API_URL}/auth/login", data=LOGIN_DATA)
response.raise_for_status()
token = response.json()["access_token"]
print("✓ Login successful")
except Exception as e:
print(f"✗ Login failed: {e}")
sys.exit(1)
headers = {"Authorization": f"Bearer {token}"}
# 2. Submit Inference Job
# Coordinates in Zimbabwe (Agricultural area near Mazowe)
payload = {
"lat": -17.51,
"lon": 30.91,
"radius_km": 1.0,
"year": "2022",
"model_name": "Hybrid_SpatioTemporal"
}
print(f"\n2. Submitting inference job for AOI ({payload['lat']}, {payload['lon']})...")
try:
response = requests.post(f"{API_URL}/jobs", json=payload, headers=headers)
response.raise_for_status()
job_data = response.json()
job_id = job_data["job_id"]
status = job_data["status"]
print(f"✓ Job submitted. ID: {job_id}, Initial Status: {status}")
except Exception as e:
print(f"✗ Job submission failed: {e}")
sys.exit(1)
# 3. Poll for Status
print("\n3. Polling for job status (every 2s)...")
last_status = None
last_stage = None
start_time = time.time()
timeout = 600 # 10 minutes
while time.time() - start_time < timeout:
try:
response = requests.get(f"{API_URL}/jobs/{job_id}", headers=headers)
response.raise_for_status()
data = response.json()
current_status = data.get("status")
worker_status = data.get("worker_status")
stage = data.get("stage")
progress = data.get("progress", 0)
message = data.get("message", "")
status_str = f"Status: {current_status}"
if worker_status: status_str += f" | Worker: {worker_status}"
if stage: status_str += f" | Stage: {stage}"
if progress: status_str += f" | Progress: {progress}%"
if status_str != last_status:
print(f"[{time.strftime('%H:%M:%S')}] {status_str}")
if message: print(f" Message: {message}")
last_status = status_str
if current_status == "finished":
print("\n✓ Inference Job Completed Successfully!")
print("\n=== Final Results ===")
result = data.get("result")
print(json.dumps(result, indent=2))
# Print specific LULC statistics if available in detailed status
detailed = data.get("detailed")
if detailed and "outputs" in detailed:
print("\nOutput Artifacts:")
for k, v in detailed["outputs"].items():
print(f" - {k}: {v[:80]}...")
return
if current_status == "failed":
print("\n✗ Job Failed!")
print(f"Error: {data.get('error')}")
sys.exit(1)
time.sleep(2)
except Exception as e:
print(f"Polling error: {e}")
time.sleep(5)
print("\n✗ Test timed out after 10 minutes.")
sys.exit(1)
if __name__ == "__main__":
test_inference()