Spaces:
Running
on
Zero
Running
on
Zero
Add semi-autoregressive generation
Browse files
app.py
CHANGED
@@ -16,7 +16,6 @@ from infer import (
|
|
16 |
find_answer_start,
|
17 |
get_noising_schedule,
|
18 |
noisify_answer,
|
19 |
-
generate_diffusion_text,
|
20 |
filter_logits,
|
21 |
confidence_guided_noising,
|
22 |
noisify_answer_without_remasking
|
@@ -84,10 +83,10 @@ def highlight_tokens(token_ids, answer_start, changed_indices, color):
|
|
84 |
highlighted.append(tok_str)
|
85 |
return "".join(highlighted)
|
86 |
|
87 |
-
def diffusion_chat(question, max_it, pause_length, eos_bias, sharpness,
|
88 |
-
|
89 |
use_permanent_unmasking, noise_clipping, top_p,
|
90 |
-
top_k):
|
91 |
|
92 |
eos_bias = -eos_bias
|
93 |
if question.strip() == "":
|
@@ -105,7 +104,7 @@ def diffusion_chat(question, max_it, pause_length, eos_bias, sharpness,
|
|
105 |
|
106 |
# Initial noising
|
107 |
current_tokens, just_noised_indices = noisify_answer(
|
108 |
-
input_ids, answer_start, tokenizer, threshold=1.0,
|
109 |
)
|
110 |
yield render_html("Iteration 0 (initial noise)",
|
111 |
highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red"))
|
@@ -115,8 +114,10 @@ def diffusion_chat(question, max_it, pause_length, eos_bias, sharpness,
|
|
115 |
prev_decoded = []
|
116 |
|
117 |
unmasked_mask = [False] * len(current_tokens)
|
118 |
-
|
119 |
for i in range(max_it):
|
|
|
|
|
120 |
generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k, eos_bias = eos_bias)
|
121 |
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
|
122 |
|
@@ -156,7 +157,7 @@ def diffusion_chat(question, max_it, pause_length, eos_bias, sharpness,
|
|
156 |
else:
|
157 |
noised_answer, just_noised_indices = noisify_answer(
|
158 |
current_tokens, answer_start, tokenizer,
|
159 |
-
threshold=threshold,
|
160 |
)
|
161 |
|
162 |
for idx in range(answer_start, len(current_tokens)):
|
@@ -178,7 +179,7 @@ def diffusion_chat(question, max_it, pause_length, eos_bias, sharpness,
|
|
178 |
final_ids = answer_ids
|
179 |
|
180 |
final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
|
181 |
-
yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output)
|
182 |
|
183 |
|
184 |
def is_running_on_spaces():
|
@@ -195,7 +196,7 @@ if is_running_on_spaces():
|
|
195 |
)
|
196 |
else:
|
197 |
# Load from local path
|
198 |
-
ckpt_path = "diffusion-model-
|
199 |
|
200 |
model, tokenizer = load_trained_model(checkpoint_path=ckpt_path)
|
201 |
print("✅ Model loaded.")
|
@@ -213,13 +214,13 @@ demo = gr.Interface(
|
|
213 |
gr.Slider(0.01, 5, value=0.01, step=0.01, label="Pause between iteration ↑ = longer pause"),
|
214 |
gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label="Generation length: ↑ = more output tokens by decreasing eos token probability"),
|
215 |
gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Noise decay sharpness: ↓ = more noise in later iterations"),
|
216 |
-
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"),
|
217 |
gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Noise start fraction: ↑ = more noise"),
|
218 |
gr.Checkbox(value=False, label="Use confidence-guided noising"),
|
219 |
gr.Checkbox(value=False, label="Use permanent unmasking"),
|
220 |
gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"),
|
221 |
gr.Slider(1, 1000, value = 3, step = 1, label = "Top-p: ↑ = more random answers"),
|
222 |
-
gr.Slider(0.0, 1.0, value = 1.0, step = 0.01, label = "Top-k: ↑ = more random answers")
|
|
|
223 |
],
|
224 |
outputs=[gr.HTML(label="Diffusion Output")],
|
225 |
title="Diffusion Language Model Chat",
|
|
|
16 |
find_answer_start,
|
17 |
get_noising_schedule,
|
18 |
noisify_answer,
|
|
|
19 |
filter_logits,
|
20 |
confidence_guided_noising,
|
21 |
noisify_answer_without_remasking
|
|
|
83 |
highlighted.append(tok_str)
|
84 |
return "".join(highlighted)
|
85 |
|
86 |
+
def diffusion_chat(question, max_it, pause_length, eos_bias, sharpness,
|
87 |
+
noise_start, use_confidence_noising,
|
88 |
use_permanent_unmasking, noise_clipping, top_p,
|
89 |
+
top_k, added_tokens):
|
90 |
|
91 |
eos_bias = -eos_bias
|
92 |
if question.strip() == "":
|
|
|
104 |
|
105 |
# Initial noising
|
106 |
current_tokens, just_noised_indices = noisify_answer(
|
107 |
+
input_ids, answer_start, tokenizer, threshold=1.0, noise_start=1.0
|
108 |
)
|
109 |
yield render_html("Iteration 0 (initial noise)",
|
110 |
highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red"))
|
|
|
114 |
prev_decoded = []
|
115 |
|
116 |
unmasked_mask = [False] * len(current_tokens)
|
117 |
+
current_tokens = current_tokens[:answer_start]
|
118 |
for i in range(max_it):
|
119 |
+
current_tokens = current_tokens + [mask_token_id] * added_tokens
|
120 |
+
current_tokens = current_tokens[:256] # Ensure we don't exceed the max length
|
121 |
generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k, eos_bias = eos_bias)
|
122 |
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
|
123 |
|
|
|
157 |
else:
|
158 |
noised_answer, just_noised_indices = noisify_answer(
|
159 |
current_tokens, answer_start, tokenizer,
|
160 |
+
threshold=threshold, noise_start=noise_start
|
161 |
)
|
162 |
|
163 |
for idx in range(answer_start, len(current_tokens)):
|
|
|
179 |
final_ids = answer_ids
|
180 |
|
181 |
final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
|
182 |
+
yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output) # type: ignore
|
183 |
|
184 |
|
185 |
def is_running_on_spaces():
|
|
|
196 |
)
|
197 |
else:
|
198 |
# Load from local path
|
199 |
+
ckpt_path = "diffusion-model-3B.pth" # change to your actual local path
|
200 |
|
201 |
model, tokenizer = load_trained_model(checkpoint_path=ckpt_path)
|
202 |
print("✅ Model loaded.")
|
|
|
214 |
gr.Slider(0.01, 5, value=0.01, step=0.01, label="Pause between iteration ↑ = longer pause"),
|
215 |
gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label="Generation length: ↑ = more output tokens by decreasing eos token probability"),
|
216 |
gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Noise decay sharpness: ↓ = more noise in later iterations"),
|
|
|
217 |
gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Noise start fraction: ↑ = more noise"),
|
218 |
gr.Checkbox(value=False, label="Use confidence-guided noising"),
|
219 |
gr.Checkbox(value=False, label="Use permanent unmasking"),
|
220 |
gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"),
|
221 |
gr.Slider(1, 1000, value = 3, step = 1, label = "Top-p: ↑ = more random answers"),
|
222 |
+
gr.Slider(0.0, 1.0, value = 1.0, step = 0.01, label = "Top-k: ↑ = more random answers"),
|
223 |
+
gr.Slider(1, 256, value=256, step=1, label="Semi-autoregressive generation: number of added tokens per iteration"),
|
224 |
],
|
225 |
outputs=[gr.HTML(label="Diffusion Output")],
|
226 |
title="Diffusion Language Model Chat",
|
infer.py
CHANGED
@@ -97,7 +97,7 @@ def get_noising_schedule(i, max_it, sharpness=5.0):
|
|
97 |
x = i / max_it
|
98 |
return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
|
99 |
|
100 |
-
def noisify_answer(input_ids, answer_start, tokenizer, threshold=1.0, clustering=0
|
101 |
noised = input_ids.copy()
|
102 |
answer_len = len(noised) - answer_start
|
103 |
num_to_noise = int(threshold * answer_len * noise_start)
|
@@ -316,7 +316,7 @@ def generate_answer(question: str, model, tokenizer, max_it=16, noise_start=0.5,
|
|
316 |
input_ids += [mask_token] * (max_length - len(input_ids))
|
317 |
|
318 |
ori_tokens = input_ids
|
319 |
-
current_tokens = noisify_answer(ori_tokens, answer_start, threshold=1.0
|
320 |
|
321 |
last_tokens = []
|
322 |
for step in range(max_it):
|
@@ -344,6 +344,6 @@ def generate_answer(question: str, model, tokenizer, max_it=16, noise_start=0.5,
|
|
344 |
# Re-apply noise for next iteration
|
345 |
if step < max_it - 1:
|
346 |
threshold = noise_start * get_noising_schedule(step, max_it, sharpness=noising_sharpness)
|
347 |
-
current_tokens = noisify_answer(current_tokens, answer_start, threshold=threshold
|
348 |
|
349 |
return tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).strip()
|
|
|
97 |
x = i / max_it
|
98 |
return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
|
99 |
|
100 |
+
def noisify_answer(input_ids, answer_start, tokenizer, threshold=1.0, clustering=0, noise_start = 1.0):
|
101 |
noised = input_ids.copy()
|
102 |
answer_len = len(noised) - answer_start
|
103 |
num_to_noise = int(threshold * answer_len * noise_start)
|
|
|
316 |
input_ids += [mask_token] * (max_length - len(input_ids))
|
317 |
|
318 |
ori_tokens = input_ids
|
319 |
+
current_tokens = noisify_answer(ori_tokens, answer_start, tokenizer, threshold=1.0)
|
320 |
|
321 |
last_tokens = []
|
322 |
for step in range(max_it):
|
|
|
344 |
# Re-apply noise for next iteration
|
345 |
if step < max_it - 1:
|
346 |
threshold = noise_start * get_noising_schedule(step, max_it, sharpness=noising_sharpness)
|
347 |
+
current_tokens = noisify_answer(current_tokens, answer_start, tokenizer, threshold=threshold)
|
348 |
|
349 |
return tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).strip()
|