import os import io import logging from typing import Optional, Dict, Any, Union from fastapi import FastAPI, HTTPException from pydantic import BaseModel from PIL import Image import base64 # Import the handler from handler import EndpointHandler # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Initialize the FastAPI app app = FastAPI(title="diffsketcher_edit API", description="API for diffsketcher_edit text-to-SVG generation") # Initialize the handler model_dir = os.environ.get("MODEL_DIR", "/code/model_weights") handler = EndpointHandler(model_dir) logger.info(f"Initialized handler with model_dir: {model_dir}") class TextToImageRequest(BaseModel): inputs: Union[str, Dict[str, Any]] @app.post("/") async def generate_image(request: TextToImageRequest): # Generate an image from a text prompt try: logger.info(f"Received request: {request}") # Process the request using the handler image = handler(request.dict()) # Convert the image to bytes img_byte_arr = io.BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() # Return the image as base64 return {"image": base64.b64encode(img_byte_arr).decode('utf-8')} except Exception as e: logger.error(f"Error processing request: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): # Health check endpoint return {"status": "ok"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)