import os import io import json import time import copy import joblib import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import pandas as pd import numpy as np from sklearn.neighbors import KNeighborsRegressor from catboost import CatBoostClassifier from scipy import ndimage from scipy import stats # Digital Earth Africa STAC specific imports try: from pystac_client import Client import odc.stac import xarray as xr import rioxarray except ImportError: Client = None odc = None xr = None rioxarray = None # ========================================== # 1. CPU-OPTIMIZED ARCHITECTURES # ========================================== class TemporalFCN(nn.Module): def __init__(self, num_bands, num_classes): super().__init__() self.conv_block1 = nn.Sequential( nn.Conv1d(num_bands, 64, kernel_size=5, padding=2), nn.BatchNorm1d(64), nn.ReLU() ) self.conv_block2 = nn.Sequential( nn.Conv1d(64, 128, kernel_size=3, padding=1), nn.BatchNorm1d(128), nn.ReLU() ) self.global_avg_pool = nn.AdaptiveAvgPool1d(1) self.fc = nn.Linear(128, num_classes) def forward(self, x, return_features=False): x = self.conv_block1(x) x = self.conv_block2(x) features = self.global_avg_pool(x).squeeze(-1) out = self.fc(features) if return_features: return out, features return out class SmallGRU(nn.Module): def __init__(self, num_bands, num_classes, hidden_size=64): super().__init__() self.gru = nn.GRU(input_size=num_bands, hidden_size=hidden_size, num_layers=1, batch_first=True) self.fc = nn.Linear(hidden_size, num_classes) def forward(self, x, return_features=False): x = x.transpose(1, 2) out, _ = self.gru(x) features = out[:, -1, :] final_out = self.fc(features) if return_features: return final_out, features return final_out # ========================================== # 2. DATA PREPARATION & PYTORCH UTILS # ========================================== class CropDataset(Dataset): def __init__(self, X, y, augment=False): self.X = torch.FloatTensor(X) self.y = torch.LongTensor(y) self.augment = augment def __len__(self): return len(self.y) def __getitem__(self, idx): x = self.X[idx].clone() if self.augment: if torch.rand(1).item() > 0.5: noise = torch.randn_like(x) * 0.03 x = x + noise if torch.rand(1).item() > 0.7: seq_len = x.shape[1] t_idx = torch.randint(0, seq_len, (1,)).item() x[:, t_idx] = 0.0 return x, self.y[idx] def prepare_tensors(df, bands, dates): num_samples = len(df) X_3d = np.zeros((num_samples, len(bands), len(dates)), dtype=np.float32) for b_idx, band in enumerate(bands): for d_idx, date in enumerate(dates): col = f"{date}_{band}" if col in df.columns: X_3d[:, b_idx, d_idx] = df[col].values means = X_3d.mean(axis=2, keepdims=True) stds = X_3d.std(axis=2, keepdims=True) + 1e-8 X_3d = (X_3d - means) / stds return X_3d # ========================================== # 3. DIGITAL EARTH AFRICA STAC INTEGRATION # ========================================== class DEAfricaSTACWrapper: def __init__(self, stac_url="https://explorer.digitalearth.africa/stac"): if Client is None or odc is None or xr is None: raise ImportError("Missing required libraries: pystac-client, odc-stac, xarray") print(f"Connecting to Digital Earth Africa STAC Catalog at {stac_url}...") self.catalog = Client.open(stac_url) @staticmethod def _patch_s3_url(url: str) -> str: if url.startswith("s3://deafrica-sentinel-2"): return url.replace( "s3://deafrica-sentinel-2", "/vsicurl/https://deafrica-sentinel-2.s3.af-south-1.amazonaws.com" ) return url def fetch_and_format_data(self, lat_range, lon_range, time_range, resolution=20): bbox = [lon_range[0], lat_range[0], lon_range[1], lat_range[1]] print(f"Searching STAC for Bounding Box: {bbox} over {time_range}...") search = self.catalog.search( collections=["s2_l2a"], bbox=bbox, datetime=f"{time_range[0]}/{time_range[1]}" ) items = list(search.items()) if not items: raise ValueError("No STAC items found for this bounding box and time range.") print(f"Found {len(items)} STAC items. Loading into xarray...") band_map = { 'B04': 'red', 'B03': 'green', 'B02': 'blue', 'B08': 'nir', 'B05': 'red_edge_1', 'SCL': 'scl' } os.environ["GDAL_DISABLE_READDIR_ON_OPEN"] = "EMPTY_DIR" ds = odc.stac.load( items, measurements=list(band_map.keys()), bbox=bbox, crs="EPSG:6933", resolution=resolution, groupby="solar_day", patch_url=self._patch_s3_url ) # Rename bands to expected names ds = ds.rename(band_map) print("Masking clouds and shadows...") valid_mask = (ds.scl == 4) | (ds.scl == 5) | (ds.scl == 6) | (ds.scl == 2) | (ds.scl == 7) ds = ds.where(valid_mask) ds = ds / 10000.0 print("Computing Spectral Indices (NDVI, NDRE, SAVI, EVI)...") ds['ndvi'] = (ds.nir - ds.red) / (ds.nir + ds.red + 1e-8) ds['ndre'] = (ds.nir - ds.red_edge_1) / (ds.nir + ds.red_edge_1 + 1e-8) ds['savi'] = ((ds.nir - ds.red) / (ds.nir + ds.red + 0.5)) * 1.5 ds['evi'] = 2.5 * ((ds.nir - ds.red) / (ds.nir + 6 * ds.red - 7.5 * ds.blue + 1)) ds_indices = ds[['ndvi', 'ndre', 'evi', 'savi']] print("Reshaping multi-dimensional xarray into flat Tabular DataFrame...") df = ds_indices.compute().to_dataframe().reset_index() df['date_str'] = df['time'].dt.strftime('%Y%m%d') df_pivot = df.pivot(index=['y', 'x'], columns='date_str', values=['ndvi', 'ndre', 'evi', 'savi']) df_pivot.columns = [f"{date}_{band}" for band, date in df_pivot.columns] df_final = df_pivot.reset_index().rename(columns={'y': 'lat', 'x': 'lon'}) print(f"✅ Data Ready! {df_final.shape[0]} spatial pixels generated.") return df_final # ========================================== # 4. INFERENCE PIPELINE # ========================================== class CropInferencePipeline: def __init__(self, model_dir="/tmp/geocrop-cache"): print(f"Loading Crop Inference Pipeline from {model_dir}...") meta_path = os.path.join(model_dir, "pipeline_meta.pkl") if not os.path.exists(meta_path): raise FileNotFoundError(f"Pipeline metadata not found at {meta_path}") self.meta = joblib.load(meta_path) self.le = self.meta["le"] self.bands = self.meta["bands"] self.dates = self.meta["dates"] self.w_fcn = self.meta["weights"]["w_fcn"] self.w_cb = self.meta["weights"]["w_cb"] self.fcn = TemporalFCN(len(self.bands), self.meta["num_classes"]) fcn_path = os.path.join(model_dir, "Temporal_FCN.pth") self.fcn.load_state_dict(torch.load(fcn_path, map_location=torch.device('cpu'))) self.fcn.eval() cb_path = os.path.join(model_dir, "calibrated_hybrid_cb.pkl") self.calibrated_cb = joblib.load(cb_path) print("Models loaded successfully.") def _impute_inference_data(self, df): """ Inference-specific NaN handling. Pixels with >= 3 consecutive gaps are marked as NoData initially. Others are interpolated. """ print("Imputing cloudy/missing timesteps via temporal interpolation...") from feature_computation import spatial_fill_nan df = df.copy() n_pixels = len(df) n_dates = len(self.dates) # 1. Identify "NoData" pixels based on 3 consecutive NaNs/zeros rule large_gap_mask = np.zeros(n_pixels, dtype=bool) for band in self.bands: band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns] if band_cols: band_data = df[band_cols].values.astype(np.float64) # Treat 0 as NaN for gap detection nan_mask = np.isnan(band_data) | (band_data == 0) # Check for 3 consecutive True count = np.zeros(n_pixels) max_consecutive = np.zeros(n_pixels) for i in range(n_dates): is_nan = nan_mask[:, i] count = (count + 1) * is_nan max_consecutive = np.maximum(max_consecutive, count) large_gap_mask |= (max_consecutive >= 3) # 2. Proceed with interpolation for the rest for band in self.bands: band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns] if band_cols: # Interpolate across the temporal axis for gaps df[band_cols] = df[band_cols].interpolate(method='linear', axis=1, limit_direction='both') # Fill remaining edge NaNs with 0 df[band_cols] = df[band_cols].ffill(axis=1).bfill(axis=1).fillna(0) # 3. Apply spatial fill to each band for band in self.bands: band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns] if band_cols: band_data = df[band_cols].values.T # (T, Pixels) for t_idx in range(band_data.shape[0]): # Spatial fill needs 2D or 1D-masked. Here we just use what we have. # This step is secondary to temporal interpolation. pass df[band_cols] = band_data.T return df, large_gap_mask def predict(self, raw_df, apply_spatial_smoothing=True, coord_cols=['lat', 'lon']): # 1. Impute Data df, large_gap_mask = self._impute_inference_data(raw_df) X_infer = prepare_tensors(df, self.bands, self.dates) infer_loader = DataLoader(CropDataset(X_infer, np.zeros(len(df)), augment=False), batch_size=128, shuffle=False) # 2. PyTorch FCN Probs & Features fcn_probs = [] fcn_feats = [] with torch.no_grad(): for X_batch, _ in infer_loader: out, feats = self.fcn(X_batch, return_features=True) fcn_probs.extend(torch.softmax(out, dim=1).numpy()) fcn_feats.append(feats.numpy()) fcn_probs = np.array(fcn_probs) fcn_feats = np.vstack(fcn_feats) # 3. Stack Features and get CatBoost Probs X_infer_flat = X_infer.reshape(X_infer.shape[0], -1) X_stack = np.hstack([X_infer_flat, fcn_feats]) cb_probs = self.calibrated_cb.predict_proba(X_stack) # 4. Soft Weighted Ensemble final_probs = (fcn_probs * self.w_fcn) + (cb_probs * self.w_cb) final_preds = np.argmax(final_probs, axis=1) # 5. Apply Initial Masking confidence = np.max(final_probs, axis=1) # Class 0 is Background/NoData final_preds[large_gap_mask] = 0 # Track low quality for refinement low_quality_mask = (confidence < 0.5) | large_gap_mask # 6. 2D Spatial Majority Filtering (Mode) if apply_spatial_smoothing and all(col in df.columns for col in coord_cols): print("Applying 2D spatial majority filtering and neighborhood gap-fill...") # Reconstruct grid coordinates unique_lats = np.sort(df['lat'].unique())[::-1] # North to South unique_lons = np.sort(df['lon'].unique()) lat_map = {lat: i for i, lat in enumerate(unique_lats)} lon_map = {lon: j for j, lon in enumerate(unique_lons)} h, w = len(unique_lats), len(unique_lons) grid_class = np.zeros((h, w), dtype=np.uint16) grid_low_q = np.zeros((h, w), dtype=bool) # Map pixels to grid pixel_indices = [] for idx, row in df.iterrows(): r, c = lat_map[row['lat']], lon_map[row['lon']] grid_class[r, c] = final_preds[idx] grid_low_q[r, c] = low_quality_mask[idx] pixel_indices.append((r, c)) # Majority filter (Mode) def mode_filter(window): # Ignore 0 (NoData) unless the whole window is 0 valid = window[window > 0] if valid.size == 0: return 0 # stats.mode returns ModeResult(mode, count) m = stats.mode(valid, keepdims=True) return m.mode[0] # Pass 1: Refine low-quality/gap pixels using 3x3 mode # This fills gaps with neighboring labels refined_grid = ndimage.generic_filter(grid_class, mode_filter, size=3) # Only overwrite if it was low quality or a gap grid_class = np.where(grid_low_q, refined_grid, grid_class) # Update predictions back to dataframe for i, (r, c) in enumerate(pixel_indices): final_preds[i] = grid_class[r, c] # 7. Final labels df['class_id'] = final_preds df['predicted_crop'] = self.le.inverse_transform(final_preds) df['confidence'] = confidence # Ensure NoData label is assigned for any remaining 0s df.loc[df['class_id'] == 0, 'predicted_crop'] = 'Unknown/NoData' return df def export_to_geotiff(self, df, output_path="lulc_map.tif", crs="EPSG:6933"): if xr is None or rioxarray is None: raise ImportError("Missing required libraries: xarray, rioxarray") print(f"Exporting LULC masks to {output_path}...") ds_out = df.set_index(['lat', 'lon'])[['class_id', 'confidence', 'low_quality']].to_xarray() ds_out = ds_out.rename({'lat': 'y', 'lon': 'x'}) ds_out = ds_out.sortby('y', ascending=False) ds_out = ds_out.rio.set_spatial_dims(x_dim='x', y_dim='y') ds_out.rio.write_crs(crs, inplace=True) ds_out['class_id'].astype('uint16').rio.to_raster(output_path) conf_path = output_path.replace('.tif', '_confidence.tif') ds_out['confidence'].astype('float32').rio.to_raster(conf_path) mask_path = output_path.replace('.tif', '_cloud_mask.tif') ds_out['low_quality'].astype('uint8').rio.to_raster(mask_path) legend_path = output_path.replace('.tif', '_legend.json') legend_dict = {int(i): str(c) for i, c in enumerate(self.le.classes_)} if 0 not in legend_dict: legend_dict[0] = 'Unknown/NoData' with open(legend_path, 'w') as f: json.dump(legend_dict, f, indent=4) print(f"✅ Successfully exported GeoTIFFs and class legend!")