Update app.py
Browse files
app.py
CHANGED
@@ -546,133 +546,175 @@ def main():
|
|
546 |
# Image upload section
|
547 |
with st.expander("Stage 2: Image Upload & Initial Detection", expanded=True):
|
548 |
st.subheader("Upload an Image")
|
549 |
-
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
550 |
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
618 |
|
619 |
-
#
|
620 |
-
if
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
626 |
|
627 |
-
#
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
)
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
st.error("GradCAM visualization failed - comparison image not generated")
|
643 |
-
|
644 |
-
# Generate caption for GradCAM overlay image if BLIP model is loaded
|
645 |
-
if st.session_state.blip_model_loaded and overlay:
|
646 |
-
with st.spinner("Analyzing GradCAM visualization..."):
|
647 |
-
gradcam_caption = generate_gradcam_caption(
|
648 |
-
overlay,
|
649 |
-
st.session_state.finetuned_processor,
|
650 |
-
st.session_state.finetuned_model
|
651 |
-
)
|
652 |
-
st.session_state.gradcam_caption = gradcam_caption
|
653 |
-
except Exception as e:
|
654 |
-
st.error(f"Error generating GradCAM: {str(e)}")
|
655 |
-
import traceback
|
656 |
-
st.error(traceback.format_exc())
|
657 |
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
|
|
|
|
664 |
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
676 |
|
677 |
# Image Analysis Summary section - AFTER Stage 2
|
678 |
if hasattr(st.session_state, 'current_image') and (hasattr(st.session_state, 'image_caption') or hasattr(st.session_state, 'gradcam_caption')):
|
|
|
546 |
# Image upload section
|
547 |
with st.expander("Stage 2: Image Upload & Initial Detection", expanded=True):
|
548 |
st.subheader("Upload an Image")
|
|
|
549 |
|
550 |
+
# Add alternative upload methods
|
551 |
+
upload_tab1, upload_tab2, upload_tab3 = st.tabs(["File Upload", "URL Input", "Sample Images"])
|
552 |
+
|
553 |
+
uploaded_image = None
|
554 |
+
|
555 |
+
with upload_tab1:
|
556 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
557 |
+
if uploaded_file is not None:
|
558 |
+
try:
|
559 |
+
# Try to load directly from bytes to avoid file system issues
|
560 |
+
file_bytes = uploaded_file.getvalue()
|
561 |
+
uploaded_image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
|
562 |
+
st.session_state.upload_method = "file"
|
563 |
+
except Exception as e:
|
564 |
+
st.error(f"Error loading file: {str(e)}")
|
565 |
+
|
566 |
+
with upload_tab2:
|
567 |
+
url = st.text_input("Enter image URL:")
|
568 |
+
if url and url.strip():
|
569 |
+
try:
|
570 |
+
import requests
|
571 |
+
response = requests.get(url, stream=True)
|
572 |
+
if response.status_code == 200:
|
573 |
+
uploaded_image = Image.open(io.BytesIO(response.content)).convert("RGB")
|
574 |
+
st.session_state.upload_method = "url"
|
575 |
+
else:
|
576 |
+
st.error(f"Failed to load image from URL: Status code {response.status_code}")
|
577 |
+
except Exception as e:
|
578 |
+
st.error(f"Error loading image from URL: {str(e)}")
|
579 |
+
|
580 |
+
with upload_tab3:
|
581 |
+
sample_images = {
|
582 |
+
"Real face sample": "https://raw.githubusercontent.com/deepfakes/faceswap/master/docs/full_guide/images/test_face.jpg",
|
583 |
+
"Likely deepfake sample": "https://storage.googleapis.com/deepfake-detection/example_deepfake.jpg",
|
584 |
+
"Neutral face sample": "https://t4.ftcdn.net/jpg/02/19/63/31/360_F_219633151_BW6TD8D1EA9OqZu4JgdmeJGg4JBaiAHj.jpg",
|
585 |
+
}
|
586 |
+
selected_sample = st.selectbox("Select a sample image:", list(sample_images.keys()))
|
587 |
+
if st.button("Load Sample"):
|
588 |
+
try:
|
589 |
+
import requests
|
590 |
+
response = requests.get(sample_images[selected_sample], stream=True)
|
591 |
+
if response.status_code == 200:
|
592 |
+
uploaded_image = Image.open(io.BytesIO(response.content)).convert("RGB")
|
593 |
+
st.session_state.upload_method = "sample"
|
594 |
+
else:
|
595 |
+
st.error(f"Failed to load sample image: Status code {response.status_code}")
|
596 |
+
except Exception as e:
|
597 |
+
st.error(f"Error loading sample image: {str(e)}")
|
598 |
+
|
599 |
+
# If we have an uploaded image, process it
|
600 |
+
if uploaded_image is not None:
|
601 |
+
# Display the image
|
602 |
+
image = uploaded_image
|
603 |
+
col1, col2 = st.columns([1, 2])
|
604 |
+
with col1:
|
605 |
+
st.image(image, caption="Uploaded Image", width=300)
|
606 |
+
|
607 |
+
# Generate detailed caption for original image if BLIP model is loaded
|
608 |
+
if st.session_state.blip_model_loaded:
|
609 |
+
with st.spinner("Generating image description..."):
|
610 |
+
caption = generate_image_caption(
|
611 |
+
image,
|
612 |
+
st.session_state.original_processor,
|
613 |
+
st.session_state.original_model
|
614 |
+
)
|
615 |
+
st.session_state.image_caption = caption
|
616 |
+
|
617 |
+
# Continue with your existing code for the Xception model analysis
|
618 |
+
if st.session_state.xception_model_loaded:
|
619 |
+
try:
|
620 |
+
with st.spinner("Analyzing image with Xception model..."):
|
621 |
+
# Preprocess image for Xception
|
622 |
+
st.write("Starting Xception processing...")
|
623 |
+
input_tensor, original_image, face_box = preprocess_image_xception(image)
|
624 |
+
|
625 |
+
# Get device and model
|
626 |
+
device = st.session_state.device
|
627 |
+
model = st.session_state.xception_model
|
628 |
+
|
629 |
+
# Ensure model is in eval mode
|
630 |
+
model.eval()
|
631 |
+
|
632 |
+
# Move tensor to device
|
633 |
+
input_tensor = input_tensor.to(device)
|
634 |
+
st.write(f"Input tensor on device: {device}")
|
635 |
+
|
636 |
+
# Forward pass with proper error handling
|
637 |
+
try:
|
638 |
+
with torch.no_grad():
|
639 |
+
st.write("Running model inference...")
|
640 |
+
logits = model(input_tensor)
|
641 |
+
st.write(f"Raw logits: {logits}")
|
642 |
+
probabilities = torch.softmax(logits, dim=1)[0]
|
643 |
+
st.write(f"Probabilities: {probabilities}")
|
644 |
+
pred_class = torch.argmax(probabilities).item()
|
645 |
+
confidence = probabilities[pred_class].item()
|
646 |
+
st.write(f"Predicted class: {pred_class}, Confidence: {confidence:.4f}")
|
647 |
|
648 |
+
# Explicit class mapping - adjust if needed based on your model
|
649 |
+
pred_label = "Fake" if pred_class == 0 else "Real"
|
650 |
+
st.write(f"Mapped to label: {pred_label}")
|
651 |
+
except Exception as e:
|
652 |
+
st.error(f"Error in model inference: {str(e)}")
|
653 |
+
import traceback
|
654 |
+
st.error(traceback.format_exc())
|
655 |
+
# Set default values
|
656 |
+
pred_class = 0
|
657 |
+
confidence = 0.5
|
658 |
+
pred_label = "Error in prediction"
|
659 |
+
|
660 |
+
# Display results
|
661 |
+
with col2:
|
662 |
+
st.markdown("### Detection Result")
|
663 |
+
st.markdown(f"**Classification:** {pred_label} (Confidence: {confidence:.2%})")
|
664 |
|
665 |
+
# Display face box on image if detected
|
666 |
+
if face_box:
|
667 |
+
img_to_show = original_image.copy()
|
668 |
+
img_draw = np.array(img_to_show)
|
669 |
+
x, y, w, h = face_box
|
670 |
+
cv2.rectangle(img_draw, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
671 |
+
st.image(Image.fromarray(img_draw), caption="Detected Face", width=300)
|
672 |
+
|
673 |
+
# GradCAM visualization with error handling
|
674 |
+
st.subheader("GradCAM Visualization")
|
675 |
+
try:
|
676 |
+
st.write("Generating GradCAM visualization...")
|
677 |
+
cam, overlay, comparison, detected_face_box = process_image_with_xception_gradcam(
|
678 |
+
image, model, device, pred_class
|
679 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
680 |
|
681 |
+
if comparison:
|
682 |
+
# Display GradCAM results (controlled size)
|
683 |
+
st.image(comparison, caption="Original | CAM | Overlay", width=700)
|
684 |
+
|
685 |
+
# Save for later use
|
686 |
+
st.session_state.comparison_image = comparison
|
687 |
+
else:
|
688 |
+
st.error("GradCAM visualization failed - comparison image not generated")
|
689 |
|
690 |
+
# Generate caption for GradCAM overlay image if BLIP model is loaded
|
691 |
+
if st.session_state.blip_model_loaded and overlay:
|
692 |
+
with st.spinner("Analyzing GradCAM visualization..."):
|
693 |
+
gradcam_caption = generate_gradcam_caption(
|
694 |
+
overlay,
|
695 |
+
st.session_state.finetuned_processor,
|
696 |
+
st.session_state.finetuned_model
|
697 |
+
)
|
698 |
+
st.session_state.gradcam_caption = gradcam_caption
|
699 |
+
except Exception as e:
|
700 |
+
st.error(f"Error generating GradCAM: {str(e)}")
|
701 |
+
import traceback
|
702 |
+
st.error(traceback.format_exc())
|
703 |
+
|
704 |
+
# Save results in session state for LLM analysis
|
705 |
+
st.session_state.current_image = image
|
706 |
+
st.session_state.current_overlay = overlay if 'overlay' in locals() else None
|
707 |
+
st.session_state.current_face_box = detected_face_box if 'detected_face_box' in locals() else None
|
708 |
+
st.session_state.current_pred_label = pred_label
|
709 |
+
st.session_state.current_confidence = confidence
|
710 |
+
|
711 |
+
st.success("✅ Initial detection and GradCAM visualization complete!")
|
712 |
+
except Exception as e:
|
713 |
+
st.error(f"Overall error in Xception processing: {str(e)}")
|
714 |
+
import traceback
|
715 |
+
st.error(traceback.format_exc())
|
716 |
+
else:
|
717 |
+
st.warning("⚠️ Please load the Xception model first to perform initial detection.")
|
718 |
|
719 |
# Image Analysis Summary section - AFTER Stage 2
|
720 |
if hasattr(st.session_state, 'current_image') and (hasattr(st.session_state, 'image_caption') or hasattr(st.session_state, 'gradcam_caption')):
|