geocrop-platform./apps/worker/config.py

336 lines
10 KiB
Python

"""Central configuration for GeoCrop.
This file keeps ALL constants and environment wiring in one place.
It also defines a StorageAdapter interface so you can swap:
- local filesystem (dev)
- MinIO S3 (prod)
Roo Code can extend this with:
- Zimbabwe polygon path
- DEA STAC collection/band config
- model registry
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from datetime import date
from pathlib import Path
from typing import Dict, Optional, Tuple
# ==========================================
# Training config
# ==========================================
@dataclass
class TrainingConfig:
# Dataset
label_col: str = "label"
junk_cols: list = field(
default_factory=lambda: [
".geo",
"system:index",
"latitude",
"longitude",
"lat",
"lon",
"ID",
"parent_id",
"batch_id",
"is_syn",
]
)
# Split
test_size: float = 0.2
random_state: int = 42
# Scout
scout_n_estimators: int = 100
# Models (match your original hyperparams)
rf_n_estimators: int = 200
xgb_n_estimators: int = 300
xgb_learning_rate: float = 0.05
xgb_max_depth: int = 7
xgb_subsample: float = 0.8
xgb_colsample_bytree: float = 0.8
lgb_n_estimators: int = 800
lgb_learning_rate: float = 0.03
lgb_num_leaves: int = 63
lgb_subsample: float = 0.8
lgb_colsample_bytree: float = 0.8
lgb_min_child_samples: int = 30
cb_iterations: int = 500
cb_learning_rate: float = 0.05
cb_depth: int = 6
# Artifact upload
upload_minio: bool = False
minio_endpoint: str = ""
minio_access_key: str = ""
minio_secret_key: str = ""
minio_bucket: str = "geocrop-models"
minio_prefix: str = "models"
# ==========================================
# Inference config
# ==========================================
class StorageAdapter:
"""Abstract interface used by inference.
Roo Code should implement a MinIO-backed adapter.
"""
def download_model_bundle(self, model_key: str, dest_dir: Path):
raise NotImplementedError
def get_dw_local_path(self, year: int, season: str) -> str:
"""Return local filepath to DW baseline COG for given year/season.
In prod you might download on-demand or mount a shared volume.
"""
raise NotImplementedError
def upload_result(self, local_path: Path, key: str) -> str:
"""Upload a file and return a URI (s3://... or https://signed-url)."""
raise NotImplementedError
def write_layer_geotiff(self, out_path: Path, arr, profile: dict):
"""Write a 1-band or 3-band GeoTIFF aligned to profile."""
import rasterio
if arr.ndim == 2:
count = 1
elif arr.ndim == 3 and arr.shape[2] == 3:
count = 3
else:
raise ValueError("arr must be (H,W) or (H,W,3)")
prof = profile.copy()
prof.update({"count": count})
with rasterio.open(out_path, "w", **prof) as dst:
if count == 1:
dst.write(arr, 1)
else:
# (H,W,3) -> (3,H,W)
dst.write(arr.transpose(2, 0, 1))
class MinIOStorage(StorageAdapter):
"""MinIO/S3-backed storage adapter for production.
Supports:
- Model artifact downloading (from geocrop-models bucket)
- DW baseline access (from geocrop-baselines bucket)
- Result uploads (to geocrop-results bucket)
- Presigned URL generation
"""
def __init__(
self,
endpoint: str = "minio.geocrop.svc.cluster.local:9000",
access_key: str = None,
secret_key: str = None,
bucket_models: str = "geocrop-models",
bucket_baselines: str = "geocrop-baselines",
bucket_results: str = "geocrop-results",
):
self.endpoint = endpoint
self.access_key = access_key or os.getenv("MINIO_ACCESS_KEY", "minioadmin")
self.secret_key = secret_key or os.getenv("MINIO_SECRET_KEY", "minioadmin")
self.bucket_models = bucket_models
self.bucket_baselines = bucket_baselines
self.bucket_results = bucket_results
# Lazy-load boto3
self._s3_client = None
@property
def s3(self):
"""Lazy-load S3 client."""
if self._s3_client is None:
import boto3
from botocore.config import Config
self._s3_client = boto3.client(
"s3",
endpoint_url=f"http://{self.endpoint}",
aws_access_key_id=self.access_key,
aws_secret_access_key=self.secret_key,
config=Config(signature_version="s3v4"),
region_name="us-east-1",
)
return self._s3_client
def download_model_bundle(self, model_key: str, dest_dir: Path):
"""Download model files from geocrop-models bucket.
Args:
model_key: Full key including prefix (e.g., "models/Zimbabwe_Ensemble_Raw_Model.pkl")
dest_dir: Local directory to save files
"""
dest_dir = Path(dest_dir)
dest_dir.mkdir(parents=True, exist_ok=True)
# Extract filename from key
filename = Path(model_key).name
local_path = dest_dir / filename
try:
print(f" Downloading s3://{self.bucket_models}/{model_key} -> {local_path}")
self.s3.download_file(
self.bucket_models,
model_key,
str(local_path)
)
except Exception as e:
raise FileNotFoundError(f"Failed to download model {model_key}: {e}") from e
def get_dw_local_path(self, year: int, season: str) -> str:
"""Get path to DW baseline COG for given year/season.
Returns a VSI S3 path for direct rasterio access.
Args:
year: Season start year (e.g., 2021 for 2021-2022 season)
season: Season type ("summer")
Returns:
VSI S3 path string (e.g., "s3://geocrop-baselines/DW_Zim_HighestConf_2021_2022-...")
"""
# Format: DW_Zim_HighestConf_{year}_{year+1}.tif
# Note: The actual files may have tile suffixes like -0000000000-0000000000.tif
# We'll return a prefix that rasterio can handle with wildcard
# For now, construct the base path
# In production, we might need to find the exact tiles
base_key = f"DW_Zim_HighestConf_{year}_{year + 1}"
# Return VSI path for rasterio to handle
return f"s3://{self.bucket_baselines}/{base_key}"
def upload_result(self, local_path: Path, key: str) -> str:
"""Upload result file to geocrop-results bucket.
Args:
local_path: Local file path
key: S3 key (e.g., "results/refined_2022.tif")
Returns:
S3 URI
"""
local_path = Path(local_path)
try:
self.s3.upload_file(
str(local_path),
self.bucket_results,
key
)
except Exception as e:
raise RuntimeError(f"Failed to upload {local_path}: {e}") from e
return f"s3://{self.bucket_results}/{key}"
def generate_presigned_url(self, bucket: str, key: str, expires: int = 3600) -> str:
"""Generate presigned URL for downloading.
Args:
bucket: Bucket name
key: S3 key
expires: URL expiration in seconds
Returns:
Presigned URL
"""
try:
url = self.s3.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": key},
ExpiresIn=expires,
)
return url
except Exception as e:
raise RuntimeError(f"Failed to generate presigned URL: {e}") from e
@dataclass
class InferenceConfig:
# Constraints
max_radius_m: float = 5000.0
# Season window (YOU asked to use Sep -> May)
# We'll interpret "year" as the first year in the season.
# Example: year=2019 -> season 2019-09-01 to 2020-05-31
summer_start_month: int = 9
summer_start_day: int = 1
summer_end_month: int = 5
summer_end_day: int = 31
smoothing_enabled: bool = True
smoothing_kernel: int = 3
# DEA STAC
dea_root: str = "https://explorer.digitalearth.africa/stac"
dea_search: str = "https://explorer.digitalearth.africa/stac/search"
dea_stac_url: str = "https://explorer.digitalearth.africa/stac"
# Storage adapter
storage: StorageAdapter = None
def season_dates(self, year: int, season: str = "summer") -> Tuple[str, str]:
if season.lower() != "summer":
raise ValueError("Only summer season supported for now")
start = date(year, self.summer_start_month, self.summer_start_day)
end = date(year + 1, self.summer_end_month, self.summer_end_day)
return start.isoformat(), end.isoformat()
# ==========================================
# Example local dev adapter
# ==========================================
class LocalStorage(StorageAdapter):
"""Simple dev adapter using local filesystem."""
def __init__(self, base_dir: str = "/data/geocrop"):
self.base = Path(base_dir)
self.base.mkdir(parents=True, exist_ok=True)
(self.base / "results").mkdir(exist_ok=True)
(self.base / "models").mkdir(exist_ok=True)
(self.base / "dw").mkdir(exist_ok=True)
def download_model_bundle(self, model_key: str, dest_dir: Path):
src = self.base / "models" / model_key
if not src.exists():
raise FileNotFoundError(f"Missing local model bundle: {src}")
dest_dir.mkdir(parents=True, exist_ok=True)
for p in src.iterdir():
if p.is_file():
(dest_dir / p.name).write_bytes(p.read_bytes())
def get_dw_local_path(self, year: int, season: str) -> str:
p = self.base / "dw" / f"dw_{season}_{year}.tif"
if not p.exists():
raise FileNotFoundError(f"Missing DW baseline: {p}")
return str(p)
def upload_result(self, local_path: Path, key: str) -> str:
dest = self.base / key
dest.parent.mkdir(parents=True, exist_ok=True)
dest.write_bytes(local_path.read_bytes())
return f"file://{dest}"