Ashrafb commited on
Commit
d797d75
·
verified ·
1 Parent(s): df0eb28

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +32 -17
main.py CHANGED
@@ -1,22 +1,23 @@
1
- from fastapi import FastAPI, File, UploadFile, Form
 
 
2
  from fastapi.responses import StreamingResponse, FileResponse
3
  from fastapi.staticfiles import StaticFiles
4
- import torch
5
- import cv2
6
  import numpy as np
7
  import logging
8
  from io import BytesIO
9
  import tempfile
10
- import os
 
11
 
12
  app = FastAPI()
13
 
14
- # Load model and necessary components
15
  model = None
 
16
 
17
  def load_model():
18
  global model
19
- from vtoonify_model import Model
20
  model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
21
  model.load_model('cartoon4')
22
 
@@ -24,7 +25,7 @@ def load_model():
24
  logging.basicConfig(level=logging.INFO)
25
 
26
  @app.post("/upload/")
27
- async def process_image(file: UploadFile = File(...), top: int = Form(...), bottom: int = Form(...), left: int = Form(...), right: int = Form(...)):
28
  global model
29
  if model is None:
30
  load_model()
@@ -32,24 +33,38 @@ async def process_image(file: UploadFile = File(...), top: int = Form(...), bott
32
  # Read the uploaded image file
33
  contents = await file.read()
34
 
35
- # Convert the uploaded image to numpy array
36
  nparr = np.frombuffer(contents, np.uint8)
37
  frame_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # Read as BGR format by default
38
-
39
  if frame_bgr is None:
40
  logging.error("Failed to decode the image.")
41
  return {"error": "Failed to decode the image. Please ensure the file is a valid image format."}
42
 
43
  logging.info(f"Uploaded image shape: {frame_bgr.shape}")
44
 
45
- # Save the uploaded image temporarily to pass the file path to the model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
47
- cv2.imwrite(temp_file.name, frame_bgr)
48
  temp_file_path = temp_file.name
49
 
50
  try:
51
- # Process the uploaded image using the file path
52
- aligned_face, instyle, message = model.detect_and_align_image(temp_file_path, top, bottom, left, right)
53
  if aligned_face is None or instyle is None:
54
  logging.error("Failed to process the image: No face detected or alignment failed.")
55
  return {"error": message}
@@ -67,15 +82,15 @@ async def process_image(file: UploadFile = File(...), top: int = Form(...), bott
67
 
68
  # Return the processed image as a streaming response
69
  return StreamingResponse(BytesIO(encoded_image.tobytes()), media_type="image/jpeg")
70
-
71
  finally:
72
  # Clean up the temporary file
73
  os.remove(temp_file_path)
74
 
75
  # Mount static files directory
76
- app.mount("/", StaticFiles(directory="AB", html=True), name="static")
77
 
78
  # Define index route
79
  @app.get("/")
80
- def index():
81
- return FileResponse(path="/app/AB/index.html", media_type="text/html")
 
1
+ import os
2
+ import cv2
3
+ from fastapi import FastAPI, File, UploadFile
4
  from fastapi.responses import StreamingResponse, FileResponse
5
  from fastapi.staticfiles import StaticFiles
 
 
6
  import numpy as np
7
  import logging
8
  from io import BytesIO
9
  import tempfile
10
+ from mtcnn import MTCNN
11
+ from vtoonify_model import Model # Import VToonify model
12
 
13
  app = FastAPI()
14
 
15
+ # Initialize the VToonify model and MTCNN detector
16
  model = None
17
+ detector = MTCNN()
18
 
19
  def load_model():
20
  global model
 
21
  model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
22
  model.load_model('cartoon4')
23
 
 
25
  logging.basicConfig(level=logging.INFO)
26
 
27
  @app.post("/upload/")
28
+ async def process_image(file: UploadFile = File(...)):
29
  global model
30
  if model is None:
31
  load_model()
 
33
  # Read the uploaded image file
34
  contents = await file.read()
35
 
36
+ # Convert the uploaded image to a numpy array
37
  nparr = np.frombuffer(contents, np.uint8)
38
  frame_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # Read as BGR format by default
39
+
40
  if frame_bgr is None:
41
  logging.error("Failed to decode the image.")
42
  return {"error": "Failed to decode the image. Please ensure the file is a valid image format."}
43
 
44
  logging.info(f"Uploaded image shape: {frame_bgr.shape}")
45
 
46
+ # Convert BGR to RGB for MTCNN
47
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
48
+
49
+ # Detect faces using MTCNN
50
+ results = detector.detect_faces(frame_rgb)
51
+
52
+ if len(results) == 0:
53
+ logging.error("No faces detected in the image.")
54
+ return {"error": "No faces detected in the image."}
55
+
56
+ # Use the first detected face
57
+ x, y, width, height = results[0]['box']
58
+ cropped_face = frame_rgb[y:y+height, x:x+width]
59
+
60
+ # Save the cropped face temporarily to pass the file path to the model
61
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
62
+ cv2.imwrite(temp_file.name, cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR))
63
  temp_file_path = temp_file.name
64
 
65
  try:
66
+ # Process the cropped face using VToonify
67
+ aligned_face, instyle, message = model.detect_and_align_image(temp_file_path, 0, 0, 0, 0)
68
  if aligned_face is None or instyle is None:
69
  logging.error("Failed to process the image: No face detected or alignment failed.")
70
  return {"error": message}
 
82
 
83
  # Return the processed image as a streaming response
84
  return StreamingResponse(BytesIO(encoded_image.tobytes()), media_type="image/jpeg")
85
+
86
  finally:
87
  # Clean up the temporary file
88
  os.remove(temp_file_path)
89
 
90
  # Mount static files directory
91
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")
92
 
93
  # Define index route
94
  @app.get("/")
95
+ def index() -> FileResponse:
96
+ return FileResponse(path="/app/static/index.html", media_type="text/html")