VT45 / main.py
Ashrafb's picture
Update main.py
d081ae5 verified
raw
history blame
3.35 kB
import os
import cv2
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.staticfiles import StaticFiles
import numpy as np
import logging
from io import BytesIO
import tempfile
from mtcnn import MTCNN
from vtoonify_model import Model # Import VToonify model
app = FastAPI()
# Initialize the VToonify model and MTCNN detector
model = None
detector = MTCNN()
def load_model():
global model
model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
model.load_model('cartoon1')
# Configure logging
logging.basicConfig(level=logging.INFO)
@app.post("/upload/")
async def process_image(file: UploadFile = File(...)):
global model
if model is None:
load_model()
# Read the uploaded image file
contents = await file.read()
# Convert the uploaded image to a numpy array
nparr = np.frombuffer(contents, np.uint8)
frame_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # Read as BGR format by default
if frame_bgr is None:
logging.error("Failed to decode the image.")
return {"error": "Failed to decode the image. Please ensure the file is a valid image format."}
logging.info(f"Uploaded image shape: {frame_bgr.shape}")
# Convert BGR to RGB for MTCNN
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
# Detect faces using MTCNN
results = detector.detect_faces(frame_rgb)
if len(results) == 0:
logging.error("No faces detected in the image.")
return {"error": "No faces detected in the image."}
# Use the first detected face
x, y, width, height = results[0]['box']
cropped_face = frame_rgb[y:y+height, x:x+width]
# Save the cropped face temporarily to pass the file path to the model
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
cv2.imwrite(temp_file.name, cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR))
temp_file_path = temp_file.name
try:
# Process the cropped face using VToonify
aligned_face, instyle, message = model.detect_and_align_image(temp_file_path, 0, 0, 0, 0)
if aligned_face is None or instyle is None:
logging.error("Failed to process the image: No face detected or alignment failed.")
return {"error": message}
processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon1')
if processed_image is None:
logging.error("Failed to toonify the image.")
return {"error": message}
# Convert the processed image to RGB before returning
processed_image_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
# Convert processed image to bytes
_, encoded_image = cv2.imencode('.jpg', processed_image_rgb)
# Return the processed image as a streaming response
return StreamingResponse(BytesIO(encoded_image.tobytes()), media_type="image/jpeg")
finally:
# Clean up the temporary file
os.remove(temp_file_path)
# Mount static files directory
app.mount("/", StaticFiles(directory="static", html=True), name="static")
# Define index route
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")