Spaces:
Running on Zero

Ruurd commited on
Commit
df4c990
·
verified ·
1 Parent(s): d83b96d

Changed initialization to MASK tokens instead of EOS tokens

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -16,8 +16,8 @@ hf_token = os.getenv("HF_TOKEN")
16
  # --- Load tokenizer ---
17
  tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
18
  vocab_size = len(tokenizer)
19
- pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id
20
- eot_token_id = tokenizer.eos_token_id
21
  assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
22
 
23
  # def load_model():
@@ -114,7 +114,6 @@ def confidence_guided_noising(input_ids, answer_start, confidences, noise_clippi
114
  noised = input_ids.copy()
115
  answer_len = len(input_ids) - answer_start
116
  num_to_noise = int(threshold * answer_len * noise_start)
117
- mask_token_id = tokenizer.encode('MASK', add_special_tokens = False)[0]
118
 
119
  if num_to_noise == 0:
120
  return noised
@@ -176,7 +175,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_
176
  return
177
 
178
  if len(input_ids) < 256:
179
- input_ids += [pad_token] * (256 - len(input_ids))
180
  else:
181
  input_ids = input_ids[:256]
182
 
@@ -203,7 +202,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_
203
  highlighted = []
204
  for j, tok in enumerate(decoded_tokens):
205
  tok_id = tokenizer.convert_tokens_to_ids(tok)
206
- if tok_id == eot_token_id:
207
  continue
208
  token_str = tokenizer.convert_tokens_to_string([tok])
209
  if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
@@ -245,7 +244,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_
245
  highlighted = []
246
  for j, tok in enumerate(decoded_tokens):
247
  tok_id = tokenizer.convert_tokens_to_ids(tok)
248
- if tok_id == eot_token_id:
249
  continue
250
  token_str = tokenizer.convert_tokens_to_string([tok])
251
  abs_idx = answer_start + j
@@ -259,7 +258,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_
259
 
260
 
261
  final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
262
- final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
263
  final_output = tokenizer.convert_tokens_to_string(final_tokens)
264
  print(final_output)
265
  yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')
 
16
  # --- Load tokenizer ---
17
  tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
18
  vocab_size = len(tokenizer)
19
+ eos_token_id = tokenizer.eos_token_id
20
+ mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
21
  assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
22
 
23
  # def load_model():
 
114
  noised = input_ids.copy()
115
  answer_len = len(input_ids) - answer_start
116
  num_to_noise = int(threshold * answer_len * noise_start)
 
117
 
118
  if num_to_noise == 0:
119
  return noised
 
175
  return
176
 
177
  if len(input_ids) < 256:
178
+ input_ids += [mask_token_id] * (256 - len(input_ids))
179
  else:
180
  input_ids = input_ids[:256]
181
 
 
202
  highlighted = []
203
  for j, tok in enumerate(decoded_tokens):
204
  tok_id = tokenizer.convert_tokens_to_ids(tok)
205
+ if tok_id == eos_token_id:
206
  continue
207
  token_str = tokenizer.convert_tokens_to_string([tok])
208
  if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
 
244
  highlighted = []
245
  for j, tok in enumerate(decoded_tokens):
246
  tok_id = tokenizer.convert_tokens_to_ids(tok)
247
+ if tok_id == eos_token_id:
248
  continue
249
  token_str = tokenizer.convert_tokens_to_string([tok])
250
  abs_idx = answer_start + j
 
258
 
259
 
260
  final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
261
+ final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eos_token_id]
262
  final_output = tokenizer.convert_tokens_to_string(final_tokens)
263
  print(final_output)
264
  yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')