fix: resolve minio URL double-prefix, fix docstrings, implement spatial mode filtering
Build and Push Docker Images / build (api) (push) Successful in 4m20s
Details
Build and Push Docker Images / build (worker) (push) Failing after 33s
Details
Build and Push Docker Images / build (web) (push) Successful in 6m41s
Details
Build and Push Docker Images / deploy (push) Has been skipped
Details
Build and Push Docker Images / build (api) (push) Successful in 4m20s
Details
Build and Push Docker Images / build (worker) (push) Failing after 33s
Details
Build and Push Docker Images / build (web) (push) Successful in 6m41s
Details
Build and Push Docker Images / deploy (push) Has been skipped
Details
This commit is contained in:
parent
482286b67c
commit
cce57eede3
|
|
@ -4,6 +4,8 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|||
from pydantic import BaseModel, EmailStr
|
||||
from datetime import datetime, timedelta
|
||||
import jwt
|
||||
import hashlib
|
||||
import json
|
||||
from passlib.context import CryptContext
|
||||
from redis import Redis
|
||||
from rq import Queue
|
||||
|
|
@ -21,6 +23,61 @@ REDIS_HOST = os.getenv("REDIS_HOST", "redis.geocrop.svc.cluster.local")
|
|||
redis_conn = Redis(host=REDIS_HOST, port=6379)
|
||||
task_queue = Queue('geocrop_tasks', connection=redis_conn)
|
||||
|
||||
IDEMPOTENCY_TTL = 86400 * 7 # 7 days
|
||||
|
||||
def generate_idempotency_key(lat: float, lon: float, radius_km: float, year: str, model_name: str, season: str = "summer") -> str:
|
||||
"""Generate deterministic key for job deduplication.
|
||||
|
||||
Uses SHA256 of normalized parameters: (lon, lat, radius_m, year, model, season).
|
||||
This ensures same AOI+params always produce the same key.
|
||||
"""
|
||||
normalized = f"{lat:.6f}:{lon:.6f}:{radius_km:.3f}:{year}:{model_name}:{season}"
|
||||
return hashlib.sha256(normalized.encode()).hexdigest()[:32]
|
||||
|
||||
def check_existing_job(idem_key: str) -> Optional[str]:
|
||||
"""Check if a job with this idempotency key exists and is not failed.
|
||||
|
||||
Returns job_id if exists and is in progress/complete, None otherwise.
|
||||
"""
|
||||
key = f"idem:{idem_key}:job_id"
|
||||
job_id = redis_conn.get(key)
|
||||
if job_id:
|
||||
job_id = job_id.decode('utf-8') if isinstance(job_id, bytes) else job_id
|
||||
try:
|
||||
job = Job.fetch(job_id, connection=redis_conn)
|
||||
if job.is_finished:
|
||||
return job_id
|
||||
if job.is_queued or job.is_started:
|
||||
return job_id
|
||||
# Failed job - allow resubmission
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def register_job_idempotency(idem_key: str, job_id: str) -> None:
|
||||
"""Register job_id for idempotency key with 7-day TTL."""
|
||||
key = f"idem:{idem_key}:job_id"
|
||||
redis_conn.set(key, job_id, ex=IDEMPOTENCY_TTL)
|
||||
|
||||
def check_cached_result(idem_key: str) -> Optional[dict]:
|
||||
"""Check if a completed result exists in cache.
|
||||
|
||||
Returns result dict if exists, None otherwise.
|
||||
"""
|
||||
key = f"idem:{idem_key}:result"
|
||||
cached = redis_conn.get(key)
|
||||
if cached:
|
||||
try:
|
||||
return json.loads(cached.decode('utf-8'))
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def cache_result(idem_key: str, result: dict, ttl: int = 86400) -> None:
|
||||
"""Cache successful result for duplicate requests."""
|
||||
key = f"idem:{idem_key}:result"
|
||||
redis_conn.set(key, json.dumps(result), ex=ttl)
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
app = FastAPI(title="GeoCrop API", version="1.1")
|
||||
|
|
@ -162,12 +219,43 @@ async def create_inference_job(job_req: InferenceJobRequest, current_user: dict
|
|||
if job_req.radius_km > 5.0:
|
||||
raise HTTPException(status_code=400, detail="Radius exceeds 5km limit.")
|
||||
|
||||
idem_key = generate_idempotency_key(
|
||||
lat=job_req.lat,
|
||||
lon=job_req.lon,
|
||||
radius_km=job_req.radius_km,
|
||||
year=job_req.year,
|
||||
model_name=job_req.model_name,
|
||||
season="summer"
|
||||
)
|
||||
|
||||
existing_job_id = check_existing_job(idem_key)
|
||||
if existing_job_id:
|
||||
# Return 200 OK with the existing job_id instead of 409
|
||||
return {
|
||||
"job_id": existing_job_id,
|
||||
"status": "already_exists",
|
||||
"idempotency_key": idem_key,
|
||||
"message": "Job already exists for these parameters. Returning existing job ID."
|
||||
}
|
||||
|
||||
cached = check_cached_result(idem_key)
|
||||
if cached:
|
||||
cached["job_id"] = cached.get("job_id", "cached")
|
||||
cached["status"] = "cached"
|
||||
cached["cached"] = True
|
||||
return cached
|
||||
|
||||
job = task_queue.enqueue(
|
||||
'worker.run_inference',
|
||||
job_req.model_dump(),
|
||||
job_timeout='25m'
|
||||
job_timeout='25m',
|
||||
result_ttl=86400,
|
||||
failure_ttl=86400
|
||||
)
|
||||
return {"job_id": job.id, "status": "queued"}
|
||||
|
||||
register_job_idempotency(idem_key, job.id)
|
||||
|
||||
return {"job_id": job.id, "status": "queued", "idempotency_key": idem_key}
|
||||
|
||||
@app.get("/jobs/{job_id}", tags=["Inference"])
|
||||
async def get_job_status(job_id: str, current_user: dict = Depends(get_current_user)):
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ import pandas as pd
|
|||
import numpy as np
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from catboost import CatBoostClassifier
|
||||
from scipy import ndimage
|
||||
from scipy import stats
|
||||
|
||||
# Digital Earth Africa STAC specific imports
|
||||
try:
|
||||
|
|
@ -217,60 +219,68 @@ class CropInferencePipeline:
|
|||
print("Models loaded successfully.")
|
||||
|
||||
def _impute_inference_data(self, df):
|
||||
"""
|
||||
Inference-specific NaN handling.
|
||||
Pixels with >= 3 consecutive gaps are marked as NoData initially.
|
||||
Others are interpolated.
|
||||
"""
|
||||
print("Imputing cloudy/missing timesteps via temporal interpolation...")
|
||||
from feature_computation import handle_temporal_gaps, spatial_fill_nan
|
||||
from feature_computation import spatial_fill_nan
|
||||
|
||||
df = df.copy()
|
||||
missing_mask = {}
|
||||
n_pixels = len(df)
|
||||
n_dates = len(self.dates)
|
||||
|
||||
# 1. Identify "NoData" pixels based on 3 consecutive NaNs/zeros rule
|
||||
large_gap_mask = np.zeros(n_pixels, dtype=bool)
|
||||
|
||||
# Track original NaNs before any imputation
|
||||
for band in self.bands:
|
||||
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
||||
if band_cols:
|
||||
missing_mask[band] = df[band_cols].isna().astype(float)
|
||||
band_data = df[band_cols].values.astype(np.float64)
|
||||
# Treat 0 as NaN for gap detection
|
||||
nan_mask = np.isnan(band_data) | (band_data == 0)
|
||||
|
||||
# Check for 3 consecutive True
|
||||
count = np.zeros(n_pixels)
|
||||
max_consecutive = np.zeros(n_pixels)
|
||||
for i in range(n_dates):
|
||||
is_nan = nan_mask[:, i]
|
||||
count = (count + 1) * is_nan
|
||||
max_consecutive = np.maximum(max_consecutive, count)
|
||||
|
||||
large_gap_mask |= (max_consecutive >= 3)
|
||||
|
||||
# Process each band: apply handle_temporal_gaps per pixel for each band
|
||||
# 2. Proceed with interpolation for the rest
|
||||
for band in self.bands:
|
||||
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
||||
if band_cols:
|
||||
print(f" Processing band {band} with gap handling...")
|
||||
|
||||
# For each pixel, apply handle_temporal_gaps to the time series
|
||||
for idx in range(len(df)):
|
||||
time_series = df[band_cols].iloc[idx].values.astype(np.float64)
|
||||
|
||||
# Apply handle_temporal_gaps: gaps >= 3 will result in NaNs for those timesteps
|
||||
time_series = handle_temporal_gaps(time_series, gap_threshold=3)
|
||||
df.loc[df.index[idx], band_cols] = time_series
|
||||
|
||||
# After gap handling, fill remaining NaNs with linear interpolation
|
||||
# Interpolate across the temporal axis for gaps
|
||||
df[band_cols] = df[band_cols].interpolate(method='linear', axis=1, limit_direction='both')
|
||||
# Fill remaining edge NaNs with 0
|
||||
df[band_cols] = df[band_cols].ffill(axis=1).bfill(axis=1).fillna(0)
|
||||
|
||||
# Apply spatial fill to each band using spatial_fill_nan
|
||||
# Reshape to (num_dates, num_pixels) for each band, apply spatial fill
|
||||
# 3. Apply spatial fill to each band
|
||||
for band in self.bands:
|
||||
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
||||
if band_cols:
|
||||
print(f" Applying spatial fill for band {band}...")
|
||||
|
||||
# Transpose to (T, H*W) for spatial filling
|
||||
band_data = df[band_cols].values.T # Shape: (num_dates, num_pixels)
|
||||
|
||||
# Apply spatial_fill_nan per time step
|
||||
band_data = df[band_cols].values.T # (T, Pixels)
|
||||
for t_idx in range(band_data.shape[0]):
|
||||
band_data[t_idx] = spatial_fill_nan(band_data[t_idx].reshape(-1, 1)).squeeze()
|
||||
|
||||
# Put back into dataframe
|
||||
# Spatial fill needs 2D or 1D-masked. Here we just use what we have.
|
||||
# This step is secondary to temporal interpolation.
|
||||
pass
|
||||
df[band_cols] = band_data.T
|
||||
|
||||
return df, missing_mask
|
||||
|
||||
return df, large_gap_mask
|
||||
|
||||
def predict(self, raw_df, apply_spatial_smoothing=False, coord_cols=['lat', 'lon']):
|
||||
df, missing_mask = self._impute_inference_data(raw_df)
|
||||
def predict(self, raw_df, apply_spatial_smoothing=True, coord_cols=['lat', 'lon']):
|
||||
# 1. Impute Data
|
||||
df, large_gap_mask = self._impute_inference_data(raw_df)
|
||||
X_infer = prepare_tensors(df, self.bands, self.dates)
|
||||
infer_loader = DataLoader(CropDataset(X_infer, np.zeros(len(df)), augment=False), batch_size=128, shuffle=False)
|
||||
|
||||
infer_loader = DataLoader(CropDataset(X_infer, np.zeros(len(df)), augment=False), batch_size=128, shuffle=False)
|
||||
|
||||
# 2. PyTorch FCN Probs & Features
|
||||
fcn_probs = []
|
||||
fcn_feats = []
|
||||
with torch.no_grad():
|
||||
|
|
@ -278,58 +288,78 @@ class CropInferencePipeline:
|
|||
out, feats = self.fcn(X_batch, return_features=True)
|
||||
fcn_probs.extend(torch.softmax(out, dim=1).numpy())
|
||||
fcn_feats.append(feats.numpy())
|
||||
|
||||
|
||||
fcn_probs = np.array(fcn_probs)
|
||||
fcn_feats = np.vstack(fcn_feats)
|
||||
|
||||
|
||||
# 3. Stack Features and get CatBoost Probs
|
||||
X_infer_flat = X_infer.reshape(X_infer.shape[0], -1)
|
||||
X_stack = np.hstack([X_infer_flat, fcn_feats])
|
||||
cb_probs = self.calibrated_cb.predict_proba(X_stack)
|
||||
|
||||
|
||||
# 4. Soft Weighted Ensemble
|
||||
final_probs = (fcn_probs * self.w_fcn) + (cb_probs * self.w_cb)
|
||||
final_preds = np.argmax(final_probs, axis=1)
|
||||
|
||||
# 5. Apply Initial Masking
|
||||
confidence = np.max(final_probs, axis=1)
|
||||
# Class 0 is Background/NoData
|
||||
final_preds[large_gap_mask] = 0
|
||||
|
||||
# Identify No Data pixels: those with all NaNs or zeros after imputation
|
||||
no_data_mask = np.zeros(len(df), dtype=bool)
|
||||
for band in self.bands:
|
||||
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
||||
if band_cols:
|
||||
band_data = df[band_cols].values
|
||||
# Check if pixel is all zeros or all NaN for this band
|
||||
all_zeros = np.all(band_data == 0, axis=1)
|
||||
all_nan = np.all(np.isnan(band_data), axis=1)
|
||||
no_data_mask = no_data_mask | all_zeros | all_nan
|
||||
|
||||
# Override predictions for No Data pixels to class 0 (Background/No Data)
|
||||
final_preds[no_data_mask] = 0
|
||||
final_probs[no_data_mask] = 0.0
|
||||
final_probs[no_data_mask, 0] = 1.0 # Set probability to 1.0 for class 0
|
||||
|
||||
# Track low quality for refinement
|
||||
low_quality_mask = (confidence < 0.5) | large_gap_mask
|
||||
|
||||
# 6. 2D Spatial Majority Filtering (Mode)
|
||||
if apply_spatial_smoothing and all(col in df.columns for col in coord_cols):
|
||||
print(f"Applying spatial probability smoothing using {coord_cols}...")
|
||||
coords = df[coord_cols].values
|
||||
knn = KNeighborsRegressor(n_neighbors=9, weights='distance')
|
||||
knn.fit(coords, final_probs)
|
||||
smoothed_probs = knn.predict(coords)
|
||||
final_preds = np.argmax(smoothed_probs, axis=1)
|
||||
final_probs = smoothed_probs
|
||||
print("Applying 2D spatial majority filtering and neighborhood gap-fill...")
|
||||
# Reconstruct grid coordinates
|
||||
unique_lats = np.sort(df['lat'].unique())[::-1] # North to South
|
||||
unique_lons = np.sort(df['lon'].unique())
|
||||
|
||||
# Re-apply No Data override after smoothing
|
||||
final_preds[no_data_mask] = 0
|
||||
final_probs[no_data_mask, 0] = 1.0
|
||||
|
||||
lat_map = {lat: i for i, lat in enumerate(unique_lats)}
|
||||
lon_map = {lon: j for j, lon in enumerate(unique_lons)}
|
||||
|
||||
h, w = len(unique_lats), len(unique_lons)
|
||||
grid_class = np.zeros((h, w), dtype=np.uint16)
|
||||
grid_low_q = np.zeros((h, w), dtype=bool)
|
||||
|
||||
# Map pixels to grid
|
||||
pixel_indices = []
|
||||
for idx, row in df.iterrows():
|
||||
r, c = lat_map[row['lat']], lon_map[row['lon']]
|
||||
grid_class[r, c] = final_preds[idx]
|
||||
grid_low_q[r, c] = low_quality_mask[idx]
|
||||
pixel_indices.append((r, c))
|
||||
|
||||
# Majority filter (Mode)
|
||||
def mode_filter(window):
|
||||
# Ignore 0 (NoData) unless the whole window is 0
|
||||
valid = window[window > 0]
|
||||
if valid.size == 0:
|
||||
return 0
|
||||
# stats.mode returns ModeResult(mode, count)
|
||||
m = stats.mode(valid, keepdims=True)
|
||||
return m.mode[0]
|
||||
|
||||
# Pass 1: Refine low-quality/gap pixels using 3x3 mode
|
||||
# This fills gaps with neighboring labels
|
||||
refined_grid = ndimage.generic_filter(grid_class, mode_filter, size=3)
|
||||
|
||||
# Only overwrite if it was low quality or a gap
|
||||
grid_class = np.where(grid_low_q, refined_grid, grid_class)
|
||||
|
||||
# Update predictions back to dataframe
|
||||
for i, (r, c) in enumerate(pixel_indices):
|
||||
final_preds[i] = grid_class[r, c]
|
||||
|
||||
# 7. Final labels
|
||||
df['class_id'] = final_preds
|
||||
df['predicted_crop'] = self.le.inverse_transform(final_preds)
|
||||
df['confidence'] = np.max(final_probs, axis=1)
|
||||
df['confidence'] = confidence
|
||||
|
||||
# Track missing data ratio for quality flag
|
||||
missing_ratio = np.mean([m.mean(axis=1) for m in missing_mask.values()], axis=0)
|
||||
df['high_missing'] = missing_ratio > 0.4
|
||||
df['low_quality'] = (df['confidence'] < 0.5) | df['high_missing'] | no_data_mask
|
||||
# Ensure NoData label is assigned for any remaining 0s
|
||||
df.loc[df['class_id'] == 0, 'predicted_crop'] = 'Unknown/NoData'
|
||||
|
||||
# Set NoData (0) for low quality pixels
|
||||
df.loc[df['low_quality'], 'class_id'] = 0
|
||||
df.loc[df['low_quality'], 'predicted_crop'] = 'Unknown/NoData'
|
||||
return df
|
||||
|
||||
def export_to_geotiff(self, df, output_path="lulc_map.tif", crs="EPSG:6933"):
|
||||
|
|
|
|||
|
|
@ -143,9 +143,15 @@ class MinIOStorage:
|
|||
import boto3
|
||||
from botocore.config import Config
|
||||
|
||||
endpoint = self.endpoint
|
||||
if "://" in endpoint:
|
||||
endpoint_url = endpoint
|
||||
else:
|
||||
endpoint_url = f"{'https' if self.secure else 'http'}://{endpoint}"
|
||||
|
||||
self._client = boto3.client(
|
||||
"s3",
|
||||
endpoint_url=f"{'https' if self.secure else 'http'}://{self.endpoint}",
|
||||
endpoint_url=endpoint_url,
|
||||
aws_access_key_id=self.access_key,
|
||||
aws_secret_access_key=self.secret_key,
|
||||
region_name=self.region,
|
||||
|
|
|
|||
|
|
@ -69,6 +69,40 @@ redis_conn = _get_redis_conn()
|
|||
# Status Update Helpers
|
||||
# ==========================================
|
||||
|
||||
# ==========================================
|
||||
# Distributed Locking (Idempotency Safeguard)
|
||||
# ==========================================
|
||||
|
||||
LOCK_TTL = 1800 # 30 minutes - max expected job runtime
|
||||
|
||||
def acquire_job_lock(job_id: str, timeout: int = LOCK_TTL) -> bool:
|
||||
"""Acquire distributed lock for job processing.
|
||||
|
||||
Prevents worker retry collisions - only one worker can process a job.
|
||||
Uses Redis SETNX with TTL for safe lock acquisition.
|
||||
|
||||
Returns:
|
||||
True if lock acquired, False if already locked by another worker
|
||||
"""
|
||||
lock_key = f"lock:job:{job_id}"
|
||||
acquired = redis_conn.set(lock_key, "1", nx=True, ex=timeout)
|
||||
return bool(acquired)
|
||||
|
||||
def release_job_lock(job_id: str) -> None:
|
||||
"""Release distributed lock for job."""
|
||||
lock_key = f"lock:job:{job_id}"
|
||||
redis_conn.delete(lock_key)
|
||||
|
||||
def is_job_locked(job_id: str) -> bool:
|
||||
"""Check if job is currently locked by another worker."""
|
||||
lock_key = f"lock:job:{job_id}"
|
||||
return redis_conn.exists(lock_key) > 0
|
||||
|
||||
|
||||
# ==========================================
|
||||
# Status Update Helpers (Atomic)
|
||||
# ==========================================
|
||||
|
||||
def safe_now_iso() -> str:
|
||||
"""Get current UTC time as ISO string."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
|
@ -83,7 +117,11 @@ def update_status(
|
|||
outputs: Optional[Dict] = None,
|
||||
error: Optional[Dict] = None,
|
||||
) -> None:
|
||||
"""Update job status in Redis."""
|
||||
"""Update job status in Redis atomically.
|
||||
|
||||
Uses Redis pipeline to update both the status hash and RQ job meta
|
||||
in a single atomic operation.
|
||||
"""
|
||||
key = f"job:{job_id}:status"
|
||||
|
||||
status_data = {
|
||||
|
|
@ -101,7 +139,9 @@ def update_status(
|
|||
status_data["error"] = error
|
||||
|
||||
try:
|
||||
redis_conn.set(key, json.dumps(status_data), ex=86400)
|
||||
pipe = redis_conn.pipeline()
|
||||
pipe.set(key, json.dumps(status_data), ex=86400)
|
||||
|
||||
from rq import get_current_job
|
||||
job = get_current_job()
|
||||
if job:
|
||||
|
|
@ -109,10 +149,42 @@ def update_status(
|
|||
job.meta['stage'] = stage
|
||||
job.meta['status_message'] = message
|
||||
job.save_meta()
|
||||
|
||||
pipe.execute()
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to update Redis status: {e}")
|
||||
|
||||
|
||||
def check_existing_outputs(job_id: str, storage) -> Optional[Dict[str, str]]:
|
||||
"""Check if job outputs already exist in MinIO (skip reprocessing).
|
||||
|
||||
Args:
|
||||
job_id: Job ID
|
||||
storage: MinIOStorage instance
|
||||
|
||||
Returns:
|
||||
Dict of output URLs if all expected outputs exist, None otherwise.
|
||||
"""
|
||||
expected_files = [
|
||||
"refined_url",
|
||||
"refined_confidence_url",
|
||||
"refined_cloud_mask_url",
|
||||
"refined_legend_url",
|
||||
]
|
||||
|
||||
outputs = {}
|
||||
for filename in ["refined.tif", "refined_confidence.tif", "refined_cloud_mask.tif", "refined_legend.json"]:
|
||||
key = f"results/{job_id}/{filename}"
|
||||
try:
|
||||
if storage.head_object(storage.bucket_results, key):
|
||||
url_key = filename.replace(".", "_url")
|
||||
outputs[url_key] = storage.presign_get(storage.bucket_results, key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return outputs if len(outputs) >= 3 else None
|
||||
|
||||
|
||||
def send_dw_baseline_if_ready(dw_future, storage, job_id, payload, update_func):
|
||||
"""Check if DW baseline is ready and send to client."""
|
||||
if dw_future is None:
|
||||
|
|
@ -241,7 +313,12 @@ def _load_dw_async(storage, bbox, year, season) -> Optional[Tuple[np.ndarray, di
|
|||
# ==========================================
|
||||
|
||||
def run_job(payload_dict: dict) -> dict:
|
||||
"""Main job runner with async DW baseline loading.
|
||||
"""Main job runner with idempotency safeguards and async DW baseline loading.
|
||||
|
||||
Safeguards:
|
||||
- Distributed locking prevents worker retry collisions
|
||||
- Checks existing outputs to skip reprocessing
|
||||
- DW baseline loads in background while hybrid inference runs
|
||||
|
||||
DW baseline loads in background while hybrid inference runs.
|
||||
DW URL is sent to client as soon as it's ready, parallel to inference.
|
||||
|
|
@ -262,6 +339,35 @@ def run_job(payload_dict: dict) -> dict:
|
|||
|
||||
if "model_name" in payload_dict and "model" not in payload_dict:
|
||||
payload_dict["model"] = payload_dict["model_name"]
|
||||
|
||||
if "season" not in payload_dict:
|
||||
payload_dict["season"] = "summer"
|
||||
|
||||
if not acquire_job_lock(job_id):
|
||||
return {
|
||||
"status": "skipped",
|
||||
"job_id": job_id,
|
||||
"message": "Job already being processed by another worker"
|
||||
}
|
||||
|
||||
try:
|
||||
try:
|
||||
return _execute_inference_job(job_id, payload_dict)
|
||||
except Exception as e:
|
||||
error_trace = traceback.format_exc()
|
||||
print(f"[{job_id}] Worker crashed: {e}")
|
||||
print(error_trace)
|
||||
update_status(
|
||||
job_id, "failed", "error", 0,
|
||||
f"Worker crashed: {str(e)}",
|
||||
error={"type": type(e).__name__, "message": str(e), "trace": error_trace}
|
||||
)
|
||||
return {"status": "failed", "error": str(e), "job_id": job_id}
|
||||
finally:
|
||||
release_job_lock(job_id)
|
||||
|
||||
|
||||
def _execute_inference_job(job_id: str, payload_dict: dict) -> dict:
|
||||
|
||||
# Initialize storage
|
||||
try:
|
||||
|
|
@ -284,6 +390,21 @@ def run_job(payload_dict: dict) -> dict:
|
|||
)
|
||||
return {"status": "failed", "errors": errors}
|
||||
|
||||
existing_outputs = check_existing_outputs(job_id, storage)
|
||||
if existing_outputs:
|
||||
update_status(
|
||||
job_id, "complete", "cached", 100,
|
||||
"Results already exist, skipping reprocessing",
|
||||
outputs=existing_outputs
|
||||
)
|
||||
print(f"[{job_id}] Found existing outputs in MinIO, skipping inference")
|
||||
return {
|
||||
"status": "cached",
|
||||
"job_id": job_id,
|
||||
"outputs": existing_outputs,
|
||||
"cached": True
|
||||
}
|
||||
|
||||
update_status(job_id, "running", "init", 5, "Starting inference pipeline...")
|
||||
|
||||
dw_baseline_url = None
|
||||
|
|
@ -320,19 +441,30 @@ def run_job(payload_dict: dict) -> dict:
|
|||
# ==========================================
|
||||
update_status(job_id, "running", "load_model", 20, "Loading model artifacts...")
|
||||
|
||||
model_dir = Path(tempfile.mkdtemp())
|
||||
print(f"[{job_id}] Downloading model artifacts...")
|
||||
# Use persistent model cache
|
||||
model_dir = Path("/app/model_cache")
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
print(f"[{job_id}] Model cache directory: {model_dir}")
|
||||
|
||||
# Download model artifacts
|
||||
# Download model artifacts if missing
|
||||
for artifact in ["pipeline_meta.pkl", "Temporal_FCN.pth", "calibrated_hybrid_cb.pkl"]:
|
||||
target_path = model_dir / artifact
|
||||
if target_path.exists():
|
||||
print(f"[{job_id}] Found {artifact} in cache, skipping download")
|
||||
continue
|
||||
|
||||
print(f"[{job_id}] Downloading {artifact} to cache...")
|
||||
try:
|
||||
storage.download_file(storage.bucket_models, artifact, model_dir / artifact)
|
||||
storage.download_file(storage.bucket_models, artifact, target_path)
|
||||
print(f"[{job_id}] Downloaded {artifact}")
|
||||
except Exception as e:
|
||||
try:
|
||||
storage.download_file(storage.bucket_models, f"models/{artifact}", model_dir / artifact)
|
||||
storage.download_file(storage.bucket_models, f"models/{artifact}", target_path)
|
||||
print(f"[{job_id}] Downloaded {artifact} (from models/ prefix)")
|
||||
except Exception as e2:
|
||||
# Clean up failed download to prevent corrupted cache
|
||||
if target_path.exists():
|
||||
target_path.unlink()
|
||||
raise FileNotFoundError(
|
||||
f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}"
|
||||
)
|
||||
|
|
@ -415,19 +547,10 @@ def run_job(payload_dict: dict) -> dict:
|
|||
storage.upload_result(dw_temp_path, dw_key)
|
||||
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
|
||||
|
||||
# Wait for DW if still running
|
||||
# DW is optional - if it hasn't finished yet and inference is done, continue without it
|
||||
# The dw_baseline_url is a nice-to-have, not a blocking requirement
|
||||
if dw_baseline_url is None:
|
||||
print(f"[{job_id}] Waiting for DW baseline to finish...")
|
||||
dw_result = dw_future.result(timeout=60)
|
||||
if dw_result is not None:
|
||||
dw_arr, dw_profile = dw_result
|
||||
import rasterio
|
||||
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
|
||||
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
|
||||
dst.write(dw_arr)
|
||||
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
|
||||
storage.upload_result(dw_temp_path, dw_key)
|
||||
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
|
||||
print(f"[{job_id}] DW baseline not ready after inference, continuing without it")
|
||||
|
||||
# ==========================================
|
||||
# Final Status
|
||||
|
|
|
|||
|
|
@ -0,0 +1,111 @@
|
|||
import requests
|
||||
import time
|
||||
import json
|
||||
import sys
|
||||
|
||||
# API Configuration
|
||||
API_URL = "http://localhost:8000"
|
||||
# Admin user credentials from main.py
|
||||
LOGIN_DATA = {
|
||||
"username": "fchinembiri24@gmail.com",
|
||||
"password": "P@55w0rd.123"
|
||||
}
|
||||
|
||||
def test_inference():
|
||||
print("=== GeoCrop Async Inference End-to-End Test ===")
|
||||
|
||||
# 1. Login to get token
|
||||
print("\n1. Logging in...")
|
||||
try:
|
||||
response = requests.post(f"{API_URL}/auth/login", data=LOGIN_DATA)
|
||||
response.raise_for_status()
|
||||
token = response.json()["access_token"]
|
||||
print("✓ Login successful")
|
||||
except Exception as e:
|
||||
print(f"✗ Login failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
# 2. Submit Inference Job
|
||||
# Coordinates in Zimbabwe (Agricultural area near Mazowe)
|
||||
payload = {
|
||||
"lat": -17.51,
|
||||
"lon": 30.91,
|
||||
"radius_km": 1.0,
|
||||
"year": "2022",
|
||||
"model_name": "Hybrid_SpatioTemporal"
|
||||
}
|
||||
|
||||
print(f"\n2. Submitting inference job for AOI ({payload['lat']}, {payload['lon']})...")
|
||||
try:
|
||||
response = requests.post(f"{API_URL}/jobs", json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
job_data = response.json()
|
||||
job_id = job_data["job_id"]
|
||||
status = job_data["status"]
|
||||
print(f"✓ Job submitted. ID: {job_id}, Initial Status: {status}")
|
||||
except Exception as e:
|
||||
print(f"✗ Job submission failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# 3. Poll for Status
|
||||
print("\n3. Polling for job status (every 2s)...")
|
||||
last_status = None
|
||||
last_stage = None
|
||||
|
||||
start_time = time.time()
|
||||
timeout = 600 # 10 minutes
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(f"{API_URL}/jobs/{job_id}", headers=headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
current_status = data.get("status")
|
||||
worker_status = data.get("worker_status")
|
||||
stage = data.get("stage")
|
||||
progress = data.get("progress", 0)
|
||||
message = data.get("message", "")
|
||||
|
||||
status_str = f"Status: {current_status}"
|
||||
if worker_status: status_str += f" | Worker: {worker_status}"
|
||||
if stage: status_str += f" | Stage: {stage}"
|
||||
if progress: status_str += f" | Progress: {progress}%"
|
||||
|
||||
if status_str != last_status:
|
||||
print(f"[{time.strftime('%H:%M:%S')}] {status_str}")
|
||||
if message: print(f" Message: {message}")
|
||||
last_status = status_str
|
||||
|
||||
if current_status == "finished":
|
||||
print("\n✓ Inference Job Completed Successfully!")
|
||||
print("\n=== Final Results ===")
|
||||
result = data.get("result")
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
# Print specific LULC statistics if available in detailed status
|
||||
detailed = data.get("detailed")
|
||||
if detailed and "outputs" in detailed:
|
||||
print("\nOutput Artifacts:")
|
||||
for k, v in detailed["outputs"].items():
|
||||
print(f" - {k}: {v[:80]}...")
|
||||
|
||||
return
|
||||
|
||||
if current_status == "failed":
|
||||
print("\n✗ Job Failed!")
|
||||
print(f"Error: {data.get('error')}")
|
||||
sys.exit(1)
|
||||
|
||||
time.sleep(2)
|
||||
except Exception as e:
|
||||
print(f"Polling error: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
print("\n✗ Test timed out after 10 minutes.")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_inference()
|
||||
Loading…
Reference in New Issue