Spaces:
Running
on
Zero
Running
on
Zero
Fix generation
Browse files
app.py
CHANGED
@@ -130,7 +130,7 @@ def confidence_guided_noising(input_ids, answer_start, confidences, threshold, e
|
|
130 |
|
131 |
|
132 |
@spaces.GPU
|
133 |
-
def generate_diffusion_text(input_ids
|
134 |
with torch.no_grad():
|
135 |
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
136 |
logits = model(input_ids=input_tensor)["logits"]
|
@@ -170,15 +170,24 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
|
|
170 |
|
171 |
for i in range(max_it):
|
172 |
print('Generating output')
|
173 |
-
generated_tokens, confidences = generate_diffusion_text(current_tokens, answer_start)
|
174 |
-
current_tokens = generated_tokens
|
175 |
|
176 |
-
#
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
179 |
|
|
|
|
|
|
|
|
|
|
|
180 |
highlighted = []
|
181 |
for j, tok in enumerate(decoded_tokens):
|
|
|
|
|
|
|
182 |
token_str = tokenizer.convert_tokens_to_string([tok])
|
183 |
if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
|
184 |
highlighted.append(f'<span style="color:green">{token_str}</span>')
|
@@ -189,27 +198,29 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
|
|
189 |
yield f"<b>Iteration {i+1}/{max_it} (after generation):</b><br>" + "".join(highlighted).replace('\n', '<br>')
|
190 |
time.sleep(0.1)
|
191 |
|
192 |
-
# ---
|
193 |
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
|
194 |
if use_confidence_noising:
|
195 |
-
|
196 |
generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping
|
197 |
)
|
198 |
-
just_noised_indices = []
|
199 |
else:
|
200 |
-
|
201 |
generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, clustering=clustering
|
202 |
)
|
203 |
|
204 |
-
|
205 |
-
|
|
|
206 |
|
|
|
|
|
207 |
highlighted = []
|
208 |
for j, tok in enumerate(decoded_tokens):
|
209 |
tok_id = tokenizer.convert_tokens_to_ids(tok)
|
210 |
if tok_id == eot_token_id:
|
211 |
-
continue
|
212 |
-
|
213 |
token_str = tokenizer.convert_tokens_to_string([tok])
|
214 |
abs_idx = answer_start + j
|
215 |
if abs_idx in just_noised_indices:
|
@@ -228,8 +239,6 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
|
|
228 |
yield f"<b>Stopped early after {i+1} iterations.</b>"
|
229 |
break
|
230 |
|
231 |
-
|
232 |
-
|
233 |
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
|
234 |
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
|
235 |
final_output = tokenizer.convert_tokens_to_string(final_tokens)
|
|
|
130 |
|
131 |
|
132 |
@spaces.GPU
|
133 |
+
def generate_diffusion_text(input_ids):
|
134 |
with torch.no_grad():
|
135 |
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
136 |
logits = model(input_ids=input_tensor)["logits"]
|
|
|
170 |
|
171 |
for i in range(max_it):
|
172 |
print('Generating output')
|
|
|
|
|
173 |
|
174 |
+
# Compose full input: original prompt + current answer
|
175 |
+
full_input_tokens = ori_input_tokens[:answer_start] + current_tokens[answer_start:]
|
176 |
+
full_input_tokens = full_input_tokens[:256] + [pad_token] * max(0, 256 - len(full_input_tokens))
|
177 |
+
|
178 |
+
# Model step
|
179 |
+
generated_tokens, confidences = generate_diffusion_text(full_input_tokens)
|
180 |
|
181 |
+
# Save full output for noising step
|
182 |
+
current_tokens = generated_tokens
|
183 |
+
|
184 |
+
# --- GREEN HIGHLIGHT ---
|
185 |
+
decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
|
186 |
highlighted = []
|
187 |
for j, tok in enumerate(decoded_tokens):
|
188 |
+
tok_id = tokenizer.convert_tokens_to_ids(tok)
|
189 |
+
if tok_id == eot_token_id:
|
190 |
+
continue
|
191 |
token_str = tokenizer.convert_tokens_to_string([tok])
|
192 |
if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
|
193 |
highlighted.append(f'<span style="color:green">{token_str}</span>')
|
|
|
198 |
yield f"<b>Iteration {i+1}/{max_it} (after generation):</b><br>" + "".join(highlighted).replace('\n', '<br>')
|
199 |
time.sleep(0.1)
|
200 |
|
201 |
+
# --- NOISING STEP ---
|
202 |
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
|
203 |
if use_confidence_noising:
|
204 |
+
noised_answer = confidence_guided_noising(
|
205 |
generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping
|
206 |
)
|
207 |
+
just_noised_indices = []
|
208 |
else:
|
209 |
+
noised_answer, just_noised_indices = noisify_answer(
|
210 |
generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, clustering=clustering
|
211 |
)
|
212 |
|
213 |
+
# Compose full input again: prompt + noised answer
|
214 |
+
current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
|
215 |
+
current_tokens = current_tokens[:256] + [pad_token] * max(0, 256 - len(current_tokens))
|
216 |
|
217 |
+
# --- RED HIGHLIGHT ---
|
218 |
+
decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
|
219 |
highlighted = []
|
220 |
for j, tok in enumerate(decoded_tokens):
|
221 |
tok_id = tokenizer.convert_tokens_to_ids(tok)
|
222 |
if tok_id == eot_token_id:
|
223 |
+
continue
|
|
|
224 |
token_str = tokenizer.convert_tokens_to_string([tok])
|
225 |
abs_idx = answer_start + j
|
226 |
if abs_idx in just_noised_indices:
|
|
|
239 |
yield f"<b>Stopped early after {i+1} iterations.</b>"
|
240 |
break
|
241 |
|
|
|
|
|
242 |
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
|
243 |
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
|
244 |
final_output = tokenizer.convert_tokens_to_string(final_tokens)
|