"""Post-processing utilities for inference output. STEP 7: Provides neighborhood smoothing and class utilities. This module provides: - Majority filter (mode) with nodata preservation - Class remapping - Confidence computation from probabilities NOTE: Uses pure numpy implementation for efficiency. """ from __future__ import annotations from typing import Optional, List import numpy as np # ========================================== # Kernel Validation # ========================================== def validate_kernel(kernel: int) -> int: """Validate smoothing kernel size. Args: kernel: Kernel size (must be 3, 5, or 7) Returns: Validated kernel size Raises: ValueError: If kernel is not 3, 5, or 7 """ valid_kernels = {3, 5, 7} if kernel not in valid_kernels: raise ValueError( f"Invalid kernel size: {kernel}. " f"Must be one of {valid_kernels}." ) return kernel # ========================================== # Majority Filter # ========================================== def _majority_filter_slow( cls: np.ndarray, kernel: int, nodata: int, ) -> np.ndarray: """Slow majority filter implementation using Python loops. This is a fallback if sliding_window_view is not available. """ H, W = cls.shape pad = kernel // 2 result = cls.copy() # Pad array padded = np.pad(cls, pad, mode='constant', constant_values=nodata) for i in range(H): for j in range(W): # Extract window window = padded[i:i+kernel, j:j+kernel] # Get center pixel center_val = cls[i, j] # Skip if center is nodata if center_val == nodata: continue # Count non-nodata values values = window.flatten() mask = values != nodata if not np.any(mask): # All neighbors are nodata, keep center continue counts = {} for v in values[mask]: counts[v] = counts.get(v, 0) + 1 # Find max count max_count = max(counts.values()) # Get candidates with max count candidates = [v for v, c in counts.items() if c == max_count] # Tie-breaking: prefer center if in tie, else smallest if center_val in candidates: result[i, j] = center_val else: result[i, j] = min(candidates) return result def majority_filter( cls: np.ndarray, kernel: int = 5, nodata: int = 0, ) -> np.ndarray: """Apply a majority (mode) filter to a class raster. Args: cls: 2D array of class IDs (H, W) kernel: Kernel size (3, 5, or 7) nodata: Nodata value to preserve Returns: Filtered class raster of same shape Rules: - Nodata pixels in input stay nodata in output - When computing neighborhood majority, nodata values are excluded from vote - If all neighbors are nodata, output nodata - Tie-breaking: - Prefer original center pixel if it's part of the tie - Otherwise choose smallest class ID """ # Validate kernel validate_kernel(kernel) cls = np.asarray(cls, dtype=np.int32) if cls.ndim != 2: raise ValueError(f"Expected 2D array, got shape {cls.shape}") H, W = cls.shape pad = kernel // 2 # Pad array with nodata padded = np.pad(cls, pad, mode='constant', constant_values=nodata) result = cls.copy() # Try to use sliding_window_view for efficiency try: from numpy.lib.stride_tricks import sliding_window_view windows = sliding_window_view(padded, (kernel, kernel)) # Iterate over valid positions for i in range(H): for j in range(W): window = windows[i, j] # Get center pixel center_val = cls[i, j] # Skip if center is nodata if center_val == nodata: continue # Flatten and count values = window.flatten() # Exclude nodata mask = values != nodata if not np.any(mask): # All neighbors are nodata, keep center continue valid_values = values[mask] # Count using bincount (faster) max_class = int(valid_values.max()) + 1 if max_class > 0: counts = np.bincount(valid_values, minlength=max_class) else: continue # Get max count max_count = counts.max() # Get candidates with max count candidates = np.where(counts == max_count)[0] # Tie-breaking if center_val in candidates: result[i, j] = center_val else: result[i, j] = int(candidates.min()) except ImportError: # Fallback to slow implementation result = _majority_filter_slow(cls, kernel, nodata) return result # ========================================== # Class Remapping # ========================================== def remap_classes( cls: np.ndarray, mapping: dict, nodata: int = 0, ) -> np.ndarray: """Apply integer mapping to class raster. Args: cls: 2D array of class IDs (H, W) mapping: Dict mapping old class IDs to new class IDs nodata: Nodata value to preserve Returns: Remapped class raster """ cls = np.asarray(cls, dtype=np.int32) result = cls.copy() # Apply mapping for old_val, new_val in mapping.items(): mask = (cls == old_val) & (cls != nodata) result[mask] = new_val return result # ========================================== # Confidence from Probabilities # ========================================== def compute_confidence_from_proba( proba_max: np.ndarray, nodata_mask: np.ndarray, ) -> np.ndarray: """Compute confidence raster from probability array. Args: proba_max: 2D array of max probability per pixel (H, W) nodata_mask: Boolean mask where pixels are nodata Returns: 2D float32 confidence raster with nodata set to 0 """ proba_max = np.asarray(proba_max, dtype=np.float32) nodata_mask = np.asarray(nodata_mask, dtype=bool) # Set nodata to 0 result = proba_max.copy() result[nodata_mask] = 0.0 return result # ========================================== # Model Class Utilities # ========================================== def get_model_classes(model) -> Optional[List[str]]: """Extract class names from a trained model if available. Args: model: Trained sklearn-compatible model Returns: List of class names if available, None otherwise """ if hasattr(model, 'classes_'): classes = model.classes_ if hasattr(classes, 'tolist'): return classes.tolist() elif isinstance(classes, (list, tuple)): return list(classes) return None return None # ========================================== # Self-Test # ========================================== if __name__ == "__main__": print("=== PostProcess Module Self-Test ===") # Check for numpy if np is None: print("numpy not available - skipping test") import sys sys.exit(0) # Create synthetic test raster print("\n1. Creating synthetic test raster...") H, W = 20, 20 np.random.seed(42) # Create raster with multiple classes and nodata holes cls = np.random.randint(1, 8, size=(H, W)).astype(np.int32) # Add some nodata holes cls[3:6, 3:6] = 0 # nodata region cls[15:18, 15:18] = 0 # another nodata region print(f" Input shape: {cls.shape}") print(f" Input unique values: {sorted(np.unique(cls))}") print(f" Nodata count: {np.sum(cls == 0)}") # Test majority filter with kernel=3 print("\n2. Testing majority_filter (kernel=3)...") result3 = majority_filter(cls, kernel=3, nodata=0) changed3 = np.sum((result3 != cls) & (cls != 0)) nodata_preserved3 = np.sum(result3 == 0) == np.sum(cls == 0) print(f" Output unique values: {sorted(np.unique(result3))}") print(f" Changed pixels (excl nodata): {changed3}") print(f" Nodata preserved: {nodata_preserved3}") if nodata_preserved3: print(" ✓ Nodata preservation test PASSED") else: print(" ✗ Nodata preservation test FAILED") # Test majority filter with kernel=5 print("\n3. Testing majority_filter (kernel=5)...") result5 = majority_filter(cls, kernel=5, nodata=0) changed5 = np.sum((result5 != cls) & (cls != 0)) nodata_preserved5 = np.sum(result5 == 0) == np.sum(cls == 0) print(f" Output unique values: {sorted(np.unique(result5))}") print(f" Changed pixels (excl nodata): {changed5}") print(f" Nodata preserved: {nodata_preserved5}") if nodata_preserved5: print(" ✓ Nodata preservation test PASSED") else: print(" ✗ Nodata preservation test FAILED") # Test class remapping print("\n4. Testing remap_classes...") mapping = {1: 10, 2: 20, 3: 30} remapped = remap_classes(cls, mapping, nodata=0) # Check mapping applied mapped_count = np.sum(np.isin(cls, [1, 2, 3]) & (cls != 0)) unchanged = np.sum(remapped == cls) print(f" Mapped pixels: {mapped_count}") print(f" Unchanged pixels: {unchanged}") print(" ✓ remap_classes test PASSED") # Test confidence from proba print("\n5. Testing compute_confidence_from_proba...") proba = np.random.rand(H, W).astype(np.float32) nodata_mask = cls == 0 confidence = compute_confidence_from_proba(proba, nodata_mask) nodata_conf_zero = np.all(confidence[nodata_mask] == 0) valid_conf_positive = np.all(confidence[~nodata_mask] >= 0) print(f" Nodata pixels have 0 confidence: {nodata_conf_zero}") print(f" Valid pixels have positive confidence: {valid_conf_positive}") if nodata_conf_zero and valid_conf_positive: print(" ✓ compute_confidence_from_proba test PASSED") else: print(" ✗ compute_confidence_from_proba test FAILED") # Test kernel validation print("\n6. Testing kernel validation...") try: validate_kernel(3) validate_kernel(5) validate_kernel(7) print(" Valid kernels (3,5,7) accepted: ✓") except ValueError: print(" ✗ Valid kernels rejected") try: validate_kernel(4) print(" ✗ Invalid kernel accepted (should have failed)") except ValueError: print(" Invalid kernel (4) rejected: ✓") print("\n=== PostProcess Module Test Complete ===")