taruschirag commited on
Commit
14c5a18
·
verified ·
1 Parent(s): 7591333

Update app.py

Browse files

Fixed the torch_dtype bug. Needed it initialize the dtype when loading the model rather than later.

Files changed (1) hide show
  1. app.py +1 -2
app.py CHANGED
@@ -95,7 +95,7 @@ class ModelWrapper:
95
 
96
  print(f"Loading model: {model_name}...")
97
  if "8b" in model_name.lower():
98
- config = AutoConfig.from_pretrained(model_name)
99
  with init_empty_weights():
100
  model_empty = AutoModelForCausalLM.from_config(config)
101
 
@@ -104,7 +104,6 @@ class ModelWrapper:
104
  model_name,
105
  device_map="auto",
106
  offload_folder="offload",
107
- torch_dtype=torch.bfloat16
108
  ).eval()
109
  else:
110
  self.model = AutoModelForCausalLM.from_pretrained(
 
95
 
96
  print(f"Loading model: {model_name}...")
97
  if "8b" in model_name.lower():
98
+ config = AutoConfig.from_pretrained(model_name, torch_dtype=torch.bfloat16)
99
  with init_empty_weights():
100
  model_empty = AutoModelForCausalLM.from_config(config)
101
 
 
104
  model_name,
105
  device_map="auto",
106
  offload_folder="offload",
 
107
  ).eval()
108
  else:
109
  self.model = AutoModelForCausalLM.from_pretrained(