310 lines
9.8 KiB
Python
310 lines
9.8 KiB
Python
"""GeoCrop training entrypoint.
|
|
|
|
This is a cleaned and production-friendly version of your notebook/script.
|
|
It trains multiple models + a soft-voting ensemble, logs metrics, and uploads
|
|
artifacts (model, label encoder, selected feature list) for inference.
|
|
|
|
Notes:
|
|
- Keeps your sklearn 1.6+ compatibility wrapper.
|
|
- Stores metadata needed for inference (classes, features, scaling decision).
|
|
- If you want to keep both Raw and Scaled ensembles, run twice.
|
|
|
|
Usage:
|
|
python train.py --data /path/to/Zimbabwe_Crop_Engineered_Ready.csv --out ./artifacts --variant Raw
|
|
python train.py --data /path/to/Zimbabwe_Crop_Engineered_Ready.csv --out ./artifacts --variant Scaled
|
|
|
|
Optional (MinIO):
|
|
export MINIO_ENDPOINT=...
|
|
export MINIO_ACCESS_KEY=...
|
|
export MINIO_SECRET_KEY=...
|
|
export MINIO_BUCKET=geocrop-models
|
|
python train.py ... --upload-minio
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import warnings
|
|
from dataclasses import asdict
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple
|
|
|
|
import joblib
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from sklearn.base import BaseEstimator, ClassifierMixin
|
|
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
|
|
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.preprocessing import LabelEncoder, StandardScaler
|
|
|
|
import xgboost as xgb
|
|
import lightgbm as lgb
|
|
from catboost import CatBoostClassifier
|
|
|
|
from config import TrainingConfig
|
|
from features import (
|
|
drop_junk_columns,
|
|
scout_feature_selection,
|
|
scale_numeric_features,
|
|
)
|
|
|
|
# -----------------------------
|
|
# Warnings
|
|
# -----------------------------
|
|
warnings.simplefilter(action="ignore", category=FutureWarning)
|
|
warnings.simplefilter(action="ignore", category=pd.errors.PerformanceWarning)
|
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
|
|
|
|
# ==========================================
|
|
# 0. GENERIC COMPATIBILITY WRAPPER
|
|
# ==========================================
|
|
class Sklearn6CompatibilityWrapper(BaseEstimator, ClassifierMixin):
|
|
"""Wrap 3rd-party classifiers for sklearn 1.6+ compatibility."""
|
|
|
|
_estimator_type = "classifier"
|
|
|
|
def __init__(self, model_class=None, **kwargs):
|
|
self.model_class = model_class
|
|
self.kwargs = kwargs
|
|
self.model = None
|
|
|
|
def fit(self, X, y):
|
|
self.model = self.model_class(**self.kwargs)
|
|
self.model.fit(X, y)
|
|
if hasattr(self.model, "classes_"):
|
|
self.classes_ = self.model.classes_
|
|
return self
|
|
|
|
def predict(self, X):
|
|
return self.model.predict(X)
|
|
|
|
def predict_proba(self, X):
|
|
return self.model.predict_proba(X)
|
|
|
|
@property
|
|
def feature_importances_(self):
|
|
return self.model.feature_importances_ if self.model else None
|
|
|
|
def get_params(self, deep=True):
|
|
return {"model_class": self.model_class, **self.kwargs}
|
|
|
|
def set_params(self, **parameters):
|
|
for parameter, value in parameters.items():
|
|
if parameter == "model_class":
|
|
self.model_class = value
|
|
else:
|
|
self.kwargs[parameter] = value
|
|
return self
|
|
|
|
def __sklearn_tags__(self):
|
|
tags = super().__sklearn_tags__()
|
|
tags.estimator_type = "classifier"
|
|
return tags
|
|
|
|
|
|
# ==========================================
|
|
# Training
|
|
# ==========================================
|
|
|
|
def build_models(cfg: TrainingConfig) -> Dict[str, BaseEstimator]:
|
|
"""Return model dictionary matching your original settings."""
|
|
return {
|
|
"RandomForest": RandomForestClassifier(
|
|
n_estimators=cfg.rf_n_estimators,
|
|
n_jobs=-1,
|
|
random_state=cfg.random_state,
|
|
class_weight="balanced",
|
|
),
|
|
"XGBoost": Sklearn6CompatibilityWrapper(
|
|
model_class=xgb.XGBClassifier,
|
|
n_estimators=cfg.xgb_n_estimators,
|
|
learning_rate=cfg.xgb_learning_rate,
|
|
max_depth=cfg.xgb_max_depth,
|
|
subsample=cfg.xgb_subsample,
|
|
colsample_bytree=cfg.xgb_colsample_bytree,
|
|
eval_metric="mlogloss",
|
|
n_jobs=-1,
|
|
random_state=cfg.random_state,
|
|
),
|
|
"LightGBM": Sklearn6CompatibilityWrapper(
|
|
model_class=lgb.LGBMClassifier,
|
|
n_estimators=cfg.lgb_n_estimators,
|
|
learning_rate=cfg.lgb_learning_rate,
|
|
num_leaves=cfg.lgb_num_leaves,
|
|
subsample=cfg.lgb_subsample,
|
|
colsample_bytree=cfg.lgb_colsample_bytree,
|
|
min_child_samples=cfg.lgb_min_child_samples,
|
|
class_weight="balanced",
|
|
n_jobs=-1,
|
|
random_state=cfg.random_state,
|
|
verbose=-1,
|
|
),
|
|
"CatBoost": Sklearn6CompatibilityWrapper(
|
|
model_class=CatBoostClassifier,
|
|
iterations=cfg.cb_iterations,
|
|
learning_rate=cfg.cb_learning_rate,
|
|
depth=cfg.cb_depth,
|
|
verbose=0,
|
|
random_seed=cfg.random_state,
|
|
auto_class_weights="Balanced",
|
|
allow_writing_files=False,
|
|
),
|
|
}
|
|
|
|
|
|
def evaluate(y_true: np.ndarray, y_pred: np.ndarray, label_names: List[str]) -> Dict:
|
|
acc = float(accuracy_score(y_true, y_pred))
|
|
f1m = float(f1_score(y_true, y_pred, average="macro"))
|
|
report = classification_report(
|
|
y_true, y_pred, target_names=label_names, output_dict=True, zero_division=0
|
|
)
|
|
cm = confusion_matrix(y_true, y_pred).tolist()
|
|
return {"accuracy": acc, "f1_macro": f1m, "report": report, "confusion": cm}
|
|
|
|
|
|
def train_one_variant(
|
|
df: pd.DataFrame,
|
|
cfg: TrainingConfig,
|
|
variant: str,
|
|
out_dir: Path,
|
|
) -> Path:
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
df_clean = drop_junk_columns(df, cfg.junk_cols)
|
|
if cfg.label_col not in df_clean.columns:
|
|
raise ValueError(f"Missing label column '{cfg.label_col}'")
|
|
|
|
X = df_clean.drop(columns=[cfg.label_col])
|
|
y = df_clean[cfg.label_col]
|
|
|
|
le = LabelEncoder()
|
|
y_enc = le.fit_transform(y)
|
|
class_names = le.classes_.tolist()
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split(
|
|
X,
|
|
y_enc,
|
|
test_size=cfg.test_size,
|
|
random_state=cfg.random_state,
|
|
stratify=y_enc,
|
|
)
|
|
|
|
selected_features = scout_feature_selection(
|
|
X_train, y_train, n_estimators=cfg.scout_n_estimators, random_state=cfg.random_state
|
|
)
|
|
X_train = X_train[selected_features]
|
|
X_test = X_test[selected_features]
|
|
|
|
scaler = None
|
|
if variant.lower() == "scaled":
|
|
X_train, X_test, scaler = scale_numeric_features(X_train, X_test)
|
|
|
|
models = build_models(cfg)
|
|
metrics: Dict[str, Dict] = {}
|
|
trained: Dict[str, BaseEstimator] = {}
|
|
|
|
for name, model in models.items():
|
|
model.fit(X_train, y_train)
|
|
preds = model.predict(X_test)
|
|
metrics[name] = evaluate(y_test, preds, class_names)
|
|
trained[name] = model
|
|
|
|
ensemble = VotingClassifier(
|
|
estimators=[(n, m) for n, m in trained.items()], voting="soft", n_jobs=-1
|
|
)
|
|
ensemble.fit(X_train, y_train)
|
|
ens_preds = ensemble.predict(X_test)
|
|
metrics["Ensemble"] = evaluate(y_test, ens_preds, class_names)
|
|
|
|
# Persist artifacts needed for inference
|
|
artifact_dir = out_dir / f"model_{variant.lower()}"
|
|
artifact_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
joblib.dump(ensemble, artifact_dir / "model.joblib")
|
|
joblib.dump(le, artifact_dir / "label_encoder.joblib")
|
|
|
|
if scaler is not None:
|
|
joblib.dump(scaler, artifact_dir / "scaler.joblib")
|
|
|
|
(artifact_dir / "selected_features.json").write_text(
|
|
json.dumps(selected_features, indent=2)
|
|
)
|
|
|
|
meta = {
|
|
"variant": variant,
|
|
"class_names": class_names,
|
|
"n_features": len(selected_features),
|
|
"config": asdict(cfg),
|
|
}
|
|
(artifact_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
|
(artifact_dir / "metrics.json").write_text(json.dumps(metrics, indent=2))
|
|
|
|
return artifact_dir
|
|
|
|
|
|
def maybe_upload_to_minio(artifact_dir: Path, cfg: TrainingConfig):
|
|
if not cfg.upload_minio:
|
|
return
|
|
try:
|
|
import boto3
|
|
from botocore.client import Config
|
|
except Exception as e:
|
|
raise RuntimeError("boto3 is required for MinIO upload") from e
|
|
|
|
s3 = boto3.client(
|
|
"s3",
|
|
endpoint_url=cfg.minio_endpoint,
|
|
aws_access_key_id=cfg.minio_access_key,
|
|
aws_secret_access_key=cfg.minio_secret_key,
|
|
config=Config(signature_version="s3v4"),
|
|
region_name="us-east-1",
|
|
)
|
|
|
|
# Ensure bucket exists
|
|
try:
|
|
s3.head_bucket(Bucket=cfg.minio_bucket)
|
|
except Exception:
|
|
s3.create_bucket(Bucket=cfg.minio_bucket)
|
|
|
|
prefix = f"{cfg.minio_prefix}/{artifact_dir.name}"
|
|
for p in artifact_dir.rglob("*"):
|
|
if p.is_file():
|
|
key = f"{prefix}/{p.name}"
|
|
s3.upload_file(str(p), cfg.minio_bucket, key)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--data", required=True, help="CSV path: Zimbabwe_Crop_Engineered_Ready.csv")
|
|
parser.add_argument("--out", required=True, help="Output directory for artifacts")
|
|
parser.add_argument("--variant", choices=["Raw", "Scaled"], default="Raw")
|
|
|
|
parser.add_argument("--upload-minio", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
cfg = TrainingConfig(
|
|
upload_minio=args.upload_minio,
|
|
minio_endpoint=os.getenv("MINIO_ENDPOINT", ""),
|
|
minio_access_key=os.getenv("MINIO_ACCESS_KEY", ""),
|
|
minio_secret_key=os.getenv("MINIO_SECRET_KEY", ""),
|
|
minio_bucket=os.getenv("MINIO_BUCKET", "geocrop-models"),
|
|
minio_prefix=os.getenv("MINIO_PREFIX", "models"),
|
|
)
|
|
|
|
df = pd.read_csv(args.data)
|
|
out_dir = Path(args.out)
|
|
|
|
artifact_dir = train_one_variant(df, cfg, args.variant, out_dir)
|
|
maybe_upload_to_minio(artifact_dir, cfg)
|
|
|
|
print(f"Saved artifacts to: {artifact_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|