saakshigupta commited on
Commit
e0112d5
Β·
verified Β·
1 Parent(s): d742a0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -305
app.py CHANGED
@@ -3,8 +3,7 @@ import torch
3
  import torch.nn as nn
4
  from torch.utils.data import DataLoader
5
  from torchvision import transforms
6
- from transformers import CLIPModel, BlipProcessor, BlipForConditionalGeneration
7
- from transformers.models.clip import CLIPModel
8
  from PIL import Image
9
  import numpy as np
10
  import io
@@ -16,6 +15,7 @@ from unsloth import FastVisionModel
16
  import os
17
  import tempfile
18
  import warnings
 
19
  warnings.filterwarnings("ignore", category=UserWarning)
20
 
21
  # App title and description
@@ -42,14 +42,13 @@ def check_gpu():
42
  # Sidebar components
43
  st.sidebar.title("About")
44
  st.sidebar.markdown("""
45
- This tool detects deepfakes using four AI models:
46
- - **CLIP**: Initial Real/Fake classification
47
- - **GradCAM**: Highlights suspicious regions
48
  - **BLIP**: Describes image content
49
  - **Llama 3.2**: Explains potential manipulations
50
 
51
  ### Quick Start
52
- 1. **Load Models** - Start with CLIP, add others as needed
53
  2. **Upload Image** - View classification and heat map
54
  3. **Analyze** - Get explanations and ask questions
55
 
@@ -72,8 +71,7 @@ if use_custom_instructions:
72
  else:
73
  custom_instruction = ""
74
 
75
- # ----- GradCAM Implementation -----
76
-
77
  class ImageDataset(torch.utils.data.Dataset):
78
  def __init__(self, image, transform=None, face_only=True, dataset_name=None):
79
  self.image = image
@@ -149,262 +147,45 @@ class ImageDataset(torch.utils.data.Dataset):
149
 
150
  return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name
151
 
152
- class GradCAM:
153
- def __init__(self, model, target_layer):
154
- self.model = model
155
- self.target_layer = target_layer
156
- self.gradients = None
157
- self.activations = None
158
- self._register_hooks()
159
-
160
- def _register_hooks(self):
161
- def forward_hook(module, input, output):
162
- if isinstance(output, tuple):
163
- self.activations = output[0]
164
- else:
165
- self.activations = output
166
-
167
- def backward_hook(module, grad_in, grad_out):
168
- if isinstance(grad_out, tuple):
169
- self.gradients = grad_out[0]
170
- else:
171
- self.gradients = grad_out
172
-
173
- layer = dict([*self.model.named_modules()])[self.target_layer]
174
- layer.register_forward_hook(forward_hook)
175
- layer.register_backward_hook(backward_hook)
176
-
177
- def generate(self, input_tensor, class_idx):
178
- self.model.zero_grad()
179
-
180
- try:
181
- # Use only the vision part of the model for gradient calculation
182
- vision_outputs = self.model.vision_model(pixel_values=input_tensor)
183
-
184
- # Get the pooler output
185
- features = vision_outputs.pooler_output
186
-
187
- # Create a dummy gradient for the feature based on the class idx
188
- one_hot = torch.zeros_like(features)
189
- one_hot[0, class_idx] = 1
190
-
191
- # Manually backpropagate
192
- features.backward(gradient=one_hot)
193
-
194
- # Check for None values
195
- if self.gradients is None or self.activations is None:
196
- st.warning("Warning: Gradients or activations are None. Using fallback CAM.")
197
- return np.ones((14, 14), dtype=np.float32) * 0.5
198
-
199
- # Process gradients and activations for transformer-based model
200
- gradients = self.gradients.cpu().detach().numpy()
201
- activations = self.activations.cpu().detach().numpy()
202
-
203
- if len(activations.shape) == 3: # [batch, sequence_length, hidden_dim]
204
- seq_len = activations.shape[1]
205
-
206
- # CLIP ViT typically has 196 patch tokens (14Γ—14) + 1 class token = 197
207
- if seq_len >= 197:
208
- # Skip the class token (first token) and reshape the patch tokens into a square
209
- patch_tokens = activations[0, 1:197, :] # Remove the class token
210
- # Take the mean across the hidden dimension
211
- token_importance = np.mean(np.abs(patch_tokens), axis=1)
212
- # Reshape to the expected grid size (14Γ—14 for CLIP ViT)
213
- cam = token_importance.reshape(14, 14)
214
- else:
215
- # Try to find factors close to a square
216
- side_len = int(np.sqrt(seq_len))
217
- # Use the mean across features as importance
218
- token_importance = np.mean(np.abs(activations[0]), axis=1)
219
- # Create as square-like shape as possible
220
- cam = np.zeros((side_len, side_len))
221
- # Fill the cam with available values
222
- flat_cam = cam.flatten()
223
- flat_cam[:min(len(token_importance), len(flat_cam))] = token_importance[:min(len(token_importance), len(flat_cam))]
224
- cam = flat_cam.reshape(side_len, side_len)
225
- else:
226
- # Fallback
227
- st.info("Using fallback CAM shape (14x14)")
228
- cam = np.ones((14, 14), dtype=np.float32) * 0.5 # Default fallback
229
-
230
- # Ensure we have valid values
231
- cam = np.maximum(cam, 0)
232
- if np.max(cam) > 0:
233
- cam = cam / np.max(cam)
234
-
235
- return cam
236
-
237
- except Exception as e:
238
- st.error(f"Error in GradCAM.generate: {str(e)}")
239
- return np.ones((14, 14), dtype=np.float32) * 0.5
240
-
241
- def overlay_cam_on_image(image, cam, face_box=None, alpha=0.5):
242
- """Overlay the CAM on the image"""
243
- if face_box is not None:
244
- x, y, w, h = face_box
245
- # Create a mask for the entire image (all zeros initially)
246
- img_np = np.array(image)
247
- full_h, full_w = img_np.shape[:2]
248
- full_cam = np.zeros((full_h, full_w), dtype=np.float32)
249
-
250
- # Resize CAM to match face region
251
- face_cam = cv2.resize(cam, (w, h))
252
-
253
- # Copy the face CAM into the full image CAM at the face position
254
- full_cam[y:y+h, x:x+w] = face_cam
255
-
256
- # Convert full CAM to image
257
- cam_resized = Image.fromarray((full_cam * 255).astype(np.uint8))
258
- cam_colormap = plt.cm.jet(np.array(cam_resized) / 255.0)[:, :, :3] # Apply colormap
259
- cam_colormap = (cam_colormap * 255).astype(np.uint8)
260
- else:
261
- # Resize CAM to match image dimensions
262
- img_np = np.array(image)
263
- h, w = img_np.shape[:2]
264
- cam_resized = cv2.resize(cam, (w, h))
265
-
266
- # Apply colormap
267
- cam_colormap = plt.cm.jet(cam_resized)[:, :, :3] # Apply colormap
268
- cam_colormap = (cam_colormap * 255).astype(np.uint8)
269
-
270
- # Blend the original image with the colormap
271
- img_np_float = img_np.astype(float) / 255.0
272
- cam_colormap_float = cam_colormap.astype(float) / 255.0
273
-
274
- blended = img_np_float * (1 - alpha) + cam_colormap_float * alpha
275
- blended = (blended * 255).astype(np.uint8)
276
 
277
- return Image.fromarray(blended)
278
-
279
- def save_comparison(image, cam, overlay, face_box=None):
280
- """Create a side-by-side comparison of the original, CAM, and overlay"""
281
- fig, axes = plt.subplots(1, 3, figsize=(15, 5))
282
-
283
- # Original Image
284
- axes[0].imshow(image)
285
- axes[0].set_title("Original")
286
- if face_box is not None:
287
- x, y, w, h = face_box
288
- rect = plt.Rectangle((x, y), w, h, edgecolor='lime', linewidth=2, fill=False)
289
- axes[0].add_patch(rect)
290
- axes[0].axis("off")
291
-
292
- # CAM
293
- if face_box is not None:
294
- # Create a full image CAM that highlights only the face
295
- img_np = np.array(image)
296
- h, w = img_np.shape[:2]
297
- full_cam = np.zeros((h, w))
298
 
299
- x, y, fw, fh = face_box
300
- # Resize CAM to face size
301
- face_cam = cv2.resize(cam, (fw, fh))
302
- # Place it in the right position
303
- full_cam[y:y+fh, x:x+fw] = face_cam
304
- axes[1].imshow(full_cam, cmap="jet")
305
  else:
306
- cam_resized = cv2.resize(cam, (image.width, image.height))
307
- axes[1].imshow(cam_resized, cmap="jet")
308
- axes[1].set_title("CAM")
309
- axes[1].axis("off")
310
-
311
- # Overlay
312
- axes[2].imshow(overlay)
313
- axes[2].set_title("Overlay")
314
- axes[2].axis("off")
315
 
316
- plt.tight_layout()
317
-
318
- # Convert plot to PIL Image for Streamlit display
319
- buf = io.BytesIO()
320
- plt.savefig(buf, format="png", bbox_inches="tight")
321
- plt.close()
322
- buf.seek(0)
323
- return Image.open(buf)
324
-
325
- # Function to load GradCAM CLIP model
326
  @st.cache_resource
327
- def load_clip_model():
328
- with st.spinner("Loading CLIP model for GradCAM..."):
 
329
  try:
330
- model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
331
-
332
- # Apply a simple classification head
333
- model.classification_head = nn.Linear(1024, 2)
334
- model.classification_head.weight.data.normal_(mean=0.0, std=0.02)
335
- model.classification_head.bias.data.zero_()
336
-
337
  model.eval()
338
- return model
339
- except Exception as e:
340
- st.error(f"Error loading CLIP model: {str(e)}")
341
- return None
342
-
343
- def get_target_layer_clip(model):
344
- """Get the target layer for GradCAM"""
345
- return "vision_model.encoder.layers.23"
346
-
347
- def process_image_with_gradcam(image, model, device, pred_class):
348
- """Process an image with GradCAM"""
349
- # Set up transformations
350
- transform = transforms.Compose([
351
- transforms.Resize((224, 224)),
352
- transforms.ToTensor(),
353
- transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
354
- ])
355
-
356
- # Create dataset for the single image
357
- dataset = ImageDataset(image, transform=transform, face_only=True)
358
-
359
- # Custom collate function
360
- def custom_collate(batch):
361
- tensors = [item[0] for item in batch]
362
- labels = [item[1] for item in batch]
363
- paths = [item[2] for item in batch]
364
- images = [item[3] for item in batch]
365
- face_boxes = [item[4] for item in batch]
366
- dataset_names = [item[5] for item in batch]
367
-
368
- tensors = torch.stack(tensors)
369
- labels = torch.tensor(labels)
370
-
371
- return tensors, labels, paths, images, face_boxes, dataset_names
372
-
373
- # Create dataloader
374
- dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=custom_collate)
375
-
376
- # Extract the batch
377
- for batch in dataloader:
378
- input_tensor, label, img_paths, original_images, face_boxes, dataset_names = batch
379
- original_image = original_images[0]
380
- face_box = face_boxes[0]
381
-
382
- # Move tensors and model to device
383
- input_tensor = input_tensor.to(device)
384
- model = model.to(device)
385
-
386
- try:
387
- # Create GradCAM extractor
388
- target_layer = get_target_layer_clip(model)
389
- cam_extractor = GradCAM(model, target_layer)
390
-
391
- # Generate CAM
392
- cam = cam_extractor.generate(input_tensor, pred_class)
393
-
394
- # Create visualizations
395
- overlay = overlay_cam_on_image(original_image, cam, face_box)
396
- comparison = save_comparison(original_image, cam, overlay, face_box)
397
-
398
- # Return results
399
- return cam, overlay, comparison, face_box
400
-
401
  except Exception as e:
402
- st.error(f"Error processing image with GradCAM: {str(e)}")
403
- # Return default values
404
- default_cam = np.ones((14, 14), dtype=np.float32) * 0.5
405
- overlay = overlay_cam_on_image(original_image, default_cam, face_box)
406
- comparison = save_comparison(original_image, default_cam, overlay, face_box)
407
- return default_cam, overlay, comparison, face_box
408
 
409
  # ----- BLIP Image Captioning -----
410
 
@@ -624,12 +405,47 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
624
  st.error(f"Error during LLM analysis: {str(e)}")
625
  return f"Error analyzing image: {str(e)}"
626
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627
  # Main app
628
  def main():
629
  # Initialize session state variables
630
- if 'clip_model_loaded' not in st.session_state:
631
- st.session_state.clip_model_loaded = False
632
- st.session_state.clip_model = None
633
 
634
  if 'llm_model_loaded' not in st.session_state:
635
  st.session_state.llm_model_loaded = False
@@ -652,21 +468,22 @@ def main():
652
  st.write("Please load the models using the buttons below:")
653
 
654
  # Button for loading models
655
- clip_col, blip_col, llm_col = st.columns(3)
656
 
657
- with clip_col:
658
- if not st.session_state.clip_model_loaded:
659
- if st.button("πŸ“₯ Load CLIP Model for Detection", type="primary"):
660
- # Load CLIP model
661
- model = load_clip_model()
662
  if model is not None:
663
- st.session_state.clip_model = model
664
- st.session_state.clip_model_loaded = True
665
- st.success("βœ… CLIP model loaded successfully!")
 
666
  else:
667
- st.error("❌ Failed to load CLIP model.")
668
  else:
669
- st.success("βœ… CLIP model loaded and ready!")
670
 
671
  with blip_col:
672
  if not st.session_state.blip_model_loaded:
@@ -724,56 +541,56 @@ def main():
724
  st.session_state.original_model
725
  )
726
  st.session_state.image_caption = caption
727
-
728
- # Store caption but don't display it yet
729
 
730
- # Detect with CLIP model if loaded
731
- if st.session_state.clip_model_loaded:
732
- with st.spinner("Analyzing image with CLIP model..."):
733
- # Preprocess image for CLIP
734
- transform = transforms.Compose([
735
- transforms.Resize((224, 224)),
736
- transforms.ToTensor(),
737
- transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
738
- ])
739
 
740
- # Create a simple dataset for the image
741
- dataset = ImageDataset(image, transform=transform, face_only=True)
742
- tensor, _, _, _, face_box, _ = dataset[0]
743
- tensor = tensor.unsqueeze(0)
744
 
745
- # Get device
746
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
747
-
748
- # Move model and tensor to device
749
- model = st.session_state.clip_model.to(device)
750
- tensor = tensor.to(device)
751
 
752
  # Forward pass
753
  with torch.no_grad():
754
- outputs = model.vision_model(pixel_values=tensor).pooler_output
755
- logits = model.classification_head(outputs)
756
- probs = torch.softmax(logits, dim=1)[0]
757
- pred_class = torch.argmax(probs).item()
758
- confidence = probs[pred_class].item()
759
- pred_label = "Fake" if pred_class == 1 else "Real"
760
 
761
  # Display results
762
  with col2:
763
  st.markdown("### Detection Result")
764
  st.markdown(f"**Classification:** {pred_label} (Confidence: {confidence:.2%})")
 
 
 
 
 
 
 
 
765
 
766
  # GradCAM visualization
767
  st.subheader("GradCAM Visualization")
768
- cam, overlay, comparison, detected_face_box = process_image_with_gradcam(
769
  image, model, device, pred_class
770
  )
771
 
772
- # Display GradCAM results (controlled size)
773
- st.image(comparison, caption="Original | CAM | Overlay", width=700)
 
 
 
 
774
 
775
  # Generate caption for GradCAM overlay image if BLIP model is loaded
776
- if st.session_state.blip_model_loaded:
777
  with st.spinner("Analyzing GradCAM visualization..."):
778
  gradcam_caption = generate_gradcam_caption(
779
  overlay,
@@ -781,8 +598,6 @@ def main():
781
  st.session_state.finetuned_model
782
  )
783
  st.session_state.gradcam_caption = gradcam_caption
784
-
785
- # Store caption but don't display it yet
786
 
787
  # Save results in session state for LLM analysis
788
  st.session_state.current_image = image
@@ -793,7 +608,7 @@ def main():
793
 
794
  st.success("βœ… Initial detection and GradCAM visualization complete!")
795
  else:
796
- st.warning("⚠️ Please load the CLIP model first to perform initial detection.")
797
  except Exception as e:
798
  st.error(f"Error processing image: {str(e)}")
799
  import traceback
@@ -930,7 +745,7 @@ def main():
930
  st.markdown("---")
931
 
932
  # Add model version indicator in sidebar
933
- st.sidebar.info("Using deepfake-explainer-2 model")
934
 
935
  if __name__ == "__main__":
936
  main()
 
3
  import torch.nn as nn
4
  from torch.utils.data import DataLoader
5
  from torchvision import transforms
6
+ from transformers import BlipProcessor, BlipForConditionalGeneration
 
7
  from PIL import Image
8
  import numpy as np
9
  import io
 
15
  import os
16
  import tempfile
17
  import warnings
18
+ from gradcam_xception import load_xception_model, generate_smoothgrad_visualizations_xception, get_xception_transform
19
  warnings.filterwarnings("ignore", category=UserWarning)
20
 
21
  # App title and description
 
42
  # Sidebar components
43
  st.sidebar.title("About")
44
  st.sidebar.markdown("""
45
+ This tool detects deepfakes using three AI models:
46
+ - **Xception**: Initial Real/Fake classification
 
47
  - **BLIP**: Describes image content
48
  - **Llama 3.2**: Explains potential manipulations
49
 
50
  ### Quick Start
51
+ 1. **Load Models** - Start with Xception, add others as needed
52
  2. **Upload Image** - View classification and heat map
53
  3. **Analyze** - Get explanations and ask questions
54
 
 
71
  else:
72
  custom_instruction = ""
73
 
74
+ # ----- GradCAM Implementation for Xception -----
 
75
  class ImageDataset(torch.utils.data.Dataset):
76
  def __init__(self, image, transform=None, face_only=True, dataset_name=None):
77
  self.image = image
 
147
 
148
  return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name
149
 
150
+ # Function to process image with Xception GradCAM
151
+ def process_image_with_xception_gradcam(image, model, device, pred_class):
152
+ """Process an image with Xception GradCAM"""
153
+ cam_results = generate_smoothgrad_visualizations_xception(
154
+ model=model,
155
+ image=image,
156
+ target_class=pred_class,
157
+ face_only=True,
158
+ num_samples=5 # Can be adjusted
159
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ if cam_results and len(cam_results) == 4:
162
+ raw_cam, cam_img, overlay, comparison = cam_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ # Extract the face box from the dataset if needed
165
+ transform = get_xception_transform()
166
+ dataset = ImageDataset(image, transform=transform, face_only=True)
167
+ _, _, _, _, face_box, _ = dataset[0]
168
+
169
+ return raw_cam, overlay, comparison, face_box
170
  else:
171
+ st.error("Failed to generate GradCAM visualization")
172
+ return None, None, None, None
 
 
 
 
 
 
 
173
 
174
+ # ----- Xception Model Loading -----
 
 
 
 
 
 
 
 
 
175
  @st.cache_resource
176
+ def load_detection_model_xception():
177
+ """Loads the Xception model from our module"""
178
+ with st.spinner("Loading Xception model for deepfake detection..."):
179
  try:
180
+ model = load_xception_model()
181
+ # Get the device
182
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
183
+ model.to(device)
 
 
 
184
  model.eval()
185
+ return model, device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  except Exception as e:
187
+ st.error(f"Error loading Xception model: {str(e)}")
188
+ return None, None
 
 
 
 
189
 
190
  # ----- BLIP Image Captioning -----
191
 
 
405
  st.error(f"Error during LLM analysis: {str(e)}")
406
  return f"Error analyzing image: {str(e)}"
407
 
408
+ # Preprocess image for Xception
409
+ def preprocess_image_xception(image):
410
+ """Preprocesses image for Xception model input and face detection."""
411
+ face_detector = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
412
+ image_np = np.array(image.convert('RGB')) # Ensure RGB
413
+ gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
414
+ faces = face_detector.detectMultiScale(gray, 1.1, 5)
415
+
416
+ face_img_for_transform = image # Default to whole image
417
+ face_box_display = None # For drawing on original image
418
+
419
+ if len(faces) == 0:
420
+ st.warning("No face detected, using whole image for prediction/CAM.")
421
+ else:
422
+ areas = [w * h for (x, y, w, h) in faces]
423
+ largest_idx = np.argmax(areas)
424
+ x, y, w, h = faces[largest_idx]
425
+ padding_x = int(w * 0.05) # Use percentages as in gradcam_xception
426
+ padding_y = int(h * 0.05)
427
+ x1, y1 = max(0, x - padding_x), max(0, y - padding_y)
428
+ x2, y2 = min(image_np.shape[1], x + w + padding_x), min(image_np.shape[0], y + h + padding_y)
429
+
430
+ # Use the padded face region for the model transform
431
+ face_img_for_transform = Image.fromarray(image_np[y1:y2, x1:x2])
432
+ # Use the original detected box (without padding) for display rectangle
433
+ face_box_display = (x, y, w, h)
434
+
435
+ # Xception specific transform
436
+ transform = get_xception_transform()
437
+ # Apply transform to the selected region (face or whole image)
438
+ input_tensor = transform(face_img_for_transform).unsqueeze(0)
439
+
440
+ # Return tensor, original full image, and the display face box
441
+ return input_tensor, image, face_box_display
442
+
443
  # Main app
444
  def main():
445
  # Initialize session state variables
446
+ if 'xception_model_loaded' not in st.session_state:
447
+ st.session_state.xception_model_loaded = False
448
+ st.session_state.xception_model = None
449
 
450
  if 'llm_model_loaded' not in st.session_state:
451
  st.session_state.llm_model_loaded = False
 
468
  st.write("Please load the models using the buttons below:")
469
 
470
  # Button for loading models
471
+ xception_col, blip_col, llm_col = st.columns(3)
472
 
473
+ with xception_col:
474
+ if not st.session_state.xception_model_loaded:
475
+ if st.button("πŸ“₯ Load Xception Model for Detection", type="primary"):
476
+ # Load Xception model
477
+ model, device = load_detection_model_xception()
478
  if model is not None:
479
+ st.session_state.xception_model = model
480
+ st.session_state.device = device
481
+ st.session_state.xception_model_loaded = True
482
+ st.success("βœ… Xception model loaded successfully!")
483
  else:
484
+ st.error("❌ Failed to load Xception model.")
485
  else:
486
+ st.success("βœ… Xception model loaded and ready!")
487
 
488
  with blip_col:
489
  if not st.session_state.blip_model_loaded:
 
541
  st.session_state.original_model
542
  )
543
  st.session_state.image_caption = caption
 
 
544
 
545
+ # Detect with Xception model if loaded
546
+ if st.session_state.xception_model_loaded:
547
+ with st.spinner("Analyzing image with Xception model..."):
548
+ # Preprocess image for Xception
549
+ input_tensor, original_image, face_box = preprocess_image_xception(image)
 
 
 
 
550
 
551
+ # Get device and model
552
+ device = st.session_state.device
553
+ model = st.session_state.xception_model
 
554
 
555
+ # Move tensor to device
556
+ input_tensor = input_tensor.to(device)
 
 
 
 
557
 
558
  # Forward pass
559
  with torch.no_grad():
560
+ logits = model(input_tensor)
561
+ probabilities = torch.softmax(logits, dim=1)[0]
562
+ pred_class = torch.argmax(probabilities).item()
563
+ confidence = probabilities[pred_class].item()
564
+ pred_label = "Fake" if pred_class == 0 else "Real" # Check class mapping
 
565
 
566
  # Display results
567
  with col2:
568
  st.markdown("### Detection Result")
569
  st.markdown(f"**Classification:** {pred_label} (Confidence: {confidence:.2%})")
570
+
571
+ # Display face box on image if detected
572
+ if face_box:
573
+ img_to_show = original_image.copy()
574
+ img_draw = np.array(img_to_show)
575
+ x, y, w, h = face_box
576
+ cv2.rectangle(img_draw, (x, y), (x + w, y + h), (0, 255, 0), 2)
577
+ st.image(Image.fromarray(img_draw), caption="Detected Face", width=300)
578
 
579
  # GradCAM visualization
580
  st.subheader("GradCAM Visualization")
581
+ cam, overlay, comparison, detected_face_box = process_image_with_xception_gradcam(
582
  image, model, device, pred_class
583
  )
584
 
585
+ if comparison:
586
+ # Display GradCAM results (controlled size)
587
+ st.image(comparison, caption="Original | CAM | Overlay", width=700)
588
+
589
+ # Save for later use
590
+ st.session_state.comparison_image = comparison
591
 
592
  # Generate caption for GradCAM overlay image if BLIP model is loaded
593
+ if st.session_state.blip_model_loaded and overlay:
594
  with st.spinner("Analyzing GradCAM visualization..."):
595
  gradcam_caption = generate_gradcam_caption(
596
  overlay,
 
598
  st.session_state.finetuned_model
599
  )
600
  st.session_state.gradcam_caption = gradcam_caption
 
 
601
 
602
  # Save results in session state for LLM analysis
603
  st.session_state.current_image = image
 
608
 
609
  st.success("βœ… Initial detection and GradCAM visualization complete!")
610
  else:
611
+ st.warning("⚠️ Please load the Xception model first to perform initial detection.")
612
  except Exception as e:
613
  st.error(f"Error processing image: {str(e)}")
614
  import traceback
 
745
  st.markdown("---")
746
 
747
  # Add model version indicator in sidebar
748
+ st.sidebar.info("Using Xception + deepfake-explainer-2 models")
749
 
750
  if __name__ == "__main__":
751
  main()