Ashrafb commited on
Commit
c93b3ba
·
verified ·
1 Parent(s): ae78d7f

Update vtoonify_model.py

Browse files
Files changed (1) hide show
  1. vtoonify_model.py +70 -62
vtoonify_model.py CHANGED
@@ -14,10 +14,13 @@ from model.vtoonify import VToonify
14
  from model.bisenet.model import BiSeNet
15
  import torch.nn.functional as F
16
  from torchvision import transforms
17
- from model.encoder.align_all_parallel import align_face
18
  import gc
19
  import huggingface_hub
20
  import os
 
 
 
 
21
 
22
  MODEL_REPO = 'PKUWilliamYang/VToonify'
23
 
@@ -68,7 +71,7 @@ class Model():
68
  def _create_insightface_detector(self):
69
  # Initialize InsightFace
70
  app = insightface.app.FaceAnalysis()
71
- app.prepare(ctx_id=0, det_size=(640, 640)) # ctx_id=-1 for CPU, 0 for GPU
72
  return app
73
 
74
  def _create_parsing_model(self):
@@ -94,66 +97,7 @@ class Model():
94
  exstyle = vtoonify.zplus2wplus(exstyle)
95
  return vtoonify, exstyle
96
 
97
- def detect_and_align(self, frame, top, bottom, left, right, return_para=False):
98
- message = 'Error: no face detected! Please retry or change the photo.'
99
- instyle = None
100
- # Use InsightFace for face detection
101
- faces = self.face_detector.get(frame)
102
- if len(faces) > 0:
103
- face = faces[0]
104
- bbox = face.bbox.astype(int)
105
- x, y, w, h = bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]
106
- top, bottom, left, right = y, y + h, x, x + w
107
- scale = 1.0 # Adjust scale as needed
108
- h, w = frame.shape[:2]
109
- H, W = int(bottom-top), int(right-left)
110
- # for HR image, we apply gaussian blur to it to avoid over-sharp stylization results
111
- kernel_1d = np.array([[0.125], [0.375], [0.375], [0.125]])
112
- if scale <= 0.75:
113
- frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
114
- if scale <= 0.375:
115
- frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
116
- frame = cv2.resize(frame, (w, h))[top:bottom, left:right]
117
- with torch.no_grad():
118
- I = align_face(frame, self.face_detector)
119
- if I is not None:
120
- I = self.transform(I).unsqueeze(dim=0).to(self.device)
121
- instyle = self.pspencoder(I)
122
- instyle = self.vtoonify.zplus2wplus(instyle)
123
- message = 'Successfully rescale the frame to (%d, %d)' % (bottom-top, right-left)
124
- else:
125
- frame = np.zeros((256, 256, 3), np.uint8)
126
- else:
127
- frame = np.zeros((256, 256, 3), np.uint8)
128
- if return_para:
129
- return frame, instyle, message, w, h, top, bottom, left, right, scale
130
- return frame, instyle, message
131
-
132
- # Other methods remain unchanged
133
- def _create_parsing_model(self):
134
- parsingpredictor = BiSeNet(n_classes=19)
135
- parsingpredictor.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO, 'models/faceparsing.pth'),
136
- map_location=lambda storage, loc: storage))
137
- parsingpredictor.to(self.device).eval()
138
- return parsingpredictor
139
-
140
- def _load_encoder(self) -> nn.Module:
141
- style_encoder_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'models/encoder.pt')
142
- return load_psp_standalone(style_encoder_path, self.device)
143
-
144
- def _load_default_model(self) -> tuple:
145
- vtoonify = VToonify(backbone='dualstylegan')
146
- vtoonify.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO,
147
- 'models/vtoonify_d_cartoon/vtoonify_s026_d0.5.pt'),
148
- map_location=lambda storage, loc: storage)['g_ema'])
149
- vtoonify.to(self.device)
150
- tmp = np.load(huggingface_hub.hf_hub_download(MODEL_REPO, 'models/vtoonify_d_cartoon/exstyle_code.npy'), allow_pickle=True).item()
151
- exstyle = torch.tensor(tmp[list(tmp.keys())[26]]).to(self.device)
152
- with torch.no_grad():
153
- exstyle = vtoonify.zplus2wplus(exstyle)
154
- return vtoonify, exstyle
155
-
156
- def load_model(self, style_type: str) -> tuple:
157
  if 'illustration' in style_type:
158
  self.color_transfer = True
159
  else:
@@ -170,7 +114,71 @@ class Model():
170
  with torch.no_grad():
171
  exstyle = self.vtoonify.zplus2wplus(exstyle)
172
  return exstyle, 'Model of %s loaded.' % (style_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  def detect_and_align_image(self, frame_rgb: np.ndarray, top: int, bottom: int, left: int, right: int) -> tuple:
175
  if frame_rgb is None:
176
  return np.zeros((256, 256, 3), np.uint8), None, 'Error: fail to load the image.'
 
14
  from model.bisenet.model import BiSeNet
15
  import torch.nn.functional as F
16
  from torchvision import transforms
 
17
  import gc
18
  import huggingface_hub
19
  import os
20
+ import logging
21
+
22
+ # Configure logging
23
+ logging.basicConfig(level=logging.INFO)
24
 
25
  MODEL_REPO = 'PKUWilliamYang/VToonify'
26
 
 
71
  def _create_insightface_detector(self):
72
  # Initialize InsightFace
73
  app = insightface.app.FaceAnalysis()
74
+ app.prepare(ctx_id=0 if self.device == 'cuda' else -1, det_size=(640, 640))
75
  return app
76
 
77
  def _create_parsing_model(self):
 
97
  exstyle = vtoonify.zplus2wplus(exstyle)
98
  return vtoonify, exstyle
99
 
100
+ def load_model(self, style_type: str) -> tuple[torch.Tensor, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  if 'illustration' in style_type:
102
  self.color_transfer = True
103
  else:
 
114
  with torch.no_grad():
115
  exstyle = self.vtoonify.zplus2wplus(exstyle)
116
  return exstyle, 'Model of %s loaded.' % (style_type)
117
+
118
+ def detect_and_align(self, frame, top, bottom, left, right, return_para=False):
119
+ message = 'Error: no face detected! Please retry or change the photo.'
120
+ instyle = None
121
+ h, w, scale = 0, 0, 0
122
+
123
+ # Use InsightFace for face detection
124
+ faces = self.face_detector.get(frame)
125
+ if len(faces) > 0:
126
+ logging.info(f"Detected {len(faces)} face(s).")
127
+ face = faces[0]
128
+ bbox = face.bbox.astype(int)
129
+ landmarks = face.landmark_2d_106
130
 
131
+ # Align face based on landmarks
132
+ aligned_face = self.align_face(frame, landmarks)
133
+ if aligned_face is not None:
134
+ with torch.no_grad():
135
+ I = self.transform(aligned_face).unsqueeze(dim=0).to(self.device)
136
+ instyle = self.pspencoder(I)
137
+ instyle = self.vtoonify.zplus2wplus(instyle)
138
+ message = 'Successfully aligned the face.'
139
+ else:
140
+ frame = np.zeros((256, 256, 3), np.uint8)
141
+ else:
142
+ logging.warning("No face detected.")
143
+ frame = np.zeros((256, 256, 3), np.uint8)
144
+
145
+ if return_para:
146
+ return frame, instyle, message, h, w, top, bottom, left, right, scale
147
+ return frame, instyle, message
148
+
149
+ def align_face(self, image, landmarks):
150
+ # Calculate auxiliary vectors for alignment
151
+ eye_left = np.mean(landmarks[36:42], axis=0)
152
+ eye_right = np.mean(landmarks[42:48], axis=0)
153
+ mouth_left = landmarks[48]
154
+ mouth_right = landmarks[54]
155
+
156
+ # Calculate transformation parameters
157
+ eye_center = (eye_left + eye_right) / 2
158
+ mouth_center = (mouth_left + mouth_right) / 2
159
+ eye_to_eye = eye_right - eye_left
160
+ eye_to_mouth = mouth_center - eye_center
161
+
162
+ # Define the transformation matrix
163
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
164
+ x /= np.hypot(*x)
165
+ x *= np.hypot(*eye_to_eye) * 2.0
166
+ y = np.flipud(x) * [-1, 1]
167
+ c = eye_center + eye_to_mouth * 0.1
168
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
169
+ qsize = np.hypot(*x) * 2
170
+
171
+ # Transform and crop the image
172
+ transform_size = 256
173
+ output_size = 256
174
+ img = PIL.Image.fromarray(image)
175
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
176
+ if output_size < transform_size:
177
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
178
+
179
+ return np.array(img)
180
+
181
+ # Other methods remain unchanged
182
  def detect_and_align_image(self, frame_rgb: np.ndarray, top: int, bottom: int, left: int, right: int) -> tuple:
183
  if frame_rgb is None:
184
  return np.zeros((256, 256, 3), np.uint8), None, 'Error: fail to load the image.'