Compare commits
7 Commits
76a5d155d7
...
e2cfec586b
| Author | SHA1 | Date |
|---|---|---|
|
|
e2cfec586b | |
|
|
18aa966dc8 | |
|
|
5cbda32e1e | |
|
|
c2cc58d7ce | |
|
|
44b9220369 | |
|
|
609f9c5892 | |
|
|
a406a28a13 |
|
|
@ -0,0 +1 @@
|
||||||
|
Aider MCP Integration Verified
|
||||||
|
|
@ -231,4 +231,7 @@ async def get_job_status(job_id: str, current_user: dict = Depends(get_current_u
|
||||||
"progress": detailed_status.get("progress"),
|
"progress": detailed_status.get("progress"),
|
||||||
"message": detailed_status.get("message"),
|
"message": detailed_status.get("message"),
|
||||||
})
|
})
|
||||||
|
# Include intermediate outputs (e.g., dw_baseline_url) even if job not finished yet
|
||||||
|
if "outputs" in detailed_status:
|
||||||
|
response["outputs"] = detailed_status["outputs"]
|
||||||
return response
|
return response
|
||||||
|
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -8,6 +8,7 @@ This module provides:
|
||||||
- Harmonic/Fourier features
|
- Harmonic/Fourier features
|
||||||
- Index computations (NDVI, NDRE, EVI, SAVI, CI_RE, NDWI)
|
- Index computations (NDVI, NDRE, EVI, SAVI, CI_RE, NDWI)
|
||||||
- Per-pixel feature builder
|
- Per-pixel feature builder
|
||||||
|
- Gap handling for temporal and spatial missing data
|
||||||
|
|
||||||
NOTE: Seasonal window summaries come in Step 4B.
|
NOTE: Seasonal window summaries come in Step 4B.
|
||||||
"""
|
"""
|
||||||
|
|
@ -15,7 +16,7 @@ NOTE: Seasonal window summaries come in Step 4B.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
@ -140,6 +141,195 @@ def smooth_series(y: np.ndarray) -> np.ndarray:
|
||||||
return savgol_smooth_1d(y_filled, window=5, polyorder=2)
|
return savgol_smooth_1d(y_filled, window=5, polyorder=2)
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# Gap Handling for Missing Data
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
|
def handle_temporal_gaps(y: np.ndarray, gap_threshold: int = 3) -> np.ndarray:
|
||||||
|
"""Handle temporal gaps in a 1D time series.
|
||||||
|
|
||||||
|
This function marks significant gaps (>= gap_threshold consecutive NaNs)
|
||||||
|
for special handling while interpolating smaller gaps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: 1D time series (may contain NaN values)
|
||||||
|
gap_threshold: Minimum consecutive NaNs to be considered a "significant gap"
|
||||||
|
Pixels with gaps >= threshold will be marked as NoData
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array with small gaps filled by interpolation, large gaps preserved as NaN
|
||||||
|
The calling code should use the gap mask to mark pixels as NoData
|
||||||
|
"""
|
||||||
|
y = np.array(y, dtype=np.float64).copy()
|
||||||
|
n = len(y)
|
||||||
|
|
||||||
|
if n == 0:
|
||||||
|
return y
|
||||||
|
|
||||||
|
# Convert to NaN where appropriate (0s might be missing)
|
||||||
|
# Only treat as NaN if there are non-zero neighbors
|
||||||
|
zero_mask = (y == 0)
|
||||||
|
if not np.all(zero_mask):
|
||||||
|
# Find first and last non-zero
|
||||||
|
nonzero_idx = np.where(~zero_mask)[0]
|
||||||
|
if len(nonzero_idx) > 0:
|
||||||
|
first_nz = nonzero_idx[0]
|
||||||
|
last_nz = nonzero_idx[-1]
|
||||||
|
# Mark interior zeros as NaN for interpolation
|
||||||
|
for i in range(first_nz, last_nz + 1):
|
||||||
|
if zero_mask[i]:
|
||||||
|
y[i] = np.nan
|
||||||
|
|
||||||
|
# Find consecutive NaN runs
|
||||||
|
nan_mask = np.isnan(y)
|
||||||
|
|
||||||
|
# Run-length encoding for NaN runs
|
||||||
|
in_gap = False
|
||||||
|
gap_start = 0
|
||||||
|
gap_lengths = []
|
||||||
|
|
||||||
|
for i in range(n + 1):
|
||||||
|
is_nan = i < n and nan_mask[i]
|
||||||
|
|
||||||
|
if is_nan and not in_gap:
|
||||||
|
# Start of a gap
|
||||||
|
in_gap = True
|
||||||
|
gap_start = i
|
||||||
|
elif not is_nan and in_gap:
|
||||||
|
# End of a gap
|
||||||
|
in_gap = False
|
||||||
|
gap_lengths.append(i - gap_start)
|
||||||
|
|
||||||
|
# Identify large gaps (>= threshold) that should NOT be filled
|
||||||
|
large_gap_mask = np.zeros(n, dtype=bool)
|
||||||
|
in_gap = False
|
||||||
|
gap_start = 0
|
||||||
|
|
||||||
|
for i in range(n + 1):
|
||||||
|
is_nan = i < n and nan_mask[i]
|
||||||
|
|
||||||
|
if is_nan and not in_gap:
|
||||||
|
in_gap = True
|
||||||
|
gap_start = i
|
||||||
|
elif not is_nan and in_gap:
|
||||||
|
in_gap = False
|
||||||
|
gap_len = i - gap_start
|
||||||
|
if gap_len >= gap_threshold:
|
||||||
|
# Mark this as a large gap - don't fill
|
||||||
|
large_gap_mask[gap_start:i] = True
|
||||||
|
|
||||||
|
# Interpolate only small gaps (and boundaries)
|
||||||
|
# Use linear interpolation
|
||||||
|
valid_mask = ~nan_mask
|
||||||
|
if not np.any(valid_mask):
|
||||||
|
return y # All NaN
|
||||||
|
|
||||||
|
# Linear interpolation for all NaNs first
|
||||||
|
x = np.arange(n)
|
||||||
|
valid_x = x[valid_mask]
|
||||||
|
valid_y = y[valid_mask]
|
||||||
|
|
||||||
|
if len(valid_x) > 0:
|
||||||
|
y_interp = np.interp(x, valid_x, valid_y)
|
||||||
|
else:
|
||||||
|
y_interp = np.full(n, np.nan)
|
||||||
|
|
||||||
|
# Restore large gaps as NaN
|
||||||
|
y_interp[large_gap_mask] = np.nan
|
||||||
|
|
||||||
|
return y_interp
|
||||||
|
|
||||||
|
|
||||||
|
def spatial_fill_nan(data_2d: np.ndarray, max_iterations: int = 3) -> np.ndarray:
|
||||||
|
"""Fill NaN values in a 2D spatial raster using spatial interpolation.
|
||||||
|
|
||||||
|
This function iteratively fills NaN values using neighboring non-NaN values.
|
||||||
|
Works from edges inward, progressively filling larger areas.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_2d: 2D numpy array (H, W) with possible NaN values
|
||||||
|
max_iterations: Maximum number of passes (more iterations fill more NaNs)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array with NaN values filled using spatial median
|
||||||
|
"""
|
||||||
|
data = data_2d.copy()
|
||||||
|
H, W = data.shape
|
||||||
|
|
||||||
|
# Create mask of valid pixels
|
||||||
|
valid_mask = ~np.isnan(data)
|
||||||
|
|
||||||
|
if np.all(valid_mask):
|
||||||
|
return data # No NaNs
|
||||||
|
|
||||||
|
for iteration in range(max_iterations):
|
||||||
|
changed = False
|
||||||
|
|
||||||
|
for i in range(H):
|
||||||
|
for j in range(W):
|
||||||
|
if np.isnan(data[i, j]):
|
||||||
|
# Get 4-connected neighbors (up, down, left, right)
|
||||||
|
neighbors = []
|
||||||
|
|
||||||
|
if i > 0 and not np.isnan(data[i-1, j]):
|
||||||
|
neighbors.append(data[i-1, j])
|
||||||
|
if i < H-1 and not np.isnan(data[i+1, j]):
|
||||||
|
neighbors.append(data[i+1, j])
|
||||||
|
if j > 0 and not np.isnan(data[i, j-1]):
|
||||||
|
neighbors.append(data[i, j-1])
|
||||||
|
if j < W-1 and not np.isnan(data[i, j+1]):
|
||||||
|
neighbors.append(data[i, j+1])
|
||||||
|
|
||||||
|
if neighbors:
|
||||||
|
# Fill with median of neighbors
|
||||||
|
data[i, j] = np.median(neighbors)
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
if not changed:
|
||||||
|
break # No more NaNs filled in this iteration
|
||||||
|
|
||||||
|
# If still NaNs remain, fill with global median
|
||||||
|
if np.any(np.isnan(data)):
|
||||||
|
global_median = np.nanmedian(data)
|
||||||
|
data = np.where(np.isnan(data), global_median, data)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def compute_gap_mask(y: np.ndarray, gap_threshold: int = 3) -> np.ndarray:
|
||||||
|
"""Compute a boolean mask indicating pixels with significant temporal gaps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: 1D time series (may contain NaN values)
|
||||||
|
gap_threshold: Minimum consecutive NaNs to be considered a "significant gap"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Boolean array where True indicates a significant gap (>= threshold consecutive NaNs)
|
||||||
|
"""
|
||||||
|
y = np.array(y, dtype=np.float64)
|
||||||
|
n = len(y)
|
||||||
|
|
||||||
|
nan_mask = np.isnan(y)
|
||||||
|
gap_mask = np.zeros(n, dtype=bool)
|
||||||
|
|
||||||
|
in_gap = False
|
||||||
|
gap_start = 0
|
||||||
|
|
||||||
|
for i in range(n + 1):
|
||||||
|
is_nan = i < n and nan_mask[i]
|
||||||
|
|
||||||
|
if is_nan and not in_gap:
|
||||||
|
in_gap = True
|
||||||
|
gap_start = i
|
||||||
|
elif not is_nan and in_gap:
|
||||||
|
in_gap = False
|
||||||
|
gap_len = i - gap_start
|
||||||
|
if gap_len >= gap_threshold:
|
||||||
|
gap_mask[gap_start:i] = True
|
||||||
|
|
||||||
|
return gap_mask
|
||||||
|
|
||||||
|
|
||||||
# ==========================================
|
# ==========================================
|
||||||
# Index Computations
|
# Index Computations
|
||||||
# ==========================================
|
# ==========================================
|
||||||
|
|
@ -327,7 +517,7 @@ def build_features_for_pixel(
|
||||||
Args:
|
Args:
|
||||||
ts: Dict of index name -> 1D array time series
|
ts: Dict of index name -> 1D array time series
|
||||||
Keys: "ndvi", "ndre", "evi", "savi", "ci_re", "ndwi"
|
Keys: "ndvi", "ndre", "evi", "savi", "ci_re", "ndwi"
|
||||||
step_days: Days between observations
|
step_days: Days between observations (for AUC calculation)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with ONLY scalar computed features (no arrays):
|
Dict with ONLY scalar computed features (no arrays):
|
||||||
|
|
@ -602,6 +792,7 @@ def build_features_v2_for_pixel(
|
||||||
"""
|
"""
|
||||||
from scipy.stats import skew, kurtosis
|
from scipy.stats import skew, kurtosis
|
||||||
from scipy.integrate import simpson
|
from scipy.integrate import simpson
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
features = {}
|
features = {}
|
||||||
dt_dates = pd.to_datetime(dates, format='%Y%m%d')
|
dt_dates = pd.to_datetime(dates, format='%Y%m%d')
|
||||||
|
|
@ -838,6 +1029,29 @@ if __name__ == "__main__":
|
||||||
assert len(features) == 51, f"Expected 51 features in dict, got {len(features)}"
|
assert len(features) == 51, f"Expected 51 features in dict, got {len(features)}"
|
||||||
assert vector.shape == (51,), f"Expected shape (51,), got {vector.shape}"
|
assert vector.shape == (51,), f"Expected shape (51,), got {vector.shape}"
|
||||||
|
|
||||||
|
print("\n8. Testing gap handling functions...")
|
||||||
|
|
||||||
|
# Create time series with gaps
|
||||||
|
gap_series = np.array([0.5, 0.6, np.nan, np.nan, np.nan, 0.7, 0.8, np.nan, 0.9, 0.4])
|
||||||
|
|
||||||
|
# Test handle_temporal_gaps with threshold=3
|
||||||
|
filled_series = handle_temporal_gaps(gap_series, gap_threshold=3)
|
||||||
|
print(f" Original: {gap_series}")
|
||||||
|
print(f" After gap handling (threshold=3): {filled_series}")
|
||||||
|
|
||||||
|
# Test compute_gap_mask
|
||||||
|
gap_mask = compute_gap_mask(gap_series, gap_threshold=3)
|
||||||
|
print(f" Gap mask (threshold=3): {gap_mask}")
|
||||||
|
|
||||||
|
# Test spatial_fill_nan
|
||||||
|
spatial_arr = np.array([[0.5, 0.6, np.nan, 0.8],
|
||||||
|
[0.7, np.nan, 0.9, 0.4],
|
||||||
|
[np.nan, 0.3, 0.2, np.nan],
|
||||||
|
[0.1, 0.2, 0.3, 0.4]])
|
||||||
|
filled_spatial = spatial_fill_nan(spatial_arr, max_iterations=2)
|
||||||
|
print(f" Original spatial (2D):\n{spatial_arr}")
|
||||||
|
print(f" After spatial fill:\n{filled_spatial}")
|
||||||
|
|
||||||
print("\n=== STEP 4B All Tests Passed ===")
|
print("\n=== STEP 4B All Tests Passed ===")
|
||||||
print(f" Total features: {len(features)}")
|
print(f" Total features: {len(features)}")
|
||||||
print(f" Feature order length: {len(FEATURE_ORDER_V1)}")
|
print(f" Feature order length: {len(FEATURE_ORDER_V1)}")
|
||||||
|
|
|
||||||
|
|
@ -207,14 +207,52 @@ class CropInferencePipeline:
|
||||||
|
|
||||||
def _impute_inference_data(self, df):
|
def _impute_inference_data(self, df):
|
||||||
print("Imputing cloudy/missing timesteps via temporal interpolation...")
|
print("Imputing cloudy/missing timesteps via temporal interpolation...")
|
||||||
|
from feature_computation import handle_temporal_gaps, spatial_fill_nan
|
||||||
|
|
||||||
df = df.copy()
|
df = df.copy()
|
||||||
missing_mask = {}
|
missing_mask = {}
|
||||||
|
|
||||||
|
# Track original NaNs before any imputation
|
||||||
for band in self.bands:
|
for band in self.bands:
|
||||||
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
||||||
if band_cols:
|
if band_cols:
|
||||||
missing_mask[band] = df[band_cols].isna().astype(float)
|
missing_mask[band] = df[band_cols].isna().astype(float)
|
||||||
|
|
||||||
|
# Process each band: apply handle_temporal_gaps per pixel for each band
|
||||||
|
for band in self.bands:
|
||||||
|
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
||||||
|
if band_cols:
|
||||||
|
print(f" Processing band {band} with gap handling...")
|
||||||
|
|
||||||
|
# For each pixel, apply handle_temporal_gaps to the time series
|
||||||
|
for idx in range(len(df)):
|
||||||
|
time_series = df[band_cols].iloc[idx].values.astype(np.float64)
|
||||||
|
|
||||||
|
# Apply handle_temporal_gaps: gaps >= 3 will result in NaNs for those timesteps
|
||||||
|
time_series = handle_temporal_gaps(time_series, gap_threshold=3)
|
||||||
|
df.loc[df.index[idx], band_cols] = time_series
|
||||||
|
|
||||||
|
# After gap handling, fill remaining NaNs with linear interpolation
|
||||||
df[band_cols] = df[band_cols].interpolate(method='linear', axis=1, limit_direction='both')
|
df[band_cols] = df[band_cols].interpolate(method='linear', axis=1, limit_direction='both')
|
||||||
df[band_cols] = df[band_cols].ffill(axis=1).bfill(axis=1).fillna(0)
|
df[band_cols] = df[band_cols].ffill(axis=1).bfill(axis=1).fillna(0)
|
||||||
|
|
||||||
|
# Apply spatial fill to each band using spatial_fill_nan
|
||||||
|
# Reshape to (num_dates, num_pixels) for each band, apply spatial fill
|
||||||
|
for band in self.bands:
|
||||||
|
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
||||||
|
if band_cols:
|
||||||
|
print(f" Applying spatial fill for band {band}...")
|
||||||
|
|
||||||
|
# Transpose to (T, H*W) for spatial filling
|
||||||
|
band_data = df[band_cols].values.T # Shape: (num_dates, num_pixels)
|
||||||
|
|
||||||
|
# Apply spatial_fill_nan per time step
|
||||||
|
for t_idx in range(band_data.shape[0]):
|
||||||
|
band_data[t_idx] = spatial_fill_nan(band_data[t_idx].reshape(-1, 1)).squeeze()
|
||||||
|
|
||||||
|
# Put back into dataframe
|
||||||
|
df[band_cols] = band_data.T
|
||||||
|
|
||||||
return df, missing_mask
|
return df, missing_mask
|
||||||
|
|
||||||
def predict(self, raw_df, apply_spatial_smoothing=False, coord_cols=['lat', 'lon']):
|
def predict(self, raw_df, apply_spatial_smoothing=False, coord_cols=['lat', 'lon']):
|
||||||
|
|
@ -240,6 +278,22 @@ class CropInferencePipeline:
|
||||||
final_probs = (fcn_probs * self.w_fcn) + (cb_probs * self.w_cb)
|
final_probs = (fcn_probs * self.w_fcn) + (cb_probs * self.w_cb)
|
||||||
final_preds = np.argmax(final_probs, axis=1)
|
final_preds = np.argmax(final_probs, axis=1)
|
||||||
|
|
||||||
|
# Identify No Data pixels: those with all NaNs or zeros after imputation
|
||||||
|
no_data_mask = np.zeros(len(df), dtype=bool)
|
||||||
|
for band in self.bands:
|
||||||
|
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
||||||
|
if band_cols:
|
||||||
|
band_data = df[band_cols].values
|
||||||
|
# Check if pixel is all zeros or all NaN for this band
|
||||||
|
all_zeros = np.all(band_data == 0, axis=1)
|
||||||
|
all_nan = np.all(np.isnan(band_data), axis=1)
|
||||||
|
no_data_mask = no_data_mask | all_zeros | all_nan
|
||||||
|
|
||||||
|
# Override predictions for No Data pixels to class 0 (Background/No Data)
|
||||||
|
final_preds[no_data_mask] = 0
|
||||||
|
final_probs[no_data_mask] = 0.0
|
||||||
|
final_probs[no_data_mask, 0] = 1.0 # Set probability to 1.0 for class 0
|
||||||
|
|
||||||
if apply_spatial_smoothing and all(col in df.columns for col in coord_cols):
|
if apply_spatial_smoothing and all(col in df.columns for col in coord_cols):
|
||||||
print(f"Applying spatial probability smoothing using {coord_cols}...")
|
print(f"Applying spatial probability smoothing using {coord_cols}...")
|
||||||
coords = df[coord_cols].values
|
coords = df[coord_cols].values
|
||||||
|
|
@ -249,15 +303,20 @@ class CropInferencePipeline:
|
||||||
final_preds = np.argmax(smoothed_probs, axis=1)
|
final_preds = np.argmax(smoothed_probs, axis=1)
|
||||||
final_probs = smoothed_probs
|
final_probs = smoothed_probs
|
||||||
|
|
||||||
|
# Re-apply No Data override after smoothing
|
||||||
|
final_preds[no_data_mask] = 0
|
||||||
|
final_probs[no_data_mask, 0] = 1.0
|
||||||
|
|
||||||
df['class_id'] = final_preds
|
df['class_id'] = final_preds
|
||||||
df['predicted_crop'] = self.le.inverse_transform(final_preds)
|
df['predicted_crop'] = self.le.inverse_transform(final_preds)
|
||||||
df['confidence'] = np.max(final_probs, axis=1)
|
df['confidence'] = np.max(final_probs, axis=1)
|
||||||
|
|
||||||
|
# Track missing data ratio for quality flag
|
||||||
missing_ratio = np.mean([m.mean(axis=1) for m in missing_mask.values()], axis=0)
|
missing_ratio = np.mean([m.mean(axis=1) for m in missing_mask.values()], axis=0)
|
||||||
df['high_missing'] = missing_ratio > 0.4
|
df['high_missing'] = missing_ratio > 0.4
|
||||||
df['low_quality'] = (df['confidence'] < 0.5) | df['high_missing']
|
df['low_quality'] = (df['confidence'] < 0.5) | df['high_missing'] | no_data_mask
|
||||||
|
|
||||||
# Set NoData (0) for low quality
|
# Set NoData (0) for low quality pixels
|
||||||
df.loc[df['low_quality'], 'class_id'] = 0
|
df.loc[df['low_quality'], 'class_id'] = 0
|
||||||
df.loc[df['low_quality'], 'predicted_crop'] = 'Unknown/NoData'
|
df.loc[df['low_quality'], 'predicted_crop'] = 'Unknown/NoData'
|
||||||
return df
|
return df
|
||||||
|
|
|
||||||
|
|
@ -1,650 +0,0 @@
|
||||||
"""GeoCrop inference pipeline (worker-side).
|
|
||||||
|
|
||||||
This module is designed to be called by your RQ worker.
|
|
||||||
Given a job payload (AOI, year, model choice), it:
|
|
||||||
1) Loads the correct model artifact from MinIO (or local cache).
|
|
||||||
2) Loads/clips the DW baseline COG for the requested season/year.
|
|
||||||
3) Queries Digital Earth Africa STAC for imagery and builds feature stack.
|
|
||||||
- IMPORTANT: Uses exact feature engineering from train.py:
|
|
||||||
- Savitzky-Golay smoothing (window=5, polyorder=2)
|
|
||||||
- Phenology metrics (amplitude, AUC, peak, slope)
|
|
||||||
- Harmonic features (1st/2nd order sin/cos)
|
|
||||||
- Seasonal window statistics (Early/Peak/Late)
|
|
||||||
4) Runs per-pixel inference to produce refined classes at 10m.
|
|
||||||
5) Applies neighborhood smoothing (majority filter).
|
|
||||||
6) Writes output GeoTIFF (COG recommended) to MinIO.
|
|
||||||
|
|
||||||
IMPORTANT: This implementation supports the current MinIO model format:
|
|
||||||
- Zimbabwe_Ensemble_Raw_Model.pkl (no scaler needed)
|
|
||||||
- Zimbabwe_Ensemble_Model.pkl (scaler needed)
|
|
||||||
- etc.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, Optional, Tuple, List
|
|
||||||
|
|
||||||
# Try to import required dependencies
|
|
||||||
try:
|
|
||||||
import joblib
|
|
||||||
except ImportError:
|
|
||||||
joblib = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
import numpy as np
|
|
||||||
except ImportError:
|
|
||||||
np = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
import rasterio
|
|
||||||
from rasterio import windows
|
|
||||||
from rasterio.enums import Resampling
|
|
||||||
except ImportError:
|
|
||||||
rasterio = None
|
|
||||||
windows = None
|
|
||||||
Resampling = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from config import InferenceConfig
|
|
||||||
except ImportError:
|
|
||||||
InferenceConfig = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from features import (
|
|
||||||
build_feature_stack_from_dea,
|
|
||||||
clip_raster_to_aoi,
|
|
||||||
load_dw_baseline_window,
|
|
||||||
majority_filter,
|
|
||||||
validate_aoi_zimbabwe,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# STEP 6: Model Loading and Raster Prediction
|
|
||||||
# ==========================================
|
|
||||||
|
|
||||||
def load_model(storage, model_name: str):
|
|
||||||
"""Load a trained model from MinIO storage.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage: MinIOStorage instance with download_model_file method
|
|
||||||
model_name: Name of model (e.g., "RandomForest", "XGBoost", "Ensemble")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Loaded sklearn-compatible model
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: If model file not found
|
|
||||||
ValueError: If model has incompatible number of features
|
|
||||||
"""
|
|
||||||
# Create temp directory for download
|
|
||||||
import tempfile
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
dest_dir = Path(tmp_dir)
|
|
||||||
|
|
||||||
# Download model file from MinIO
|
|
||||||
# storage.download_model_file already handles mapping
|
|
||||||
model_path = storage.download_model_file(model_name, dest_dir)
|
|
||||||
|
|
||||||
# Load model with joblib
|
|
||||||
model = joblib.load(model_path)
|
|
||||||
|
|
||||||
# Validate model compatibility
|
|
||||||
if hasattr(model, 'n_features_in_'):
|
|
||||||
from feature_computation import FEATURE_ORDER_V1, FEATURE_ORDER_V2
|
|
||||||
actual_features = model.n_features_in_
|
|
||||||
|
|
||||||
if actual_features == len(FEATURE_ORDER_V1):
|
|
||||||
print(f"Detected V1 model ({actual_features} features)")
|
|
||||||
elif actual_features == len(FEATURE_ORDER_V2):
|
|
||||||
print(f"Detected V2 model ({actual_features} features)")
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Model feature mismatch: model expects {actual_features} features. "
|
|
||||||
f"Available versions: V1 ({len(FEATURE_ORDER_V1)}), V2 ({len(FEATURE_ORDER_V2)})."
|
|
||||||
)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def predict_raster(
|
|
||||||
model,
|
|
||||||
feature_cube: np.ndarray,
|
|
||||||
feature_order: List[str],
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Run inference on a feature cube.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Trained sklearn-compatible model
|
|
||||||
feature_cube: 3D array of shape (H, W, 51) containing features
|
|
||||||
feature_order: List of 51 feature names in order
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
2D array of shape (H, W) with class predictions
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If feature_cube dimensions don't match feature_order
|
|
||||||
"""
|
|
||||||
# Validate dimensions
|
|
||||||
expected_features = len(feature_order)
|
|
||||||
actual_features = feature_cube.shape[-1]
|
|
||||||
|
|
||||||
if actual_features != expected_features:
|
|
||||||
raise ValueError(
|
|
||||||
f"Feature dimension mismatch: feature_cube has {actual_features} features "
|
|
||||||
f"but feature_order has {expected_features}. "
|
|
||||||
f"feature_cube shape: {feature_cube.shape}, feature_order length: {len(feature_order)}. "
|
|
||||||
f"Expected 51 features matching FEATURE_ORDER_V1."
|
|
||||||
)
|
|
||||||
|
|
||||||
H, W, C = feature_cube.shape
|
|
||||||
|
|
||||||
# Flatten spatial dimensions: (H, W, C) -> (H*W, C)
|
|
||||||
X = feature_cube.reshape(-1, C)
|
|
||||||
|
|
||||||
# Identify nodata pixels (all zeros)
|
|
||||||
nodata_mask = np.all(X == 0, axis=1)
|
|
||||||
num_nodata = np.sum(nodata_mask)
|
|
||||||
|
|
||||||
# Replace nodata with small non-zero values to avoid model issues
|
|
||||||
# The predictions will be overwritten for nodata pixels anyway
|
|
||||||
X_safe = X.copy()
|
|
||||||
if num_nodata > 0:
|
|
||||||
# Use epsilon to avoid division by zero in some models
|
|
||||||
X_safe[nodata_mask] = np.full(C, 1e-6)
|
|
||||||
|
|
||||||
# Run prediction
|
|
||||||
y_pred = model.predict(X_safe)
|
|
||||||
|
|
||||||
# Set nodata pixels to 0 (assuming class 0 reserved for nodata)
|
|
||||||
if num_nodata > 0:
|
|
||||||
y_pred[nodata_mask] = 0
|
|
||||||
|
|
||||||
# Reshape back to (H, W)
|
|
||||||
result = y_pred.reshape(H, W)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# Legacy functions (kept for backward compatibility)
|
|
||||||
# ==========================================
|
|
||||||
|
|
||||||
|
|
||||||
# Model name to MinIO filename mapping
|
|
||||||
# Format: "Zimbabwe_<ModelName>_Model.pkl" or "Zimbabwe_<ModelName>_Raw_Model.pkl"
|
|
||||||
MODEL_NAME_MAPPING = {
|
|
||||||
# Ensemble models
|
|
||||||
"Ensemble": "Zimbabwe_Ensemble_Raw_Model.pkl",
|
|
||||||
"Ensemble_Raw": "Zimbabwe_Ensemble_Raw_Model.pkl",
|
|
||||||
"Ensemble_Scaled": "Zimbabwe_Ensemble_Model.pkl",
|
|
||||||
|
|
||||||
# Individual models
|
|
||||||
"RandomForest": "Zimbabwe_RandomForest_Model.pkl",
|
|
||||||
"XGBoost": "Zimbabwe_XGBoost_Model.pkl",
|
|
||||||
"LightGBM": "Zimbabwe_LightGBM_Model.pkl",
|
|
||||||
"CatBoost": "Zimbabwe_CatBoost_Model.pkl",
|
|
||||||
|
|
||||||
# Legacy/raw variants
|
|
||||||
"RandomForest_Raw": "Zimbabwe_RandomForest_Model.pkl",
|
|
||||||
"XGBoost_Raw": "Zimbabwe_XGBoost_Model.pkl",
|
|
||||||
"LightGBM_Raw": "Zimbabwe_LightGBM_Model.pkl",
|
|
||||||
"CatBoost_Raw": "Zimbabwe_CatBoost_Model.pkl",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Default class mapping if label encoder not available
|
|
||||||
# Based on typical Zimbabwe crop classification
|
|
||||||
DEFAULT_CLASSES = [
|
|
||||||
"cropland_rainfed",
|
|
||||||
"cropland_irrigated",
|
|
||||||
"tree_crop",
|
|
||||||
"grassland",
|
|
||||||
"shrubland",
|
|
||||||
"urban",
|
|
||||||
"water",
|
|
||||||
"bare",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InferenceResult:
|
|
||||||
job_id: str
|
|
||||||
status: str
|
|
||||||
outputs: Dict[str, str]
|
|
||||||
meta: Dict
|
|
||||||
|
|
||||||
|
|
||||||
def _local_artifact_cache_dir() -> Path:
|
|
||||||
d = Path(os.getenv("GEOCROP_CACHE_DIR", "/tmp/geocrop-cache"))
|
|
||||||
d.mkdir(parents=True, exist_ok=True)
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_filename(model_name: str) -> str:
|
|
||||||
"""Get the MinIO filename for a given model name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Model name from job payload (e.g., "Ensemble", "Ensemble_Scaled")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MinIO filename (e.g., "Zimbabwe_Ensemble_Raw_Model.pkl")
|
|
||||||
"""
|
|
||||||
# Direct lookup
|
|
||||||
if model_name in MODEL_NAME_MAPPING:
|
|
||||||
return MODEL_NAME_MAPPING[model_name]
|
|
||||||
|
|
||||||
# Try case-insensitive
|
|
||||||
model_lower = model_name.lower()
|
|
||||||
for key, value in MODEL_NAME_MAPPING.items():
|
|
||||||
if key.lower() == model_lower:
|
|
||||||
return value
|
|
||||||
|
|
||||||
# Default fallback
|
|
||||||
if "_raw" in model_lower:
|
|
||||||
return f"Zimbabwe_{model_name.replace('_Raw', '').title()}_Raw_Model.pkl"
|
|
||||||
else:
|
|
||||||
return f"Zimbabwe_{model_name.title()}_Model.pkl"
|
|
||||||
|
|
||||||
|
|
||||||
def needs_scaler(model_name: str) -> bool:
|
|
||||||
"""Determine if a model needs feature scaling.
|
|
||||||
|
|
||||||
Models with "_Raw" suffix do NOT need scaling.
|
|
||||||
All other models require StandardScaler.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Model name from job payload
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if scaler should be applied
|
|
||||||
"""
|
|
||||||
# Check for _Raw suffix
|
|
||||||
if "_raw" in model_name.lower():
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Ensemble without suffix defaults to raw
|
|
||||||
if model_name.lower() == "ensemble":
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Default: needs scaling
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_artifacts(cfg: InferenceConfig, model_name: str) -> Tuple[object, object, Optional[object], List[str]]:
|
|
||||||
"""Load model, label encoder, optional scaler, and feature list.
|
|
||||||
|
|
||||||
Supports current MinIO format:
|
|
||||||
- Zimbabwe_*_Raw_Model.pkl (no scaler)
|
|
||||||
- Zimbabwe_*_Model.pkl (needs scaler)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: Inference configuration
|
|
||||||
model_name: Name of the model to load
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (model, label_encoder, scaler, selected_features)
|
|
||||||
"""
|
|
||||||
cache = _local_artifact_cache_dir() / model_name.replace(" ", "_")
|
|
||||||
cache.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Get the MinIO filename
|
|
||||||
model_filename = get_model_filename(model_name)
|
|
||||||
model_key = f"models/{model_filename}" # Prefix in bucket
|
|
||||||
|
|
||||||
model_p = cache / "model.pkl"
|
|
||||||
le_p = cache / "label_encoder.pkl"
|
|
||||||
scaler_p = cache / "scaler.pkl"
|
|
||||||
feats_p = cache / "selected_features.json"
|
|
||||||
|
|
||||||
# Check if cached
|
|
||||||
if not model_p.exists():
|
|
||||||
print(f"📥 Downloading model from MinIO: {model_key}")
|
|
||||||
cfg.storage.download_model_bundle(model_key, cache)
|
|
||||||
|
|
||||||
# Load model
|
|
||||||
model = joblib.load(model_p)
|
|
||||||
|
|
||||||
# Load or create label encoder
|
|
||||||
if le_p.exists():
|
|
||||||
label_encoder = joblib.load(le_p)
|
|
||||||
else:
|
|
||||||
# Try to get classes from model
|
|
||||||
print("⚠️ Label encoder not found, creating default")
|
|
||||||
from sklearn.preprocessing import LabelEncoder
|
|
||||||
label_encoder = LabelEncoder()
|
|
||||||
# Fit on default classes
|
|
||||||
label_encoder.fit(DEFAULT_CLASSES)
|
|
||||||
|
|
||||||
# Load scaler if needed
|
|
||||||
scaler = None
|
|
||||||
if needs_scaler(model_name):
|
|
||||||
if scaler_p.exists():
|
|
||||||
scaler = joblib.load(scaler_p)
|
|
||||||
else:
|
|
||||||
print("⚠️ Scaler not found but required for this model variant")
|
|
||||||
# Create a dummy scaler that does nothing
|
|
||||||
from sklearn.preprocessing import StandardScaler
|
|
||||||
scaler = StandardScaler()
|
|
||||||
# Note: In production, this should fail - scaler must be uploaded
|
|
||||||
|
|
||||||
# Load selected features
|
|
||||||
if feats_p.exists():
|
|
||||||
selected_features = json.loads(feats_p.read_text())
|
|
||||||
else:
|
|
||||||
print("⚠️ Selected features not found, will use all computed features")
|
|
||||||
selected_features = None
|
|
||||||
|
|
||||||
return model, label_encoder, scaler, selected_features
|
|
||||||
|
|
||||||
|
|
||||||
def run_inference_job(cfg: InferenceConfig, job: Dict) -> InferenceResult:
|
|
||||||
"""Main worker entry.
|
|
||||||
|
|
||||||
job payload example:
|
|
||||||
{
|
|
||||||
"job_id": "...",
|
|
||||||
"user_id": "...",
|
|
||||||
"lat": -17.8,
|
|
||||||
"lon": 31.0,
|
|
||||||
"radius_m": 2000,
|
|
||||||
"year": 2022,
|
|
||||||
"season": "summer",
|
|
||||||
"model": "Ensemble" # or "Ensemble_Scaled", "RandomForest", etc.
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
job_id = str(job.get("job_id"))
|
|
||||||
|
|
||||||
# 1) Validate AOI constraints
|
|
||||||
aoi = (float(job["lon"]), float(job["lat"]), float(job["radius_m"]))
|
|
||||||
validate_aoi_zimbabwe(aoi, max_radius_m=cfg.max_radius_m)
|
|
||||||
|
|
||||||
year = int(job["year"])
|
|
||||||
season = str(job.get("season", "summer")).lower()
|
|
||||||
|
|
||||||
# Your training window (Sep -> May)
|
|
||||||
start_date, end_date = cfg.season_dates(year=year, season=season)
|
|
||||||
|
|
||||||
model_name = str(job.get("model", "Ensemble"))
|
|
||||||
print(f"🤖 Loading model: {model_name}")
|
|
||||||
|
|
||||||
model, le, scaler, selected_features = load_model_artifacts(cfg, model_name)
|
|
||||||
|
|
||||||
# Determine if we need scaling
|
|
||||||
use_scaler = scaler is not None and needs_scaler(model_name)
|
|
||||||
print(f" Scaler required: {use_scaler}")
|
|
||||||
|
|
||||||
# 2) Load DW baseline for this year/season (already converted to COGs)
|
|
||||||
# (This gives you the "DW baseline toggle" layer too.)
|
|
||||||
dw_arr, dw_profile = load_dw_baseline_window(
|
|
||||||
cfg=cfg,
|
|
||||||
year=year,
|
|
||||||
season=season,
|
|
||||||
aoi=aoi,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3) Build EO feature stack from DEA STAC
|
|
||||||
# IMPORTANT: This now uses full feature engineering matching train.py
|
|
||||||
print("📡 Building feature stack from DEA STAC...")
|
|
||||||
feat_arr, feat_profile, feat_names, aux_layers = build_feature_stack_from_dea(
|
|
||||||
cfg=cfg,
|
|
||||||
aoi=aoi,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
target_profile=dw_profile,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f" Computed {len(feat_names)} features")
|
|
||||||
print(f" Feature array shape: {feat_arr.shape}")
|
|
||||||
|
|
||||||
# 4) Prepare model input: (H,W,C) -> (N,C)
|
|
||||||
H, W, C = feat_arr.shape
|
|
||||||
X = feat_arr.reshape(-1, C)
|
|
||||||
|
|
||||||
# Ensure feature order matches training
|
|
||||||
if selected_features is not None:
|
|
||||||
name_to_idx = {n: i for i, n in enumerate(feat_names)}
|
|
||||||
keep_idx = [name_to_idx[n] for n in selected_features if n in name_to_idx]
|
|
||||||
|
|
||||||
if len(keep_idx) == 0:
|
|
||||||
print("⚠️ No matching features found, using all computed features")
|
|
||||||
else:
|
|
||||||
print(f" Using {len(keep_idx)} selected features")
|
|
||||||
X = X[:, keep_idx]
|
|
||||||
else:
|
|
||||||
print(" Using all computed features (no selection)")
|
|
||||||
|
|
||||||
# Apply scaler if needed
|
|
||||||
if use_scaler and scaler is not None:
|
|
||||||
print(" Applying StandardScaler")
|
|
||||||
X = scaler.transform(X)
|
|
||||||
|
|
||||||
# Handle NaNs (common with clouds/no-data)
|
|
||||||
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
|
|
||||||
|
|
||||||
# 5) Predict
|
|
||||||
print("🔮 Running prediction...")
|
|
||||||
y_pred = model.predict(X).astype(np.int32)
|
|
||||||
|
|
||||||
# Back to string labels (your refined classes)
|
|
||||||
try:
|
|
||||||
refined_labels = le.inverse_transform(y_pred)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠️ Label inverse_transform failed: {e}")
|
|
||||||
# Fallback: use default classes
|
|
||||||
refined_labels = np.array([DEFAULT_CLASSES[i % len(DEFAULT_CLASSES)] for i in y_pred])
|
|
||||||
|
|
||||||
refined_labels = refined_labels.reshape(H, W)
|
|
||||||
|
|
||||||
# 6) Neighborhood smoothing (majority filter)
|
|
||||||
smoothing_kernel = job.get("smoothing_kernel", cfg.smoothing_kernel)
|
|
||||||
if cfg.smoothing_enabled and smoothing_kernel > 1:
|
|
||||||
print(f"🧼 Applying majority filter (k={smoothing_kernel})")
|
|
||||||
refined_labels = majority_filter(refined_labels, k=smoothing_kernel)
|
|
||||||
|
|
||||||
# 7) Write outputs (GeoTIFF only; COG recommended for tiling)
|
|
||||||
ts = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
|
|
||||||
out_name = f"refined_{season}_{year}_{job_id}_{ts}.tif"
|
|
||||||
baseline_name = f"dw_{season}_{year}_{job_id}_{ts}.tif"
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
|
||||||
refined_path = Path(tmp) / out_name
|
|
||||||
dw_path = Path(tmp) / baseline_name
|
|
||||||
|
|
||||||
# DW baseline
|
|
||||||
with rasterio.open(dw_path, "w", **dw_profile) as dst:
|
|
||||||
dst.write(dw_arr, 1)
|
|
||||||
|
|
||||||
# Refined - store as uint16 with a sidecar legend in meta (recommended)
|
|
||||||
# For now store an index raster; map index->class in meta.json
|
|
||||||
classes = le.classes_.tolist() if hasattr(le, 'classes_') else DEFAULT_CLASSES
|
|
||||||
class_to_idx = {c: i for i, c in enumerate(classes)}
|
|
||||||
|
|
||||||
# Handle string labels
|
|
||||||
if refined_labels.dtype.kind in ['U', 'O', 'S']:
|
|
||||||
# String labels - create mapping
|
|
||||||
idx_raster = np.zeros((H, W), dtype=np.uint16)
|
|
||||||
for i, cls in enumerate(classes):
|
|
||||||
mask = refined_labels == cls
|
|
||||||
idx_raster[mask] = i
|
|
||||||
else:
|
|
||||||
# Numeric labels already
|
|
||||||
idx_raster = refined_labels.astype(np.uint16)
|
|
||||||
|
|
||||||
refined_profile = dw_profile.copy()
|
|
||||||
refined_profile.update({"dtype": "uint16", "count": 1})
|
|
||||||
|
|
||||||
with rasterio.open(refined_path, "w", **refined_profile) as dst:
|
|
||||||
dst.write(idx_raster, 1)
|
|
||||||
|
|
||||||
# Upload
|
|
||||||
refined_uri = cfg.storage.upload_result(local_path=refined_path, key=f"results/{out_name}")
|
|
||||||
dw_uri = cfg.storage.upload_result(local_path=dw_path, key=f"results/{baseline_name}")
|
|
||||||
|
|
||||||
# Optionally upload aux layers (true color, NDVI/EVI/SAVI)
|
|
||||||
aux_uris = {}
|
|
||||||
for layer_name, layer in aux_layers.items():
|
|
||||||
# layer: (H,W) or (H,W,3)
|
|
||||||
aux_path = Path(tmp) / f"{layer_name}_{season}_{year}_{job_id}_{ts}.tif"
|
|
||||||
|
|
||||||
# Determine count and dtype
|
|
||||||
if layer.ndim == 3 and layer.shape[2] == 3:
|
|
||||||
count = 3
|
|
||||||
dtype = layer.dtype
|
|
||||||
else:
|
|
||||||
count = 1
|
|
||||||
dtype = layer.dtype
|
|
||||||
|
|
||||||
aux_profile = dw_profile.copy()
|
|
||||||
aux_profile.update({"count": count, "dtype": str(dtype)})
|
|
||||||
|
|
||||||
with rasterio.open(aux_path, "w", **aux_profile) as dst:
|
|
||||||
if count == 1:
|
|
||||||
dst.write(layer, 1)
|
|
||||||
else:
|
|
||||||
dst.write(layer.transpose(2, 0, 1), [1, 2, 3])
|
|
||||||
|
|
||||||
aux_uris[layer_name] = cfg.storage.upload_result(
|
|
||||||
local_path=aux_path, key=f"results/{aux_path.name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
meta = {
|
|
||||||
"job_id": job_id,
|
|
||||||
"year": year,
|
|
||||||
"season": season,
|
|
||||||
"start_date": start_date,
|
|
||||||
"end_date": end_date,
|
|
||||||
"model": model_name,
|
|
||||||
"scaler_used": use_scaler,
|
|
||||||
"classes": classes,
|
|
||||||
"class_index": class_to_idx,
|
|
||||||
"features_computed": feat_names,
|
|
||||||
"n_features": len(feat_names),
|
|
||||||
"smoothing": {"enabled": cfg.smoothing_enabled, "kernel": smoothing_kernel},
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs = {
|
|
||||||
"refined_geotiff": refined_uri,
|
|
||||||
"dw_baseline_geotiff": dw_uri,
|
|
||||||
**aux_uris,
|
|
||||||
}
|
|
||||||
|
|
||||||
return InferenceResult(job_id=job_id, status="done", outputs=outputs, meta=meta)
|
|
||||||
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# Self-Test
|
|
||||||
# ==========================================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("=== Inference Module Self-Test ===")
|
|
||||||
|
|
||||||
# Check for required dependencies
|
|
||||||
missing_deps = []
|
|
||||||
for mod in ['joblib', 'sklearn']:
|
|
||||||
try:
|
|
||||||
__import__(mod)
|
|
||||||
except ImportError:
|
|
||||||
missing_deps.append(mod)
|
|
||||||
|
|
||||||
if missing_deps:
|
|
||||||
print(f"\n⚠️ Missing dependencies: {missing_deps}")
|
|
||||||
print(" These will be available in the container environment.")
|
|
||||||
print(" Running syntax validation only...")
|
|
||||||
|
|
||||||
# Test 1: predict_raster with dummy data (only if sklearn available)
|
|
||||||
print("\n1. Testing predict_raster with dummy feature cube...")
|
|
||||||
|
|
||||||
# Create dummy feature cube (10, 10, 51)
|
|
||||||
H, W, C = 10, 10, 51
|
|
||||||
dummy_cube = np.random.rand(H, W, C).astype(np.float32)
|
|
||||||
|
|
||||||
# Create dummy feature order
|
|
||||||
from feature_computation import FEATURE_ORDER_V1
|
|
||||||
feature_order = FEATURE_ORDER_V1
|
|
||||||
|
|
||||||
print(f" Feature cube shape: {dummy_cube.shape}")
|
|
||||||
print(f" Feature order length: {len(feature_order)}")
|
|
||||||
|
|
||||||
if 'sklearn' not in missing_deps:
|
|
||||||
# Create a dummy model for testing
|
|
||||||
from sklearn.ensemble import RandomForestClassifier
|
|
||||||
|
|
||||||
# Train a small model on random data
|
|
||||||
X_train = np.random.rand(100, C)
|
|
||||||
y_train = np.random.randint(0, 8, 100)
|
|
||||||
dummy_model = RandomForestClassifier(n_estimators=10, random_state=42)
|
|
||||||
dummy_model.fit(X_train, y_train)
|
|
||||||
|
|
||||||
# Verify model compatibility check
|
|
||||||
print(f" Model n_features_in_: {dummy_model.n_features_in_}")
|
|
||||||
|
|
||||||
# Run prediction
|
|
||||||
try:
|
|
||||||
result = predict_raster(dummy_model, dummy_cube, feature_order)
|
|
||||||
print(f" Prediction result shape: {result.shape}")
|
|
||||||
print(f" Expected shape: ({H}, {W})")
|
|
||||||
|
|
||||||
if result.shape == (H, W):
|
|
||||||
print(" ✓ predict_raster test PASSED")
|
|
||||||
else:
|
|
||||||
print(" ✗ predict_raster test FAILED - wrong shape")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ✗ predict_raster test FAILED: {e}")
|
|
||||||
|
|
||||||
# Test 2: predict_raster with nodata handling
|
|
||||||
print("\n2. Testing nodata handling...")
|
|
||||||
|
|
||||||
# Create cube with nodata (all zeros)
|
|
||||||
nodata_cube = np.zeros((5, 5, C), dtype=np.float32)
|
|
||||||
nodata_cube[2, 2, :] = 1.0 # One valid pixel
|
|
||||||
|
|
||||||
result_nodata = predict_raster(dummy_model, nodata_cube, feature_order)
|
|
||||||
print(f" Nodata pixel value at [2,2]: {result_nodata[2, 2]}")
|
|
||||||
print(f" Nodata pixels (should be 0): {result_nodata[0, 0]}")
|
|
||||||
|
|
||||||
if result_nodata[0, 0] == 0 and result_nodata[0, 1] == 0:
|
|
||||||
print(" ✓ Nodata handling test PASSED")
|
|
||||||
else:
|
|
||||||
print(" ✗ Nodata handling test FAILED")
|
|
||||||
|
|
||||||
# Test 3: Feature mismatch detection
|
|
||||||
print("\n3. Testing feature mismatch detection...")
|
|
||||||
|
|
||||||
wrong_cube = np.random.rand(5, 5, 50).astype(np.float32) # 50 features, not 51
|
|
||||||
|
|
||||||
try:
|
|
||||||
predict_raster(dummy_model, wrong_cube, feature_order)
|
|
||||||
print(" ✗ Feature mismatch test FAILED - should have raised ValueError")
|
|
||||||
except ValueError as e:
|
|
||||||
if "Feature dimension mismatch" in str(e):
|
|
||||||
print(" ✓ Feature mismatch test PASSED")
|
|
||||||
else:
|
|
||||||
print(f" ✗ Wrong error: {e}")
|
|
||||||
else:
|
|
||||||
print(" (sklearn not available - skipping)")
|
|
||||||
|
|
||||||
# Test 4: Try loading model from MinIO (will fail without real storage)
|
|
||||||
print("\n4. Testing load_model from MinIO...")
|
|
||||||
try:
|
|
||||||
from storage import MinIOStorage
|
|
||||||
storage = MinIOStorage()
|
|
||||||
|
|
||||||
# This will fail without real MinIO, but we can catch the error
|
|
||||||
model = load_model(storage, "RandomForest")
|
|
||||||
print(" Model loaded successfully")
|
|
||||||
print(" ✓ load_model test PASSED")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" (Expected) MinIO/storage not available: {e}")
|
|
||||||
print(" ✓ load_model test handled gracefully")
|
|
||||||
|
|
||||||
print("\n=== Inference Module Test Complete ===")
|
|
||||||
|
|
||||||
|
|
@ -8,8 +8,7 @@ This module wires together all the step modules:
|
||||||
- stac_client.py (DEA STAC search)
|
- stac_client.py (DEA STAC search)
|
||||||
- feature_computation.py (51-feature extraction)
|
- feature_computation.py (51-feature extraction)
|
||||||
- dw_baseline.py (windowed DW baseline)
|
- dw_baseline.py (windowed DW baseline)
|
||||||
- inference.py (model loading + prediction)
|
- hybrid_inference.py (CNN + CatBoost ensemble inference)
|
||||||
- postprocess.py (majority filter smoothing)
|
|
||||||
- cog.py (COG export)
|
- cog.py (COG export)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -20,14 +19,18 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import traceback
|
import traceback
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
# Redis/RQ for job queue
|
# Redis/RQ for job queue
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
from rq import Queue
|
from rq import Queue
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
# ==========================================
|
# ==========================================
|
||||||
# Redis Configuration
|
# Redis Configuration
|
||||||
# ==========================================
|
# ==========================================
|
||||||
|
|
@ -44,11 +47,9 @@ def _get_redis_conn():
|
||||||
redis_host = os.getenv("REDIS_HOST", "redis.geocrop.svc.cluster.local")
|
redis_host = os.getenv("REDIS_HOST", "redis.geocrop.svc.cluster.local")
|
||||||
redis_port_str = os.getenv("REDIS_PORT", "6379")
|
redis_port_str = os.getenv("REDIS_PORT", "6379")
|
||||||
|
|
||||||
# Handle case where REDIS_PORT might be a full URL
|
|
||||||
try:
|
try:
|
||||||
redis_port = int(redis_port_str)
|
redis_port = int(redis_port_str)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# If it's a URL, extract the port
|
|
||||||
if "://" in redis_port_str:
|
if "://" in redis_port_str:
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
parsed = urllib.parse.urlparse(redis_port_str)
|
parsed = urllib.parse.urlparse(redis_port_str)
|
||||||
|
|
@ -56,7 +57,6 @@ def _get_redis_conn():
|
||||||
else:
|
else:
|
||||||
redis_port = 6379
|
redis_port = 6379
|
||||||
|
|
||||||
# MUST NOT use decode_responses=True because RQ uses pickle (binary)
|
|
||||||
return Redis(host=redis_host, port=redis_port)
|
return Redis(host=redis_host, port=redis_port)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -81,17 +81,7 @@ def update_status(
|
||||||
outputs: Optional[Dict] = None,
|
outputs: Optional[Dict] = None,
|
||||||
error: Optional[Dict] = None,
|
error: Optional[Dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update job status in Redis.
|
"""Update job status in Redis."""
|
||||||
|
|
||||||
Args:
|
|
||||||
job_id: Job identifier
|
|
||||||
status: Overall status (queued, running, failed, done)
|
|
||||||
stage: Current pipeline stage
|
|
||||||
progress: Progress percentage (0-100)
|
|
||||||
message: Human-readable message
|
|
||||||
outputs: Output file URLs (when done)
|
|
||||||
error: Error details (on failure)
|
|
||||||
"""
|
|
||||||
key = f"job:{job_id}:status"
|
key = f"job:{job_id}:status"
|
||||||
|
|
||||||
status_data = {
|
status_data = {
|
||||||
|
|
@ -109,8 +99,7 @@ def update_status(
|
||||||
status_data["error"] = error
|
status_data["error"] = error
|
||||||
|
|
||||||
try:
|
try:
|
||||||
redis_conn.set(key, json.dumps(status_data), ex=86400) # 24h expiry
|
redis_conn.set(key, json.dumps(status_data), ex=86400)
|
||||||
# Also update the job metadata in RQ if possible
|
|
||||||
from rq import get_current_job
|
from rq import get_current_job
|
||||||
job = get_current_job()
|
job = get_current_job()
|
||||||
if job:
|
if job:
|
||||||
|
|
@ -122,39 +111,65 @@ def update_status(
|
||||||
print(f"Warning: Failed to update Redis status: {e}")
|
print(f"Warning: Failed to update Redis status: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def send_dw_baseline_if_ready(dw_future, storage, job_id, payload, update_func):
|
||||||
|
"""Check if DW baseline is ready and send to client."""
|
||||||
|
if dw_future is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if dw_future.done():
|
||||||
|
try:
|
||||||
|
dw_result = dw_future.result()
|
||||||
|
if dw_result is not None:
|
||||||
|
dw_arr, dw_profile = dw_result
|
||||||
|
|
||||||
|
# Save to temp file
|
||||||
|
import rasterio
|
||||||
|
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
|
||||||
|
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
|
||||||
|
dst.write(dw_arr)
|
||||||
|
|
||||||
|
# Upload to MinIO
|
||||||
|
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
|
||||||
|
storage.upload_result(dw_temp_path, dw_key)
|
||||||
|
|
||||||
|
# Generate presigned URL
|
||||||
|
dw_url = storage.presign_get("geocrop-baselines", dw_key)
|
||||||
|
print(f"[{job_id}] DW baseline URL ready: {dw_url[:80]}...")
|
||||||
|
|
||||||
|
# Notify client
|
||||||
|
update_func(
|
||||||
|
job_id, "running", "dw_ready", 30,
|
||||||
|
"Dynamic World baseline ready",
|
||||||
|
outputs={"dw_baseline_url": dw_url},
|
||||||
|
)
|
||||||
|
return dw_url
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[{job_id}] DW baseline processing failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# ==========================================
|
# ==========================================
|
||||||
# Payload Validation
|
# Payload Validation
|
||||||
# ==========================================
|
# ==========================================
|
||||||
|
|
||||||
def parse_and_validate_payload(payload: dict) -> tuple[dict, List[str]]:
|
def parse_and_validate_payload(payload: dict) -> tuple[dict, List[str]]:
|
||||||
"""Parse and validate job payload.
|
"""Parse and validate job payload."""
|
||||||
|
|
||||||
Args:
|
|
||||||
payload: Raw job payload dict
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (validated_payload, list_of_errors)
|
|
||||||
"""
|
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
# Required fields
|
|
||||||
required = ["job_id", "lat", "lon", "radius_m", "year"]
|
required = ["job_id", "lat", "lon", "radius_m", "year"]
|
||||||
for field in required:
|
for field in required:
|
||||||
if field not in payload:
|
if field not in payload:
|
||||||
errors.append(f"Missing required field: {field}")
|
errors.append(f"Missing required field: {field}")
|
||||||
|
|
||||||
# Validate AOI
|
|
||||||
if "lat" in payload and "lon" in payload:
|
if "lat" in payload and "lon" in payload:
|
||||||
lat = float(payload["lat"])
|
lat = float(payload["lat"])
|
||||||
lon = float(payload["lon"])
|
lon = float(payload["lon"])
|
||||||
|
|
||||||
# Zimbabwe bounds check
|
|
||||||
if not (-22.5 <= lat <= -15.6):
|
if not (-22.5 <= lat <= -15.6):
|
||||||
errors.append(f"Latitude {lat} outside Zimbabwe bounds")
|
errors.append(f"Latitude {lat} outside Zimbabwe bounds")
|
||||||
if not (25.2 <= lon <= 33.1):
|
if not (25.2 <= lon <= 33.1):
|
||||||
errors.append(f"Longitude {lon} outside Zimbabwe bounds")
|
errors.append(f"Longitude {lon} outside Zimbabwe bounds")
|
||||||
|
|
||||||
# Validate radius
|
|
||||||
if "radius_m" in payload:
|
if "radius_m" in payload:
|
||||||
radius = int(payload["radius_m"])
|
radius = int(payload["radius_m"])
|
||||||
if radius > 5000:
|
if radius > 5000:
|
||||||
|
|
@ -162,26 +177,22 @@ def parse_and_validate_payload(payload: dict) -> tuple[dict, List[str]]:
|
||||||
if radius < 100:
|
if radius < 100:
|
||||||
errors.append(f"Radius {radius}m below min 100m")
|
errors.append(f"Radius {radius}m below min 100m")
|
||||||
|
|
||||||
# Validate year
|
|
||||||
if "year" in payload:
|
if "year" in payload:
|
||||||
year = int(payload["year"])
|
year = int(payload["year"])
|
||||||
current_year = datetime.now().year
|
current_year = datetime.now().year
|
||||||
if year < 2015 or year > current_year:
|
if year < 2015 or year > current_year:
|
||||||
errors.append(f"Year {year} outside valid range (2015-{current_year})")
|
errors.append(f"Year {year} outside valid range (2015-{current_year})")
|
||||||
|
|
||||||
# Validate model
|
|
||||||
if "model" in payload:
|
if "model" in payload:
|
||||||
from contracts import VALID_MODELS
|
from contracts import VALID_MODELS
|
||||||
if payload["model"] not in VALID_MODELS:
|
if payload["model"] not in VALID_MODELS:
|
||||||
errors.append(f"Invalid model: {payload['model']}. Must be one of {VALID_MODELS}")
|
errors.append(f"Invalid model: {payload['model']}. Must be one of {VALID_MODELS}")
|
||||||
|
|
||||||
# Validate kernel
|
|
||||||
if "smoothing_kernel" in payload:
|
if "smoothing_kernel" in payload:
|
||||||
kernel = int(payload["smoothing_kernel"])
|
kernel = int(payload["smoothing_kernel"])
|
||||||
if kernel not in [3, 5, 7]:
|
if kernel not in [3, 5, 7]:
|
||||||
errors.append(f"Invalid smoothing_kernel: {kernel}. Must be 3, 5, or 7")
|
errors.append(f"Invalid smoothing_kernel: {kernel}. Must be 3, 5, or 7")
|
||||||
|
|
||||||
# Set defaults
|
|
||||||
validated = {
|
validated = {
|
||||||
"job_id": payload.get("job_id", "unknown"),
|
"job_id": payload.get("job_id", "unknown"),
|
||||||
"lat": float(payload.get("lat", 0)),
|
"lat": float(payload.get("lat", 0)),
|
||||||
|
|
@ -203,30 +214,47 @@ def parse_and_validate_payload(payload: dict) -> tuple[dict, List[str]]:
|
||||||
|
|
||||||
|
|
||||||
# ==========================================
|
# ==========================================
|
||||||
# Main Job Runner
|
# Async DW Loading Helper
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
|
def _load_dw_async(storage, bbox, year, season) -> Optional[Tuple[np.ndarray, dict]]:
|
||||||
|
"""Async wrapper for DW baseline loading."""
|
||||||
|
from dw_baseline import load_dw_baseline_window
|
||||||
|
try:
|
||||||
|
dw_arr, dw_profile = load_dw_baseline_window(
|
||||||
|
storage=storage,
|
||||||
|
aoi_bbox_wgs84=bbox,
|
||||||
|
year=year,
|
||||||
|
season=season,
|
||||||
|
)
|
||||||
|
print(f"[_dw_load] DW baseline loaded: shape={dw_arr.shape}")
|
||||||
|
return dw_arr, dw_profile
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[_dw_load] DW baseline failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# Main Job Runner (Async)
|
||||||
# ==========================================
|
# ==========================================
|
||||||
|
|
||||||
def run_job(payload_dict: dict) -> dict:
|
def run_job(payload_dict: dict) -> dict:
|
||||||
"""Main job runner function.
|
"""Main job runner with async DW baseline loading.
|
||||||
|
|
||||||
This is the RQ task function that orchestrates the full pipeline.
|
DW baseline loads in background while hybrid inference runs.
|
||||||
|
DW URL is sent to client as soon as it's ready, parallel to inference.
|
||||||
"""
|
"""
|
||||||
from rq import get_current_job
|
from rq import get_current_job
|
||||||
current_job = get_current_job()
|
current_job = get_current_job()
|
||||||
|
|
||||||
# Extract job_id from payload or RQ
|
|
||||||
job_id = payload_dict.get("job_id")
|
job_id = payload_dict.get("job_id")
|
||||||
if not job_id and current_job:
|
if not job_id and current_job:
|
||||||
job_id = current_job.id
|
job_id = current_job.id
|
||||||
if not job_id:
|
if not job_id:
|
||||||
job_id = "unknown"
|
job_id = "unknown"
|
||||||
|
|
||||||
# Ensure job_id is in payload for validation
|
|
||||||
payload_dict["job_id"] = job_id
|
payload_dict["job_id"] = job_id
|
||||||
|
|
||||||
# Standardize payload from API format to worker format
|
|
||||||
# API sends: radius_km, model_name
|
|
||||||
# Worker expects: radius_m, model
|
|
||||||
if "radius_km" in payload_dict and "radius_m" not in payload_dict:
|
if "radius_km" in payload_dict and "radius_m" not in payload_dict:
|
||||||
payload_dict["radius_m"] = int(float(payload_dict["radius_km"]) * 1000)
|
payload_dict["radius_m"] = int(float(payload_dict["radius_km"]) * 1000)
|
||||||
|
|
||||||
|
|
@ -245,7 +273,6 @@ def run_job(payload_dict: dict) -> dict:
|
||||||
)
|
)
|
||||||
return {"status": "failed", "error": str(e)}
|
return {"status": "failed", "error": str(e)}
|
||||||
|
|
||||||
# Parse and validate payload
|
|
||||||
payload, errors = parse_and_validate_payload(payload_dict)
|
payload, errors = parse_and_validate_payload(payload_dict)
|
||||||
if errors:
|
if errors:
|
||||||
update_status(
|
update_status(
|
||||||
|
|
@ -255,174 +282,97 @@ def run_job(payload_dict: dict) -> dict:
|
||||||
)
|
)
|
||||||
return {"status": "failed", "errors": errors}
|
return {"status": "failed", "errors": errors}
|
||||||
|
|
||||||
# Update initial status
|
update_status(job_id, "running", "init", 5, "Starting inference pipeline...")
|
||||||
update_status(job_id, "running", "fetch_stac", 5, "Fetching STAC items...")
|
|
||||||
|
|
||||||
missing_outputs = []
|
dw_baseline_url = None
|
||||||
output_urls = {}
|
output_urls = {}
|
||||||
|
missing_outputs = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# ==========================================
|
# Get config and AOI bbox
|
||||||
# Stage 1: Fetch STAC
|
|
||||||
# ==========================================
|
|
||||||
print(f"[{job_id}] Fetching STAC items for {payload['year']} {payload['season']}...")
|
|
||||||
|
|
||||||
from stac_client import DEASTACClient
|
|
||||||
from config import InferenceConfig, MinIOStorage as ConfigMinIO
|
from config import InferenceConfig, MinIOStorage as ConfigMinIO
|
||||||
|
|
||||||
cfg = InferenceConfig()
|
cfg = InferenceConfig()
|
||||||
# Initialize storage adapter for inference.py
|
|
||||||
cfg.storage = ConfigMinIO()
|
cfg.storage = ConfigMinIO()
|
||||||
|
|
||||||
# Get season dates
|
|
||||||
start_date, end_date = cfg.season_dates(payload['year'], payload['season'])
|
start_date, end_date = cfg.season_dates(payload['year'], payload['season'])
|
||||||
|
|
||||||
# Calculate AOI bbox
|
|
||||||
lat, lon, radius = payload['lat'], payload['lon'], payload['radius_m']
|
lat, lon, radius = payload['lat'], payload['lon'], payload['radius_m']
|
||||||
|
radius_deg = radius / 111000
|
||||||
|
bbox = [lon - radius_deg, lat - radius_deg, lon + radius_deg, lat + radius_deg]
|
||||||
|
|
||||||
# Rough bbox from radius (in degrees)
|
# ==========================================
|
||||||
radius_deg = radius / 111000 # ~111km per degree
|
# Start DW baseline loading in background
|
||||||
bbox = [
|
# ==========================================
|
||||||
lon - radius_deg, # min_lon
|
update_status(job_id, "running", "load_dw", 10, "Loading Dynamic World baseline (async)...")
|
||||||
lat - radius_deg, # min_lat
|
print(f"[{job_id}] Starting async DW baseline load...")
|
||||||
lon + radius_deg, # max_lon
|
|
||||||
lat + radius_deg, # max_lat
|
|
||||||
]
|
|
||||||
|
|
||||||
# Search STAC
|
with ThreadPoolExecutor(max_workers=1) as dw_executor:
|
||||||
stac_client = DEASTACClient()
|
dw_future = dw_executor.submit(
|
||||||
|
_load_dw_async,
|
||||||
try:
|
storage, bbox, payload['year'], payload['season']
|
||||||
items = stac_client.search_items(
|
|
||||||
bbox=bbox,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
)
|
)
|
||||||
print(f"[{job_id}] Found {len(items)} STAC items")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[{job_id}] STAC search failed: {e}")
|
|
||||||
# Continue but note that features may be limited
|
|
||||||
|
|
||||||
update_status(job_id, "running", "build_features", 20, "Building feature cube...")
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# Stage 2: Build Feature Cube
|
|
||||||
# ==========================================
|
|
||||||
print(f"[{job_id}] Building feature cube...")
|
|
||||||
|
|
||||||
from feature_computation import FEATURE_ORDER_V1
|
|
||||||
|
|
||||||
feature_order = FEATURE_ORDER_V1
|
|
||||||
expected_features = len(feature_order) # Should be 51
|
|
||||||
|
|
||||||
print(f"[{job_id}] Expected {expected_features} features (FEATURE_ORDER_V1)")
|
|
||||||
|
|
||||||
# Check if we have an existing feature builder in features.py
|
|
||||||
feature_cube = None
|
|
||||||
use_synthetic = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from features import build_feature_stack_from_dea
|
|
||||||
print(f"[{job_id}] Trying build_feature_stack_from_dea for feature extraction...")
|
|
||||||
|
|
||||||
# Try to call it - this requires stackstac and DEA STAC access
|
# ==========================================
|
||||||
try:
|
# Start hybrid inference immediately (in parallel)
|
||||||
feature_cube = build_feature_stack_from_dea(
|
# ==========================================
|
||||||
items=items,
|
update_status(job_id, "running", "load_model", 20, "Loading model artifacts...")
|
||||||
bbox=bbox,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
)
|
|
||||||
print(f"[{job_id}] Feature cube built successfully: {feature_cube.shape if feature_cube is not None else 'None'}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[{job_id}] Feature stack building failed: {e}")
|
|
||||||
print(f"[{job_id}] Falling back to synthetic features for testing")
|
|
||||||
use_synthetic = True
|
|
||||||
|
|
||||||
except ImportError as e:
|
|
||||||
print(f"[{job_id}] Feature builder not available: {e}")
|
|
||||||
print(f"[{job_id}] Using synthetic features for testing")
|
|
||||||
use_synthetic = True
|
|
||||||
|
|
||||||
# Generate synthetic features for testing when real data isn't available
|
|
||||||
if feature_cube is None:
|
|
||||||
print(f"[{job_id}] Generating synthetic features for pipeline test...")
|
|
||||||
|
|
||||||
# Determine raster dimensions from DW baseline if loaded
|
model_dir = Path(tempfile.mkdtemp())
|
||||||
if 'dw_arr' in dir() and dw_arr is not None:
|
print(f"[{job_id}] Downloading model artifacts...")
|
||||||
H, W = dw_arr.shape
|
|
||||||
else:
|
|
||||||
# Default size for testing
|
|
||||||
H, W = 100, 100
|
|
||||||
|
|
||||||
# Generate synthetic features: shape (H, W, 51)
|
# Download model artifacts
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# Use year as seed for reproducible but varied features
|
|
||||||
np.random.seed(payload['year'] + int(payload.get('lon', 0) * 100) + int(payload.get('lat', 0) * 100))
|
|
||||||
|
|
||||||
# Generate realistic-looking features (normalized values)
|
|
||||||
feature_cube = np.random.rand(H, W, expected_features).astype(np.float32)
|
|
||||||
|
|
||||||
# Add some structure - make center pixels different from edges
|
|
||||||
y, x = np.ogrid[:H, :W]
|
|
||||||
center_y, center_x = H // 2, W // 2
|
|
||||||
dist = np.sqrt((y - center_y)**2 + (x - center_x)**2)
|
|
||||||
max_dist = np.sqrt(center_y**2 + center_x**2)
|
|
||||||
|
|
||||||
# Add a gradient based on distance from center (simulating field pattern)
|
|
||||||
for i in range(min(10, expected_features)):
|
|
||||||
feature_cube[:, :, i] = (1 - dist / max_dist) * 0.5 + feature_cube[:, :, i] * 0.5
|
|
||||||
|
|
||||||
print(f"[{job_id}] Synthetic feature cube shape: {feature_cube.shape}")
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# Stage 3: Load Model Artifacts
|
|
||||||
# ==========================================
|
|
||||||
update_status(job_id, "running", "load_model", 40, "Loading model artifacts...")
|
|
||||||
|
|
||||||
is_hybrid = "hybrid" in payload['model'].lower() or "spatiotemporal" in payload['model'].lower()
|
|
||||||
|
|
||||||
model_dir = Path(tempfile.mkdtemp())
|
|
||||||
|
|
||||||
if is_hybrid:
|
|
||||||
print(f"[{job_id}] Model type: Hybrid Spatio-Temporal. Downloading artifacts...")
|
|
||||||
# Expected files in MinIO: pipeline_meta.pkl, Temporal_FCN.pth, calibrated_hybrid_cb.pkl
|
|
||||||
for artifact in ["pipeline_meta.pkl", "Temporal_FCN.pth", "calibrated_hybrid_cb.pkl"]:
|
for artifact in ["pipeline_meta.pkl", "Temporal_FCN.pth", "calibrated_hybrid_cb.pkl"]:
|
||||||
try:
|
try:
|
||||||
storage.download_file(storage.bucket_models, artifact, model_dir / artifact)
|
storage.download_file(storage.bucket_models, artifact, model_dir / artifact)
|
||||||
print(f"[{job_id}] Downloaded {artifact}")
|
print(f"[{job_id}] Downloaded {artifact}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{job_id}] Failed to download {artifact}: {e}")
|
|
||||||
# Try with 'hybrid/' prefix if direct fails
|
|
||||||
try:
|
try:
|
||||||
storage.download_file(storage.bucket_models, f"hybrid/{artifact}", model_dir / artifact)
|
storage.download_file(storage.bucket_models, f"hybrid/{artifact}", model_dir / artifact)
|
||||||
print(f"[{job_id}] Downloaded {artifact} (from hybrid/ prefix)")
|
print(f"[{job_id}] Downloaded {artifact} (from hybrid/ prefix)")
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
raise FileNotFoundError(f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}")
|
raise FileNotFoundError(
|
||||||
|
f"Required artifact {artifact} not found in {storage.bucket_models}: {e2}"
|
||||||
|
)
|
||||||
|
|
||||||
|
update_status(job_id, "running", "fetch_stac", 30, "Fetching spatio-temporal data...")
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# Stage 4: Fetch Spatio-Temporal Data
|
|
||||||
# ==========================================
|
|
||||||
update_status(job_id, "running", "fetch_stac", 50, "Fetching spatio-temporal indices...")
|
|
||||||
from hybrid_inference import DEAfricaSTACWrapper, CropInferencePipeline
|
from hybrid_inference import DEAfricaSTACWrapper, CropInferencePipeline
|
||||||
|
|
||||||
stac_wrapper = DEAfricaSTACWrapper()
|
stac_wrapper = DEAfricaSTACWrapper()
|
||||||
# Calculate ranges for wrapper
|
|
||||||
lat_range = (bbox[1], bbox[3])
|
lat_range = (bbox[1], bbox[3])
|
||||||
lon_range = (bbox[0], bbox[2])
|
lon_range = (bbox[0], bbox[2])
|
||||||
time_range = (start_date, end_date)
|
time_range = (start_date, end_date)
|
||||||
|
|
||||||
|
print(f"[{job_id}] Fetching STAC data from DEA...")
|
||||||
unseen_pixel_df = stac_wrapper.fetch_and_format_data(
|
unseen_pixel_df = stac_wrapper.fetch_and_format_data(
|
||||||
lat_range=lat_range,
|
lat_range=lat_range,
|
||||||
lon_range=lon_range,
|
lon_range=lon_range,
|
||||||
time_range=time_range
|
time_range=time_range
|
||||||
)
|
)
|
||||||
|
print(f"[{job_id}] STAC data fetched: {len(unseen_pixel_df)} pixels")
|
||||||
|
|
||||||
|
# Check if DW is ready while processing STAC
|
||||||
|
if dw_future.done():
|
||||||
|
dw_result = dw_future.result()
|
||||||
|
if dw_result is not None:
|
||||||
|
dw_arr, dw_profile = dw_result
|
||||||
|
import rasterio
|
||||||
|
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
|
||||||
|
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
|
||||||
|
dst.write(dw_arr)
|
||||||
|
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
|
||||||
|
storage.upload_result(dw_temp_path, dw_key)
|
||||||
|
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
|
||||||
|
update_status(
|
||||||
|
job_id, "running", "dw_ready", 35,
|
||||||
|
"Dynamic World baseline ready",
|
||||||
|
outputs={"dw_baseline_url": dw_baseline_url},
|
||||||
|
)
|
||||||
|
|
||||||
|
update_status(job_id, "running", "infer", 50, "Running Hybrid Inference (CNN + CatBoost)...")
|
||||||
|
print(f"[{job_id}] Running hybrid inference...")
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# Stage 5: Hybrid Inference
|
|
||||||
# ==========================================
|
|
||||||
update_status(job_id, "running", "infer", 70, "Running Hybrid Inference (CNN + CatBoost)...")
|
|
||||||
pipeline = CropInferencePipeline(model_dir=str(model_dir))
|
pipeline = CropInferencePipeline(model_dir=str(model_dir))
|
||||||
|
|
||||||
mapped_crops_df = pipeline.predict(
|
mapped_crops_df = pipeline.predict(
|
||||||
|
|
@ -430,17 +380,19 @@ def run_job(payload_dict: dict) -> dict:
|
||||||
apply_spatial_smoothing=True,
|
apply_spatial_smoothing=True,
|
||||||
coord_cols=['lat', 'lon']
|
coord_cols=['lat', 'lon']
|
||||||
)
|
)
|
||||||
|
print(f"[{job_id}] Inference complete, exporting results...")
|
||||||
|
|
||||||
# ==========================================
|
# ==========================================
|
||||||
# Stage 6: Export and Upload
|
# Export and Upload Results
|
||||||
# ==========================================
|
# ==========================================
|
||||||
update_status(job_id, "running", "export_cog", 90, "Exporting results...")
|
update_status(job_id, "running", "export_cog", 80, "Exporting results...")
|
||||||
|
|
||||||
output_dir = Path(tempfile.mkdtemp())
|
output_dir = Path(tempfile.mkdtemp())
|
||||||
output_path = output_dir / "refined.tif"
|
output_path = output_dir / "refined.tif"
|
||||||
|
|
||||||
pipeline.export_to_geotiff(mapped_crops_df, output_path=str(output_path))
|
pipeline.export_to_geotiff(mapped_crops_df, output_path=str(output_path))
|
||||||
|
|
||||||
output_urls = {}
|
# Upload results
|
||||||
for filename in ["refined.tif", "refined_confidence.tif", "refined_cloud_mask.tif", "refined_legend.json"]:
|
for filename in ["refined.tif", "refined_confidence.tif", "refined_cloud_mask.tif", "refined_legend.json"]:
|
||||||
local_f = output_dir / filename
|
local_f = output_dir / filename
|
||||||
if local_f.exists():
|
if local_f.exists():
|
||||||
|
|
@ -448,39 +400,47 @@ def run_job(payload_dict: dict) -> dict:
|
||||||
storage.upload_result(local_f, result_key)
|
storage.upload_result(local_f, result_key)
|
||||||
output_urls[filename.replace(".","_url")] = storage.presign_get("geocrop-results", result_key)
|
output_urls[filename.replace(".","_url")] = storage.presign_get("geocrop-results", result_key)
|
||||||
|
|
||||||
else:
|
# Check DW one more time (may have finished during inference)
|
||||||
# Fallback to Legacy/Standard logic
|
if dw_baseline_url is None and dw_future.done():
|
||||||
print(f"[{job_id}] Using standard/ensemble inference logic...")
|
dw_result = dw_future.result()
|
||||||
from inference import run_inference_job
|
if dw_result is not None:
|
||||||
|
dw_arr, dw_profile = dw_result
|
||||||
|
import rasterio
|
||||||
|
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
|
||||||
|
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
|
||||||
|
dst.write(dw_arr)
|
||||||
|
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
|
||||||
|
storage.upload_result(dw_temp_path, dw_key)
|
||||||
|
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
|
||||||
|
|
||||||
# Create a mock job dict compatible with run_inference_job
|
# Wait for DW if still running
|
||||||
job_payload = {
|
if dw_baseline_url is None:
|
||||||
"job_id": job_id,
|
print(f"[{job_id}] Waiting for DW baseline to finish...")
|
||||||
"lat": payload["lat"],
|
dw_result = dw_future.result(timeout=60)
|
||||||
"lon": payload["lon"],
|
if dw_result is not None:
|
||||||
"radius_m": payload["radius_m"],
|
dw_arr, dw_profile = dw_result
|
||||||
"year": payload["year"],
|
import rasterio
|
||||||
"season": payload["season"],
|
dw_temp_path = Path(tempfile.mktemp(suffix=".tif"))
|
||||||
"model": payload["model"],
|
with rasterio.open(dw_temp_path, 'w', **dw_profile) as dst:
|
||||||
"smoothing_kernel": payload["smoothing_kernel"]
|
dst.write(dw_arr)
|
||||||
}
|
dw_key = f"baselines/{job_id}/dw_baseline_{payload['year']}_{payload['season']}.tif"
|
||||||
|
storage.upload_result(dw_temp_path, dw_key)
|
||||||
inference_result = run_inference_job(cfg, job_payload)
|
dw_baseline_url = storage.presign_get("geocrop-baselines", dw_key)
|
||||||
output_urls = inference_result.outputs
|
|
||||||
|
# ==========================================
|
||||||
# Note: indices and true_color not yet implemented
|
# Final Status
|
||||||
|
# ==========================================
|
||||||
|
final_outputs = dict(output_urls)
|
||||||
|
if dw_baseline_url:
|
||||||
|
final_outputs["dw_baseline_url"] = dw_baseline_url
|
||||||
|
|
||||||
if payload['outputs'].get('indices'):
|
if payload['outputs'].get('indices'):
|
||||||
missing_outputs.append("indices: not implemented")
|
missing_outputs.append("indices: not implemented")
|
||||||
if payload['outputs'].get('true_color'):
|
if payload['outputs'].get('true_color'):
|
||||||
missing_outputs.append("true_color: not implemented")
|
missing_outputs.append("true_color: not implemented")
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# Stage 7: Final Status
|
|
||||||
# ==========================================
|
|
||||||
final_status = "partial" if missing_outputs else "done"
|
final_status = "partial" if missing_outputs else "done"
|
||||||
final_message = f"Inference complete"
|
final_message = f"Inference complete" + (f" ({', '.join(missing_outputs)})" if missing_outputs else "")
|
||||||
if missing_outputs:
|
|
||||||
final_message += f" (partial: {', '.join(missing_outputs)})"
|
|
||||||
|
|
||||||
update_status(
|
update_status(
|
||||||
job_id,
|
job_id,
|
||||||
|
|
@ -488,7 +448,7 @@ def run_job(payload_dict: dict) -> dict:
|
||||||
"done",
|
"done",
|
||||||
100,
|
100,
|
||||||
final_message,
|
final_message,
|
||||||
outputs=output_urls,
|
outputs=final_outputs if final_outputs else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"[{job_id}] Job complete: {final_status}")
|
print(f"[{job_id}] Job complete: {final_status}")
|
||||||
|
|
@ -496,12 +456,11 @@ def run_job(payload_dict: dict) -> dict:
|
||||||
return {
|
return {
|
||||||
"status": final_status,
|
"status": final_status,
|
||||||
"job_id": job_id,
|
"job_id": job_id,
|
||||||
"outputs": output_urls,
|
"outputs": final_outputs,
|
||||||
"missing": missing_outputs if missing_outputs else None,
|
"missing": missing_outputs if missing_outputs else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Catch-all for any unexpected errors
|
|
||||||
error_trace = traceback.format_exc()
|
error_trace = traceback.format_exc()
|
||||||
print(f"[{job_id}] Error: {e}")
|
print(f"[{job_id}] Error: {e}")
|
||||||
print(error_trace)
|
print(error_trace)
|
||||||
|
|
@ -518,9 +477,10 @@ def run_job(payload_dict: dict) -> dict:
|
||||||
"job_id": job_id,
|
"job_id": job_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Alias for API
|
|
||||||
run_inference = run_job
|
run_inference = run_job
|
||||||
|
|
||||||
|
|
||||||
# ==========================================
|
# ==========================================
|
||||||
# RQ Worker Entry Point
|
# RQ Worker Entry Point
|
||||||
# ==========================================
|
# ==========================================
|
||||||
|
|
@ -530,7 +490,6 @@ def start_rq_worker():
|
||||||
from rq import Worker
|
from rq import Worker
|
||||||
import signal
|
import signal
|
||||||
|
|
||||||
# Ensure /app is in sys.path so we can import modules
|
|
||||||
if '/app' not in sys.path:
|
if '/app' not in sys.path:
|
||||||
sys.path.insert(0, '/app')
|
sys.path.insert(0, '/app')
|
||||||
|
|
||||||
|
|
@ -539,9 +498,7 @@ def start_rq_worker():
|
||||||
print(f"=== GeoCrop RQ Worker Starting ===")
|
print(f"=== GeoCrop RQ Worker Starting ===")
|
||||||
print(f"Listening on queue: {queue_name}")
|
print(f"Listening on queue: {queue_name}")
|
||||||
print(f"Redis: {os.getenv('REDIS_HOST', 'redis.geocrop.svc.cluster.local')}:{os.getenv('REDIS_PORT', '6379')}")
|
print(f"Redis: {os.getenv('REDIS_HOST', 'redis.geocrop.svc.cluster.local')}:{os.getenv('REDIS_PORT', '6379')}")
|
||||||
print(f"Python path: {sys.path[:3]}")
|
|
||||||
|
|
||||||
# Handle graceful shutdown
|
|
||||||
def signal_handler(signum, frame):
|
def signal_handler(signum, frame):
|
||||||
print("\nReceived shutdown signal, exiting gracefully...")
|
print("\nReceived shutdown signal, exiting gracefully...")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
@ -569,22 +526,17 @@ if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.test or not args.worker:
|
if args.test or not args.worker:
|
||||||
# Syntax-level self-test
|
|
||||||
print("=== GeoCrop Worker Syntax Test ===")
|
print("=== GeoCrop Worker Syntax Test ===")
|
||||||
|
|
||||||
# Test imports
|
|
||||||
try:
|
try:
|
||||||
from contracts import STAGES, VALID_MODELS
|
from contracts import STAGES, VALID_MODELS
|
||||||
from storage import MinIOStorage
|
from storage import MinIOStorage
|
||||||
from feature_computation import FEATURE_ORDER_V1
|
|
||||||
print(f"✓ Imports OK")
|
print(f"✓ Imports OK")
|
||||||
print(f" STAGES: {STAGES}")
|
print(f" STAGES: {STAGES}")
|
||||||
print(f" VALID_MODELS: {VALID_MODELS}")
|
print(f" VALID_MODELS: {VALID_MODELS}")
|
||||||
print(f" FEATURE_ORDER length: {len(FEATURE_ORDER_V1)}")
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"⚠ Some imports missing (expected outside container): {e}")
|
print(f"⚠ Some imports missing: {e}")
|
||||||
|
|
||||||
# Test payload parsing
|
|
||||||
print("\n--- Payload Parsing Test ---")
|
print("\n--- Payload Parsing Test ---")
|
||||||
test_payload = {
|
test_payload = {
|
||||||
"job_id": "test-123",
|
"job_id": "test-123",
|
||||||
|
|
@ -592,7 +544,7 @@ if __name__ == "__main__":
|
||||||
"lon": 31.0,
|
"lon": 31.0,
|
||||||
"radius_m": 2000,
|
"radius_m": 2000,
|
||||||
"year": 2022,
|
"year": 2022,
|
||||||
"model": "Ensemble",
|
"model": "Hybrid_SpatioTemporal",
|
||||||
"smoothing_kernel": 5,
|
"smoothing_kernel": 5,
|
||||||
"outputs": {"refined": True, "dw_baseline": True},
|
"outputs": {"refined": True, "dw_baseline": True},
|
||||||
}
|
}
|
||||||
|
|
@ -605,18 +557,8 @@ if __name__ == "__main__":
|
||||||
print(f" job_id: {validated['job_id']}")
|
print(f" job_id: {validated['job_id']}")
|
||||||
print(f" AOI: ({validated['lat']}, {validated['lon']}) radius={validated['radius_m']}m")
|
print(f" AOI: ({validated['lat']}, {validated['lon']}) radius={validated['radius_m']}m")
|
||||||
print(f" model: {validated['model']}")
|
print(f" model: {validated['model']}")
|
||||||
print(f" kernel: {validated['smoothing_kernel']}")
|
|
||||||
|
|
||||||
# Show what would run
|
|
||||||
print("\n--- Pipeline Overview ---")
|
|
||||||
print("Pipeline stages:")
|
|
||||||
for i, stage in enumerate(STAGES):
|
|
||||||
print(f" {i+1}. {stage}")
|
|
||||||
|
|
||||||
print("\nNote: This is a syntax-level test.")
|
|
||||||
print("Full execution requires Redis, MinIO, and STAC access in the container.")
|
|
||||||
|
|
||||||
print("\n=== Worker Syntax Test Complete ===")
|
print("\n=== Worker Syntax Test Complete ===")
|
||||||
|
|
||||||
if args.worker:
|
if args.worker:
|
||||||
start_rq_worker()
|
start_rq_worker()
|
||||||
Loading…
Reference in New Issue