409 lines
11 KiB
Python
409 lines
11 KiB
Python
"""GeoTIFF and COG output utilities.
|
|
|
|
STEP 8: Provides functions to write GeoTIFFs and convert them to Cloud Optimized GeoTIFFs.
|
|
|
|
This module provides:
|
|
- Profile normalization for output
|
|
- GeoTIFF writing with compression
|
|
- COG conversion with overviews
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import subprocess
|
|
import tempfile
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Optional, Union
|
|
|
|
import numpy as np
|
|
|
|
|
|
# ==========================================
|
|
# Profile Normalization
|
|
# ==========================================
|
|
|
|
def normalize_profile_for_output(
|
|
profile: dict,
|
|
dtype: str,
|
|
nodata,
|
|
count: int = 1,
|
|
) -> dict:
|
|
"""Normalize rasterio profile for output.
|
|
|
|
Args:
|
|
profile: Input rasterio profile (e.g., from DW baseline window)
|
|
dtype: Output data type (e.g., 'uint8', 'uint16', 'float32')
|
|
nodata: Nodata value
|
|
count: Number of bands
|
|
|
|
Returns:
|
|
Normalized profile dictionary
|
|
"""
|
|
# Copy input profile
|
|
out_profile = dict(profile)
|
|
|
|
# Set output-specific values
|
|
out_profile["driver"] = "GTiff"
|
|
out_profile["dtype"] = dtype
|
|
out_profile["nodata"] = nodata
|
|
out_profile["count"] = count
|
|
|
|
# Compression and tiling
|
|
out_profile["tiled"] = True
|
|
|
|
# Determine block size based on raster size
|
|
width = profile.get("width", 0)
|
|
height = profile.get("height", 0)
|
|
|
|
if width * height < 1024 * 1024: # Less than 1M pixels
|
|
block_size = 256
|
|
else:
|
|
block_size = 512
|
|
|
|
out_profile["blockxsize"] = block_size
|
|
out_profile["blockysize"] = block_size
|
|
|
|
# Compression
|
|
out_profile["compress"] = "DEFLATE"
|
|
|
|
# Predictor for compression
|
|
if dtype in ("uint8", "uint16", "int16", "int32"):
|
|
out_profile["predictor"] = 2 # Horizontal differencing
|
|
elif dtype in ("float32", "float64"):
|
|
out_profile["predictor"] = 3 # Floating point prediction
|
|
|
|
# BigTIFF if needed
|
|
out_profile["BIGTIFF"] = "IF_SAFER"
|
|
|
|
return out_profile
|
|
|
|
|
|
# ==========================================
|
|
# GeoTIFF Writing
|
|
# ==========================================
|
|
|
|
def write_geotiff(
|
|
out_path: str,
|
|
arr: np.ndarray,
|
|
profile: dict,
|
|
) -> str:
|
|
"""Write array to GeoTIFF.
|
|
|
|
Args:
|
|
out_path: Output file path
|
|
arr: 2D (H,W) or 3D (count,H,W) numpy array
|
|
profile: Rasterio profile
|
|
|
|
Returns:
|
|
Output path
|
|
"""
|
|
try:
|
|
import rasterio
|
|
from rasterio.io import MemoryFile
|
|
except ImportError:
|
|
raise ImportError("rasterio is required for GeoTIFF writing")
|
|
|
|
arr = np.asarray(arr)
|
|
|
|
# Handle 2D vs 3D arrays
|
|
if arr.ndim == 2:
|
|
count = 1
|
|
arr = arr.reshape(1, *arr.shape)
|
|
elif arr.ndim == 3:
|
|
count = arr.shape[0]
|
|
else:
|
|
raise ValueError(f"Expected 2D or 3D array, got {arr.ndim}D")
|
|
|
|
# Validate dimensions
|
|
if arr.shape[1] != profile.get("height") or arr.shape[2] != profile.get("width"):
|
|
raise ValueError(
|
|
f"Array shape {arr.shape[1:]} doesn't match profile dimensions "
|
|
f"({profile.get('height')}, {profile.get('width')})"
|
|
)
|
|
|
|
# Update profile count
|
|
out_profile = dict(profile)
|
|
out_profile["count"] = count
|
|
out_profile["dtype"] = str(arr.dtype)
|
|
|
|
# Write
|
|
with rasterio.open(out_path, "w", **out_profile) as dst:
|
|
dst.write(arr)
|
|
|
|
return out_path
|
|
|
|
|
|
# ==========================================
|
|
# COG Conversion
|
|
# ==========================================
|
|
|
|
def translate_to_cog(
|
|
src_path: str,
|
|
dst_path: str,
|
|
dtype: Optional[str] = None,
|
|
nodata=None,
|
|
) -> str:
|
|
"""Convert GeoTIFF to Cloud Optimized GeoTIFF.
|
|
|
|
Args:
|
|
src_path: Source GeoTIFF path
|
|
dst_path: Destination COG path
|
|
dtype: Optional output dtype override
|
|
nodata: Optional nodata value override
|
|
|
|
Returns:
|
|
Destination path
|
|
"""
|
|
# Check if rasterio has COG driver
|
|
try:
|
|
import rasterio
|
|
from rasterio import shutil as rio_shutil
|
|
|
|
# Try using rasterio's COG driver
|
|
copy_opts = {
|
|
"driver": "COG",
|
|
"BLOCKSIZE": 512,
|
|
"COMPRESS": "DEFLATE",
|
|
"OVERVIEWS": "NONE", # We'll add overviews separately if needed
|
|
}
|
|
|
|
if dtype:
|
|
copy_opts["dtype"] = dtype
|
|
if nodata is not None:
|
|
copy_opts["nodata"] = nodata
|
|
|
|
rio_shutil.copy(src_path, dst_path, **copy_opts)
|
|
return dst_path
|
|
|
|
except Exception as e:
|
|
# Check for GDAL as fallback
|
|
try:
|
|
subprocess.run(
|
|
["gdal_translate", "--version"],
|
|
capture_output=True,
|
|
check=True,
|
|
)
|
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
raise RuntimeError(
|
|
f"Cannot convert to COG: rasterio failed ({e}) and gdal_translate not available. "
|
|
"Please install GDAL or ensure rasterio has COG support."
|
|
)
|
|
|
|
# Use GDAL as fallback
|
|
cmd = [
|
|
"gdal_translate",
|
|
"-of", "COG",
|
|
"-co", "BLOCKSIZE=512",
|
|
"-co", "COMPRESS=DEFLATE",
|
|
]
|
|
|
|
if dtype:
|
|
cmd.extend(["-ot", dtype])
|
|
if nodata is not None:
|
|
cmd.extend(["-a_nodata", str(nodata)])
|
|
|
|
# Add overviews
|
|
cmd.extend([
|
|
"-co", "OVERVIEWS=IGNORE_EXIST=YES",
|
|
])
|
|
|
|
cmd.extend([src_path, dst_path])
|
|
|
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
if result.returncode != 0:
|
|
raise RuntimeError(
|
|
f"gdal_translate failed: {result.stderr}"
|
|
)
|
|
|
|
# Add overviews using gdaladdo
|
|
try:
|
|
subprocess.run(
|
|
["gdaladdo", "-r", "average", dst_path, "2", "4", "8", "16"],
|
|
capture_output=True,
|
|
check=True,
|
|
)
|
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
# Overviews are optional, continue without them
|
|
pass
|
|
|
|
return dst_path
|
|
|
|
|
|
def translate_to_cog_with_retry(
|
|
src_path: str,
|
|
dst_path: str,
|
|
dtype: Optional[str] = None,
|
|
nodata=None,
|
|
max_retries: int = 3,
|
|
) -> str:
|
|
"""Convert GeoTIFF to COG with retry logic.
|
|
|
|
Args:
|
|
src_path: Source GeoTIFF path
|
|
dst_path: Destination COG path
|
|
dtype: Optional output dtype override
|
|
nodata: Optional nodata value override
|
|
max_retries: Maximum retry attempts
|
|
|
|
Returns:
|
|
Destination path
|
|
"""
|
|
last_error = None
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
return translate_to_cog(src_path, dst_path, dtype, nodata)
|
|
except Exception as e:
|
|
last_error = e
|
|
if attempt < max_retries - 1:
|
|
wait_time = 2 ** attempt # Exponential backoff
|
|
time.sleep(wait_time)
|
|
continue
|
|
|
|
raise RuntimeError(
|
|
f"Failed to convert to COG after {max_retries} retries. "
|
|
f"Last error: {last_error}"
|
|
)
|
|
|
|
|
|
# ==========================================
|
|
# Convenience Wrapper
|
|
# ==========================================
|
|
|
|
def write_cog(
|
|
dst_path: str,
|
|
arr: np.ndarray,
|
|
base_profile: dict,
|
|
dtype: str,
|
|
nodata,
|
|
) -> str:
|
|
"""Write array as COG.
|
|
|
|
Convenience wrapper that:
|
|
1. Creates temp GeoTIFF
|
|
2. Converts to COG
|
|
3. Cleans up temp file
|
|
|
|
Args:
|
|
dst_path: Destination COG path
|
|
arr: 2D or 3D numpy array
|
|
base_profile: Base rasterio profile
|
|
dtype: Output data type
|
|
nodata: Nodata value
|
|
|
|
Returns:
|
|
Destination COG path
|
|
"""
|
|
# Normalize profile
|
|
profile = normalize_profile_for_output(
|
|
base_profile,
|
|
dtype=dtype,
|
|
nodata=nodata,
|
|
count=arr.shape[0] if arr.ndim == 3 else 1,
|
|
)
|
|
|
|
# Create temp file for intermediate GeoTIFF
|
|
with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmp:
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
# Write intermediate GeoTIFF
|
|
write_geotiff(tmp_path, arr, profile)
|
|
|
|
# Convert to COG
|
|
translate_to_cog_with_retry(tmp_path, dst_path, dtype=dtype, nodata=nodata)
|
|
|
|
finally:
|
|
# Cleanup temp file
|
|
if os.path.exists(tmp_path):
|
|
os.remove(tmp_path)
|
|
|
|
return dst_path
|
|
|
|
|
|
# ==========================================
|
|
# Self-Test
|
|
# ==========================================
|
|
|
|
if __name__ == "__main__":
|
|
print("=== COG Module Self-Test ===")
|
|
|
|
# Check for rasterio
|
|
try:
|
|
import rasterio
|
|
except ImportError:
|
|
print("rasterio not available - skipping test")
|
|
import sys
|
|
sys.exit(0)
|
|
|
|
print("\n1. Testing normalize_profile_for_output...")
|
|
|
|
# Create minimal profile
|
|
base_profile = {
|
|
"driver": "GTiff",
|
|
"height": 128,
|
|
"width": 128,
|
|
"count": 1,
|
|
"crs": "EPSG:4326",
|
|
"transform": [0.0, 1.0, 0.0, 0.0, 0.0, -1.0],
|
|
}
|
|
|
|
# Test with uint8
|
|
out_profile = normalize_profile_for_output(
|
|
base_profile,
|
|
dtype="uint8",
|
|
nodata=0,
|
|
)
|
|
|
|
print(f" Driver: {out_profile.get('driver')}")
|
|
print(f" Dtype: {out_profile.get('dtype')}")
|
|
print(f" Tiled: {out_profile.get('tiled')}")
|
|
print(f" Block size: {out_profile.get('blockxsize')}x{out_profile.get('blockysize')}")
|
|
print(f" Compress: {out_profile.get('compress')}")
|
|
print(" ✓ normalize_profile test PASSED")
|
|
|
|
print("\n2. Testing write_geotiff...")
|
|
|
|
# Create synthetic array
|
|
arr = np.random.randint(0, 256, size=(128, 128), dtype=np.uint8)
|
|
arr[10:20, 10:20] = 0 # nodata holes
|
|
|
|
out_path = "/tmp/test_output.tif"
|
|
write_geotiff(out_path, arr, out_profile)
|
|
|
|
print(f" Written to: {out_path}")
|
|
print(f" File size: {os.path.getsize(out_path)} bytes")
|
|
|
|
# Verify read back
|
|
with rasterio.open(out_path) as src:
|
|
read_arr = src.read(1)
|
|
print(f" Read back shape: {read_arr.shape}")
|
|
print(" ✓ write_geotiff test PASSED")
|
|
|
|
# Cleanup
|
|
os.remove(out_path)
|
|
|
|
print("\n3. Testing write_cog...")
|
|
|
|
# Write as COG
|
|
cog_path = "/tmp/test_cog.tif"
|
|
write_cog(cog_path, arr, base_profile, dtype="uint8", nodata=0)
|
|
|
|
print(f" Written to: {cog_path}")
|
|
print(f" File size: {os.path.getsize(cog_path)} bytes")
|
|
|
|
# Verify read back
|
|
with rasterio.open(cog_path) as src:
|
|
read_arr = src.read(1)
|
|
print(f" Read back shape: {read_arr.shape}")
|
|
print(f" Profile: driver={src.driver}, count={src.count}")
|
|
print(" ✓ write_cog test PASSED")
|
|
|
|
# Cleanup
|
|
os.remove(cog_path)
|
|
|
|
print("\n=== COG Module Test Complete ===")
|