import torch import torch.nn.functional as F import numpy as np import time import random import importlib import torch.nn as nn import os from transformers import AutoTokenizer rng = np.random.default_rng() def disable_dropout(model): for name, module in model.named_modules(): if isinstance(module, nn.Dropout): setattr(model, name, nn.Identity()) # Replace Dropout with Identity return model def load_trained_model(checkpoint_path: str, base_model_name: str = "meta-llama/Llama-3.2-3B"): # Load tokenizer + config from saved dir hf_token = os.getenv("HF_TOKEN") tokenizer = AutoTokenizer.from_pretrained(base_model_name, use_fast=True, token=hf_token, torch_dtype=torch.float32) # Step 5: Load the model safely model = torch.load(checkpoint_path, map_location=torch.device('cpu'), weights_only=False) # Disable dropout model = disable_dropout(model) print("✅ Model successfully loaded from checkpoint:", checkpoint_path) # Move to correct device device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" # model = model.to(torch.float32) model.to(device) model.eval() return model, tokenizer def filter_logits(logits, top_k=0, top_p=1.0, temperature=1.0): """ Vectorized top-k and/or top-p (nucleus) filtering with temperature scaling. Accepts logits of shape (seq_len, vocab_size) or (1, seq_len, vocab_size), and returns logits in the same shape. """ original_shape = logits.shape if logits.dim() == 3: logits = logits.squeeze(0) # shape: (seq_len, vocab_size) logits = logits.clone() # --- Temperature scaling --- if temperature != 1.0: logits = logits / temperature # --- Top-k filtering --- if top_k > 0 and top_k < logits.size(-1): topk_vals, _ = torch.topk(logits, top_k, dim=-1) thresholds = topk_vals[:, -1].unsqueeze(-1) logits = torch.where(logits < thresholds, torch.full_like(logits, float("-inf")), logits) # --- Top-p filtering --- if top_p > 0.0 and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) probs = torch.softmax(sorted_logits, dim=-1) cum_probs = probs.cumsum(dim=-1) mask = cum_probs > top_p mask[:, 0] = False # always keep top token scatter_mask = torch.zeros_like(logits, dtype=torch.bool).scatter(dim=-1, index=sorted_indices, src=mask) logits = torch.where(scatter_mask, torch.full_like(logits, float("-inf")), logits) # Restore original shape if original_shape[0] == 1: logits = logits.unsqueeze(0) return logits def decode_tokens_safe(tokenizer, token_ids): return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ") def find_answer_start(input_ids, marker_ids): for i in range(len(input_ids) - len(marker_ids) + 1): if input_ids[i:i + len(marker_ids)] == marker_ids: return i + len(marker_ids) return None def noisify_answer(input_ids, answer_start, threshold=1.0, is_unmasked=None, mask_token_id=128002): noised = input_ids.copy() total_len = len(input_ids) candidates = [ i for i in range(answer_start, total_len) if is_unmasked is None or not is_unmasked[i] ] num_to_add = int(threshold * total_len) if num_to_add > 0 and len(candidates) > 0: newly_masked = rng.choice(candidates, size=min(num_to_add, len(candidates)), replace=False) for idx in newly_masked: noised[idx] = mask_token_id return noised def get_noising_schedule(i, max_it, sharpness=5.0): x = i / max_it return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness)) import torch.nn.functional as F def generate_diffusion_text(model, input_ids, answer_start, top_k=0, top_p=1.0, temperature=1.0, eos_token_id=None, eos_boost=0.0): model.eval() with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device) logits = model(input_ids=input_tensor)["logits"] # (1, seq_len, vocab_size) # Optionally boost or suppress EOS token if eos_token_id is not None and eos_boost != 0.0: logits[:, :, eos_token_id] += eos_boost # Filter and sample filtered_logits = filter_logits(logits, top_k=top_k, top_p=top_p, temperature=temperature) probs = F.softmax(filtered_logits, dim=-1).squeeze() # (seq_len, vocab_size) probs = torch.clamp(probs, min=1e-8, max=1.0) sampled = torch.multinomial(probs, num_samples=1).squeeze(-1) confidences = probs.gather(1, sampled.unsqueeze(-1)).squeeze(-1) return input_ids[:answer_start] + sampled[answer_start:].tolist(), confidences def calculate_answer_perplexity(prompt, answer, model_name='gpt2-large'): from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name).eval() device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") model.to(device) full_input = prompt + answer enc = tokenizer(full_input, return_tensors="pt") input_ids = enc.input_ids.to(device) with torch.no_grad(): labels = input_ids.clone() prompt_len = len(tokenizer(prompt, add_special_tokens=False)["input_ids"]) labels[0, :prompt_len] = -100 loss = model(input_ids, labels=labels).loss return torch.exp(loss).item() def generate_answer(question: str, model, tokenizer, max_it=16, noise_start=0.5, noising_sharpness=5.0, max_length=256, top_k=100, top_p=1.0, temperature=1.0, eos_token_id = None, eos_boost = 0.0) -> str: if eos_token_id is None: eos_token_id = tokenizer.eos_token_id # Format prompt with LLaMA 3 chat template prompt = ( "<|begin_of_text|>\n" "<|start_header_id|>system<|end_header_id|>\n" "You are a helpful assistant.\n" "<|eot_id|>\n" "<|start_header_id|>user<|end_header_id|>\n" f"{question.strip()}\n" "<|start_header_id|>assistant<|end_header_id|>\n" ) input_ids = tokenizer.encode(prompt, add_special_tokens=False) marker = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>\n", add_special_tokens=False) def find_answer_start(ids, marker): for i in range(len(ids) - len(marker) + 1): if ids[i:i+len(marker)] == marker: return i + len(marker) return None answer_start = find_answer_start(input_ids, marker) if answer_start is None: raise ValueError("Assistant marker not found in prompt.") # Pad to max length pad_token = tokenizer.eos_token_id mask_token = tokenizer.encode("MASK", add_special_tokens=False)[0] input_ids = input_ids[:max_length] if len(input_ids) < max_length: input_ids += [mask_token] * (max_length - len(input_ids)) ori_tokens = input_ids current_tokens = noisify_answer(ori_tokens, answer_start, threshold=1.0, mask_token_id=mask_token) last_tokens = [] for step in range(max_it): # Generate a new prediction current_tokens, confidence_scores = generate_diffusion_text( model, current_tokens, answer_start, top_k=top_k, top_p=top_p, temperature=temperature, eos_token_id=eos_token_id, eos_boost=eos_boost ) # Display for debugging / tracking display_diffusion_output( step, max_it, question, ori_tokens, current_tokens, confidence_scores, answer_start, tokenizer ) # Early stopping last_tokens.append(current_tokens) if len(last_tokens) > 4: last_tokens.pop(0) if all(t == last_tokens[0] for t in last_tokens): break # Re-apply noise for next iteration if step < max_it - 1: threshold = noise_start * get_noising_schedule(step, max_it, sharpness=noising_sharpness) current_tokens = noisify_answer(current_tokens, answer_start, threshold=threshold, mask_token_id=mask_token) return tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).strip()