taruschirag commited on
Commit
02b6bc9
·
verified ·
1 Parent(s): e101ec5

Update app.py

Browse files

Changed it to 2 step process to deal with Meta error.

Files changed (1) hide show
  1. app.py +31 -10
app.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
  import spaces
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
9
  from accelerate import init_empty_weights, load_checkpoint_and_dispatch
 
10
 
11
 
12
  # --- Constants ---
@@ -76,17 +77,37 @@ class ModelWrapper:
76
  self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
77
 
78
  print(f"Loading model: {model_name}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- # We can now use the same, simpler loading logic for all models.
81
- # The `from_pretrained` method will handle downloading from the Hub
82
- # and applying the device_map.
83
- self.model = AutoModelForCausalLM.from_pretrained(
84
- model_name,
85
- device_map="auto",
86
- torch_dtype=torch.bfloat16,
87
- offload_folder="offload" # Keep this for memory management
88
- ).eval()
89
-
90
  print(f"Model {model_name} loaded successfully.")
91
 
92
  def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
 
7
  import spaces
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
9
  from accelerate import init_empty_weights, load_checkpoint_and_dispatch
10
+ from huggingface_hub import snapshot_download
11
 
12
 
13
  # --- Constants ---
 
77
  self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
78
 
79
  print(f"Loading model: {model_name}...")
80
+
81
+ # For large models, we use a more robust, memory-safe loading method.
82
+ # This explicitly handles the "meta tensor" device placement.
83
+ if "8b" in model_name.lower() or "4b" in model_name.lower():
84
+
85
+ # Step 1: Download the model files and get the local path.
86
+ print(f"Ensuring model checkpoint is available locally for {model_name}...")
87
+ checkpoint_path = snapshot_download(repo_id=model_name)
88
+ print(f"Checkpoint is at: {checkpoint_path}")
89
+
90
+ # Step 2: Create the model's "skeleton" on the meta device (no memory used).
91
+ config = AutoConfig.from_pretrained(model_name, torch_dtype=torch.bfloat16)
92
+ with init_empty_weights():
93
+ model_empty = AutoModelForCausalLM.from_config(config)
94
+
95
+ # Step 3: Load the real weights from the local files directly onto the GPU(s).
96
+ # This function is designed to handle the meta->device transition correctly.
97
+ self.model = load_checkpoint_and_dispatch(
98
+ model_empty,
99
+ checkpoint_path,
100
+ device_map="auto",
101
+ offload_folder="offload"
102
+ ).eval()
103
 
104
+ else: # For smaller models, the simpler method is fine.
105
+ self.model = AutoModelForCausalLM.from_pretrained(
106
+ model_name,
107
+ device_map="auto",
108
+ torch_dtype=torch.bfloat16
109
+ ).eval()
110
+
 
 
 
111
  print(f"Model {model_name} loaded successfully.")
112
 
113
  def get_message_template(self, system_content=None, user_content=None, assistant_content=None):