Brianpuz commited on
Commit
4c71e7c
Β·
1 Parent(s): 787eccb

fix problems

Browse files
Files changed (1) hide show
  1. app.py +48 -28
app.py CHANGED
@@ -81,23 +81,30 @@ class AbliterationProcessor:
81
  def load_model(self, model_id):
82
  """Load model and tokenizer"""
83
  try:
 
 
 
 
84
  self.model = AutoModelForCausalLM.from_pretrained(
85
  model_id,
86
  trust_remote_code=True,
87
- torch_dtype=torch.float16
 
88
  )
89
  self.tokenizer = AutoTokenizer.from_pretrained(
90
  model_id,
91
  trust_remote_code=True
92
  )
93
- return f"βœ… Model {model_id} loaded successfully!"
 
 
94
  except Exception as e:
95
  return f"❌ Model loading failed: {str(e)}"
96
 
97
  def process_abliteration(self, model_id, harmful_text, harmless_text, instructions,
98
  scale_factor, skip_begin, skip_end, layer_fraction,
99
  private_repo, export_to_org, repo_owner, org_token, oauth_token: gr.OAuthToken | None,
100
- progress=gr.Progress()):
101
  """Execute abliteration processing and upload to HuggingFace"""
102
  if oauth_token is None or oauth_token.token is None:
103
  return (
@@ -120,12 +127,12 @@ class AbliterationProcessor:
120
  repo_owner = "self"
121
 
122
  try:
123
- progress(0, desc="STEP 1/14: Loading model...")
124
  # Load model
125
  if self.model is None or self.tokenizer is None:
126
  self.load_model(model_id)
127
 
128
- progress(0, desc="STEP 2/14: Parsing instructions...")
129
  # Parse text content
130
  harmful_instructions = [line.strip() for line in harmful_text.strip().split('\n') if line.strip()]
131
  harmless_instructions = [line.strip() for line in harmless_text.strip().split('\n') if line.strip()]
@@ -134,12 +141,12 @@ class AbliterationProcessor:
134
  harmful_instructions = random.sample(harmful_instructions, min(instructions, len(harmful_instructions)))
135
  harmless_instructions = random.sample(harmless_instructions, min(instructions, len(harmless_instructions)))
136
 
137
- progress(0, desc="STEP 3/14: Calculating layer index...")
138
  # Calculate layer index
139
  layer_idx = int(len(self.model.model.layers) * layer_fraction)
140
  pos = -1
141
 
142
- progress(0, desc="STEP 4/14: Generating harmful tokens...")
143
  # Generate tokens
144
  harmful_toks = [
145
  self.tokenizer.apply_chat_template(
@@ -149,7 +156,7 @@ class AbliterationProcessor:
149
  ) for insn in harmful_instructions
150
  ]
151
 
152
- progress(0, desc="STEP 5/14: Generating harmless tokens...")
153
  harmless_toks = [
154
  self.tokenizer.apply_chat_template(
155
  conversation=[{"role": "user", "content": insn}],
@@ -168,13 +175,13 @@ class AbliterationProcessor:
168
  output_hidden_states=True
169
  )
170
 
171
- progress(0, desc="STEP 6/14: Processing harmful instructions...")
172
  harmful_outputs = [generate(toks) for toks in harmful_toks]
173
 
174
- progress(0, desc="STEP 7/14: Processing harmless instructions...")
175
  harmless_outputs = [generate(toks) for toks in harmless_toks]
176
 
177
- progress(0, desc="STEP 8/14: Extracting hidden states...")
178
  # Extract hidden states
179
  harmful_hidden = [output.hidden_states[0][layer_idx][:, pos, :] for output in harmful_outputs]
180
  harmless_hidden = [output.hidden_states[0][layer_idx][:, pos, :] for output in harmless_outputs]
@@ -182,7 +189,7 @@ class AbliterationProcessor:
182
  harmful_mean = torch.stack(harmful_hidden).mean(dim=0)
183
  harmless_mean = torch.stack(harmless_hidden).mean(dim=0)
184
 
185
- progress(0, desc="STEP 9/14: Calculating refusal direction...")
186
  # Calculate refusal direction
187
  refusal_dir = harmful_mean - harmless_mean
188
  refusal_dir = refusal_dir / refusal_dir.norm()
@@ -194,11 +201,11 @@ class AbliterationProcessor:
194
  self.refusal_dir = refusal_dir
195
  self.projection_matrix = projection_matrix
196
 
197
- progress(0, desc="STEP 10/14: Updating model weights...")
198
  # Modify model weights
199
  self.modify_layer_weights_optimized(projection_matrix, skip_begin, skip_end, scale_factor, progress)
200
 
201
- progress(0, desc="STEP 11/14: Preparing model for upload...")
202
  # Create temporary directory to save model
203
  with tempfile.TemporaryDirectory() as temp_dir:
204
  # Save model in safetensors format
@@ -206,7 +213,7 @@ class AbliterationProcessor:
206
  self.tokenizer.save_pretrained(temp_dir)
207
  torch.save(self.refusal_dir, os.path.join(temp_dir, "refusal_dir.pt"))
208
 
209
- progress(0, desc="STEP 12/14: Uploading to HuggingFace...")
210
  # Upload to HuggingFace
211
  repo_namespace = get_repo_namespace(repo_owner, username, user_orgs)
212
  model_name = model_id.split("/")[-1]
@@ -230,7 +237,7 @@ class AbliterationProcessor:
230
  repo_id=repo_id
231
  )
232
 
233
- progress(0, desc="STEP 13/14: Creating model card...")
234
  # Create model card
235
  try:
236
  original_card = ModelCard.load(model_id, token=oauth_token.token)
@@ -245,7 +252,7 @@ class AbliterationProcessor:
245
  repo_id=repo_id
246
  )
247
 
248
- progress(0, desc="STEP 14/14: Complete!")
249
  return (
250
  f'<h1>βœ… DONE</h1><br/>Repo: <a href="{new_repo_url}" target="_blank" style="text-decoration:underline">{repo_id}</a>',
251
  f"llama{np.random.randint(9)}.png",
@@ -265,7 +272,7 @@ class AbliterationProcessor:
265
 
266
  for i, layer_idx in enumerate(layers_to_modify):
267
  if progress:
268
- progress(0, desc=f"STEP 10/14: Updating layer {layer_idx+1}/{num_layers} (Layer {i+1}/{total_layers})")
269
 
270
  layer = self.model.model.layers[layer_idx]
271
 
@@ -303,22 +310,34 @@ class AbliterationProcessor:
303
  return_tensors="pt"
304
  )
305
 
306
- # Generate response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  gen = self.model.generate(
308
  toks.to(self.model.device),
309
  max_new_tokens=2048,
310
  temperature=0.7,
311
  do_sample=True,
312
- pad_token_id=self.tokenizer.eos_token_id
313
- )
314
-
315
- # Decode response
316
- decoded = self.tokenizer.batch_decode(
317
- gen[0][len(toks[0]):],
318
- skip_special_tokens=True
319
  )
320
 
321
- response = "".join(decoded).strip()
 
322
  return response, history + [[message, response]]
323
 
324
  except Exception as e:
@@ -473,7 +492,8 @@ def create_interface():
473
  with gr.TabItem("πŸ’¬ Chat Test"):
474
  chatbot = gr.Chatbot(
475
  label="Chat Window",
476
- height=400
 
477
  )
478
  msg = gr.Textbox(
479
  label="Input Message",
 
81
  def load_model(self, model_id):
82
  """Load model and tokenizer"""
83
  try:
84
+ # Auto-detect GPU
85
+ device = "cuda" if torch.cuda.is_available() else "cpu"
86
+ print(f"Using device: {device}")
87
+
88
  self.model = AutoModelForCausalLM.from_pretrained(
89
  model_id,
90
  trust_remote_code=True,
91
+ torch_dtype=torch.float16,
92
+ device_map="auto" if device == "cuda" else None
93
  )
94
  self.tokenizer = AutoTokenizer.from_pretrained(
95
  model_id,
96
  trust_remote_code=True
97
  )
98
+
99
+ device_info = f" on {device.upper()}" if device == "cuda" else ""
100
+ return f"βœ… Model {model_id} loaded successfully{device_info}!"
101
  except Exception as e:
102
  return f"❌ Model loading failed: {str(e)}"
103
 
104
  def process_abliteration(self, model_id, harmful_text, harmless_text, instructions,
105
  scale_factor, skip_begin, skip_end, layer_fraction,
106
  private_repo, export_to_org, repo_owner, org_token, oauth_token: gr.OAuthToken | None,
107
+ progress=gr.Progress(track_tqdm=False)):
108
  """Execute abliteration processing and upload to HuggingFace"""
109
  if oauth_token is None or oauth_token.token is None:
110
  return (
 
127
  repo_owner = "self"
128
 
129
  try:
130
+ progress(desc="STEP 1/14: Loading model...")
131
  # Load model
132
  if self.model is None or self.tokenizer is None:
133
  self.load_model(model_id)
134
 
135
+ progress(desc="STEP 2/14: Parsing instructions...")
136
  # Parse text content
137
  harmful_instructions = [line.strip() for line in harmful_text.strip().split('\n') if line.strip()]
138
  harmless_instructions = [line.strip() for line in harmless_text.strip().split('\n') if line.strip()]
 
141
  harmful_instructions = random.sample(harmful_instructions, min(instructions, len(harmful_instructions)))
142
  harmless_instructions = random.sample(harmless_instructions, min(instructions, len(harmless_instructions)))
143
 
144
+ progress(desc="STEP 3/14: Calculating layer index...")
145
  # Calculate layer index
146
  layer_idx = int(len(self.model.model.layers) * layer_fraction)
147
  pos = -1
148
 
149
+ progress(desc="STEP 4/14: Generating harmful tokens...")
150
  # Generate tokens
151
  harmful_toks = [
152
  self.tokenizer.apply_chat_template(
 
156
  ) for insn in harmful_instructions
157
  ]
158
 
159
+ progress(desc="STEP 5/14: Generating harmless tokens...")
160
  harmless_toks = [
161
  self.tokenizer.apply_chat_template(
162
  conversation=[{"role": "user", "content": insn}],
 
175
  output_hidden_states=True
176
  )
177
 
178
+ progress(desc="STEP 6/14: Processing harmful instructions...")
179
  harmful_outputs = [generate(toks) for toks in harmful_toks]
180
 
181
+ progress(desc="STEP 7/14: Processing harmless instructions...")
182
  harmless_outputs = [generate(toks) for toks in harmless_toks]
183
 
184
+ progress(desc="STEP 8/14: Extracting hidden states...")
185
  # Extract hidden states
186
  harmful_hidden = [output.hidden_states[0][layer_idx][:, pos, :] for output in harmful_outputs]
187
  harmless_hidden = [output.hidden_states[0][layer_idx][:, pos, :] for output in harmless_outputs]
 
189
  harmful_mean = torch.stack(harmful_hidden).mean(dim=0)
190
  harmless_mean = torch.stack(harmless_hidden).mean(dim=0)
191
 
192
+ progress(desc="STEP 9/14: Calculating refusal direction...")
193
  # Calculate refusal direction
194
  refusal_dir = harmful_mean - harmless_mean
195
  refusal_dir = refusal_dir / refusal_dir.norm()
 
201
  self.refusal_dir = refusal_dir
202
  self.projection_matrix = projection_matrix
203
 
204
+ progress(desc="STEP 10/14: Updating model weights...")
205
  # Modify model weights
206
  self.modify_layer_weights_optimized(projection_matrix, skip_begin, skip_end, scale_factor, progress)
207
 
208
+ progress(desc="STEP 11/14: Preparing model for upload...")
209
  # Create temporary directory to save model
210
  with tempfile.TemporaryDirectory() as temp_dir:
211
  # Save model in safetensors format
 
213
  self.tokenizer.save_pretrained(temp_dir)
214
  torch.save(self.refusal_dir, os.path.join(temp_dir, "refusal_dir.pt"))
215
 
216
+ progress(desc="STEP 12/14: Uploading to HuggingFace...")
217
  # Upload to HuggingFace
218
  repo_namespace = get_repo_namespace(repo_owner, username, user_orgs)
219
  model_name = model_id.split("/")[-1]
 
237
  repo_id=repo_id
238
  )
239
 
240
+ progress(desc="STEP 13/14: Creating model card...")
241
  # Create model card
242
  try:
243
  original_card = ModelCard.load(model_id, token=oauth_token.token)
 
252
  repo_id=repo_id
253
  )
254
 
255
+ progress(desc="STEP 14/14: Complete!")
256
  return (
257
  f'<h1>βœ… DONE</h1><br/>Repo: <a href="{new_repo_url}" target="_blank" style="text-decoration:underline">{repo_id}</a>',
258
  f"llama{np.random.randint(9)}.png",
 
272
 
273
  for i, layer_idx in enumerate(layers_to_modify):
274
  if progress:
275
+ progress(desc=f"STEP 10/14: Updating layer {layer_idx+1}/{num_layers} (Layer {i+1}/{total_layers})")
276
 
277
  layer = self.model.model.layers[layer_idx]
278
 
 
310
  return_tensors="pt"
311
  )
312
 
313
+ # Generate response with streaming like abliterated_optimized.py
314
+ from transformers import TextStreamer
315
+
316
+ # Create a custom streamer that captures all output
317
+ captured_output = []
318
+
319
+ class CustomStreamer(TextStreamer):
320
+ def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=True):
321
+ super().__init__(tokenizer, skip_prompt=skip_prompt, skip_special_tokens=skip_special_tokens)
322
+ self.captured = []
323
+
324
+ def on_finalized_text(self, text: str, stream_end: bool = False):
325
+ self.captured.append(text)
326
+ super().on_finalized_text(text, stream_end)
327
+
328
+ streamer = CustomStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
329
+
330
  gen = self.model.generate(
331
  toks.to(self.model.device),
332
  max_new_tokens=2048,
333
  temperature=0.7,
334
  do_sample=True,
335
+ pad_token_id=self.tokenizer.eos_token_id,
336
+ streamer=streamer
 
 
 
 
 
337
  )
338
 
339
+ # Get the complete response from streamer
340
+ response = "".join(streamer.captured).strip()
341
  return response, history + [[message, response]]
342
 
343
  except Exception as e:
 
492
  with gr.TabItem("πŸ’¬ Chat Test"):
493
  chatbot = gr.Chatbot(
494
  label="Chat Window",
495
+ height=400,
496
+ type="messages"
497
  )
498
  msg = gr.Textbox(
499
  label="Input Message",