geocrop-platform./training/train.py

310 lines
9.8 KiB
Python

"""GeoCrop training entrypoint.
This is a cleaned and production-friendly version of your notebook/script.
It trains multiple models + a soft-voting ensemble, logs metrics, and uploads
artifacts (model, label encoder, selected feature list) for inference.
Notes:
- Keeps your sklearn 1.6+ compatibility wrapper.
- Stores metadata needed for inference (classes, features, scaling decision).
- If you want to keep both Raw and Scaled ensembles, run twice.
Usage:
python train.py --data /path/to/Zimbabwe_Crop_Engineered_Ready.csv --out ./artifacts --variant Raw
python train.py --data /path/to/Zimbabwe_Crop_Engineered_Ready.csv --out ./artifacts --variant Scaled
Optional (MinIO):
export MINIO_ENDPOINT=...
export MINIO_ACCESS_KEY=...
export MINIO_SECRET_KEY=...
export MINIO_BUCKET=geocrop-models
python train.py ... --upload-minio
"""
from __future__ import annotations
import argparse
import json
import os
import warnings
from dataclasses import asdict
from pathlib import Path
from typing import Dict, List, Tuple
import joblib
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
import xgboost as xgb
import lightgbm as lgb
from catboost import CatBoostClassifier
from config import TrainingConfig
from features import (
drop_junk_columns,
scout_feature_selection,
scale_numeric_features,
)
# -----------------------------
# Warnings
# -----------------------------
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=pd.errors.PerformanceWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# ==========================================
# 0. GENERIC COMPATIBILITY WRAPPER
# ==========================================
class Sklearn6CompatibilityWrapper(BaseEstimator, ClassifierMixin):
"""Wrap 3rd-party classifiers for sklearn 1.6+ compatibility."""
_estimator_type = "classifier"
def __init__(self, model_class=None, **kwargs):
self.model_class = model_class
self.kwargs = kwargs
self.model = None
def fit(self, X, y):
self.model = self.model_class(**self.kwargs)
self.model.fit(X, y)
if hasattr(self.model, "classes_"):
self.classes_ = self.model.classes_
return self
def predict(self, X):
return self.model.predict(X)
def predict_proba(self, X):
return self.model.predict_proba(X)
@property
def feature_importances_(self):
return self.model.feature_importances_ if self.model else None
def get_params(self, deep=True):
return {"model_class": self.model_class, **self.kwargs}
def set_params(self, **parameters):
for parameter, value in parameters.items():
if parameter == "model_class":
self.model_class = value
else:
self.kwargs[parameter] = value
return self
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.estimator_type = "classifier"
return tags
# ==========================================
# Training
# ==========================================
def build_models(cfg: TrainingConfig) -> Dict[str, BaseEstimator]:
"""Return model dictionary matching your original settings."""
return {
"RandomForest": RandomForestClassifier(
n_estimators=cfg.rf_n_estimators,
n_jobs=-1,
random_state=cfg.random_state,
class_weight="balanced",
),
"XGBoost": Sklearn6CompatibilityWrapper(
model_class=xgb.XGBClassifier,
n_estimators=cfg.xgb_n_estimators,
learning_rate=cfg.xgb_learning_rate,
max_depth=cfg.xgb_max_depth,
subsample=cfg.xgb_subsample,
colsample_bytree=cfg.xgb_colsample_bytree,
eval_metric="mlogloss",
n_jobs=-1,
random_state=cfg.random_state,
),
"LightGBM": Sklearn6CompatibilityWrapper(
model_class=lgb.LGBMClassifier,
n_estimators=cfg.lgb_n_estimators,
learning_rate=cfg.lgb_learning_rate,
num_leaves=cfg.lgb_num_leaves,
subsample=cfg.lgb_subsample,
colsample_bytree=cfg.lgb_colsample_bytree,
min_child_samples=cfg.lgb_min_child_samples,
class_weight="balanced",
n_jobs=-1,
random_state=cfg.random_state,
verbose=-1,
),
"CatBoost": Sklearn6CompatibilityWrapper(
model_class=CatBoostClassifier,
iterations=cfg.cb_iterations,
learning_rate=cfg.cb_learning_rate,
depth=cfg.cb_depth,
verbose=0,
random_seed=cfg.random_state,
auto_class_weights="Balanced",
allow_writing_files=False,
),
}
def evaluate(y_true: np.ndarray, y_pred: np.ndarray, label_names: List[str]) -> Dict:
acc = float(accuracy_score(y_true, y_pred))
f1m = float(f1_score(y_true, y_pred, average="macro"))
report = classification_report(
y_true, y_pred, target_names=label_names, output_dict=True, zero_division=0
)
cm = confusion_matrix(y_true, y_pred).tolist()
return {"accuracy": acc, "f1_macro": f1m, "report": report, "confusion": cm}
def train_one_variant(
df: pd.DataFrame,
cfg: TrainingConfig,
variant: str,
out_dir: Path,
) -> Path:
out_dir.mkdir(parents=True, exist_ok=True)
df_clean = drop_junk_columns(df, cfg.junk_cols)
if cfg.label_col not in df_clean.columns:
raise ValueError(f"Missing label column '{cfg.label_col}'")
X = df_clean.drop(columns=[cfg.label_col])
y = df_clean[cfg.label_col]
le = LabelEncoder()
y_enc = le.fit_transform(y)
class_names = le.classes_.tolist()
X_train, X_test, y_train, y_test = train_test_split(
X,
y_enc,
test_size=cfg.test_size,
random_state=cfg.random_state,
stratify=y_enc,
)
selected_features = scout_feature_selection(
X_train, y_train, n_estimators=cfg.scout_n_estimators, random_state=cfg.random_state
)
X_train = X_train[selected_features]
X_test = X_test[selected_features]
scaler = None
if variant.lower() == "scaled":
X_train, X_test, scaler = scale_numeric_features(X_train, X_test)
models = build_models(cfg)
metrics: Dict[str, Dict] = {}
trained: Dict[str, BaseEstimator] = {}
for name, model in models.items():
model.fit(X_train, y_train)
preds = model.predict(X_test)
metrics[name] = evaluate(y_test, preds, class_names)
trained[name] = model
ensemble = VotingClassifier(
estimators=[(n, m) for n, m in trained.items()], voting="soft", n_jobs=-1
)
ensemble.fit(X_train, y_train)
ens_preds = ensemble.predict(X_test)
metrics["Ensemble"] = evaluate(y_test, ens_preds, class_names)
# Persist artifacts needed for inference
artifact_dir = out_dir / f"model_{variant.lower()}"
artifact_dir.mkdir(parents=True, exist_ok=True)
joblib.dump(ensemble, artifact_dir / "model.joblib")
joblib.dump(le, artifact_dir / "label_encoder.joblib")
if scaler is not None:
joblib.dump(scaler, artifact_dir / "scaler.joblib")
(artifact_dir / "selected_features.json").write_text(
json.dumps(selected_features, indent=2)
)
meta = {
"variant": variant,
"class_names": class_names,
"n_features": len(selected_features),
"config": asdict(cfg),
}
(artifact_dir / "meta.json").write_text(json.dumps(meta, indent=2))
(artifact_dir / "metrics.json").write_text(json.dumps(metrics, indent=2))
return artifact_dir
def maybe_upload_to_minio(artifact_dir: Path, cfg: TrainingConfig):
if not cfg.upload_minio:
return
try:
import boto3
from botocore.client import Config
except Exception as e:
raise RuntimeError("boto3 is required for MinIO upload") from e
s3 = boto3.client(
"s3",
endpoint_url=cfg.minio_endpoint,
aws_access_key_id=cfg.minio_access_key,
aws_secret_access_key=cfg.minio_secret_key,
config=Config(signature_version="s3v4"),
region_name="us-east-1",
)
# Ensure bucket exists
try:
s3.head_bucket(Bucket=cfg.minio_bucket)
except Exception:
s3.create_bucket(Bucket=cfg.minio_bucket)
prefix = f"{cfg.minio_prefix}/{artifact_dir.name}"
for p in artifact_dir.rglob("*"):
if p.is_file():
key = f"{prefix}/{p.name}"
s3.upload_file(str(p), cfg.minio_bucket, key)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data", required=True, help="CSV path: Zimbabwe_Crop_Engineered_Ready.csv")
parser.add_argument("--out", required=True, help="Output directory for artifacts")
parser.add_argument("--variant", choices=["Raw", "Scaled"], default="Raw")
parser.add_argument("--upload-minio", action="store_true")
args = parser.parse_args()
cfg = TrainingConfig(
upload_minio=args.upload_minio,
minio_endpoint=os.getenv("MINIO_ENDPOINT", ""),
minio_access_key=os.getenv("MINIO_ACCESS_KEY", ""),
minio_secret_key=os.getenv("MINIO_SECRET_KEY", ""),
minio_bucket=os.getenv("MINIO_BUCKET", "geocrop-models"),
minio_prefix=os.getenv("MINIO_PREFIX", "models"),
)
df = pd.read_csv(args.data)
out_dir = Path(args.out)
artifact_dir = train_one_variant(df, cfg, args.variant, out_dir)
maybe_upload_to_minio(artifact_dir, cfg)
print(f"Saved artifacts to: {artifact_dir}")
if __name__ == "__main__":
main()