import os import io import json import time import copy import joblib import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import pandas as pd import numpy as np from sklearn.neighbors import KNeighborsRegressor from catboost import CatBoostClassifier # Digital Earth Africa STAC specific imports try: from pystac_client import Client import odc.stac import xarray as xr import rioxarray except ImportError: Client = None odc = None xr = None rioxarray = None # ========================================== # 1. CPU-OPTIMIZED ARCHITECTURES # ========================================== class TemporalFCN(nn.Module): def __init__(self, num_bands, num_classes): super().__init__() self.conv_block1 = nn.Sequential( nn.Conv1d(num_bands, 64, kernel_size=5, padding=2), nn.BatchNorm1d(64), nn.ReLU() ) self.conv_block2 = nn.Sequential( nn.Conv1d(64, 128, kernel_size=3, padding=1), nn.BatchNorm1d(128), nn.ReLU() ) self.global_avg_pool = nn.AdaptiveAvgPool1d(1) self.fc = nn.Linear(128, num_classes) def forward(self, x, return_features=False): x = self.conv_block1(x) x = self.conv_block2(x) features = self.global_avg_pool(x).squeeze(-1) out = self.fc(features) if return_features: return out, features return out class SmallGRU(nn.Module): def __init__(self, num_bands, num_classes, hidden_size=64): super().__init__() self.gru = nn.GRU(input_size=num_bands, hidden_size=hidden_size, num_layers=1, batch_first=True) self.fc = nn.Linear(hidden_size, num_classes) def forward(self, x, return_features=False): x = x.transpose(1, 2) out, _ = self.gru(x) features = out[:, -1, :] final_out = self.fc(features) if return_features: return final_out, features return final_out # ========================================== # 2. DATA PREPARATION & PYTORCH UTILS # ========================================== class CropDataset(Dataset): def __init__(self, X, y, augment=False): self.X = torch.FloatTensor(X) self.y = torch.LongTensor(y) self.augment = augment def __len__(self): return len(self.y) def __getitem__(self, idx): x = self.X[idx].clone() if self.augment: if torch.rand(1).item() > 0.5: noise = torch.randn_like(x) * 0.03 x = x + noise if torch.rand(1).item() > 0.7: seq_len = x.shape[1] t_idx = torch.randint(0, seq_len, (1,)).item() x[:, t_idx] = 0.0 return x, self.y[idx] def prepare_tensors(df, bands, dates): num_samples = len(df) X_3d = np.zeros((num_samples, len(bands), len(dates)), dtype=np.float32) for b_idx, band in enumerate(bands): for d_idx, date in enumerate(dates): col = f"{date}_{band}" if col in df.columns: X_3d[:, b_idx, d_idx] = df[col].values means = X_3d.mean(axis=2, keepdims=True) stds = X_3d.std(axis=2, keepdims=True) + 1e-8 X_3d = (X_3d - means) / stds return X_3d # ========================================== # 3. DIGITAL EARTH AFRICA STAC INTEGRATION # ========================================== class DEAfricaSTACWrapper: def __init__(self, stac_url="https://explorer.digitalearth.africa/stac"): if Client is None or odc is None or xr is None: raise ImportError("Missing required libraries: pystac-client, odc-stac, xarray") print(f"Connecting to Digital Earth Africa STAC Catalog at {stac_url}...") self.catalog = Client.open(stac_url) def fetch_and_format_data(self, lat_range, lon_range, time_range, resolution=20): bbox = [lon_range[0], lat_range[0], lon_range[1], lat_range[1]] print(f"Searching STAC for Bounding Box: {bbox} over {time_range}...") search = self.catalog.search( collections=["s2_l2a"], bbox=bbox, datetime=f"{time_range[0]}/{time_range[1]}" ) items = list(search.items()) if not items: raise ValueError("No STAC items found for this bounding box and time range.") print(f"Found {len(items)} STAC items. Loading into xarray...") # Mapping for DE Africa S2 bands band_map = { 'B04': 'red', 'B03': 'green', 'B02': 'blue', 'B08': 'nir', 'B05': 'red_edge_1', 'SCL': 'scl' } ds = odc.stac.load( items, measurements=list(band_map.keys()), bbox=bbox, crs="EPSG:6933", resolution=resolution, groupby="solar_day" ) # Rename bands to expected names ds = ds.rename(band_map) print("Masking clouds and shadows...") valid_mask = (ds.scl == 4) | (ds.scl == 5) | (ds.scl == 6) | (ds.scl == 2) | (ds.scl == 7) ds = ds.where(valid_mask) ds = ds / 10000.0 print("Computing Spectral Indices (NDVI, NDRE, SAVI, EVI)...") ds['ndvi'] = (ds.nir - ds.red) / (ds.nir + ds.red + 1e-8) ds['ndre'] = (ds.nir - ds.red_edge_1) / (ds.nir + ds.red_edge_1 + 1e-8) ds['savi'] = ((ds.nir - ds.red) / (ds.nir + ds.red + 0.5)) * 1.5 ds['evi'] = 2.5 * ((ds.nir - ds.red) / (ds.nir + 6 * ds.red - 7.5 * ds.blue + 1)) ds_indices = ds[['ndvi', 'ndre', 'evi', 'savi']] print("Reshaping multi-dimensional xarray into flat Tabular DataFrame...") df = ds_indices.compute().to_dataframe().reset_index() df['date_str'] = df['time'].dt.strftime('%Y%m%d') df_pivot = df.pivot(index=['y', 'x'], columns='date_str', values=['ndvi', 'ndre', 'evi', 'savi']) df_pivot.columns = [f"{date}_{band}" for band, date in df_pivot.columns] df_final = df_pivot.reset_index().rename(columns={'y': 'lat', 'x': 'lon'}) print(f"✅ Data Ready! {df_final.shape[0]} spatial pixels generated.") return df_final # ========================================== # 4. INFERENCE PIPELINE # ========================================== class CropInferencePipeline: def __init__(self, model_dir="/tmp/geocrop-cache"): print(f"Loading Crop Inference Pipeline from {model_dir}...") meta_path = os.path.join(model_dir, "pipeline_meta.pkl") if not os.path.exists(meta_path): raise FileNotFoundError(f"Pipeline metadata not found at {meta_path}") self.meta = joblib.load(meta_path) self.le = self.meta["le"] self.bands = self.meta["bands"] self.dates = self.meta["dates"] self.w_fcn = self.meta["weights"]["w_fcn"] self.w_cb = self.meta["weights"]["w_cb"] self.fcn = TemporalFCN(len(self.bands), self.meta["num_classes"]) fcn_path = os.path.join(model_dir, "Temporal_FCN.pth") self.fcn.load_state_dict(torch.load(fcn_path, map_location=torch.device('cpu'))) self.fcn.eval() cb_path = os.path.join(model_dir, "calibrated_hybrid_cb.pkl") self.calibrated_cb = joblib.load(cb_path) print("Models loaded successfully.") def _impute_inference_data(self, df): print("Imputing cloudy/missing timesteps via temporal interpolation...") from feature_computation import handle_temporal_gaps, spatial_fill_nan df = df.copy() missing_mask = {} # Track original NaNs before any imputation for band in self.bands: band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns] if band_cols: missing_mask[band] = df[band_cols].isna().astype(float) # Process each band: apply handle_temporal_gaps per pixel for each band for band in self.bands: band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns] if band_cols: print(f" Processing band {band} with gap handling...") # For each pixel, apply handle_temporal_gaps to the time series for idx in range(len(df)): time_series = df[band_cols].iloc[idx].values.astype(np.float64) # Apply handle_temporal_gaps: gaps >= 3 will result in NaNs for those timesteps time_series = handle_temporal_gaps(time_series, gap_threshold=3) df.loc[df.index[idx], band_cols] = time_series # After gap handling, fill remaining NaNs with linear interpolation df[band_cols] = df[band_cols].interpolate(method='linear', axis=1, limit_direction='both') df[band_cols] = df[band_cols].ffill(axis=1).bfill(axis=1).fillna(0) # Apply spatial fill to each band using spatial_fill_nan # Reshape to (num_dates, num_pixels) for each band, apply spatial fill for band in self.bands: band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns] if band_cols: print(f" Applying spatial fill for band {band}...") # Transpose to (T, H*W) for spatial filling band_data = df[band_cols].values.T # Shape: (num_dates, num_pixels) # Apply spatial_fill_nan per time step for t_idx in range(band_data.shape[0]): band_data[t_idx] = spatial_fill_nan(band_data[t_idx].reshape(-1, 1)).squeeze() # Put back into dataframe df[band_cols] = band_data.T return df, missing_mask def predict(self, raw_df, apply_spatial_smoothing=False, coord_cols=['lat', 'lon']): df, missing_mask = self._impute_inference_data(raw_df) X_infer = prepare_tensors(df, self.bands, self.dates) infer_loader = DataLoader(CropDataset(X_infer, np.zeros(len(df)), augment=False), batch_size=128, shuffle=False) fcn_probs = [] fcn_feats = [] with torch.no_grad(): for X_batch, _ in infer_loader: out, feats = self.fcn(X_batch, return_features=True) fcn_probs.extend(torch.softmax(out, dim=1).numpy()) fcn_feats.append(feats.numpy()) fcn_probs = np.array(fcn_probs) fcn_feats = np.vstack(fcn_feats) X_infer_flat = X_infer.reshape(X_infer.shape[0], -1) X_stack = np.hstack([X_infer_flat, fcn_feats]) cb_probs = self.calibrated_cb.predict_proba(X_stack) final_probs = (fcn_probs * self.w_fcn) + (cb_probs * self.w_cb) final_preds = np.argmax(final_probs, axis=1) # Identify No Data pixels: those with all NaNs or zeros after imputation no_data_mask = np.zeros(len(df), dtype=bool) for band in self.bands: band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns] if band_cols: band_data = df[band_cols].values # Check if pixel is all zeros or all NaN for this band all_zeros = np.all(band_data == 0, axis=1) all_nan = np.all(np.isnan(band_data), axis=1) no_data_mask = no_data_mask | all_zeros | all_nan # Override predictions for No Data pixels to class 0 (Background/No Data) final_preds[no_data_mask] = 0 final_probs[no_data_mask] = 0.0 final_probs[no_data_mask, 0] = 1.0 # Set probability to 1.0 for class 0 if apply_spatial_smoothing and all(col in df.columns for col in coord_cols): print(f"Applying spatial probability smoothing using {coord_cols}...") coords = df[coord_cols].values knn = KNeighborsRegressor(n_neighbors=9, weights='distance') knn.fit(coords, final_probs) smoothed_probs = knn.predict(coords) final_preds = np.argmax(smoothed_probs, axis=1) final_probs = smoothed_probs # Re-apply No Data override after smoothing final_preds[no_data_mask] = 0 final_probs[no_data_mask, 0] = 1.0 df['class_id'] = final_preds df['predicted_crop'] = self.le.inverse_transform(final_preds) df['confidence'] = np.max(final_probs, axis=1) # Track missing data ratio for quality flag missing_ratio = np.mean([m.mean(axis=1) for m in missing_mask.values()], axis=0) df['high_missing'] = missing_ratio > 0.4 df['low_quality'] = (df['confidence'] < 0.5) | df['high_missing'] | no_data_mask # Set NoData (0) for low quality pixels df.loc[df['low_quality'], 'class_id'] = 0 df.loc[df['low_quality'], 'predicted_crop'] = 'Unknown/NoData' return df def export_to_geotiff(self, df, output_path="lulc_map.tif", crs="EPSG:6933"): if xr is None or rioxarray is None: raise ImportError("Missing required libraries: xarray, rioxarray") print(f"Exporting LULC masks to {output_path}...") ds_out = df.set_index(['lat', 'lon'])[['class_id', 'confidence', 'low_quality']].to_xarray() ds_out = ds_out.rename({'lat': 'y', 'lon': 'x'}) ds_out = ds_out.sortby('y', ascending=False) ds_out = ds_out.rio.set_spatial_dims(x_dim='x', y_dim='y') ds_out.rio.write_crs(crs, inplace=True) ds_out['class_id'].astype('uint16').rio.to_raster(output_path) conf_path = output_path.replace('.tif', '_confidence.tif') ds_out['confidence'].astype('float32').rio.to_raster(conf_path) mask_path = output_path.replace('.tif', '_cloud_mask.tif') ds_out['low_quality'].astype('uint8').rio.to_raster(mask_path) legend_path = output_path.replace('.tif', '_legend.json') legend_dict = {int(i): str(c) for i, c in enumerate(self.le.classes_)} if 0 not in legend_dict: legend_dict[0] = 'Unknown/NoData' with open(legend_path, 'w') as f: json.dump(legend_dict, f, indent=4) print(f"✅ Successfully exported GeoTIFFs and class legend!")