saakshigupta commited on
Commit
cf71d82
·
verified ·
1 Parent(s): a396d49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -16
app.py CHANGED
@@ -26,16 +26,7 @@ from peft import PeftModel
26
  from gradcam_xception import generate_smoothgrad_visualizations_xception
27
  warnings.filterwarnings("ignore", category=UserWarning)
28
 
29
- # Define Xception transform function directly in app.py
30
- def get_xception_transform():
31
- """Get the image transformation pipeline for Xception input."""
32
- # Standard Xception preprocessing
33
- transform = transforms.Compose([
34
- transforms.Resize((299, 299)),
35
- transforms.ToTensor(),
36
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
37
- ])
38
- return transform
39
 
40
  # App title and description
41
  st.set_page_config(
@@ -357,18 +348,32 @@ class ImageDataset(torch.utils.data.Dataset):
357
  face_img = Image.fromarray(face_img_np)
358
 
359
  # Apply transform to face image
 
360
  if self.transform:
361
  face_tensor = self.transform(face_img)
362
  else:
363
- face_tensor = transforms.ToTensor()(face_img)
 
 
 
 
 
 
364
 
365
  return face_tensor, label, "uploaded_image", original_image, face_box, self.dataset_name
366
  else:
367
  # Process the whole image
 
368
  if self.transform:
369
  image_tensor = self.transform(self.image)
370
  else:
371
- image_tensor = transforms.ToTensor()(self.image)
 
 
 
 
 
 
372
 
373
  return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name
374
 
@@ -387,7 +392,12 @@ def process_image_with_xception_gradcam(image, model, device, pred_class):
387
  raw_cam, cam_img, overlay, comparison = cam_results
388
 
389
  # Extract the face box from the dataset if needed
390
- transform = get_xception_transform()
 
 
 
 
 
391
  dataset = ImageDataset(image, transform=transform, face_only=True)
392
  _, _, _, _, face_box, _ = dataset[0]
393
 
@@ -399,11 +409,12 @@ def process_image_with_xception_gradcam(image, model, device, pred_class):
399
  # ----- Xception Model Loading -----
400
  @st.cache_resource
401
  def load_detection_model_xception():
402
- """Loads the Xception model from our module"""
403
  with st.spinner("Loading Xception model for deepfake detection..."):
404
  try:
405
  log_debug("Beginning Xception model loading")
406
  from gradcam_xception import load_xception_model
 
407
  model = load_xception_model()
408
 
409
  # Get the device
@@ -412,7 +423,7 @@ def load_detection_model_xception():
412
 
413
  model.to(device)
414
  model.eval()
415
- log_debug("Xception model loaded successfully")
416
  return model, device
417
  except ImportError as e:
418
  st.error(f"Import Error: {str(e)}. Make sure gradcam_xception.py is present.")
@@ -695,9 +706,15 @@ def preprocess_image_xception(image):
695
  face_img_for_transform = Image.fromarray(image_np[y1:y2, x1:x2])
696
  # Use the original detected box (without padding) for display rectangle
697
  face_box_display = (x, y, w, h)
 
698
 
699
  # Xception specific transform
700
- transform = get_xception_transform()
 
 
 
 
 
701
  # Apply transform to the selected region (face or whole image)
702
  input_tensor = transform(face_img_for_transform).unsqueeze(0)
703
 
 
26
  from gradcam_xception import generate_smoothgrad_visualizations_xception
27
  warnings.filterwarnings("ignore", category=UserWarning)
28
 
29
+ # Xception transform is now defined directly in preprocess_image_xception
 
 
 
 
 
 
 
 
 
30
 
31
  # App title and description
32
  st.set_page_config(
 
348
  face_img = Image.fromarray(face_img_np)
349
 
350
  # Apply transform to face image
351
+ IMAGE_SIZE = 299
352
  if self.transform:
353
  face_tensor = self.transform(face_img)
354
  else:
355
+ # Use default transform if none provided
356
+ transform = transforms.Compose([
357
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
358
+ transforms.ToTensor(),
359
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
360
+ ])
361
+ face_tensor = transform(face_img)
362
 
363
  return face_tensor, label, "uploaded_image", original_image, face_box, self.dataset_name
364
  else:
365
  # Process the whole image
366
+ IMAGE_SIZE = 299
367
  if self.transform:
368
  image_tensor = self.transform(self.image)
369
  else:
370
+ # Use default transform if none provided
371
+ transform = transforms.Compose([
372
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
373
+ transforms.ToTensor(),
374
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
375
+ ])
376
+ image_tensor = transform(self.image)
377
 
378
  return image_tensor, label, "uploaded_image", original_image, None, self.dataset_name
379
 
 
392
  raw_cam, cam_img, overlay, comparison = cam_results
393
 
394
  # Extract the face box from the dataset if needed
395
+ IMAGE_SIZE = 299
396
+ transform = transforms.Compose([
397
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
398
+ transforms.ToTensor(),
399
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
400
+ ])
401
  dataset = ImageDataset(image, transform=transform, face_only=True)
402
  _, _, _, _, face_box, _ = dataset[0]
403
 
 
409
  # ----- Xception Model Loading -----
410
  @st.cache_resource
411
  def load_detection_model_xception():
412
+ """Loads the Xception model from HF Hub."""
413
  with st.spinner("Loading Xception model for deepfake detection..."):
414
  try:
415
  log_debug("Beginning Xception model loading")
416
  from gradcam_xception import load_xception_model
417
+ log_debug("Loading Xception model (this may take a moment)...")
418
  model = load_xception_model()
419
 
420
  # Get the device
 
423
 
424
  model.to(device)
425
  model.eval()
426
+ log_debug(f"Xception model loaded to {device}.")
427
  return model, device
428
  except ImportError as e:
429
  st.error(f"Import Error: {str(e)}. Make sure gradcam_xception.py is present.")
 
706
  face_img_for_transform = Image.fromarray(image_np[y1:y2, x1:x2])
707
  # Use the original detected box (without padding) for display rectangle
708
  face_box_display = (x, y, w, h)
709
+ log_debug(f"Face detected: Box {face_box_display}")
710
 
711
  # Xception specific transform
712
+ IMAGE_SIZE = 299
713
+ transform = transforms.Compose([
714
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
715
+ transforms.ToTensor(),
716
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # Standard Xception norm
717
+ ])
718
  # Apply transform to the selected region (face or whole image)
719
  input_tensor = transform(face_img_for_transform).unsqueeze(0)
720