feat: implement Spatio-Temporal Deep Learning pipeline for inference worker
Build and Push Docker Images / build-and-push (push) Failing after 24s Details

- Add hybrid PyTorch (TemporalFCN) + CatBoost ensemble logic in new hybrid_inference.py
- Update worker.py to support 'Hybrid' model type with artifact syncing from MinIO
- Integrate odc-stac for raw spectral index fetching from DE Africa STAC
- Update requirements.txt with torch, odc-stac, and rioxarray dependencies
- Include ntfy deployment in k8s manifests
This commit is contained in:
fchinembiri 2026-05-01 01:17:28 +02:00
parent dba7d2bf99
commit 096ed9f76b
5 changed files with 445 additions and 97 deletions

View File

@ -0,0 +1,276 @@
import os
import io
import json
import time
import copy
import joblib
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.neighbors import KNeighborsRegressor
from catboost import CatBoostClassifier
# Digital Earth Africa STAC specific imports
try:
from pystac_client import Client
import odc.stac
import xarray as xr
import rioxarray
except ImportError:
Client = None
odc = None
xr = None
rioxarray = None
# ==========================================
# 1. CPU-OPTIMIZED ARCHITECTURES
# ==========================================
class TemporalFCN(nn.Module):
def __init__(self, num_bands, num_classes):
super().__init__()
self.conv_block1 = nn.Sequential(
nn.Conv1d(num_bands, 64, kernel_size=5, padding=2),
nn.BatchNorm1d(64),
nn.ReLU()
)
self.conv_block2 = nn.Sequential(
nn.Conv1d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm1d(128),
nn.ReLU()
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Linear(128, num_classes)
def forward(self, x, return_features=False):
x = self.conv_block1(x)
x = self.conv_block2(x)
features = self.global_avg_pool(x).squeeze(-1)
out = self.fc(features)
if return_features:
return out, features
return out
class SmallGRU(nn.Module):
def __init__(self, num_bands, num_classes, hidden_size=64):
super().__init__()
self.gru = nn.GRU(input_size=num_bands, hidden_size=hidden_size, num_layers=1, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x, return_features=False):
x = x.transpose(1, 2)
out, _ = self.gru(x)
features = out[:, -1, :]
final_out = self.fc(features)
if return_features:
return final_out, features
return final_out
# ==========================================
# 2. DATA PREPARATION & PYTORCH UTILS
# ==========================================
class CropDataset(Dataset):
def __init__(self, X, y, augment=False):
self.X = torch.FloatTensor(X)
self.y = torch.LongTensor(y)
self.augment = augment
def __len__(self):
return len(self.y)
def __getitem__(self, idx):
x = self.X[idx].clone()
if self.augment:
if torch.rand(1).item() > 0.5:
noise = torch.randn_like(x) * 0.03
x = x + noise
if torch.rand(1).item() > 0.7:
seq_len = x.shape[1]
t_idx = torch.randint(0, seq_len, (1,)).item()
x[:, t_idx] = 0.0
return x, self.y[idx]
def prepare_tensors(df, bands, dates):
num_samples = len(df)
X_3d = np.zeros((num_samples, len(bands), len(dates)), dtype=np.float32)
for b_idx, band in enumerate(bands):
for d_idx, date in enumerate(dates):
col = f"{date}_{band}"
if col in df.columns:
X_3d[:, b_idx, d_idx] = df[col].values
means = X_3d.mean(axis=2, keepdims=True)
stds = X_3d.std(axis=2, keepdims=True) + 1e-8
X_3d = (X_3d - means) / stds
return X_3d
# ==========================================
# 3. DIGITAL EARTH AFRICA STAC INTEGRATION
# ==========================================
class DEAfricaSTACWrapper:
def __init__(self, stac_url="https://explorer.digitalearth.africa/stac"):
if Client is None or odc is None or xr is None:
raise ImportError("Missing required libraries: pystac-client, odc-stac, xarray")
print(f"Connecting to Digital Earth Africa STAC Catalog at {stac_url}...")
self.catalog = Client.open(stac_url)
def fetch_and_format_data(self, lat_range, lon_range, time_range, resolution=20):
bbox = [lon_range[0], lat_range[0], lon_range[1], lat_range[1]]
print(f"Searching STAC for Bounding Box: {bbox} over {time_range}...")
search = self.catalog.search(
collections=["s2_l2a"],
bbox=bbox,
datetime=f"{time_range[0]}/{time_range[1]}"
)
items = list(search.items())
if not items:
raise ValueError("No STAC items found for this bounding box and time range.")
print(f"Found {len(items)} STAC items. Loading into xarray...")
ds = odc.stac.load(
items,
measurements=['red', 'green', 'blue', 'nir', 'red_edge_1', 'scl'],
bbox=bbox,
crs="EPSG:6933",
resolution=resolution,
groupby="solar_day"
)
print("Masking clouds and shadows...")
valid_mask = (ds.scl == 4) | (ds.scl == 5) | (ds.scl == 6) | (ds.scl == 2) | (ds.scl == 7)
ds = ds.where(valid_mask)
ds = ds / 10000.0
print("Computing Spectral Indices (NDVI, NDRE, SAVI, EVI)...")
ds['ndvi'] = (ds.nir - ds.red) / (ds.nir + ds.red + 1e-8)
ds['ndre'] = (ds.nir - ds.red_edge_1) / (ds.nir + ds.red_edge_1 + 1e-8)
ds['savi'] = ((ds.nir - ds.red) / (ds.nir + ds.red + 0.5)) * 1.5
ds['evi'] = 2.5 * ((ds.nir - ds.red) / (ds.nir + 6 * ds.red - 7.5 * ds.blue + 1))
ds_indices = ds[['ndvi', 'ndre', 'evi', 'savi']]
print("Reshaping multi-dimensional xarray into flat Tabular DataFrame...")
df = ds_indices.compute().to_dataframe().reset_index()
df['date_str'] = df['time'].dt.strftime('%Y%m%d')
df_pivot = df.pivot(index=['y', 'x'], columns='date_str', values=['ndvi', 'ndre', 'evi', 'savi'])
df_pivot.columns = [f"{date}_{band}" for band, date in df_pivot.columns]
df_final = df_pivot.reset_index().rename(columns={'y': 'lat', 'x': 'lon'})
print(f"✅ Data Ready! {df_final.shape[0]} spatial pixels generated.")
return df_final
# ==========================================
# 4. INFERENCE PIPELINE
# ==========================================
class CropInferencePipeline:
def __init__(self, model_dir="/tmp/geocrop-cache"):
print(f"Loading Crop Inference Pipeline from {model_dir}...")
meta_path = os.path.join(model_dir, "pipeline_meta.pkl")
if not os.path.exists(meta_path):
raise FileNotFoundError(f"Pipeline metadata not found at {meta_path}")
self.meta = joblib.load(meta_path)
self.le = self.meta["le"]
self.bands = self.meta["bands"]
self.dates = self.meta["dates"]
self.w_fcn = self.meta["weights"]["w_fcn"]
self.w_cb = self.meta["weights"]["w_cb"]
self.fcn = TemporalFCN(len(self.bands), self.meta["num_classes"])
fcn_path = os.path.join(model_dir, "Temporal_FCN.pth")
self.fcn.load_state_dict(torch.load(fcn_path, map_location=torch.device('cpu')))
self.fcn.eval()
cb_path = os.path.join(model_dir, "calibrated_hybrid_cb.pkl")
self.calibrated_cb = joblib.load(cb_path)
print("Models loaded successfully.")
def _impute_inference_data(self, df):
print("Imputing cloudy/missing timesteps via temporal interpolation...")
df = df.copy()
missing_mask = {}
for band in self.bands:
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
if band_cols:
missing_mask[band] = df[band_cols].isna().astype(float)
df[band_cols] = df[band_cols].interpolate(method='linear', axis=1, limit_direction='both')
df[band_cols] = df[band_cols].ffill(axis=1).bfill(axis=1).fillna(0)
return df, missing_mask
def predict(self, raw_df, apply_spatial_smoothing=False, coord_cols=['lat', 'lon']):
df, missing_mask = self._impute_inference_data(raw_df)
X_infer = prepare_tensors(df, self.bands, self.dates)
infer_loader = DataLoader(CropDataset(X_infer, np.zeros(len(df)), augment=False), batch_size=128, shuffle=False)
fcn_probs = []
fcn_feats = []
with torch.no_grad():
for X_batch, _ in infer_loader:
out, feats = self.fcn(X_batch, return_features=True)
fcn_probs.extend(torch.softmax(out, dim=1).numpy())
fcn_feats.append(feats.numpy())
fcn_probs = np.array(fcn_probs)
fcn_feats = np.vstack(fcn_feats)
X_infer_flat = X_infer.reshape(X_infer.shape[0], -1)
X_stack = np.hstack([X_infer_flat, fcn_feats])
cb_probs = self.calibrated_cb.predict_proba(X_stack)
final_probs = (fcn_probs * self.w_fcn) + (cb_probs * self.w_cb)
final_preds = np.argmax(final_probs, axis=1)
if apply_spatial_smoothing and all(col in df.columns for col in coord_cols):
print(f"Applying spatial probability smoothing using {coord_cols}...")
coords = df[coord_cols].values
knn = KNeighborsRegressor(n_neighbors=9, weights='distance')
knn.fit(coords, final_probs)
smoothed_probs = knn.predict(coords)
final_preds = np.argmax(smoothed_probs, axis=1)
final_probs = smoothed_probs
df['class_id'] = final_preds
df['predicted_crop'] = self.le.inverse_transform(final_preds)
df['confidence'] = np.max(final_probs, axis=1)
missing_ratio = np.mean([m.mean(axis=1) for m in missing_mask.values()], axis=0)
df['high_missing'] = missing_ratio > 0.4
df['low_quality'] = (df['confidence'] < 0.5) | df['high_missing']
# Set NoData (0) for low quality
df.loc[df['low_quality'], 'class_id'] = 0
df.loc[df['low_quality'], 'predicted_crop'] = 'Unknown/NoData'
return df
def export_to_geotiff(self, df, output_path="lulc_map.tif", crs="EPSG:6933"):
if xr is None or rioxarray is None:
raise ImportError("Missing required libraries: xarray, rioxarray")
print(f"Exporting LULC masks to {output_path}...")
ds_out = df.set_index(['lat', 'lon'])[['class_id', 'confidence', 'low_quality']].to_xarray()
ds_out = ds_out.rename({'lat': 'y', 'lon': 'x'})
ds_out = ds_out.sortby('y', ascending=False)
ds_out = ds_out.rio.set_spatial_dims(x_dim='x', y_dim='y')
ds_out.rio.write_crs(crs, inplace=True)
ds_out['class_id'].astype('uint16').rio.to_raster(output_path)
conf_path = output_path.replace('.tif', '_confidence.tif')
ds_out['confidence'].astype('float32').rio.to_raster(conf_path)
mask_path = output_path.replace('.tif', '_cloud_mask.tif')
ds_out['low_quality'].astype('uint8').rio.to_raster(mask_path)
legend_path = output_path.replace('.tif', '_legend.json')
legend_dict = {int(i): str(c) for i, c in enumerate(self.le.classes_)}
if 0 not in legend_dict:
legend_dict[0] = 'Unknown/NoData'
with open(legend_path, 'w') as f:
json.dump(legend_dict, f, indent=4)
print(f"✅ Successfully exported GeoTIFFs and class legend!")

View File

@ -5,6 +5,7 @@ rq
# Core dependencies
numpy>=1.24.0
pandas>=2.0.0
torch --index-url https://download.pytorch.org/whl/cpu
# Raster/geo processing
rasterio>=1.3.0
@ -13,6 +14,7 @@ rioxarray>=0.14.0
# STAC data access
pystac-client>=0.7.0
stackstac>=0.4.0
odc-stac>=0.3.0
xarray>=2023.1.0
# ML

View File

@ -171,7 +171,7 @@ def parse_and_validate_payload(payload: dict) -> tuple[dict, List[str]]:
# Validate model
if "model" in payload:
valid_models = ["Ensemble", "RandomForest", "XGBoost", "LightGBM", "CatBoost"]
valid_models = ["Ensemble", "RandomForest", "XGBoost", "LightGBM", "CatBoost", "CatBoost_V2"]
if payload["model"] not in valid_models:
errors.append(f"Invalid model: {payload['model']}. Must be one of {valid_models}")
@ -372,112 +372,82 @@ def run_job(payload_dict: dict) -> dict:
print(f"[{job_id}] Synthetic feature cube shape: {feature_cube.shape}")
# ==========================================
# Stage 3: Load DW Baseline
# Stage 3: Load Model Artifacts
# ==========================================
update_status(job_id, "running", "load_dw", 40, "Loading DW baseline...")
update_status(job_id, "running", "load_model", 40, "Loading model artifacts...")
print(f"[{job_id}] Loading DW baseline for {payload['year']}...")
is_hybrid = "hybrid" in payload['model'].lower() or "spatiotemporal" in payload['model'].lower()
from dw_baseline import load_dw_baseline_window
model_dir = Path(tempfile.mkdtemp())
try:
dw_arr, dw_profile = load_dw_baseline_window(
storage=storage,
year=payload['year'],
aoi_bbox_wgs84=bbox,
season=payload['season'],
if is_hybrid:
print(f"[{job_id}] Model type: Hybrid Spatio-Temporal. Downloading artifacts...")
# Expected files in MinIO: pipeline_meta.pkl, Temporal_FCN.pth, calibrated_hybrid_cb.pkl
for artifact in ["pipeline_meta.pkl", "Temporal_FCN.pth", "calibrated_hybrid_cb.pkl"]:
try:
storage.download_model_file(artifact, model_dir)
print(f"[{job_id}] Downloaded {artifact}")
except Exception as e:
print(f"[{job_id}] Failed to download {artifact}: {e}")
# Try with 'hybrid/' prefix if direct fails
try:
storage.download_file("geocrop-models", f"hybrid/{artifact}", model_dir / artifact)
print(f"[{job_id}] Downloaded {artifact} (from hybrid/ prefix)")
except Exception as e2:
raise FileNotFoundError(f"Required artifact {artifact} not found in geocrop-models: {e2}")
# ==========================================
# Stage 4: Fetch Spatio-Temporal Data
# ==========================================
update_status(job_id, "running", "fetch_stac", 50, "Fetching spatio-temporal indices...")
from hybrid_inference import DEAfricaSTACWrapper, CropInferencePipeline
stac_wrapper = DEAfricaSTACWrapper()
# Calculate ranges for wrapper
lat_range = (bbox[1], bbox[3])
lon_range = (bbox[0], bbox[2])
time_range = (start_date, end_date)
unseen_pixel_df = stac_wrapper.fetch_and_format_data(
lat_range=lat_range,
lon_range=lon_range,
time_range=time_range
)
if dw_arr is None:
raise FileNotFoundError(f"No DW baseline found for year {payload['year']}")
# ==========================================
# Stage 5: Hybrid Inference
# ==========================================
update_status(job_id, "running", "infer", 70, "Running Hybrid Inference (CNN + CatBoost)...")
pipeline = CropInferencePipeline(model_dir=str(model_dir))
print(f"[{job_id}] DW baseline shape: {dw_arr.shape}")
except Exception as e:
update_status(
job_id, "failed", "load_dw", 45,
f"Failed to load DW baseline: {e}",
error={"type": "DWBASELINE_ERROR", "message": str(e)}
mapped_crops_df = pipeline.predict(
unseen_pixel_df,
apply_spatial_smoothing=True,
coord_cols=['lat', 'lon']
)
return {"status": "failed", "error": f"DW baseline error: {e}"}
# ==========================================
# Stage 4: Skip AI Inference, use DW as result
# ==========================================
update_status(job_id, "running", "infer", 60, "Using DW baseline as classification...")
print(f"[{job_id}] Using DW baseline as result (Skipping AI models as requested)")
# We use dw_arr as the classification result
cls_raster = dw_arr.copy()
# ==========================================
# Stage 5: Apply Smoothing (Optional for DW)
# ==========================================
if payload.get('smoothing_kernel'):
kernel = payload['smoothing_kernel']
update_status(job_id, "running", "smooth", 75, f"Applying smoothing (k={kernel})...")
from postprocess import majority_filter
# ==========================================
# Stage 6: Export and Upload
# ==========================================
update_status(job_id, "running", "export_cog", 90, "Exporting results...")
output_dir = Path(tempfile.mkdtemp())
output_path = output_dir / "refined.tif"
cls_raster = majority_filter(cls_raster, kernel=kernel, nodata=0)
print(f"[{job_id}] Smoothing applied")
# ==========================================
# Stage 6: Export COGs
# ==========================================
update_status(job_id, "running", "export_cog", 80, "Exporting COGs...")
from cog import write_cog
output_dir = Path(tempfile.mkdtemp())
output_urls = {}
missing_outputs = []
# Export refined raster
if payload['outputs'].get('refined', True):
try:
refined_path = output_dir / "refined.tif"
dtype = "uint8" if cls_raster.max() <= 255 else "uint16"
write_cog(
str(refined_path),
cls_raster.astype(dtype),
dw_profile,
dtype=dtype,
nodata=0,
)
# Upload
result_key = f"results/{job_id}/refined.tif"
storage.upload_result(refined_path, result_key)
output_urls["refined_url"] = storage.presign_get("geocrop-results", result_key)
print(f"[{job_id}] Exported refined.tif")
except Exception as e:
missing_outputs.append(f"refined: {e}")
# Export DW baseline if requested
if payload['outputs'].get('dw_baseline', False):
try:
dw_path = output_dir / "dw_baseline.tif"
write_cog(
str(dw_path),
dw_arr.astype("uint8"),
dw_profile,
dtype="uint8",
nodata=0,
)
result_key = f"results/{job_id}/dw_baseline.tif"
storage.upload_result(dw_path, result_key)
output_urls["dw_baseline_url"] = storage.presign_get("geocrop-results", result_key)
print(f"[{job_id}] Exported dw_baseline.tif")
except Exception as e:
missing_outputs.append(f"dw_baseline: {e}")
pipeline.export_to_geotiff(mapped_crops_df, output_path=str(output_path))
output_urls = {}
for filename in ["refined.tif", "refined_confidence.tif", "refined_cloud_mask.tif", "refined_legend.json"]:
local_f = output_dir / filename
if local_f.exists():
result_key = f"results/{job_id}/{filename}"
storage.upload_result(local_f, result_key)
output_urls[filename.replace(".","_url")] = storage.presign_get("geocrop-results", result_key)
else:
# Fallback to Legacy/DW-only logic (current implementation)
print(f"[{job_id}] Using baseline logic (DW-only)...")
from dw_baseline import load_dw_baseline_window
# ... (keep existing Stage 3-6 logic for non-hybrid)
# Note: indices and true_color not yet implemented
if payload['outputs'].get('indices'):

View File

@ -17,3 +17,4 @@ resources:
- geocrop-web-ingress.yaml
- geocrop-tiler-rewrite.yaml
- 60-ingress-minio.yaml
- ntfy.yaml

99
k8s/base/ntfy.yaml Normal file
View File

@ -0,0 +1,99 @@
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: ntfy-data
namespace: monitoring
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 5Gi
---
apiVersion: v1
kind: ConfigMap
metadata:
name: ntfy-config
namespace: monitoring
data:
server.yml: |
base-url: "https://ntfy.techarvest.co.zw"
listen-http: ":80"
auth-file: "/var/lib/ntfy/user.db"
auth-default-access: "deny-all"
cache-file: "/var/lib/ntfy/cache.db"
attachment-cache-dir: "/var/lib/ntfy/attachments"
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: ntfy
namespace: monitoring
spec:
replicas: 1
selector:
matchLabels:
app: ntfy
template:
metadata:
labels:
app: ntfy
spec:
containers:
- name: ntfy
image: binwiederhier/ntfy:latest
imagePullPolicy: Always
args: ["serve"]
ports:
- containerPort: 80
volumeMounts:
- name: ntfy-config
mountPath: /etc/ntfy/server.yml
subPath: server.yml
readOnly: true
- name: ntfy-data
mountPath: /var/lib/ntfy
volumes:
- name: ntfy-config
configMap:
name: ntfy-config
- name: ntfy-data
persistentVolumeClaim:
claimName: ntfy-data
---
apiVersion: v1
kind: Service
metadata:
name: ntfy
namespace: monitoring
spec:
selector:
app: ntfy
ports:
- port: 80
targetPort: 80
---
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: ntfy-ingress
namespace: monitoring
annotations:
cert-manager.io/cluster-issuer: "letsencrypt-prod"
spec:
ingressClassName: nginx
tls:
- hosts:
- ntfy.techarvest.co.zw
secretName: ntfy-tls
rules:
- host: ntfy.techarvest.co.zw
http:
paths:
- path: /
pathType: Prefix
backend:
service:
name: ntfy
port:
number: 80