Spaces:
Running on Zero

Ruurd commited on
Commit
86c363a
·
verified ·
1 Parent(s): 3125ce6

Only noise until max_it-1

Browse files
Files changed (1) hide show
  1. app.py +28 -27
app.py CHANGED
@@ -135,33 +135,34 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
135
  break
136
 
137
  # NOISING
138
- threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
139
- if use_confidence_noising:
140
- noised_answer, just_noised_indices = confidence_guided_noising(
141
- current_tokens, answer_start, tokenizer, confidences, noise_clipping,
142
- threshold=threshold, noise_start=noise_start
143
- )
144
- elif use_permanent_unmasking:
145
- noised_answer, just_noised_indices = noisify_answer_without_remasking(
146
- current_tokens, answer_start, tokenizer, threshold=threshold,
147
- noise_start=noise_start, unmasked_mask=unmasked_mask
148
- )
149
- else:
150
- noised_answer, just_noised_indices = noisify_answer(
151
- current_tokens, answer_start, tokenizer,
152
- threshold=threshold, clustering=clustering, noise_start=noise_start
153
- )
154
-
155
- for idx in range(answer_start, len(current_tokens)):
156
- if noised_answer[idx] != mask_token_id:
157
- unmasked_mask[idx] = True
158
-
159
-
160
-
161
- yield render_html(f"Iteration {i+1}/{max_it} (before noising)",
162
- highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red"))
163
-
164
- current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
 
165
 
166
  # Final output
167
  answer_ids = current_tokens[answer_start:]
 
135
  break
136
 
137
  # NOISING
138
+ if i < max_it-1:
139
+ threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
140
+ if use_confidence_noising:
141
+ noised_answer, just_noised_indices = confidence_guided_noising(
142
+ current_tokens, answer_start, tokenizer, confidences, noise_clipping,
143
+ threshold=threshold, noise_start=noise_start
144
+ )
145
+ elif use_permanent_unmasking:
146
+ noised_answer, just_noised_indices = noisify_answer_without_remasking(
147
+ current_tokens, answer_start, tokenizer, threshold=threshold,
148
+ noise_start=noise_start, unmasked_mask=unmasked_mask
149
+ )
150
+ else:
151
+ noised_answer, just_noised_indices = noisify_answer(
152
+ current_tokens, answer_start, tokenizer,
153
+ threshold=threshold, clustering=clustering, noise_start=noise_start
154
+ )
155
+
156
+ for idx in range(answer_start, len(current_tokens)):
157
+ if noised_answer[idx] != mask_token_id:
158
+ unmasked_mask[idx] = True
159
+
160
+
161
+
162
+ yield render_html(f"Iteration {i+1}/{max_it} (before noising)",
163
+ highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red"))
164
+
165
+ current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
166
 
167
  # Final output
168
  answer_ids = current_tokens[answer_start:]