monster07 commited on
Commit
bbe970a
Β·
verified Β·
1 Parent(s): 2b83e03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -47
app.py CHANGED
@@ -8,37 +8,123 @@ import numpy as np
8
  import matplotlib.pyplot as plt
9
  import os
10
 
11
- # βœ… Load Xception model trained on DFDC
12
- class XceptionModel(nn.Module):
13
- def __init__(self):
14
  super().__init__()
15
- base = torch.hub.load('pytorch/vision:v0.10.0', 'xception', pretrained=False)
16
- base.fc = nn.Linear(base.fc.in_features, 1)
17
- self.base = base
18
-
19
  def forward(self, x):
20
- return self.base(x)
 
 
21
 
22
- # Load model and weights
23
- model = XceptionModel()
24
- state_dict = torch.hub.load_state_dict_from_url(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  "https://huggingface.co/Selimsef/xception-cnn-df/resolve/main/xception-binary-weights.pt",
26
- map_location=torch.device('cpu')
27
- )
28
- model.load_state_dict(state_dict)
29
  model.eval()
30
 
31
- # βœ… Define transforms
32
  transform = transforms.Compose([
33
  transforms.Resize((299, 299)),
34
  transforms.ToTensor(),
35
  transforms.Normalize([0.5]*3, [0.5]*3)
36
  ])
37
 
38
- # βœ… Inference function
39
- def predict_deepfake(video_path):
40
- if video_path is None:
41
- return "❌ Please upload a video.", None
42
 
43
  cap = cv2.VideoCapture(video_path)
44
  preds = []
@@ -55,15 +141,14 @@ def predict_deepfake(video_path):
55
  y2 = int(h * 0.75)
56
  x1 = int(w * 0.25)
57
  x2 = int(w * 0.75)
58
- face_crop = frame[y1:y2, x1:x2] # Center crop
59
-
60
- image = cv2.cvtColor(face_crop, cv2.COLOR_BGR2RGB)
61
- pil_img = Image.fromarray(image)
62
- input_tensor = transform(pil_img).unsqueeze(0)
63
 
64
  with torch.no_grad():
65
- output = model(input_tensor)
66
- prob = torch.sigmoid(output)[0].item()
67
  preds.append(prob)
68
 
69
  count += 1
@@ -71,36 +156,31 @@ def predict_deepfake(video_path):
71
  cap.release()
72
 
73
  if not preds:
74
- return "❌ No frames processed. Try a different video.", None
75
 
76
  avg = np.mean(preds)
77
- result = "βœ… **REAL**" if avg < 0.5 else "❌ **FAKE**"
78
- verdict = f"""
79
- 🎯 Final Verdict: {result}
80
- πŸ“Š Confidence (avg): {avg:.2f}
81
- """
82
-
83
- # Plot histogram
84
- fig, ax = plt.subplots(figsize=(6, 4))
85
  ax.hist(preds, bins=10, color="red" if avg > 0.5 else "green", edgecolor="black")
86
- ax.set_title("Confidence per Frame (0 = Real, 1 = Fake)")
87
  ax.set_xlabel("Fake Probability")
88
- ax.set_ylabel("Frame Count")
89
  ax.grid(True)
90
 
91
- return verdict, fig
92
 
93
- # βœ… Gradio Interface
94
  with gr.Blocks() as demo:
95
- gr.Markdown("## 🎭 Real Deepfake Detection (Xception DFDC Model)")
96
- gr.Markdown("Upload a short `.mp4` video and the app will classify it as **REAL** or **FAKE** using a pretrained deepfake detection model.")
97
-
98
- video_input = gr.Video(label="Upload video")
99
- result_output = gr.Markdown()
100
- graph_output = gr.Plot()
101
 
102
- analyze_btn = gr.Button("πŸ” Analyze")
 
 
 
103
 
104
- analyze_btn.click(fn=predict_deepfake, inputs=video_input, outputs=[result_output, graph_output])
105
 
106
  demo.queue().launch()
 
8
  import matplotlib.pyplot as plt
9
  import os
10
 
11
+ # βœ… Xception Block Definition
12
+ class SeparableConv2d(nn.Module):
13
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
14
  super().__init__()
15
+ self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=False)
16
+ self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)
 
 
17
  def forward(self, x):
18
+ x = self.depthwise(x)
19
+ x = self.pointwise(x)
20
+ return x
21
 
22
+ class Block(nn.Module):
23
+ def __init__(self, in_filters, out_filters, reps, stride=1, start_with_relu=True, grow_first=True):
24
+ super().__init__()
25
+ layers = []
26
+ filters = in_filters
27
+ if grow_first:
28
+ if start_with_relu:
29
+ layers.append(nn.ReLU(inplace=True))
30
+ layers.extend([
31
+ SeparableConv2d(in_filters, out_filters, 3, 1, 1),
32
+ nn.BatchNorm2d(out_filters)
33
+ ])
34
+ filters = out_filters
35
+ for _ in range(reps - 1):
36
+ layers.extend([
37
+ nn.ReLU(inplace=True),
38
+ SeparableConv2d(filters, filters, 3, 1, 1),
39
+ nn.BatchNorm2d(filters)
40
+ ])
41
+ if not grow_first:
42
+ layers.extend([
43
+ nn.ReLU(inplace=True),
44
+ SeparableConv2d(in_filters, out_filters, 3, 1, 1),
45
+ nn.BatchNorm2d(out_filters)
46
+ ])
47
+ if stride != 1:
48
+ layers.append(nn.MaxPool2d(3, stride, 1))
49
+ self.block = nn.Sequential(*layers)
50
+ self.skip = nn.Conv2d(in_filters, out_filters, 1, stride, bias=False)
51
+ self.skipbn = nn.BatchNorm2d(out_filters)
52
+ def forward(self, inp):
53
+ x = self.block(inp)
54
+ skip = self.skipbn(self.skip(inp))
55
+ x += skip
56
+ return x
57
+
58
+ # βœ… Xception Architecture
59
+ class Xception(nn.Module):
60
+ def __init__(self, num_classes=1):
61
+ super().__init__()
62
+ self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False)
63
+ self.bn1 = nn.BatchNorm2d(32)
64
+ self.relu = nn.ReLU(inplace=True)
65
+ self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
66
+ self.bn2 = nn.BatchNorm2d(64)
67
+ self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True)
68
+ self.block2 = Block(128, 256, 2, 2, grow_first=True)
69
+ self.block3 = Block(256, 728, 2, 2, grow_first=True)
70
+ self.block4 = Block(728, 728, 3)
71
+ self.block5 = Block(728, 728, 3)
72
+ self.block6 = Block(728, 728, 3)
73
+ self.block7 = Block(728, 728, 3)
74
+ self.block8 = Block(728, 728, 3)
75
+ self.block9 = Block(728, 728, 3)
76
+ self.block10 = Block(728, 728, 3)
77
+ self.block11 = Block(728, 728, 3)
78
+ self.block12 = Block(728, 1024, 2, 2, grow_first=False)
79
+ self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
80
+ self.bn3 = nn.BatchNorm2d(1536)
81
+ self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
82
+ self.bn4 = nn.BatchNorm2d(2048)
83
+ self.fc = nn.Linear(2048, num_classes)
84
+ def features(self, input):
85
+ x = self.relu(self.bn1(self.conv1(input)))
86
+ x = self.relu(self.bn2(self.conv2(x)))
87
+ x = self.block1(x)
88
+ x = self.block2(x)
89
+ x = self.block3(x)
90
+ x = self.block4(x)
91
+ x = self.block5(x)
92
+ x = self.block6(x)
93
+ x = self.block7(x)
94
+ x = self.block8(x)
95
+ x = self.block9(x)
96
+ x = self.block10(x)
97
+ x = self.block11(x)
98
+ x = self.block12(x)
99
+ x = self.relu(self.bn3(self.conv3(x)))
100
+ x = self.relu(self.bn4(self.conv4(x)))
101
+ return x
102
+ def forward(self, input):
103
+ x = self.features(input)
104
+ x = nn.AdaptiveAvgPool2d((1, 1))(x)
105
+ x = x.view(x.size(0), -1)
106
+ x = self.fc(x)
107
+ return x
108
+
109
+ # βœ… Load weights
110
+ model = Xception()
111
+ model.load_state_dict(torch.hub.load_state_dict_from_url(
112
  "https://huggingface.co/Selimsef/xception-cnn-df/resolve/main/xception-binary-weights.pt",
113
+ map_location="cpu"
114
+ ))
 
115
  model.eval()
116
 
117
+ # βœ… Transform
118
  transform = transforms.Compose([
119
  transforms.Resize((299, 299)),
120
  transforms.ToTensor(),
121
  transforms.Normalize([0.5]*3, [0.5]*3)
122
  ])
123
 
124
+ # βœ… Analyze function
125
+ def analyze_deepfake(video_path):
126
+ if not video_path:
127
+ return "❌ No video uploaded", None
128
 
129
  cap = cv2.VideoCapture(video_path)
130
  preds = []
 
141
  y2 = int(h * 0.75)
142
  x1 = int(w * 0.25)
143
  x2 = int(w * 0.75)
144
+ crop = frame[y1:y2, x1:x2]
145
+ image = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
146
+ image = Image.fromarray(image)
147
+ input_tensor = transform(image).unsqueeze(0)
 
148
 
149
  with torch.no_grad():
150
+ out = model(input_tensor)
151
+ prob = torch.sigmoid(out)[0].item()
152
  preds.append(prob)
153
 
154
  count += 1
 
156
  cap.release()
157
 
158
  if not preds:
159
+ return "❌ No frames analyzed", None
160
 
161
  avg = np.mean(preds)
162
+ label = "**FAKE**" if avg > 0.5 else "**REAL**"
163
+ result = f"🎯 Verdict: {label}\nConfidence: {avg:.2f}"
164
+
165
+ fig, ax = plt.subplots()
 
 
 
 
166
  ax.hist(preds, bins=10, color="red" if avg > 0.5 else "green", edgecolor="black")
167
+ ax.set_title("Confidence per Frame")
168
  ax.set_xlabel("Fake Probability")
169
+ ax.set_ylabel("Frames")
170
  ax.grid(True)
171
 
172
+ return result, fig
173
 
174
+ # βœ… Gradio App
175
  with gr.Blocks() as demo:
176
+ gr.Markdown("# 🎭 Deepfake Detector with Xception (DFDC)")
177
+ gr.Markdown("Upload a `.mp4` video. The app will classify it as REAL or FAKE based on pretrained deepfake model.")
 
 
 
 
178
 
179
+ video = gr.Video(label="Upload Video")
180
+ output_text = gr.Markdown()
181
+ output_plot = gr.Plot()
182
+ analyze = gr.Button("πŸ” Analyze")
183
 
184
+ analyze.click(fn=analyze_deepfake, inputs=video, outputs=[output_text, output_plot])
185
 
186
  demo.queue().launch()