geocrop-platform./training/config.py

197 lines
5.6 KiB
Python

"""Central configuration for GeoCrop.
This file keeps ALL constants and environment wiring in one place.
It also defines a StorageAdapter interface so you can swap:
- local filesystem (dev)
- MinIO S3 (prod)
Roo Code can extend this with:
- Zimbabwe polygon path
- DEA STAC collection/band config
- model registry
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from datetime import date
from pathlib import Path
from typing import Dict, Optional, Tuple
# ==========================================
# Training config
# ==========================================
@dataclass
class TrainingConfig:
# Dataset
label_col: str = "label"
junk_cols: list = field(
default_factory=lambda: [
".geo",
"system:index",
"latitude",
"longitude",
"lat",
"lon",
"ID",
"parent_id",
"batch_id",
"is_syn",
]
)
# Split
test_size: float = 0.2
random_state: int = 42
# Scout
scout_n_estimators: int = 100
# Models (match your original hyperparams)
rf_n_estimators: int = 200
xgb_n_estimators: int = 300
xgb_learning_rate: float = 0.05
xgb_max_depth: int = 7
xgb_subsample: float = 0.8
xgb_colsample_bytree: float = 0.8
lgb_n_estimators: int = 800
lgb_learning_rate: float = 0.03
lgb_num_leaves: int = 63
lgb_subsample: float = 0.8
lgb_colsample_bytree: float = 0.8
lgb_min_child_samples: int = 30
cb_iterations: int = 500
cb_learning_rate: float = 0.05
cb_depth: int = 6
# Artifact upload
upload_minio: bool = False
minio_endpoint: str = ""
minio_access_key: str = ""
minio_secret_key: str = ""
minio_bucket: str = "geocrop-models"
minio_prefix: str = "models"
# ==========================================
# Inference config
# ==========================================
class StorageAdapter:
"""Abstract interface used by inference.
Roo Code should implement a MinIO-backed adapter.
"""
def download_model_bundle(self, model_key: str, dest_dir: Path):
raise NotImplementedError
def get_dw_local_path(self, year: int, season: str) -> str:
"""Return local filepath to DW baseline COG for given year/season.
In prod you might download on-demand or mount a shared volume.
"""
raise NotImplementedError
def upload_result(self, local_path: Path, key: str) -> str:
"""Upload a file and return a URI (s3://... or https://signed-url)."""
raise NotImplementedError
def write_layer_geotiff(self, out_path: Path, arr, profile: dict):
"""Write a 1-band or 3-band GeoTIFF aligned to profile."""
import rasterio
if arr.ndim == 2:
count = 1
elif arr.ndim == 3 and arr.shape[2] == 3:
count = 3
else:
raise ValueError("arr must be (H,W) or (H,W,3)")
prof = profile.copy()
prof.update({"count": count})
with rasterio.open(out_path, "w", **prof) as dst:
if count == 1:
dst.write(arr, 1)
else:
# (H,W,3) -> (3,H,W)
dst.write(arr.transpose(2, 0, 1))
@dataclass
class InferenceConfig:
# Constraints
max_radius_m: float = 5000.0
# Season window (YOU asked to use Sep -> May)
# We'll interpret "year" as the first year in the season.
# Example: year=2019 -> season 2019-09-01 to 2020-05-31
summer_start_month: int = 9
summer_start_day: int = 1
summer_end_month: int = 5
summer_end_day: int = 31
smoothing_enabled: bool = True
smoothing_kernel: int = 3
# DEA STAC
dea_root: str = "https://explorer.digitalearth.africa/stac"
dea_search: str = "https://explorer.digitalearth.africa/stac/search"
# Storage adapter
storage: StorageAdapter = None
def season_dates(self, year: int, season: str = "summer") -> Tuple[str, str]:
if season.lower() != "summer":
raise ValueError("Only summer season supported for now")
start = date(year, self.summer_start_month, self.summer_start_day)
end = date(year + 1, self.summer_end_month, self.summer_end_day)
return start.isoformat(), end.isoformat()
# ==========================================
# Example local dev adapter
# ==========================================
class LocalStorage(StorageAdapter):
"""Simple dev adapter using local filesystem."""
def __init__(self, base_dir: str = "/data/geocrop"):
self.base = Path(base_dir)
self.base.mkdir(parents=True, exist_ok=True)
(self.base / "results").mkdir(exist_ok=True)
(self.base / "models").mkdir(exist_ok=True)
(self.base / "dw").mkdir(exist_ok=True)
def download_model_bundle(self, model_key: str, dest_dir: Path):
src = self.base / "models" / model_key
if not src.exists():
raise FileNotFoundError(f"Missing local model bundle: {src}")
dest_dir.mkdir(parents=True, exist_ok=True)
for p in src.iterdir():
if p.is_file():
(dest_dir / p.name).write_bytes(p.read_bytes())
def get_dw_local_path(self, year: int, season: str) -> str:
p = self.base / "dw" / f"dw_{season}_{year}.tif"
if not p.exists():
raise FileNotFoundError(f"Missing DW baseline: {p}")
return str(p)
def upload_result(self, local_path: Path, key: str) -> str:
dest = self.base / key
dest.parent.mkdir(parents=True, exist_ok=True)
dest.write_bytes(local_path.read_bytes())
return f"file://{dest}"