File size: 2,432 Bytes
f647f39 8f49770 f647f39 fb700c4 f647f39 00227e3 f647f39 8b9baab f647f39 8f49770 f647f39 fb700c4 f647f39 fb700c4 8b9baab af4a7d4 8b9baab f647f39 fb700c4 f647f39 fb700c4 f647f39 fb700c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
import torch
import cv2
import numpy as np
import logging
from io import BytesIO
app = FastAPI()
# Load model and necessary components
model = None
def load_model():
global model
from vtoonify_model import Model
model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
model.load_model('cartoon4')
# Configure logging
logging.basicConfig(level=logging.INFO)
@app.post("/upload/")
async def process_image(file: UploadFile = File(...), top: int = Form(...), bottom: int = Form(...), left: int = Form(...), right: int = Form(...)):
global model
if model is None:
load_model()
# Read the uploaded image file
contents = await file.read()
# Convert the uploaded image to numpy array
nparr = np.frombuffer(contents, np.uint8)
frame_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if frame_bgr is None:
logging.error("Failed to decode the image.")
return {"error": "Failed to decode the image. Please ensure the file is a valid image format."}
logging.info(f"Uploaded image shape: {frame_bgr.shape}")
# Process the uploaded image
aligned_face, instyle, message = model.detect_and_align_image(frame_bgr, top, bottom, left, right)
if aligned_face is None or instyle is None:
logging.error("Failed to process the image: No face detected or alignment failed.")
return {"error": message}
processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon1')
if processed_image is None:
logging.error("Failed to toonify the image.")
return {"error": message}
# Convert BGR to RGB for display
processed_image_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
# Convert processed image to bytes
_, encoded_image = cv2.imencode('.jpg', processed_image_rgb)
# Return the processed image as a streaming response
return StreamingResponse(BytesIO(encoded_image.tobytes()), media_type="image/jpeg")
# Mount static files directory
app.mount("/", StaticFiles(directory="AB", html=True), name="static")
# Define index route
@app.get("/")
def index():
from fastapi.responses import FileResponse
return FileResponse(path="/app/AB/index.html", media_type="text/html") |