Wolowolo commited on
Commit
1e770e5
·
verified ·
1 Parent(s): 2f6003d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +424 -415
app.py CHANGED
@@ -1,416 +1,425 @@
1
- # -*- coding: utf-8 -*-
2
- # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
3
- # --------------------------------------------------------
4
- # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
- # You can find the license in the LICENSE file in the root directory of this source tree.
6
- # --------------------------------------------------------
7
-
8
- import sys
9
- import os
10
- os.system(f'pip install dlib')
11
- import dlib
12
- import argparse
13
- import numpy as np
14
- from PIL import Image
15
- import cv2
16
- import torch
17
- from huggingface_hub import hf_hub_download
18
- import gradio as gr
19
-
20
- import models_vit
21
- from util.datasets import build_dataset
22
- from engine_finetune import test_two_class, test_multi_class
23
- import matplotlib.pyplot as plt
24
- from torchvision import transforms
25
- import traceback
26
- from pytorch_grad_cam import (
27
- GradCAM,ScoreCAM,
28
- XGradCAM, EigenCAM
29
- )
30
- from pytorch_grad_cam import GuidedBackpropReLUModel
31
- from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
32
-
33
- def reshape_transform(tensor,height=14,width=14):
34
- result = tensor[:, 1:, :].reshape(tensor.size(0),height,width,tensor.size(2))
35
- result = result.transpose(2,3).transpose(1,2)
36
- return result
37
-
38
- def get_args_parser():
39
- parser = argparse.ArgumentParser('FSFM3C fine-tuning&Testing for image classification', add_help=False)
40
- parser.add_argument('--batch_size', default=64, type=int, help='Batch size per GPU')
41
- parser.add_argument('--epochs', default=50, type=int)
42
- parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations')
43
- parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', help='Name of model to train')
44
- parser.add_argument('--input_size', default=224, type=int, help='images input size')
45
- parser.add_argument('--normalize_from_IMN', action='store_true', help='cal mean and std from imagenet')
46
- parser.set_defaults(normalize_from_IMN=True)
47
- parser.add_argument('--apply_simple_augment', action='store_true', help='apply simple data augment')
48
- parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', help='Drop path rate')
49
- parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', help='Clip gradient norm')
50
- parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay')
51
- parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate')
52
- parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', help='base learning rate')
53
- parser.add_argument('--layer_decay', type=float, default=0.75, help='layer-wise lr decay')
54
- parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', help='lower lr bound')
55
- parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', help='epochs to warmup LR')
56
- parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', help='Color jitter factor')
57
- parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Use AutoAugment policy')
58
- parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing')
59
- parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob')
60
- parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode')
61
- parser.add_argument('--recount', type=int, default=1, help='Random erase count')
62
- parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first augmentation split')
63
- parser.add_argument('--mixup', type=float, default=0, help='mixup alpha')
64
- parser.add_argument('--cutmix', type=float, default=0, help='cutmix alpha')
65
- parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio')
66
- parser.add_argument('--mixup_prob', type=float, default=1.0, help='Probability of performing mixup or cutmix')
67
- parser.add_argument('--mixup_switch_prob', type=float, default=0.5, help='Probability of switching to cutmix')
68
- parser.add_argument('--mixup_mode', type=str, default='batch', help='How to apply mixup/cutmix params')
69
- parser.add_argument('--finetune', default='', help='finetune from checkpoint')
70
- parser.add_argument('--global_pool', action='store_true')
71
- parser.set_defaults(global_pool=True)
72
- parser.add_argument('--cls_token', action='store_false', dest='global_pool', help='Use class token for classification')
73
- parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, help='dataset path')
74
- parser.add_argument('--nb_classes', default=1000, type=int, help='number of the classification types')
75
- parser.add_argument('--output_dir', default='', help='path where to save')
76
- parser.add_argument('--log_dir', default='', help='path where to tensorboard log')
77
- parser.add_argument('--device', default='cuda', help='device to use for training / testing')
78
- parser.add_argument('--seed', default=0, type=int)
79
- parser.add_argument('--resume', default='', help='resume from checkpoint')
80
- parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
81
- parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
82
- parser.set_defaults(eval=True)
83
- parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation')
84
- parser.add_argument('--num_workers', default=10, type=int)
85
- parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader')
86
- parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
87
- parser.set_defaults(pin_mem=True)
88
- parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
89
- parser.add_argument('--local_rank', default=-1, type=int)
90
- parser.add_argument('--dist_on_itp', action='store_true')
91
- parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
92
- return parser
93
-
94
-
95
- def load_model(select_skpt):
96
- global ckpt, device, model, checkpoint
97
- if select_skpt not in CKPT_NAME:
98
- return gr.update(), "Select a correct model"
99
- ckpt = select_skpt
100
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
- args.nb_classes = CKPT_CLASS[ckpt]
102
- model = models_vit.__dict__[CKPT_MODEL[ckpt]](
103
- num_classes=args.nb_classes,
104
- drop_path_rate=args.drop_path,
105
- global_pool=args.global_pool,
106
- ).to(device)
107
-
108
- args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
109
- if os.path.isfile(args.resume) == False:
110
- hf_hub_download(local_dir=CKPT_SAVE_PATH,
111
- local_dir_use_symlinks=False,
112
- repo_id='Wolowolo/fsfm-3c',
113
- filename=CKPT_PATH[ckpt])
114
- args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
115
- checkpoint = torch.load(args.resume, map_location=device)
116
- model.load_state_dict(checkpoint['model'], strict=False)
117
- model.eval()
118
- global cam
119
- cam = GradCAM(model = model,
120
- target_layers=[model.blocks[-1].norm1],
121
- reshape_transform=reshape_transform
122
- )
123
- return gr.update(), f"[Loaded Model Successfully:] {args.resume}] "
124
-
125
-
126
- def get_boundingbox(face, width, height, minsize=None):
127
- x1, y1, x2, y2 = face.left(), face.top(), face.right(), face.bottom()
128
- size_bb = int(max(x2 - x1, y2 - y1) * 1.3)
129
- if minsize and size_bb < minsize:
130
- size_bb = minsize
131
- center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
132
- x1, y1 = max(int(center_x - size_bb // 2), 0), max(int(center_y - size_bb // 2), 0)
133
- size_bb = min(width - x1, size_bb)
134
- size_bb = min(height - y1, size_bb)
135
- return x1, y1, size_bb
136
-
137
-
138
- def extract_face(frame):
139
- face_detector = dlib.get_frontal_face_detector()
140
- image = np.array(frame.convert('RGB'))
141
- faces = face_detector(image, 1)
142
- if faces:
143
- face = faces[0]
144
- x, y, size = get_boundingbox(face, image.shape[1], image.shape[0])
145
- cropped_face = image[y:y + size, x:x + size]
146
- return Image.fromarray(cropped_face)
147
- return None
148
-
149
-
150
- def get_frame_index_uniform_sample(total_frame_num, extract_frame_num):
151
- return np.linspace(0, total_frame_num - 1, num=extract_frame_num, dtype=int).tolist()
152
-
153
-
154
- def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None):
155
- video_capture = cv2.VideoCapture(src_video)
156
- total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
157
- frame_indices = get_frame_index_uniform_sample(total_frames, num_frames) if num_frames else range(total_frames)
158
- for frame_index in frame_indices:
159
- video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
160
- ret, frame = video_capture.read()
161
- if not ret:
162
- continue
163
- image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
164
- img = extract_face(image)
165
- if img:
166
- img = img.resize((224, 224), Image.BICUBIC)
167
- save_img_name = f"frame_{frame_index}.png"
168
- img.save(os.path.join(dst_path, '0', save_img_name))
169
- video_capture.release()
170
- return frame_indices
171
- class TargetCategory:
172
- def __init__(self, category_index):
173
- self.category_index = category_index
174
-
175
- def __call__(self, output):
176
- return output[self.category_index]
177
- def preprocess_image_cam(pil_img,mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]):
178
- # 将 PIL 图像转换为 numpy 数组
179
- img_np = np.array(pil_img)
180
-
181
- # 归一化到 [0, 1]
182
- img_np = img_np.astype(np.float32) / 255.0
183
-
184
- # 标准化
185
- img_np = (img_np - mean) / std
186
-
187
- # 调整维度顺序以适应模型输入 (C, H, W)
188
- img_np = np.transpose(img_np, (2, 0, 1))
189
-
190
- # 添加批次维度 (B, C, H, W)
191
- img_np = np.expand_dims(img_np, axis=0)
192
-
193
- return img_np
194
- def FSFM3C_image_detection(image):
195
- frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH))))
196
- os.makedirs(frame_path, exist_ok=True)
197
- os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
198
- img = extract_face(image)
199
- if img is None:
200
- return 'No face detected, please upload a clear face!'
201
- img = img.resize((224, 224), Image.BICUBIC)
202
- img.save(os.path.join(frame_path, '0', "frame_0.png"))
203
- args.data_path = frame_path
204
- args.batch_size = 1
205
- dataset_val = build_dataset(is_train=False, args=args)
206
- sampler_val = torch.utils.data.SequentialSampler(dataset_val)
207
- data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
208
-
209
- if CKPT_CLASS[ckpt] > 2:
210
- frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device)
211
- class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack']
212
- avg_video_pred = np.mean(video_pred_list, axis=0)
213
- max_prob_index = np.argmax(avg_video_pred)
214
- max_prob_class = class_names[max_prob_index]
215
- probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)]
216
- image_results = f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]"
217
-
218
- # Generate CAM heatmap for the detected class
219
- use_cuda = True
220
- input_tensor = preprocess_image(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
221
- if use_cuda:
222
- input_tensor = input_tensor.cuda()
223
-
224
- # Dynamically determine the target category based on the maximum probability class
225
- category_names_to_index = {
226
- 'Real or Bonafide': 0,
227
- 'Deepfake': 1,
228
- 'Diffusion or AIGC generated': 2,
229
- 'Spoofing or Presentation-attack': 3
230
- }
231
- target_category = TargetCategory(category_names_to_index[max_prob_class])
232
-
233
- grayscale_cam = cam(input_tensor=input_tensor, targets=[target_category])
234
- grayscale_cam = 1 - grayscale_cam[0, :]
235
- img = np.array(img)
236
- if img.shape[2] == 4:
237
- img = img[:, :, :3]
238
- img = img.astype(np.float32) / 255.0
239
- visualization = show_cam_on_image(img, grayscale_cam)
240
- visualization = cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR)
241
-
242
- # Add text overlay to the heatmap
243
- # text = f"Detected: {max_prob_class}"
244
- # cv2.putText(visualization, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
245
- output_path = "./CAM_images/output_heatmap.png"
246
- cv2.imwrite(output_path, visualization)
247
- return image_results, output_path,probabilities[max_prob_index]
248
-
249
- if CKPT_CLASS[ckpt] == 2:
250
- frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device)
251
- if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++':
252
- prob = sum(video_pred_list) / len(video_pred_list)
253
- label = "Deepfake" if prob <= 0.5 else "Real"
254
- prob = prob if label == "Real" else 1 - prob
255
- if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO':
256
- prob = sum(video_pred_list) / len(video_pred_list)
257
- label = "Spoofing" if prob <= 0.5 else "Bonafide"
258
- prob = prob if label == "Bonafide" else 1 - prob
259
- image_results = f"The largest face in this image may be {label} with probability {prob * 100:.1f}%"
260
- return image_results, None ,None
261
-
262
-
263
- def FSFM3C_video_detection(video, num_frames):
264
- try:
265
- frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH))))
266
- os.makedirs(frame_path, exist_ok=True)
267
- os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
268
- frame_indices = extract_face_from_fixed_num_frames(video, frame_path, num_frames=num_frames)
269
- args.data_path = frame_path
270
- args.batch_size = num_frames
271
- dataset_val = build_dataset(is_train=False, args=args)
272
- sampler_val = torch.utils.data.SequentialSampler(dataset_val)
273
- data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
274
-
275
- if CKPT_CLASS[ckpt] > 2:
276
- frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device)
277
- class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack']
278
- avg_video_pred = np.mean(video_pred_list, axis=0)
279
- max_prob_index = np.argmax(avg_video_pred)
280
- max_prob_class = class_names[max_prob_index]
281
- probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)]
282
-
283
- frame_results = {f"frame_{frame_indices[i]}": [f"{class_names[j]}: {prob * 100:.1f}%" for j, prob in enumerate(frame_preds_list[i])] for i in range(len(frame_indices))}
284
- video_results = (f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]\n \n"
285
- f"The frame-level detection results ['frame_index': 'probabilities']: \n{frame_results}")
286
- return video_results
287
-
288
- if CKPT_CLASS[ckpt] == 2:
289
- frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device)
290
- if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++':
291
- prob = sum(video_pred_list) / len(video_pred_list)
292
- label = "Deepfake" if prob <= 0.5 else "Real"
293
- prob = prob if label == "Real" else 1 - prob
294
- frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in
295
- range(len(frame_indices))} if label == "Real" else {f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in
296
- range(len(frame_indices))}
297
-
298
- if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO':
299
- prob = sum(video_pred_list) / len(video_pred_list)
300
- label = "Spoofing" if prob <= 0.5 else "Bonafide"
301
- prob = prob if label == "Bonafide" else 1 - prob
302
- frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in
303
- range(len(frame_indices))} if label == "Bonafide" else {f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in
304
- range(len(frame_indices))}
305
-
306
- video_results = (f"The largest face in this image may be {label} with probability {prob * 100:.1f}%\n \n"
307
- f"The frame-level detection results ['frame_index': 'real_face_probability']: \n{frame_results}")
308
- return video_results
309
- except Exception as e:
310
- return f"Error occurred. Please provide a clear face video or reduce the number of frames."
311
-
312
- # Paths and Constants
313
- P = os.path.abspath(__file__)
314
- FRAME_SAVE_PATH = os.path.join(os.path.dirname(P), 'frame')
315
- CKPT_SAVE_PATH = os.path.join(os.path.dirname(P), 'checkpoints')
316
- os.makedirs(FRAME_SAVE_PATH, exist_ok=True)
317
- os.makedirs(CKPT_SAVE_PATH, exist_ok=True)
318
- CKPT_NAME = [
319
- '✨Unified-detector_v1_Fine-tuned_on_4_classes',
320
- 'DfD-Checkpoint_Fine-tuned_on_FF++',
321
- 'FAS-Checkpoint_Fine-tuned_on_MCIO',
322
- ]
323
- CKPT_PATH = {
324
- '✨Unified-detector_v1_Fine-tuned_on_4_classes': 'finetuned_models/Unified-detector/v1_Fine-tuned_on_4_classes/checkpoint-min_train_loss.pth',
325
- 'DfD-Checkpoint_Fine-tuned_on_FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth',
326
- 'FAS-Checkpoint_Fine-tuned_on_MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth',
327
- }
328
- CKPT_CLASS = {
329
- '✨Unified-detector_v1_Fine-tuned_on_4_classes': 4,
330
- 'DfD-Checkpoint_Fine-tuned_on_FF++': 2,
331
- 'FAS-Checkpoint_Fine-tuned_on_MCIO': 2
332
- }
333
- CKPT_MODEL = {
334
- '✨Unified-detector_v1_Fine-tuned_on_4_classes': 'vit_base_patch16',
335
- 'DfD-Checkpoint_Fine-tuned_on_FF++': 'vit_base_patch16',
336
- 'FAS-Checkpoint_Fine-tuned_on_MCIO': 'vit_base_patch16',
337
- }
338
-
339
- with gr.Blocks(css=".custom-label { font-weight: bold !important; font-size: 16px !important; }") as demo:
340
- gr.HTML("<h1 style='text-align: center;'>🦱 Real Facial Image&Video Detection <br> Against Face Forgery (Deepfake/Diffusion) and Spoofing (Presentation-attacks)</h1>")
341
- gr.Markdown("<b>☉ Powered by the fine-tuned ViT models that is pre-trained from [FSFM-3C](https://fsfm-3c.github.io/)</b> <br> "
342
- "<b>☉ We do not and cannot access or store the data you have uploaded!</b> <br> "
343
- "<b>☉ Release (Continuously updating) </b> <br> <b>[V1.0] 2025/02/22-Current🎉</b>: "
344
- "1) Updated <b>[✨Unified-detector_v1] for Unified Physical-Digital Face Attack&Forgery Detection, a ViT-B/16-224 (FSFM Pre-trained) detector that could identify Real&Bonafide, Deepfake, Diffusion&AIGC, Spooing&Presentation-attacks facial images or videos </b> ; 2) Provided the selection of the number of video frames (uniformly sampling 1-32 frames, more frames may time-consuming for this page without GPU acceleration); 3) Fixed some errors of V0.1 including loading and prediction. <br>"
345
- "<b>[V0.1] 2024/12-2025/02/21</b>: "
346
- "Create this page with basic detectors [DfD-Checkpoint_Fine-tuned_on_FF++, FAS-Checkpoint_Fine-tuned_on_MCIO] that follow the paper implementation. <br> ")
347
- gr.Markdown("- Please <b>provide a facial image or video(<100s)</b>, and <b>select the model</b> for detection: <br> <b>[SUGGEST] [✨Unified-detector_v1_Fine-tuned_on_4_classes]</b> a (FSFM Pre-trained) ViT-B/16-224 for Both Real/Deepfake/Diffusion/Spoofing facial images&videos Detection <br> <b>[DfD-Checkpoint_Fine-tuned_on_FF++]</b> for deepfake detection, FSFM ViT-B/16-224 fine-tuned on the FF++_c23 train&val sets (4 manipulations, 32 frames per video) <br> <b>[FAS-Checkpoint_Fine-tuned_on_MCIO]</b> for face anti-spoofing, FSFM ViT-B/16-224 fine-tuned on the MCIO datasets (2 frames per video)")
348
-
349
- with gr.Row():
350
- ckpt_select_dropdown = gr.Dropdown(
351
- label="Select the Model for Detection ⬇️",
352
- elem_classes="custom-label",
353
- choices=['Choose Model Here 🖱️'] + CKPT_NAME + ['continuously updating...'],
354
- multiselect=False,
355
- value='Choose Model Here 🖱️',
356
- interactive=True,
357
- )
358
- model_loading_status = gr.Textbox(label="Model Loading Status")
359
- with gr.Row():
360
- with gr.Column(scale=5):
361
- gr.Markdown("### Image Detection (Fast Try: copying image from [whichfaceisreal](https://whichfaceisreal.com/))")
362
- image = gr.Image(label="Upload/Capture/Paste your image", type="pil")
363
- image_submit_btn = gr.Button("Submit")
364
- output_results_image = gr.Textbox(label="Detection Result")
365
-
366
- with gr.Row():
367
- output_heatmap = gr.Image(label="Grad_CAM")
368
- output_max_prob_class = gr.Textbox(label="Detected Class")
369
- with gr.Column(scale=5):
370
- gr.Markdown("### Video Detection")
371
- video = gr.Video(label="Upload/Capture your video")
372
- frame_slider = gr.Slider(minimum=1, maximum=32, step=1, value=32, label="Number of Frames for Detection")
373
- video_submit_btn = gr.Button("Submit")
374
- output_results_video = gr.Textbox(label="Detection Result")
375
-
376
- ckpt_select_dropdown.change(
377
- fn=load_model,
378
- inputs=[ckpt_select_dropdown],
379
- outputs=[ckpt_select_dropdown, model_loading_status],
380
- )
381
- image_submit_btn.click(
382
- fn=FSFM3C_image_detection,
383
- inputs=[image],
384
- outputs=[output_results_image, output_heatmap,output_max_prob_class],
385
- )
386
- video_submit_btn.click(
387
- fn=FSFM3C_video_detection,
388
- inputs=[video, frame_slider],
389
- outputs=[output_results_video],
390
- )
391
-
392
- if __name__ == "__main__":
393
- args = get_args_parser()
394
- args = args.parse_args()
395
- ckpt = '✨Unified-detector_v1_Fine-tuned_on_4_classes'
396
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
397
- args.nb_classes = CKPT_CLASS[ckpt]
398
- model = models_vit.__dict__[CKPT_MODEL[ckpt]](
399
- num_classes=args.nb_classes,
400
- drop_path_rate=args.drop_path,
401
- global_pool=args.global_pool,
402
- ).to(device)
403
- args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
404
- if os.path.isfile(args.resume) == False:
405
- hf_hub_download(local_dir=CKPT_SAVE_PATH,
406
- local_dir_use_symlinks=False,
407
- repo_id='Wolowolo/fsfm-3c',
408
- filename=CKPT_PATH[ckpt])
409
- args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
410
- checkpoint = torch.load(args.resume, map_location=device)
411
- model.load_state_dict(checkpoint['model'], strict=False)
412
- model.eval()
413
-
414
- gr.close_all()
415
- demo.queue()
 
 
 
 
 
 
 
 
 
416
  demo.launch()
 
1
+ # -*- coding: utf-8 -*-
2
+ # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
3
+ # --------------------------------------------------------
4
+ # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
+ # You can find the license in the LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+
8
+ import sys
9
+ import os
10
+ os.system(f'pip install dlib')
11
+ import dlib
12
+ import argparse
13
+ import numpy as np
14
+ from PIL import Image
15
+ import cv2
16
+ import torch
17
+ from huggingface_hub import hf_hub_download
18
+ import gradio as gr
19
+
20
+ import models_vit
21
+ from util.datasets import build_dataset
22
+ from engine_finetune import test_two_class, test_multi_class
23
+ import matplotlib.pyplot as plt
24
+ from torchvision import transforms
25
+ import traceback
26
+ from pytorch_grad_cam import (
27
+ GradCAM,ScoreCAM,
28
+ XGradCAM, EigenCAM
29
+ )
30
+ from pytorch_grad_cam import GuidedBackpropReLUModel
31
+ from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
32
+
33
+ def reshape_transform(tensor,height=14,width=14):
34
+ result = tensor[:, 1:, :].reshape(tensor.size(0),height,width,tensor.size(2))
35
+ result = result.transpose(2,3).transpose(1,2)
36
+ return result
37
+
38
+ def get_args_parser():
39
+ parser = argparse.ArgumentParser('FSFM3C fine-tuning&Testing for image classification', add_help=False)
40
+ parser.add_argument('--batch_size', default=64, type=int, help='Batch size per GPU')
41
+ parser.add_argument('--epochs', default=50, type=int)
42
+ parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations')
43
+ parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', help='Name of model to train')
44
+ parser.add_argument('--input_size', default=224, type=int, help='images input size')
45
+ parser.add_argument('--normalize_from_IMN', action='store_true', help='cal mean and std from imagenet')
46
+ parser.set_defaults(normalize_from_IMN=True)
47
+ parser.add_argument('--apply_simple_augment', action='store_true', help='apply simple data augment')
48
+ parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', help='Drop path rate')
49
+ parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', help='Clip gradient norm')
50
+ parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay')
51
+ parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate')
52
+ parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', help='base learning rate')
53
+ parser.add_argument('--layer_decay', type=float, default=0.75, help='layer-wise lr decay')
54
+ parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', help='lower lr bound')
55
+ parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', help='epochs to warmup LR')
56
+ parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', help='Color jitter factor')
57
+ parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Use AutoAugment policy')
58
+ parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing')
59
+ parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob')
60
+ parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode')
61
+ parser.add_argument('--recount', type=int, default=1, help='Random erase count')
62
+ parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first augmentation split')
63
+ parser.add_argument('--mixup', type=float, default=0, help='mixup alpha')
64
+ parser.add_argument('--cutmix', type=float, default=0, help='cutmix alpha')
65
+ parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio')
66
+ parser.add_argument('--mixup_prob', type=float, default=1.0, help='Probability of performing mixup or cutmix')
67
+ parser.add_argument('--mixup_switch_prob', type=float, default=0.5, help='Probability of switching to cutmix')
68
+ parser.add_argument('--mixup_mode', type=str, default='batch', help='How to apply mixup/cutmix params')
69
+ parser.add_argument('--finetune', default='', help='finetune from checkpoint')
70
+ parser.add_argument('--global_pool', action='store_true')
71
+ parser.set_defaults(global_pool=True)
72
+ parser.add_argument('--cls_token', action='store_false', dest='global_pool', help='Use class token for classification')
73
+ parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, help='dataset path')
74
+ parser.add_argument('--nb_classes', default=1000, type=int, help='number of the classification types')
75
+ parser.add_argument('--output_dir', default='', help='path where to save')
76
+ parser.add_argument('--log_dir', default='', help='path where to tensorboard log')
77
+ parser.add_argument('--device', default='cuda', help='device to use for training / testing')
78
+ parser.add_argument('--seed', default=0, type=int)
79
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
80
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
81
+ parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
82
+ parser.set_defaults(eval=True)
83
+ parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation')
84
+ parser.add_argument('--num_workers', default=10, type=int)
85
+ parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader')
86
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
87
+ parser.set_defaults(pin_mem=True)
88
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
89
+ parser.add_argument('--local_rank', default=-1, type=int)
90
+ parser.add_argument('--dist_on_itp', action='store_true')
91
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
92
+ return parser
93
+
94
+
95
+ def load_model(select_skpt):
96
+ global ckpt, device, model, checkpoint
97
+ if select_skpt not in CKPT_NAME:
98
+ return gr.update(), "Select a correct model"
99
+ ckpt = select_skpt
100
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
+ args.nb_classes = CKPT_CLASS[ckpt]
102
+ model = models_vit.__dict__[CKPT_MODEL[ckpt]](
103
+ num_classes=args.nb_classes,
104
+ drop_path_rate=args.drop_path,
105
+ global_pool=args.global_pool,
106
+ ).to(device)
107
+
108
+ args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
109
+ if os.path.isfile(args.resume) == False:
110
+ hf_hub_download(local_dir=CKPT_SAVE_PATH,
111
+ local_dir_use_symlinks=False,
112
+ repo_id='Wolowolo/fsfm-3c',
113
+ filename=CKPT_PATH[ckpt])
114
+ args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
115
+ checkpoint = torch.load(args.resume, map_location=device)
116
+ model.load_state_dict(checkpoint['model'], strict=False)
117
+ model.eval()
118
+ global cam
119
+ cam = GradCAM(model = model,
120
+ target_layers=[model.blocks[-1].norm1],
121
+ reshape_transform=reshape_transform
122
+ )
123
+ return gr.update(), f"[Loaded Model Successfully:] {args.resume}] "
124
+
125
+
126
+ def get_boundingbox(face, width, height, minsize=None):
127
+ x1, y1, x2, y2 = face.left(), face.top(), face.right(), face.bottom()
128
+ size_bb = int(max(x2 - x1, y2 - y1) * 1.3)
129
+ if minsize and size_bb < minsize:
130
+ size_bb = minsize
131
+ center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
132
+ x1, y1 = max(int(center_x - size_bb // 2), 0), max(int(center_y - size_bb // 2), 0)
133
+ size_bb = min(width - x1, size_bb)
134
+ size_bb = min(height - y1, size_bb)
135
+ return x1, y1, size_bb
136
+
137
+
138
+ def extract_face(frame):
139
+ face_detector = dlib.get_frontal_face_detector()
140
+ image = np.array(frame.convert('RGB'))
141
+ faces = face_detector(image, 1)
142
+ if faces:
143
+ face = faces[0]
144
+ x, y, size = get_boundingbox(face, image.shape[1], image.shape[0])
145
+ cropped_face = image[y:y + size, x:x + size]
146
+ return Image.fromarray(cropped_face)
147
+ return None
148
+
149
+
150
+ def get_frame_index_uniform_sample(total_frame_num, extract_frame_num):
151
+ return np.linspace(0, total_frame_num - 1, num=extract_frame_num, dtype=int).tolist()
152
+
153
+
154
+ def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None):
155
+ video_capture = cv2.VideoCapture(src_video)
156
+ total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
157
+ frame_indices = get_frame_index_uniform_sample(total_frames, num_frames) if num_frames else range(total_frames)
158
+ for frame_index in frame_indices:
159
+ video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
160
+ ret, frame = video_capture.read()
161
+ if not ret:
162
+ continue
163
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
164
+ img = extract_face(image)
165
+ if img:
166
+ img = img.resize((224, 224), Image.BICUBIC)
167
+ save_img_name = f"frame_{frame_index}.png"
168
+ img.save(os.path.join(dst_path, '0', save_img_name))
169
+ video_capture.release()
170
+ return frame_indices
171
+ class TargetCategory:
172
+ def __init__(self, category_index):
173
+ self.category_index = category_index
174
+
175
+ def __call__(self, output):
176
+ return output[self.category_index]
177
+ def preprocess_image_cam(pil_img,mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]):
178
+ # 将 PIL 图像转换为 numpy 数组
179
+ img_np = np.array(pil_img)
180
+
181
+ # 归一化到 [0, 1]
182
+ img_np = img_np.astype(np.float32) / 255.0
183
+
184
+ # 标准化
185
+ img_np = (img_np - mean) / std
186
+
187
+ # 调整维度顺序以适应模型输入 (C, H, W)
188
+ img_np = np.transpose(img_np, (2, 0, 1))
189
+
190
+ # 添加批次维度 (B, C, H, W)
191
+ img_np = np.expand_dims(img_np, axis=0)
192
+
193
+ return img_np
194
+ def FSFM3C_image_detection(image):
195
+ frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH))))
196
+ os.makedirs(frame_path, exist_ok=True)
197
+ os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
198
+ img = extract_face(image)
199
+ if img is None:
200
+ return 'No face detected, please upload a clear face!'
201
+ img = img.resize((224, 224), Image.BICUBIC)
202
+ img.save(os.path.join(frame_path, '0', "frame_0.png"))
203
+ args.data_path = frame_path
204
+ args.batch_size = 1
205
+ dataset_val = build_dataset(is_train=False, args=args)
206
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
207
+ data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
208
+
209
+ if CKPT_CLASS[ckpt] > 2:
210
+ frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device)
211
+ class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack']
212
+ avg_video_pred = np.mean(video_pred_list, axis=0)
213
+ max_prob_index = np.argmax(avg_video_pred)
214
+ max_prob_class = class_names[max_prob_index]
215
+ probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)]
216
+ image_results = f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]"
217
+
218
+ # Generate CAM heatmap for the detected class
219
+ use_cuda = True
220
+ input_tensor = preprocess_image(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
221
+ if use_cuda:
222
+ input_tensor = input_tensor.cuda()
223
+
224
+ # Dynamically determine the target category based on the maximum probability class
225
+ category_names_to_index = {
226
+ 'Real or Bonafide': 0,
227
+ 'Deepfake': 1,
228
+ 'Diffusion or AIGC generated': 2,
229
+ 'Spoofing or Presentation-attack': 3
230
+ }
231
+ target_category = TargetCategory(category_names_to_index[max_prob_class])
232
+
233
+ grayscale_cam = cam(input_tensor=input_tensor, targets=[target_category])
234
+ grayscale_cam = 1 - grayscale_cam[0, :]
235
+ img = np.array(img)
236
+ if img.shape[2] == 4:
237
+ img = img[:, :, :3]
238
+ img = img.astype(np.float32) / 255.0
239
+ visualization = show_cam_on_image(img, grayscale_cam)
240
+ visualization = cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR)
241
+
242
+ # Add text overlay to the heatmap
243
+ # text = f"Detected: {max_prob_class}"
244
+ # cv2.putText(visualization, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
245
+ output_path = "./CAM_images/output_heatmap.png"
246
+ cv2.imwrite(output_path, visualization)
247
+ return image_results, output_path,probabilities[max_prob_index]
248
+
249
+ if CKPT_CLASS[ckpt] == 2:
250
+ frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device)
251
+ if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++':
252
+ prob = sum(video_pred_list) / len(video_pred_list)
253
+ label = "Deepfake" if prob <= 0.5 else "Real"
254
+ prob = prob if label == "Real" else 1 - prob
255
+ if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO':
256
+ prob = sum(video_pred_list) / len(video_pred_list)
257
+ label = "Spoofing" if prob <= 0.5 else "Bonafide"
258
+ prob = prob if label == "Bonafide" else 1 - prob
259
+ image_results = f"The largest face in this image may be {label} with probability {prob * 100:.1f}%"
260
+ return image_results, None ,None
261
+
262
+
263
+ def FSFM3C_video_detection(video, num_frames):
264
+ try:
265
+ frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH))))
266
+ os.makedirs(frame_path, exist_ok=True)
267
+ os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
268
+ frame_indices = extract_face_from_fixed_num_frames(video, frame_path, num_frames=num_frames)
269
+ args.data_path = frame_path
270
+ args.batch_size = num_frames
271
+ dataset_val = build_dataset(is_train=False, args=args)
272
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
273
+ data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
274
+
275
+ if CKPT_CLASS[ckpt] > 2:
276
+ frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device)
277
+ class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack']
278
+ avg_video_pred = np.mean(video_pred_list, axis=0)
279
+ max_prob_index = np.argmax(avg_video_pred)
280
+ max_prob_class = class_names[max_prob_index]
281
+ probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)]
282
+
283
+ frame_results = {f"frame_{frame_indices[i]}": [f"{class_names[j]}: {prob * 100:.1f}%" for j, prob in enumerate(frame_preds_list[i])] for i in range(len(frame_indices))}
284
+ video_results = (f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]\n \n"
285
+ f"The frame-level detection results ['frame_index': 'probabilities']: \n{frame_results}")
286
+ return video_results
287
+
288
+ if CKPT_CLASS[ckpt] == 2:
289
+ frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device)
290
+ if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++':
291
+ prob = sum(video_pred_list) / len(video_pred_list)
292
+ label = "Deepfake" if prob <= 0.5 else "Real"
293
+ prob = prob if label == "Real" else 1 - prob
294
+ frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in
295
+ range(len(frame_indices))} if label == "Real" else {f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in
296
+ range(len(frame_indices))}
297
+
298
+ if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO':
299
+ prob = sum(video_pred_list) / len(video_pred_list)
300
+ label = "Spoofing" if prob <= 0.5 else "Bonafide"
301
+ prob = prob if label == "Bonafide" else 1 - prob
302
+ frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in
303
+ range(len(frame_indices))} if label == "Bonafide" else {f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in
304
+ range(len(frame_indices))}
305
+
306
+ video_results = (f"The largest face in this image may be {label} with probability {prob * 100:.1f}%\n \n"
307
+ f"The frame-level detection results ['frame_index': 'real_face_probability']: \n{frame_results}")
308
+ return video_results
309
+ except Exception as e:
310
+ return f"Error occurred. Please provide a clear face video or reduce the number of frames."
311
+
312
+ # Paths and Constants
313
+ P = os.path.abspath(__file__)
314
+ FRAME_SAVE_PATH = os.path.join(os.path.dirname(P), 'frame')
315
+ CKPT_SAVE_PATH = os.path.join(os.path.dirname(P), 'checkpoints')
316
+ os.makedirs(FRAME_SAVE_PATH, exist_ok=True)
317
+ os.makedirs(CKPT_SAVE_PATH, exist_ok=True)
318
+ CKPT_NAME = [
319
+ '✨Unified-detector_v1_Fine-tuned_on_4_classes',
320
+ 'DfD-Checkpoint_Fine-tuned_on_FF++',
321
+ 'FAS-Checkpoint_Fine-tuned_on_MCIO',
322
+ ]
323
+ CKPT_PATH = {
324
+ '✨Unified-detector_v1_Fine-tuned_on_4_classes': 'finetuned_models/Unified-detector/v1_Fine-tuned_on_4_classes/checkpoint-min_train_loss.pth',
325
+ 'DfD-Checkpoint_Fine-tuned_on_FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth',
326
+ 'FAS-Checkpoint_Fine-tuned_on_MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth',
327
+ }
328
+ CKPT_CLASS = {
329
+ '✨Unified-detector_v1_Fine-tuned_on_4_classes': 4,
330
+ 'DfD-Checkpoint_Fine-tuned_on_FF++': 2,
331
+ 'FAS-Checkpoint_Fine-tuned_on_MCIO': 2
332
+ }
333
+ CKPT_MODEL = {
334
+ '✨Unified-detector_v1_Fine-tuned_on_4_classes': 'vit_base_patch16',
335
+ 'DfD-Checkpoint_Fine-tuned_on_FF++': 'vit_base_patch16',
336
+ 'FAS-Checkpoint_Fine-tuned_on_MCIO': 'vit_base_patch16',
337
+ }
338
+
339
+ with gr.Blocks(css=".custom-label { font-weight: bold !important; font-size: 16px !important; }") as demo:
340
+ gr.HTML("<h1 style='text-align: center;'>🦱 Real Facial Image&Video Detection <br> Against Face Forgery (Deepfake/Diffusion) and Spoofing (Presentation-attacks)</h1>")
341
+ gr.Markdown("<b>☉ Powered by the fine-tuned ViT models that is pre-trained from [FSFM-3C](https://fsfm-3c.github.io/)</b> <br> "
342
+ "<b>☉ We do not and cannot access or store the data you have uploaded!</b> <br> "
343
+ "<b>☉ Release (Continuously updating) </b> <br> <b>[V1.0] 2025/02/22-Current🎉</b>: "
344
+ "1) Updated <b>[✨Unified-detector_v1] for Unified Physical-Digital Face Attack&Forgery Detection, a ViT-B/16-224 (FSFM Pre-trained) detector that could identify Real&Bonafide, Deepfake, Diffusion&AIGC, Spooing&Presentation-attacks facial images or videos </b> ; 2) Provided the selection of the number of video frames (uniformly sampling 1-32 frames, more frames may time-consuming for this page without GPU acceleration); 3) Fixed some errors of V0.1 including loading and prediction. <br>"
345
+ "<b>[V0.1] 2024/12-2025/02/21</b>: "
346
+ "Create this page with basic detectors [DfD-Checkpoint_Fine-tuned_on_FF++, FAS-Checkpoint_Fine-tuned_on_MCIO] that follow the paper implementation. <br> ")
347
+ gr.Markdown("- Please <b>provide a facial image or video(<100s)</b>, and <b>select the model</b> for detection: <br> <b>[SUGGEST] [✨Unified-detector_v1_Fine-tuned_on_4_classes]</b> a (FSFM Pre-trained) ViT-B/16-224 for Both Real/Deepfake/Diffusion/Spoofing facial images&videos Detection <br> <b>[DfD-Checkpoint_Fine-tuned_on_FF++]</b> for deepfake detection, FSFM ViT-B/16-224 fine-tuned on the FF++_c23 train&val sets (4 manipulations, 32 frames per video) <br> <b>[FAS-Checkpoint_Fine-tuned_on_MCIO]</b> for face anti-spoofing, FSFM ViT-B/16-224 fine-tuned on the MCIO datasets (2 frames per video)")
348
+
349
+ with gr.Row():
350
+ ckpt_select_dropdown = gr.Dropdown(
351
+ label="Select the Model for Detection ⬇️",
352
+ elem_classes="custom-label",
353
+ choices=['Choose Model Here 🖱️'] + CKPT_NAME + ['continuously updating...'],
354
+ multiselect=False,
355
+ value='Choose Model Here 🖱️',
356
+ interactive=True,
357
+ )
358
+ model_loading_status = gr.Textbox(label="Model Loading Status")
359
+ with gr.Row():
360
+ with gr.Column(scale=5):
361
+ gr.Markdown("### Image Detection (Fast Try: copying image from [whichfaceisreal](https://whichfaceisreal.com/))")
362
+ image = gr.Image(label="Upload/Capture/Paste your image", type="pil")
363
+ image_submit_btn = gr.Button("Submit")
364
+ output_results_image = gr.Textbox(label="Detection Result")
365
+
366
+ with gr.Row():
367
+ output_heatmap = gr.Image(label="Grad_CAM")
368
+ output_max_prob_class = gr.Textbox(label="Detected Class")
369
+ with gr.Column(scale=5):
370
+ gr.Markdown("### Video Detection")
371
+ video = gr.Video(label="Upload/Capture your video")
372
+ frame_slider = gr.Slider(minimum=1, maximum=32, step=1, value=32, label="Number of Frames for Detection")
373
+ video_submit_btn = gr.Button("Submit")
374
+ output_results_video = gr.Textbox(label="Detection Result")
375
+
376
+ gr.HTML(
377
+ '<div style="display: flex; justify-content: center; gap: 20px; margin-bottom: 20px;">'
378
+ '<a href="https://mapmyvisitors.com/web/1bxvi" title="Visit tracker">'
379
+ '<img src="https://mapmyvisitors.com/map.png?d=FYhBoxLDEaFAxdfRzk5TuchYOBGrnSa98Ky59EkEEpY&cl=ffffff">'
380
+ '</a>'
381
+ '</div>'
382
+ )
383
+
384
+
385
+ ckpt_select_dropdown.change(
386
+ fn=load_model,
387
+ inputs=[ckpt_select_dropdown],
388
+ outputs=[ckpt_select_dropdown, model_loading_status],
389
+ )
390
+ image_submit_btn.click(
391
+ fn=FSFM3C_image_detection,
392
+ inputs=[image],
393
+ outputs=[output_results_image, output_heatmap,output_max_prob_class],
394
+ )
395
+ video_submit_btn.click(
396
+ fn=FSFM3C_video_detection,
397
+ inputs=[video, frame_slider],
398
+ outputs=[output_results_video],
399
+ )
400
+
401
+ if __name__ == "__main__":
402
+ args = get_args_parser()
403
+ args = args.parse_args()
404
+ ckpt = '✨Unified-detector_v1_Fine-tuned_on_4_classes'
405
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
406
+ args.nb_classes = CKPT_CLASS[ckpt]
407
+ model = models_vit.__dict__[CKPT_MODEL[ckpt]](
408
+ num_classes=args.nb_classes,
409
+ drop_path_rate=args.drop_path,
410
+ global_pool=args.global_pool,
411
+ ).to(device)
412
+ args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
413
+ if os.path.isfile(args.resume) == False:
414
+ hf_hub_download(local_dir=CKPT_SAVE_PATH,
415
+ local_dir_use_symlinks=False,
416
+ repo_id='Wolowolo/fsfm-3c',
417
+ filename=CKPT_PATH[ckpt])
418
+ args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
419
+ checkpoint = torch.load(args.resume, map_location=device)
420
+ model.load_state_dict(checkpoint['model'], strict=False)
421
+ model.eval()
422
+
423
+ gr.close_all()
424
+ demo.queue()
425
  demo.launch()