saakshigupta commited on
Commit
2c1d4e3
·
verified ·
1 Parent(s): 9d40c4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -61
app.py CHANGED
@@ -4,9 +4,10 @@ import os
4
  import tempfile
5
  # First load unsloth
6
  from unsloth import FastVisionModel
7
- # Add configuration to fix TorchDynamo issues
8
  import torch
9
- torch._dynamo.config.capture_scalar_outputs = True
 
10
  # Disable fallback warnings to reduce noise
11
  torch._dynamo.config.suppress_errors = True
12
  # Then transformers
@@ -559,30 +560,14 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
559
 
560
  # Generate response with error handling
561
  with st.spinner("Generating detailed analysis... (this may take 15-30 seconds)"):
562
- try:
563
- # First try with dynamic compilation (default)
564
- with torch.no_grad():
565
- output_ids = model.generate(
566
- **inputs,
567
- max_new_tokens=max_tokens,
568
- use_cache=True,
569
- temperature=temperature,
570
- top_p=0.9
571
- )
572
- except Exception as dynamo_error:
573
- st.warning(f"Encountered optimization error, falling back to eager mode: {str(dynamo_error)}")
574
-
575
- # Try again with dynamo disabled
576
- with torch.no_grad():
577
- # Temporarily disable dynamo
578
- with torch._dynamo.disable():
579
- output_ids = model.generate(
580
- **inputs,
581
- max_new_tokens=max_tokens,
582
- use_cache=True,
583
- temperature=temperature,
584
- top_p=0.9
585
- )
586
 
587
  # Decode the output
588
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
@@ -598,42 +583,43 @@ def analyze_image_with_llm(image, gradcam_overlay, face_box, pred_label, confide
598
  except Exception as e:
599
  st.error(f"Error during LLM analysis: {str(e)}")
600
 
601
- # Try one more time with all optimizations disabled
602
  try:
603
- st.info("Attempting fallback with all optimizations disabled...")
 
 
 
 
 
 
 
 
 
 
 
604
  with torch.no_grad():
605
- with torch._dynamo.disable():
606
- # Prepare a simpler prompt
607
- simple_message = [{"role": "user", "content": [
608
- {"type": "text", "text": "Analyze this image and tell if it's a deepfake."}
609
- ]}]
610
- simple_image = image
611
-
612
- # Apply simpler template
613
- simple_text = tokenizer.apply_chat_template(simple_message, add_generation_prompt=True)
614
-
615
- # Tokenize with just the image
616
- simple_inputs = tokenizer(
617
- simple_image,
618
- simple_text,
619
- add_special_tokens=False,
620
- return_tensors="pt",
621
- ).to(model.device)
622
-
623
- # Generate with minimal settings
624
- output_ids = model.generate(
625
- **simple_inputs,
626
- max_new_tokens=200,
627
- use_cache=True,
628
- temperature=0.5,
629
- top_p=0.9
630
- )
631
-
632
- # Decode
633
- fallback_response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
634
- return "Error with optimized generation. Fallback analysis: " + fallback_response.split("Analyze this image and tell if it's a deepfake.")[-1].strip()
635
  except Exception as fallback_error:
636
- return f"Error analyzing image. Primary error: {str(e)}\nFallback error: {str(fallback_error)}"
637
 
638
  # Preprocess image for Xception
639
  def preprocess_image_xception(image):
@@ -1110,7 +1096,7 @@ def main():
1110
 
1111
  if clear_button:
1112
  st.session_state.chat_history = []
1113
- st.experimental_rerun()
1114
 
1115
  if analyze_button and new_question:
1116
  try:
@@ -1167,7 +1153,7 @@ def main():
1167
  st.markdown(result)
1168
 
1169
  # Rerun to update the chat history display
1170
- st.experimental_rerun()
1171
 
1172
  except Exception as e:
1173
  st.error(f"Error during LLM analysis: {str(e)}")
 
4
  import tempfile
5
  # First load unsloth
6
  from unsloth import FastVisionModel
7
+ # Completely disable dynamic compilation due to compatibility issues
8
  import torch
9
+ # Disable TorchDynamo completely to avoid optimization errors
10
+ torch._dynamo.config.disable = True
11
  # Disable fallback warnings to reduce noise
12
  torch._dynamo.config.suppress_errors = True
13
  # Then transformers
 
560
 
561
  # Generate response with error handling
562
  with st.spinner("Generating detailed analysis... (this may take 15-30 seconds)"):
563
+ with torch.no_grad():
564
+ output_ids = model.generate(
565
+ **inputs,
566
+ max_new_tokens=max_tokens,
567
+ use_cache=True,
568
+ temperature=temperature,
569
+ top_p=0.9
570
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
 
572
  # Decode the output
573
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
583
  except Exception as e:
584
  st.error(f"Error during LLM analysis: {str(e)}")
585
 
586
+ # Try one more time with simpler input
587
  try:
588
+ st.info("Attempting fallback with simplified input...")
589
+
590
+ # Prepare a simpler prompt
591
+ simple_message = [{"role": "user", "content": [
592
+ {"type": "text", "text": "Analyze this image and tell if it's a deepfake."},
593
+ {"type": "image", "image": image}
594
+ ]}]
595
+
596
+ # Apply simpler template
597
+ simple_text = tokenizer.apply_chat_template(simple_message, add_generation_prompt=True)
598
+
599
+ # Generate with minimal settings
600
  with torch.no_grad():
601
+ simple_inputs = tokenizer(
602
+ image,
603
+ simple_text,
604
+ add_special_tokens=False,
605
+ return_tensors="pt",
606
+ ).to(model.device)
607
+
608
+ simple_inputs = fix_cross_attention_mask(simple_inputs)
609
+
610
+ output_ids = model.generate(
611
+ **simple_inputs,
612
+ max_new_tokens=200,
613
+ use_cache=True,
614
+ temperature=0.5,
615
+ top_p=0.9
616
+ )
617
+
618
+ # Decode
619
+ fallback_response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
620
+ return "Error with primary analysis. Fallback result: " + fallback_response.split("Analyze this image and tell if it's a deepfake.")[-1].strip()
 
 
 
 
 
 
 
 
 
 
621
  except Exception as fallback_error:
622
+ return f"Error analyzing image: {str(fallback_error)}"
623
 
624
  # Preprocess image for Xception
625
  def preprocess_image_xception(image):
 
1096
 
1097
  if clear_button:
1098
  st.session_state.chat_history = []
1099
+ st.rerun()
1100
 
1101
  if analyze_button and new_question:
1102
  try:
 
1153
  st.markdown(result)
1154
 
1155
  # Rerun to update the chat history display
1156
+ st.rerun()
1157
 
1158
  except Exception as e:
1159
  st.error(f"Error during LLM analysis: {str(e)}")