geocrop-platform./apps/worker/stac_client.py

378 lines
12 KiB
Python

"""DEA STAC client for the worker.
STEP 3: STAC client using pystac-client.
This module provides:
- Collection resolution with fallback
- STAC search with cloud filtering
- Item normalization without downloading
NOTE: This does NOT implement stackstac loading - that comes in Step 4/5.
"""
from __future__ import annotations
import os
import time
import logging
from datetime import datetime
from typing import List, Optional, Dict, Any
# Configure logging
logger = logging.getLogger(__name__)
# ==========================================
# Configuration
# ==========================================
# Environment variables with defaults
DEA_STAC_ROOT = os.getenv("DEA_STAC_ROOT", "https://explorer.digitalearth.africa/stac")
DEA_STAC_SEARCH = os.getenv("DEA_STAC_SEARCH", "https://explorer.digitalearth.africa/stac/search")
DEA_CLOUD_MAX = int(os.getenv("DEA_CLOUD_MAX", "30"))
DEA_TIMEOUT_S = int(os.getenv("DEA_TIMEOUT_S", "30"))
# Preferred Sentinel-2 collection IDs (in order of preference)
S2_COLLECTION_PREFER = [
"s2_l2a",
"s2_l2a_c1",
"sentinel-2-l2a",
"sentinel_2_l2a",
]
# Desired band/asset keys to look for
DESIRED_ASSETS = [
"red", # B4
"green", # B3
"blue", # B2
"nir", # B8
"nir08", # B8A (red-edge)
"nir09", # B9
"swir16", # B11
"swir22", # B12
"scl", # Scene Classification Layer
"qa", # QA band
]
# ==========================================
# STAC Client Class
# ==========================================
class DEASTACClient:
"""Client for Digital Earth Africa STAC API."""
def __init__(
self,
root: str = DEA_STAC_ROOT,
search_url: str = DEA_STAC_SEARCH,
cloud_max: int = DEA_CLOUD_MAX,
timeout: int = DEA_TIMEOUT_S,
):
self.root = root
self.search_url = search_url
self.cloud_max = cloud_max
self.timeout = timeout
self._client = None
self._collections = None
@property
def client(self):
"""Lazy-load pystac client."""
if self._client is None:
import pystac_client
self._client = pystac_client.Client.open(self.root)
return self._client
def _retry_operation(self, operation, max_retries: int = 3, *args, **kwargs):
"""Execute operation with exponential backoff retry.
Args:
operation: Callable to execute
max_retries: Maximum retry attempts
*args, **kwargs: Arguments for operation
Returns:
Result of operation
"""
import pystac_client.exceptions as pystac_exc
last_exception = None
for attempt in range(max_retries):
try:
return operation(*args, **kwargs)
except (
pystac_exc.PySTACClientError,
pystac_exc.PySTACIOError,
Exception,
) as e:
# Only retry on network-like errors
error_str = str(e).lower()
should_retry = any(
kw in error_str
for kw in ["connection", "timeout", "network", "temporal"]
)
if not should_retry:
raise
last_exception = e
if attempt < max_retries - 1:
wait_time = 2 ** attempt
logger.warning(f"Retry {attempt + 1}/{max_retries} after {wait_time}s: {e}")
time.sleep(wait_time)
raise last_exception
def list_collections(self) -> List[str]:
"""List available collections.
Returns:
List of collection IDs
"""
def _list():
cols = self.client.get_collections()
return [c.id for c in cols]
return self._retry_operation(_list)
def resolve_s2_collection(self) -> Optional[str]:
"""Resolve best Sentinel-2 collection ID.
Returns:
Collection ID if found, None otherwise
"""
if self._collections is None:
self._collections = self.list_collections()
for coll_id in S2_COLLECTION_PREFER:
if coll_id in self._collections:
logger.info(f"Resolved S2 collection: {coll_id}")
return coll_id
# Log what collections ARE available
logger.warning(
f"None of {S2_COLLECTION_PREFER} found. "
f"Available: {self._collections[:10]}..."
)
return None
def search_items(
self,
bbox: List[float],
start_date: str,
end_date: str,
collections: Optional[List[str]] = None,
limit: int = 200,
) -> List[Any]:
"""Search for STAC items.
Args:
bbox: [minx, miny, maxx, maxy]
start_date: Start date (YYYY-MM-DD)
end_date: End date (YYYY-MM-DD)
collections: Optional list of collection IDs; auto-resolves if None
limit: Maximum items to return
Returns:
List of pystac.Item objects
Raises:
ValueError: If no collection available
"""
# Auto-resolve collection
if collections is None:
coll_id = self.resolve_s2_collection()
if coll_id is None:
available = self.list_collections()
raise ValueError(
f"No Sentinel-2 collection found. "
f"Available collections: {available[:20]}..."
)
collections = [coll_id]
def _search():
# Build query
query_params = {}
# Try cloud cover filter if DEA_CLOUD_MAX > 0
if self.cloud_max > 0:
try:
# Try with eo:cloud_cover (DEA supports this)
query_params["eo:cloud_cover"] = {"lt": self.cloud_max}
except Exception as e:
logger.warning(f"Cloud filter not supported: {e}")
search = self.client.search(
collections=collections,
bbox=bbox,
datetime=f"{start_date}/{end_date}",
limit=limit,
query=query_params if query_params else None,
)
return list(search.items())
return self._retry_operation(_search)
def _get_asset_info(self, item: Any) -> Dict[str, Dict]:
"""Extract minimal asset information from item.
Args:
item: pystac.Item
Returns:
Dict of asset key -> {href, type, roles}
"""
result = {}
if not item.assets:
return result
# First try desired assets
for key in DESIRED_ASSETS:
if key in item.assets:
asset = item.assets[key]
result[key] = {
"href": str(asset.href) if asset.href else None,
"type": asset.media_type if hasattr(asset, 'media_type') else None,
"roles": list(asset.roles) if asset.roles else [],
}
# If none of desired assets found, include first 5 as hint
if not result:
for i, (key, asset) in enumerate(list(item.assets.items())[:5]):
result[key] = {
"href": str(asset.href) if asset.href else None,
"type": asset.media_type if hasattr(asset, 'media_type') else None,
"roles": list(asset.roles) if asset.roles else [],
}
return result
def summarize_items(self, items: List[Any]) -> Dict[str, Any]:
"""Summarize search results without downloading.
Args:
items: List of pystac.Item objects
Returns:
Dict with:
{
"count": int,
"collection": str,
"time_start": str,
"time_end": str,
"items": [
{
"id": str,
"datetime": str,
"bbox": [...],
"cloud_cover": float|None,
"assets": {...}
}, ...
]
}
"""
if not items:
return {
"count": 0,
"collection": None,
"time_start": None,
"time_end": None,
"items": [],
}
# Get collection from first item
collection = items[0].collection_id if items[0].collection_id else "unknown"
# Get time range
times = [item.datetime for item in items if item.datetime]
time_start = min(times).isoformat() if times else None
time_end = max(times).isoformat() if times else None
# Build item summaries
item_summaries = []
for item in items:
# Get cloud cover
cloud_cover = None
if hasattr(item, 'properties'):
cloud_cover = item.properties.get('eo:cloud_cover')
# Get asset info
assets = self._get_asset_info(item)
item_summaries.append({
"id": item.id,
"datetime": item.datetime.isoformat() if item.datetime else None,
"bbox": list(item.bbox) if item.bbox else None,
"cloud_cover": cloud_cover,
"assets": assets,
})
return {
"count": len(items),
"collection": collection,
"time_start": time_start,
"time_end": time_end,
"items": item_summaries,
}
# ==========================================
# Self-Test
# ==========================================
if __name__ == "__main__":
print("=== DEA STAC Client Self-Test ===")
print(f"Root: {DEA_STAC_ROOT}")
print(f"Search: {DEA_STAC_SEARCH}")
print(f"Cloud max: {DEA_CLOUD_MAX}%")
print()
# Create client
client = DEASTACClient()
# Test collection resolution
print("Testing collection resolution...")
try:
s2_coll = client.resolve_s2_collection()
print(f" Resolved S2 collection: {s2_coll}")
except Exception as e:
print(f" Error: {e}")
# Test search with small AOI and date range
print("\nTesting search...")
# Zimbabwe AOI: lon 30.46, lat -16.81 (Harare area)
# Small bbox: ~2km radius
bbox = [30.40, -16.90, 30.52, -16.72] # [minx, miny, maxx, maxy]
# 30-day window in 2021
start_date = "2021-11-01"
end_date = "2021-12-01"
print(f" bbox: {bbox}")
print(f" dates: {start_date} to {end_date}")
try:
items = client.search_items(bbox, start_date, end_date)
print(f" Found {len(items)} items")
# Summarize
summary = client.summarize_items(items)
print(f" Collection: {summary['collection']}")
print(f" Time range: {summary['time_start']} to {summary['time_end']}")
if summary['items']:
first = summary['items'][0]
print(f" First item:")
print(f" id: {first['id']}")
print(f" datetime: {first['datetime']}")
print(f" cloud_cover: {first['cloud_cover']}")
print(f" assets: {list(first['assets'].keys())}")
except Exception as e:
print(f" Search error: {e}")
import traceback
traceback.print_exc()
print("\n=== Self-Test Complete ===")