Update app.py
Browse filesFix fatal error
app.py
CHANGED
@@ -169,19 +169,20 @@ model = models_vit.__dict__['vit_base_patch16'](
|
|
169 |
num_classes=args.nb_classes,
|
170 |
drop_path_rate=args.drop_path,
|
171 |
global_pool=args.global_pool,
|
172 |
-
)
|
173 |
|
174 |
|
175 |
def load_model(ckpt):
|
176 |
if ckpt == 'choose from here' or 'continuously updating...':
|
177 |
return gr.update()
|
178 |
-
args.resume = os.path.join(CKPT_SAVE_PATH, ckpt)
|
179 |
if os.path.isfile(args.resume) == False:
|
180 |
hf_hub_download(local_dir=CKPT_SAVE_PATH,
|
181 |
repo_id='Wolowolo/fsfm-3c/' + CKPT_NAME[ckpt],
|
182 |
filename=ckpt)
|
183 |
checkpoint = torch.load(args.resume, map_location='cpu')
|
184 |
model.load_state_dict(checkpoint['model'])
|
|
|
185 |
return gr.update()
|
186 |
|
187 |
|
@@ -276,9 +277,7 @@ def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None, dev
|
|
276 |
return frame_indices
|
277 |
|
278 |
|
279 |
-
def FSFM3C_video_detection(video):
|
280 |
-
model.to(device)
|
281 |
-
|
282 |
# extract frames
|
283 |
num_frames = 32
|
284 |
|
@@ -308,21 +307,18 @@ def FSFM3C_video_detection(video):
|
|
308 |
|
309 |
real_prob_video = int(round(1. - (sum(video_pred_list) / len(video_pred_list)), 2) * 100)
|
310 |
if real_prob_video > 50:
|
311 |
-
result_message = "real"
|
312 |
else:
|
313 |
-
result_message = "fake"
|
|
|
|
|
314 |
|
315 |
-
video_results = (f"The face in this video may be {result_message}
|
316 |
-
f"and the video-level real_face_probability is {real_prob_video}% \n"
|
317 |
-
f"The frame-level detection results ['sampled_frame_index': 'real_face_probability']: \n"
|
318 |
-
f"{frame_results} \n")
|
319 |
|
320 |
return video_results
|
321 |
|
322 |
|
323 |
-
def FSFM3C_image_detection(image):
|
324 |
-
model.to(device)
|
325 |
-
|
326 |
files = os.listdir(FRAME_SAVE_PATH)
|
327 |
num_files = len(files)
|
328 |
frame_path = os.path.join(FRAME_SAVE_PATH, str(num_files))
|
@@ -352,12 +348,11 @@ def FSFM3C_image_detection(image):
|
|
352 |
|
353 |
real_prob_image = int(round(1. - (sum(video_pred_list) / len(video_pred_list)), 2) * 100)
|
354 |
if real_prob_image > 50:
|
355 |
-
result_message = "real"
|
356 |
else:
|
357 |
-
result_message = "fake"
|
358 |
-
|
359 |
-
image_results = (f"The face in this image may be {result_message}
|
360 |
-
f"and the real_face_probability is {real_prob_image}%")
|
361 |
|
362 |
return image_results
|
363 |
|
@@ -406,12 +401,12 @@ with gr.Blocks() as demo:
|
|
406 |
|
407 |
image_submit_btn.click(
|
408 |
fn=FSFM3C_image_detection,
|
409 |
-
inputs=[image],
|
410 |
outputs=[output_results_image],
|
411 |
)
|
412 |
video_submit_btn.click(
|
413 |
fn=FSFM3C_video_detection,
|
414 |
-
inputs=[video],
|
415 |
outputs=[output_results_video],
|
416 |
)
|
417 |
ckpt_select_dropdown.change(
|
|
|
169 |
num_classes=args.nb_classes,
|
170 |
drop_path_rate=args.drop_path,
|
171 |
global_pool=args.global_pool,
|
172 |
+
).to(device)
|
173 |
|
174 |
|
175 |
def load_model(ckpt):
|
176 |
if ckpt == 'choose from here' or 'continuously updating...':
|
177 |
return gr.update()
|
178 |
+
args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_NAME[ckpt])
|
179 |
if os.path.isfile(args.resume) == False:
|
180 |
hf_hub_download(local_dir=CKPT_SAVE_PATH,
|
181 |
repo_id='Wolowolo/fsfm-3c/' + CKPT_NAME[ckpt],
|
182 |
filename=ckpt)
|
183 |
checkpoint = torch.load(args.resume, map_location='cpu')
|
184 |
model.load_state_dict(checkpoint['model'])
|
185 |
+
model.eval()
|
186 |
return gr.update()
|
187 |
|
188 |
|
|
|
277 |
return frame_indices
|
278 |
|
279 |
|
280 |
+
def FSFM3C_video_detection(video, ckpt_select_dropdown):
|
|
|
|
|
281 |
# extract frames
|
282 |
num_frames = 32
|
283 |
|
|
|
307 |
|
308 |
real_prob_video = int(round(1. - (sum(video_pred_list) / len(video_pred_list)), 2) * 100)
|
309 |
if real_prob_video > 50:
|
310 |
+
result_message = "real" if 'FAS' not in ckpt_select_dropdown else 'spoof'
|
311 |
else:
|
312 |
+
result_message = "fake" if 'FAS' not in ckpt_select_dropdown else 'real'
|
313 |
+
prob = 1 - real_prob_image if real_prob_video <= 50 else real_prob_video
|
314 |
+
image_results = (f"The face in this image may be {result_message} with probability is {real_prob_image}%")
|
315 |
|
316 |
+
video_results = (f"The face in this video may be {result_message} with probability {prob}")
|
|
|
|
|
|
|
317 |
|
318 |
return video_results
|
319 |
|
320 |
|
321 |
+
def FSFM3C_image_detection(image, ckpt_select_dropdown):
|
|
|
|
|
322 |
files = os.listdir(FRAME_SAVE_PATH)
|
323 |
num_files = len(files)
|
324 |
frame_path = os.path.join(FRAME_SAVE_PATH, str(num_files))
|
|
|
348 |
|
349 |
real_prob_image = int(round(1. - (sum(video_pred_list) / len(video_pred_list)), 2) * 100)
|
350 |
if real_prob_image > 50:
|
351 |
+
result_message = "real" if 'FAS' not in ckpt_select_dropdown else 'spoof'
|
352 |
else:
|
353 |
+
result_message = "fake" if 'FAS' not in ckpt_select_dropdown else 'real'
|
354 |
+
prob = 1 - real_prob_image if real_prob_image <= 50 else real_prob_image
|
355 |
+
image_results = (f"The face in this image may be {result_message} with probability is {real_prob_image}%")
|
|
|
356 |
|
357 |
return image_results
|
358 |
|
|
|
401 |
|
402 |
image_submit_btn.click(
|
403 |
fn=FSFM3C_image_detection,
|
404 |
+
inputs=[image, ckpt_select_dropdown],
|
405 |
outputs=[output_results_image],
|
406 |
)
|
407 |
video_submit_btn.click(
|
408 |
fn=FSFM3C_video_detection,
|
409 |
+
inputs=[video, ckpt_select_dropdown],
|
410 |
outputs=[output_results_video],
|
411 |
)
|
412 |
ckpt_select_dropdown.change(
|