geocrop-platform./apps/worker/config.py

386 lines
13 KiB
Python

"""Central configuration for GeoCrop.
This file keeps ALL constants and environment wiring in one place.
It also defines a StorageAdapter interface so you can swap:
- local filesystem (dev)
- MinIO S3 (prod)
Roo Code can extend this with:
- Zimbabwe polygon path
- DEA STAC collection/band config
- model registry
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from datetime import date
from pathlib import Path
from typing import Dict, Optional, Tuple
# ==========================================
# Training config
# ==========================================
@dataclass
class TrainingConfig:
# Dataset
label_col: str = "label"
junk_cols: list = field(
default_factory=lambda: [
".geo",
"system:index",
"latitude",
"longitude",
"lat",
"lon",
"ID",
"parent_id",
"batch_id",
"is_syn",
]
)
# Split
test_size: float = 0.2
random_state: int = 42
# Scout
scout_n_estimators: int = 100
# Models (match your original hyperparams)
rf_n_estimators: int = 200
xgb_n_estimators: int = 300
xgb_learning_rate: float = 0.05
xgb_max_depth: int = 7
xgb_subsample: float = 0.8
xgb_colsample_bytree: float = 0.8
lgb_n_estimators: int = 800
lgb_learning_rate: float = 0.03
lgb_num_leaves: int = 63
lgb_subsample: float = 0.8
lgb_colsample_bytree: float = 0.8
lgb_min_child_samples: int = 30
cb_iterations: int = 500
cb_learning_rate: float = 0.05
cb_depth: int = 6
# Artifact upload
upload_minio: bool = False
minio_endpoint: str = ""
minio_access_key: str = ""
minio_secret_key: str = ""
minio_bucket: str = "geocrop-models"
minio_prefix: str = "models"
# ==========================================
# Inference config
# ==========================================
class StorageAdapter:
"""Abstract interface used by inference.
Roo Code should implement a MinIO-backed adapter.
"""
def download_model_bundle(self, model_key: str, dest_dir: Path):
raise NotImplementedError
def get_dw_local_path(self, year: int, season: str) -> str:
"""Return local filepath to DW baseline COG for given year/season.
In prod you might download on-demand or mount a shared volume.
"""
raise NotImplementedError
def upload_result(self, local_path: Path, key: str) -> str:
"""Upload a file and return a URI (s3://... or https://signed-url)."""
raise NotImplementedError
def write_layer_geotiff(self, out_path: Path, arr, profile: dict):
"""Write a 1-band or 3-band GeoTIFF aligned to profile."""
import rasterio
if arr.ndim == 2:
count = 1
elif arr.ndim == 3 and arr.shape[2] == 3:
count = 3
else:
raise ValueError("arr must be (H,W) or (H,W,3)")
prof = profile.copy()
prof.update({"count": count})
with rasterio.open(out_path, "w", **prof) as dst:
if count == 1:
dst.write(arr, 1)
else:
# (H,W,3) -> (3,H,W)
dst.write(arr.transpose(2, 0, 1))
class MinIOStorage(StorageAdapter):
"""MinIO/S3-backed storage adapter for production.
Supports:
- Model artifact downloading (from geocrop-models bucket)
- DW baseline access (from geocrop-baselines bucket)
- Result uploads (to geocrop-results bucket)
- Presigned URL generation
"""
def __init__(
self,
endpoint: str = None,
access_key: str = None,
secret_key: str = None,
bucket_models: str = "geocrop-models",
bucket_baselines: str = "geocrop-baselines",
bucket_results: str = "geocrop-results",
):
# Default to internal service if not provided
if endpoint is None:
host = os.getenv("MINIO_SERVICE_HOST", "minio.geocrop.svc.cluster.local")
port = os.getenv("MINIO_SERVICE_PORT", "9000")
endpoint = f"{host}:{port}"
self.endpoint = endpoint
self.access_key = access_key or os.getenv("MINIO_ACCESS_KEY", "minioadmin")
self.secret_key = secret_key or os.getenv("MINIO_SECRET_KEY", "minioadmin")
self.bucket_models = bucket_models
self.bucket_baselines = bucket_baselines
self.bucket_results = bucket_results
# Lazy-load boto3
self._s3_client = None
@property
def s3(self):
"""Lazy-load S3 client."""
if self._s3_client is None:
import boto3
from botocore.config import Config
scheme = "https" if ".techarvest.co.zw" in self.endpoint else "http"
url = f"{scheme}://{self.endpoint}"
if "://" in self.endpoint:
url = self.endpoint
self._s3_client = boto3.client(
"s3",
endpoint_url=url,
aws_access_key_id=self.access_key,
aws_secret_access_key=self.secret_key,
config=Config(
signature_version="s3v4",
retries={"max_attempts": 3, "mode": "standard"}
),
region_name="us-east-1",
)
return self._s3_client
def download_model_bundle(self, model_key: str, dest_dir: Path):
"""Download model files from geocrop-models bucket.
Args:
model_key: Full key or prefix (e.g., "models/Zimbabwe_Ensemble_Raw_Model.pkl" or "models/v1/")
dest_dir: Local directory to save files
"""
dest_dir = Path(dest_dir)
dest_dir.mkdir(parents=True, exist_ok=True)
try:
# Check if it's a single file or a prefix
response = self.s3.list_objects_v2(Bucket=self.bucket_models, Prefix=model_key)
if 'Contents' not in response:
raise FileNotFoundError(f"No objects found with prefix/key {model_key} in {self.bucket_models}")
# If it's a single file and the key matches exactly
if len(response['Contents']) == 1 and response['Contents'][0]['Key'] == model_key:
filename = Path(model_key).name
# If inference.py expects 'model.pkl', we provide it
local_path = dest_dir / "model.pkl" if model_key.endswith(".pkl") else dest_dir / filename
print(f" Downloading single file s3://{self.bucket_models}/{model_key} -> {local_path}")
self.s3.download_file(self.bucket_models, model_key, str(local_path))
else:
# It's a prefix, download all files within it
print(f" Downloading prefix s3://{self.bucket_models}/{model_key} to {dest_dir}")
for obj in response['Contents']:
key = obj['Key']
if key.endswith("/"): continue # Skip "directories"
# Get relative path from prefix
rel_path = os.path.relpath(key, model_key)
if rel_path == ".":
rel_path = Path(key).name
target_path = dest_dir / rel_path
target_path.parent.mkdir(parents=True, exist_ok=True)
print(f" -> {key} to {target_path}")
self.s3.download_file(self.bucket_models, key, str(target_path))
except Exception as e:
raise FileNotFoundError(f"Failed to download model bundle {model_key}: {e}") from e
def get_dw_local_path(self, year: int, season: str) -> str:
"""Get path to DW baseline COG for given year/season.
Returns a VSI S3 path for direct rasterio access.
Args:
year: Season start year (e.g., 2021 for 2021-2022 season)
season: Season type ("summer")
Returns:
VSI S3 path string (e.g., "/vsis3/geocrop-baselines/dw/zim/summer/DW_Zim_HighestConf_2021_2022-...")
"""
# Prefix in MinIO
prefix = f"dw/zim/summer/DW_Zim_HighestConf_{year}_{year + 1}"
try:
# List objects to find the actual tiles
response = self.s3.list_objects_v2(Bucket=self.bucket_baselines, Prefix=prefix)
if 'Contents' not in response:
# Try alternative prefix without dw/zim/summer
prefix_alt = f"DW_Zim_HighestConf_{year}_{year + 1}"
response = self.s3.list_objects_v2(Bucket=self.bucket_baselines, Prefix=prefix_alt)
if 'Contents' not in response:
raise FileNotFoundError(f"No DW baseline tiles found for {year} {season} in {self.bucket_baselines}")
# For now, just pick the first tile.
# In a real system, we should use a VRT or find the tile that covers the AOI.
# But for testing, the first tile often works if the AOI is near the origin.
key = response['Contents'][0]['Key']
print(f" Found DW baseline tile: {key}")
# Return /vsis3 path for rasterio
return f"/vsis3/{self.bucket_baselines}/{key}"
except Exception as e:
raise FileNotFoundError(f"Failed to find DW baseline: {e}") from e
def upload_result(self, local_path: Path, key: str) -> str:
"""Upload result file to geocrop-results bucket.
Args:
local_path: Local file path
key: S3 key (e.g., "results/refined_2022.tif")
Returns:
S3 URI
"""
local_path = Path(local_path)
try:
self.s3.upload_file(
str(local_path),
self.bucket_results,
key
)
except Exception as e:
raise RuntimeError(f"Failed to upload {local_path}: {e}") from e
return f"s3://{self.bucket_results}/{key}"
def generate_presigned_url(self, bucket: str, key: str, expires: int = 3600) -> str:
"""Generate presigned URL for downloading.
Args:
bucket: Bucket name
key: S3 key
expires: URL expiration in seconds
Returns:
Presigned URL
"""
try:
url = self.s3.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": key},
ExpiresIn=expires,
)
return url
except Exception as e:
raise RuntimeError(f"Failed to generate presigned URL: {e}") from e
@dataclass
class InferenceConfig:
# Constraints
max_radius_m: float = 5000.0
# Season window (YOU asked to use Sep -> May)
# We'll interpret "year" as the first year in the season.
# Example: year=2019 -> season 2019-09-01 to 2020-05-31
summer_start_month: int = 9
summer_start_day: int = 1
summer_end_month: int = 5
summer_end_day: int = 31
smoothing_enabled: bool = True
smoothing_kernel: int = 3
# DEA STAC
dea_root: str = "https://explorer.digitalearth.africa/stac"
dea_search: str = "https://explorer.digitalearth.africa/stac/search"
dea_stac_url: str = "https://explorer.digitalearth.africa/stac"
# Storage adapter
storage: StorageAdapter = None
def season_dates(self, year: int, season: str = "summer") -> Tuple[str, str]:
if season.lower() != "summer":
raise ValueError("Only summer season supported for now")
start = date(year, self.summer_start_month, self.summer_start_day)
end = date(year + 1, self.summer_end_month, self.summer_end_day)
return start.isoformat(), end.isoformat()
# ==========================================
# Example local dev adapter
# ==========================================
class LocalStorage(StorageAdapter):
"""Simple dev adapter using local filesystem."""
def __init__(self, base_dir: str = "/data/geocrop"):
self.base = Path(base_dir)
self.base.mkdir(parents=True, exist_ok=True)
(self.base / "results").mkdir(exist_ok=True)
(self.base / "models").mkdir(exist_ok=True)
(self.base / "dw").mkdir(exist_ok=True)
def download_model_bundle(self, model_key: str, dest_dir: Path):
src = self.base / "models" / model_key
if not src.exists():
raise FileNotFoundError(f"Missing local model bundle: {src}")
dest_dir.mkdir(parents=True, exist_ok=True)
for p in src.iterdir():
if p.is_file():
(dest_dir / p.name).write_bytes(p.read_bytes())
def get_dw_local_path(self, year: int, season: str) -> str:
p = self.base / "dw" / f"dw_{season}_{year}.tif"
if not p.exists():
raise FileNotFoundError(f"Missing DW baseline: {p}")
return str(p)
def upload_result(self, local_path: Path, key: str) -> str:
dest = self.base / key
dest.parent.mkdir(parents=True, exist_ok=True)
dest.write_bytes(local_path.read_bytes())
return f"file://{dest}"