Ashrafb commited on
Commit
1776609
·
verified ·
1 Parent(s): e410b45

Update vtoonify_model.py

Browse files
Files changed (1) hide show
  1. vtoonify_model.py +8 -109
vtoonify_model.py CHANGED
@@ -18,7 +18,7 @@ import gc
18
  import huggingface_hub
19
  import os
20
  import logging
21
- from PIL import Image # Importing Image from PIL
22
 
23
  # Configure logging
24
  logging.basicConfig(level=logging.INFO)
@@ -132,12 +132,14 @@ class Model():
132
  # Align face based on landmarks
133
  aligned_face = self.align_face(frame, landmarks)
134
  if aligned_face is not None:
 
135
  with torch.no_grad():
136
  I = self.transform(aligned_face).unsqueeze(dim=0).to(self.device)
137
  instyle = self.pspencoder(I)
138
  instyle = self.vtoonify.zplus2wplus(instyle)
139
  message = 'Successfully aligned the face.'
140
  else:
 
141
  frame = np.zeros((256, 256, 3), np.uint8)
142
  else:
143
  logging.warning("No face detected.")
@@ -172,14 +174,13 @@ class Model():
172
  # Transform and crop the image
173
  transform_size = 256
174
  output_size = 256
175
- img = Image.fromarray(image) # Corrected to use PIL.Image
176
  img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR)
177
  if output_size < transform_size:
178
  img = img.resize((output_size, output_size), Image.ANTIALIAS)
179
 
180
  return np.array(img)
181
 
182
- # Other methods remain unchanged
183
  def detect_and_align_image(self, frame_rgb: np.ndarray, top: int, bottom: int, left: int, right: int) -> tuple:
184
  if frame_rgb is None:
185
  return np.zeros((256, 256, 3), np.uint8), None, 'Error: fail to load the image.'
@@ -188,61 +189,13 @@ class Model():
188
  frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
189
  return self.detect_and_align(frame_bgr, top, bottom, left, right)
190
 
191
- def detect_and_align_video(self, video: str, top: int, bottom: int, left: int, right: int) -> tuple:
192
- if video is None:
193
- return np.zeros((256, 256, 3), np.uint8), None, 'Error: fail to load empty file.'
194
- video_cap = cv2.VideoCapture(video)
195
- if video_cap.get(7) == 0:
196
- video_cap.release()
197
- return np.zeros((256, 256, 3), np.uint8), torch.zeros(1, 18, 512).to(self.device), 'Error: fail to load the video.'
198
- success, frame = video_cap.read()
199
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
200
- video_cap.release()
201
- return self.detect_and_align(frame, top, bottom, left, right)
202
-
203
- def detect_and_align_full_video(self, video: str, top: int, bottom: int, left: int, right: int) -> tuple:
204
- message = 'Error: no face detected! Please retry or change the video.'
205
- instyle = None
206
- if video is None:
207
- return 'default.mp4', instyle, 'Error: fail to load empty file.'
208
- video_cap = cv2.VideoCapture(video)
209
- if video_cap.get(7) == 0:
210
- video_cap.release()
211
- return 'default.mp4', instyle, 'Error: fail to load the video.'
212
- num = min(self.video_limit_gpu, int(video_cap.get(7)))
213
- if self.device == 'cpu':
214
- num = min(self.video_limit_cpu, num)
215
- success, frame = video_cap.read()
216
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
217
- frame, instyle, message, w, h, top, bottom, left, right, scale = self.detect_and_align(frame, top, bottom, left, right, True)
218
- if instyle is None:
219
- return 'default.mp4', instyle, message
220
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
221
- videoWriter = cv2.VideoWriter('input.mp4', fourcc, video_cap.get(5), (int(right-left), int(bottom-top)))
222
- videoWriter.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
223
- kernel_1d = np.array([[0.125], [0.375], [0.375], [0.125]])
224
- for i in range(num-1):
225
- success, frame = video_cap.read()
226
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
227
- if scale <= 0.75:
228
- frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
229
- if scale <= 0.375:
230
- frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
231
- frame = cv2.resize(frame, (w, h))[top:bottom, left:right]
232
- videoWriter.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
233
-
234
- videoWriter.release()
235
- video_cap.release()
236
-
237
- return 'input.mp4', instyle, 'Successfully rescale the video to (%d, %d)' % (bottom-top, right-left)
238
-
239
  def image_toonify(self, aligned_face: np.ndarray, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float, style_type: str) -> tuple:
240
  if instyle is None or aligned_face is None:
241
- return np.zeros((256, 256, 3), np.uint8), 'Opps, something wrong with the input. Please go to Step 2 and Rescale Image/First Frame again.'
242
  if self.style_name != style_type:
243
  exstyle, _ = self.load_model(style_type)
244
  if exstyle is None:
245
- return np.zeros((256, 256, 3), np.uint8), 'Opps, something wrong with the style type. Please go to Step 1 and load model again.'
246
  with torch.no_grad():
247
  if self.color_transfer:
248
  s_w = exstyle
@@ -251,69 +204,15 @@ class Model():
251
  s_w[:, :7] = exstyle[:, :7]
252
 
253
  x = self.transform(aligned_face).unsqueeze(dim=0).to(self.device)
 
254
  x_p = F.interpolate(self.parsingpredictor(2*(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)))[0],
255
  scale_factor=0.5, recompute_scale_factor=False).detach()
256
  inputs = torch.cat((x, x_p/16.), dim=1)
257
  y_tilde = self.vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s=style_degree)
258
  y_tilde = torch.clamp(y_tilde, -1, 1)
259
- print('*** Toonify %dx%d image with style of %s' % (y_tilde.shape[2], y_tilde.shape[3], style_type))
260
  return ((y_tilde[0].cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8), 'Successfully toonify the image with style of %s' % (self.style_name)
261
 
262
- def video_toonify(self, aligned_video: str, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float, style_type: str) -> tuple:
263
- if aligned_video is None:
264
- return 'default.mp4', 'Opps, something wrong with the input. Please go to Step 2 and Rescale Video again.'
265
- video_cap = cv2.VideoCapture(aligned_video)
266
- if instyle is None or aligned_video is None or video_cap.get(7) == 0:
267
- video_cap.release()
268
- return 'default.mp4', 'Opps, something wrong with the input. Please go to Step 2 and Rescale Video again.'
269
- if self.style_name != style_type:
270
- exstyle, _ = self.load_model(style_type)
271
- if exstyle is None:
272
- return 'default.mp4', 'Opps, something wrong with the style type. Please go to Step 1 and load model again.'
273
- num = min(self.video_limit_gpu, int(video_cap.get(7)))
274
- if self.device == 'cpu':
275
- num = min(self.video_limit_cpu, num)
276
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
277
- videoWriter = cv2.VideoWriter('output.mp4', fourcc,
278
- video_cap.get(5), (int(video_cap.get(3)*4),
279
- int(video_cap.get(4)*4)))
280
-
281
- batch_frames = []
282
- if video_cap.get(3) != 0:
283
- if self.device == 'cpu':
284
- batch_size = max(1, int(4 * 256 * 256 / video_cap.get(3) / video_cap.get(4)))
285
- else:
286
- batch_size = min(max(1, int(4 * 400 * 360 / video_cap.get(3) / video_cap.get(4))), 4)
287
- else:
288
- batch_size = 1
289
- print('*** Toonify using batch size of %d on %dx%d video of %d frames with style of %s' % (batch_size, int(video_cap.get(3)*4), int(video_cap.get(4)*4), num, style_type))
290
- with torch.no_grad():
291
- if self.color_transfer:
292
- s_w = exstyle
293
- else:
294
- s_w = instyle.clone()
295
- s_w[:, :7] = exstyle[:, :7]
296
- for i in range(num):
297
- success, frame = video_cap.read()
298
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
299
- batch_frames += [self.transform(frame).unsqueeze(dim=0).to(self.device)]
300
- if len(batch_frames) == batch_size or (i+1) == num:
301
- x = torch.cat(batch_frames, dim=0)
302
- batch_frames = []
303
- with torch.no_grad():
304
- x_p = F.interpolate(self.parsingpredictor(2*(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)))[0],
305
- scale_factor=0.5, recompute_scale_factor=False).detach()
306
- inputs = torch.cat((x, x_p/16.), dim=1)
307
- y_tilde = self.vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), style_degree)
308
- y_tilde = torch.clamp(y_tilde, -1, 1)
309
- for k in range(y_tilde.size(0)):
310
- videoWriter.write(tensor2cv2(y_tilde[k].cpu()))
311
- gc.collect()
312
-
313
- videoWriter.release()
314
- video_cap.release()
315
- return 'output.mp4', 'Successfully toonify video of %d frames with style of %s' % (num, self.style_name)
316
-
317
  def tensor2cv2(self, img):
318
  """Convert a tensor image to OpenCV format."""
319
  tmp = ((img.cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8).copy()
 
18
  import huggingface_hub
19
  import os
20
  import logging
21
+ from PIL import Image
22
 
23
  # Configure logging
24
  logging.basicConfig(level=logging.INFO)
 
132
  # Align face based on landmarks
133
  aligned_face = self.align_face(frame, landmarks)
134
  if aligned_face is not None:
135
+ logging.info(f"Aligned face shape: {aligned_face.shape}")
136
  with torch.no_grad():
137
  I = self.transform(aligned_face).unsqueeze(dim=0).to(self.device)
138
  instyle = self.pspencoder(I)
139
  instyle = self.vtoonify.zplus2wplus(instyle)
140
  message = 'Successfully aligned the face.'
141
  else:
142
+ logging.warning("Failed to align face.")
143
  frame = np.zeros((256, 256, 3), np.uint8)
144
  else:
145
  logging.warning("No face detected.")
 
174
  # Transform and crop the image
175
  transform_size = 256
176
  output_size = 256
177
+ img = Image.fromarray(image)
178
  img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR)
179
  if output_size < transform_size:
180
  img = img.resize((output_size, output_size), Image.ANTIALIAS)
181
 
182
  return np.array(img)
183
 
 
184
  def detect_and_align_image(self, frame_rgb: np.ndarray, top: int, bottom: int, left: int, right: int) -> tuple:
185
  if frame_rgb is None:
186
  return np.zeros((256, 256, 3), np.uint8), None, 'Error: fail to load the image.'
 
189
  frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
190
  return self.detect_and_align(frame_bgr, top, bottom, left, right)
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  def image_toonify(self, aligned_face: np.ndarray, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float, style_type: str) -> tuple:
193
  if instyle is None or aligned_face is None:
194
+ return np.zeros((256, 256, 3), np.uint8), 'Oops, something wrong with the input. Please go to Step 2 and Rescale Image/First Frame again.'
195
  if self.style_name != style_type:
196
  exstyle, _ = self.load_model(style_type)
197
  if exstyle is None:
198
+ return np.zeros((256, 256, 3), np.uint8), 'Oops, something wrong with the style type. Please go to Step 1 and load model again.'
199
  with torch.no_grad():
200
  if self.color_transfer:
201
  s_w = exstyle
 
204
  s_w[:, :7] = exstyle[:, :7]
205
 
206
  x = self.transform(aligned_face).unsqueeze(dim=0).to(self.device)
207
+ logging.info(f"Input to VToonify shape: {x.shape}")
208
  x_p = F.interpolate(self.parsingpredictor(2*(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)))[0],
209
  scale_factor=0.5, recompute_scale_factor=False).detach()
210
  inputs = torch.cat((x, x_p/16.), dim=1)
211
  y_tilde = self.vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s=style_degree)
212
  y_tilde = torch.clamp(y_tilde, -1, 1)
213
+ logging.info(f"Output from VToonify shape: {y_tilde.shape}")
214
  return ((y_tilde[0].cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8), 'Successfully toonify the image with style of %s' % (self.style_name)
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  def tensor2cv2(self, img):
217
  """Convert a tensor image to OpenCV format."""
218
  tmp = ((img.cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8).copy()