Update vtoonify_model.py
Browse files- vtoonify_model.py +8 -109
vtoonify_model.py
CHANGED
@@ -18,7 +18,7 @@ import gc
|
|
18 |
import huggingface_hub
|
19 |
import os
|
20 |
import logging
|
21 |
-
from PIL import Image
|
22 |
|
23 |
# Configure logging
|
24 |
logging.basicConfig(level=logging.INFO)
|
@@ -132,12 +132,14 @@ class Model():
|
|
132 |
# Align face based on landmarks
|
133 |
aligned_face = self.align_face(frame, landmarks)
|
134 |
if aligned_face is not None:
|
|
|
135 |
with torch.no_grad():
|
136 |
I = self.transform(aligned_face).unsqueeze(dim=0).to(self.device)
|
137 |
instyle = self.pspencoder(I)
|
138 |
instyle = self.vtoonify.zplus2wplus(instyle)
|
139 |
message = 'Successfully aligned the face.'
|
140 |
else:
|
|
|
141 |
frame = np.zeros((256, 256, 3), np.uint8)
|
142 |
else:
|
143 |
logging.warning("No face detected.")
|
@@ -172,14 +174,13 @@ class Model():
|
|
172 |
# Transform and crop the image
|
173 |
transform_size = 256
|
174 |
output_size = 256
|
175 |
-
img = Image.fromarray(image)
|
176 |
img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR)
|
177 |
if output_size < transform_size:
|
178 |
img = img.resize((output_size, output_size), Image.ANTIALIAS)
|
179 |
|
180 |
return np.array(img)
|
181 |
|
182 |
-
# Other methods remain unchanged
|
183 |
def detect_and_align_image(self, frame_rgb: np.ndarray, top: int, bottom: int, left: int, right: int) -> tuple:
|
184 |
if frame_rgb is None:
|
185 |
return np.zeros((256, 256, 3), np.uint8), None, 'Error: fail to load the image.'
|
@@ -188,61 +189,13 @@ class Model():
|
|
188 |
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
|
189 |
return self.detect_and_align(frame_bgr, top, bottom, left, right)
|
190 |
|
191 |
-
def detect_and_align_video(self, video: str, top: int, bottom: int, left: int, right: int) -> tuple:
|
192 |
-
if video is None:
|
193 |
-
return np.zeros((256, 256, 3), np.uint8), None, 'Error: fail to load empty file.'
|
194 |
-
video_cap = cv2.VideoCapture(video)
|
195 |
-
if video_cap.get(7) == 0:
|
196 |
-
video_cap.release()
|
197 |
-
return np.zeros((256, 256, 3), np.uint8), torch.zeros(1, 18, 512).to(self.device), 'Error: fail to load the video.'
|
198 |
-
success, frame = video_cap.read()
|
199 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
200 |
-
video_cap.release()
|
201 |
-
return self.detect_and_align(frame, top, bottom, left, right)
|
202 |
-
|
203 |
-
def detect_and_align_full_video(self, video: str, top: int, bottom: int, left: int, right: int) -> tuple:
|
204 |
-
message = 'Error: no face detected! Please retry or change the video.'
|
205 |
-
instyle = None
|
206 |
-
if video is None:
|
207 |
-
return 'default.mp4', instyle, 'Error: fail to load empty file.'
|
208 |
-
video_cap = cv2.VideoCapture(video)
|
209 |
-
if video_cap.get(7) == 0:
|
210 |
-
video_cap.release()
|
211 |
-
return 'default.mp4', instyle, 'Error: fail to load the video.'
|
212 |
-
num = min(self.video_limit_gpu, int(video_cap.get(7)))
|
213 |
-
if self.device == 'cpu':
|
214 |
-
num = min(self.video_limit_cpu, num)
|
215 |
-
success, frame = video_cap.read()
|
216 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
217 |
-
frame, instyle, message, w, h, top, bottom, left, right, scale = self.detect_and_align(frame, top, bottom, left, right, True)
|
218 |
-
if instyle is None:
|
219 |
-
return 'default.mp4', instyle, message
|
220 |
-
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
221 |
-
videoWriter = cv2.VideoWriter('input.mp4', fourcc, video_cap.get(5), (int(right-left), int(bottom-top)))
|
222 |
-
videoWriter.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
223 |
-
kernel_1d = np.array([[0.125], [0.375], [0.375], [0.125]])
|
224 |
-
for i in range(num-1):
|
225 |
-
success, frame = video_cap.read()
|
226 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
227 |
-
if scale <= 0.75:
|
228 |
-
frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
|
229 |
-
if scale <= 0.375:
|
230 |
-
frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
|
231 |
-
frame = cv2.resize(frame, (w, h))[top:bottom, left:right]
|
232 |
-
videoWriter.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
233 |
-
|
234 |
-
videoWriter.release()
|
235 |
-
video_cap.release()
|
236 |
-
|
237 |
-
return 'input.mp4', instyle, 'Successfully rescale the video to (%d, %d)' % (bottom-top, right-left)
|
238 |
-
|
239 |
def image_toonify(self, aligned_face: np.ndarray, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float, style_type: str) -> tuple:
|
240 |
if instyle is None or aligned_face is None:
|
241 |
-
return np.zeros((256, 256, 3), np.uint8), '
|
242 |
if self.style_name != style_type:
|
243 |
exstyle, _ = self.load_model(style_type)
|
244 |
if exstyle is None:
|
245 |
-
return np.zeros((256, 256, 3), np.uint8), '
|
246 |
with torch.no_grad():
|
247 |
if self.color_transfer:
|
248 |
s_w = exstyle
|
@@ -251,69 +204,15 @@ class Model():
|
|
251 |
s_w[:, :7] = exstyle[:, :7]
|
252 |
|
253 |
x = self.transform(aligned_face).unsqueeze(dim=0).to(self.device)
|
|
|
254 |
x_p = F.interpolate(self.parsingpredictor(2*(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)))[0],
|
255 |
scale_factor=0.5, recompute_scale_factor=False).detach()
|
256 |
inputs = torch.cat((x, x_p/16.), dim=1)
|
257 |
y_tilde = self.vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s=style_degree)
|
258 |
y_tilde = torch.clamp(y_tilde, -1, 1)
|
259 |
-
|
260 |
return ((y_tilde[0].cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8), 'Successfully toonify the image with style of %s' % (self.style_name)
|
261 |
|
262 |
-
def video_toonify(self, aligned_video: str, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float, style_type: str) -> tuple:
|
263 |
-
if aligned_video is None:
|
264 |
-
return 'default.mp4', 'Opps, something wrong with the input. Please go to Step 2 and Rescale Video again.'
|
265 |
-
video_cap = cv2.VideoCapture(aligned_video)
|
266 |
-
if instyle is None or aligned_video is None or video_cap.get(7) == 0:
|
267 |
-
video_cap.release()
|
268 |
-
return 'default.mp4', 'Opps, something wrong with the input. Please go to Step 2 and Rescale Video again.'
|
269 |
-
if self.style_name != style_type:
|
270 |
-
exstyle, _ = self.load_model(style_type)
|
271 |
-
if exstyle is None:
|
272 |
-
return 'default.mp4', 'Opps, something wrong with the style type. Please go to Step 1 and load model again.'
|
273 |
-
num = min(self.video_limit_gpu, int(video_cap.get(7)))
|
274 |
-
if self.device == 'cpu':
|
275 |
-
num = min(self.video_limit_cpu, num)
|
276 |
-
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
277 |
-
videoWriter = cv2.VideoWriter('output.mp4', fourcc,
|
278 |
-
video_cap.get(5), (int(video_cap.get(3)*4),
|
279 |
-
int(video_cap.get(4)*4)))
|
280 |
-
|
281 |
-
batch_frames = []
|
282 |
-
if video_cap.get(3) != 0:
|
283 |
-
if self.device == 'cpu':
|
284 |
-
batch_size = max(1, int(4 * 256 * 256 / video_cap.get(3) / video_cap.get(4)))
|
285 |
-
else:
|
286 |
-
batch_size = min(max(1, int(4 * 400 * 360 / video_cap.get(3) / video_cap.get(4))), 4)
|
287 |
-
else:
|
288 |
-
batch_size = 1
|
289 |
-
print('*** Toonify using batch size of %d on %dx%d video of %d frames with style of %s' % (batch_size, int(video_cap.get(3)*4), int(video_cap.get(4)*4), num, style_type))
|
290 |
-
with torch.no_grad():
|
291 |
-
if self.color_transfer:
|
292 |
-
s_w = exstyle
|
293 |
-
else:
|
294 |
-
s_w = instyle.clone()
|
295 |
-
s_w[:, :7] = exstyle[:, :7]
|
296 |
-
for i in range(num):
|
297 |
-
success, frame = video_cap.read()
|
298 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
299 |
-
batch_frames += [self.transform(frame).unsqueeze(dim=0).to(self.device)]
|
300 |
-
if len(batch_frames) == batch_size or (i+1) == num:
|
301 |
-
x = torch.cat(batch_frames, dim=0)
|
302 |
-
batch_frames = []
|
303 |
-
with torch.no_grad():
|
304 |
-
x_p = F.interpolate(self.parsingpredictor(2*(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)))[0],
|
305 |
-
scale_factor=0.5, recompute_scale_factor=False).detach()
|
306 |
-
inputs = torch.cat((x, x_p/16.), dim=1)
|
307 |
-
y_tilde = self.vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), style_degree)
|
308 |
-
y_tilde = torch.clamp(y_tilde, -1, 1)
|
309 |
-
for k in range(y_tilde.size(0)):
|
310 |
-
videoWriter.write(tensor2cv2(y_tilde[k].cpu()))
|
311 |
-
gc.collect()
|
312 |
-
|
313 |
-
videoWriter.release()
|
314 |
-
video_cap.release()
|
315 |
-
return 'output.mp4', 'Successfully toonify video of %d frames with style of %s' % (num, self.style_name)
|
316 |
-
|
317 |
def tensor2cv2(self, img):
|
318 |
"""Convert a tensor image to OpenCV format."""
|
319 |
tmp = ((img.cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8).copy()
|
|
|
18 |
import huggingface_hub
|
19 |
import os
|
20 |
import logging
|
21 |
+
from PIL import Image
|
22 |
|
23 |
# Configure logging
|
24 |
logging.basicConfig(level=logging.INFO)
|
|
|
132 |
# Align face based on landmarks
|
133 |
aligned_face = self.align_face(frame, landmarks)
|
134 |
if aligned_face is not None:
|
135 |
+
logging.info(f"Aligned face shape: {aligned_face.shape}")
|
136 |
with torch.no_grad():
|
137 |
I = self.transform(aligned_face).unsqueeze(dim=0).to(self.device)
|
138 |
instyle = self.pspencoder(I)
|
139 |
instyle = self.vtoonify.zplus2wplus(instyle)
|
140 |
message = 'Successfully aligned the face.'
|
141 |
else:
|
142 |
+
logging.warning("Failed to align face.")
|
143 |
frame = np.zeros((256, 256, 3), np.uint8)
|
144 |
else:
|
145 |
logging.warning("No face detected.")
|
|
|
174 |
# Transform and crop the image
|
175 |
transform_size = 256
|
176 |
output_size = 256
|
177 |
+
img = Image.fromarray(image)
|
178 |
img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR)
|
179 |
if output_size < transform_size:
|
180 |
img = img.resize((output_size, output_size), Image.ANTIALIAS)
|
181 |
|
182 |
return np.array(img)
|
183 |
|
|
|
184 |
def detect_and_align_image(self, frame_rgb: np.ndarray, top: int, bottom: int, left: int, right: int) -> tuple:
|
185 |
if frame_rgb is None:
|
186 |
return np.zeros((256, 256, 3), np.uint8), None, 'Error: fail to load the image.'
|
|
|
189 |
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
|
190 |
return self.detect_and_align(frame_bgr, top, bottom, left, right)
|
191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
def image_toonify(self, aligned_face: np.ndarray, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float, style_type: str) -> tuple:
|
193 |
if instyle is None or aligned_face is None:
|
194 |
+
return np.zeros((256, 256, 3), np.uint8), 'Oops, something wrong with the input. Please go to Step 2 and Rescale Image/First Frame again.'
|
195 |
if self.style_name != style_type:
|
196 |
exstyle, _ = self.load_model(style_type)
|
197 |
if exstyle is None:
|
198 |
+
return np.zeros((256, 256, 3), np.uint8), 'Oops, something wrong with the style type. Please go to Step 1 and load model again.'
|
199 |
with torch.no_grad():
|
200 |
if self.color_transfer:
|
201 |
s_w = exstyle
|
|
|
204 |
s_w[:, :7] = exstyle[:, :7]
|
205 |
|
206 |
x = self.transform(aligned_face).unsqueeze(dim=0).to(self.device)
|
207 |
+
logging.info(f"Input to VToonify shape: {x.shape}")
|
208 |
x_p = F.interpolate(self.parsingpredictor(2*(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)))[0],
|
209 |
scale_factor=0.5, recompute_scale_factor=False).detach()
|
210 |
inputs = torch.cat((x, x_p/16.), dim=1)
|
211 |
y_tilde = self.vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s=style_degree)
|
212 |
y_tilde = torch.clamp(y_tilde, -1, 1)
|
213 |
+
logging.info(f"Output from VToonify shape: {y_tilde.shape}")
|
214 |
return ((y_tilde[0].cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8), 'Successfully toonify the image with style of %s' % (self.style_name)
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
def tensor2cv2(self, img):
|
217 |
"""Convert a tensor image to OpenCV format."""
|
218 |
tmp = ((img.cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8).copy()
|