VT45 / main.py
Ashrafb's picture
Update main.py
fb700c4 verified
raw
history blame
2.43 kB
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
import torch
import cv2
import numpy as np
import logging
from io import BytesIO
app = FastAPI()
# Load model and necessary components
model = None
def load_model():
global model
from vtoonify_model import Model
model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
model.load_model('cartoon4')
# Configure logging
logging.basicConfig(level=logging.INFO)
@app.post("/upload/")
async def process_image(file: UploadFile = File(...), top: int = Form(...), bottom: int = Form(...), left: int = Form(...), right: int = Form(...)):
global model
if model is None:
load_model()
# Read the uploaded image file
contents = await file.read()
# Convert the uploaded image to numpy array
nparr = np.frombuffer(contents, np.uint8)
frame_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if frame_bgr is None:
logging.error("Failed to decode the image.")
return {"error": "Failed to decode the image. Please ensure the file is a valid image format."}
logging.info(f"Uploaded image shape: {frame_bgr.shape}")
# Process the uploaded image
aligned_face, instyle, message = model.detect_and_align_image(frame_bgr, top, bottom, left, right)
if aligned_face is None or instyle is None:
logging.error("Failed to process the image: No face detected or alignment failed.")
return {"error": message}
processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon1')
if processed_image is None:
logging.error("Failed to toonify the image.")
return {"error": message}
# Convert BGR to RGB for display
processed_image_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
# Convert processed image to bytes
_, encoded_image = cv2.imencode('.jpg', processed_image_rgb)
# Return the processed image as a streaming response
return StreamingResponse(BytesIO(encoded_image.tobytes()), media_type="image/jpeg")
# Mount static files directory
app.mount("/", StaticFiles(directory="AB", html=True), name="static")
# Define index route
@app.get("/")
def index():
from fastapi.responses import FileResponse
return FileResponse(path="/app/AB/index.html", media_type="text/html")