Spaces:
Running on Zero

lad / app.py
Ruurd's picture
Fix red higlighting
9756472
raw
history blame
9.99 kB
import gradio as gr
import torch
import numpy as np
import json
import time
from transformers import AutoTokenizer
import os
import importlib
from huggingface_hub import hf_hub_download
from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, BidirectionalLlamaAttention, disable_dropout
import spaces
hf_token = os.getenv("HF_TOKEN")
# --- Load tokenizer ---
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
vocab_size = len(tokenizer)
pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id
eot_token_id = tokenizer.eos_token_id
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
# --- Load token probabilities ---
with open("token_probabilities.json") as f:
token_probs_dict = json.load(f)
token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
def load_model():
ckpt_path = hf_hub_download(
repo_id="ruurd/tini_model",
filename="diffusion-model.pth",
token=os.getenv("HF_TOKEN")
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load(ckpt_path, map_location=device)
model = disable_dropout(model)
model.to(device)
model.eval()
return model
rng = np.random.default_rng()
# --- Utility Functions ---
def decode_tokens_safe(token_ids):
return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ")
def find_answer_start(input_ids, marker_ids):
for i in range(len(input_ids) - len(marker_ids) + 1):
if input_ids[i:i + len(marker_ids)] == marker_ids:
return i + len(marker_ids)
return None
def get_noising_schedule(i, max_it, sharpness=5.0):
x = i / max_it
return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0, clustering=0.5):
noised = input_ids.copy()
answer_len = len(noised) - answer_start
num_to_noise = int(threshold * answer_len)
if num_to_noise == 0:
return noised, []
mixed_probs = token_probabilities.copy()
mixed_probs[eot_token_id] *= eot_weight
mixed_probs /= mixed_probs.sum()
num_clusters = max(1, int((1 - clustering) * num_to_noise))
cluster_size = max(1, int(num_to_noise / num_clusters))
noised_indices = set()
for _ in range(num_clusters):
center = rng.integers(answer_start, len(noised))
span_start = max(answer_start, center - cluster_size // 2)
span_end = min(len(noised), span_start + cluster_size)
noised_indices.update(range(span_start, span_end))
noised_indices = sorted(list(noised_indices))[:num_to_noise]
noise = rng.choice(np.arange(vocab_size), size=len(noised_indices), p=mixed_probs)
for idx, val in zip(noised_indices, noise):
noised[idx] = val
return noised, noised_indices
# Add new noising function
def confidence_guided_noising(input_ids, answer_start, confidences, threshold, eot_weight, noise_clipping):
noised = input_ids.copy()
answer_len = len(input_ids) - answer_start
num_to_noise = int(threshold * answer_len)
if num_to_noise == 0:
return noised
raw_weights = 1.0 - np.array(confidences[answer_start:])
# Avoid zero-probability weights for selection
# If noise clipping == 1, all tokens have equal chance to be noised.
# If noise_clipping == 0.00001, all tokens are noised according to the confidence of the past prediction
raw_weights = np.clip(raw_weights, a_min = noise_clipping, a_max = None)
weights = raw_weights / raw_weights.sum()
if num_to_noise > len(weights):
num_to_noise = len(weights) # prevent oversampling
indices = rng.choice(
np.arange(answer_start, len(input_ids)),
size=num_to_noise,
replace=False,
p=weights
)
mixed_probs = token_probabilities.copy()
mixed_probs[eot_token_id] *= eot_weight
mixed_probs /= mixed_probs.sum()
noise = rng.choice(np.arange(vocab_size), size=num_to_noise, p=mixed_probs)
for idx, val in zip(indices, noise):
noised[idx] = val
return noised
@spaces.GPU
def generate_diffusion_text(input_ids, answer_start):
with torch.no_grad():
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
logits = model(input_ids=input_tensor)["logits"]
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
probs = torch.clamp(probs, min=1e-8, max=1.0)
sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist()
# Extract confidence of selected tokens
conf = probs[range(len(sampled)), sampled].cpu().numpy()
return sampled, conf
# --- Inference Wrapper ---
def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_confidence_noising, clustering):
placeholder = "What do you know about the city of New York?"
if question.strip() == "":
question = placeholder
print('started generation')
prompt = f"User: {question}\nAssistant:"
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
answer_start = find_answer_start(input_ids, assistant_marker_ids)
if answer_start is None:
yield "Error: Could not find Assistant marker in input."
return
if len(input_ids) < 256:
input_ids += [pad_token] * (256 - len(input_ids))
else:
input_ids = input_ids[:256]
ori_input_tokens = input_ids
current_tokens, just_noised_indices = noisify_answer(
ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight, clustering=clustering
)
last_tokens = []
prev_decoded_tokens = []
for i in range(max_it):
print('Generating output')
generated_tokens, confidences = generate_diffusion_text(current_tokens, answer_start)
current_tokens = generated_tokens
# --- Decode and highlight changed tokens in GREEN ---
decoded_ids = current_tokens[answer_start:]
decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
highlighted = []
for j, tok in enumerate(decoded_tokens):
token_str = tokenizer.convert_tokens_to_string([tok])
if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
highlighted.append(f'<span style="color:green">{token_str}</span>')
else:
highlighted.append(token_str)
prev_decoded_tokens = decoded_tokens
yield f"<b>Iteration {i+1}/{max_it} (after generation):</b><br>" + "".join(highlighted).replace('\n', '<br>')
time.sleep(0.1)
# --- Apply noising and highlight RED tokens ---
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
if use_confidence_noising:
current_tokens = confidence_guided_noising(
generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping
)
just_noised_indices = [] # Optional: could extract from confidence scores
else:
current_tokens, just_noised_indices = noisify_answer(
generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, clustering=clustering
)
decoded_ids = current_tokens[answer_start:]
decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
highlighted = []
for j, tok in enumerate(decoded_tokens):
tok_id = tokenizer.convert_tokens_to_ids(tok)
if tok_id == eot_token_id:
continue # Skip EOT tokens in display
token_str = tokenizer.convert_tokens_to_string([tok])
abs_idx = answer_start + j
if abs_idx in just_noised_indices:
highlighted.append(f'<span style="color:red">{token_str}</span>')
else:
highlighted.append(token_str)
yield f"<b>Iteration {i+1}/{max_it} (after noising):</b><br>" + "".join(highlighted).replace('\n', '<br>')
time.sleep(0.1)
# --- Early stopping ---
last_tokens.append(generated_tokens)
if len(last_tokens) > 3:
last_tokens.pop(0)
if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
yield f"<b>Stopped early after {i+1} iterations.</b>"
break
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
final_output = tokenizer.convert_tokens_to_string(final_tokens)
print(final_output)
yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')
# --- Gradio Interface ---
print("Loading model...")
model = load_model()
print("✅ Model loaded.")
demo = gr.Interface(
fn=diffusion_chat,
inputs=[
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"),
gr.Slider(0, 1, value=0.4, step=0.05, label="↓ = longer answers (EOT weight)"),
gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)"),
gr.Slider(0.01, 1.0, value=0.05, step=0.01, label="↓ = more confidence guidance (noise clipping)"),
gr.Checkbox(value=False, label="Use confidence-guided noising"),
gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="↑ = more clustered noising (fewer, larger edits)")
],
outputs=[gr.HTML(label="Diffusion Output")],
title="Diffusion Language Model Chat",
theme="default",
description="This interface runs a diffusion-based language model to generate answers progressively."
)
demo.launch(share=True, allowed_paths=["."], ssr_mode=False)