197 lines
5.6 KiB
Python
197 lines
5.6 KiB
Python
"""Central configuration for GeoCrop.
|
|
|
|
This file keeps ALL constants and environment wiring in one place.
|
|
It also defines a StorageAdapter interface so you can swap:
|
|
- local filesystem (dev)
|
|
- MinIO S3 (prod)
|
|
|
|
Roo Code can extend this with:
|
|
- Zimbabwe polygon path
|
|
- DEA STAC collection/band config
|
|
- model registry
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from datetime import date
|
|
from pathlib import Path
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
|
|
# ==========================================
|
|
# Training config
|
|
# ==========================================
|
|
|
|
|
|
@dataclass
|
|
class TrainingConfig:
|
|
# Dataset
|
|
label_col: str = "label"
|
|
junk_cols: list = field(
|
|
default_factory=lambda: [
|
|
".geo",
|
|
"system:index",
|
|
"latitude",
|
|
"longitude",
|
|
"lat",
|
|
"lon",
|
|
"ID",
|
|
"parent_id",
|
|
"batch_id",
|
|
"is_syn",
|
|
]
|
|
)
|
|
|
|
# Split
|
|
test_size: float = 0.2
|
|
random_state: int = 42
|
|
|
|
# Scout
|
|
scout_n_estimators: int = 100
|
|
|
|
# Models (match your original hyperparams)
|
|
rf_n_estimators: int = 200
|
|
|
|
xgb_n_estimators: int = 300
|
|
xgb_learning_rate: float = 0.05
|
|
xgb_max_depth: int = 7
|
|
xgb_subsample: float = 0.8
|
|
xgb_colsample_bytree: float = 0.8
|
|
|
|
lgb_n_estimators: int = 800
|
|
lgb_learning_rate: float = 0.03
|
|
lgb_num_leaves: int = 63
|
|
lgb_subsample: float = 0.8
|
|
lgb_colsample_bytree: float = 0.8
|
|
lgb_min_child_samples: int = 30
|
|
|
|
cb_iterations: int = 500
|
|
cb_learning_rate: float = 0.05
|
|
cb_depth: int = 6
|
|
|
|
# Artifact upload
|
|
upload_minio: bool = False
|
|
minio_endpoint: str = ""
|
|
minio_access_key: str = ""
|
|
minio_secret_key: str = ""
|
|
minio_bucket: str = "geocrop-models"
|
|
minio_prefix: str = "models"
|
|
|
|
|
|
# ==========================================
|
|
# Inference config
|
|
# ==========================================
|
|
|
|
|
|
class StorageAdapter:
|
|
"""Abstract interface used by inference.
|
|
|
|
Roo Code should implement a MinIO-backed adapter.
|
|
"""
|
|
|
|
def download_model_bundle(self, model_key: str, dest_dir: Path):
|
|
raise NotImplementedError
|
|
|
|
def get_dw_local_path(self, year: int, season: str) -> str:
|
|
"""Return local filepath to DW baseline COG for given year/season.
|
|
|
|
In prod you might download on-demand or mount a shared volume.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def upload_result(self, local_path: Path, key: str) -> str:
|
|
"""Upload a file and return a URI (s3://... or https://signed-url)."""
|
|
raise NotImplementedError
|
|
|
|
def write_layer_geotiff(self, out_path: Path, arr, profile: dict):
|
|
"""Write a 1-band or 3-band GeoTIFF aligned to profile."""
|
|
import rasterio
|
|
|
|
if arr.ndim == 2:
|
|
count = 1
|
|
elif arr.ndim == 3 and arr.shape[2] == 3:
|
|
count = 3
|
|
else:
|
|
raise ValueError("arr must be (H,W) or (H,W,3)")
|
|
|
|
prof = profile.copy()
|
|
prof.update({"count": count})
|
|
|
|
with rasterio.open(out_path, "w", **prof) as dst:
|
|
if count == 1:
|
|
dst.write(arr, 1)
|
|
else:
|
|
# (H,W,3) -> (3,H,W)
|
|
dst.write(arr.transpose(2, 0, 1))
|
|
|
|
|
|
@dataclass
|
|
class InferenceConfig:
|
|
# Constraints
|
|
max_radius_m: float = 5000.0
|
|
|
|
# Season window (YOU asked to use Sep -> May)
|
|
# We'll interpret "year" as the first year in the season.
|
|
# Example: year=2019 -> season 2019-09-01 to 2020-05-31
|
|
summer_start_month: int = 9
|
|
summer_start_day: int = 1
|
|
summer_end_month: int = 5
|
|
summer_end_day: int = 31
|
|
|
|
smoothing_enabled: bool = True
|
|
smoothing_kernel: int = 3
|
|
|
|
# DEA STAC
|
|
dea_root: str = "https://explorer.digitalearth.africa/stac"
|
|
dea_search: str = "https://explorer.digitalearth.africa/stac/search"
|
|
|
|
# Storage adapter
|
|
storage: StorageAdapter = None
|
|
|
|
def season_dates(self, year: int, season: str = "summer") -> Tuple[str, str]:
|
|
if season.lower() != "summer":
|
|
raise ValueError("Only summer season supported for now")
|
|
|
|
start = date(year, self.summer_start_month, self.summer_start_day)
|
|
end = date(year + 1, self.summer_end_month, self.summer_end_day)
|
|
return start.isoformat(), end.isoformat()
|
|
|
|
|
|
# ==========================================
|
|
# Example local dev adapter
|
|
# ==========================================
|
|
|
|
|
|
class LocalStorage(StorageAdapter):
|
|
"""Simple dev adapter using local filesystem."""
|
|
|
|
def __init__(self, base_dir: str = "/data/geocrop"):
|
|
self.base = Path(base_dir)
|
|
self.base.mkdir(parents=True, exist_ok=True)
|
|
(self.base / "results").mkdir(exist_ok=True)
|
|
(self.base / "models").mkdir(exist_ok=True)
|
|
(self.base / "dw").mkdir(exist_ok=True)
|
|
|
|
def download_model_bundle(self, model_key: str, dest_dir: Path):
|
|
src = self.base / "models" / model_key
|
|
if not src.exists():
|
|
raise FileNotFoundError(f"Missing local model bundle: {src}")
|
|
dest_dir.mkdir(parents=True, exist_ok=True)
|
|
for p in src.iterdir():
|
|
if p.is_file():
|
|
(dest_dir / p.name).write_bytes(p.read_bytes())
|
|
|
|
def get_dw_local_path(self, year: int, season: str) -> str:
|
|
p = self.base / "dw" / f"dw_{season}_{year}.tif"
|
|
if not p.exists():
|
|
raise FileNotFoundError(f"Missing DW baseline: {p}")
|
|
return str(p)
|
|
|
|
def upload_result(self, local_path: Path, key: str) -> str:
|
|
dest = self.base / key
|
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
dest.write_bytes(local_path.read_bytes())
|
|
return f"file://{dest}"
|