Spaces:
Running
on
Zero
Running
on
Zero
First commit
Browse files- app.py +158 -0
- llama_diffusion_model.py +134 -0
- requirements.txt +7 -0
- token_probabilities.json +0 -0
app.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, disable_dropout
|
8 |
+
import os
|
9 |
+
|
10 |
+
hf_token = os.getenv("HF_TOKEN")
|
11 |
+
|
12 |
+
|
13 |
+
# --- Load tokenizer ---
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
|
15 |
+
vocab_size = len(tokenizer)
|
16 |
+
pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id
|
17 |
+
eot_token_id = tokenizer.eos_token_id
|
18 |
+
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
|
19 |
+
|
20 |
+
# --- Load token probabilities ---
|
21 |
+
with open("token_probabilities.json") as f:
|
22 |
+
token_probs_dict = json.load(f)
|
23 |
+
token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
|
24 |
+
|
25 |
+
|
26 |
+
def load_model():
|
27 |
+
config = CustomTransformerConfig(vocab_size=vocab_size)
|
28 |
+
model = CustomTransformerModel(config)
|
29 |
+
model.load_state_dict(torch.hub.load_state_dict_from_url(
|
30 |
+
"https://huggingface.co/Ruurd/tini_model/resolve/main/diffusion-model.pth",
|
31 |
+
map_location="cuda",
|
32 |
+
headers={"Authorization": f"Bearer {hf_token}"}
|
33 |
+
))
|
34 |
+
model = disable_dropout(model)
|
35 |
+
model.to("cuda")
|
36 |
+
model.eval()
|
37 |
+
return model
|
38 |
+
|
39 |
+
rng = np.random.default_rng()
|
40 |
+
|
41 |
+
# --- Utility Functions ---
|
42 |
+
def decode_tokens_safe(token_ids):
|
43 |
+
return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ")
|
44 |
+
|
45 |
+
def find_answer_start(input_ids, marker_ids):
|
46 |
+
for i in range(len(input_ids) - len(marker_ids) + 1):
|
47 |
+
if input_ids[i:i + len(marker_ids)] == marker_ids:
|
48 |
+
return i + len(marker_ids)
|
49 |
+
return None
|
50 |
+
|
51 |
+
def get_noising_schedule(i, max_it, sharpness=5.0):
|
52 |
+
x = i / max_it
|
53 |
+
return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
|
54 |
+
|
55 |
+
def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0):
|
56 |
+
noised = input_ids.copy()
|
57 |
+
answer_len = len(input_ids) - answer_start
|
58 |
+
num_to_noise = int(threshold * answer_len)
|
59 |
+
if num_to_noise > 0:
|
60 |
+
indices = rng.choice(np.arange(answer_start, len(input_ids)), size=num_to_noise, replace=False)
|
61 |
+
|
62 |
+
mixed_probs = token_probabilities.copy()
|
63 |
+
mixed_probs[eot_token_id] *= eot_weight
|
64 |
+
mixed_probs /= mixed_probs.sum()
|
65 |
+
|
66 |
+
noise = rng.choice(np.arange(vocab_size), size=num_to_noise, p=mixed_probs)
|
67 |
+
for idx, val in zip(indices, noise):
|
68 |
+
noised[idx] = val
|
69 |
+
return noised
|
70 |
+
|
71 |
+
def generate_diffusion_text(model, input_ids, answer_start):
|
72 |
+
with torch.no_grad():
|
73 |
+
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
74 |
+
logits = model(input_ids=input_tensor)["logits"]
|
75 |
+
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
|
76 |
+
probs = torch.clamp(probs, min=1e-8, max=1.0)
|
77 |
+
sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist()
|
78 |
+
return input_ids[:answer_start] + sampled[answer_start:]
|
79 |
+
|
80 |
+
# --- Inference Wrapper ---
|
81 |
+
def diffusion_chat(question, eot_weight, max_it, sharpness, model):
|
82 |
+
placeholder = "What do you know about the city of New York?"
|
83 |
+
if question.strip() == "":
|
84 |
+
question = placeholder
|
85 |
+
|
86 |
+
prompt = f"User: {question}\nAssistant:"
|
87 |
+
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
88 |
+
answer_start = find_answer_start(input_ids, assistant_marker_ids)
|
89 |
+
if answer_start is None:
|
90 |
+
yield "Error: Could not find Assistant marker in input."
|
91 |
+
return
|
92 |
+
|
93 |
+
if len(input_ids) < 256:
|
94 |
+
input_ids += [pad_token] * (256 - len(input_ids))
|
95 |
+
else:
|
96 |
+
input_ids = input_ids[:256]
|
97 |
+
|
98 |
+
ori_input_tokens = input_ids
|
99 |
+
current_tokens = noisify_answer(ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight)
|
100 |
+
prev_decoded_tokens = []
|
101 |
+
last_tokens = []
|
102 |
+
|
103 |
+
for i in range(max_it):
|
104 |
+
generated_tokens = generate_diffusion_text(model, current_tokens, answer_start)
|
105 |
+
current_tokens = generated_tokens
|
106 |
+
|
107 |
+
decoded_ids = current_tokens[answer_start:]
|
108 |
+
decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
|
109 |
+
filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
|
110 |
+
filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else []
|
111 |
+
|
112 |
+
if filtered_prev_tokens:
|
113 |
+
highlighted = []
|
114 |
+
for tok_new, tok_old in zip(filtered_tokens, filtered_prev_tokens):
|
115 |
+
if tok_new != tok_old:
|
116 |
+
highlighted.append(f'<span style="color:green">{tokenizer.convert_tokens_to_string([tok_new])}</span>')
|
117 |
+
else:
|
118 |
+
highlighted.append(tokenizer.convert_tokens_to_string([tok_new]))
|
119 |
+
else:
|
120 |
+
highlighted = [tokenizer.convert_tokens_to_string([tok]) for tok in filtered_tokens]
|
121 |
+
|
122 |
+
prev_decoded_tokens = decoded_tokens
|
123 |
+
yield f"<b>Iteration {i+1}/{max_it} (running):</b><br>" + "".join(highlighted)
|
124 |
+
|
125 |
+
last_tokens.append(generated_tokens)
|
126 |
+
if len(last_tokens) > 3:
|
127 |
+
last_tokens.pop(0)
|
128 |
+
if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
|
129 |
+
yield f"<b>Stopped early after {i+1} iterations.</b>"
|
130 |
+
break
|
131 |
+
|
132 |
+
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
|
133 |
+
current_tokens = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight)
|
134 |
+
time.sleep(0.01)
|
135 |
+
|
136 |
+
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
|
137 |
+
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
|
138 |
+
final_output = tokenizer.convert_tokens_to_string(final_tokens)
|
139 |
+
yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output
|
140 |
+
|
141 |
+
# --- Gradio Interface ---
|
142 |
+
model_state = gr.State(load_model())
|
143 |
+
|
144 |
+
demo = gr.Interface(
|
145 |
+
fn=diffusion_chat,
|
146 |
+
inputs=[
|
147 |
+
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"),
|
148 |
+
gr.Slider(0, 1, value=0.4, step=0.05, label="↓ = longer answers (EOT weight)"),
|
149 |
+
gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
|
150 |
+
gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)"),
|
151 |
+
model_state
|
152 |
+
],
|
153 |
+
outputs=gr.HTML(label="Diffusion Output"),
|
154 |
+
title="Diffusion Language Model Chat",
|
155 |
+
description="This interface runs a diffusion-based language model to generate answers progressively."
|
156 |
+
)
|
157 |
+
|
158 |
+
demo.launch()
|
llama_diffusion_model.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.amp import autocast
|
5 |
+
from transformers import AutoModelForCausalLM, PreTrainedModel, PretrainedConfig
|
6 |
+
from peft import LoraConfig, get_peft_model
|
7 |
+
import os
|
8 |
+
|
9 |
+
hf_token = os.getenv("HF_TOKEN")
|
10 |
+
|
11 |
+
class BidirectionalLlamaAttention(nn.Module):
|
12 |
+
def __init__(self, original_layer, masking='unidirectional'):
|
13 |
+
super().__init__()
|
14 |
+
self.original = original_layer
|
15 |
+
self.masking = masking
|
16 |
+
|
17 |
+
self.q_proj = original_layer.q_proj
|
18 |
+
self.k_proj = original_layer.k_proj
|
19 |
+
self.v_proj = original_layer.v_proj
|
20 |
+
self.o_proj = original_layer.o_proj
|
21 |
+
|
22 |
+
self.head_dim = self.q_proj.out_features // original_layer.num_heads
|
23 |
+
self.num_heads = original_layer.num_heads
|
24 |
+
self.num_key_value_groups = original_layer.num_key_value_groups
|
25 |
+
self.attention_dropout = original_layer.attention_dropout
|
26 |
+
self.layer_idx = original_layer.layer_idx
|
27 |
+
self.scaling = original_layer.scaling
|
28 |
+
|
29 |
+
def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_value=None, cache_position=None, **kwargs):
|
30 |
+
bsz, seq_len, _ = hidden_states.size()
|
31 |
+
|
32 |
+
query_states = self._split_heads(self.q_proj(hidden_states))
|
33 |
+
key_states = self._split_heads(self.k_proj(hidden_states))
|
34 |
+
value_states = self._split_heads(self.v_proj(hidden_states))
|
35 |
+
|
36 |
+
cos, sin = position_embeddings
|
37 |
+
query_states, key_states = self._apply_rotary(query_states, key_states, cos, sin)
|
38 |
+
|
39 |
+
if self.masking == 'bidirectional':
|
40 |
+
attn_mask = torch.ones((bsz, 1, seq_len, seq_len), device=hidden_states.device)
|
41 |
+
else:
|
42 |
+
attn_mask = torch.tril(torch.ones(seq_len, seq_len, device=hidden_states.device)).unsqueeze(0).unsqueeze(0)
|
43 |
+
|
44 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling
|
45 |
+
attn_weights = attn_weights + attn_mask.log()
|
46 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
47 |
+
attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
48 |
+
|
49 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
50 |
+
attn_output = self._merge_heads(attn_output)
|
51 |
+
return self.o_proj(attn_output), attn_weights
|
52 |
+
|
53 |
+
def _split_heads(self, x):
|
54 |
+
return x.view(x.size(0), x.size(1), self.num_heads, self.head_dim).transpose(1, 2)
|
55 |
+
|
56 |
+
def _merge_heads(self, x):
|
57 |
+
return x.transpose(1, 2).contiguous().view(x.size(0), -1, self.num_heads * self.head_dim)
|
58 |
+
|
59 |
+
def _apply_rotary(self, q, k, cos, sin):
|
60 |
+
cos = cos.unsqueeze(1)
|
61 |
+
sin = sin.unsqueeze(1)
|
62 |
+
q_rot = (q * cos) + (self._rotate_half(q) * sin)
|
63 |
+
k_rot = (k * cos) + (self._rotate_half(k) * sin)
|
64 |
+
return q_rot, k_rot
|
65 |
+
|
66 |
+
def _rotate_half(self, x):
|
67 |
+
x1 = x[..., : x.shape[-1] // 2]
|
68 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
69 |
+
return torch.cat((-x2, x1), dim=-1)
|
70 |
+
|
71 |
+
|
72 |
+
class CustomTransformerConfig(PretrainedConfig):
|
73 |
+
def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0, max_position_embeddings=4096, **kwargs):
|
74 |
+
super().__init__(**kwargs)
|
75 |
+
self.vocab_size = vocab_size
|
76 |
+
self.hidden_size = hidden_size
|
77 |
+
self.num_layers = num_layers
|
78 |
+
self.num_heads = num_heads
|
79 |
+
self.dropout = dropout
|
80 |
+
self.prediction_chunk = prediction_chunk
|
81 |
+
self.max_position_embeddings = max_position_embeddings
|
82 |
+
|
83 |
+
|
84 |
+
class CustomTransformerModel(PreTrainedModel):
|
85 |
+
config_class = CustomTransformerConfig
|
86 |
+
|
87 |
+
def __init__(self, config):
|
88 |
+
super().__init__(config)
|
89 |
+
|
90 |
+
self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, token=hf_token)
|
91 |
+
self.llama.resize_token_embeddings(config.vocab_size)
|
92 |
+
|
93 |
+
for i, layer in enumerate(self.llama.model.layers):
|
94 |
+
layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking='bidirectional')
|
95 |
+
|
96 |
+
for param in self.llama.parameters():
|
97 |
+
param.requires_grad = False
|
98 |
+
|
99 |
+
for param in self.llama.lm_head.parameters():
|
100 |
+
param.requires_grad = True
|
101 |
+
|
102 |
+
lora_config = LoraConfig(
|
103 |
+
r=256,
|
104 |
+
lora_alpha=256,
|
105 |
+
lora_dropout=0.0,
|
106 |
+
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
107 |
+
bias="none",
|
108 |
+
task_type=None
|
109 |
+
)
|
110 |
+
|
111 |
+
self.llama = get_peft_model(self.llama, lora_config)
|
112 |
+
self.llama = self.llama.to(torch.float16)
|
113 |
+
|
114 |
+
def forward(self, input_ids, labels=None, **kwargs):
|
115 |
+
batch_size, seq_length = input_ids.shape
|
116 |
+
assert seq_length == self.config.prediction_chunk
|
117 |
+
|
118 |
+
with autocast("cuda", dtype=torch.float16):
|
119 |
+
outputs = self.llama(input_ids=input_ids, output_hidden_states=True, **kwargs)
|
120 |
+
logits = outputs.logits[:, :, :self.config.vocab_size].view(batch_size, self.config.prediction_chunk, self.config.vocab_size)
|
121 |
+
|
122 |
+
loss = None
|
123 |
+
if labels is not None:
|
124 |
+
loss_fct = nn.CrossEntropyLoss()
|
125 |
+
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
126 |
+
|
127 |
+
return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
|
128 |
+
|
129 |
+
|
130 |
+
def disable_dropout(model):
|
131 |
+
for name, module in model.named_modules():
|
132 |
+
if isinstance(module, nn.Dropout):
|
133 |
+
setattr(model, name, nn.Identity())
|
134 |
+
return model
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
transformers>=4.38.0
|
3 |
+
datasets>=2.16.0
|
4 |
+
peft>=0.8.2
|
5 |
+
accelerate>=0.24.1
|
6 |
+
gradio>=4.10.0
|
7 |
+
numpy
|
token_probabilities.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|