File size: 1,744 Bytes
a2fd1ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import os
import io
import logging
from typing import Optional, Dict, Any, Union
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from PIL import Image
import base64
# Import the handler
from handler import EndpointHandler
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Initialize the FastAPI app
app = FastAPI(title="diffsketcher_edit API", description="API for diffsketcher_edit text-to-SVG generation")
# Initialize the handler
model_dir = os.environ.get("MODEL_DIR", "/code/model_weights")
handler = EndpointHandler(model_dir)
logger.info(f"Initialized handler with model_dir: {model_dir}")
class TextToImageRequest(BaseModel):
inputs: Union[str, Dict[str, Any]]
@app.post("/")
async def generate_image(request: TextToImageRequest):
# Generate an image from a text prompt
try:
logger.info(f"Received request: {request}")
# Process the request using the handler
image = handler(request.dict())
# Convert the image to bytes
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
# Return the image as base64
return {"image": base64.b64encode(img_byte_arr).decode('utf-8')}
except Exception as e:
logger.error(f"Error processing request: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
# Health check endpoint
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|