Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,11 @@ import os
|
|
4 |
import tempfile
|
5 |
# First load unsloth
|
6 |
from unsloth import FastVisionModel
|
|
|
|
|
|
|
|
|
|
|
7 |
# Then transformers
|
8 |
from transformers import BlipProcessor, BlipForConditionalGeneration
|
9 |
import torch
|
@@ -552,16 +557,32 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
|
|
552 |
# Fix cross-attention mask if needed
|
553 |
inputs = fix_cross_attention_mask(inputs)
|
554 |
|
555 |
-
# Generate response
|
556 |
with st.spinner("Generating detailed analysis... (this may take 15-30 seconds)"):
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
|
566 |
# Decode the output
|
567 |
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
@@ -576,7 +597,43 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
|
|
576 |
|
577 |
except Exception as e:
|
578 |
st.error(f"Error during LLM analysis: {str(e)}")
|
579 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
580 |
|
581 |
# Preprocess image for Xception
|
582 |
def preprocess_image_xception(image):
|
|
|
4 |
import tempfile
|
5 |
# First load unsloth
|
6 |
from unsloth import FastVisionModel
|
7 |
+
# Add configuration to fix TorchDynamo issues
|
8 |
+
import torch
|
9 |
+
torch._dynamo.config.capture_scalar_outputs = True
|
10 |
+
# Set a reasonable optimization level
|
11 |
+
torch._dynamo.config.opt_level = "default"
|
12 |
# Then transformers
|
13 |
from transformers import BlipProcessor, BlipForConditionalGeneration
|
14 |
import torch
|
|
|
557 |
# Fix cross-attention mask if needed
|
558 |
inputs = fix_cross_attention_mask(inputs)
|
559 |
|
560 |
+
# Generate response with error handling
|
561 |
with st.spinner("Generating detailed analysis... (this may take 15-30 seconds)"):
|
562 |
+
try:
|
563 |
+
# First try with dynamic compilation (default)
|
564 |
+
with torch.no_grad():
|
565 |
+
output_ids = model.generate(
|
566 |
+
**inputs,
|
567 |
+
max_new_tokens=max_tokens,
|
568 |
+
use_cache=True,
|
569 |
+
temperature=temperature,
|
570 |
+
top_p=0.9
|
571 |
+
)
|
572 |
+
except Exception as dynamo_error:
|
573 |
+
st.warning(f"Encountered optimization error, falling back to eager mode: {str(dynamo_error)}")
|
574 |
+
|
575 |
+
# Try again with dynamo disabled
|
576 |
+
with torch.no_grad():
|
577 |
+
# Temporarily disable dynamo
|
578 |
+
with torch._dynamo.disable():
|
579 |
+
output_ids = model.generate(
|
580 |
+
**inputs,
|
581 |
+
max_new_tokens=max_tokens,
|
582 |
+
use_cache=True,
|
583 |
+
temperature=temperature,
|
584 |
+
top_p=0.9
|
585 |
+
)
|
586 |
|
587 |
# Decode the output
|
588 |
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
|
597 |
|
598 |
except Exception as e:
|
599 |
st.error(f"Error during LLM analysis: {str(e)}")
|
600 |
+
|
601 |
+
# Try one more time with all optimizations disabled
|
602 |
+
try:
|
603 |
+
st.info("Attempting fallback with all optimizations disabled...")
|
604 |
+
with torch.no_grad():
|
605 |
+
with torch._dynamo.disable():
|
606 |
+
# Prepare a simpler prompt
|
607 |
+
simple_message = [{"role": "user", "content": [
|
608 |
+
{"type": "text", "text": "Analyze this image and tell if it's a deepfake."}
|
609 |
+
]}]
|
610 |
+
simple_image = image
|
611 |
+
|
612 |
+
# Apply simpler template
|
613 |
+
simple_text = tokenizer.apply_chat_template(simple_message, add_generation_prompt=True)
|
614 |
+
|
615 |
+
# Tokenize with just the image
|
616 |
+
simple_inputs = tokenizer(
|
617 |
+
simple_image,
|
618 |
+
simple_text,
|
619 |
+
add_special_tokens=False,
|
620 |
+
return_tensors="pt",
|
621 |
+
).to(model.device)
|
622 |
+
|
623 |
+
# Generate with minimal settings
|
624 |
+
output_ids = model.generate(
|
625 |
+
**simple_inputs,
|
626 |
+
max_new_tokens=200,
|
627 |
+
use_cache=True,
|
628 |
+
temperature=0.5,
|
629 |
+
top_p=0.9
|
630 |
+
)
|
631 |
+
|
632 |
+
# Decode
|
633 |
+
fallback_response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
634 |
+
return "Error with optimized generation. Fallback analysis: " + fallback_response.split("Analyze this image and tell if it's a deepfake.")[-1].strip()
|
635 |
+
except Exception as fallback_error:
|
636 |
+
return f"Error analyzing image. Primary error: {str(e)}\nFallback error: {str(fallback_error)}"
|
637 |
|
638 |
# Preprocess image for Xception
|
639 |
def preprocess_image_xception(image):
|