Spaces:
Running on Zero

Ruurd commited on
Commit
7346e83
·
1 Parent(s): a8d72d4

Add semi-autoregressive generation

Browse files
Files changed (2) hide show
  1. app.py +12 -11
  2. infer.py +3 -3
app.py CHANGED
@@ -16,7 +16,6 @@ from infer import (
16
  find_answer_start,
17
  get_noising_schedule,
18
  noisify_answer,
19
- generate_diffusion_text,
20
  filter_logits,
21
  confidence_guided_noising,
22
  noisify_answer_without_remasking
@@ -84,10 +83,10 @@ def highlight_tokens(token_ids, answer_start, changed_indices, color):
84
  highlighted.append(tok_str)
85
  return "".join(highlighted)
86
 
87
- def diffusion_chat(question, max_it, pause_length, eos_bias, sharpness,
88
- clustering, noise_start, use_confidence_noising,
89
  use_permanent_unmasking, noise_clipping, top_p,
90
- top_k):
91
 
92
  eos_bias = -eos_bias
93
  if question.strip() == "":
@@ -105,7 +104,7 @@ def diffusion_chat(question, max_it, pause_length, eos_bias, sharpness,
105
 
106
  # Initial noising
107
  current_tokens, just_noised_indices = noisify_answer(
108
- input_ids, answer_start, tokenizer, threshold=1.0, clustering=clustering, noise_start=1.0
109
  )
110
  yield render_html("Iteration 0 (initial noise)",
111
  highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red"))
@@ -115,8 +114,10 @@ def diffusion_chat(question, max_it, pause_length, eos_bias, sharpness,
115
  prev_decoded = []
116
 
117
  unmasked_mask = [False] * len(current_tokens)
118
-
119
  for i in range(max_it):
 
 
120
  generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k, eos_bias = eos_bias)
121
  current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
122
 
@@ -156,7 +157,7 @@ def diffusion_chat(question, max_it, pause_length, eos_bias, sharpness,
156
  else:
157
  noised_answer, just_noised_indices = noisify_answer(
158
  current_tokens, answer_start, tokenizer,
159
- threshold=threshold, clustering=clustering, noise_start=noise_start
160
  )
161
 
162
  for idx in range(answer_start, len(current_tokens)):
@@ -178,7 +179,7 @@ def diffusion_chat(question, max_it, pause_length, eos_bias, sharpness,
178
  final_ids = answer_ids
179
 
180
  final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
181
- yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output)
182
 
183
 
184
  def is_running_on_spaces():
@@ -195,7 +196,7 @@ if is_running_on_spaces():
195
  )
196
  else:
197
  # Load from local path
198
- ckpt_path = "diffusion-model-8B.pth" # change to your actual local path
199
 
200
  model, tokenizer = load_trained_model(checkpoint_path=ckpt_path)
201
  print("✅ Model loaded.")
@@ -213,13 +214,13 @@ demo = gr.Interface(
213
  gr.Slider(0.01, 5, value=0.01, step=0.01, label="Pause between iteration ↑ = longer pause"),
214
  gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label="Generation length: ↑ = more output tokens by decreasing eos token probability"),
215
  gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Noise decay sharpness: ↓ = more noise in later iterations"),
216
- gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"),
217
  gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Noise start fraction: ↑ = more noise"),
218
  gr.Checkbox(value=False, label="Use confidence-guided noising"),
219
  gr.Checkbox(value=False, label="Use permanent unmasking"),
220
  gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"),
221
  gr.Slider(1, 1000, value = 3, step = 1, label = "Top-p: ↑ = more random answers"),
222
- gr.Slider(0.0, 1.0, value = 1.0, step = 0.01, label = "Top-k: ↑ = more random answers")
 
223
  ],
224
  outputs=[gr.HTML(label="Diffusion Output")],
225
  title="Diffusion Language Model Chat",
 
16
  find_answer_start,
17
  get_noising_schedule,
18
  noisify_answer,
 
19
  filter_logits,
20
  confidence_guided_noising,
21
  noisify_answer_without_remasking
 
83
  highlighted.append(tok_str)
84
  return "".join(highlighted)
85
 
86
+ def diffusion_chat(question, max_it, pause_length, eos_bias, sharpness,
87
+ noise_start, use_confidence_noising,
88
  use_permanent_unmasking, noise_clipping, top_p,
89
+ top_k, added_tokens):
90
 
91
  eos_bias = -eos_bias
92
  if question.strip() == "":
 
104
 
105
  # Initial noising
106
  current_tokens, just_noised_indices = noisify_answer(
107
+ input_ids, answer_start, tokenizer, threshold=1.0, noise_start=1.0
108
  )
109
  yield render_html("Iteration 0 (initial noise)",
110
  highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red"))
 
114
  prev_decoded = []
115
 
116
  unmasked_mask = [False] * len(current_tokens)
117
+ current_tokens = current_tokens[:answer_start]
118
  for i in range(max_it):
119
+ current_tokens = current_tokens + [mask_token_id] * added_tokens
120
+ current_tokens = current_tokens[:256] # Ensure we don't exceed the max length
121
  generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k, eos_bias = eos_bias)
122
  current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
123
 
 
157
  else:
158
  noised_answer, just_noised_indices = noisify_answer(
159
  current_tokens, answer_start, tokenizer,
160
+ threshold=threshold, noise_start=noise_start
161
  )
162
 
163
  for idx in range(answer_start, len(current_tokens)):
 
179
  final_ids = answer_ids
180
 
181
  final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
182
+ yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output) # type: ignore
183
 
184
 
185
  def is_running_on_spaces():
 
196
  )
197
  else:
198
  # Load from local path
199
+ ckpt_path = "diffusion-model-3B.pth" # change to your actual local path
200
 
201
  model, tokenizer = load_trained_model(checkpoint_path=ckpt_path)
202
  print("✅ Model loaded.")
 
214
  gr.Slider(0.01, 5, value=0.01, step=0.01, label="Pause between iteration ↑ = longer pause"),
215
  gr.Slider(-5.0, 5.0, value=0.0, step=0.1, label="Generation length: ↑ = more output tokens by decreasing eos token probability"),
216
  gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Noise decay sharpness: ↓ = more noise in later iterations"),
 
217
  gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Noise start fraction: ↑ = more noise"),
218
  gr.Checkbox(value=False, label="Use confidence-guided noising"),
219
  gr.Checkbox(value=False, label="Use permanent unmasking"),
220
  gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"),
221
  gr.Slider(1, 1000, value = 3, step = 1, label = "Top-p: ↑ = more random answers"),
222
+ gr.Slider(0.0, 1.0, value = 1.0, step = 0.01, label = "Top-k: ↑ = more random answers"),
223
+ gr.Slider(1, 256, value=256, step=1, label="Semi-autoregressive generation: number of added tokens per iteration"),
224
  ],
225
  outputs=[gr.HTML(label="Diffusion Output")],
226
  title="Diffusion Language Model Chat",
infer.py CHANGED
@@ -97,7 +97,7 @@ def get_noising_schedule(i, max_it, sharpness=5.0):
97
  x = i / max_it
98
  return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
99
 
100
- def noisify_answer(input_ids, answer_start, tokenizer, threshold=1.0, clustering=0.5, noise_start = 1.0):
101
  noised = input_ids.copy()
102
  answer_len = len(noised) - answer_start
103
  num_to_noise = int(threshold * answer_len * noise_start)
@@ -316,7 +316,7 @@ def generate_answer(question: str, model, tokenizer, max_it=16, noise_start=0.5,
316
  input_ids += [mask_token] * (max_length - len(input_ids))
317
 
318
  ori_tokens = input_ids
319
- current_tokens = noisify_answer(ori_tokens, answer_start, threshold=1.0, mask_token_id=mask_token)
320
 
321
  last_tokens = []
322
  for step in range(max_it):
@@ -344,6 +344,6 @@ def generate_answer(question: str, model, tokenizer, max_it=16, noise_start=0.5,
344
  # Re-apply noise for next iteration
345
  if step < max_it - 1:
346
  threshold = noise_start * get_noising_schedule(step, max_it, sharpness=noising_sharpness)
347
- current_tokens = noisify_answer(current_tokens, answer_start, threshold=threshold, mask_token_id=mask_token)
348
 
349
  return tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).strip()
 
97
  x = i / max_it
98
  return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
99
 
100
+ def noisify_answer(input_ids, answer_start, tokenizer, threshold=1.0, clustering=0, noise_start = 1.0):
101
  noised = input_ids.copy()
102
  answer_len = len(noised) - answer_start
103
  num_to_noise = int(threshold * answer_len * noise_start)
 
316
  input_ids += [mask_token] * (max_length - len(input_ids))
317
 
318
  ori_tokens = input_ids
319
+ current_tokens = noisify_answer(ori_tokens, answer_start, tokenizer, threshold=1.0)
320
 
321
  last_tokens = []
322
  for step in range(max_it):
 
344
  # Re-apply noise for next iteration
345
  if step < max_it - 1:
346
  threshold = noise_start * get_noising_schedule(step, max_it, sharpness=noising_sharpness)
347
+ current_tokens = noisify_answer(current_tokens, answer_start, tokenizer, threshold=threshold)
348
 
349
  return tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).strip()