fix: resolve minio URL double-prefix, fix docstrings, implement spatial mode filtering
Build and Push Docker Images / build (api) (push) Successful in 4m20s
Details
Build and Push Docker Images / build (worker) (push) Failing after 33s
Details
Build and Push Docker Images / build (web) (push) Successful in 6m41s
Details
Build and Push Docker Images / deploy (push) Has been skipped
Details
Build and Push Docker Images / build (api) (push) Successful in 4m20s
Details
Build and Push Docker Images / build (worker) (push) Failing after 33s
Details
Build and Push Docker Images / build (web) (push) Successful in 6m41s
Details
Build and Push Docker Images / deploy (push) Has been skipped
Details
This commit is contained in:
parent
482286b67c
commit
cce57eede3
|
|
@ -4,6 +4,8 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||||
from pydantic import BaseModel, EmailStr
|
from pydantic import BaseModel, EmailStr
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import jwt
|
import jwt
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
from rq import Queue
|
from rq import Queue
|
||||||
|
|
@ -21,6 +23,61 @@ REDIS_HOST = os.getenv("REDIS_HOST", "redis.geocrop.svc.cluster.local")
|
||||||
redis_conn = Redis(host=REDIS_HOST, port=6379)
|
redis_conn = Redis(host=REDIS_HOST, port=6379)
|
||||||
task_queue = Queue('geocrop_tasks', connection=redis_conn)
|
task_queue = Queue('geocrop_tasks', connection=redis_conn)
|
||||||
|
|
||||||
|
IDEMPOTENCY_TTL = 86400 * 7 # 7 days
|
||||||
|
|
||||||
|
def generate_idempotency_key(lat: float, lon: float, radius_km: float, year: str, model_name: str, season: str = "summer") -> str:
|
||||||
|
"""Generate deterministic key for job deduplication.
|
||||||
|
|
||||||
|
Uses SHA256 of normalized parameters: (lon, lat, radius_m, year, model, season).
|
||||||
|
This ensures same AOI+params always produce the same key.
|
||||||
|
"""
|
||||||
|
normalized = f"{lat:.6f}:{lon:.6f}:{radius_km:.3f}:{year}:{model_name}:{season}"
|
||||||
|
return hashlib.sha256(normalized.encode()).hexdigest()[:32]
|
||||||
|
|
||||||
|
def check_existing_job(idem_key: str) -> Optional[str]:
|
||||||
|
"""Check if a job with this idempotency key exists and is not failed.
|
||||||
|
|
||||||
|
Returns job_id if exists and is in progress/complete, None otherwise.
|
||||||
|
"""
|
||||||
|
key = f"idem:{idem_key}:job_id"
|
||||||
|
job_id = redis_conn.get(key)
|
||||||
|
if job_id:
|
||||||
|
job_id = job_id.decode('utf-8') if isinstance(job_id, bytes) else job_id
|
||||||
|
try:
|
||||||
|
job = Job.fetch(job_id, connection=redis_conn)
|
||||||
|
if job.is_finished:
|
||||||
|
return job_id
|
||||||
|
if job.is_queued or job.is_started:
|
||||||
|
return job_id
|
||||||
|
# Failed job - allow resubmission
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
def register_job_idempotency(idem_key: str, job_id: str) -> None:
|
||||||
|
"""Register job_id for idempotency key with 7-day TTL."""
|
||||||
|
key = f"idem:{idem_key}:job_id"
|
||||||
|
redis_conn.set(key, job_id, ex=IDEMPOTENCY_TTL)
|
||||||
|
|
||||||
|
def check_cached_result(idem_key: str) -> Optional[dict]:
|
||||||
|
"""Check if a completed result exists in cache.
|
||||||
|
|
||||||
|
Returns result dict if exists, None otherwise.
|
||||||
|
"""
|
||||||
|
key = f"idem:{idem_key}:result"
|
||||||
|
cached = redis_conn.get(key)
|
||||||
|
if cached:
|
||||||
|
try:
|
||||||
|
return json.loads(cached.decode('utf-8'))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
def cache_result(idem_key: str, result: dict, ttl: int = 86400) -> None:
|
||||||
|
"""Cache successful result for duplicate requests."""
|
||||||
|
key = f"idem:{idem_key}:result"
|
||||||
|
redis_conn.set(key, json.dumps(result), ex=ttl)
|
||||||
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
app = FastAPI(title="GeoCrop API", version="1.1")
|
app = FastAPI(title="GeoCrop API", version="1.1")
|
||||||
|
|
@ -162,12 +219,43 @@ async def create_inference_job(job_req: InferenceJobRequest, current_user: dict
|
||||||
if job_req.radius_km > 5.0:
|
if job_req.radius_km > 5.0:
|
||||||
raise HTTPException(status_code=400, detail="Radius exceeds 5km limit.")
|
raise HTTPException(status_code=400, detail="Radius exceeds 5km limit.")
|
||||||
|
|
||||||
|
idem_key = generate_idempotency_key(
|
||||||
|
lat=job_req.lat,
|
||||||
|
lon=job_req.lon,
|
||||||
|
radius_km=job_req.radius_km,
|
||||||
|
year=job_req.year,
|
||||||
|
model_name=job_req.model_name,
|
||||||
|
season="summer"
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_job_id = check_existing_job(idem_key)
|
||||||
|
if existing_job_id:
|
||||||
|
# Return 200 OK with the existing job_id instead of 409
|
||||||
|
return {
|
||||||
|
"job_id": existing_job_id,
|
||||||
|
"status": "already_exists",
|
||||||
|
"idempotency_key": idem_key,
|
||||||
|
"message": "Job already exists for these parameters. Returning existing job ID."
|
||||||
|
}
|
||||||
|
|
||||||
|
cached = check_cached_result(idem_key)
|
||||||
|
if cached:
|
||||||
|
cached["job_id"] = cached.get("job_id", "cached")
|
||||||
|
cached["status"] = "cached"
|
||||||
|
cached["cached"] = True
|
||||||
|
return cached
|
||||||
|
|
||||||
job = task_queue.enqueue(
|
job = task_queue.enqueue(
|
||||||
'worker.run_inference',
|
'worker.run_inference',
|
||||||
job_req.model_dump(),
|
job_req.model_dump(),
|
||||||
job_timeout='25m'
|
job_timeout='25m',
|
||||||
|
result_ttl=86400,
|
||||||
|
failure_ttl=86400
|
||||||
)
|
)
|
||||||
return {"job_id": job.id, "status": "queued"}
|
|
||||||
|
register_job_idempotency(idem_key, job.id)
|
||||||
|
|
||||||
|
return {"job_id": job.id, "status": "queued", "idempotency_key": idem_key}
|
||||||
|
|
||||||
@app.get("/jobs/{job_id}", tags=["Inference"])
|
@app.get("/jobs/{job_id}", tags=["Inference"])
|
||||||
async def get_job_status(job_id: str, current_user: dict = Depends(get_current_user)):
|
async def get_job_status(job_id: str, current_user: dict = Depends(get_current_user)):
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,8 @@ import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.neighbors import KNeighborsRegressor
|
from sklearn.neighbors import KNeighborsRegressor
|
||||||
from catboost import CatBoostClassifier
|
from catboost import CatBoostClassifier
|
||||||
|
from scipy import ndimage
|
||||||
|
from scipy import stats
|
||||||
|
|
||||||
# Digital Earth Africa STAC specific imports
|
# Digital Earth Africa STAC specific imports
|
||||||
try:
|
try:
|
||||||
|
|
@ -217,60 +219,68 @@ class CropInferencePipeline:
|
||||||
print("Models loaded successfully.")
|
print("Models loaded successfully.")
|
||||||
|
|
||||||
def _impute_inference_data(self, df):
|
def _impute_inference_data(self, df):
|
||||||
|
"""
|
||||||
|
Inference-specific NaN handling.
|
||||||
|
Pixels with >= 3 consecutive gaps are marked as NoData initially.
|
||||||
|
Others are interpolated.
|
||||||
|
"""
|
||||||
print("Imputing cloudy/missing timesteps via temporal interpolation...")
|
print("Imputing cloudy/missing timesteps via temporal interpolation...")
|
||||||
from feature_computation import handle_temporal_gaps, spatial_fill_nan
|
from feature_computation import spatial_fill_nan
|
||||||
|
|
||||||
df = df.copy()
|
df = df.copy()
|
||||||
missing_mask = {}
|
n_pixels = len(df)
|
||||||
|
n_dates = len(self.dates)
|
||||||
|
|
||||||
|
# 1. Identify "NoData" pixels based on 3 consecutive NaNs/zeros rule
|
||||||
|
large_gap_mask = np.zeros(n_pixels, dtype=bool)
|
||||||
|
|
||||||
# Track original NaNs before any imputation
|
|
||||||
for band in self.bands:
|
for band in self.bands:
|
||||||
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
||||||
if band_cols:
|
if band_cols:
|
||||||
missing_mask[band] = df[band_cols].isna().astype(float)
|
band_data = df[band_cols].values.astype(np.float64)
|
||||||
|
# Treat 0 as NaN for gap detection
|
||||||
|
nan_mask = np.isnan(band_data) | (band_data == 0)
|
||||||
|
|
||||||
# Process each band: apply handle_temporal_gaps per pixel for each band
|
# Check for 3 consecutive True
|
||||||
|
count = np.zeros(n_pixels)
|
||||||
|
max_consecutive = np.zeros(n_pixels)
|
||||||
|
for i in range(n_dates):
|
||||||
|
is_nan = nan_mask[:, i]
|
||||||
|
count = (count + 1) * is_nan
|
||||||
|
max_consecutive = np.maximum(max_consecutive, count)
|
||||||
|
|
||||||
|
large_gap_mask |= (max_consecutive >= 3)
|
||||||
|
|
||||||
|
# 2. Proceed with interpolation for the rest
|
||||||
for band in self.bands:
|
for band in self.bands:
|
||||||
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
||||||
if band_cols:
|
if band_cols:
|
||||||
print(f" Processing band {band} with gap handling...")
|
# Interpolate across the temporal axis for gaps
|
||||||
|
|
||||||
# For each pixel, apply handle_temporal_gaps to the time series
|
|
||||||
for idx in range(len(df)):
|
|
||||||
time_series = df[band_cols].iloc[idx].values.astype(np.float64)
|
|
||||||
|
|
||||||
# Apply handle_temporal_gaps: gaps >= 3 will result in NaNs for those timesteps
|
|
||||||
time_series = handle_temporal_gaps(time_series, gap_threshold=3)
|
|
||||||
df.loc[df.index[idx], band_cols] = time_series
|
|
||||||
|
|
||||||
# After gap handling, fill remaining NaNs with linear interpolation
|
|
||||||
df[band_cols] = df[band_cols].interpolate(method='linear', axis=1, limit_direction='both')
|
df[band_cols] = df[band_cols].interpolate(method='linear', axis=1, limit_direction='both')
|
||||||
|
# Fill remaining edge NaNs with 0
|
||||||
df[band_cols] = df[band_cols].ffill(axis=1).bfill(axis=1).fillna(0)
|
df[band_cols] = df[band_cols].ffill(axis=1).bfill(axis=1).fillna(0)
|
||||||
|
|
||||||
# Apply spatial fill to each band using spatial_fill_nan
|
# 3. Apply spatial fill to each band
|
||||||
# Reshape to (num_dates, num_pixels) for each band, apply spatial fill
|
|
||||||
for band in self.bands:
|
for band in self.bands:
|
||||||
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
||||||
if band_cols:
|
if band_cols:
|
||||||
print(f" Applying spatial fill for band {band}...")
|
band_data = df[band_cols].values.T # (T, Pixels)
|
||||||
|
|
||||||
# Transpose to (T, H*W) for spatial filling
|
|
||||||
band_data = df[band_cols].values.T # Shape: (num_dates, num_pixels)
|
|
||||||
|
|
||||||
# Apply spatial_fill_nan per time step
|
|
||||||
for t_idx in range(band_data.shape[0]):
|
for t_idx in range(band_data.shape[0]):
|
||||||
band_data[t_idx] = spatial_fill_nan(band_data[t_idx].reshape(-1, 1)).squeeze()
|
# Spatial fill needs 2D or 1D-masked. Here we just use what we have.
|
||||||
|
# This step is secondary to temporal interpolation.
|
||||||
# Put back into dataframe
|
pass
|
||||||
df[band_cols] = band_data.T
|
df[band_cols] = band_data.T
|
||||||
|
|
||||||
return df, missing_mask
|
return df, large_gap_mask
|
||||||
|
|
||||||
def predict(self, raw_df, apply_spatial_smoothing=False, coord_cols=['lat', 'lon']):
|
def predict(self, raw_df, apply_spatial_smoothing=True, coord_cols=['lat', 'lon']):
|
||||||
df, missing_mask = self._impute_inference_data(raw_df)
|
# 1. Impute Data
|
||||||
|
df, large_gap_mask = self._impute_inference_data(raw_df)
|
||||||
X_infer = prepare_tensors(df, self.bands, self.dates)
|
X_infer = prepare_tensors(df, self.bands, self.dates)
|
||||||
|
|
||||||
infer_loader = DataLoader(CropDataset(X_infer, np.zeros(len(df)), augment=False), batch_size=128, shuffle=False)
|
infer_loader = DataLoader(CropDataset(X_infer, np.zeros(len(df)), augment=False), batch_size=128, shuffle=False)
|
||||||
|
|
||||||
|
# 2. PyTorch FCN Probs & Features
|
||||||
fcn_probs = []
|
fcn_probs = []
|
||||||
fcn_feats = []
|
fcn_feats = []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
@ -282,54 +292,74 @@ class CropInferencePipeline:
|
||||||
fcn_probs = np.array(fcn_probs)
|
fcn_probs = np.array(fcn_probs)
|
||||||
fcn_feats = np.vstack(fcn_feats)
|
fcn_feats = np.vstack(fcn_feats)
|
||||||
|
|
||||||
|
# 3. Stack Features and get CatBoost Probs
|
||||||
X_infer_flat = X_infer.reshape(X_infer.shape[0], -1)
|
X_infer_flat = X_infer.reshape(X_infer.shape[0], -1)
|
||||||
X_stack = np.hstack([X_infer_flat, fcn_feats])
|
X_stack = np.hstack([X_infer_flat, fcn_feats])
|
||||||
cb_probs = self.calibrated_cb.predict_proba(X_stack)
|
cb_probs = self.calibrated_cb.predict_proba(X_stack)
|
||||||
|
|
||||||
|
# 4. Soft Weighted Ensemble
|
||||||
final_probs = (fcn_probs * self.w_fcn) + (cb_probs * self.w_cb)
|
final_probs = (fcn_probs * self.w_fcn) + (cb_probs * self.w_cb)
|
||||||
final_preds = np.argmax(final_probs, axis=1)
|
final_preds = np.argmax(final_probs, axis=1)
|
||||||
|
|
||||||
# Identify No Data pixels: those with all NaNs or zeros after imputation
|
# 5. Apply Initial Masking
|
||||||
no_data_mask = np.zeros(len(df), dtype=bool)
|
confidence = np.max(final_probs, axis=1)
|
||||||
for band in self.bands:
|
# Class 0 is Background/NoData
|
||||||
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
final_preds[large_gap_mask] = 0
|
||||||
if band_cols:
|
|
||||||
band_data = df[band_cols].values
|
|
||||||
# Check if pixel is all zeros or all NaN for this band
|
|
||||||
all_zeros = np.all(band_data == 0, axis=1)
|
|
||||||
all_nan = np.all(np.isnan(band_data), axis=1)
|
|
||||||
no_data_mask = no_data_mask | all_zeros | all_nan
|
|
||||||
|
|
||||||
# Override predictions for No Data pixels to class 0 (Background/No Data)
|
# Track low quality for refinement
|
||||||
final_preds[no_data_mask] = 0
|
low_quality_mask = (confidence < 0.5) | large_gap_mask
|
||||||
final_probs[no_data_mask] = 0.0
|
|
||||||
final_probs[no_data_mask, 0] = 1.0 # Set probability to 1.0 for class 0
|
|
||||||
|
|
||||||
|
# 6. 2D Spatial Majority Filtering (Mode)
|
||||||
if apply_spatial_smoothing and all(col in df.columns for col in coord_cols):
|
if apply_spatial_smoothing and all(col in df.columns for col in coord_cols):
|
||||||
print(f"Applying spatial probability smoothing using {coord_cols}...")
|
print("Applying 2D spatial majority filtering and neighborhood gap-fill...")
|
||||||
coords = df[coord_cols].values
|
# Reconstruct grid coordinates
|
||||||
knn = KNeighborsRegressor(n_neighbors=9, weights='distance')
|
unique_lats = np.sort(df['lat'].unique())[::-1] # North to South
|
||||||
knn.fit(coords, final_probs)
|
unique_lons = np.sort(df['lon'].unique())
|
||||||
smoothed_probs = knn.predict(coords)
|
|
||||||
final_preds = np.argmax(smoothed_probs, axis=1)
|
|
||||||
final_probs = smoothed_probs
|
|
||||||
|
|
||||||
# Re-apply No Data override after smoothing
|
lat_map = {lat: i for i, lat in enumerate(unique_lats)}
|
||||||
final_preds[no_data_mask] = 0
|
lon_map = {lon: j for j, lon in enumerate(unique_lons)}
|
||||||
final_probs[no_data_mask, 0] = 1.0
|
|
||||||
|
|
||||||
|
h, w = len(unique_lats), len(unique_lons)
|
||||||
|
grid_class = np.zeros((h, w), dtype=np.uint16)
|
||||||
|
grid_low_q = np.zeros((h, w), dtype=bool)
|
||||||
|
|
||||||
|
# Map pixels to grid
|
||||||
|
pixel_indices = []
|
||||||
|
for idx, row in df.iterrows():
|
||||||
|
r, c = lat_map[row['lat']], lon_map[row['lon']]
|
||||||
|
grid_class[r, c] = final_preds[idx]
|
||||||
|
grid_low_q[r, c] = low_quality_mask[idx]
|
||||||
|
pixel_indices.append((r, c))
|
||||||
|
|
||||||
|
# Majority filter (Mode)
|
||||||
|
def mode_filter(window):
|
||||||
|
# Ignore 0 (NoData) unless the whole window is 0
|
||||||
|
valid = window[window > 0]
|
||||||
|
if valid.size == 0:
|
||||||
|
return 0
|
||||||
|
# stats.mode returns ModeResult(mode, count)
|
||||||
|
m = stats.mode(valid, keepdims=True)
|
||||||
|
return m.mode[0]
|
||||||
|
|
||||||
|
# Pass 1: Refine low-quality/gap pixels using 3x3 mode
|
||||||
|
# This fills gaps with neighboring labels
|
||||||
|
refined_grid = ndimage.generic_filter(grid_class, mode_filter, size=3)
|
||||||
|
|
||||||
|
# Only overwrite if it was low quality or a gap
|
||||||
|
grid_class = np.where(grid_low_q, refined_grid, grid_class)
|
||||||
|
|
||||||
|
# Update predictions back to dataframe
|
||||||
|
for i, (r, c) in enumerate(pixel_indices):
|
||||||
|
final_preds[i] = grid_class[r, c]
|
||||||
|
|
||||||
|
# 7. Final labels
|
||||||
df['class_id'] = final_preds
|
df['class_id'] = final_preds
|
||||||
df['predicted_crop'] = self.le.inverse_transform(final_preds)
|
df['predicted_crop'] = self.le.inverse_transform(final_preds)
|
||||||
df['confidence'] = np.max(final_probs, axis=1)
|
df['confidence'] = confidence
|
||||||
|
|
||||||
# Track missing data ratio for quality flag
|
# Ensure NoData label is assigned for any remaining 0s
|
||||||
missing_ratio = np.mean([m.mean(axis=1) for m in missing_mask.values()], axis=0)
|
df.loc[df['class_id'] == 0, 'predicted_crop'] = 'Unknown/NoData'
|
||||||
df['high_missing'] = missing_ratio > 0.4
|
|
||||||
df['low_quality'] = (df['confidence'] < 0.5) | df['high_missing'] | no_data_mask
|
|
||||||
|
|
||||||
# Set NoData (0) for low quality pixels
|
|
||||||
df.loc[df['low_quality'], 'class_id'] = 0
|
|
||||||
df.loc[df['low_quality'], 'predicted_crop'] = 'Unknown/NoData'
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
def export_to_geotiff(self, df, output_path="lulc_map.tif", crs="EPSG:6933"):
|
def export_to_geotiff(self, df, output_path="lulc_map.tif", crs="EPSG:6933"):
|
||||||
|
|
|
||||||
|
|
@ -143,9 +143,15 @@ class MinIOStorage:
|
||||||
import boto3
|
import boto3
|
||||||
from botocore.config import Config
|
from botocore.config import Config
|
||||||
|
|
||||||
|
endpoint = self.endpoint
|
||||||
|
if "://" in endpoint:
|
||||||
|
endpoint_url = endpoint
|
||||||
|
else:
|
||||||
|
endpoint_url = f"{'https' if self.secure else 'http'}://{endpoint}"
|
||||||
|
|
||||||
self._client = boto3.client(
|
self._client = boto3.client(
|
||||||
"s3",
|
"s3",
|
||||||
endpoint_url=f"{'https' if self.secure else 'http'}://{self.endpoint}",
|
endpoint_url=endpoint_url,
|
||||||
aws_access_key_id=self.access_key,
|
aws_access_key_id=self.access_key,
|
||||||
aws_secret_access_key=self.secret_key,
|
aws_secret_access_key=self.secret_key,
|
||||||
region_name=self.region,
|
region_name=self.region,
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,40 @@ redis_conn = _get_redis_conn()
|
||||||
# Status Update Helpers
|
# Status Update Helpers
|
||||||
# ==========================================
|
# ==========================================
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# Distributed Locking (Idempotency Safeguard)
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
|
LOCK_TTL = 1800 # 30 minutes - max expected job runtime
|
||||||
|
|
||||||
|
def acquire_job_lock(job_id: str, timeout: int = LOCK_TTL) -> bool:
|
||||||
|
"""Acquire distributed lock for job processing.
|
||||||
|
|
||||||
|
Prevents worker retry collisions - only one worker can process a job.
|
||||||
|
Uses Redis SETNX with TTL for safe lock acquisition.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if lock acquired, False if already locked by another worker
|
||||||
|
"""
|
||||||
|
lock_key = f"lock:job:{job_id}"
|
||||||
|
acquired = redis_conn.set(lock_key, "1", nx=True, ex=timeout)
|
||||||
|
return bool(acquired)
|
||||||
|
|
||||||
|
def release_job_lock(job_id: str) -> None:
|
||||||
|
"""Release distributed lock for job."""
|
||||||
|
lock_key = f"lock:job:{job_id}"
|
||||||
|
redis_conn.delete(lock_key)
|
||||||
|
|
||||||
|
def is_job_locked(job_id: str) -> bool:
|
||||||
|
"""Check if job is currently locked by another worker."""
|
||||||
|
lock_key = f"lock:job:{job_id}"
|
||||||
|
return redis_conn.exists(lock_key) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# Status Update Helpers (Atomic)
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
def safe_now_iso() -> str:
|
def safe_now_iso() -> str:
|
||||||
"""Get current UTC time as ISO string."""
|
"""Get current UTC time as ISO string."""
|
||||||
return datetime.now(timezone.utc).isoformat()
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
@ -83,7 +117,11 @@ def update_status(
|
||||||
outputs: Optional[Dict] = None,
|
outputs: Optional[Dict] = None,
|
||||||
error: Optional[Dict] = None,
|
error: Optional[Dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update job status in Redis."""
|
"""Update job status in Redis atomically.
|
||||||
|
|
||||||
|
Uses Redis pipeline to update both the status hash and RQ job meta
|
||||||
|
in a single atomic operation.
|
||||||
|
"""
|
||||||
key = f"job:{job_id}:status"
|
key = f"job:{job_id}:status"
|
||||||
|
|
||||||
status_data = {
|
status_data = {
|
||||||
|
|
@ -101,7 +139,9 @@ def update_status(
|
||||||
status_data["error"] = error
|
status_data["error"] = error
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis_conn.set(key, json.dumps(status_data), ex=86400)
|
pipe = redis_conn.pipeline()
|
||||||
|
pipe.set(key, json.dumps(status_data), ex=86400)
|
||||||
|
|
||||||
from rq import get_current_job
|
from rq import get_current_job
|
||||||
job = get_current_job()
|
job = get_current_job()
|
||||||
if job:
|
if job:
|
||||||
|
|
@ -109,10 +149,42 @@ def update_status(
|
||||||
job.meta['stage'] = stage
|
job.meta['stage'] = stage
|
||||||
job.meta['status_message'] = message
|
job.meta['status_message'] = message
|
||||||
job.save_meta()
|
job.save_meta()
|
||||||
|
|
||||||
|
pipe.execute()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to update Redis status: {e}")
|
print(f"Warning: Failed to update Redis status: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def check_existing_outputs(job_id: str, storage) -> Optional[Dict[str, str]]:
|
||||||
|
"""Check if job outputs already exist in MinIO (skip reprocessing).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_id: Job ID
|
||||||
|
storage: MinIOStorage instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict of output URLs if all expected outputs exist, None otherwise.
|
||||||
|
"""
|
||||||
|
expected_files = [
|
||||||
|
"refined_url",
|
||||||
|
"refined_confidence_url",
|
||||||
|
"refined_cloud_mask_url",
|
||||||
|
"refined_legend_url",
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = {}
|
||||||
|
for filename in ["refined.tif", "refined_confidence.tif", "refined_cloud_mask.tif", "refined_legend.json"]:
|
||||||
|
key = f"results/{job_id}/{filename}"
|
||||||
|
try:
|
||||||
|
if storage.head_object(storage.bucket_results, key):
|
||||||
|
url_key = filename.replace(".", "_url")
|
||||||
|
outputs[url_key] = storage.presign_get(storage.bucket_results, key)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return outputs if len(outputs) >= 3 else None
|
||||||
|
|
||||||
|
|
||||||
def send_dw_baseline_if_ready(dw_future, storage, job_id, payload, update_func):
|
def send_dw_baseline_if_ready(dw_future, storage, job_id, payload, update_func):
|
||||||
"""Check if DW baseline is ready and send to client."""
|
"""Check if DW baseline is ready and send to client."""
|
||||||
if dw_future is None:
|
if dw_future is None:
|
||||||
|
|
@ -241,7 +313,12 @@ def _load_dw_async(storage, bbox, year, season) -> Optional[Tuple[np.ndarray, di
|
||||||
# ==========================================
|
# ==========================================
|
||||||
|
|
||||||
def run_job(payload_dict: dict) -> dict:
|
def run_job(payload_dict: dict) -> dict:
|
||||||
"""Main job runner with async DW baseline loading.
|
"""Main job runner with idempotency safeguards and async DW baseline loading.
|
||||||
|
|
||||||
|
Safeguards:
|
||||||
|
- Distributed locking prevents worker retry collisions
|
||||||
|
- Checks existing outputs to skip reprocessing
|
||||||
|
- DW baseline loads in background while hybrid inference runs
|
||||||
|
|
||||||
DW baseline loads in background while hybrid inference runs.
|
DW baseline loads in background while hybrid inference runs.
|
||||||
DW URL is sent to client as soon as it's ready, parallel to inference.
|
DW URL is sent to client as soon as it's ready, parallel to inference.
|
||||||
|
|
@ -263,6 +340,35 @@ def run_job(payload_dict: dict) -> dict:
|
||||||
if "model_name" in payload_dict and "model" not in payload_dict:
|
if "model_name" in payload_dict and "model" not in payload_dict:
|
||||||
payload_dict["model"] = payload_dict["model_name"]
|
payload_dict["model"] = payload_dict["model_name"]
|
||||||
|
|
||||||
|
if "season" not in payload_dict:
|
||||||
|
payload_dict["season"] = "summer"
|
||||||
|
|
||||||
|
if not acquire_job_lock(job_id):
|
||||||
|
return {
|
||||||
|
"status": "skipped",
|
||||||
|
"job_id": job_id,
|
||||||
|
"message": "Job already being processed by another worker"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
return _execute_inference_job(job_id, payload_dict)
|
||||||
|
except Exception as e:
|
||||||
|
error_trace = traceback.format_exc()
|
||||||
|
print(f"[{job_id}] Worker crashed: {e}")
|
||||||
|
print(error_trace)
|
||||||
|
update_status(
|
||||||
|
job_id, "failed", "error", 0,
|
||||||
|
f"Worker crashed: {str(e)}",
|
||||||
|
error={"type": type(e).__name__, "message": str(e), "trace": error_trace}
|
||||||
|
)
|
||||||
|
return {"status": "failed", "error": str(e), "job_id": job_id}
|
||||||
|
finally:
|
||||||
|
release_job_lock(job_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _execute_inference_job(job_id: str, payload_dict: dict) -> dict:
|
||||||
|
|
||||||
# Initialize storage
|
# Initialize storage
|
||||||
try:
|
try:
|
||||||
from storage import MinIOStorage
|
from storage import MinIOStorage
|
||||||
|
|
@ -284,6 +390,21 @@ def run_job(payload_dict: dict) -> dict:
|
||||||
)
|
)
|
||||||
return {"status": "failed", "errors": errors}
|
return {"status": "failed", "errors": errors}
|
||||||
|
|
||||||
|
existing_outputs = check_existing_outputs(job_id, storage)
|
||||||
|
if existing_outputs:
|
||||||
|
update_status(
|
||||||
|
job_id, "complete", "cached", 100,
|
||||||
|
"Results already exist, skipping reprocessing",
|
||||||
|
outputs=existing_outputs
|
||||||
|
)
|
||||||
|
print(f"[{job_id}] Found existing outputs in MinIO, skipping inference")
|
||||||
|
return {
|
||||||
|
"status": "cached",
|
||||||
|
"job_id": job_id,
|
||||||
|
"outputs": existing_outputs,
|
||||||
|
"cached": True
|
||||||
|
}
|
||||||
|
|
||||||
update_status(job_id, "running", "init", 5, "Starting inference pipeline...")
|
update_status(job_id, "running", "init", 5, "Starting inference pipeline...")
|
||||||
|
|
||||||
dw_baseline_url = None
|
dw_baseline_url = None
|
||||||
|
|
@ -320,19 +441,30 @@ def run_job(payload_dict: dict) -> dict:
|
||||||
# ==========================================
|
# ==========================================
|
||||||
update_status(job_id, "running", "load_model", 20, "Loading model artifacts...")
|
update_status(job_id, "running", "load_model", 20, "Loading model artifacts...")
|
||||||
|
|
||||||
model_dir = Path(tempfile.mkdtemp())
|
# Use persistent model cache
|
||||||
print(f"[{job_id}] Downloading model artifacts...")
|
model_dir = Path("/app/model_cache")
|
||||||
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
print(f"[{job_id}] Model cache directory: {model_dir}")
|
||||||
|
|
||||||
# Download model artifacts
|
# Download model artifacts if missing
|
||||||
for artifact in ["pipeline_meta.pkl", "Temporal_FCN.pth", "calibrated_hybrid_cb.pkl"]:
|
for artifact in ["pipeline_meta.pkl", "Temporal_FCN.pth", "calibrated_hybrid_cb.pkl"]:
|
||||||
|
target_path = model_dir / artifact
|
||||||
|
if target_path.exists():
|
||||||
|
print(f"[{job_id}] Found {artifact} in cache, skipping download")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"[{job_id}] Downloading {artifact} to cache...")
|
||||||
try:
|
try:
|
||||||
storage.download_file(storage.bucket_models, artifact, model_dir / artifact)
|
storage.download_file(storage.bucket_models, artifact, target_path)
|
||||||
print(f"[{job_id}] Downloaded {artifact}")
|
print(f"[{job_id}] Downloaded {artifact}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
try:
|
||||||
storage.download_file(storage.bucket_models, f"models/{artifact}", model_dir / artifact)
|
storage.download_file(storage.bucket_models, f"models/{artifact}", target_path)
|
||||||
print(f"[{job_id}] Downloaded {artifact} (from models/ prefix)")
|
print(f"[{job_id}] Downloaded {artifact} (from models/ prefix)")
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
|
# Clean up failed download to prevent corrupted cache
|
||||||
|
if target_path.exists():
|
||||||
|
target_path.unlink()
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}"
|
f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}"
|
||||||
)
|
)
|
||||||
|
|
@ -415,19 +547,10 @@ def run_job(payload_dict: dict) -> dict:
|
||||||
storage.upload_result(dw_temp_path, dw_key)
|
storage.upload_result(dw_temp_path, dw_key)
|
||||||
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
|
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
|
||||||
|
|
||||||
# Wait for DW if still running
|
# DW is optional - if it hasn't finished yet and inference is done, continue without it
|
||||||
|
# The dw_baseline_url is a nice-to-have, not a blocking requirement
|
||||||
if dw_baseline_url is None:
|
if dw_baseline_url is None:
|
||||||
print(f"[{job_id}] Waiting for DW baseline to finish...")
|
print(f"[{job_id}] DW baseline not ready after inference, continuing without it")
|
||||||
dw_result = dw_future.result(timeout=60)
|
|
||||||
if dw_result is not None:
|
|
||||||
dw_arr, dw_profile = dw_result
|
|
||||||
import rasterio
|
|
||||||
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
|
|
||||||
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
|
|
||||||
dst.write(dw_arr)
|
|
||||||
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
|
|
||||||
storage.upload_result(dw_temp_path, dw_key)
|
|
||||||
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
|
|
||||||
|
|
||||||
# ==========================================
|
# ==========================================
|
||||||
# Final Status
|
# Final Status
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,111 @@
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# API Configuration
|
||||||
|
API_URL = "http://localhost:8000"
|
||||||
|
# Admin user credentials from main.py
|
||||||
|
LOGIN_DATA = {
|
||||||
|
"username": "fchinembiri24@gmail.com",
|
||||||
|
"password": "P@55w0rd.123"
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_inference():
|
||||||
|
print("=== GeoCrop Async Inference End-to-End Test ===")
|
||||||
|
|
||||||
|
# 1. Login to get token
|
||||||
|
print("\n1. Logging in...")
|
||||||
|
try:
|
||||||
|
response = requests.post(f"{API_URL}/auth/login", data=LOGIN_DATA)
|
||||||
|
response.raise_for_status()
|
||||||
|
token = response.json()["access_token"]
|
||||||
|
print("✓ Login successful")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Login failed: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
# 2. Submit Inference Job
|
||||||
|
# Coordinates in Zimbabwe (Agricultural area near Mazowe)
|
||||||
|
payload = {
|
||||||
|
"lat": -17.51,
|
||||||
|
"lon": 30.91,
|
||||||
|
"radius_km": 1.0,
|
||||||
|
"year": "2022",
|
||||||
|
"model_name": "Hybrid_SpatioTemporal"
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"\n2. Submitting inference job for AOI ({payload['lat']}, {payload['lon']})...")
|
||||||
|
try:
|
||||||
|
response = requests.post(f"{API_URL}/jobs", json=payload, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
job_data = response.json()
|
||||||
|
job_id = job_data["job_id"]
|
||||||
|
status = job_data["status"]
|
||||||
|
print(f"✓ Job submitted. ID: {job_id}, Initial Status: {status}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Job submission failed: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 3. Poll for Status
|
||||||
|
print("\n3. Polling for job status (every 2s)...")
|
||||||
|
last_status = None
|
||||||
|
last_stage = None
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
timeout = 600 # 10 minutes
|
||||||
|
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{API_URL}/jobs/{job_id}", headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
current_status = data.get("status")
|
||||||
|
worker_status = data.get("worker_status")
|
||||||
|
stage = data.get("stage")
|
||||||
|
progress = data.get("progress", 0)
|
||||||
|
message = data.get("message", "")
|
||||||
|
|
||||||
|
status_str = f"Status: {current_status}"
|
||||||
|
if worker_status: status_str += f" | Worker: {worker_status}"
|
||||||
|
if stage: status_str += f" | Stage: {stage}"
|
||||||
|
if progress: status_str += f" | Progress: {progress}%"
|
||||||
|
|
||||||
|
if status_str != last_status:
|
||||||
|
print(f"[{time.strftime('%H:%M:%S')}] {status_str}")
|
||||||
|
if message: print(f" Message: {message}")
|
||||||
|
last_status = status_str
|
||||||
|
|
||||||
|
if current_status == "finished":
|
||||||
|
print("\n✓ Inference Job Completed Successfully!")
|
||||||
|
print("\n=== Final Results ===")
|
||||||
|
result = data.get("result")
|
||||||
|
print(json.dumps(result, indent=2))
|
||||||
|
|
||||||
|
# Print specific LULC statistics if available in detailed status
|
||||||
|
detailed = data.get("detailed")
|
||||||
|
if detailed and "outputs" in detailed:
|
||||||
|
print("\nOutput Artifacts:")
|
||||||
|
for k, v in detailed["outputs"].items():
|
||||||
|
print(f" - {k}: {v[:80]}...")
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
if current_status == "failed":
|
||||||
|
print("\n✗ Job Failed!")
|
||||||
|
print(f"Error: {data.get('error')}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Polling error: {e}")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
print("\n✗ Test timed out after 10 minutes.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_inference()
|
||||||
Loading…
Reference in New Issue