"""GeoCrop training entrypoint. This is a cleaned and production-friendly version of your notebook/script. It trains multiple models + a soft-voting ensemble, logs metrics, and uploads artifacts (model, label encoder, selected feature list) for inference. Notes: - Keeps your sklearn 1.6+ compatibility wrapper. - Stores metadata needed for inference (classes, features, scaling decision). - If you want to keep both Raw and Scaled ensembles, run twice. Usage: python train.py --data /path/to/Zimbabwe_Crop_Engineered_Ready.csv --out ./artifacts --variant Raw python train.py --data /path/to/Zimbabwe_Crop_Engineered_Ready.csv --out ./artifacts --variant Scaled Optional (MinIO): export MINIO_ENDPOINT=... export MINIO_ACCESS_KEY=... export MINIO_SECRET_KEY=... export MINIO_BUCKET=geocrop-models python train.py ... --upload-minio """ from __future__ import annotations import argparse import json import os import warnings from dataclasses import asdict from pathlib import Path from typing import Dict, List, Tuple import joblib import numpy as np import pandas as pd from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.ensemble import RandomForestClassifier, VotingClassifier from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder, StandardScaler import xgboost as xgb import lightgbm as lgb from catboost import CatBoostClassifier from config import TrainingConfig from features import ( drop_junk_columns, scout_feature_selection, scale_numeric_features, ) # ----------------------------- # Warnings # ----------------------------- warnings.simplefilter(action="ignore", category=FutureWarning) warnings.simplefilter(action="ignore", category=pd.errors.PerformanceWarning) warnings.filterwarnings("ignore", category=UserWarning) # ========================================== # 0. GENERIC COMPATIBILITY WRAPPER # ========================================== class Sklearn6CompatibilityWrapper(BaseEstimator, ClassifierMixin): """Wrap 3rd-party classifiers for sklearn 1.6+ compatibility.""" _estimator_type = "classifier" def __init__(self, model_class=None, **kwargs): self.model_class = model_class self.kwargs = kwargs self.model = None def fit(self, X, y): self.model = self.model_class(**self.kwargs) self.model.fit(X, y) if hasattr(self.model, "classes_"): self.classes_ = self.model.classes_ return self def predict(self, X): return self.model.predict(X) def predict_proba(self, X): return self.model.predict_proba(X) @property def feature_importances_(self): return self.model.feature_importances_ if self.model else None def get_params(self, deep=True): return {"model_class": self.model_class, **self.kwargs} def set_params(self, **parameters): for parameter, value in parameters.items(): if parameter == "model_class": self.model_class = value else: self.kwargs[parameter] = value return self def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.estimator_type = "classifier" return tags # ========================================== # Training # ========================================== def build_models(cfg: TrainingConfig) -> Dict[str, BaseEstimator]: """Return model dictionary matching your original settings.""" return { "RandomForest": RandomForestClassifier( n_estimators=cfg.rf_n_estimators, n_jobs=-1, random_state=cfg.random_state, class_weight="balanced", ), "XGBoost": Sklearn6CompatibilityWrapper( model_class=xgb.XGBClassifier, n_estimators=cfg.xgb_n_estimators, learning_rate=cfg.xgb_learning_rate, max_depth=cfg.xgb_max_depth, subsample=cfg.xgb_subsample, colsample_bytree=cfg.xgb_colsample_bytree, eval_metric="mlogloss", n_jobs=-1, random_state=cfg.random_state, ), "LightGBM": Sklearn6CompatibilityWrapper( model_class=lgb.LGBMClassifier, n_estimators=cfg.lgb_n_estimators, learning_rate=cfg.lgb_learning_rate, num_leaves=cfg.lgb_num_leaves, subsample=cfg.lgb_subsample, colsample_bytree=cfg.lgb_colsample_bytree, min_child_samples=cfg.lgb_min_child_samples, class_weight="balanced", n_jobs=-1, random_state=cfg.random_state, verbose=-1, ), "CatBoost": Sklearn6CompatibilityWrapper( model_class=CatBoostClassifier, iterations=cfg.cb_iterations, learning_rate=cfg.cb_learning_rate, depth=cfg.cb_depth, verbose=0, random_seed=cfg.random_state, auto_class_weights="Balanced", allow_writing_files=False, ), } def evaluate(y_true: np.ndarray, y_pred: np.ndarray, label_names: List[str]) -> Dict: acc = float(accuracy_score(y_true, y_pred)) f1m = float(f1_score(y_true, y_pred, average="macro")) report = classification_report( y_true, y_pred, target_names=label_names, output_dict=True, zero_division=0 ) cm = confusion_matrix(y_true, y_pred).tolist() return {"accuracy": acc, "f1_macro": f1m, "report": report, "confusion": cm} def train_one_variant( df: pd.DataFrame, cfg: TrainingConfig, variant: str, out_dir: Path, ) -> Path: out_dir.mkdir(parents=True, exist_ok=True) df_clean = drop_junk_columns(df, cfg.junk_cols) if cfg.label_col not in df_clean.columns: raise ValueError(f"Missing label column '{cfg.label_col}'") X = df_clean.drop(columns=[cfg.label_col]) y = df_clean[cfg.label_col] le = LabelEncoder() y_enc = le.fit_transform(y) class_names = le.classes_.tolist() X_train, X_test, y_train, y_test = train_test_split( X, y_enc, test_size=cfg.test_size, random_state=cfg.random_state, stratify=y_enc, ) selected_features = scout_feature_selection( X_train, y_train, n_estimators=cfg.scout_n_estimators, random_state=cfg.random_state ) X_train = X_train[selected_features] X_test = X_test[selected_features] scaler = None if variant.lower() == "scaled": X_train, X_test, scaler = scale_numeric_features(X_train, X_test) models = build_models(cfg) metrics: Dict[str, Dict] = {} trained: Dict[str, BaseEstimator] = {} for name, model in models.items(): model.fit(X_train, y_train) preds = model.predict(X_test) metrics[name] = evaluate(y_test, preds, class_names) trained[name] = model ensemble = VotingClassifier( estimators=[(n, m) for n, m in trained.items()], voting="soft", n_jobs=-1 ) ensemble.fit(X_train, y_train) ens_preds = ensemble.predict(X_test) metrics["Ensemble"] = evaluate(y_test, ens_preds, class_names) # Persist artifacts needed for inference artifact_dir = out_dir / f"model_{variant.lower()}" artifact_dir.mkdir(parents=True, exist_ok=True) joblib.dump(ensemble, artifact_dir / "model.joblib") joblib.dump(le, artifact_dir / "label_encoder.joblib") if scaler is not None: joblib.dump(scaler, artifact_dir / "scaler.joblib") (artifact_dir / "selected_features.json").write_text( json.dumps(selected_features, indent=2) ) meta = { "variant": variant, "class_names": class_names, "n_features": len(selected_features), "config": asdict(cfg), } (artifact_dir / "meta.json").write_text(json.dumps(meta, indent=2)) (artifact_dir / "metrics.json").write_text(json.dumps(metrics, indent=2)) return artifact_dir def maybe_upload_to_minio(artifact_dir: Path, cfg: TrainingConfig): if not cfg.upload_minio: return try: import boto3 from botocore.client import Config except Exception as e: raise RuntimeError("boto3 is required for MinIO upload") from e s3 = boto3.client( "s3", endpoint_url=cfg.minio_endpoint, aws_access_key_id=cfg.minio_access_key, aws_secret_access_key=cfg.minio_secret_key, config=Config(signature_version="s3v4"), region_name="us-east-1", ) # Ensure bucket exists try: s3.head_bucket(Bucket=cfg.minio_bucket) except Exception: s3.create_bucket(Bucket=cfg.minio_bucket) prefix = f"{cfg.minio_prefix}/{artifact_dir.name}" for p in artifact_dir.rglob("*"): if p.is_file(): key = f"{prefix}/{p.name}" s3.upload_file(str(p), cfg.minio_bucket, key) def main(): parser = argparse.ArgumentParser() parser.add_argument("--data", required=True, help="CSV path: Zimbabwe_Crop_Engineered_Ready.csv") parser.add_argument("--out", required=True, help="Output directory for artifacts") parser.add_argument("--variant", choices=["Raw", "Scaled"], default="Raw") parser.add_argument("--upload-minio", action="store_true") args = parser.parse_args() cfg = TrainingConfig( upload_minio=args.upload_minio, minio_endpoint=os.getenv("MINIO_ENDPOINT", ""), minio_access_key=os.getenv("MINIO_ACCESS_KEY", ""), minio_secret_key=os.getenv("MINIO_SECRET_KEY", ""), minio_bucket=os.getenv("MINIO_BUCKET", "geocrop-models"), minio_prefix=os.getenv("MINIO_PREFIX", "models"), ) df = pd.read_csv(args.data) out_dir = Path(args.out) artifact_dir = train_one_variant(df, cfg, args.variant, out_dir) maybe_upload_to_minio(artifact_dir, cfg) print(f"Saved artifacts to: {artifact_dir}") if __name__ == "__main__": main()