Ashrafb commited on
Commit
c14c58e
·
verified ·
1 Parent(s): 8c2e897

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +38 -55
main.py CHANGED
@@ -1,32 +1,30 @@
1
- import os
2
- import cv2
3
- from fastapi import FastAPI, File, UploadFile
4
  from fastapi.responses import StreamingResponse, FileResponse
5
  from fastapi.staticfiles import StaticFiles
 
 
6
  import numpy as np
7
  import logging
8
  from io import BytesIO
9
  import tempfile
10
- from mtcnn import MTCNN
11
- from vtoonify_model import Model # Import VToonify model
12
- import torch
13
 
14
  app = FastAPI()
15
 
16
- # Initialize logging
17
- logging.basicConfig(level=logging.INFO)
18
-
19
- # Initialize the VToonify model and MTCNN detector
20
  model = None
21
- detector = MTCNN(min_face_size=20, scale_factor=0.709)
22
 
23
  def load_model():
24
  global model
 
25
  model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
26
- model.load_model('cartoon1')
 
 
 
27
 
28
  @app.post("/upload/")
29
- async def process_image(file: UploadFile = File(...)):
30
  global model
31
  if model is None:
32
  load_model()
@@ -34,65 +32,50 @@ async def process_image(file: UploadFile = File(...)):
34
  # Read the uploaded image file
35
  contents = await file.read()
36
 
37
- # Convert the uploaded image to a numpy array
38
  nparr = np.frombuffer(contents, np.uint8)
39
  frame_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # Read as BGR format by default
40
-
41
  if frame_bgr is None:
42
  logging.error("Failed to decode the image.")
43
  return {"error": "Failed to decode the image. Please ensure the file is a valid image format."}
44
 
45
  logging.info(f"Uploaded image shape: {frame_bgr.shape}")
46
 
47
- # Convert BGR to RGB for MTCNN
48
- frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
49
-
50
- # Detect faces using MTCNN
51
- results = detector.detect_faces(frame_rgb)
52
- logging.info(f"Detection results: {results}")
53
-
54
- if len(results) == 0:
55
- logging.error("No faces detected in the image.")
56
- return {"error": "No faces detected in the image."}
57
-
58
- # Use the first detected face
59
- x, y, width, height = results[0]['box']
60
- cropped_face = frame_rgb[y:y+height, x:x+width]
61
-
62
- # Save the cropped face temporarily to pass the file path to the model
63
- with tempfile.NamedTemporaryFile(delete=True, suffix=".jpg") as temp_file:
64
- cv2.imwrite(temp_file.name, cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR))
65
  temp_file_path = temp_file.name
66
 
67
- try:
68
- # Process the cropped face using VToonify
69
- aligned_face, instyle, message = model.detect_and_align_image(temp_file_path, 0, 0, 0, 0)
70
- if aligned_face is None or instyle is None:
71
- logging.error("Failed to process the image: No face detected or alignment failed.")
72
- return {"error": message}
73
-
74
- processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon1')
75
- if processed_image is None:
76
- logging.error("Failed to toonify the image.")
77
- return {"error": message}
78
 
79
- # Convert the processed image to RGB before returning
80
- processed_image_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
 
 
81
 
82
- # Convert processed image to bytes
83
- _, encoded_image = cv2.imencode('.jpg', processed_image_rgb)
84
 
85
- # Return the processed image as a streaming response
86
- return StreamingResponse(BytesIO(encoded_image.tobytes()), media_type="image/jpeg")
87
 
88
- except Exception as e:
89
- logging.error(f"Error during processing: {e}")
90
- return {"error": str(e)}
 
 
 
91
 
92
  # Mount static files directory
93
- app.mount("/", StaticFiles(directory="AB", html=True), name="AB")
94
 
95
  # Define index route
96
  @app.get("/")
97
- def index() -> FileResponse:
98
  return FileResponse(path="/app/AB/index.html", media_type="text/html")
 
1
+ from fastapi import FastAPI, File, UploadFile, Form
 
 
2
  from fastapi.responses import StreamingResponse, FileResponse
3
  from fastapi.staticfiles import StaticFiles
4
+ import torch
5
+ import cv2
6
  import numpy as np
7
  import logging
8
  from io import BytesIO
9
  import tempfile
10
+ import os
 
 
11
 
12
  app = FastAPI()
13
 
14
+ # Load model and necessary components
 
 
 
15
  model = None
 
16
 
17
  def load_model():
18
  global model
19
+ from vtoonify_model import Model
20
  model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
21
+ model.load_model('cartoon4')
22
+
23
+ # Configure logging
24
+ logging.basicConfig(level=logging.INFO)
25
 
26
  @app.post("/upload/")
27
+ async def process_image(file: UploadFile = File(...), top: int = Form(...), bottom: int = Form(...), left: int = Form(...), right: int = Form(...)):
28
  global model
29
  if model is None:
30
  load_model()
 
32
  # Read the uploaded image file
33
  contents = await file.read()
34
 
35
+ # Convert the uploaded image to numpy array
36
  nparr = np.frombuffer(contents, np.uint8)
37
  frame_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # Read as BGR format by default
38
+
39
  if frame_bgr is None:
40
  logging.error("Failed to decode the image.")
41
  return {"error": "Failed to decode the image. Please ensure the file is a valid image format."}
42
 
43
  logging.info(f"Uploaded image shape: {frame_bgr.shape}")
44
 
45
+ # Save the uploaded image temporarily to pass the file path to the model
46
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
47
+ cv2.imwrite(temp_file.name, frame_bgr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  temp_file_path = temp_file.name
49
 
50
+ try:
51
+ # Process the uploaded image using the file path
52
+ aligned_face, instyle, message = model.detect_and_align_image(temp_file_path, top, bottom, left, right)
53
+ if aligned_face is None or instyle is None:
54
+ logging.error("Failed to process the image: No face detected or alignment failed.")
55
+ return {"error": message}
 
 
 
 
 
56
 
57
+ processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon1')
58
+ if processed_image is None:
59
+ logging.error("Failed to toonify the image.")
60
+ return {"error": message}
61
 
62
+ # Convert the processed image to RGB before returning
63
+ processed_image_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
64
 
65
+ # Convert processed image to bytes
66
+ _, encoded_image = cv2.imencode('.jpg', processed_image_rgb)
67
 
68
+ # Return the processed image as a streaming response
69
+ return StreamingResponse(BytesIO(encoded_image.tobytes()), media_type="image/jpeg")
70
+
71
+ finally:
72
+ # Clean up the temporary file
73
+ os.remove(temp_file_path)
74
 
75
  # Mount static files directory
76
+ app.mount("/", StaticFiles(directory="AB", html=True), name="static")
77
 
78
  # Define index route
79
  @app.get("/")
80
+ def index():
81
  return FileResponse(path="/app/AB/index.html", media_type="text/html")