taruschirag commited on
Commit
9d6f3f8
·
verified ·
1 Parent(s): 14c5a18

Update app.py

Browse files

fixing the model loading bug

Files changed (1) hide show
  1. app.py +11 -14
app.py CHANGED
@@ -94,20 +94,17 @@ class ModelWrapper:
94
  self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
95
 
96
  print(f"Loading model: {model_name}...")
97
- if "8b" in model_name.lower():
98
- config = AutoConfig.from_pretrained(model_name, torch_dtype=torch.bfloat16)
99
- with init_empty_weights():
100
- model_empty = AutoModelForCausalLM.from_config(config)
101
-
102
- self.model = load_checkpoint_and_dispatch(
103
- model_empty,
104
- model_name,
105
- device_map="auto",
106
- offload_folder="offload",
107
- ).eval()
108
- else:
109
- self.model = AutoModelForCausalLM.from_pretrained(
110
- model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
111
  print(f"Model {model_name} loaded successfully.")
112
 
113
  def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
 
94
  self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
95
 
96
  print(f"Loading model: {model_name}...")
97
+
98
+ # We can now use the same, simpler loading logic for all models.
99
+ # The `from_pretrained` method will handle downloading from the Hub
100
+ # and applying the device_map.
101
+ self.model = AutoModelForCausalLM.from_pretrained(
102
+ model_name,
103
+ device_map="auto",
104
+ torch_dtype=torch.bfloat16,
105
+ offload_folder="offload" # Keep this for memory management
106
+ ).eval()
107
+
 
 
 
108
  print(f"Model {model_name} loaded successfully.")
109
 
110
  def get_message_template(self, system_content=None, user_content=None, assistant_content=None):