saakshigupta commited on
Commit
8c6db4b
·
verified ·
1 Parent(s): 7dc4d76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -10
app.py CHANGED
@@ -4,6 +4,11 @@ import os
4
  import tempfile
5
  # First load unsloth
6
  from unsloth import FastVisionModel
 
 
 
 
 
7
  # Then transformers
8
  from transformers import BlipProcessor, BlipForConditionalGeneration
9
  import torch
@@ -552,16 +557,32 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
552
  # Fix cross-attention mask if needed
553
  inputs = fix_cross_attention_mask(inputs)
554
 
555
- # Generate response
556
  with st.spinner("Generating detailed analysis... (this may take 15-30 seconds)"):
557
- with torch.no_grad():
558
- output_ids = model.generate(
559
- **inputs,
560
- max_new_tokens=max_tokens,
561
- use_cache=True,
562
- temperature=temperature,
563
- top_p=0.9
564
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
 
566
  # Decode the output
567
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
@@ -576,7 +597,43 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
576
 
577
  except Exception as e:
578
  st.error(f"Error during LLM analysis: {str(e)}")
579
- return f"Error analyzing image: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
 
581
  # Preprocess image for Xception
582
  def preprocess_image_xception(image):
 
4
  import tempfile
5
  # First load unsloth
6
  from unsloth import FastVisionModel
7
+ # Add configuration to fix TorchDynamo issues
8
+ import torch
9
+ torch._dynamo.config.capture_scalar_outputs = True
10
+ # Set a reasonable optimization level
11
+ torch._dynamo.config.opt_level = "default"
12
  # Then transformers
13
  from transformers import BlipProcessor, BlipForConditionalGeneration
14
  import torch
 
557
  # Fix cross-attention mask if needed
558
  inputs = fix_cross_attention_mask(inputs)
559
 
560
+ # Generate response with error handling
561
  with st.spinner("Generating detailed analysis... (this may take 15-30 seconds)"):
562
+ try:
563
+ # First try with dynamic compilation (default)
564
+ with torch.no_grad():
565
+ output_ids = model.generate(
566
+ **inputs,
567
+ max_new_tokens=max_tokens,
568
+ use_cache=True,
569
+ temperature=temperature,
570
+ top_p=0.9
571
+ )
572
+ except Exception as dynamo_error:
573
+ st.warning(f"Encountered optimization error, falling back to eager mode: {str(dynamo_error)}")
574
+
575
+ # Try again with dynamo disabled
576
+ with torch.no_grad():
577
+ # Temporarily disable dynamo
578
+ with torch._dynamo.disable():
579
+ output_ids = model.generate(
580
+ **inputs,
581
+ max_new_tokens=max_tokens,
582
+ use_cache=True,
583
+ temperature=temperature,
584
+ top_p=0.9
585
+ )
586
 
587
  # Decode the output
588
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
597
 
598
  except Exception as e:
599
  st.error(f"Error during LLM analysis: {str(e)}")
600
+
601
+ # Try one more time with all optimizations disabled
602
+ try:
603
+ st.info("Attempting fallback with all optimizations disabled...")
604
+ with torch.no_grad():
605
+ with torch._dynamo.disable():
606
+ # Prepare a simpler prompt
607
+ simple_message = [{"role": "user", "content": [
608
+ {"type": "text", "text": "Analyze this image and tell if it's a deepfake."}
609
+ ]}]
610
+ simple_image = image
611
+
612
+ # Apply simpler template
613
+ simple_text = tokenizer.apply_chat_template(simple_message, add_generation_prompt=True)
614
+
615
+ # Tokenize with just the image
616
+ simple_inputs = tokenizer(
617
+ simple_image,
618
+ simple_text,
619
+ add_special_tokens=False,
620
+ return_tensors="pt",
621
+ ).to(model.device)
622
+
623
+ # Generate with minimal settings
624
+ output_ids = model.generate(
625
+ **simple_inputs,
626
+ max_new_tokens=200,
627
+ use_cache=True,
628
+ temperature=0.5,
629
+ top_p=0.9
630
+ )
631
+
632
+ # Decode
633
+ fallback_response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
634
+ return "Error with optimized generation. Fallback analysis: " + fallback_response.split("Analyze this image and tell if it's a deepfake.")[-1].strip()
635
+ except Exception as fallback_error:
636
+ return f"Error analyzing image. Primary error: {str(e)}\nFallback error: {str(fallback_error)}"
637
 
638
  # Preprocess image for Xception
639
  def preprocess_image_xception(image):