banao-tech commited on
Commit
29f706c
·
verified ·
1 Parent(s): ef9514e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +20 -10
main.py CHANGED
@@ -37,23 +37,33 @@ except:
37
  yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cpu", weights_only=False)["model"]
38
 
39
  from transformers import AutoProcessor, AutoModelForCausalLM
 
40
 
41
- processor = AutoProcessor.from_pretrained(
42
- "microsoft/Florence-2-base", trust_remote_code=True
43
- )
44
 
45
  try:
46
  model = AutoModelForCausalLM.from_pretrained(
47
  "weights/icon_caption_florence",
48
- torch_dtype=torch.float16,
49
- trust_remote_code=True,
50
- ).to("cuda")
51
- except:
 
 
52
  model = AutoModelForCausalLM.from_pretrained(
53
  "weights/icon_caption_florence",
54
- torch_dtype=torch.float16,
55
- trust_remote_code=True,
56
- )
 
 
 
 
 
 
 
57
  caption_model_processor = {"processor": processor, "model": model}
58
  print("finish loading model!!!")
59
 
 
37
  yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cpu", weights_only=False)["model"]
38
 
39
  from transformers import AutoProcessor, AutoModelForCausalLM
40
+ import torch
41
 
42
+ # Check if CUDA is available
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+ dtype = torch.float16 if device == "cuda" else torch.float32 # Use float32 on CPU
45
 
46
  try:
47
  model = AutoModelForCausalLM.from_pretrained(
48
  "weights/icon_caption_florence",
49
+ torch_dtype=dtype, # Dynamic dtype based on device
50
+ trust_remote_code=True
51
+ ).to(device)
52
+ except Exception as e:
53
+ print(f"Error loading model: {str(e)}")
54
+ # Fallback to CPU with float32
55
  model = AutoModelForCausalLM.from_pretrained(
56
  "weights/icon_caption_florence",
57
+ torch_dtype=torch.float32,
58
+ trust_remote_code=True
59
+ ).to("cpu")
60
+
61
+ # Force config for DaViT vision tower
62
+ if not hasattr(model.config, 'vision_config'):
63
+ model.config.vision_config = {}
64
+ if 'model_type' not in model.config.vision_config:
65
+ model.config.vision_config['model_type'] = 'davit'
66
+
67
  caption_model_processor = {"processor": processor, "model": model}
68
  print("finish loading model!!!")
69