quarterturn commited on
Commit
4b8c042
·
1 Parent(s): e6cfd14

more fixes

Browse files
app.py CHANGED
@@ -32,13 +32,16 @@ else:
32
  device = "cpu"
33
  print("GPU is not available. Using CPU.")
34
 
 
 
 
35
  # Load the processor
36
  model = "allenai/Molmo-7B-D-0924"
37
 
38
  processor = AutoProcessor.from_pretrained(
39
  model,
40
  trust_remote_code=True,
41
- torch_dtype='auto',
42
  device_map='auto'
43
  )
44
 
@@ -46,7 +49,7 @@ processor = AutoProcessor.from_pretrained(
46
  model = AutoModelForCausalLM.from_pretrained(
47
  model,
48
  trust_remote_code=True,
49
- torch_dtype='auto',
50
  device_map='auto',
51
  )
52
 
@@ -95,7 +98,7 @@ def generate_caption(image_path, processor, model, generation_config, bits_and_b
95
  inputs["images"] = inputs["images"].to(torch.bfloat16)
96
 
97
  # generate output; maximum 500 new tokens; stop generation when is generated
98
- with torch.autocast(device_type=device, enabled=True, dtype="auto"):
99
  output = model.generate_from_batch(
100
  inputs,
101
  GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
 
32
  device = "cpu"
33
  print("GPU is not available. Using CPU.")
34
 
35
+ device = "cpu"
36
+ print("Forcing CPU")
37
+
38
  # Load the processor
39
  model = "allenai/Molmo-7B-D-0924"
40
 
41
  processor = AutoProcessor.from_pretrained(
42
  model,
43
  trust_remote_code=True,
44
+ torch_dtype=torch.bfloat16,
45
  device_map='auto'
46
  )
47
 
 
49
  model = AutoModelForCausalLM.from_pretrained(
50
  model,
51
  trust_remote_code=True,
52
+ torch_dtype=torch.bfloat16,
53
  device_map='auto',
54
  )
55
 
 
98
  inputs["images"] = inputs["images"].to(torch.bfloat16)
99
 
100
  # generate output; maximum 500 new tokens; stop generation when is generated
101
+ with torch.autocast(device_type=device, enabled=True, dtype=torch.bfloat16):
102
  output = model.generate_from_batch(
103
  inputs,
104
  GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
images/6a5e7d80-1c47-4dbf-89e2-43b5749f74ed.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:508ee3059b8e9ea66544730f4e912471ed19d9a87d1c6afd3cc94786dc469eea
3
+ size 994