Spaces:
Running on Zero

tini / app.py
Ruurd's picture
Fix clamping and introduce top-k and top-p filtering
b5f844d verified
raw
history blame
13.1 kB
import gradio as gr
import torch
import numpy as np
import json
import time
from transformers import AutoTokenizer
import os
import importlib
from huggingface_hub import hf_hub_download
from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, disable_dropout
import spaces
hf_token = os.getenv("HF_TOKEN")
# --- Load tokenizer ---
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
vocab_size = len(tokenizer)
eos_token_id = tokenizer.eos_token_id
mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
# def load_model():
# ckpt_path = hf_hub_download(
# repo_id="ruurd/tini_bi_m",
# filename="diffusion-model.pth",
# token=os.getenv("HF_TOKEN")
# )
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = torch.load(ckpt_path, map_location=device)
# model = disable_dropout(model)
# model.to(device)
# model.eval()
# return model
def load_model():
ckpt_path = hf_hub_download(
repo_id="ruurd/tini_model",
filename="diffusion-model.pth",
token=os.getenv("HF_TOKEN"),
# revision="xxx",
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Step 1: Create model from scratch
model = CustomTransformerModel(CustomTransformerConfig())
# Step 2: Load state_dict from full checkpoint
full_model = torch.load(ckpt_path, map_location=device)
# This handles both full model or just state_dict
try:
state_dict = full_model.state_dict()
except AttributeError:
state_dict = full_model # already a state_dict
# Step 3: Load weights (might print mismatches)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)
model = disable_dropout(model)
model.to(device)
model.eval()
return model
rng = np.random.default_rng()
# --- Utility Functions ---
def decode_tokens_safe(token_ids):
return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ")
def find_answer_start(input_ids, marker_ids):
for i in range(len(input_ids) - len(marker_ids) + 1):
if input_ids[i:i + len(marker_ids)] == marker_ids:
return i + len(marker_ids)
return None
def get_noising_schedule(i, max_it, sharpness=5.0):
x = i / max_it
return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
def noisify_answer(input_ids, answer_start, threshold=1.0, clustering=0.5, noise_start = 1.0):
noised = input_ids.copy()
answer_len = len(noised) - answer_start
num_to_noise = int(threshold * answer_len * noise_start)
mask_token_id = tokenizer.encode('MASK', add_special_tokens = False)[0]
if num_to_noise == 0:
return noised, []
num_clusters = max(1, int((1 - clustering) * num_to_noise))
cluster_size = max(1, int(num_to_noise / num_clusters))
noised_indices = set()
for _ in range(num_clusters):
center = rng.integers(answer_start, len(noised))
span_start = max(answer_start, center - cluster_size // 2)
span_end = min(len(noised), span_start + cluster_size)
noised_indices.update(range(span_start, span_end))
noised_indices = sorted(list(noised_indices))[:num_to_noise]
for idx in noised_indices:
noised[idx] = mask_token_id
return noised, noised_indices
# Add new noising function
def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, noise_start=1.0):
noised = input_ids.copy()
answer_len = len(input_ids) - answer_start
num_to_noise = int(threshold * answer_len * noise_start)
if num_to_noise == 0:
return noised, []
all_indices = np.arange(answer_start, len(input_ids))
eos_indices = [i for i in all_indices if input_ids[i] == eos_token_id]
non_eos_indices = [i for i in all_indices if input_ids[i] != eos_token_id]
# Proportionally split how many to noise
num_non_eos_to_noise = int(num_to_noise * len(non_eos_indices) / (len(non_eos_indices) + len(eos_indices) + 1e-5))
num_eos_to_noise = num_to_noise - num_non_eos_to_noise
noised_indices = []
# --- Non-EOS ---
if non_eos_indices:
raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in non_eos_indices])
raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None)
weights = raw_weights / raw_weights.sum()
chosen = rng.choice(non_eos_indices, size=min(num_non_eos_to_noise, len(non_eos_indices)), replace=False, p=weights)
noised_indices.extend(chosen.tolist())
# --- EOS ---
if eos_indices and num_eos_to_noise > 0:
raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in eos_indices])
raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None)
weights = raw_weights / raw_weights.sum()
chosen = rng.choice(eos_indices, size=min(num_eos_to_noise, len(eos_indices)), replace=False, p=weights)
noised_indices.extend(chosen.tolist())
for idx in noised_indices:
noised[idx] = mask_token_id
noised_indices = sorted(noised_indices)
return noised, noised_indices
def filter_logits(logits, top_k=0, top_p=0.0):
"""Filter logits per position for top-k / nucleus (top-p) sampling."""
logits = logits.clone() # don't modify in-place
batch_size, seq_len, vocab_size = logits.shape
for i in range(seq_len):
token_logits = logits[0, i]
if top_k > 0:
top_values, _ = torch.topk(token_logits, top_k)
threshold = top_values[-1]
token_logits[token_logits < threshold] = float("-inf")
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(token_logits, descending=True)
cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = 0 # always keep at least 1 token
token_logits[sorted_indices[sorted_indices_to_remove]] = float("-inf")
logits[0, i] = token_logits
return logits
@spaces.GPU
def generate_diffusion_text(input_ids):
with torch.no_grad():
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
logits = model(input_ids=input_tensor)["logits"]
logits = filter_logits(logits, top_k=top_k, top_p=top_p)
logits = logits.clamp(min=-1e8, max=1e4)
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
probs = torch.clamp(probs, min=1e-8, max=1.0)
assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
assert (probs >= 0).all(), "Negative probs!"
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()
# Extract confidence of selected tokens
conf = probs[range(len(sampled)), sampled].cpu().numpy()
return sampled, conf
# --- Inference Wrapper ---
def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_start, use_confidence_noising, noise_clipping):
placeholder = "What do you know about the city of Amsterdam?"
if question.strip() == "":
question = placeholder
print('started generation')
prompt = f"User: {question}\nAssistant:"
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
answer_start = find_answer_start(input_ids, assistant_marker_ids)
if answer_start is None:
yield "Error: Could not find Assistant marker in input."
return
if len(input_ids) < 256:
input_ids += [mask_token_id] * (256 - len(input_ids))
else:
input_ids = input_ids[:256]
ori_input_tokens = input_ids
current_tokens, just_noised_indices = noisify_answer(
input_ids, answer_start, threshold=1.0, clustering=clustering, noise_start = 1.0,
)
yield f"<b>Iteration 0 (initial noise):</b><br>" + tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).replace('\n', '<br>')
time.sleep(pause_length)
last_tokens = []
prev_decoded_tokens = []
for i in range(max_it):
print('Generating output')
# Model step
generated_tokens, confidences = generate_diffusion_text(current_tokens)
# Save full output for noising step
current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:]
# --- GREEN HIGHLIGHT ---
decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
highlighted = []
for j, tok in enumerate(decoded_tokens):
tok_id = tokenizer.convert_tokens_to_ids(tok)
if tok_id == eos_token_id:
continue
token_str = tokenizer.convert_tokens_to_string([tok])
if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
highlighted.append(f'<span style="color:green">{token_str}</span>')
else:
highlighted.append(token_str)
prev_decoded_tokens = decoded_tokens
yield f"<b>Iteration {i+1}/{max_it} (after generation):</b><br>" + "".join(highlighted).replace('\n', '<br>')
time.sleep(pause_length)
# --- Early stopping ---
last_tokens.append(current_tokens)
if len(last_tokens) > 3:
last_tokens.pop(0)
if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
yield f"<b>Stopped early after {i+1} iterations.</b>"
break
previous_tokens = current_tokens.copy()
# --- NOISING STEP ---
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
if use_confidence_noising:
noised_answer, just_noised_indices = confidence_guided_noising(
current_tokens, answer_start, confidences, noise_clipping, threshold=threshold, noise_start=noise_start
)
just_noised_indices = []
else:
noised_answer, just_noised_indices = noisify_answer(
current_tokens, answer_start, threshold=threshold, clustering=clustering, noise_start = noise_start,
)
# Compose full input again: prompt + noised answer
current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:]
# --- RED HIGHLIGHT ---
decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
highlighted = []
for j, tok in enumerate(decoded_tokens):
tok_id = tokenizer.convert_tokens_to_ids(tok)
if tok_id == eos_token_id:
continue
token_str = tokenizer.convert_tokens_to_string([tok])
abs_idx = answer_start + j
if abs_idx in just_noised_indices:
highlighted.append(f'<span style="color:red">{token_str}</span>')
else:
highlighted.append(token_str)
yield f"<b>Iteration {i+1}/{max_it} (before noising):</b><br>" + "".join(highlighted).replace('\n', '<br>')
time.sleep(pause_length)
answer_ids = current_tokens[answer_start:]
try:
eos_index = answer_ids.index(eos_token_id)
final_ids = answer_ids[:eos_index]
except ValueError:
final_ids = answer_ids
num_tokens = len(final_ids)
final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
print(final_output)
yield f"<b>Final Output ({num_tokens} tokens after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')
# --- Gradio Interface ---
print("Loading model...")
model = load_model()
print("✅ Model loaded.")
demo = gr.Interface(
fn=diffusion_chat,
inputs=[
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
gr.Slider(0.01, 5, value=0.01, step=0.01, label="↑ = longer pause (for visualization)"),
gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="↓ = more noising (sharpness)"),
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="↑ = more clustered noising (fewer, larger edits)"),
gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="↑ = more noise (noise start)"),
gr.Checkbox(value=False, label="Use confidence-guided noising"),
gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="↓ = more confidence guidance (noise clipping)"),
],
outputs=[gr.HTML(label="Diffusion Output")],
title="Diffusion Language Model Chat",
theme="default",
description="This interface runs a diffusion-based language model to generate answers progressively."
)
demo.launch(share=True, allowed_paths=["."], ssr_mode=False)