saakshigupta commited on
Commit
4690e29
·
verified ·
1 Parent(s): 8919ca1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -122
app.py CHANGED
@@ -546,133 +546,175 @@ def main():
546
  # Image upload section
547
  with st.expander("Stage 2: Image Upload & Initial Detection", expanded=True):
548
  st.subheader("Upload an Image")
549
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
550
 
551
- if uploaded_file is not None:
552
- try:
553
- # Load and display the image (with controlled size)
554
- image = Image.open(uploaded_file).convert("RGB")
555
-
556
- # Display the image with a controlled width
557
- col1, col2 = st.columns([1, 2])
558
- with col1:
559
- st.image(image, caption="Uploaded Image", width=300)
560
-
561
- # Generate detailed caption for original image if BLIP model is loaded
562
- if st.session_state.blip_model_loaded:
563
- with st.spinner("Generating image description..."):
564
- caption = generate_image_caption(
565
- image,
566
- st.session_state.original_processor,
567
- st.session_state.original_model
568
- )
569
- st.session_state.image_caption = caption
570
-
571
- # Detect with Xception model if loaded
572
- if st.session_state.xception_model_loaded:
573
- try:
574
- with st.spinner("Analyzing image with Xception model..."):
575
- # Preprocess image for Xception
576
- st.write("Starting Xception processing...")
577
- input_tensor, original_image, face_box = preprocess_image_xception(image)
578
-
579
- # Get device and model
580
- device = st.session_state.device
581
- model = st.session_state.xception_model
582
-
583
- # Ensure model is in eval mode
584
- model.eval()
585
-
586
- # Move tensor to device
587
- input_tensor = input_tensor.to(device)
588
- st.write(f"Input tensor on device: {device}")
589
-
590
- # Forward pass with proper error handling
591
- try:
592
- with torch.no_grad():
593
- st.write("Running model inference...")
594
- logits = model(input_tensor)
595
- st.write(f"Raw logits: {logits}")
596
- probabilities = torch.softmax(logits, dim=1)[0]
597
- st.write(f"Probabilities: {probabilities}")
598
- pred_class = torch.argmax(probabilities).item()
599
- confidence = probabilities[pred_class].item()
600
- st.write(f"Predicted class: {pred_class}, Confidence: {confidence:.4f}")
601
-
602
- # Explicit class mapping - adjust if needed based on your model
603
- pred_label = "Fake" if pred_class == 0 else "Real"
604
- st.write(f"Mapped to label: {pred_label}")
605
- except Exception as e:
606
- st.error(f"Error in model inference: {str(e)}")
607
- import traceback
608
- st.error(traceback.format_exc())
609
- # Set default values
610
- pred_class = 0
611
- confidence = 0.5
612
- pred_label = "Error in prediction"
613
-
614
- # Display results
615
- with col2:
616
- st.markdown("### Detection Result")
617
- st.markdown(f"**Classification:** {pred_label} (Confidence: {confidence:.2%})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
 
619
- # Display face box on image if detected
620
- if face_box:
621
- img_to_show = original_image.copy()
622
- img_draw = np.array(img_to_show)
623
- x, y, w, h = face_box
624
- cv2.rectangle(img_draw, (x, y), (x + w, y + h), (0, 255, 0), 2)
625
- st.image(Image.fromarray(img_draw), caption="Detected Face", width=300)
 
 
 
 
 
 
 
 
 
626
 
627
- # GradCAM visualization with error handling
628
- st.subheader("GradCAM Visualization")
629
- try:
630
- st.write("Generating GradCAM visualization...")
631
- cam, overlay, comparison, detected_face_box = process_image_with_xception_gradcam(
632
- image, model, device, pred_class
633
- )
634
-
635
- if comparison:
636
- # Display GradCAM results (controlled size)
637
- st.image(comparison, caption="Original | CAM | Overlay", width=700)
638
-
639
- # Save for later use
640
- st.session_state.comparison_image = comparison
641
- else:
642
- st.error("GradCAM visualization failed - comparison image not generated")
643
-
644
- # Generate caption for GradCAM overlay image if BLIP model is loaded
645
- if st.session_state.blip_model_loaded and overlay:
646
- with st.spinner("Analyzing GradCAM visualization..."):
647
- gradcam_caption = generate_gradcam_caption(
648
- overlay,
649
- st.session_state.finetuned_processor,
650
- st.session_state.finetuned_model
651
- )
652
- st.session_state.gradcam_caption = gradcam_caption
653
- except Exception as e:
654
- st.error(f"Error generating GradCAM: {str(e)}")
655
- import traceback
656
- st.error(traceback.format_exc())
657
 
658
- # Save results in session state for LLM analysis
659
- st.session_state.current_image = image
660
- st.session_state.current_overlay = overlay if 'overlay' in locals() else None
661
- st.session_state.current_face_box = detected_face_box if 'detected_face_box' in locals() else None
662
- st.session_state.current_pred_label = pred_label
663
- st.session_state.current_confidence = confidence
 
 
664
 
665
- st.success("✅ Initial detection and GradCAM visualization complete!")
666
- except Exception as e:
667
- st.error(f"Overall error in Xception processing: {str(e)}")
668
- import traceback
669
- st.error(traceback.format_exc())
670
- else:
671
- st.warning("⚠️ Please load the Xception model first to perform initial detection.")
672
- except Exception as e:
673
- st.error(f"Error processing image: {str(e)}")
674
- import traceback
675
- st.error(traceback.format_exc()) # This will show the full error traceback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
 
677
  # Image Analysis Summary section - AFTER Stage 2
678
  if hasattr(st.session_state, 'current_image') and (hasattr(st.session_state, 'image_caption') or hasattr(st.session_state, 'gradcam_caption')):
 
546
  # Image upload section
547
  with st.expander("Stage 2: Image Upload & Initial Detection", expanded=True):
548
  st.subheader("Upload an Image")
 
549
 
550
+ # Add alternative upload methods
551
+ upload_tab1, upload_tab2, upload_tab3 = st.tabs(["File Upload", "URL Input", "Sample Images"])
552
+
553
+ uploaded_image = None
554
+
555
+ with upload_tab1:
556
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
557
+ if uploaded_file is not None:
558
+ try:
559
+ # Try to load directly from bytes to avoid file system issues
560
+ file_bytes = uploaded_file.getvalue()
561
+ uploaded_image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
562
+ st.session_state.upload_method = "file"
563
+ except Exception as e:
564
+ st.error(f"Error loading file: {str(e)}")
565
+
566
+ with upload_tab2:
567
+ url = st.text_input("Enter image URL:")
568
+ if url and url.strip():
569
+ try:
570
+ import requests
571
+ response = requests.get(url, stream=True)
572
+ if response.status_code == 200:
573
+ uploaded_image = Image.open(io.BytesIO(response.content)).convert("RGB")
574
+ st.session_state.upload_method = "url"
575
+ else:
576
+ st.error(f"Failed to load image from URL: Status code {response.status_code}")
577
+ except Exception as e:
578
+ st.error(f"Error loading image from URL: {str(e)}")
579
+
580
+ with upload_tab3:
581
+ sample_images = {
582
+ "Real face sample": "https://raw.githubusercontent.com/deepfakes/faceswap/master/docs/full_guide/images/test_face.jpg",
583
+ "Likely deepfake sample": "https://storage.googleapis.com/deepfake-detection/example_deepfake.jpg",
584
+ "Neutral face sample": "https://t4.ftcdn.net/jpg/02/19/63/31/360_F_219633151_BW6TD8D1EA9OqZu4JgdmeJGg4JBaiAHj.jpg",
585
+ }
586
+ selected_sample = st.selectbox("Select a sample image:", list(sample_images.keys()))
587
+ if st.button("Load Sample"):
588
+ try:
589
+ import requests
590
+ response = requests.get(sample_images[selected_sample], stream=True)
591
+ if response.status_code == 200:
592
+ uploaded_image = Image.open(io.BytesIO(response.content)).convert("RGB")
593
+ st.session_state.upload_method = "sample"
594
+ else:
595
+ st.error(f"Failed to load sample image: Status code {response.status_code}")
596
+ except Exception as e:
597
+ st.error(f"Error loading sample image: {str(e)}")
598
+
599
+ # If we have an uploaded image, process it
600
+ if uploaded_image is not None:
601
+ # Display the image
602
+ image = uploaded_image
603
+ col1, col2 = st.columns([1, 2])
604
+ with col1:
605
+ st.image(image, caption="Uploaded Image", width=300)
606
+
607
+ # Generate detailed caption for original image if BLIP model is loaded
608
+ if st.session_state.blip_model_loaded:
609
+ with st.spinner("Generating image description..."):
610
+ caption = generate_image_caption(
611
+ image,
612
+ st.session_state.original_processor,
613
+ st.session_state.original_model
614
+ )
615
+ st.session_state.image_caption = caption
616
+
617
+ # Continue with your existing code for the Xception model analysis
618
+ if st.session_state.xception_model_loaded:
619
+ try:
620
+ with st.spinner("Analyzing image with Xception model..."):
621
+ # Preprocess image for Xception
622
+ st.write("Starting Xception processing...")
623
+ input_tensor, original_image, face_box = preprocess_image_xception(image)
624
+
625
+ # Get device and model
626
+ device = st.session_state.device
627
+ model = st.session_state.xception_model
628
+
629
+ # Ensure model is in eval mode
630
+ model.eval()
631
+
632
+ # Move tensor to device
633
+ input_tensor = input_tensor.to(device)
634
+ st.write(f"Input tensor on device: {device}")
635
+
636
+ # Forward pass with proper error handling
637
+ try:
638
+ with torch.no_grad():
639
+ st.write("Running model inference...")
640
+ logits = model(input_tensor)
641
+ st.write(f"Raw logits: {logits}")
642
+ probabilities = torch.softmax(logits, dim=1)[0]
643
+ st.write(f"Probabilities: {probabilities}")
644
+ pred_class = torch.argmax(probabilities).item()
645
+ confidence = probabilities[pred_class].item()
646
+ st.write(f"Predicted class: {pred_class}, Confidence: {confidence:.4f}")
647
 
648
+ # Explicit class mapping - adjust if needed based on your model
649
+ pred_label = "Fake" if pred_class == 0 else "Real"
650
+ st.write(f"Mapped to label: {pred_label}")
651
+ except Exception as e:
652
+ st.error(f"Error in model inference: {str(e)}")
653
+ import traceback
654
+ st.error(traceback.format_exc())
655
+ # Set default values
656
+ pred_class = 0
657
+ confidence = 0.5
658
+ pred_label = "Error in prediction"
659
+
660
+ # Display results
661
+ with col2:
662
+ st.markdown("### Detection Result")
663
+ st.markdown(f"**Classification:** {pred_label} (Confidence: {confidence:.2%})")
664
 
665
+ # Display face box on image if detected
666
+ if face_box:
667
+ img_to_show = original_image.copy()
668
+ img_draw = np.array(img_to_show)
669
+ x, y, w, h = face_box
670
+ cv2.rectangle(img_draw, (x, y), (x + w, y + h), (0, 255, 0), 2)
671
+ st.image(Image.fromarray(img_draw), caption="Detected Face", width=300)
672
+
673
+ # GradCAM visualization with error handling
674
+ st.subheader("GradCAM Visualization")
675
+ try:
676
+ st.write("Generating GradCAM visualization...")
677
+ cam, overlay, comparison, detected_face_box = process_image_with_xception_gradcam(
678
+ image, model, device, pred_class
679
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
680
 
681
+ if comparison:
682
+ # Display GradCAM results (controlled size)
683
+ st.image(comparison, caption="Original | CAM | Overlay", width=700)
684
+
685
+ # Save for later use
686
+ st.session_state.comparison_image = comparison
687
+ else:
688
+ st.error("GradCAM visualization failed - comparison image not generated")
689
 
690
+ # Generate caption for GradCAM overlay image if BLIP model is loaded
691
+ if st.session_state.blip_model_loaded and overlay:
692
+ with st.spinner("Analyzing GradCAM visualization..."):
693
+ gradcam_caption = generate_gradcam_caption(
694
+ overlay,
695
+ st.session_state.finetuned_processor,
696
+ st.session_state.finetuned_model
697
+ )
698
+ st.session_state.gradcam_caption = gradcam_caption
699
+ except Exception as e:
700
+ st.error(f"Error generating GradCAM: {str(e)}")
701
+ import traceback
702
+ st.error(traceback.format_exc())
703
+
704
+ # Save results in session state for LLM analysis
705
+ st.session_state.current_image = image
706
+ st.session_state.current_overlay = overlay if 'overlay' in locals() else None
707
+ st.session_state.current_face_box = detected_face_box if 'detected_face_box' in locals() else None
708
+ st.session_state.current_pred_label = pred_label
709
+ st.session_state.current_confidence = confidence
710
+
711
+ st.success("✅ Initial detection and GradCAM visualization complete!")
712
+ except Exception as e:
713
+ st.error(f"Overall error in Xception processing: {str(e)}")
714
+ import traceback
715
+ st.error(traceback.format_exc())
716
+ else:
717
+ st.warning("⚠️ Please load the Xception model first to perform initial detection.")
718
 
719
  # Image Analysis Summary section - AFTER Stage 2
720
  if hasattr(st.session_state, 'current_image') and (hasattr(st.session_state, 'image_caption') or hasattr(st.session_state, 'gradcam_caption')):