Wolowolo commited on
Commit
e46e042
·
verified ·
1 Parent(s): 189285f

Update app.py

Browse files

Fix fatal error

Files changed (1) hide show
  1. app.py +16 -21
app.py CHANGED
@@ -169,19 +169,20 @@ model = models_vit.__dict__['vit_base_patch16'](
169
  num_classes=args.nb_classes,
170
  drop_path_rate=args.drop_path,
171
  global_pool=args.global_pool,
172
- )
173
 
174
 
175
  def load_model(ckpt):
176
  if ckpt == 'choose from here' or 'continuously updating...':
177
  return gr.update()
178
- args.resume = os.path.join(CKPT_SAVE_PATH, ckpt)
179
  if os.path.isfile(args.resume) == False:
180
  hf_hub_download(local_dir=CKPT_SAVE_PATH,
181
  repo_id='Wolowolo/fsfm-3c/' + CKPT_NAME[ckpt],
182
  filename=ckpt)
183
  checkpoint = torch.load(args.resume, map_location='cpu')
184
  model.load_state_dict(checkpoint['model'])
 
185
  return gr.update()
186
 
187
 
@@ -276,9 +277,7 @@ def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None, dev
276
  return frame_indices
277
 
278
 
279
- def FSFM3C_video_detection(video):
280
- model.to(device)
281
-
282
  # extract frames
283
  num_frames = 32
284
 
@@ -308,21 +307,18 @@ def FSFM3C_video_detection(video):
308
 
309
  real_prob_video = int(round(1. - (sum(video_pred_list) / len(video_pred_list)), 2) * 100)
310
  if real_prob_video > 50:
311
- result_message = "real"
312
  else:
313
- result_message = "fake"
 
 
314
 
315
- video_results = (f"The face in this video may be {result_message}, "
316
- f"and the video-level real_face_probability is {real_prob_video}% \n"
317
- f"The frame-level detection results ['sampled_frame_index': 'real_face_probability']: \n"
318
- f"{frame_results} \n")
319
 
320
  return video_results
321
 
322
 
323
- def FSFM3C_image_detection(image):
324
- model.to(device)
325
-
326
  files = os.listdir(FRAME_SAVE_PATH)
327
  num_files = len(files)
328
  frame_path = os.path.join(FRAME_SAVE_PATH, str(num_files))
@@ -352,12 +348,11 @@ def FSFM3C_image_detection(image):
352
 
353
  real_prob_image = int(round(1. - (sum(video_pred_list) / len(video_pred_list)), 2) * 100)
354
  if real_prob_image > 50:
355
- result_message = "real"
356
  else:
357
- result_message = "fake"
358
-
359
- image_results = (f"The face in this image may be {result_message},"
360
- f"and the real_face_probability is {real_prob_image}%")
361
 
362
  return image_results
363
 
@@ -406,12 +401,12 @@ with gr.Blocks() as demo:
406
 
407
  image_submit_btn.click(
408
  fn=FSFM3C_image_detection,
409
- inputs=[image],
410
  outputs=[output_results_image],
411
  )
412
  video_submit_btn.click(
413
  fn=FSFM3C_video_detection,
414
- inputs=[video],
415
  outputs=[output_results_video],
416
  )
417
  ckpt_select_dropdown.change(
 
169
  num_classes=args.nb_classes,
170
  drop_path_rate=args.drop_path,
171
  global_pool=args.global_pool,
172
+ ).to(device)
173
 
174
 
175
  def load_model(ckpt):
176
  if ckpt == 'choose from here' or 'continuously updating...':
177
  return gr.update()
178
+ args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_NAME[ckpt])
179
  if os.path.isfile(args.resume) == False:
180
  hf_hub_download(local_dir=CKPT_SAVE_PATH,
181
  repo_id='Wolowolo/fsfm-3c/' + CKPT_NAME[ckpt],
182
  filename=ckpt)
183
  checkpoint = torch.load(args.resume, map_location='cpu')
184
  model.load_state_dict(checkpoint['model'])
185
+ model.eval()
186
  return gr.update()
187
 
188
 
 
277
  return frame_indices
278
 
279
 
280
+ def FSFM3C_video_detection(video, ckpt_select_dropdown):
 
 
281
  # extract frames
282
  num_frames = 32
283
 
 
307
 
308
  real_prob_video = int(round(1. - (sum(video_pred_list) / len(video_pred_list)), 2) * 100)
309
  if real_prob_video > 50:
310
+ result_message = "real" if 'FAS' not in ckpt_select_dropdown else 'spoof'
311
  else:
312
+ result_message = "fake" if 'FAS' not in ckpt_select_dropdown else 'real'
313
+ prob = 1 - real_prob_image if real_prob_video <= 50 else real_prob_video
314
+ image_results = (f"The face in this image may be {result_message} with probability is {real_prob_image}%")
315
 
316
+ video_results = (f"The face in this video may be {result_message} with probability {prob}")
 
 
 
317
 
318
  return video_results
319
 
320
 
321
+ def FSFM3C_image_detection(image, ckpt_select_dropdown):
 
 
322
  files = os.listdir(FRAME_SAVE_PATH)
323
  num_files = len(files)
324
  frame_path = os.path.join(FRAME_SAVE_PATH, str(num_files))
 
348
 
349
  real_prob_image = int(round(1. - (sum(video_pred_list) / len(video_pred_list)), 2) * 100)
350
  if real_prob_image > 50:
351
+ result_message = "real" if 'FAS' not in ckpt_select_dropdown else 'spoof'
352
  else:
353
+ result_message = "fake" if 'FAS' not in ckpt_select_dropdown else 'real'
354
+ prob = 1 - real_prob_image if real_prob_image <= 50 else real_prob_image
355
+ image_results = (f"The face in this image may be {result_message} with probability is {real_prob_image}%")
 
356
 
357
  return image_results
358
 
 
401
 
402
  image_submit_btn.click(
403
  fn=FSFM3C_image_detection,
404
+ inputs=[image, ckpt_select_dropdown],
405
  outputs=[output_results_image],
406
  )
407
  video_submit_btn.click(
408
  fn=FSFM3C_video_detection,
409
+ inputs=[video, ckpt_select_dropdown],
410
  outputs=[output_results_video],
411
  )
412
  ckpt_select_dropdown.change(