648 lines
21 KiB
Python
648 lines
21 KiB
Python
"""GeoCrop inference pipeline (worker-side).
|
|
|
|
This module is designed to be called by your RQ worker.
|
|
Given a job payload (AOI, year, model choice), it:
|
|
1) Loads the correct model artifact from MinIO (or local cache).
|
|
2) Loads/clips the DW baseline COG for the requested season/year.
|
|
3) Queries Digital Earth Africa STAC for imagery and builds feature stack.
|
|
- IMPORTANT: Uses exact feature engineering from train.py:
|
|
- Savitzky-Golay smoothing (window=5, polyorder=2)
|
|
- Phenology metrics (amplitude, AUC, peak, slope)
|
|
- Harmonic features (1st/2nd order sin/cos)
|
|
- Seasonal window statistics (Early/Peak/Late)
|
|
4) Runs per-pixel inference to produce refined classes at 10m.
|
|
5) Applies neighborhood smoothing (majority filter).
|
|
6) Writes output GeoTIFF (COG recommended) to MinIO.
|
|
|
|
IMPORTANT: This implementation supports the current MinIO model format:
|
|
- Zimbabwe_Ensemble_Raw_Model.pkl (no scaler needed)
|
|
- Zimbabwe_Ensemble_Model.pkl (scaler needed)
|
|
- etc.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import tempfile
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, Optional, Tuple, List
|
|
|
|
# Try to import required dependencies
|
|
try:
|
|
import joblib
|
|
except ImportError:
|
|
joblib = None
|
|
|
|
try:
|
|
import numpy as np
|
|
except ImportError:
|
|
np = None
|
|
|
|
try:
|
|
import rasterio
|
|
from rasterio import windows
|
|
from rasterio.enums import Resampling
|
|
except ImportError:
|
|
rasterio = None
|
|
windows = None
|
|
Resampling = None
|
|
|
|
try:
|
|
from config import InferenceConfig
|
|
except ImportError:
|
|
InferenceConfig = None
|
|
|
|
try:
|
|
from features import (
|
|
build_feature_stack_from_dea,
|
|
clip_raster_to_aoi,
|
|
load_dw_baseline_window,
|
|
majority_filter,
|
|
validate_aoi_zimbabwe,
|
|
)
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
# ==========================================
|
|
# STEP 6: Model Loading and Raster Prediction
|
|
# ==========================================
|
|
|
|
def load_model(storage, model_name: str):
|
|
"""Load a trained model from MinIO storage.
|
|
|
|
Args:
|
|
storage: MinIOStorage instance with download_model_file method
|
|
model_name: Name of model (e.g., "RandomForest", "XGBoost", "Ensemble")
|
|
|
|
Returns:
|
|
Loaded sklearn-compatible model
|
|
|
|
Raises:
|
|
FileNotFoundError: If model file not found
|
|
ValueError: If model has incompatible number of features
|
|
"""
|
|
# Create temp directory for download
|
|
import tempfile
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
dest_dir = Path(tmp_dir)
|
|
|
|
# Download model file from MinIO
|
|
# storage.download_model_file already handles mapping
|
|
model_path = storage.download_model_file(model_name, dest_dir)
|
|
|
|
# Load model with joblib
|
|
model = joblib.load(model_path)
|
|
|
|
# Validate model compatibility
|
|
if hasattr(model, 'n_features_in_'):
|
|
expected_features = 51
|
|
actual_features = model.n_features_in_
|
|
|
|
if actual_features != expected_features:
|
|
raise ValueError(
|
|
f"Model feature mismatch: model expects {actual_features} features "
|
|
f"but worker provides 51 features. "
|
|
f"Model: {model_name}, Expected: {actual_features}, Got: 51"
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
def predict_raster(
|
|
model,
|
|
feature_cube: np.ndarray,
|
|
feature_order: List[str],
|
|
) -> np.ndarray:
|
|
"""Run inference on a feature cube.
|
|
|
|
Args:
|
|
model: Trained sklearn-compatible model
|
|
feature_cube: 3D array of shape (H, W, 51) containing features
|
|
feature_order: List of 51 feature names in order
|
|
|
|
Returns:
|
|
2D array of shape (H, W) with class predictions
|
|
|
|
Raises:
|
|
ValueError: If feature_cube dimensions don't match feature_order
|
|
"""
|
|
# Validate dimensions
|
|
expected_features = len(feature_order)
|
|
actual_features = feature_cube.shape[-1]
|
|
|
|
if actual_features != expected_features:
|
|
raise ValueError(
|
|
f"Feature dimension mismatch: feature_cube has {actual_features} features "
|
|
f"but feature_order has {expected_features}. "
|
|
f"feature_cube shape: {feature_cube.shape}, feature_order length: {len(feature_order)}. "
|
|
f"Expected 51 features matching FEATURE_ORDER_V1."
|
|
)
|
|
|
|
H, W, C = feature_cube.shape
|
|
|
|
# Flatten spatial dimensions: (H, W, C) -> (H*W, C)
|
|
X = feature_cube.reshape(-1, C)
|
|
|
|
# Identify nodata pixels (all zeros)
|
|
nodata_mask = np.all(X == 0, axis=1)
|
|
num_nodata = np.sum(nodata_mask)
|
|
|
|
# Replace nodata with small non-zero values to avoid model issues
|
|
# The predictions will be overwritten for nodata pixels anyway
|
|
X_safe = X.copy()
|
|
if num_nodata > 0:
|
|
# Use epsilon to avoid division by zero in some models
|
|
X_safe[nodata_mask] = np.full(C, 1e-6)
|
|
|
|
# Run prediction
|
|
y_pred = model.predict(X_safe)
|
|
|
|
# Set nodata pixels to 0 (assuming class 0 reserved for nodata)
|
|
if num_nodata > 0:
|
|
y_pred[nodata_mask] = 0
|
|
|
|
# Reshape back to (H, W)
|
|
result = y_pred.reshape(H, W)
|
|
|
|
return result
|
|
|
|
|
|
# ==========================================
|
|
# Legacy functions (kept for backward compatibility)
|
|
# ==========================================
|
|
|
|
|
|
# Model name to MinIO filename mapping
|
|
# Format: "Zimbabwe_<ModelName>_Model.pkl" or "Zimbabwe_<ModelName>_Raw_Model.pkl"
|
|
MODEL_NAME_MAPPING = {
|
|
# Ensemble models
|
|
"Ensemble": "Zimbabwe_Ensemble_Raw_Model.pkl",
|
|
"Ensemble_Raw": "Zimbabwe_Ensemble_Raw_Model.pkl",
|
|
"Ensemble_Scaled": "Zimbabwe_Ensemble_Model.pkl",
|
|
|
|
# Individual models
|
|
"RandomForest": "Zimbabwe_RandomForest_Model.pkl",
|
|
"XGBoost": "Zimbabwe_XGBoost_Model.pkl",
|
|
"LightGBM": "Zimbabwe_LightGBM_Model.pkl",
|
|
"CatBoost": "Zimbabwe_CatBoost_Model.pkl",
|
|
|
|
# Legacy/raw variants
|
|
"RandomForest_Raw": "Zimbabwe_RandomForest_Model.pkl",
|
|
"XGBoost_Raw": "Zimbabwe_XGBoost_Model.pkl",
|
|
"LightGBM_Raw": "Zimbabwe_LightGBM_Model.pkl",
|
|
"CatBoost_Raw": "Zimbabwe_CatBoost_Model.pkl",
|
|
}
|
|
|
|
# Default class mapping if label encoder not available
|
|
# Based on typical Zimbabwe crop classification
|
|
DEFAULT_CLASSES = [
|
|
"cropland_rainfed",
|
|
"cropland_irrigated",
|
|
"tree_crop",
|
|
"grassland",
|
|
"shrubland",
|
|
"urban",
|
|
"water",
|
|
"bare",
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class InferenceResult:
|
|
job_id: str
|
|
status: str
|
|
outputs: Dict[str, str]
|
|
meta: Dict
|
|
|
|
|
|
def _local_artifact_cache_dir() -> Path:
|
|
d = Path(os.getenv("GEOCROP_CACHE_DIR", "/tmp/geocrop-cache"))
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
return d
|
|
|
|
|
|
def get_model_filename(model_name: str) -> str:
|
|
"""Get the MinIO filename for a given model name.
|
|
|
|
Args:
|
|
model_name: Model name from job payload (e.g., "Ensemble", "Ensemble_Scaled")
|
|
|
|
Returns:
|
|
MinIO filename (e.g., "Zimbabwe_Ensemble_Raw_Model.pkl")
|
|
"""
|
|
# Direct lookup
|
|
if model_name in MODEL_NAME_MAPPING:
|
|
return MODEL_NAME_MAPPING[model_name]
|
|
|
|
# Try case-insensitive
|
|
model_lower = model_name.lower()
|
|
for key, value in MODEL_NAME_MAPPING.items():
|
|
if key.lower() == model_lower:
|
|
return value
|
|
|
|
# Default fallback
|
|
if "_raw" in model_lower:
|
|
return f"Zimbabwe_{model_name.replace('_Raw', '').title()}_Raw_Model.pkl"
|
|
else:
|
|
return f"Zimbabwe_{model_name.title()}_Model.pkl"
|
|
|
|
|
|
def needs_scaler(model_name: str) -> bool:
|
|
"""Determine if a model needs feature scaling.
|
|
|
|
Models with "_Raw" suffix do NOT need scaling.
|
|
All other models require StandardScaler.
|
|
|
|
Args:
|
|
model_name: Model name from job payload
|
|
|
|
Returns:
|
|
True if scaler should be applied
|
|
"""
|
|
# Check for _Raw suffix
|
|
if "_raw" in model_name.lower():
|
|
return False
|
|
|
|
# Ensemble without suffix defaults to raw
|
|
if model_name.lower() == "ensemble":
|
|
return False
|
|
|
|
# Default: needs scaling
|
|
return True
|
|
|
|
|
|
def load_model_artifacts(cfg: InferenceConfig, model_name: str) -> Tuple[object, object, Optional[object], List[str]]:
|
|
"""Load model, label encoder, optional scaler, and feature list.
|
|
|
|
Supports current MinIO format:
|
|
- Zimbabwe_*_Raw_Model.pkl (no scaler)
|
|
- Zimbabwe_*_Model.pkl (needs scaler)
|
|
|
|
Args:
|
|
cfg: Inference configuration
|
|
model_name: Name of the model to load
|
|
|
|
Returns:
|
|
Tuple of (model, label_encoder, scaler, selected_features)
|
|
"""
|
|
cache = _local_artifact_cache_dir() / model_name.replace(" ", "_")
|
|
cache.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Get the MinIO filename
|
|
model_filename = get_model_filename(model_name)
|
|
model_key = f"models/{model_filename}" # Prefix in bucket
|
|
|
|
model_p = cache / "model.pkl"
|
|
le_p = cache / "label_encoder.pkl"
|
|
scaler_p = cache / "scaler.pkl"
|
|
feats_p = cache / "selected_features.json"
|
|
|
|
# Check if cached
|
|
if not model_p.exists():
|
|
print(f"📥 Downloading model from MinIO: {model_key}")
|
|
cfg.storage.download_model_bundle(model_key, cache)
|
|
|
|
# Load model
|
|
model = joblib.load(model_p)
|
|
|
|
# Load or create label encoder
|
|
if le_p.exists():
|
|
label_encoder = joblib.load(le_p)
|
|
else:
|
|
# Try to get classes from model
|
|
print("⚠️ Label encoder not found, creating default")
|
|
from sklearn.preprocessing import LabelEncoder
|
|
label_encoder = LabelEncoder()
|
|
# Fit on default classes
|
|
label_encoder.fit(DEFAULT_CLASSES)
|
|
|
|
# Load scaler if needed
|
|
scaler = None
|
|
if needs_scaler(model_name):
|
|
if scaler_p.exists():
|
|
scaler = joblib.load(scaler_p)
|
|
else:
|
|
print("⚠️ Scaler not found but required for this model variant")
|
|
# Create a dummy scaler that does nothing
|
|
from sklearn.preprocessing import StandardScaler
|
|
scaler = StandardScaler()
|
|
# Note: In production, this should fail - scaler must be uploaded
|
|
|
|
# Load selected features
|
|
if feats_p.exists():
|
|
selected_features = json.loads(feats_p.read_text())
|
|
else:
|
|
print("⚠️ Selected features not found, will use all computed features")
|
|
selected_features = None
|
|
|
|
return model, label_encoder, scaler, selected_features
|
|
|
|
|
|
def run_inference_job(cfg: InferenceConfig, job: Dict) -> InferenceResult:
|
|
"""Main worker entry.
|
|
|
|
job payload example:
|
|
{
|
|
"job_id": "...",
|
|
"user_id": "...",
|
|
"lat": -17.8,
|
|
"lon": 31.0,
|
|
"radius_m": 2000,
|
|
"year": 2022,
|
|
"season": "summer",
|
|
"model": "Ensemble" # or "Ensemble_Scaled", "RandomForest", etc.
|
|
}
|
|
"""
|
|
|
|
job_id = str(job.get("job_id"))
|
|
|
|
# 1) Validate AOI constraints
|
|
aoi = (float(job["lon"]), float(job["lat"]), float(job["radius_m"]))
|
|
validate_aoi_zimbabwe(aoi, max_radius_m=cfg.max_radius_m)
|
|
|
|
year = int(job["year"])
|
|
season = str(job.get("season", "summer")).lower()
|
|
|
|
# Your training window (Sep -> May)
|
|
start_date, end_date = cfg.season_dates(year=year, season=season)
|
|
|
|
model_name = str(job.get("model", "Ensemble"))
|
|
print(f"🤖 Loading model: {model_name}")
|
|
|
|
model, le, scaler, selected_features = load_model_artifacts(cfg, model_name)
|
|
|
|
# Determine if we need scaling
|
|
use_scaler = scaler is not None and needs_scaler(model_name)
|
|
print(f" Scaler required: {use_scaler}")
|
|
|
|
# 2) Load DW baseline for this year/season (already converted to COGs)
|
|
# (This gives you the "DW baseline toggle" layer too.)
|
|
dw_arr, dw_profile = load_dw_baseline_window(
|
|
cfg=cfg,
|
|
year=year,
|
|
season=season,
|
|
aoi=aoi,
|
|
)
|
|
|
|
# 3) Build EO feature stack from DEA STAC
|
|
# IMPORTANT: This now uses full feature engineering matching train.py
|
|
print("📡 Building feature stack from DEA STAC...")
|
|
feat_arr, feat_profile, feat_names, aux_layers = build_feature_stack_from_dea(
|
|
cfg=cfg,
|
|
aoi=aoi,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
target_profile=dw_profile,
|
|
)
|
|
|
|
print(f" Computed {len(feat_names)} features")
|
|
print(f" Feature array shape: {feat_arr.shape}")
|
|
|
|
# 4) Prepare model input: (H,W,C) -> (N,C)
|
|
H, W, C = feat_arr.shape
|
|
X = feat_arr.reshape(-1, C)
|
|
|
|
# Ensure feature order matches training
|
|
if selected_features is not None:
|
|
name_to_idx = {n: i for i, n in enumerate(feat_names)}
|
|
keep_idx = [name_to_idx[n] for n in selected_features if n in name_to_idx]
|
|
|
|
if len(keep_idx) == 0:
|
|
print("⚠️ No matching features found, using all computed features")
|
|
else:
|
|
print(f" Using {len(keep_idx)} selected features")
|
|
X = X[:, keep_idx]
|
|
else:
|
|
print(" Using all computed features (no selection)")
|
|
|
|
# Apply scaler if needed
|
|
if use_scaler and scaler is not None:
|
|
print(" Applying StandardScaler")
|
|
X = scaler.transform(X)
|
|
|
|
# Handle NaNs (common with clouds/no-data)
|
|
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
|
|
|
|
# 5) Predict
|
|
print("🔮 Running prediction...")
|
|
y_pred = model.predict(X).astype(np.int32)
|
|
|
|
# Back to string labels (your refined classes)
|
|
try:
|
|
refined_labels = le.inverse_transform(y_pred)
|
|
except Exception as e:
|
|
print(f"⚠️ Label inverse_transform failed: {e}")
|
|
# Fallback: use default classes
|
|
refined_labels = np.array([DEFAULT_CLASSES[i % len(DEFAULT_CLASSES)] for i in y_pred])
|
|
|
|
refined_labels = refined_labels.reshape(H, W)
|
|
|
|
# 6) Neighborhood smoothing (majority filter)
|
|
smoothing_kernel = job.get("smoothing_kernel", cfg.smoothing_kernel)
|
|
if cfg.smoothing_enabled and smoothing_kernel > 1:
|
|
print(f"🧼 Applying majority filter (k={smoothing_kernel})")
|
|
refined_labels = majority_filter(refined_labels, k=smoothing_kernel)
|
|
|
|
# 7) Write outputs (GeoTIFF only; COG recommended for tiling)
|
|
ts = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
|
|
out_name = f"refined_{season}_{year}_{job_id}_{ts}.tif"
|
|
baseline_name = f"dw_{season}_{year}_{job_id}_{ts}.tif"
|
|
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
refined_path = Path(tmp) / out_name
|
|
dw_path = Path(tmp) / baseline_name
|
|
|
|
# DW baseline
|
|
with rasterio.open(dw_path, "w", **dw_profile) as dst:
|
|
dst.write(dw_arr, 1)
|
|
|
|
# Refined - store as uint16 with a sidecar legend in meta (recommended)
|
|
# For now store an index raster; map index->class in meta.json
|
|
classes = le.classes_.tolist() if hasattr(le, 'classes_') else DEFAULT_CLASSES
|
|
class_to_idx = {c: i for i, c in enumerate(classes)}
|
|
|
|
# Handle string labels
|
|
if refined_labels.dtype.kind in ['U', 'O', 'S']:
|
|
# String labels - create mapping
|
|
idx_raster = np.zeros((H, W), dtype=np.uint16)
|
|
for i, cls in enumerate(classes):
|
|
mask = refined_labels == cls
|
|
idx_raster[mask] = i
|
|
else:
|
|
# Numeric labels already
|
|
idx_raster = refined_labels.astype(np.uint16)
|
|
|
|
refined_profile = dw_profile.copy()
|
|
refined_profile.update({"dtype": "uint16", "count": 1})
|
|
|
|
with rasterio.open(refined_path, "w", **refined_profile) as dst:
|
|
dst.write(idx_raster, 1)
|
|
|
|
# Upload
|
|
refined_uri = cfg.storage.upload_result(local_path=refined_path, key=f"results/{out_name}")
|
|
dw_uri = cfg.storage.upload_result(local_path=dw_path, key=f"results/{baseline_name}")
|
|
|
|
# Optionally upload aux layers (true color, NDVI/EVI/SAVI)
|
|
aux_uris = {}
|
|
for layer_name, layer in aux_layers.items():
|
|
# layer: (H,W) or (H,W,3)
|
|
aux_path = Path(tmp) / f"{layer_name}_{season}_{year}_{job_id}_{ts}.tif"
|
|
|
|
# Determine count and dtype
|
|
if layer.ndim == 3 and layer.shape[2] == 3:
|
|
count = 3
|
|
dtype = layer.dtype
|
|
else:
|
|
count = 1
|
|
dtype = layer.dtype
|
|
|
|
aux_profile = dw_profile.copy()
|
|
aux_profile.update({"count": count, "dtype": str(dtype)})
|
|
|
|
with rasterio.open(aux_path, "w", **aux_profile) as dst:
|
|
if count == 1:
|
|
dst.write(layer, 1)
|
|
else:
|
|
dst.write(layer.transpose(2, 0, 1), [1, 2, 3])
|
|
|
|
aux_uris[layer_name] = cfg.storage.upload_result(
|
|
local_path=aux_path, key=f"results/{aux_path.name}"
|
|
)
|
|
|
|
meta = {
|
|
"job_id": job_id,
|
|
"year": year,
|
|
"season": season,
|
|
"start_date": start_date,
|
|
"end_date": end_date,
|
|
"model": model_name,
|
|
"scaler_used": use_scaler,
|
|
"classes": classes,
|
|
"class_index": class_to_idx,
|
|
"features_computed": feat_names,
|
|
"n_features": len(feat_names),
|
|
"smoothing": {"enabled": cfg.smoothing_enabled, "kernel": smoothing_kernel},
|
|
}
|
|
|
|
outputs = {
|
|
"refined_geotiff": refined_uri,
|
|
"dw_baseline_geotiff": dw_uri,
|
|
**aux_uris,
|
|
}
|
|
|
|
return InferenceResult(job_id=job_id, status="done", outputs=outputs, meta=meta)
|
|
|
|
|
|
# ==========================================
|
|
# Self-Test
|
|
# ==========================================
|
|
|
|
if __name__ == "__main__":
|
|
print("=== Inference Module Self-Test ===")
|
|
|
|
# Check for required dependencies
|
|
missing_deps = []
|
|
for mod in ['joblib', 'sklearn']:
|
|
try:
|
|
__import__(mod)
|
|
except ImportError:
|
|
missing_deps.append(mod)
|
|
|
|
if missing_deps:
|
|
print(f"\n⚠️ Missing dependencies: {missing_deps}")
|
|
print(" These will be available in the container environment.")
|
|
print(" Running syntax validation only...")
|
|
|
|
# Test 1: predict_raster with dummy data (only if sklearn available)
|
|
print("\n1. Testing predict_raster with dummy feature cube...")
|
|
|
|
# Create dummy feature cube (10, 10, 51)
|
|
H, W, C = 10, 10, 51
|
|
dummy_cube = np.random.rand(H, W, C).astype(np.float32)
|
|
|
|
# Create dummy feature order
|
|
from feature_computation import FEATURE_ORDER_V1
|
|
feature_order = FEATURE_ORDER_V1
|
|
|
|
print(f" Feature cube shape: {dummy_cube.shape}")
|
|
print(f" Feature order length: {len(feature_order)}")
|
|
|
|
if 'sklearn' not in missing_deps:
|
|
# Create a dummy model for testing
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
|
|
# Train a small model on random data
|
|
X_train = np.random.rand(100, C)
|
|
y_train = np.random.randint(0, 8, 100)
|
|
dummy_model = RandomForestClassifier(n_estimators=10, random_state=42)
|
|
dummy_model.fit(X_train, y_train)
|
|
|
|
# Verify model compatibility check
|
|
print(f" Model n_features_in_: {dummy_model.n_features_in_}")
|
|
|
|
# Run prediction
|
|
try:
|
|
result = predict_raster(dummy_model, dummy_cube, feature_order)
|
|
print(f" Prediction result shape: {result.shape}")
|
|
print(f" Expected shape: ({H}, {W})")
|
|
|
|
if result.shape == (H, W):
|
|
print(" ✓ predict_raster test PASSED")
|
|
else:
|
|
print(" ✗ predict_raster test FAILED - wrong shape")
|
|
except Exception as e:
|
|
print(f" ✗ predict_raster test FAILED: {e}")
|
|
|
|
# Test 2: predict_raster with nodata handling
|
|
print("\n2. Testing nodata handling...")
|
|
|
|
# Create cube with nodata (all zeros)
|
|
nodata_cube = np.zeros((5, 5, C), dtype=np.float32)
|
|
nodata_cube[2, 2, :] = 1.0 # One valid pixel
|
|
|
|
result_nodata = predict_raster(dummy_model, nodata_cube, feature_order)
|
|
print(f" Nodata pixel value at [2,2]: {result_nodata[2, 2]}")
|
|
print(f" Nodata pixels (should be 0): {result_nodata[0, 0]}")
|
|
|
|
if result_nodata[0, 0] == 0 and result_nodata[0, 1] == 0:
|
|
print(" ✓ Nodata handling test PASSED")
|
|
else:
|
|
print(" ✗ Nodata handling test FAILED")
|
|
|
|
# Test 3: Feature mismatch detection
|
|
print("\n3. Testing feature mismatch detection...")
|
|
|
|
wrong_cube = np.random.rand(5, 5, 50).astype(np.float32) # 50 features, not 51
|
|
|
|
try:
|
|
predict_raster(dummy_model, wrong_cube, feature_order)
|
|
print(" ✗ Feature mismatch test FAILED - should have raised ValueError")
|
|
except ValueError as e:
|
|
if "Feature dimension mismatch" in str(e):
|
|
print(" ✓ Feature mismatch test PASSED")
|
|
else:
|
|
print(f" ✗ Wrong error: {e}")
|
|
else:
|
|
print(" (sklearn not available - skipping)")
|
|
|
|
# Test 4: Try loading model from MinIO (will fail without real storage)
|
|
print("\n4. Testing load_model from MinIO...")
|
|
try:
|
|
from storage import MinIOStorage
|
|
storage = MinIOStorage()
|
|
|
|
# This will fail without real MinIO, but we can catch the error
|
|
model = load_model(storage, "RandomForest")
|
|
print(" Model loaded successfully")
|
|
print(" ✓ load_model test PASSED")
|
|
except Exception as e:
|
|
print(f" (Expected) MinIO/storage not available: {e}")
|
|
print(" ✓ load_model test handled gracefully")
|
|
|
|
print("\n=== Inference Module Test Complete ===")
|
|
|