VT45 / main.py
Ashrafb's picture
Update main.py
c175957 verified
raw
history blame
1.81 kB
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.staticfiles import StaticFiles
import torch
import cv2
import numpy as np
from io import BytesIO
app = FastAPI()
# Load model and necessary components
model = None
def load_model():
global model
from vtoonify_model import Model
model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
model.load_model('cartoon4')
@app.post("/upload/")
async def process_image(file: UploadFile = File(...)):
global model
if model is None:
load_model()
# Read the uploaded image file
contents = await file.read()
# Convert the uploaded image to numpy array
nparr = np.frombuffer(contents, np.uint8)
frame_rgb = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# Automatically detect and align the face
aligned_face, instyle, message = model.detect_and_align_image(frame_rgb, 0, 0, 0, 0)
if instyle is None:
return {"error": "No face detected. Please try a different image."}
# Process the uploaded image
processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon4')
# Convert BGR to RGB
processed_image_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
# Convert processed image to bytes
_, encoded_image = cv2.imencode('.jpg', processed_image_rgb)
# Return the processed image as a streaming response
return StreamingResponse(BytesIO(encoded_image.tobytes()), media_type="image/jpeg")
# Mount static files directory
app.mount("/", StaticFiles(directory="AB", html=True), name="static")
# Define index route
@app.get("/")
def index():
return FileResponse(path="/app/AB/index.html", media_type="text/html")