Spaces:
Running
on
Zero
Running
on
Zero
Changed initialization to MASK tokens instead of EOS tokens
Browse files
app.py
CHANGED
@@ -16,8 +16,8 @@ hf_token = os.getenv("HF_TOKEN")
|
|
16 |
# --- Load tokenizer ---
|
17 |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
|
18 |
vocab_size = len(tokenizer)
|
19 |
-
|
20 |
-
|
21 |
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
|
22 |
|
23 |
# def load_model():
|
@@ -114,7 +114,6 @@ def confidence_guided_noising(input_ids, answer_start, confidences, noise_clippi
|
|
114 |
noised = input_ids.copy()
|
115 |
answer_len = len(input_ids) - answer_start
|
116 |
num_to_noise = int(threshold * answer_len * noise_start)
|
117 |
-
mask_token_id = tokenizer.encode('MASK', add_special_tokens = False)[0]
|
118 |
|
119 |
if num_to_noise == 0:
|
120 |
return noised
|
@@ -176,7 +175,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_
|
|
176 |
return
|
177 |
|
178 |
if len(input_ids) < 256:
|
179 |
-
input_ids += [
|
180 |
else:
|
181 |
input_ids = input_ids[:256]
|
182 |
|
@@ -203,7 +202,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_
|
|
203 |
highlighted = []
|
204 |
for j, tok in enumerate(decoded_tokens):
|
205 |
tok_id = tokenizer.convert_tokens_to_ids(tok)
|
206 |
-
if tok_id ==
|
207 |
continue
|
208 |
token_str = tokenizer.convert_tokens_to_string([tok])
|
209 |
if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
|
@@ -245,7 +244,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_
|
|
245 |
highlighted = []
|
246 |
for j, tok in enumerate(decoded_tokens):
|
247 |
tok_id = tokenizer.convert_tokens_to_ids(tok)
|
248 |
-
if tok_id ==
|
249 |
continue
|
250 |
token_str = tokenizer.convert_tokens_to_string([tok])
|
251 |
abs_idx = answer_start + j
|
@@ -259,7 +258,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_
|
|
259 |
|
260 |
|
261 |
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
|
262 |
-
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) !=
|
263 |
final_output = tokenizer.convert_tokens_to_string(final_tokens)
|
264 |
print(final_output)
|
265 |
yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')
|
|
|
16 |
# --- Load tokenizer ---
|
17 |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
|
18 |
vocab_size = len(tokenizer)
|
19 |
+
eos_token_id = tokenizer.eos_token_id
|
20 |
+
mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
|
21 |
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
|
22 |
|
23 |
# def load_model():
|
|
|
114 |
noised = input_ids.copy()
|
115 |
answer_len = len(input_ids) - answer_start
|
116 |
num_to_noise = int(threshold * answer_len * noise_start)
|
|
|
117 |
|
118 |
if num_to_noise == 0:
|
119 |
return noised
|
|
|
175 |
return
|
176 |
|
177 |
if len(input_ids) < 256:
|
178 |
+
input_ids += [mask_token_id] * (256 - len(input_ids))
|
179 |
else:
|
180 |
input_ids = input_ids[:256]
|
181 |
|
|
|
202 |
highlighted = []
|
203 |
for j, tok in enumerate(decoded_tokens):
|
204 |
tok_id = tokenizer.convert_tokens_to_ids(tok)
|
205 |
+
if tok_id == eos_token_id:
|
206 |
continue
|
207 |
token_str = tokenizer.convert_tokens_to_string([tok])
|
208 |
if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
|
|
|
244 |
highlighted = []
|
245 |
for j, tok in enumerate(decoded_tokens):
|
246 |
tok_id = tokenizer.convert_tokens_to_ids(tok)
|
247 |
+
if tok_id == eos_token_id:
|
248 |
continue
|
249 |
token_str = tokenizer.convert_tokens_to_string([tok])
|
250 |
abs_idx = answer_start + j
|
|
|
258 |
|
259 |
|
260 |
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
|
261 |
+
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eos_token_id]
|
262 |
final_output = tokenizer.convert_tokens_to_string(final_tokens)
|
263 |
print(final_output)
|
264 |
yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')
|