391 lines
15 KiB
Python
391 lines
15 KiB
Python
import os
|
|
import io
|
|
import json
|
|
import time
|
|
import copy
|
|
import joblib
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import pandas as pd
|
|
import numpy as np
|
|
from sklearn.neighbors import KNeighborsRegressor
|
|
from catboost import CatBoostClassifier
|
|
from scipy import ndimage
|
|
from scipy import stats
|
|
|
|
# Digital Earth Africa STAC specific imports
|
|
try:
|
|
from pystac_client import Client
|
|
import odc.stac
|
|
import xarray as xr
|
|
import rioxarray
|
|
except ImportError:
|
|
Client = None
|
|
odc = None
|
|
xr = None
|
|
rioxarray = None
|
|
|
|
# ==========================================
|
|
# 1. CPU-OPTIMIZED ARCHITECTURES
|
|
# ==========================================
|
|
|
|
class TemporalFCN(nn.Module):
|
|
def __init__(self, num_bands, num_classes):
|
|
super().__init__()
|
|
self.conv_block1 = nn.Sequential(
|
|
nn.Conv1d(num_bands, 64, kernel_size=5, padding=2),
|
|
nn.BatchNorm1d(64),
|
|
nn.ReLU()
|
|
)
|
|
self.conv_block2 = nn.Sequential(
|
|
nn.Conv1d(64, 128, kernel_size=3, padding=1),
|
|
nn.BatchNorm1d(128),
|
|
nn.ReLU()
|
|
)
|
|
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
|
|
self.fc = nn.Linear(128, num_classes)
|
|
|
|
def forward(self, x, return_features=False):
|
|
x = self.conv_block1(x)
|
|
x = self.conv_block2(x)
|
|
features = self.global_avg_pool(x).squeeze(-1)
|
|
out = self.fc(features)
|
|
if return_features:
|
|
return out, features
|
|
return out
|
|
|
|
class SmallGRU(nn.Module):
|
|
def __init__(self, num_bands, num_classes, hidden_size=64):
|
|
super().__init__()
|
|
self.gru = nn.GRU(input_size=num_bands, hidden_size=hidden_size, num_layers=1, batch_first=True)
|
|
self.fc = nn.Linear(hidden_size, num_classes)
|
|
|
|
def forward(self, x, return_features=False):
|
|
x = x.transpose(1, 2)
|
|
out, _ = self.gru(x)
|
|
features = out[:, -1, :]
|
|
final_out = self.fc(features)
|
|
if return_features:
|
|
return final_out, features
|
|
return final_out
|
|
|
|
# ==========================================
|
|
# 2. DATA PREPARATION & PYTORCH UTILS
|
|
# ==========================================
|
|
|
|
class CropDataset(Dataset):
|
|
def __init__(self, X, y, augment=False):
|
|
self.X = torch.FloatTensor(X)
|
|
self.y = torch.LongTensor(y)
|
|
self.augment = augment
|
|
|
|
def __len__(self):
|
|
return len(self.y)
|
|
|
|
def __getitem__(self, idx):
|
|
x = self.X[idx].clone()
|
|
if self.augment:
|
|
if torch.rand(1).item() > 0.5:
|
|
noise = torch.randn_like(x) * 0.03
|
|
x = x + noise
|
|
if torch.rand(1).item() > 0.7:
|
|
seq_len = x.shape[1]
|
|
t_idx = torch.randint(0, seq_len, (1,)).item()
|
|
x[:, t_idx] = 0.0
|
|
return x, self.y[idx]
|
|
|
|
def prepare_tensors(df, bands, dates):
|
|
num_samples = len(df)
|
|
X_3d = np.zeros((num_samples, len(bands), len(dates)), dtype=np.float32)
|
|
for b_idx, band in enumerate(bands):
|
|
for d_idx, date in enumerate(dates):
|
|
col = f"{date}_{band}"
|
|
if col in df.columns:
|
|
X_3d[:, b_idx, d_idx] = df[col].values
|
|
|
|
means = X_3d.mean(axis=2, keepdims=True)
|
|
stds = X_3d.std(axis=2, keepdims=True) + 1e-8
|
|
X_3d = (X_3d - means) / stds
|
|
return X_3d
|
|
|
|
# ==========================================
|
|
# 3. DIGITAL EARTH AFRICA STAC INTEGRATION
|
|
# ==========================================
|
|
|
|
class DEAfricaSTACWrapper:
|
|
def __init__(self, stac_url="https://explorer.digitalearth.africa/stac"):
|
|
if Client is None or odc is None or xr is None:
|
|
raise ImportError("Missing required libraries: pystac-client, odc-stac, xarray")
|
|
print(f"Connecting to Digital Earth Africa STAC Catalog at {stac_url}...")
|
|
self.catalog = Client.open(stac_url)
|
|
|
|
@staticmethod
|
|
def _patch_s3_url(url: str) -> str:
|
|
if url.startswith("s3://deafrica-sentinel-2"):
|
|
return url.replace(
|
|
"s3://deafrica-sentinel-2",
|
|
"/vsicurl/https://deafrica-sentinel-2.s3.af-south-1.amazonaws.com"
|
|
)
|
|
return url
|
|
|
|
def fetch_and_format_data(self, lat_range, lon_range, time_range, resolution=20):
|
|
bbox = [lon_range[0], lat_range[0], lon_range[1], lat_range[1]]
|
|
print(f"Searching STAC for Bounding Box: {bbox} over {time_range}...")
|
|
search = self.catalog.search(
|
|
collections=["s2_l2a"],
|
|
bbox=bbox,
|
|
datetime=f"{time_range[0]}/{time_range[1]}"
|
|
)
|
|
items = list(search.items())
|
|
if not items:
|
|
raise ValueError("No STAC items found for this bounding box and time range.")
|
|
|
|
print(f"Found {len(items)} STAC items. Loading into xarray...")
|
|
|
|
band_map = {
|
|
'B04': 'red',
|
|
'B03': 'green',
|
|
'B02': 'blue',
|
|
'B08': 'nir',
|
|
'B05': 'red_edge_1',
|
|
'SCL': 'scl'
|
|
}
|
|
|
|
os.environ["GDAL_DISABLE_READDIR_ON_OPEN"] = "EMPTY_DIR"
|
|
|
|
ds = odc.stac.load(
|
|
items,
|
|
measurements=list(band_map.keys()),
|
|
bbox=bbox,
|
|
crs="EPSG:6933",
|
|
resolution=resolution,
|
|
groupby="solar_day",
|
|
patch_url=self._patch_s3_url
|
|
)
|
|
|
|
# Rename bands to expected names
|
|
ds = ds.rename(band_map)
|
|
|
|
print("Masking clouds and shadows...")
|
|
valid_mask = (ds.scl == 4) | (ds.scl == 5) | (ds.scl == 6) | (ds.scl == 2) | (ds.scl == 7)
|
|
ds = ds.where(valid_mask)
|
|
ds = ds / 10000.0
|
|
|
|
print("Computing Spectral Indices (NDVI, NDRE, SAVI, EVI)...")
|
|
ds['ndvi'] = (ds.nir - ds.red) / (ds.nir + ds.red + 1e-8)
|
|
ds['ndre'] = (ds.nir - ds.red_edge_1) / (ds.nir + ds.red_edge_1 + 1e-8)
|
|
ds['savi'] = ((ds.nir - ds.red) / (ds.nir + ds.red + 0.5)) * 1.5
|
|
ds['evi'] = 2.5 * ((ds.nir - ds.red) / (ds.nir + 6 * ds.red - 7.5 * ds.blue + 1))
|
|
|
|
ds_indices = ds[['ndvi', 'ndre', 'evi', 'savi']]
|
|
|
|
print("Reshaping multi-dimensional xarray into flat Tabular DataFrame...")
|
|
df = ds_indices.compute().to_dataframe().reset_index()
|
|
df['date_str'] = df['time'].dt.strftime('%Y%m%d')
|
|
|
|
df_pivot = df.pivot(index=['y', 'x'], columns='date_str', values=['ndvi', 'ndre', 'evi', 'savi'])
|
|
df_pivot.columns = [f"{date}_{band}" for band, date in df_pivot.columns]
|
|
|
|
df_final = df_pivot.reset_index().rename(columns={'y': 'lat', 'x': 'lon'})
|
|
print(f"✅ Data Ready! {df_final.shape[0]} spatial pixels generated.")
|
|
return df_final
|
|
|
|
# ==========================================
|
|
# 4. INFERENCE PIPELINE
|
|
# ==========================================
|
|
|
|
class CropInferencePipeline:
|
|
def __init__(self, model_dir="/tmp/geocrop-cache"):
|
|
print(f"Loading Crop Inference Pipeline from {model_dir}...")
|
|
meta_path = os.path.join(model_dir, "pipeline_meta.pkl")
|
|
if not os.path.exists(meta_path):
|
|
raise FileNotFoundError(f"Pipeline metadata not found at {meta_path}")
|
|
|
|
self.meta = joblib.load(meta_path)
|
|
self.le = self.meta["le"]
|
|
self.bands = self.meta["bands"]
|
|
self.dates = self.meta["dates"]
|
|
self.w_fcn = self.meta["weights"]["w_fcn"]
|
|
self.w_cb = self.meta["weights"]["w_cb"]
|
|
|
|
self.fcn = TemporalFCN(len(self.bands), self.meta["num_classes"])
|
|
fcn_path = os.path.join(model_dir, "Temporal_FCN.pth")
|
|
self.fcn.load_state_dict(torch.load(fcn_path, map_location=torch.device('cpu')))
|
|
self.fcn.eval()
|
|
|
|
cb_path = os.path.join(model_dir, "calibrated_hybrid_cb.pkl")
|
|
self.calibrated_cb = joblib.load(cb_path)
|
|
print("Models loaded successfully.")
|
|
|
|
def _impute_inference_data(self, df):
|
|
"""
|
|
Inference-specific NaN handling.
|
|
Pixels with >= 3 consecutive gaps are marked as NoData initially.
|
|
Others are interpolated.
|
|
"""
|
|
print("Imputing cloudy/missing timesteps via temporal interpolation...")
|
|
from feature_computation import spatial_fill_nan
|
|
|
|
df = df.copy()
|
|
n_pixels = len(df)
|
|
n_dates = len(self.dates)
|
|
|
|
# 1. Identify "NoData" pixels based on 3 consecutive NaNs/zeros rule
|
|
large_gap_mask = np.zeros(n_pixels, dtype=bool)
|
|
|
|
for band in self.bands:
|
|
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
|
if band_cols:
|
|
band_data = df[band_cols].values.astype(np.float64)
|
|
# Treat 0 as NaN for gap detection
|
|
nan_mask = np.isnan(band_data) | (band_data == 0)
|
|
|
|
# Check for 3 consecutive True
|
|
count = np.zeros(n_pixels)
|
|
max_consecutive = np.zeros(n_pixels)
|
|
for i in range(n_dates):
|
|
is_nan = nan_mask[:, i]
|
|
count = (count + 1) * is_nan
|
|
max_consecutive = np.maximum(max_consecutive, count)
|
|
|
|
large_gap_mask |= (max_consecutive >= 3)
|
|
|
|
# 2. Proceed with interpolation for the rest
|
|
for band in self.bands:
|
|
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
|
if band_cols:
|
|
# Interpolate across the temporal axis for gaps
|
|
df[band_cols] = df[band_cols].interpolate(method='linear', axis=1, limit_direction='both')
|
|
# Fill remaining edge NaNs with 0
|
|
df[band_cols] = df[band_cols].ffill(axis=1).bfill(axis=1).fillna(0)
|
|
|
|
# 3. Apply spatial fill to each band
|
|
for band in self.bands:
|
|
band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns]
|
|
if band_cols:
|
|
band_data = df[band_cols].values.T # (T, Pixels)
|
|
for t_idx in range(band_data.shape[0]):
|
|
# Spatial fill needs 2D or 1D-masked. Here we just use what we have.
|
|
# This step is secondary to temporal interpolation.
|
|
pass
|
|
df[band_cols] = band_data.T
|
|
|
|
return df, large_gap_mask
|
|
|
|
def predict(self, raw_df, apply_spatial_smoothing=True, coord_cols=['lat', 'lon']):
|
|
# 1. Impute Data
|
|
df, large_gap_mask = self._impute_inference_data(raw_df)
|
|
X_infer = prepare_tensors(df, self.bands, self.dates)
|
|
|
|
infer_loader = DataLoader(CropDataset(X_infer, np.zeros(len(df)), augment=False), batch_size=128, shuffle=False)
|
|
|
|
# 2. PyTorch FCN Probs & Features
|
|
fcn_probs = []
|
|
fcn_feats = []
|
|
with torch.no_grad():
|
|
for X_batch, _ in infer_loader:
|
|
out, feats = self.fcn(X_batch, return_features=True)
|
|
fcn_probs.extend(torch.softmax(out, dim=1).numpy())
|
|
fcn_feats.append(feats.numpy())
|
|
|
|
fcn_probs = np.array(fcn_probs)
|
|
fcn_feats = np.vstack(fcn_feats)
|
|
|
|
# 3. Stack Features and get CatBoost Probs
|
|
X_infer_flat = X_infer.reshape(X_infer.shape[0], -1)
|
|
X_stack = np.hstack([X_infer_flat, fcn_feats])
|
|
cb_probs = self.calibrated_cb.predict_proba(X_stack)
|
|
|
|
# 4. Soft Weighted Ensemble
|
|
final_probs = (fcn_probs * self.w_fcn) + (cb_probs * self.w_cb)
|
|
final_preds = np.argmax(final_probs, axis=1)
|
|
|
|
# 5. Apply Initial Masking
|
|
confidence = np.max(final_probs, axis=1)
|
|
# Class 0 is Background/NoData
|
|
final_preds[large_gap_mask] = 0
|
|
|
|
# Track low quality for refinement
|
|
low_quality_mask = (confidence < 0.5) | large_gap_mask
|
|
|
|
# 6. 2D Spatial Majority Filtering (Mode)
|
|
if apply_spatial_smoothing and all(col in df.columns for col in coord_cols):
|
|
print("Applying 2D spatial majority filtering and neighborhood gap-fill...")
|
|
# Reconstruct grid coordinates
|
|
unique_lats = np.sort(df['lat'].unique())[::-1] # North to South
|
|
unique_lons = np.sort(df['lon'].unique())
|
|
|
|
lat_map = {lat: i for i, lat in enumerate(unique_lats)}
|
|
lon_map = {lon: j for j, lon in enumerate(unique_lons)}
|
|
|
|
h, w = len(unique_lats), len(unique_lons)
|
|
grid_class = np.zeros((h, w), dtype=np.uint16)
|
|
grid_low_q = np.zeros((h, w), dtype=bool)
|
|
|
|
# Map pixels to grid
|
|
pixel_indices = []
|
|
for idx, row in df.iterrows():
|
|
r, c = lat_map[row['lat']], lon_map[row['lon']]
|
|
grid_class[r, c] = final_preds[idx]
|
|
grid_low_q[r, c] = low_quality_mask[idx]
|
|
pixel_indices.append((r, c))
|
|
|
|
# Majority filter (Mode)
|
|
def mode_filter(window):
|
|
# Ignore 0 (NoData) unless the whole window is 0
|
|
valid = window[window > 0]
|
|
if valid.size == 0:
|
|
return 0
|
|
# stats.mode returns ModeResult(mode, count)
|
|
m = stats.mode(valid, keepdims=True)
|
|
return m.mode[0]
|
|
|
|
# Pass 1: Refine low-quality/gap pixels using 3x3 mode
|
|
# This fills gaps with neighboring labels
|
|
refined_grid = ndimage.generic_filter(grid_class, mode_filter, size=3)
|
|
|
|
# Only overwrite if it was low quality or a gap
|
|
grid_class = np.where(grid_low_q, refined_grid, grid_class)
|
|
|
|
# Update predictions back to dataframe
|
|
for i, (r, c) in enumerate(pixel_indices):
|
|
final_preds[i] = grid_class[r, c]
|
|
|
|
# 7. Final labels
|
|
df['class_id'] = final_preds
|
|
df['predicted_crop'] = self.le.inverse_transform(final_preds)
|
|
df['confidence'] = confidence
|
|
|
|
# Ensure NoData label is assigned for any remaining 0s
|
|
df.loc[df['class_id'] == 0, 'predicted_crop'] = 'Unknown/NoData'
|
|
|
|
return df
|
|
|
|
def export_to_geotiff(self, df, output_path="lulc_map.tif", crs="EPSG:6933"):
|
|
if xr is None or rioxarray is None:
|
|
raise ImportError("Missing required libraries: xarray, rioxarray")
|
|
print(f"Exporting LULC masks to {output_path}...")
|
|
|
|
ds_out = df.set_index(['lat', 'lon'])[['class_id', 'confidence', 'low_quality']].to_xarray()
|
|
ds_out = ds_out.rename({'lat': 'y', 'lon': 'x'})
|
|
ds_out = ds_out.sortby('y', ascending=False)
|
|
ds_out = ds_out.rio.set_spatial_dims(x_dim='x', y_dim='y')
|
|
ds_out.rio.write_crs(crs, inplace=True)
|
|
|
|
ds_out['class_id'].astype('uint16').rio.to_raster(output_path)
|
|
|
|
conf_path = output_path.replace('.tif', '_confidence.tif')
|
|
ds_out['confidence'].astype('float32').rio.to_raster(conf_path)
|
|
|
|
mask_path = output_path.replace('.tif', '_cloud_mask.tif')
|
|
ds_out['low_quality'].astype('uint8').rio.to_raster(mask_path)
|
|
|
|
legend_path = output_path.replace('.tif', '_legend.json')
|
|
legend_dict = {int(i): str(c) for i, c in enumerate(self.le.classes_)}
|
|
if 0 not in legend_dict:
|
|
legend_dict[0] = 'Unknown/NoData'
|
|
with open(legend_path, 'w') as f:
|
|
json.dump(legend_dict, f, indent=4)
|
|
print(f"✅ Successfully exported GeoTIFFs and class legend!")
|