Spaces:
Running
on
Zero
Running
on
Zero
Fix clamping and introduce top-k and top-p filtering
Browse files
app.py
CHANGED
@@ -151,13 +151,40 @@ def confidence_guided_noising(input_ids, answer_start, confidences, noise_clippi
|
|
151 |
noised_indices = sorted(noised_indices)
|
152 |
return noised, noised_indices
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
@spaces.GPU
|
156 |
def generate_diffusion_text(input_ids):
|
157 |
with torch.no_grad():
|
158 |
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
159 |
logits = model(input_ids=input_tensor)["logits"]
|
160 |
-
logits = logits
|
|
|
161 |
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
162 |
probs = torch.clamp(probs, min=1e-8, max=1.0)
|
163 |
assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
|
|
|
151 |
noised_indices = sorted(noised_indices)
|
152 |
return noised, noised_indices
|
153 |
|
154 |
+
def filter_logits(logits, top_k=0, top_p=0.0):
|
155 |
+
"""Filter logits per position for top-k / nucleus (top-p) sampling."""
|
156 |
+
logits = logits.clone() # don't modify in-place
|
157 |
+
batch_size, seq_len, vocab_size = logits.shape
|
158 |
+
|
159 |
+
for i in range(seq_len):
|
160 |
+
token_logits = logits[0, i]
|
161 |
+
|
162 |
+
if top_k > 0:
|
163 |
+
top_values, _ = torch.topk(token_logits, top_k)
|
164 |
+
threshold = top_values[-1]
|
165 |
+
token_logits[token_logits < threshold] = float("-inf")
|
166 |
+
|
167 |
+
if top_p > 0.0:
|
168 |
+
sorted_logits, sorted_indices = torch.sort(token_logits, descending=True)
|
169 |
+
cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
|
170 |
+
|
171 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
172 |
+
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
|
173 |
+
sorted_indices_to_remove[0] = 0 # always keep at least 1 token
|
174 |
+
|
175 |
+
token_logits[sorted_indices[sorted_indices_to_remove]] = float("-inf")
|
176 |
+
|
177 |
+
logits[0, i] = token_logits
|
178 |
+
|
179 |
+
return logits
|
180 |
|
181 |
@spaces.GPU
|
182 |
def generate_diffusion_text(input_ids):
|
183 |
with torch.no_grad():
|
184 |
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
185 |
logits = model(input_ids=input_tensor)["logits"]
|
186 |
+
logits = filter_logits(logits, top_k=top_k, top_p=top_p)
|
187 |
+
logits = logits.clamp(min=-1e8, max=1e4)
|
188 |
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
189 |
probs = torch.clamp(probs, min=1e-8, max=1.0)
|
190 |
assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
|