geocrop-platform./apps/worker/hybrid_inference.py

277 lines
11 KiB
Python

import os
import io
import json
import time
import copy
import joblib
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.neighbors import KNeighborsRegressor
from catboost import CatBoostClassifier
# Digital Earth Africa STAC specific imports
try:
from pystac_client import Client
import odc.stac
import xarray as xr
import rioxarray
except ImportError:
Client = None
odc = None
xr = None
rioxarray = None
# ==========================================
# 1. CPU-OPTIMIZED ARCHITECTURES
# ==========================================
class TemporalFCN(nn.Module):
def __init__(self, num_bands, num_classes):
super().__init__()
self.conv_block1 = nn.Sequential(
nn.Conv1d(num_bands, 64, kernel_size=5, padding=2),
nn.BatchNorm1d(64),
nn.ReLU()
)
self.conv_block2 = nn.Sequential(
nn.Conv1d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm1d(128),
nn.ReLU()
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Linear(128, num_classes)
def forward(self, x, return_features=False):
x = self.conv_block1(x)
x = self.conv_block2(x)
features = self.global_avg_pool(x).squeeze(-1)
out = self.fc(features)
if return_features:
return out, features
return out
class SmallGRU(nn.Module):
def __init__(self, num_bands, num_classes, hidden_size=64):
super().__init__()
self.gru = nn.GRU(input_size=num_bands, hidden_size=hidden_size, num_layers=1, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x, return_features=False):
x = x.transpose(1, 2)
out, _ = self.gru(x)
features = out[:, -1, :]
final_out = self.fc(features)
if return_features:
return final_out, features
return final_out
# ==========================================
# 2. DATA PREPARATION & PYTORCH UTILS
# ==========================================
class CropDataset(Dataset):
def __init__(self, X, y, augment=False):
self.X = torch.FloatTensor(X)
self.y = torch.LongTensor(y)
self.augment = augment
def __len__(self):
return len(self.y)
def __getitem__(self, idx):
x = self.X[idx].clone()
if self.augment:
if torch.rand(1).item() > 0.5:
noise = torch.randn_like(x) * 0.03
x = x + noise
if torch.rand(1).item() > 0.7:
seq_len = x.shape[1]
t_idx = torch.randint(0, seq_len, (1,)).item()
x[:, t_idx] = 0.0
return x, self.y[idx]
def prepare_tensors(df, bands, dates):
num_samples = len(df)
X_3d = np.zeros((num_samples, len(bands), len(dates)), dtype=np.float32)
for b_idx, band in enumerate(bands):
for d_idx, date in enumerate(dates):
col = f"{date}_{band}"
if col in df.columns:
X_3d[:, b_idx, d_idx] = df[col].values
means = X_3d.mean(axis=2, keepdims=True)
stds = X_3d.std(axis=2, keepdims=True) + 1e-8
X_3d = (X_3d - means) / stds
return X_3d
# ==========================================
# 3. DIGITAL EARTH AFRICA STAC INTEGRATION
# ==========================================
class DEAfricaSTACWrapper:
def __init__(self, stac_url="https://explorer.digitalearth.africa/stac"):
if Client is None or odc is None or xr is None:
raise ImportError("Missing required libraries: pystac-client, odc-stac, xarray")
print(f"Connecting to Digital Earth Africa STAC Catalog at {stac_url}...")
self.catalog = Client.open(stac_url)
def fetch_and_format_data(self, lat_range, lon_range, time_range, resolution=20):
bbox = [lon_range[0], lat_range[0], lon_range[1], lat_range[1]]
print(f"Searching STAC for Bounding Box: {bbox} over {time_range}...")
search = self.catalog.search(
collections=["s2_l2a"],
bbox=bbox,
datetime=f"{time_range[0]}/{time_range[1]}"
)
items = list(search.items())
if not items:
raise ValueError("No STAC items found for this bounding box and time range.")
print(f"Found {len(items)} STAC items. Loading into xarray...")
ds = odc.stac.load(
items,
measurements=['red', 'green', 'blue', 'nir', 'red_edge_1', 'scl'],
bbox=bbox,
crs="EPSG:6933",
resolution=resolution,
groupby="solar_day"
)
print("Masking clouds and shadows...")
valid_mask = (ds.scl == 4) | (ds.scl == 5) | (ds.scl == 6) | (ds.scl == 2) | (ds.scl == 7)
ds = ds.where(valid_mask)
ds = ds / 10000.0
print("Computing Spectral Indices (NDVI, NDRE, SAVI, EVI)...")
ds['ndvi'] = (ds.nir - ds.red) / (ds.nir + ds.red + 1e-8)
ds['ndre'] = (ds.nir - ds.red_edge_1) / (ds.nir + ds.red_edge_1 + 1e-8)
ds['savi'] = ((ds.nir - ds.red) / (ds.nir + ds.red + 0.5)) * 1.5
ds['evi'] = 2.5 * ((ds.nir - ds.red) / (ds.nir + 6 * ds.red - 7.5 * ds.blue + 1))
ds_indices = ds[['ndvi', 'ndre', 'evi', 'savi']]
print("Reshaping multi-dimensional xarray into flat Tabular DataFrame...")
df = ds_indices.compute().to_dataframe().reset_index()
df['date_str'] = df['time'].dt.strftime('%Y%m%d')
df_pivot = df.pivot(index=['y', 'x'], columns='date_str', values=['ndvi', 'ndre', 'evi', 'savi'])
df_pivot.columns = [f"{date}_{band}" for band, date in df_pivot.columns]
df_final = df_pivot.reset_index().rename(columns={'y': 'lat', 'x': 'lon'})
print(f"✅ Data Ready! {df_final.shape[0]} spatial pixels generated.")
return df_final
# ==========================================
# 4. INFERENCE PIPELINE
# ==========================================
class CropInferencePipeline:
def __init__(self, model_dir="/tmp/geocrop-cache"):
print(f"Loading Crop Inference Pipeline from {model_dir}...")
meta_path = os.path.join(model_dir, "pipeline_meta.pkl")
if not os.path.exists(meta_path):
raise FileNotFoundError(f"Pipeline metadata not found at {meta_path}")
self.meta = joblib.load(meta_path)
self.le = self.meta["le"]
self.bands = self.meta["bands"]
self.dates = self.meta["dates"]
self.w_fcn = self.meta["weights"]["w_fcn"]
self.w_cb = self.meta["weights"]["w_cb"]
self.fcn = TemporalFCN(len(self.bands), self.meta["num_classes"])
fcn_path = os.path.join(model_dir, "Temporal_FCN.pth")
self.fcn.load_state_dict(torch.load(fcn_path, map_location=torch.device('cpu')))
self.fcn.eval()
cb_path = os.path.join(model_dir, "calibrated_hybrid_cb.pkl")
self.calibrated_cb = joblib.load(cb_path)
print("Models loaded successfully.")
def _impute_inference_data(self, df):
print("Imputing cloudy/missing timesteps via temporal interpolation...")
df = df.copy()
missing_mask = {}
for band in self.bands:
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
if band_cols:
missing_mask[band] = df[band_cols].isna().astype(float)
df[band_cols] = df[band_cols].interpolate(method='linear', axis=1, limit_direction='both')
df[band_cols] = df[band_cols].ffill(axis=1).bfill(axis=1).fillna(0)
return df, missing_mask
def predict(self, raw_df, apply_spatial_smoothing=False, coord_cols=['lat', 'lon']):
df, missing_mask = self._impute_inference_data(raw_df)
X_infer = prepare_tensors(df, self.bands, self.dates)
infer_loader = DataLoader(CropDataset(X_infer, np.zeros(len(df)), augment=False), batch_size=128, shuffle=False)
fcn_probs = []
fcn_feats = []
with torch.no_grad():
for X_batch, _ in infer_loader:
out, feats = self.fcn(X_batch, return_features=True)
fcn_probs.extend(torch.softmax(out, dim=1).numpy())
fcn_feats.append(feats.numpy())
fcn_probs = np.array(fcn_probs)
fcn_feats = np.vstack(fcn_feats)
X_infer_flat = X_infer.reshape(X_infer.shape[0], -1)
X_stack = np.hstack([X_infer_flat, fcn_feats])
cb_probs = self.calibrated_cb.predict_proba(X_stack)
final_probs = (fcn_probs * self.w_fcn) + (cb_probs * self.w_cb)
final_preds = np.argmax(final_probs, axis=1)
if apply_spatial_smoothing and all(col in df.columns for col in coord_cols):
print(f"Applying spatial probability smoothing using {coord_cols}...")
coords = df[coord_cols].values
knn = KNeighborsRegressor(n_neighbors=9, weights='distance')
knn.fit(coords, final_probs)
smoothed_probs = knn.predict(coords)
final_preds = np.argmax(smoothed_probs, axis=1)
final_probs = smoothed_probs
df['class_id'] = final_preds
df['predicted_crop'] = self.le.inverse_transform(final_preds)
df['confidence'] = np.max(final_probs, axis=1)
missing_ratio = np.mean([m.mean(axis=1) for m in missing_mask.values()], axis=0)
df['high_missing'] = missing_ratio > 0.4
df['low_quality'] = (df['confidence'] < 0.5) | df['high_missing']
# Set NoData (0) for low quality
df.loc[df['low_quality'], 'class_id'] = 0
df.loc[df['low_quality'], 'predicted_crop'] = 'Unknown/NoData'
return df
def export_to_geotiff(self, df, output_path="lulc_map.tif", crs="EPSG:6933"):
if xr is None or rioxarray is None:
raise ImportError("Missing required libraries: xarray, rioxarray")
print(f"Exporting LULC masks to {output_path}...")
ds_out = df.set_index(['lat', 'lon'])[['class_id', 'confidence', 'low_quality']].to_xarray()
ds_out = ds_out.rename({'lat': 'y', 'lon': 'x'})
ds_out = ds_out.sortby('y', ascending=False)
ds_out = ds_out.rio.set_spatial_dims(x_dim='x', y_dim='y')
ds_out.rio.write_crs(crs, inplace=True)
ds_out['class_id'].astype('uint16').rio.to_raster(output_path)
conf_path = output_path.replace('.tif', '_confidence.tif')
ds_out['confidence'].astype('float32').rio.to_raster(conf_path)
mask_path = output_path.replace('.tif', '_cloud_mask.tif')
ds_out['low_quality'].astype('uint8').rio.to_raster(mask_path)
legend_path = output_path.replace('.tif', '_legend.json')
legend_dict = {int(i): str(c) for i, c in enumerate(self.le.classes_)}
if 0 not in legend_dict:
legend_dict[0] = 'Unknown/NoData'
with open(legend_path, 'w') as f:
json.dump(legend_dict, f, indent=4)
print(f"✅ Successfully exported GeoTIFFs and class legend!")