383 lines
11 KiB
Python
383 lines
11 KiB
Python
"""Post-processing utilities for inference output.
|
|
|
|
STEP 7: Provides neighborhood smoothing and class utilities.
|
|
|
|
This module provides:
|
|
- Majority filter (mode) with nodata preservation
|
|
- Class remapping
|
|
- Confidence computation from probabilities
|
|
|
|
NOTE: Uses pure numpy implementation for efficiency.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Optional, List
|
|
|
|
import numpy as np
|
|
|
|
|
|
# ==========================================
|
|
# Kernel Validation
|
|
# ==========================================
|
|
|
|
def validate_kernel(kernel: int) -> int:
|
|
"""Validate smoothing kernel size.
|
|
|
|
Args:
|
|
kernel: Kernel size (must be 3, 5, or 7)
|
|
|
|
Returns:
|
|
Validated kernel size
|
|
|
|
Raises:
|
|
ValueError: If kernel is not 3, 5, or 7
|
|
"""
|
|
valid_kernels = {3, 5, 7}
|
|
if kernel not in valid_kernels:
|
|
raise ValueError(
|
|
f"Invalid kernel size: {kernel}. "
|
|
f"Must be one of {valid_kernels}."
|
|
)
|
|
return kernel
|
|
|
|
|
|
# ==========================================
|
|
# Majority Filter
|
|
# ==========================================
|
|
|
|
def _majority_filter_slow(
|
|
cls: np.ndarray,
|
|
kernel: int,
|
|
nodata: int,
|
|
) -> np.ndarray:
|
|
"""Slow majority filter implementation using Python loops.
|
|
|
|
This is a fallback if sliding_window_view is not available.
|
|
"""
|
|
H, W = cls.shape
|
|
pad = kernel // 2
|
|
result = cls.copy()
|
|
|
|
# Pad array
|
|
padded = np.pad(cls, pad, mode='constant', constant_values=nodata)
|
|
|
|
for i in range(H):
|
|
for j in range(W):
|
|
# Extract window
|
|
window = padded[i:i+kernel, j:j+kernel]
|
|
|
|
# Get center pixel
|
|
center_val = cls[i, j]
|
|
|
|
# Skip if center is nodata
|
|
if center_val == nodata:
|
|
continue
|
|
|
|
# Count non-nodata values
|
|
values = window.flatten()
|
|
mask = values != nodata
|
|
|
|
if not np.any(mask):
|
|
# All neighbors are nodata, keep center
|
|
continue
|
|
|
|
counts = {}
|
|
for v in values[mask]:
|
|
counts[v] = counts.get(v, 0) + 1
|
|
|
|
# Find max count
|
|
max_count = max(counts.values())
|
|
|
|
# Get candidates with max count
|
|
candidates = [v for v, c in counts.items() if c == max_count]
|
|
|
|
# Tie-breaking: prefer center if in tie, else smallest
|
|
if center_val in candidates:
|
|
result[i, j] = center_val
|
|
else:
|
|
result[i, j] = min(candidates)
|
|
|
|
return result
|
|
|
|
|
|
def majority_filter(
|
|
cls: np.ndarray,
|
|
kernel: int = 5,
|
|
nodata: int = 0,
|
|
) -> np.ndarray:
|
|
"""Apply a majority (mode) filter to a class raster.
|
|
|
|
Args:
|
|
cls: 2D array of class IDs (H, W)
|
|
kernel: Kernel size (3, 5, or 7)
|
|
nodata: Nodata value to preserve
|
|
|
|
Returns:
|
|
Filtered class raster of same shape
|
|
|
|
Rules:
|
|
- Nodata pixels in input stay nodata in output
|
|
- When computing neighborhood majority, nodata values are excluded from vote
|
|
- If all neighbors are nodata, output nodata
|
|
- Tie-breaking:
|
|
- Prefer original center pixel if it's part of the tie
|
|
- Otherwise choose smallest class ID
|
|
"""
|
|
# Validate kernel
|
|
validate_kernel(kernel)
|
|
|
|
cls = np.asarray(cls, dtype=np.int32)
|
|
|
|
if cls.ndim != 2:
|
|
raise ValueError(f"Expected 2D array, got shape {cls.shape}")
|
|
|
|
H, W = cls.shape
|
|
pad = kernel // 2
|
|
|
|
# Pad array with nodata
|
|
padded = np.pad(cls, pad, mode='constant', constant_values=nodata)
|
|
result = cls.copy()
|
|
|
|
# Try to use sliding_window_view for efficiency
|
|
try:
|
|
from numpy.lib.stride_tricks import sliding_window_view
|
|
windows = sliding_window_view(padded, (kernel, kernel))
|
|
|
|
# Iterate over valid positions
|
|
for i in range(H):
|
|
for j in range(W):
|
|
window = windows[i, j]
|
|
|
|
# Get center pixel
|
|
center_val = cls[i, j]
|
|
|
|
# Skip if center is nodata
|
|
if center_val == nodata:
|
|
continue
|
|
|
|
# Flatten and count
|
|
values = window.flatten()
|
|
|
|
# Exclude nodata
|
|
mask = values != nodata
|
|
|
|
if not np.any(mask):
|
|
# All neighbors are nodata, keep center
|
|
continue
|
|
|
|
valid_values = values[mask]
|
|
|
|
# Count using bincount (faster)
|
|
max_class = int(valid_values.max()) + 1
|
|
if max_class > 0:
|
|
counts = np.bincount(valid_values, minlength=max_class)
|
|
else:
|
|
continue
|
|
|
|
# Get max count
|
|
max_count = counts.max()
|
|
|
|
# Get candidates with max count
|
|
candidates = np.where(counts == max_count)[0]
|
|
|
|
# Tie-breaking
|
|
if center_val in candidates:
|
|
result[i, j] = center_val
|
|
else:
|
|
result[i, j] = int(candidates.min())
|
|
|
|
except ImportError:
|
|
# Fallback to slow implementation
|
|
result = _majority_filter_slow(cls, kernel, nodata)
|
|
|
|
return result
|
|
|
|
|
|
# ==========================================
|
|
# Class Remapping
|
|
# ==========================================
|
|
|
|
def remap_classes(
|
|
cls: np.ndarray,
|
|
mapping: dict,
|
|
nodata: int = 0,
|
|
) -> np.ndarray:
|
|
"""Apply integer mapping to class raster.
|
|
|
|
Args:
|
|
cls: 2D array of class IDs (H, W)
|
|
mapping: Dict mapping old class IDs to new class IDs
|
|
nodata: Nodata value to preserve
|
|
|
|
Returns:
|
|
Remapped class raster
|
|
"""
|
|
cls = np.asarray(cls, dtype=np.int32)
|
|
result = cls.copy()
|
|
|
|
# Apply mapping
|
|
for old_val, new_val in mapping.items():
|
|
mask = (cls == old_val) & (cls != nodata)
|
|
result[mask] = new_val
|
|
|
|
return result
|
|
|
|
|
|
# ==========================================
|
|
# Confidence from Probabilities
|
|
# ==========================================
|
|
|
|
def compute_confidence_from_proba(
|
|
proba_max: np.ndarray,
|
|
nodata_mask: np.ndarray,
|
|
) -> np.ndarray:
|
|
"""Compute confidence raster from probability array.
|
|
|
|
Args:
|
|
proba_max: 2D array of max probability per pixel (H, W)
|
|
nodata_mask: Boolean mask where pixels are nodata
|
|
|
|
Returns:
|
|
2D float32 confidence raster with nodata set to 0
|
|
"""
|
|
proba_max = np.asarray(proba_max, dtype=np.float32)
|
|
nodata_mask = np.asarray(nodata_mask, dtype=bool)
|
|
|
|
# Set nodata to 0
|
|
result = proba_max.copy()
|
|
result[nodata_mask] = 0.0
|
|
|
|
return result
|
|
|
|
|
|
# ==========================================
|
|
# Model Class Utilities
|
|
# ==========================================
|
|
|
|
def get_model_classes(model) -> Optional[List[str]]:
|
|
"""Extract class names from a trained model if available.
|
|
|
|
Args:
|
|
model: Trained sklearn-compatible model
|
|
|
|
Returns:
|
|
List of class names if available, None otherwise
|
|
"""
|
|
if hasattr(model, 'classes_'):
|
|
classes = model.classes_
|
|
if hasattr(classes, 'tolist'):
|
|
return classes.tolist()
|
|
elif isinstance(classes, (list, tuple)):
|
|
return list(classes)
|
|
return None
|
|
return None
|
|
|
|
|
|
# ==========================================
|
|
# Self-Test
|
|
# ==========================================
|
|
|
|
if __name__ == "__main__":
|
|
print("=== PostProcess Module Self-Test ===")
|
|
|
|
# Check for numpy
|
|
if np is None:
|
|
print("numpy not available - skipping test")
|
|
import sys
|
|
sys.exit(0)
|
|
|
|
# Create synthetic test raster
|
|
print("\n1. Creating synthetic test raster...")
|
|
|
|
H, W = 20, 20
|
|
np.random.seed(42)
|
|
|
|
# Create raster with multiple classes and nodata holes
|
|
cls = np.random.randint(1, 8, size=(H, W)).astype(np.int32)
|
|
|
|
# Add some nodata holes
|
|
cls[3:6, 3:6] = 0 # nodata region
|
|
cls[15:18, 15:18] = 0 # another nodata region
|
|
|
|
print(f" Input shape: {cls.shape}")
|
|
print(f" Input unique values: {sorted(np.unique(cls))}")
|
|
print(f" Nodata count: {np.sum(cls == 0)}")
|
|
|
|
# Test majority filter with kernel=3
|
|
print("\n2. Testing majority_filter (kernel=3)...")
|
|
result3 = majority_filter(cls, kernel=3, nodata=0)
|
|
changed3 = np.sum((result3 != cls) & (cls != 0))
|
|
nodata_preserved3 = np.sum(result3 == 0) == np.sum(cls == 0)
|
|
|
|
print(f" Output unique values: {sorted(np.unique(result3))}")
|
|
print(f" Changed pixels (excl nodata): {changed3}")
|
|
print(f" Nodata preserved: {nodata_preserved3}")
|
|
|
|
if nodata_preserved3:
|
|
print(" ✓ Nodata preservation test PASSED")
|
|
else:
|
|
print(" ✗ Nodata preservation test FAILED")
|
|
|
|
# Test majority filter with kernel=5
|
|
print("\n3. Testing majority_filter (kernel=5)...")
|
|
result5 = majority_filter(cls, kernel=5, nodata=0)
|
|
changed5 = np.sum((result5 != cls) & (cls != 0))
|
|
nodata_preserved5 = np.sum(result5 == 0) == np.sum(cls == 0)
|
|
|
|
print(f" Output unique values: {sorted(np.unique(result5))}")
|
|
print(f" Changed pixels (excl nodata): {changed5}")
|
|
print(f" Nodata preserved: {nodata_preserved5}")
|
|
|
|
if nodata_preserved5:
|
|
print(" ✓ Nodata preservation test PASSED")
|
|
else:
|
|
print(" ✗ Nodata preservation test FAILED")
|
|
|
|
# Test class remapping
|
|
print("\n4. Testing remap_classes...")
|
|
mapping = {1: 10, 2: 20, 3: 30}
|
|
remapped = remap_classes(cls, mapping, nodata=0)
|
|
|
|
# Check mapping applied
|
|
mapped_count = np.sum(np.isin(cls, [1, 2, 3]) & (cls != 0))
|
|
unchanged = np.sum(remapped == cls)
|
|
print(f" Mapped pixels: {mapped_count}")
|
|
print(f" Unchanged pixels: {unchanged}")
|
|
print(" ✓ remap_classes test PASSED")
|
|
|
|
# Test confidence from proba
|
|
print("\n5. Testing compute_confidence_from_proba...")
|
|
proba = np.random.rand(H, W).astype(np.float32)
|
|
nodata_mask = cls == 0
|
|
confidence = compute_confidence_from_proba(proba, nodata_mask)
|
|
|
|
nodata_conf_zero = np.all(confidence[nodata_mask] == 0)
|
|
valid_conf_positive = np.all(confidence[~nodata_mask] >= 0)
|
|
|
|
print(f" Nodata pixels have 0 confidence: {nodata_conf_zero}")
|
|
print(f" Valid pixels have positive confidence: {valid_conf_positive}")
|
|
|
|
if nodata_conf_zero and valid_conf_positive:
|
|
print(" ✓ compute_confidence_from_proba test PASSED")
|
|
else:
|
|
print(" ✗ compute_confidence_from_proba test FAILED")
|
|
|
|
# Test kernel validation
|
|
print("\n6. Testing kernel validation...")
|
|
try:
|
|
validate_kernel(3)
|
|
validate_kernel(5)
|
|
validate_kernel(7)
|
|
print(" Valid kernels (3,5,7) accepted: ✓")
|
|
except ValueError:
|
|
print(" ✗ Valid kernels rejected")
|
|
|
|
try:
|
|
validate_kernel(4)
|
|
print(" ✗ Invalid kernel accepted (should have failed)")
|
|
except ValueError:
|
|
print(" Invalid kernel (4) rejected: ✓")
|
|
|
|
print("\n=== PostProcess Module Test Complete ===")
|