from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks from fastapi.responses import JSONResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, validator import base64, types, json, logging, asyncio from endpoint_handler import EndpointHandler logger = logging.getLogger("uvicorn.error") app = FastAPI() handler = None @app.on_event("startup") async def load_handler(): global handler try: handler = EndpointHandler() logger.info("EndpointHandler initialized successfully.") except Exception as e: logger.error(f"Failed to initialize handler: {e}", exc_info=True) handler = None # so /predict will return 503 @app.on_event("shutdown") async def cleanup_handler(): if handler: try: handler.close() logger.info("Handler cleaned up on shutdown.") except Exception: logger.error("Error during handler cleanup", exc_info=True) class PredictInput(BaseModel): image: str question: str stream: bool = False @validator("question") def question_not_empty(cls, v): if not v.strip(): raise ValueError("Question must not be empty") return v @validator("image") def valid_base64(cls, v): try: base64.b64decode(v, validate=True) except Exception: raise ValueError("`image` must be valid base64") return v class PredictRequest(BaseModel): inputs: PredictInput @app.post("/predict") async def predict_endpoint(payload: PredictRequest): if handler is None: return JSONResponse({"error": "Service unavailable"}, status_code=503) # Log input logger.info(f"Received question: {payload.inputs.question}") # Prepare the data dict exactly how EndpointHandler expects request_dict = { "inputs": { "image": payload.inputs.image, "question": payload.inputs.question, "stream": payload.inputs.stream, } } try: result = await asyncio.to_thread(handler.predict, request_dict) except ValueError as ve: return JSONResponse({"error": str(ve)}, status_code=400) except Exception as e: logger.error("Unexpected error in handler.predict", exc_info=True) return JSONResponse({"error": "Internal server error"}, status_code=500) # If handler.predict returned a generator, wrap in SSE if isinstance(result, types.GeneratorType): async def event_stream(): try: for chunk in result: yield f"data: {json.dumps(chunk)}\n\n" except (asyncio.CancelledError, ConnectionResetError): logger.info("Client disconnected from stream.") except Exception: logger.error("Error during streaming", exc_info=True) yield f"data: {json.dumps({'error': 'Stream error'})}\n\n" return StreamingResponse(event_stream(), media_type="text/event-stream") # Otherwise normal JSON return JSONResponse(content=result, status_code=200)