from fastapi import FastAPI from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel import types import json from pydantic import validator from endpoint_handler import EndpointHandler # your handler file import base64 app = FastAPI() handler = None @app.on_event("startup") async def load_handler(): global handler handler = EndpointHandler() class PredictInput(BaseModel): image: str # base64-encoded image string question: str stream: bool = False @validator("question") def question_not_empty(cls, v): if not v.strip(): raise ValueError("Question must not be empty") return v @validator("image") def valid_base64_and_size(cls, v): try: decoded = base64.b64decode(v, validate=True) except Exception: raise ValueError("`image` must be valid base64") if len(decoded) > 10 * 1024 * 1024: # 10 MB limit raise ValueError("Image exceeds 10 MB after decoding") return v class PredictRequest(BaseModel): inputs: PredictInput @app.get("/") async def root(): return {"message": "FastAPI app is running on Hugging Face"} @app.post("/predict") async def predict_endpoint(payload: PredictRequest): print(f"[Request] Received question: {payload.inputs.question}") data = { "inputs": { "image": payload.inputs.image, "question": payload.inputs.question, "stream": payload.inputs.stream } } try: result = handler.predict(data) except ValueError as ve: return JSONResponse({"error": str(ve)}, status_code=400) except Exception as e: return JSONResponse({"error": "Internal server error"}, status_code=500) # ─── If it's a generator, return SSE/streaming ────────────────────────── if isinstance(result, types.GeneratorType): def event_stream(): try: for chunk in result: # Each chunk should already be a Python dict or # a string containing JSON. We wrap it in "data: …\n\n" yield f"data: {json.dumps(chunk)}\n\n" # Finally send an end‐of‐stream marker yield f"data: {json.dumps({'end': True})}\n\n" except Exception as e: yield f"data: {json.dumps({'error': str(e)})}\n\n" return StreamingResponse(event_stream(), media_type="text/event-stream") # ─── Otherwise (non‐streaming), return a single JSON response ────────── # result is expected to be a JSON‐string or a dict try: # If handler.predict returned a JSON‐encoded str, parse it to dict if isinstance(result, str): parsed = json.loads(result) else: parsed = result # assume it's already a dict except Exception: # Fall back to returning the raw result return JSONResponse({"error": "Invalid JSON from handler"}, status_code=500) return JSONResponse(parsed)