geocrop-platform./training/storage_client.py

104 lines
4.4 KiB
Python

import boto3
import os
import logging
import pandas as pd
import io
from botocore.exceptions import ClientError
# Configure logging for the worker/training scripts
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class MinIOStorageClient:
"""
A reusable client for interacting with the local MinIO storage.
Handles Datasets (CSVs), Baselines (DW TIFFs), and Model Artifacts.
"""
def __init__(self):
# Initialize S3 client using environment variables
# Defaults to the internal Kubernetes DNS for MinIO if not provided
self.endpoint_url = os.environ.get('AWS_S3_ENDPOINT_URL', 'http://minio.geocrop.svc.cluster.local:9000')
self.s3_client = boto3.client(
's3',
endpoint_url=self.endpoint_url,
aws_access_key_id=os.environ.get('AWS_ACCESS_KEY_ID'),
aws_secret_access_key=os.environ.get('AWS_SECRET_ACCESS_KEY'),
config=boto3.session.Config(signature_version='s3v4'),
verify=False # MinIO often uses self-signed certs internally
)
def list_files(self, bucket_name: str, prefix: str = "") -> list:
"""Lists files in a specific bucket, optionally filtered by a prefix folder."""
try:
response = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
files = []
logger.info(f"Scanning bucket '{bucket_name}' with prefix '{prefix}'...")
if 'Contents' in response:
for obj in response['Contents']:
size_mb = obj['Size'] / 1024 / 1024
files.append({"key": obj['Key'], "size_mb": size_mb})
logger.info(f" - Found: {obj['Key']} ({size_mb:.2f} MB)")
else:
logger.warning(f"No files found in bucket {bucket_name}.")
return files
except ClientError as e:
logger.error(f"Failed to list files in {bucket_name}: {e}")
return []
def download_file(self, bucket_name: str, object_name: str, download_path: str) -> bool:
"""Downloads a file from MinIO to the local pod/container storage."""
try:
logger.info(f"Downloading {object_name} from {bucket_name} to {download_path}...")
self.s3_client.download_file(bucket_name, object_name, download_path)
logger.info("Download complete.")
return True
except ClientError as e:
logger.error(f"Error downloading {object_name}: {e}")
return False
def upload_file(self, file_path: str, bucket_name: str, object_name: str) -> bool:
"""Uploads a local file (like a trained model or prediction COG) back to MinIO."""
try:
logger.info(f"Uploading {file_path} to {bucket_name}/{object_name}...")
self.s3_client.upload_file(file_path, bucket_name, object_name)
logger.info("Upload complete.")
return True
except ClientError as e:
logger.error(f"Error uploading {file_path}: {e}")
return False
def load_dataset(self, bucket_name: str, object_name: str) -> pd.DataFrame:
"""Loads a CSV dataset directly from MinIO into a Pandas DataFrame in memory."""
try:
logger.info(f"Loading {object_name} from {bucket_name} into memory...")
response = self.s3_client.get_object(Bucket=bucket_name, Key=object_name)
df = pd.read_csv(io.BytesIO(response['Body'].read()))
logger.info(f"Successfully loaded dataset with shape: {df.shape}")
return df
except ClientError as e:
logger.error(f"Error loading {object_name}: {e}")
return None
except Exception as e:
logger.error(f"Error parsing {object_name} into DataFrame: {e}")
return None
# ==========================================
# Example Usage (For your Jupyter Notebooks)
# ==========================================
if __name__ == "__main__":
storage = MinIOStorageClient()
# 1. List the Zimbabwe Augmented CSV batches
datasets_bucket = 'geocrop-datasets'
csv_files = storage.list_files(datasets_bucket)
# 2. Load a batch directly into memory for training
if csv_files:
first_batch_key = csv_files[0]['key']
df = storage.load_dataset(datasets_bucket, first_batch_key)
if df is not None:
print(df.info())