File size: 3,352 Bytes
d797d75
 
 
7612bf6
f647f39
 
7612bf6
f647f39
5f1c244
d797d75
 
00227e3
f647f39
 
d797d75
f647f39
d797d75
f647f39
 
 
 
d081ae5
f647f39
7612bf6
 
 
f647f39
d797d75
f647f39
 
 
 
 
 
 
d797d75
ec45ac1
 
d797d75
ec45ac1
 
 
 
 
 
d797d75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df0eb28
d797d75
df0eb28
f647f39
5f1c244
d797d75
 
df0eb28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d797d75
df0eb28
 
 
c7c2154
df0eb28
d797d75
f647f39
df0eb28
f647f39
d797d75
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import cv2
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.staticfiles import StaticFiles
import numpy as np
import logging
from io import BytesIO
import tempfile
from mtcnn import MTCNN
from vtoonify_model import Model  # Import VToonify model

app = FastAPI()

# Initialize the VToonify model and MTCNN detector
model = None
detector = MTCNN()

def load_model():
    global model
    model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
    model.load_model('cartoon1')

# Configure logging
logging.basicConfig(level=logging.INFO)

@app.post("/upload/")
async def process_image(file: UploadFile = File(...)):
    global model
    if model is None:
        load_model()

    # Read the uploaded image file
    contents = await file.read()

    # Convert the uploaded image to a numpy array
    nparr = np.frombuffer(contents, np.uint8)
    frame_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)  # Read as BGR format by default

    if frame_bgr is None:
        logging.error("Failed to decode the image.")
        return {"error": "Failed to decode the image. Please ensure the file is a valid image format."}

    logging.info(f"Uploaded image shape: {frame_bgr.shape}")

    # Convert BGR to RGB for MTCNN
    frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)

    # Detect faces using MTCNN
    results = detector.detect_faces(frame_rgb)

    if len(results) == 0:
        logging.error("No faces detected in the image.")
        return {"error": "No faces detected in the image."}

    # Use the first detected face
    x, y, width, height = results[0]['box']
    cropped_face = frame_rgb[y:y+height, x:x+width]

    # Save the cropped face temporarily to pass the file path to the model
    with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
        cv2.imwrite(temp_file.name, cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR))
        temp_file_path = temp_file.name

    try:
        # Process the cropped face using VToonify
        aligned_face, instyle, message = model.detect_and_align_image(temp_file_path, 0, 0, 0, 0)
        if aligned_face is None or instyle is None:
            logging.error("Failed to process the image: No face detected or alignment failed.")
            return {"error": message}

        processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon1')
        if processed_image is None:
            logging.error("Failed to toonify the image.")
            return {"error": message}

        # Convert the processed image to RGB before returning
        processed_image_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)

        # Convert processed image to bytes
        _, encoded_image = cv2.imencode('.jpg', processed_image_rgb)

        # Return the processed image as a streaming response
        return StreamingResponse(BytesIO(encoded_image.tobytes()), media_type="image/jpeg")

    finally:
        # Clean up the temporary file
        os.remove(temp_file_path)

# Mount static files directory
app.mount("/", StaticFiles(directory="static", html=True), name="static")

# Define index route
@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/static/index.html", media_type="text/html")