jree423 commited on
Commit
a2fd1ce
·
verified ·
1 Parent(s): 2f06631

Upload api.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. api.py +56 -0
api.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import logging
4
+ from typing import Optional, Dict, Any, Union
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel
7
+ from PIL import Image
8
+ import base64
9
+
10
+ # Import the handler
11
+ from handler import EndpointHandler
12
+
13
+ # Configure logging
14
+ logging.basicConfig(level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Initialize the FastAPI app
19
+ app = FastAPI(title="diffsketcher_edit API", description="API for diffsketcher_edit text-to-SVG generation")
20
+
21
+ # Initialize the handler
22
+ model_dir = os.environ.get("MODEL_DIR", "/code/model_weights")
23
+ handler = EndpointHandler(model_dir)
24
+ logger.info(f"Initialized handler with model_dir: {model_dir}")
25
+
26
+ class TextToImageRequest(BaseModel):
27
+ inputs: Union[str, Dict[str, Any]]
28
+
29
+ @app.post("/")
30
+ async def generate_image(request: TextToImageRequest):
31
+ # Generate an image from a text prompt
32
+ try:
33
+ logger.info(f"Received request: {request}")
34
+
35
+ # Process the request using the handler
36
+ image = handler(request.dict())
37
+
38
+ # Convert the image to bytes
39
+ img_byte_arr = io.BytesIO()
40
+ image.save(img_byte_arr, format='PNG')
41
+ img_byte_arr = img_byte_arr.getvalue()
42
+
43
+ # Return the image as base64
44
+ return {"image": base64.b64encode(img_byte_arr).decode('utf-8')}
45
+ except Exception as e:
46
+ logger.error(f"Error processing request: {e}")
47
+ raise HTTPException(status_code=500, detail=str(e))
48
+
49
+ @app.get("/health")
50
+ async def health_check():
51
+ # Health check endpoint
52
+ return {"status": "ok"}
53
+
54
+ if __name__ == "__main__":
55
+ import uvicorn
56
+ uvicorn.run(app, host="0.0.0.0", port=8000)