Spaces:
Running on Zero

Ruurd commited on
Commit
d86917b
·
verified ·
1 Parent(s): 7065c9f

Improve confidence guided noising and show number of tokens generated

Browse files
Files changed (1) hide show
  1. app.py +46 -25
app.py CHANGED
@@ -110,35 +110,48 @@ def noisify_answer(input_ids, answer_start, threshold=1.0, clustering=0.5, noise
110
 
111
 
112
  # Add new noising function
113
- def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, noise_start = 1.0):
114
  noised = input_ids.copy()
115
  answer_len = len(input_ids) - answer_start
116
  num_to_noise = int(threshold * answer_len * noise_start)
117
-
118
  if num_to_noise == 0:
119
  return noised
120
 
121
-
122
- raw_weights = 1.0 - np.array(confidences[answer_start:])
123
-
124
- # Avoid zero-probability weights for selection
125
- # If noise clipping == 1, all tokens have equal chance to be noised.
126
- # If noise_clipping == 0.00001, all tokens are noised according to the confidence of the past prediction
127
- raw_weights = np.clip(raw_weights, a_min = noise_clipping, a_max = None)
128
 
129
- weights = raw_weights / raw_weights.sum()
 
130
 
131
- if num_to_noise > len(weights):
132
- num_to_noise = len(weights) # prevent oversampling
 
 
133
 
134
- indices = rng.choice(
135
- np.arange(answer_start, len(input_ids)),
136
- size=num_to_noise,
137
  replace=False,
138
- p=weights
139
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- for idx in indices:
142
  noised[idx] = mask_token_id
143
 
144
  return noised
@@ -256,11 +269,19 @@ def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_
256
  time.sleep(pause_length)
257
 
258
 
259
- final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
260
- final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eos_token_id]
261
- final_output = tokenizer.convert_tokens_to_string(final_tokens)
 
 
 
 
 
 
 
262
  print(final_output)
263
- yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')
 
264
 
265
  # --- Gradio Interface ---
266
  print("Loading model...")
@@ -271,11 +292,11 @@ demo = gr.Interface(
271
  fn=diffusion_chat,
272
  inputs=[
273
  gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
274
- gr.Slider(1, 512, value=32, step=1, label="↑ = more iterations"),
275
  gr.Slider(0.01, 5, value=0.01, step=0.01, label="↑ = longer pause (for visualization)"),
276
- gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)"),
277
  gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="↑ = more clustered noising (fewer, larger edits)"),
278
- gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="↑ = more noise (noise start)"),
279
  gr.Checkbox(value=False, label="Use confidence-guided noising"),
280
  gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="↓ = more confidence guidance (noise clipping)"),
281
 
 
110
 
111
 
112
  # Add new noising function
113
+ def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, noise_start=1.0):
114
  noised = input_ids.copy()
115
  answer_len = len(input_ids) - answer_start
116
  num_to_noise = int(threshold * answer_len * noise_start)
 
117
  if num_to_noise == 0:
118
  return noised
119
 
120
+ all_indices = np.arange(answer_start, len(input_ids))
121
+ eos_indices = [i for i in all_indices if input_ids[i] == eos_token_id]
122
+ non_eos_indices = [i for i in all_indices if input_ids[i] != eos_token_id]
 
 
 
 
123
 
124
+ num_non_eos_to_noise = int(num_to_noise * (len(non_eos_indices) / (len(non_eos_indices) + len(eos_indices) + 1e-5)))
125
+ num_eos_to_noise = num_to_noise - num_non_eos_to_noise
126
 
127
+ # === Non-EOS sampling ===
128
+ raw_weights_non_eos = 1.0 - np.array([confidences[i - answer_start] for i in non_eos_indices])
129
+ raw_weights_non_eos = np.clip(raw_weights_non_eos, a_min=noise_clipping, a_max=None)
130
+ weights_non_eos = raw_weights_non_eos / raw_weights_non_eos.sum() if raw_weights_non_eos.sum() > 0 else None
131
 
132
+ chosen_non_eos = rng.choice(
133
+ non_eos_indices,
134
+ size=min(num_non_eos_to_noise, len(non_eos_indices)),
135
  replace=False,
136
+ p=weights_non_eos
137
+ ) if weights_non_eos is not None else []
138
+
139
+ # === EOS sampling ===
140
+ if eos_indices:
141
+ raw_weights_eos = 1.0 - np.array([confidences[i - answer_start] for i in eos_indices])
142
+ raw_weights_eos = np.clip(raw_weights_eos, a_min=noise_clipping, a_max=None)
143
+ weights_eos = raw_weights_eos / raw_weights_eos.sum() if raw_weights_eos.sum() > 0 else None
144
+
145
+ chosen_eos = rng.choice(
146
+ eos_indices,
147
+ size=min(num_eos_to_noise, len(eos_indices)),
148
+ replace=False,
149
+ p=weights_eos
150
+ ) if weights_eos is not None else []
151
+ else:
152
+ chosen_eos = []
153
 
154
+ for idx in list(chosen_non_eos) + list(chosen_eos):
155
  noised[idx] = mask_token_id
156
 
157
  return noised
 
269
  time.sleep(pause_length)
270
 
271
 
272
+ answer_ids = current_tokens[answer_start:]
273
+ try:
274
+ eos_index = answer_ids.index(eos_token_id)
275
+ final_ids = answer_ids[:eos_index]
276
+ except ValueError:
277
+ final_ids = answer_ids
278
+
279
+ num_tokens = len(final_ids)
280
+ final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
281
+
282
  print(final_output)
283
+ yield f"<b>Final Output ({num_tokens} tokens after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')
284
+
285
 
286
  # --- Gradio Interface ---
287
  print("Loading model...")
 
292
  fn=diffusion_chat,
293
  inputs=[
294
  gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
295
+ gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
296
  gr.Slider(0.01, 5, value=0.01, step=0.01, label="↑ = longer pause (for visualization)"),
297
+ gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="↓ = more noising (sharpness)"),
298
  gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="↑ = more clustered noising (fewer, larger edits)"),
299
+ gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="↑ = more noise (noise start)"),
300
  gr.Checkbox(value=False, label="Use confidence-guided noising"),
301
  gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="↓ = more confidence guidance (noise clipping)"),
302