336 lines
10 KiB
Python
336 lines
10 KiB
Python
"""Central configuration for GeoCrop.
|
|
|
|
This file keeps ALL constants and environment wiring in one place.
|
|
It also defines a StorageAdapter interface so you can swap:
|
|
- local filesystem (dev)
|
|
- MinIO S3 (prod)
|
|
|
|
Roo Code can extend this with:
|
|
- Zimbabwe polygon path
|
|
- DEA STAC collection/band config
|
|
- model registry
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from datetime import date
|
|
from pathlib import Path
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
|
|
# ==========================================
|
|
# Training config
|
|
# ==========================================
|
|
|
|
|
|
@dataclass
|
|
class TrainingConfig:
|
|
# Dataset
|
|
label_col: str = "label"
|
|
junk_cols: list = field(
|
|
default_factory=lambda: [
|
|
".geo",
|
|
"system:index",
|
|
"latitude",
|
|
"longitude",
|
|
"lat",
|
|
"lon",
|
|
"ID",
|
|
"parent_id",
|
|
"batch_id",
|
|
"is_syn",
|
|
]
|
|
)
|
|
|
|
# Split
|
|
test_size: float = 0.2
|
|
random_state: int = 42
|
|
|
|
# Scout
|
|
scout_n_estimators: int = 100
|
|
|
|
# Models (match your original hyperparams)
|
|
rf_n_estimators: int = 200
|
|
|
|
xgb_n_estimators: int = 300
|
|
xgb_learning_rate: float = 0.05
|
|
xgb_max_depth: int = 7
|
|
xgb_subsample: float = 0.8
|
|
xgb_colsample_bytree: float = 0.8
|
|
|
|
lgb_n_estimators: int = 800
|
|
lgb_learning_rate: float = 0.03
|
|
lgb_num_leaves: int = 63
|
|
lgb_subsample: float = 0.8
|
|
lgb_colsample_bytree: float = 0.8
|
|
lgb_min_child_samples: int = 30
|
|
|
|
cb_iterations: int = 500
|
|
cb_learning_rate: float = 0.05
|
|
cb_depth: int = 6
|
|
|
|
# Artifact upload
|
|
upload_minio: bool = False
|
|
minio_endpoint: str = ""
|
|
minio_access_key: str = ""
|
|
minio_secret_key: str = ""
|
|
minio_bucket: str = "geocrop-models"
|
|
minio_prefix: str = "models"
|
|
|
|
|
|
# ==========================================
|
|
# Inference config
|
|
# ==========================================
|
|
|
|
|
|
class StorageAdapter:
|
|
"""Abstract interface used by inference.
|
|
|
|
Roo Code should implement a MinIO-backed adapter.
|
|
"""
|
|
|
|
def download_model_bundle(self, model_key: str, dest_dir: Path):
|
|
raise NotImplementedError
|
|
|
|
def get_dw_local_path(self, year: int, season: str) -> str:
|
|
"""Return local filepath to DW baseline COG for given year/season.
|
|
|
|
In prod you might download on-demand or mount a shared volume.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def upload_result(self, local_path: Path, key: str) -> str:
|
|
"""Upload a file and return a URI (s3://... or https://signed-url)."""
|
|
raise NotImplementedError
|
|
|
|
def write_layer_geotiff(self, out_path: Path, arr, profile: dict):
|
|
"""Write a 1-band or 3-band GeoTIFF aligned to profile."""
|
|
import rasterio
|
|
|
|
if arr.ndim == 2:
|
|
count = 1
|
|
elif arr.ndim == 3 and arr.shape[2] == 3:
|
|
count = 3
|
|
else:
|
|
raise ValueError("arr must be (H,W) or (H,W,3)")
|
|
|
|
prof = profile.copy()
|
|
prof.update({"count": count})
|
|
|
|
with rasterio.open(out_path, "w", **prof) as dst:
|
|
if count == 1:
|
|
dst.write(arr, 1)
|
|
else:
|
|
# (H,W,3) -> (3,H,W)
|
|
dst.write(arr.transpose(2, 0, 1))
|
|
|
|
|
|
class MinIOStorage(StorageAdapter):
|
|
"""MinIO/S3-backed storage adapter for production.
|
|
|
|
Supports:
|
|
- Model artifact downloading (from geocrop-models bucket)
|
|
- DW baseline access (from geocrop-baselines bucket)
|
|
- Result uploads (to geocrop-results bucket)
|
|
- Presigned URL generation
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
endpoint: str = "minio.geocrop.svc.cluster.local:9000",
|
|
access_key: str = None,
|
|
secret_key: str = None,
|
|
bucket_models: str = "geocrop-models",
|
|
bucket_baselines: str = "geocrop-baselines",
|
|
bucket_results: str = "geocrop-results",
|
|
):
|
|
self.endpoint = endpoint
|
|
self.access_key = access_key or os.getenv("MINIO_ACCESS_KEY", "minioadmin")
|
|
self.secret_key = secret_key or os.getenv("MINIO_SECRET_KEY", "minioadmin")
|
|
self.bucket_models = bucket_models
|
|
self.bucket_baselines = bucket_baselines
|
|
self.bucket_results = bucket_results
|
|
|
|
# Lazy-load boto3
|
|
self._s3_client = None
|
|
|
|
@property
|
|
def s3(self):
|
|
"""Lazy-load S3 client."""
|
|
if self._s3_client is None:
|
|
import boto3
|
|
from botocore.config import Config
|
|
|
|
self._s3_client = boto3.client(
|
|
"s3",
|
|
endpoint_url=f"http://{self.endpoint}",
|
|
aws_access_key_id=self.access_key,
|
|
aws_secret_access_key=self.secret_key,
|
|
config=Config(signature_version="s3v4"),
|
|
region_name="us-east-1",
|
|
)
|
|
return self._s3_client
|
|
|
|
def download_model_bundle(self, model_key: str, dest_dir: Path):
|
|
"""Download model files from geocrop-models bucket.
|
|
|
|
Args:
|
|
model_key: Full key including prefix (e.g., "models/Zimbabwe_Ensemble_Raw_Model.pkl")
|
|
dest_dir: Local directory to save files
|
|
"""
|
|
dest_dir = Path(dest_dir)
|
|
dest_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Extract filename from key
|
|
filename = Path(model_key).name
|
|
local_path = dest_dir / filename
|
|
|
|
try:
|
|
print(f" Downloading s3://{self.bucket_models}/{model_key} -> {local_path}")
|
|
self.s3.download_file(
|
|
self.bucket_models,
|
|
model_key,
|
|
str(local_path)
|
|
)
|
|
except Exception as e:
|
|
raise FileNotFoundError(f"Failed to download model {model_key}: {e}") from e
|
|
|
|
def get_dw_local_path(self, year: int, season: str) -> str:
|
|
"""Get path to DW baseline COG for given year/season.
|
|
|
|
Returns a VSI S3 path for direct rasterio access.
|
|
|
|
Args:
|
|
year: Season start year (e.g., 2021 for 2021-2022 season)
|
|
season: Season type ("summer")
|
|
|
|
Returns:
|
|
VSI S3 path string (e.g., "s3://geocrop-baselines/DW_Zim_HighestConf_2021_2022-...")
|
|
"""
|
|
# Format: DW_Zim_HighestConf_{year}_{year+1}.tif
|
|
# Note: The actual files may have tile suffixes like -0000000000-0000000000.tif
|
|
# We'll return a prefix that rasterio can handle with wildcard
|
|
|
|
# For now, construct the base path
|
|
# In production, we might need to find the exact tiles
|
|
base_key = f"DW_Zim_HighestConf_{year}_{year + 1}"
|
|
|
|
# Return VSI path for rasterio to handle
|
|
return f"s3://{self.bucket_baselines}/{base_key}"
|
|
|
|
def upload_result(self, local_path: Path, key: str) -> str:
|
|
"""Upload result file to geocrop-results bucket.
|
|
|
|
Args:
|
|
local_path: Local file path
|
|
key: S3 key (e.g., "results/refined_2022.tif")
|
|
|
|
Returns:
|
|
S3 URI
|
|
"""
|
|
local_path = Path(local_path)
|
|
|
|
try:
|
|
self.s3.upload_file(
|
|
str(local_path),
|
|
self.bucket_results,
|
|
key
|
|
)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to upload {local_path}: {e}") from e
|
|
|
|
return f"s3://{self.bucket_results}/{key}"
|
|
|
|
def generate_presigned_url(self, bucket: str, key: str, expires: int = 3600) -> str:
|
|
"""Generate presigned URL for downloading.
|
|
|
|
Args:
|
|
bucket: Bucket name
|
|
key: S3 key
|
|
expires: URL expiration in seconds
|
|
|
|
Returns:
|
|
Presigned URL
|
|
"""
|
|
try:
|
|
url = self.s3.generate_presigned_url(
|
|
"get_object",
|
|
Params={"Bucket": bucket, "Key": key},
|
|
ExpiresIn=expires,
|
|
)
|
|
return url
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to generate presigned URL: {e}") from e
|
|
|
|
|
|
@dataclass
|
|
class InferenceConfig:
|
|
# Constraints
|
|
max_radius_m: float = 5000.0
|
|
|
|
# Season window (YOU asked to use Sep -> May)
|
|
# We'll interpret "year" as the first year in the season.
|
|
# Example: year=2019 -> season 2019-09-01 to 2020-05-31
|
|
summer_start_month: int = 9
|
|
summer_start_day: int = 1
|
|
summer_end_month: int = 5
|
|
summer_end_day: int = 31
|
|
|
|
smoothing_enabled: bool = True
|
|
smoothing_kernel: int = 3
|
|
|
|
# DEA STAC
|
|
dea_root: str = "https://explorer.digitalearth.africa/stac"
|
|
dea_search: str = "https://explorer.digitalearth.africa/stac/search"
|
|
dea_stac_url: str = "https://explorer.digitalearth.africa/stac"
|
|
|
|
# Storage adapter
|
|
storage: StorageAdapter = None
|
|
|
|
def season_dates(self, year: int, season: str = "summer") -> Tuple[str, str]:
|
|
if season.lower() != "summer":
|
|
raise ValueError("Only summer season supported for now")
|
|
|
|
start = date(year, self.summer_start_month, self.summer_start_day)
|
|
end = date(year + 1, self.summer_end_month, self.summer_end_day)
|
|
return start.isoformat(), end.isoformat()
|
|
|
|
|
|
# ==========================================
|
|
# Example local dev adapter
|
|
# ==========================================
|
|
|
|
|
|
class LocalStorage(StorageAdapter):
|
|
"""Simple dev adapter using local filesystem."""
|
|
|
|
def __init__(self, base_dir: str = "/data/geocrop"):
|
|
self.base = Path(base_dir)
|
|
self.base.mkdir(parents=True, exist_ok=True)
|
|
(self.base / "results").mkdir(exist_ok=True)
|
|
(self.base / "models").mkdir(exist_ok=True)
|
|
(self.base / "dw").mkdir(exist_ok=True)
|
|
|
|
def download_model_bundle(self, model_key: str, dest_dir: Path):
|
|
src = self.base / "models" / model_key
|
|
if not src.exists():
|
|
raise FileNotFoundError(f"Missing local model bundle: {src}")
|
|
dest_dir.mkdir(parents=True, exist_ok=True)
|
|
for p in src.iterdir():
|
|
if p.is_file():
|
|
(dest_dir / p.name).write_bytes(p.read_bytes())
|
|
|
|
def get_dw_local_path(self, year: int, season: str) -> str:
|
|
p = self.base / "dw" / f"dw_{season}_{year}.tif"
|
|
if not p.exists():
|
|
raise FileNotFoundError(f"Missing DW baseline: {p}")
|
|
return str(p)
|
|
|
|
def upload_result(self, local_path: Path, key: str) -> str:
|
|
dest = self.base / key
|
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
dest.write_bytes(local_path.read_bytes())
|
|
return f"file://{dest}"
|