Ashrafb commited on
Commit
c7c2154
·
verified ·
1 Parent(s): bb43511

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +57 -53
main.py CHANGED
@@ -1,23 +1,24 @@
1
- from fastapi import FastAPI, File, UploadFile
 
 
2
  from fastapi.responses import StreamingResponse, FileResponse
3
  from fastapi.staticfiles import StaticFiles
4
- import torch
5
- import cv2
6
  import numpy as np
7
  import logging
8
  from io import BytesIO
9
  import tempfile
10
- import os
 
11
 
12
  app = FastAPI()
13
 
14
- # Load model and necessary components
 
 
15
  model = None
16
- face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
17
 
18
  def load_model():
19
  global model
20
- from vtoonify_model import Model
21
  model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
22
  model.load_model('cartoon4')
23
 
@@ -25,7 +26,7 @@ def load_model():
25
  logging.basicConfig(level=logging.INFO)
26
 
27
  @app.post("/upload/")
28
- async def process_image(file: UploadFile = File(...)):
29
  global model
30
  if model is None:
31
  load_model()
@@ -43,52 +44,55 @@ async def process_image(file: UploadFile = File(...)):
43
 
44
  logging.info(f"Uploaded image shape: {frame_bgr.shape}")
45
 
46
- # Detect faces in the image
47
- gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
48
- faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
49
-
50
- if len(faces) == 0:
51
- logging.error("No faces detected in the image.")
52
- return {"error": "No faces detected in the image."}
53
-
54
- # Use the first detected face
55
- (x, y, w, h) = faces[0]
56
- top, bottom, left, right = y, y + h, x, x + w
57
-
58
- # Save the uploaded image temporarily to pass the file path to the model
59
- with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
60
- cv2.imwrite(temp_file.name, frame_bgr)
61
- temp_file_path = temp_file.name
62
 
 
63
  try:
64
- # Process the uploaded image using the file path
65
- aligned_face, instyle, message = model.detect_and_align_image(temp_file_path, top, bottom, left, right)
66
- if aligned_face is None or instyle is None:
67
- logging.error("Failed to process the image: No face detected or alignment failed.")
68
- return {"error": message}
69
-
70
- processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon1')
71
- if processed_image is None:
72
- logging.error("Failed to toonify the image.")
73
- return {"error": message}
74
-
75
- # Convert the processed image to RGB before returning
76
- processed_image_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
77
-
78
- # Convert processed image to bytes
79
- _, encoded_image = cv2.imencode('.jpg', processed_image_rgb)
80
-
81
- # Return the processed image as a streaming response
82
- return StreamingResponse(BytesIO(encoded_image.tobytes()), media_type="image/jpeg")
83
-
84
- finally:
85
- # Clean up the temporary file
86
- os.remove(temp_file_path)
87
-
88
- # Mount static files directory
89
- app.mount("/", StaticFiles(directory="AB", html=True), name="static")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- # Define index route
92
  @app.get("/")
93
- def index():
94
- return FileResponse(path="/app/AB/index.html", media_type="text/html")
 
1
+ import os
2
+ import cv2
3
+ from fastapi import FastAPI, File, UploadFile, Form
4
  from fastapi.responses import StreamingResponse, FileResponse
5
  from fastapi.staticfiles import StaticFiles
 
 
6
  import numpy as np
7
  import logging
8
  from io import BytesIO
9
  import tempfile
10
+ import AnimeGANv3_src # Assuming this module contains the face detection logic
11
+ from vtoonify_model import Model # Import VToonify model
12
 
13
  app = FastAPI()
14
 
15
+ os.makedirs('output', exist_ok=True)
16
+
17
+ # Initialize VToonify model
18
  model = None
 
19
 
20
  def load_model():
21
  global model
 
22
  model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
23
  model.load_model('cartoon4')
24
 
 
26
  logging.basicConfig(level=logging.INFO)
27
 
28
  @app.post("/upload/")
29
+ async def process_image(file: UploadFile = File(...), Style: str = Form(...)):
30
  global model
31
  if model is None:
32
  load_model()
 
44
 
45
  logging.info(f"Uploaded image shape: {frame_bgr.shape}")
46
 
47
+ # Convert BGR to RGB
48
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ # Use AnimeGANv3's face detection logic
51
  try:
52
+ # Assume AnimeGANv3_src.Convert detects and returns the cropped face
53
+ det_face = True # Assume we always want to detect face
54
+ detected_face, _ = AnimeGANv3_src.Convert(frame_rgb, Style, det_face)
55
+
56
+ if detected_face is None:
57
+ logging.error("No face detected by AnimeGANv3.")
58
+ return {"error": "No face detected in the image."}
59
+
60
+ # Save the detected face temporarily to pass the file path to the model
61
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
62
+ cv2.imwrite(temp_file.name, detected_face[:, :, ::-1]) # Convert RGB to BGR for saving
63
+ temp_file_path = temp_file.name
64
+
65
+ try:
66
+ # Process the detected face using VToonify
67
+ aligned_face, instyle, message = model.detect_and_align_image(temp_file_path, 0, 0, 0, 0)
68
+ if aligned_face is None or instyle is None:
69
+ logging.error("Failed to process the image: No face detected or alignment failed.")
70
+ return {"error": message}
71
+
72
+ processed_image, message = model.image_toonify(aligned_face, instyle, model.exstyle, style_degree=0.5, style_type='cartoon1')
73
+ if processed_image is None:
74
+ logging.error("Failed to toonify the image.")
75
+ return {"error": message}
76
+
77
+ # Convert the processed image to RGB before returning
78
+ processed_image_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
79
+
80
+ # Convert processed image to bytes
81
+ _, encoded_image = cv2.imencode('.jpg', processed_image_rgb)
82
+
83
+ # Return the processed image as a streaming response
84
+ return StreamingResponse(BytesIO(encoded_image.tobytes()), media_type="image/jpeg")
85
+
86
+ finally:
87
+ # Clean up the temporary file
88
+ os.remove(temp_file_path)
89
+
90
+ except RuntimeError as error:
91
+ logging.error(f"Error during AnimeGANv3 processing: {error}")
92
+ return {"error": "Failed to process the image with AnimeGANv3."}
93
+
94
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")
95
 
 
96
  @app.get("/")
97
+ def index() -> FileResponse:
98
+ return FileResponse(path="/app/static/index.html", media_type="text/html")