import streamlit as st import torch from PIL import Image import os import gc from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel # Page config st.set_page_config( page_title="Deepfake Image Analyzer", page_icon="🔍", layout="wide" ) # App title and description st.title("Deepfake Image Analyzer") st.markdown("Upload an image to analyze it for possible deepfake manipulation") # Function to free up memory def free_memory(): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Helper function to check CUDA def init_device(): if torch.cuda.is_available(): st.sidebar.success("✓ GPU available: Using CUDA") return "cuda" else: st.sidebar.warning("⚠️ No GPU detected: Using CPU (analysis will be slow)") return "cpu" # Set device device = init_device() @st.cache_resource def load_model(): """Load model with proper quantization handling""" try: # Using your original base model base_model_id = "unsloth/llama-3.2-11b-vision-instruct-unsloth-bnb-4bit" # Load processor processor = AutoProcessor.from_pretrained(base_model_id) # Load the model with proper quantization settings model = AutoModelForCausalLM.from_pretrained( base_model_id, device_map="auto", quantization_config=BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ), torch_dtype=torch.float16 ) # Load adapter adapter_id = "saakshigupta/deepfake-explainer-1" model = PeftModel.from_pretrained(model, adapter_id) return model, processor except Exception as e: st.error(f"Error loading model: {str(e)}") st.exception(e) return None, None # Function to fix cross-attention masks def fix_processor_outputs(inputs): """Fix cross-attention mask dimensions if needed""" if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape: batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape visual_features = 6404 # The exact dimension used in training new_mask = torch.ones( (batch_size, seq_len, visual_features, num_tiles), device=inputs['cross_attention_mask'].device ) inputs['cross_attention_mask'] = new_mask return True, inputs return False, inputs # Create sidebar with options with st.sidebar: st.header("Options") temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1, help="Higher values make output more random, lower values more deterministic") max_length = st.slider("Maximum response length", min_value=100, max_value=1000, value=500, step=50) custom_prompt = st.text_area( "Custom instruction (optional)", value="Analyze this image and determine if it's a deepfake. Provide both technical and non-technical explanations.", height=100 ) st.markdown("### About") st.markdown(""" This app uses a fine-tuned Llama 3.2 Vision model to detect and explain deepfakes. The analyzer looks for: - Inconsistencies in facial features - Unusual lighting or shadows - Unnatural blur patterns - Artifacts around edges - Texture inconsistencies Model by [saakshigupta](https://huggingface.co/saakshigupta/deepfake-explainer-1) """) # Load model button if st.button("Load Model"): with st.spinner("Loading model... this may take several minutes"): try: model, processor = load_model() if model is not None and processor is not None: st.session_state['model'] = model st.session_state['processor'] = processor st.success("Model loaded successfully!") else: st.error("Failed to load model.") except Exception as e: st.error(f"Error during model loading: {str(e)}") st.exception(e) # Main content area - file uploader uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) # Check if model is loaded model_loaded = 'model' in st.session_state and st.session_state['model'] is not None if uploaded_file is not None: # Display the image image = Image.open(uploaded_file).convert('RGB') st.image(image, caption="Uploaded Image", use_column_width=True) # Analyze button (only enabled if model is loaded) if st.button("Analyze Image", disabled=not model_loaded): if not model_loaded: st.warning("Please load the model first by clicking the 'Load Model' button.") else: with st.spinner("Analyzing the image... This may take 15-30 seconds"): try: # Get components from session state model = st.session_state['model'] processor = st.session_state['processor'] # Process the image using the processor inputs = processor(text=custom_prompt, images=image, return_tensors="pt") # Fix cross-attention mask if needed fixed, inputs = fix_processor_outputs(inputs) if fixed: st.info("Fixed cross-attention mask dimensions") # Move to device inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)} # Generate the analysis with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=max_length, temperature=temperature, top_p=0.9 ) # Decode the output response = processor.decode(output_ids[0], skip_special_tokens=True) # Extract the actual response (removing the prompt) if custom_prompt in response: result = response.split(custom_prompt)[-1].strip() else: result = response # Display result in a nice format st.success("Analysis complete!") # Show technical and non-technical explanations separately if they exist if "Technical Explanation:" in result and "Non-Technical Explanation:" in result: technical, non_technical = result.split("Non-Technical Explanation:") technical = technical.replace("Technical Explanation:", "").strip() col1, col2 = st.columns(2) with col1: st.subheader("Technical Analysis") st.write(technical) with col2: st.subheader("Simple Explanation") st.write(non_technical) else: st.subheader("Analysis Result") st.write(result) # Free memory after analysis free_memory() except Exception as e: st.error(f"Error analyzing image: {str(e)}") st.exception(e) elif not model_loaded: st.warning("Please load the model first by clicking the 'Load Model' button at the top of the page.") else: st.info("Please upload an image to begin analysis") # Add footer st.markdown("---") st.markdown("Deepfake Image Analyzer")