Ashrafb commited on
Commit
344e066
·
verified ·
1 Parent(s): 22bf4fb

Update vtoonify_model.py

Browse files
Files changed (1) hide show
  1. vtoonify_model.py +76 -72
vtoonify_model.py CHANGED
@@ -1,10 +1,4 @@
1
  from __future__ import annotations
2
- import gradio as gr
3
- import pathlib
4
- import sys
5
- sys.path.insert(0, 'vtoonify')
6
-
7
- from util import load_psp_standalone, get_video_crop_parameter, tensor2cv2
8
  import torch
9
  import torch.nn as nn
10
  import numpy as np
@@ -14,9 +8,7 @@ from model.vtoonify import VToonify
14
  from model.bisenet.model import BiSeNet
15
  import torch.nn.functional as F
16
  from torchvision import transforms
17
- import gc
18
  import huggingface_hub
19
- import os
20
  import logging
21
  from PIL import Image
22
 
@@ -28,65 +20,43 @@ MODEL_REPO = 'PKUWilliamYang/VToonify'
28
  class Model():
29
  def __init__(self, device):
30
  super().__init__()
31
-
32
  self.device = device
33
  self.style_types = {
34
  'cartoon1': ['vtoonify_d_cartoon/vtoonify_s026_d0.5.pt', 26],
35
- 'cartoon1-d': ['vtoonify_d_cartoon/vtoonify_s_d.pt', 26],
36
- 'cartoon2-d': ['vtoonify_d_cartoon/vtoonify_s_d.pt', 64],
37
- 'cartoon3-d': ['vtoonify_d_cartoon/vtoonify_s_d.pt', 153],
38
- 'cartoon4': ['vtoonify_d_cartoon/vtoonify_s299_d0.5.pt', 299],
39
- 'cartoon4-d': ['vtoonify_d_cartoon/vtoonify_s_d.pt', 299],
40
- 'cartoon5-d': ['vtoonify_d_cartoon/vtoonify_s_d.pt', 8],
41
- 'comic1-d': ['vtoonify_d_comic/vtoonify_s_d.pt', 28],
42
- 'comic2-d': ['vtoonify_d_comic/vtoonify_s_d.pt', 18],
43
- 'arcane1': ['vtoonify_d_arcane/vtoonify_s000_d0.5.pt', 0],
44
- 'arcane1-d': ['vtoonify_d_arcane/vtoonify_s_d.pt', 0],
45
- 'arcane2': ['vtoonify_d_arcane/vtoonify_s077_d0.5.pt', 77],
46
- 'arcane2-d': ['vtoonify_d_arcane/vtoonify_s_d.pt', 77],
47
- 'caricature1': ['vtoonify_d_caricature/vtoonify_s039_d0.5.pt', 39],
48
- 'caricature2': ['vtoonify_d_caricature/vtoonify_s068_d0.5.pt', 68],
49
- 'pixar': ['vtoonify_d_pixar/vtoonify_s052_d0.5.pt', 52],
50
- 'pixar-d': ['vtoonify_d_pixar/vtoonify_s_d.pt', 52],
51
- 'illustration1-d': ['vtoonify_d_illustration/vtoonify_s054_d_c.pt', 54],
52
- 'illustration2-d': ['vtoonify_d_illustration/vtoonify_s004_d_c.pt', 4],
53
- 'illustration3-d': ['vtoonify_d_illustration/vtoonify_s009_d_c.pt', 9],
54
- 'illustration4-d': ['vtoonify_d_illustration/vtoonify_s043_d_c.pt', 43],
55
- 'illustration5-d': ['vtoonify_d_illustration/vtoonify_s086_d_c.pt', 86],
56
  }
57
-
58
  self.face_detector = self._create_insightface_detector()
59
  self.parsingpredictor = self._create_parsing_model()
60
- self.pspencoder = self._load_encoder()
61
  self.transform = transforms.Compose([
62
  transforms.ToTensor(),
63
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
64
  ])
65
-
66
  self.vtoonify, self.exstyle = self._load_default_model()
67
  self.color_transfer = False
68
  self.style_name = 'cartoon1'
69
- self.video_limit_cpu = 100
70
- self.video_limit_gpu = 300
71
-
72
  def _create_insightface_detector(self):
73
  # Initialize InsightFace
74
  app = insightface.app.FaceAnalysis()
75
  app.prepare(ctx_id=0 if self.device == 'cuda' else -1, det_size=(640, 640))
76
  return app
77
-
78
  def _create_parsing_model(self):
79
  parsingpredictor = BiSeNet(n_classes=19)
80
  parsingpredictor.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO, 'models/faceparsing.pth'),
81
  map_location=lambda storage, loc: storage))
82
  parsingpredictor.to(self.device).eval()
83
  return parsingpredictor
84
-
85
  def _load_encoder(self) -> nn.Module:
86
  style_encoder_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'models/encoder.pt')
87
  return load_psp_standalone(style_encoder_path, self.device)
88
-
89
- def _load_default_model(self) -> tuple[torch.Tensor, str]:
90
  vtoonify = VToonify(backbone='dualstylegan')
91
  vtoonify.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO,
92
  'models/vtoonify_d_cartoon/vtoonify_s026_d0.5.pt'),
@@ -97,8 +67,8 @@ class Model():
97
  with torch.no_grad():
98
  exstyle = vtoonify.zplus2wplus(exstyle)
99
  return vtoonify, exstyle
100
-
101
- def load_model(self, style_type: str) -> tuple[torch.Tensor, str]:
102
  if 'illustration' in style_type:
103
  self.color_transfer = True
104
  else:
@@ -115,45 +85,79 @@ class Model():
115
  with torch.no_grad():
116
  exstyle = self.vtoonify.zplus2wplus(exstyle)
117
  return exstyle, 'Model of %s loaded.' % (style_type)
 
118
  def detect_and_align(self, frame, top, bottom, left, right, return_para=False):
119
  message = 'Error: no face detected! Please retry or change the photo.'
120
- paras = get_video_crop_parameter(frame, self.landmarkpredictor, [left, right, top, bottom])
121
  instyle = None
122
- h, w, scale = 0, 0, 0
123
- if paras is not None:
124
- h,w,top,bottom,left,right,scale = paras
125
- H, W = int(bottom-top), int(right-left)
126
- # for HR image, we apply gaussian blur to it to avoid over-sharp stylization results
127
- kernel_1d = np.array([[0.125],[0.375],[0.375],[0.125]])
128
- if scale <= 0.75:
129
- frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
130
- if scale <= 0.375:
131
- frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
132
- frame = cv2.resize(frame, (w, h))[top:bottom, left:right]
133
- with torch.no_grad():
134
- I = align_face(frame, self.landmarkpredictor)
135
- if I is not None:
136
- I = self.transform(I).unsqueeze(dim=0).to(self.device)
137
  instyle = self.pspencoder(I)
138
  instyle = self.vtoonify.zplus2wplus(instyle)
139
- message = 'Successfully rescale the frame to (%d, %d)'%(bottom-top, right-left)
140
- else:
141
- frame = np.zeros((256,256,3), np.uint8)
 
142
  else:
143
- frame = np.zeros((256,256,3), np.uint8)
 
 
144
  if return_para:
145
- return frame, instyle, message, w, h, top, bottom, left, right, scale
146
  return frame, instyle, message
147
-
148
-
149
 
150
- def detect_and_align_image(self, frame_rgb: np.ndarray, top: int, bottom: int, left: int, right: int) -> tuple:
151
- if frame_rgb is None:
152
- return np.zeros((256, 256, 3), np.uint8), None, 'Error: fail to load the image.'
153
-
154
- # Convert RGB to BGR
155
- frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
156
- return self.detect_and_align(frame_bgr, top, bottom, left, right)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  def image_toonify(self, aligned_face: np.ndarray, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float, style_type: str) -> tuple:
159
  if instyle is None or aligned_face is None:
 
1
  from __future__ import annotations
 
 
 
 
 
 
2
  import torch
3
  import torch.nn as nn
4
  import numpy as np
 
8
  from model.bisenet.model import BiSeNet
9
  import torch.nn.functional as F
10
  from torchvision import transforms
 
11
  import huggingface_hub
 
12
  import logging
13
  from PIL import Image
14
 
 
20
  class Model():
21
  def __init__(self, device):
22
  super().__init__()
23
+
24
  self.device = device
25
  self.style_types = {
26
  'cartoon1': ['vtoonify_d_cartoon/vtoonify_s026_d0.5.pt', 26],
27
+ # Add other styles as needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  }
29
+
30
  self.face_detector = self._create_insightface_detector()
31
  self.parsingpredictor = self._create_parsing_model()
32
+ self.pspencoder = self._load_encoder()
33
  self.transform = transforms.Compose([
34
  transforms.ToTensor(),
35
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
36
  ])
37
+
38
  self.vtoonify, self.exstyle = self._load_default_model()
39
  self.color_transfer = False
40
  self.style_name = 'cartoon1'
41
+
 
 
42
  def _create_insightface_detector(self):
43
  # Initialize InsightFace
44
  app = insightface.app.FaceAnalysis()
45
  app.prepare(ctx_id=0 if self.device == 'cuda' else -1, det_size=(640, 640))
46
  return app
47
+
48
  def _create_parsing_model(self):
49
  parsingpredictor = BiSeNet(n_classes=19)
50
  parsingpredictor.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO, 'models/faceparsing.pth'),
51
  map_location=lambda storage, loc: storage))
52
  parsingpredictor.to(self.device).eval()
53
  return parsingpredictor
54
+
55
  def _load_encoder(self) -> nn.Module:
56
  style_encoder_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'models/encoder.pt')
57
  return load_psp_standalone(style_encoder_path, self.device)
58
+
59
+ def _load_default_model(self) -> tuple:
60
  vtoonify = VToonify(backbone='dualstylegan')
61
  vtoonify.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO,
62
  'models/vtoonify_d_cartoon/vtoonify_s026_d0.5.pt'),
 
67
  with torch.no_grad():
68
  exstyle = vtoonify.zplus2wplus(exstyle)
69
  return vtoonify, exstyle
70
+
71
+ def load_model(self, style_type: str) -> tuple:
72
  if 'illustration' in style_type:
73
  self.color_transfer = True
74
  else:
 
85
  with torch.no_grad():
86
  exstyle = self.vtoonify.zplus2wplus(exstyle)
87
  return exstyle, 'Model of %s loaded.' % (style_type)
88
+
89
  def detect_and_align(self, frame, top, bottom, left, right, return_para=False):
90
  message = 'Error: no face detected! Please retry or change the photo.'
 
91
  instyle = None
92
+
93
+ # Use InsightFace for face detection
94
+ faces = self.face_detector.get(frame)
95
+ if len(faces) > 0:
96
+ logging.info(f"Detected {len(faces)} face(s).")
97
+ face = faces[0]
98
+ landmarks = face.landmark_2d_106
99
+
100
+ # Align face based on mapped landmarks
101
+ aligned_face = self.align_face(frame, landmarks)
102
+ if aligned_face is not None:
103
+ logging.info(f"Aligned face shape: {aligned_face.shape}")
104
+ with torch.no_grad():
105
+ I = self.transform(aligned_face).unsqueeze(dim=0).to(self.device)
 
106
  instyle = self.pspencoder(I)
107
  instyle = self.vtoonify.zplus2wplus(instyle)
108
+ message = 'Successfully aligned the face.'
109
+ else:
110
+ logging.warning("Failed to align face.")
111
+ frame = np.zeros((256, 256, 3), np.uint8)
112
  else:
113
+ logging.warning("No face detected.")
114
+ frame = np.zeros((256, 256, 3), np.uint8)
115
+
116
  if return_para:
117
+ return frame, instyle, message
118
  return frame, instyle, message
 
 
119
 
120
+ def align_face(self, image, landmarks):
121
+ # Map InsightFace landmarks to dlib's 68-point model
122
+ # Example: use specific indices for eyes and mouth
123
+ eye_left = np.mean(landmarks[36:42], axis=0)
124
+ eye_right = np.mean(landmarks[42:48], axis=0)
125
+ mouth_left = landmarks[48]
126
+ mouth_right = landmarks[54]
127
+
128
+ # Calculate transformation parameters
129
+ eye_center = (eye_left + eye_right) / 2
130
+ mouth_center = (mouth_left + mouth_right) / 2
131
+ eye_to_eye = eye_right - eye_left
132
+ eye_to_mouth = mouth_center - eye_center
133
+
134
+ # Define the transformation matrix
135
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
136
+ x /= np.hypot(*x)
137
+ x *= np.hypot(*eye_to_eye) * 2.0
138
+ y = np.flipud(x) * [-1, 1]
139
+ c = eye_center + eye_to_mouth * 0.1
140
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
141
+ qsize = np.hypot(*x) * 2
142
+
143
+ # Transform and crop the image
144
+ transform_size = 256
145
+ output_size = 256
146
+ img = Image.fromarray(image)
147
+ img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR)
148
+ if output_size < transform_size:
149
+ img = img.resize((output_size, output_size), Image.ANTIALIAS)
150
+
151
+ return np.array(img)
152
+
153
+ def detect_and_align_image(self, image: str, top: int, bottom: int, left: int, right: int) -> tuple:
154
+ if image is None:
155
+ return np.zeros((256, 256, 3), np.uint8), None, 'Error: fail to load empty file.'
156
+ frame = cv2.imread(image)
157
+ if frame is None:
158
+ return np.zeros((256, 256, 3), np.uint8), None, 'Error: fail to load the image.'
159
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
160
+ return self.detect_and_align(frame, top, bottom, left, right)
161
 
162
  def image_toonify(self, aligned_face: np.ndarray, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float, style_type: str) -> tuple:
163
  if instyle is None or aligned_face is None: