jree423's picture
Upload api.py with huggingface_hub
a2fd1ce verified
import os
import io
import logging
from typing import Optional, Dict, Any, Union
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from PIL import Image
import base64
# Import the handler
from handler import EndpointHandler
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Initialize the FastAPI app
app = FastAPI(title="diffsketcher_edit API", description="API for diffsketcher_edit text-to-SVG generation")
# Initialize the handler
model_dir = os.environ.get("MODEL_DIR", "/code/model_weights")
handler = EndpointHandler(model_dir)
logger.info(f"Initialized handler with model_dir: {model_dir}")
class TextToImageRequest(BaseModel):
inputs: Union[str, Dict[str, Any]]
@app.post("/")
async def generate_image(request: TextToImageRequest):
# Generate an image from a text prompt
try:
logger.info(f"Received request: {request}")
# Process the request using the handler
image = handler(request.dict())
# Convert the image to bytes
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
# Return the image as base64
return {"image": base64.b64encode(img_byte_arr).decode('utf-8')}
except Exception as e:
logger.error(f"Error processing request: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
# Health check endpoint
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)