import gradio as gr import torch import numpy as np import json import time from transformers import AutoTokenizer import os import importlib from huggingface_hub import hf_hub_download import spaces from dotenv import load_dotenv from infer import ( load_trained_model, find_answer_start, get_noising_schedule, noisify_answer, generate_diffusion_text, filter_logits ) from models import CustomTransformerModel from model_config import CustomTransformerConfig # Load .env only when running locally if os.getenv("HF_TOKEN") is None: load_dotenv() hf_token = os.getenv("HF_TOKEN") if hf_token is None: raise ValueError("HF_TOKEN is not set") rng = np.random.default_rng() # Add new noising function def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, noise_start=1.0): noised = input_ids.copy() answer_len = len(input_ids) - answer_start num_to_noise = int(threshold * answer_len * noise_start) if num_to_noise == 0: return noised, [] all_indices = np.arange(answer_start, len(input_ids)) eos_indices = [i for i in all_indices if input_ids[i] == eos_token_id] non_eos_indices = [i for i in all_indices if input_ids[i] != eos_token_id] # Proportionally split how many to noise num_non_eos_to_noise = int(num_to_noise * len(non_eos_indices) / (len(non_eos_indices) + len(eos_indices) + 1e-5)) num_eos_to_noise = num_to_noise - num_non_eos_to_noise noised_indices = [] # --- Non-EOS --- if non_eos_indices: raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in non_eos_indices]) raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None) weights = raw_weights / raw_weights.sum() chosen = rng.choice(non_eos_indices, size=min(num_non_eos_to_noise, len(non_eos_indices)), replace=False, p=weights) noised_indices.extend(chosen.tolist()) # --- EOS --- if eos_indices and num_eos_to_noise > 0: raw_weights = 1.0 - np.array([confidences[i - answer_start] for i in eos_indices]) raw_weights = np.clip(raw_weights, a_min=noise_clipping, a_max=None) weights = raw_weights / raw_weights.sum() chosen = rng.choice(eos_indices, size=min(num_eos_to_noise, len(eos_indices)), replace=False, p=weights) noised_indices.extend(chosen.tolist()) for idx in noised_indices: noised[idx] = mask_token_id noised_indices = sorted(noised_indices) return noised, noised_indices @spaces.GPU def generate_diffusion_text(input_ids, top_p, top_k): with torch.no_grad(): input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device) with torch.amp.autocast('cuda', dtype=torch.float16): logits = model(input_ids=input_tensor)["logits"] logits = filter_logits(logits, top_k=top_p, top_p=top_k) logits = logits.clamp(min=-1e8, max=1e4) probs = torch.nn.functional.softmax(logits, dim=-1)[0] probs = torch.clamp(probs, min=1e-8, max=1.0) assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!" assert (probs >= 0).all(), "Negative probs!" sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist() # Extract confidence of selected tokens conf = probs[range(len(sampled)), sampled].cpu().numpy() return sampled, conf def format_chat_prompt(question): return ( "<|begin_of_text|>\n" "<|start_header_id|>system<|end_header_id|>\n" "You are a helpful assistant.\n" "<|start_header_id|>user<|end_header_id|>\n" f"{question}\n" "<|start_header_id|>assistant<|end_header_id|>\n" ) # --- Inference Wrapper --- def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_start, use_confidence_noising, noise_clipping, top_p, top_k): placeholder = "What do you know about the city of Amsterdam?" if question.strip() == "": question = placeholder print('started generation') prompt = format_chat_prompt(question) input_ids = tokenizer.encode(prompt, add_special_tokens=False) answer_start = find_answer_start(input_ids, assistant_marker_ids) if answer_start is None: yield "Error: Could not find Assistant marker in input." return if len(input_ids) < 256: input_ids += [mask_token_id] * (256 - len(input_ids)) else: input_ids = input_ids[:256] ori_input_tokens = input_ids current_tokens, just_noised_indices = noisify_answer( input_ids, answer_start, tokenizer, threshold=1.0, clustering=clustering, noise_start = 1.0, ) yield f"Iteration 0 (initial noise):
" + tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).replace('\n', '
') time.sleep(pause_length) last_tokens = [] prev_decoded_tokens = [] generation_start = time.time() for i in range(max_it): print('Generating output') # Model step generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k) elapsed = time.time() - generation_start remaining = pause_length - elapsed if remaining > 0: time.sleep(remaining) # Save full output for noising step current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:] # --- GREEN HIGHLIGHT --- decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:]) highlighted = [] for j, tok in enumerate(decoded_tokens): tok_id = tokenizer.convert_tokens_to_ids(tok) if tok_id == eos_token_id: continue token_str = tokenizer.convert_tokens_to_string([tok]) if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]: highlighted.append(f'{token_str}') else: highlighted.append(token_str) prev_decoded_tokens = decoded_tokens yield f"Iteration {i+1}/{max_it} (after generation):
" + "".join(highlighted).replace('\n', '
') time.sleep(pause_length) # --- Early stopping --- last_tokens.append(current_tokens) if len(last_tokens) > 3: last_tokens.pop(0) if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]: yield f"Stopped early after {i+1} iterations." break previous_tokens = current_tokens.copy() # --- NOISING STEP --- threshold = get_noising_schedule(i, max_it, sharpness=sharpness) if use_confidence_noising: noised_answer, just_noised_indices = confidence_guided_noising( current_tokens, answer_start, confidences, noise_clipping, threshold=threshold, noise_start=noise_start ) # just_noised_indices = [] else: noised_answer, just_noised_indices = noisify_answer( current_tokens, answer_start, tokenizer, threshold=threshold, clustering=clustering, noise_start = noise_start, ) # --- RED HIGHLIGHT --- decoded_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:]) highlighted = [] for j, tok in enumerate(decoded_tokens): tok_id = tokenizer.convert_tokens_to_ids(tok) if tok_id == eos_token_id: continue token_str = tokenizer.convert_tokens_to_string([tok]) abs_idx = answer_start + j if abs_idx in just_noised_indices: highlighted.append(f'{token_str}') else: highlighted.append(token_str) # Compose full input again: prompt + noised answer current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:] yield f"Iteration {i+1}/{max_it} (before noising):
" + "".join(highlighted).replace('\n', '
') generation_start = time.time() answer_ids = current_tokens[answer_start:] try: eos_index = answer_ids.index(eos_token_id) final_ids = answer_ids[:eos_index] except ValueError: final_ids = answer_ids num_tokens = len(final_ids) final_output = tokenizer.decode(final_ids, skip_special_tokens=True) print(final_output) yield f"Final Output ({num_tokens} tokens after {i+1} iterations):
" + final_output.replace('\n', '
') # --- Gradio Interface --- print("Loading model...") ckpt_path = hf_hub_download( repo_id="ruurd/tini_model", filename="diffusion-model.pth", token=os.getenv("HF_TOKEN") ) model, tokenizer = load_trained_model(checkpoint_path=ckpt_path) print("✅ Model loaded.") vocab_size = len(tokenizer) eos_token_id = tokenizer.eos_token_id mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0] assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False) demo = gr.Interface( fn=diffusion_chat, inputs=[ gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"), gr.Slider(1, 512, value=64, step=1, label="Number of iterarions: ↑ = more iterations"), gr.Slider(0.01, 5, value=0.01, step=0.01, label="Pause between iteration ↑ = longer pause"), gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Noise decay sharpness: ↓ = more noise in later iterations"), gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Clustering: ↑ = more clustered noising"), gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Noise start fraction: ↑ = more noise"), gr.Checkbox(value=False, label="Use confidence-guided noising"), gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="Noise clipping: ↓ = more confidence guidance"), gr.Slider(1, 1000, value = 100, step = 1, label = "Top-p: ↑ = more random answers"), gr.Slider(0.0, 1.0, value = 0.9, step = 0.01, label = "Top-k: ↑ = more random answers") ], outputs=[gr.HTML(label="Diffusion Output")], title="Diffusion Language Model Chat", theme="default", description="This interface runs a diffusion-based language model to generate answers progressively." ) demo.launch(share=True, allowed_paths=["."], ssr_mode=False)