geocrop-platform./apps/worker/postprocess.py

383 lines
11 KiB
Python

"""Post-processing utilities for inference output.
STEP 7: Provides neighborhood smoothing and class utilities.
This module provides:
- Majority filter (mode) with nodata preservation
- Class remapping
- Confidence computation from probabilities
NOTE: Uses pure numpy implementation for efficiency.
"""
from __future__ import annotations
from typing import Optional, List
import numpy as np
# ==========================================
# Kernel Validation
# ==========================================
def validate_kernel(kernel: int) -> int:
"""Validate smoothing kernel size.
Args:
kernel: Kernel size (must be 3, 5, or 7)
Returns:
Validated kernel size
Raises:
ValueError: If kernel is not 3, 5, or 7
"""
valid_kernels = {3, 5, 7}
if kernel not in valid_kernels:
raise ValueError(
f"Invalid kernel size: {kernel}. "
f"Must be one of {valid_kernels}."
)
return kernel
# ==========================================
# Majority Filter
# ==========================================
def _majority_filter_slow(
cls: np.ndarray,
kernel: int,
nodata: int,
) -> np.ndarray:
"""Slow majority filter implementation using Python loops.
This is a fallback if sliding_window_view is not available.
"""
H, W = cls.shape
pad = kernel // 2
result = cls.copy()
# Pad array
padded = np.pad(cls, pad, mode='constant', constant_values=nodata)
for i in range(H):
for j in range(W):
# Extract window
window = padded[i:i+kernel, j:j+kernel]
# Get center pixel
center_val = cls[i, j]
# Skip if center is nodata
if center_val == nodata:
continue
# Count non-nodata values
values = window.flatten()
mask = values != nodata
if not np.any(mask):
# All neighbors are nodata, keep center
continue
counts = {}
for v in values[mask]:
counts[v] = counts.get(v, 0) + 1
# Find max count
max_count = max(counts.values())
# Get candidates with max count
candidates = [v for v, c in counts.items() if c == max_count]
# Tie-breaking: prefer center if in tie, else smallest
if center_val in candidates:
result[i, j] = center_val
else:
result[i, j] = min(candidates)
return result
def majority_filter(
cls: np.ndarray,
kernel: int = 5,
nodata: int = 0,
) -> np.ndarray:
"""Apply a majority (mode) filter to a class raster.
Args:
cls: 2D array of class IDs (H, W)
kernel: Kernel size (3, 5, or 7)
nodata: Nodata value to preserve
Returns:
Filtered class raster of same shape
Rules:
- Nodata pixels in input stay nodata in output
- When computing neighborhood majority, nodata values are excluded from vote
- If all neighbors are nodata, output nodata
- Tie-breaking:
- Prefer original center pixel if it's part of the tie
- Otherwise choose smallest class ID
"""
# Validate kernel
validate_kernel(kernel)
cls = np.asarray(cls, dtype=np.int32)
if cls.ndim != 2:
raise ValueError(f"Expected 2D array, got shape {cls.shape}")
H, W = cls.shape
pad = kernel // 2
# Pad array with nodata
padded = np.pad(cls, pad, mode='constant', constant_values=nodata)
result = cls.copy()
# Try to use sliding_window_view for efficiency
try:
from numpy.lib.stride_tricks import sliding_window_view
windows = sliding_window_view(padded, (kernel, kernel))
# Iterate over valid positions
for i in range(H):
for j in range(W):
window = windows[i, j]
# Get center pixel
center_val = cls[i, j]
# Skip if center is nodata
if center_val == nodata:
continue
# Flatten and count
values = window.flatten()
# Exclude nodata
mask = values != nodata
if not np.any(mask):
# All neighbors are nodata, keep center
continue
valid_values = values[mask]
# Count using bincount (faster)
max_class = int(valid_values.max()) + 1
if max_class > 0:
counts = np.bincount(valid_values, minlength=max_class)
else:
continue
# Get max count
max_count = counts.max()
# Get candidates with max count
candidates = np.where(counts == max_count)[0]
# Tie-breaking
if center_val in candidates:
result[i, j] = center_val
else:
result[i, j] = int(candidates.min())
except ImportError:
# Fallback to slow implementation
result = _majority_filter_slow(cls, kernel, nodata)
return result
# ==========================================
# Class Remapping
# ==========================================
def remap_classes(
cls: np.ndarray,
mapping: dict,
nodata: int = 0,
) -> np.ndarray:
"""Apply integer mapping to class raster.
Args:
cls: 2D array of class IDs (H, W)
mapping: Dict mapping old class IDs to new class IDs
nodata: Nodata value to preserve
Returns:
Remapped class raster
"""
cls = np.asarray(cls, dtype=np.int32)
result = cls.copy()
# Apply mapping
for old_val, new_val in mapping.items():
mask = (cls == old_val) & (cls != nodata)
result[mask] = new_val
return result
# ==========================================
# Confidence from Probabilities
# ==========================================
def compute_confidence_from_proba(
proba_max: np.ndarray,
nodata_mask: np.ndarray,
) -> np.ndarray:
"""Compute confidence raster from probability array.
Args:
proba_max: 2D array of max probability per pixel (H, W)
nodata_mask: Boolean mask where pixels are nodata
Returns:
2D float32 confidence raster with nodata set to 0
"""
proba_max = np.asarray(proba_max, dtype=np.float32)
nodata_mask = np.asarray(nodata_mask, dtype=bool)
# Set nodata to 0
result = proba_max.copy()
result[nodata_mask] = 0.0
return result
# ==========================================
# Model Class Utilities
# ==========================================
def get_model_classes(model) -> Optional[List[str]]:
"""Extract class names from a trained model if available.
Args:
model: Trained sklearn-compatible model
Returns:
List of class names if available, None otherwise
"""
if hasattr(model, 'classes_'):
classes = model.classes_
if hasattr(classes, 'tolist'):
return classes.tolist()
elif isinstance(classes, (list, tuple)):
return list(classes)
return None
return None
# ==========================================
# Self-Test
# ==========================================
if __name__ == "__main__":
print("=== PostProcess Module Self-Test ===")
# Check for numpy
if np is None:
print("numpy not available - skipping test")
import sys
sys.exit(0)
# Create synthetic test raster
print("\n1. Creating synthetic test raster...")
H, W = 20, 20
np.random.seed(42)
# Create raster with multiple classes and nodata holes
cls = np.random.randint(1, 8, size=(H, W)).astype(np.int32)
# Add some nodata holes
cls[3:6, 3:6] = 0 # nodata region
cls[15:18, 15:18] = 0 # another nodata region
print(f" Input shape: {cls.shape}")
print(f" Input unique values: {sorted(np.unique(cls))}")
print(f" Nodata count: {np.sum(cls == 0)}")
# Test majority filter with kernel=3
print("\n2. Testing majority_filter (kernel=3)...")
result3 = majority_filter(cls, kernel=3, nodata=0)
changed3 = np.sum((result3 != cls) & (cls != 0))
nodata_preserved3 = np.sum(result3 == 0) == np.sum(cls == 0)
print(f" Output unique values: {sorted(np.unique(result3))}")
print(f" Changed pixels (excl nodata): {changed3}")
print(f" Nodata preserved: {nodata_preserved3}")
if nodata_preserved3:
print(" ✓ Nodata preservation test PASSED")
else:
print(" ✗ Nodata preservation test FAILED")
# Test majority filter with kernel=5
print("\n3. Testing majority_filter (kernel=5)...")
result5 = majority_filter(cls, kernel=5, nodata=0)
changed5 = np.sum((result5 != cls) & (cls != 0))
nodata_preserved5 = np.sum(result5 == 0) == np.sum(cls == 0)
print(f" Output unique values: {sorted(np.unique(result5))}")
print(f" Changed pixels (excl nodata): {changed5}")
print(f" Nodata preserved: {nodata_preserved5}")
if nodata_preserved5:
print(" ✓ Nodata preservation test PASSED")
else:
print(" ✗ Nodata preservation test FAILED")
# Test class remapping
print("\n4. Testing remap_classes...")
mapping = {1: 10, 2: 20, 3: 30}
remapped = remap_classes(cls, mapping, nodata=0)
# Check mapping applied
mapped_count = np.sum(np.isin(cls, [1, 2, 3]) & (cls != 0))
unchanged = np.sum(remapped == cls)
print(f" Mapped pixels: {mapped_count}")
print(f" Unchanged pixels: {unchanged}")
print(" ✓ remap_classes test PASSED")
# Test confidence from proba
print("\n5. Testing compute_confidence_from_proba...")
proba = np.random.rand(H, W).astype(np.float32)
nodata_mask = cls == 0
confidence = compute_confidence_from_proba(proba, nodata_mask)
nodata_conf_zero = np.all(confidence[nodata_mask] == 0)
valid_conf_positive = np.all(confidence[~nodata_mask] >= 0)
print(f" Nodata pixels have 0 confidence: {nodata_conf_zero}")
print(f" Valid pixels have positive confidence: {valid_conf_positive}")
if nodata_conf_zero and valid_conf_positive:
print(" ✓ compute_confidence_from_proba test PASSED")
else:
print(" ✗ compute_confidence_from_proba test FAILED")
# Test kernel validation
print("\n6. Testing kernel validation...")
try:
validate_kernel(3)
validate_kernel(5)
validate_kernel(7)
print(" Valid kernels (3,5,7) accepted: ✓")
except ValueError:
print(" ✗ Valid kernels rejected")
try:
validate_kernel(4)
print(" ✗ Invalid kernel accepted (should have failed)")
except ValueError:
print(" Invalid kernel (4) rejected: ✓")
print("\n=== PostProcess Module Test Complete ===")