235 lines
7.8 KiB
Python
235 lines
7.8 KiB
Python
from fastapi import FastAPI, Depends, HTTPException, status
|
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
from pydantic import BaseModel, EmailStr
|
|
from datetime import datetime, timedelta
|
|
import jwt
|
|
from passlib.context import CryptContext
|
|
from redis import Redis
|
|
from rq import Queue
|
|
from rq.job import Job
|
|
import os
|
|
from typing import List, Optional
|
|
|
|
# --- Configuration ---
|
|
SECRET_KEY = os.getenv("SECRET_KEY", "your-super-secret-portfolio-key-change-this")
|
|
ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 1440
|
|
|
|
# Redis Connection
|
|
REDIS_HOST = os.getenv("REDIS_HOST", "redis.geocrop.svc.cluster.local")
|
|
redis_conn = Redis(host=REDIS_HOST, port=6379)
|
|
task_queue = Queue('geocrop_tasks', connection=redis_conn)
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
app = FastAPI(title="GeoCrop API", version="1.1")
|
|
|
|
# Add CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["https://portfolio.techarvest.co.zw", "http://localhost:5173"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")
|
|
|
|
# In-memory DB
|
|
USERS = {
|
|
"fchinembiri24@gmail.com": {
|
|
"email": "fchinembiri24@gmail.com",
|
|
"hashed_password": "$2b$12$iyR6fFeQAd2CfCDm/CdTSeB8CIjJhAHjA6Et7/UMWm0i0nIAFu21W",
|
|
"is_active": True,
|
|
"is_admin": True,
|
|
"login_count": 0,
|
|
"login_limit": 9999
|
|
}
|
|
}
|
|
|
|
class UserCreate(BaseModel):
|
|
email: EmailStr
|
|
password: str
|
|
login_limit: int = 3
|
|
|
|
class UserResponse(BaseModel):
|
|
email: EmailStr
|
|
is_active: bool
|
|
is_admin: bool
|
|
login_count: int
|
|
login_limit: int
|
|
|
|
class Token(BaseModel):
|
|
access_token: str
|
|
token_type: str
|
|
is_admin: bool
|
|
|
|
class InferenceJobRequest(BaseModel):
|
|
lat: float
|
|
lon: float
|
|
radius_km: float
|
|
year: str
|
|
model_name: str
|
|
|
|
def create_access_token(data: dict, expires_delta: timedelta):
|
|
to_encode = data.copy()
|
|
expire = datetime.utcnow() + expires_delta
|
|
to_encode.update({"exp": expire})
|
|
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
email: str = payload.get("sub")
|
|
if email is None or email not in USERS:
|
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
|
return USERS[email]
|
|
except jwt.PyJWTError:
|
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
|
|
|
async def get_admin_user(current_user: dict = Depends(get_current_user)):
|
|
if not current_user.get("is_admin"):
|
|
raise HTTPException(status_code=403, detail="Admin privileges required")
|
|
return current_user
|
|
|
|
@app.post("/auth/login", response_model=Token, tags=["Authentication"])
|
|
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
|
username = form_data.username.strip()
|
|
password = form_data.password.strip()
|
|
|
|
# Check Admin Bypass
|
|
if username == "fchinembiri24@gmail.com" and password == "P@55w0rd.123":
|
|
user = USERS["fchinembiri24@gmail.com"]
|
|
user["login_count"] += 1
|
|
access_token = create_access_token(
|
|
data={"sub": user["email"]},
|
|
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
)
|
|
return {"access_token": access_token, "token_type": "bearer", "is_admin": True}
|
|
|
|
user = USERS.get(username)
|
|
if not user or not pwd_context.verify(password, user["hashed_password"]):
|
|
raise HTTPException(status_code=401, detail="Incorrect email or password")
|
|
|
|
if user["login_count"] >= user.get("login_limit", 3):
|
|
raise HTTPException(status_code=403, detail=f"Login limit reached.")
|
|
|
|
user["login_count"] += 1
|
|
access_token = create_access_token(
|
|
data={"sub": user["email"]},
|
|
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
)
|
|
return {"access_token": access_token, "token_type": "bearer", "is_admin": user.get("is_admin", False)}
|
|
|
|
@app.get("/admin/users", response_model=List[UserResponse], tags=["Admin"])
|
|
async def list_users(admin: dict = Depends(get_admin_user)):
|
|
return [
|
|
{
|
|
"email": u["email"],
|
|
"is_active": u["is_active"],
|
|
"is_admin": u.get("is_admin", False),
|
|
"login_count": u.get("login_count", 0),
|
|
"login_limit": u.get("login_limit", 3)
|
|
}
|
|
for u in USERS.values()
|
|
]
|
|
|
|
@app.post("/admin/users", response_model=UserResponse, tags=["Admin"])
|
|
async def create_user(user_in: UserCreate, admin: dict = Depends(get_admin_user)):
|
|
if user_in.email in USERS:
|
|
raise HTTPException(status_code=400, detail="User already exists")
|
|
|
|
USERS[user_in.email] = {
|
|
"email": user_in.email,
|
|
"hashed_password": pwd_context.hash(user_in.password),
|
|
"is_active": True,
|
|
"is_admin": False,
|
|
"login_count": 0,
|
|
"login_limit": user_in.login_limit
|
|
}
|
|
return {
|
|
"email": user_in.email,
|
|
"is_active": True,
|
|
"is_admin": False,
|
|
"login_count": 0,
|
|
"login_limit": user_in.login_limit
|
|
}
|
|
|
|
@app.post("/jobs", tags=["Inference"])
|
|
async def create_inference_job(job_req: InferenceJobRequest, current_user: dict = Depends(get_current_user)):
|
|
if job_req.radius_km > 5.0:
|
|
raise HTTPException(status_code=400, detail="Radius exceeds 5km limit.")
|
|
|
|
job = task_queue.enqueue(
|
|
'worker.run_inference',
|
|
job_req.model_dump(),
|
|
job_timeout='25m'
|
|
)
|
|
return {"job_id": job.id, "status": "queued"}
|
|
|
|
@app.get("/jobs/{job_id}", tags=["Inference"])
|
|
async def get_job_status(job_id: str, current_user: dict = Depends(get_current_user)):
|
|
try:
|
|
job = Job.fetch(job_id, connection=redis_conn)
|
|
except Exception:
|
|
raise HTTPException(status_code=404, detail="Job not found")
|
|
|
|
# Try to get detailed status from custom Redis key
|
|
detailed_status = None
|
|
try:
|
|
status_bytes = redis_conn.get(f"job:{job_id}:status")
|
|
if status_bytes:
|
|
import json
|
|
detailed_status = json.loads(status_bytes.decode('utf-8'))
|
|
except Exception as e:
|
|
print(f"Error fetching detailed status: {e}")
|
|
|
|
# Extract ROI from job args
|
|
roi = None
|
|
if job.args and len(job.args) > 0:
|
|
args = job.args[0]
|
|
if isinstance(args, dict):
|
|
roi = {
|
|
"lat": args.get("lat"),
|
|
"lon": args.get("lon"),
|
|
"radius_m": int(float(args.get("radius_km", 0)) * 1000) if "radius_km" in args else args.get("radius_m")
|
|
}
|
|
|
|
if job.is_finished:
|
|
result = job.result
|
|
# If detailed status has outputs, prefer those
|
|
if detailed_status and "outputs" in detailed_status:
|
|
result = detailed_status["outputs"]
|
|
|
|
return {
|
|
"job_id": job.id,
|
|
"status": "finished",
|
|
"result": result,
|
|
"detailed": detailed_status,
|
|
"roi": roi
|
|
}
|
|
elif job.is_failed:
|
|
return {
|
|
"job_id": job.id,
|
|
"status": "failed",
|
|
"error": detailed_status.get("error") if detailed_status else None,
|
|
"roi": roi
|
|
}
|
|
else:
|
|
status = job.get_status()
|
|
# If we have detailed status, use its status/stage/progress
|
|
response = {
|
|
"job_id": job.id,
|
|
"status": status,
|
|
"roi": roi
|
|
}
|
|
if detailed_status:
|
|
response.update({
|
|
"worker_status": detailed_status.get("status"),
|
|
"stage": detailed_status.get("stage"),
|
|
"progress": detailed_status.get("progress"),
|
|
"message": detailed_status.get("message"),
|
|
})
|
|
return response
|