diff --git a/apps/worker/hybrid_inference.py b/apps/worker/hybrid_inference.py index 7eec9a2..312687d 100644 --- a/apps/worker/hybrid_inference.py +++ b/apps/worker/hybrid_inference.py @@ -207,14 +207,52 @@ class CropInferencePipeline: def _impute_inference_data(self, df): print("Imputing cloudy/missing timesteps via temporal interpolation...") + from feature_computation import handle_temporal_gaps, spatial_fill_nan + df = df.copy() missing_mask = {} + + # Track original NaNs before any imputation for band in self.bands: band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns] if band_cols: missing_mask[band] = df[band_cols].isna().astype(float) + + # Process each band: apply handle_temporal_gaps per pixel for each band + for band in self.bands: + band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns] + if band_cols: + print(f" Processing band {band} with gap handling...") + + # For each pixel, apply handle_temporal_gaps to the time series + for idx in range(len(df)): + time_series = df[band_cols].iloc[idx].values.astype(np.float64) + + # Apply handle_temporal_gaps: gaps >= 3 will result in NaNs for those timesteps + time_series = handle_temporal_gaps(time_series, gap_threshold=3) + df.loc[df.index[idx], band_cols] = time_series + + # After gap handling, fill remaining NaNs with linear interpolation df[band_cols] = df[band_cols].interpolate(method='linear', axis=1, limit_direction='both') df[band_cols] = df[band_cols].ffill(axis=1).bfill(axis=1).fillna(0) + + # Apply spatial fill to each band using spatial_fill_nan + # Reshape to (num_dates, num_pixels) for each band, apply spatial fill + for band in self.bands: + band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns] + if band_cols: + print(f" Applying spatial fill for band {band}...") + + # Transpose to (T, H*W) for spatial filling + band_data = df[band_cols].values.T # Shape: (num_dates, num_pixels) + + # Apply spatial_fill_nan per time step + for t_idx in range(band_data.shape[0]): + band_data[t_idx] = spatial_fill_nan(band_data[t_idx].reshape(-1, 1)).squeeze() + + # Put back into dataframe + df[band_cols] = band_data.T + return df, missing_mask def predict(self, raw_df, apply_spatial_smoothing=False, coord_cols=['lat', 'lon']): @@ -240,6 +278,22 @@ class CropInferencePipeline: final_probs = (fcn_probs * self.w_fcn) + (cb_probs * self.w_cb) final_preds = np.argmax(final_probs, axis=1) + # Identify No Data pixels: those with all NaNs or zeros after imputation + no_data_mask = np.zeros(len(df), dtype=bool) + for band in self.bands: + band_cols = [f"{date}_{band}" for date in self.dates if f"{date}_{band}" in df.columns] + if band_cols: + band_data = df[band_cols].values + # Check if pixel is all zeros or all NaN for this band + all_zeros = np.all(band_data == 0, axis=1) + all_nan = np.all(np.isnan(band_data), axis=1) + no_data_mask = no_data_mask | all_zeros | all_nan + + # Override predictions for No Data pixels to class 0 (Background/No Data) + final_preds[no_data_mask] = 0 + final_probs[no_data_mask] = 0.0 + final_probs[no_data_mask, 0] = 1.0 # Set probability to 1.0 for class 0 + if apply_spatial_smoothing and all(col in df.columns for col in coord_cols): print(f"Applying spatial probability smoothing using {coord_cols}...") coords = df[coord_cols].values @@ -249,15 +303,20 @@ class CropInferencePipeline: final_preds = np.argmax(smoothed_probs, axis=1) final_probs = smoothed_probs + # Re-apply No Data override after smoothing + final_preds[no_data_mask] = 0 + final_probs[no_data_mask, 0] = 1.0 + df['class_id'] = final_preds df['predicted_crop'] = self.le.inverse_transform(final_preds) df['confidence'] = np.max(final_probs, axis=1) + # Track missing data ratio for quality flag missing_ratio = np.mean([m.mean(axis=1) for m in missing_mask.values()], axis=0) df['high_missing'] = missing_ratio > 0.4 - df['low_quality'] = (df['confidence'] < 0.5) | df['high_missing'] + df['low_quality'] = (df['confidence'] < 0.5) | df['high_missing'] | no_data_mask - # Set NoData (0) for low quality + # Set NoData (0) for low quality pixels df.loc[df['low_quality'], 'class_id'] = 0 df.loc[df['low_quality'], 'predicted_crop'] = 'Unknown/NoData' return df