Update vtoonify_model.py
Browse files- vtoonify_model.py +70 -62
vtoonify_model.py
CHANGED
@@ -14,10 +14,13 @@ from model.vtoonify import VToonify
|
|
14 |
from model.bisenet.model import BiSeNet
|
15 |
import torch.nn.functional as F
|
16 |
from torchvision import transforms
|
17 |
-
from model.encoder.align_all_parallel import align_face
|
18 |
import gc
|
19 |
import huggingface_hub
|
20 |
import os
|
|
|
|
|
|
|
|
|
21 |
|
22 |
MODEL_REPO = 'PKUWilliamYang/VToonify'
|
23 |
|
@@ -68,7 +71,7 @@ class Model():
|
|
68 |
def _create_insightface_detector(self):
|
69 |
# Initialize InsightFace
|
70 |
app = insightface.app.FaceAnalysis()
|
71 |
-
app.prepare(ctx_id=0, det_size=(640, 640))
|
72 |
return app
|
73 |
|
74 |
def _create_parsing_model(self):
|
@@ -94,66 +97,7 @@ class Model():
|
|
94 |
exstyle = vtoonify.zplus2wplus(exstyle)
|
95 |
return vtoonify, exstyle
|
96 |
|
97 |
-
def
|
98 |
-
message = 'Error: no face detected! Please retry or change the photo.'
|
99 |
-
instyle = None
|
100 |
-
# Use InsightFace for face detection
|
101 |
-
faces = self.face_detector.get(frame)
|
102 |
-
if len(faces) > 0:
|
103 |
-
face = faces[0]
|
104 |
-
bbox = face.bbox.astype(int)
|
105 |
-
x, y, w, h = bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]
|
106 |
-
top, bottom, left, right = y, y + h, x, x + w
|
107 |
-
scale = 1.0 # Adjust scale as needed
|
108 |
-
h, w = frame.shape[:2]
|
109 |
-
H, W = int(bottom-top), int(right-left)
|
110 |
-
# for HR image, we apply gaussian blur to it to avoid over-sharp stylization results
|
111 |
-
kernel_1d = np.array([[0.125], [0.375], [0.375], [0.125]])
|
112 |
-
if scale <= 0.75:
|
113 |
-
frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
|
114 |
-
if scale <= 0.375:
|
115 |
-
frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
|
116 |
-
frame = cv2.resize(frame, (w, h))[top:bottom, left:right]
|
117 |
-
with torch.no_grad():
|
118 |
-
I = align_face(frame, self.face_detector)
|
119 |
-
if I is not None:
|
120 |
-
I = self.transform(I).unsqueeze(dim=0).to(self.device)
|
121 |
-
instyle = self.pspencoder(I)
|
122 |
-
instyle = self.vtoonify.zplus2wplus(instyle)
|
123 |
-
message = 'Successfully rescale the frame to (%d, %d)' % (bottom-top, right-left)
|
124 |
-
else:
|
125 |
-
frame = np.zeros((256, 256, 3), np.uint8)
|
126 |
-
else:
|
127 |
-
frame = np.zeros((256, 256, 3), np.uint8)
|
128 |
-
if return_para:
|
129 |
-
return frame, instyle, message, w, h, top, bottom, left, right, scale
|
130 |
-
return frame, instyle, message
|
131 |
-
|
132 |
-
# Other methods remain unchanged
|
133 |
-
def _create_parsing_model(self):
|
134 |
-
parsingpredictor = BiSeNet(n_classes=19)
|
135 |
-
parsingpredictor.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO, 'models/faceparsing.pth'),
|
136 |
-
map_location=lambda storage, loc: storage))
|
137 |
-
parsingpredictor.to(self.device).eval()
|
138 |
-
return parsingpredictor
|
139 |
-
|
140 |
-
def _load_encoder(self) -> nn.Module:
|
141 |
-
style_encoder_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'models/encoder.pt')
|
142 |
-
return load_psp_standalone(style_encoder_path, self.device)
|
143 |
-
|
144 |
-
def _load_default_model(self) -> tuple:
|
145 |
-
vtoonify = VToonify(backbone='dualstylegan')
|
146 |
-
vtoonify.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO,
|
147 |
-
'models/vtoonify_d_cartoon/vtoonify_s026_d0.5.pt'),
|
148 |
-
map_location=lambda storage, loc: storage)['g_ema'])
|
149 |
-
vtoonify.to(self.device)
|
150 |
-
tmp = np.load(huggingface_hub.hf_hub_download(MODEL_REPO, 'models/vtoonify_d_cartoon/exstyle_code.npy'), allow_pickle=True).item()
|
151 |
-
exstyle = torch.tensor(tmp[list(tmp.keys())[26]]).to(self.device)
|
152 |
-
with torch.no_grad():
|
153 |
-
exstyle = vtoonify.zplus2wplus(exstyle)
|
154 |
-
return vtoonify, exstyle
|
155 |
-
|
156 |
-
def load_model(self, style_type: str) -> tuple:
|
157 |
if 'illustration' in style_type:
|
158 |
self.color_transfer = True
|
159 |
else:
|
@@ -170,7 +114,71 @@ class Model():
|
|
170 |
with torch.no_grad():
|
171 |
exstyle = self.vtoonify.zplus2wplus(exstyle)
|
172 |
return exstyle, 'Model of %s loaded.' % (style_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
def detect_and_align_image(self, frame_rgb: np.ndarray, top: int, bottom: int, left: int, right: int) -> tuple:
|
175 |
if frame_rgb is None:
|
176 |
return np.zeros((256, 256, 3), np.uint8), None, 'Error: fail to load the image.'
|
|
|
14 |
from model.bisenet.model import BiSeNet
|
15 |
import torch.nn.functional as F
|
16 |
from torchvision import transforms
|
|
|
17 |
import gc
|
18 |
import huggingface_hub
|
19 |
import os
|
20 |
+
import logging
|
21 |
+
|
22 |
+
# Configure logging
|
23 |
+
logging.basicConfig(level=logging.INFO)
|
24 |
|
25 |
MODEL_REPO = 'PKUWilliamYang/VToonify'
|
26 |
|
|
|
71 |
def _create_insightface_detector(self):
|
72 |
# Initialize InsightFace
|
73 |
app = insightface.app.FaceAnalysis()
|
74 |
+
app.prepare(ctx_id=0 if self.device == 'cuda' else -1, det_size=(640, 640))
|
75 |
return app
|
76 |
|
77 |
def _create_parsing_model(self):
|
|
|
97 |
exstyle = vtoonify.zplus2wplus(exstyle)
|
98 |
return vtoonify, exstyle
|
99 |
|
100 |
+
def load_model(self, style_type: str) -> tuple[torch.Tensor, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
if 'illustration' in style_type:
|
102 |
self.color_transfer = True
|
103 |
else:
|
|
|
114 |
with torch.no_grad():
|
115 |
exstyle = self.vtoonify.zplus2wplus(exstyle)
|
116 |
return exstyle, 'Model of %s loaded.' % (style_type)
|
117 |
+
|
118 |
+
def detect_and_align(self, frame, top, bottom, left, right, return_para=False):
|
119 |
+
message = 'Error: no face detected! Please retry or change the photo.'
|
120 |
+
instyle = None
|
121 |
+
h, w, scale = 0, 0, 0
|
122 |
+
|
123 |
+
# Use InsightFace for face detection
|
124 |
+
faces = self.face_detector.get(frame)
|
125 |
+
if len(faces) > 0:
|
126 |
+
logging.info(f"Detected {len(faces)} face(s).")
|
127 |
+
face = faces[0]
|
128 |
+
bbox = face.bbox.astype(int)
|
129 |
+
landmarks = face.landmark_2d_106
|
130 |
|
131 |
+
# Align face based on landmarks
|
132 |
+
aligned_face = self.align_face(frame, landmarks)
|
133 |
+
if aligned_face is not None:
|
134 |
+
with torch.no_grad():
|
135 |
+
I = self.transform(aligned_face).unsqueeze(dim=0).to(self.device)
|
136 |
+
instyle = self.pspencoder(I)
|
137 |
+
instyle = self.vtoonify.zplus2wplus(instyle)
|
138 |
+
message = 'Successfully aligned the face.'
|
139 |
+
else:
|
140 |
+
frame = np.zeros((256, 256, 3), np.uint8)
|
141 |
+
else:
|
142 |
+
logging.warning("No face detected.")
|
143 |
+
frame = np.zeros((256, 256, 3), np.uint8)
|
144 |
+
|
145 |
+
if return_para:
|
146 |
+
return frame, instyle, message, h, w, top, bottom, left, right, scale
|
147 |
+
return frame, instyle, message
|
148 |
+
|
149 |
+
def align_face(self, image, landmarks):
|
150 |
+
# Calculate auxiliary vectors for alignment
|
151 |
+
eye_left = np.mean(landmarks[36:42], axis=0)
|
152 |
+
eye_right = np.mean(landmarks[42:48], axis=0)
|
153 |
+
mouth_left = landmarks[48]
|
154 |
+
mouth_right = landmarks[54]
|
155 |
+
|
156 |
+
# Calculate transformation parameters
|
157 |
+
eye_center = (eye_left + eye_right) / 2
|
158 |
+
mouth_center = (mouth_left + mouth_right) / 2
|
159 |
+
eye_to_eye = eye_right - eye_left
|
160 |
+
eye_to_mouth = mouth_center - eye_center
|
161 |
+
|
162 |
+
# Define the transformation matrix
|
163 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
164 |
+
x /= np.hypot(*x)
|
165 |
+
x *= np.hypot(*eye_to_eye) * 2.0
|
166 |
+
y = np.flipud(x) * [-1, 1]
|
167 |
+
c = eye_center + eye_to_mouth * 0.1
|
168 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
169 |
+
qsize = np.hypot(*x) * 2
|
170 |
+
|
171 |
+
# Transform and crop the image
|
172 |
+
transform_size = 256
|
173 |
+
output_size = 256
|
174 |
+
img = PIL.Image.fromarray(image)
|
175 |
+
img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
|
176 |
+
if output_size < transform_size:
|
177 |
+
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
|
178 |
+
|
179 |
+
return np.array(img)
|
180 |
+
|
181 |
+
# Other methods remain unchanged
|
182 |
def detect_and_align_image(self, frame_rgb: np.ndarray, top: int, bottom: int, left: int, right: int) -> tuple:
|
183 |
if frame_rgb is None:
|
184 |
return np.zeros((256, 256, 3), np.uint8), None, 'Error: fail to load the image.'
|