Spaces:
Running on Zero

Ruurd commited on
Commit
7252f98
·
1 Parent(s): acbd7fa

First commit

Browse files
Files changed (4) hide show
  1. app.py +158 -0
  2. llama_diffusion_model.py +134 -0
  3. requirements.txt +7 -0
  4. token_probabilities.json +0 -0
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import json
5
+ import time
6
+ from transformers import AutoTokenizer
7
+ from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, disable_dropout
8
+ import os
9
+
10
+ hf_token = os.getenv("HF_TOKEN")
11
+
12
+
13
+ # --- Load tokenizer ---
14
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
15
+ vocab_size = len(tokenizer)
16
+ pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id
17
+ eot_token_id = tokenizer.eos_token_id
18
+ assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
19
+
20
+ # --- Load token probabilities ---
21
+ with open("token_probabilities.json") as f:
22
+ token_probs_dict = json.load(f)
23
+ token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
24
+
25
+
26
+ def load_model():
27
+ config = CustomTransformerConfig(vocab_size=vocab_size)
28
+ model = CustomTransformerModel(config)
29
+ model.load_state_dict(torch.hub.load_state_dict_from_url(
30
+ "https://huggingface.co/Ruurd/tini_model/resolve/main/diffusion-model.pth",
31
+ map_location="cuda",
32
+ headers={"Authorization": f"Bearer {hf_token}"}
33
+ ))
34
+ model = disable_dropout(model)
35
+ model.to("cuda")
36
+ model.eval()
37
+ return model
38
+
39
+ rng = np.random.default_rng()
40
+
41
+ # --- Utility Functions ---
42
+ def decode_tokens_safe(token_ids):
43
+ return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ")
44
+
45
+ def find_answer_start(input_ids, marker_ids):
46
+ for i in range(len(input_ids) - len(marker_ids) + 1):
47
+ if input_ids[i:i + len(marker_ids)] == marker_ids:
48
+ return i + len(marker_ids)
49
+ return None
50
+
51
+ def get_noising_schedule(i, max_it, sharpness=5.0):
52
+ x = i / max_it
53
+ return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
54
+
55
+ def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0):
56
+ noised = input_ids.copy()
57
+ answer_len = len(input_ids) - answer_start
58
+ num_to_noise = int(threshold * answer_len)
59
+ if num_to_noise > 0:
60
+ indices = rng.choice(np.arange(answer_start, len(input_ids)), size=num_to_noise, replace=False)
61
+
62
+ mixed_probs = token_probabilities.copy()
63
+ mixed_probs[eot_token_id] *= eot_weight
64
+ mixed_probs /= mixed_probs.sum()
65
+
66
+ noise = rng.choice(np.arange(vocab_size), size=num_to_noise, p=mixed_probs)
67
+ for idx, val in zip(indices, noise):
68
+ noised[idx] = val
69
+ return noised
70
+
71
+ def generate_diffusion_text(model, input_ids, answer_start):
72
+ with torch.no_grad():
73
+ input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
74
+ logits = model(input_ids=input_tensor)["logits"]
75
+ probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
76
+ probs = torch.clamp(probs, min=1e-8, max=1.0)
77
+ sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist()
78
+ return input_ids[:answer_start] + sampled[answer_start:]
79
+
80
+ # --- Inference Wrapper ---
81
+ def diffusion_chat(question, eot_weight, max_it, sharpness, model):
82
+ placeholder = "What do you know about the city of New York?"
83
+ if question.strip() == "":
84
+ question = placeholder
85
+
86
+ prompt = f"User: {question}\nAssistant:"
87
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False)
88
+ answer_start = find_answer_start(input_ids, assistant_marker_ids)
89
+ if answer_start is None:
90
+ yield "Error: Could not find Assistant marker in input."
91
+ return
92
+
93
+ if len(input_ids) < 256:
94
+ input_ids += [pad_token] * (256 - len(input_ids))
95
+ else:
96
+ input_ids = input_ids[:256]
97
+
98
+ ori_input_tokens = input_ids
99
+ current_tokens = noisify_answer(ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight)
100
+ prev_decoded_tokens = []
101
+ last_tokens = []
102
+
103
+ for i in range(max_it):
104
+ generated_tokens = generate_diffusion_text(model, current_tokens, answer_start)
105
+ current_tokens = generated_tokens
106
+
107
+ decoded_ids = current_tokens[answer_start:]
108
+ decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
109
+ filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
110
+ filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else []
111
+
112
+ if filtered_prev_tokens:
113
+ highlighted = []
114
+ for tok_new, tok_old in zip(filtered_tokens, filtered_prev_tokens):
115
+ if tok_new != tok_old:
116
+ highlighted.append(f'<span style="color:green">{tokenizer.convert_tokens_to_string([tok_new])}</span>')
117
+ else:
118
+ highlighted.append(tokenizer.convert_tokens_to_string([tok_new]))
119
+ else:
120
+ highlighted = [tokenizer.convert_tokens_to_string([tok]) for tok in filtered_tokens]
121
+
122
+ prev_decoded_tokens = decoded_tokens
123
+ yield f"<b>Iteration {i+1}/{max_it} (running):</b><br>" + "".join(highlighted)
124
+
125
+ last_tokens.append(generated_tokens)
126
+ if len(last_tokens) > 3:
127
+ last_tokens.pop(0)
128
+ if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
129
+ yield f"<b>Stopped early after {i+1} iterations.</b>"
130
+ break
131
+
132
+ threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
133
+ current_tokens = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight)
134
+ time.sleep(0.01)
135
+
136
+ final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
137
+ final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
138
+ final_output = tokenizer.convert_tokens_to_string(final_tokens)
139
+ yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output
140
+
141
+ # --- Gradio Interface ---
142
+ model_state = gr.State(load_model())
143
+
144
+ demo = gr.Interface(
145
+ fn=diffusion_chat,
146
+ inputs=[
147
+ gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"),
148
+ gr.Slider(0, 1, value=0.4, step=0.05, label="↓ = longer answers (EOT weight)"),
149
+ gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
150
+ gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)"),
151
+ model_state
152
+ ],
153
+ outputs=gr.HTML(label="Diffusion Output"),
154
+ title="Diffusion Language Model Chat",
155
+ description="This interface runs a diffusion-based language model to generate answers progressively."
156
+ )
157
+
158
+ demo.launch()
llama_diffusion_model.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.amp import autocast
5
+ from transformers import AutoModelForCausalLM, PreTrainedModel, PretrainedConfig
6
+ from peft import LoraConfig, get_peft_model
7
+ import os
8
+
9
+ hf_token = os.getenv("HF_TOKEN")
10
+
11
+ class BidirectionalLlamaAttention(nn.Module):
12
+ def __init__(self, original_layer, masking='unidirectional'):
13
+ super().__init__()
14
+ self.original = original_layer
15
+ self.masking = masking
16
+
17
+ self.q_proj = original_layer.q_proj
18
+ self.k_proj = original_layer.k_proj
19
+ self.v_proj = original_layer.v_proj
20
+ self.o_proj = original_layer.o_proj
21
+
22
+ self.head_dim = self.q_proj.out_features // original_layer.num_heads
23
+ self.num_heads = original_layer.num_heads
24
+ self.num_key_value_groups = original_layer.num_key_value_groups
25
+ self.attention_dropout = original_layer.attention_dropout
26
+ self.layer_idx = original_layer.layer_idx
27
+ self.scaling = original_layer.scaling
28
+
29
+ def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_value=None, cache_position=None, **kwargs):
30
+ bsz, seq_len, _ = hidden_states.size()
31
+
32
+ query_states = self._split_heads(self.q_proj(hidden_states))
33
+ key_states = self._split_heads(self.k_proj(hidden_states))
34
+ value_states = self._split_heads(self.v_proj(hidden_states))
35
+
36
+ cos, sin = position_embeddings
37
+ query_states, key_states = self._apply_rotary(query_states, key_states, cos, sin)
38
+
39
+ if self.masking == 'bidirectional':
40
+ attn_mask = torch.ones((bsz, 1, seq_len, seq_len), device=hidden_states.device)
41
+ else:
42
+ attn_mask = torch.tril(torch.ones(seq_len, seq_len, device=hidden_states.device)).unsqueeze(0).unsqueeze(0)
43
+
44
+ attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling
45
+ attn_weights = attn_weights + attn_mask.log()
46
+ attn_weights = F.softmax(attn_weights, dim=-1)
47
+ attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
48
+
49
+ attn_output = torch.matmul(attn_weights, value_states)
50
+ attn_output = self._merge_heads(attn_output)
51
+ return self.o_proj(attn_output), attn_weights
52
+
53
+ def _split_heads(self, x):
54
+ return x.view(x.size(0), x.size(1), self.num_heads, self.head_dim).transpose(1, 2)
55
+
56
+ def _merge_heads(self, x):
57
+ return x.transpose(1, 2).contiguous().view(x.size(0), -1, self.num_heads * self.head_dim)
58
+
59
+ def _apply_rotary(self, q, k, cos, sin):
60
+ cos = cos.unsqueeze(1)
61
+ sin = sin.unsqueeze(1)
62
+ q_rot = (q * cos) + (self._rotate_half(q) * sin)
63
+ k_rot = (k * cos) + (self._rotate_half(k) * sin)
64
+ return q_rot, k_rot
65
+
66
+ def _rotate_half(self, x):
67
+ x1 = x[..., : x.shape[-1] // 2]
68
+ x2 = x[..., x.shape[-1] // 2 :]
69
+ return torch.cat((-x2, x1), dim=-1)
70
+
71
+
72
+ class CustomTransformerConfig(PretrainedConfig):
73
+ def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0, max_position_embeddings=4096, **kwargs):
74
+ super().__init__(**kwargs)
75
+ self.vocab_size = vocab_size
76
+ self.hidden_size = hidden_size
77
+ self.num_layers = num_layers
78
+ self.num_heads = num_heads
79
+ self.dropout = dropout
80
+ self.prediction_chunk = prediction_chunk
81
+ self.max_position_embeddings = max_position_embeddings
82
+
83
+
84
+ class CustomTransformerModel(PreTrainedModel):
85
+ config_class = CustomTransformerConfig
86
+
87
+ def __init__(self, config):
88
+ super().__init__(config)
89
+
90
+ self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, token=hf_token)
91
+ self.llama.resize_token_embeddings(config.vocab_size)
92
+
93
+ for i, layer in enumerate(self.llama.model.layers):
94
+ layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking='bidirectional')
95
+
96
+ for param in self.llama.parameters():
97
+ param.requires_grad = False
98
+
99
+ for param in self.llama.lm_head.parameters():
100
+ param.requires_grad = True
101
+
102
+ lora_config = LoraConfig(
103
+ r=256,
104
+ lora_alpha=256,
105
+ lora_dropout=0.0,
106
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
107
+ bias="none",
108
+ task_type=None
109
+ )
110
+
111
+ self.llama = get_peft_model(self.llama, lora_config)
112
+ self.llama = self.llama.to(torch.float16)
113
+
114
+ def forward(self, input_ids, labels=None, **kwargs):
115
+ batch_size, seq_length = input_ids.shape
116
+ assert seq_length == self.config.prediction_chunk
117
+
118
+ with autocast("cuda", dtype=torch.float16):
119
+ outputs = self.llama(input_ids=input_ids, output_hidden_states=True, **kwargs)
120
+ logits = outputs.logits[:, :, :self.config.vocab_size].view(batch_size, self.config.prediction_chunk, self.config.vocab_size)
121
+
122
+ loss = None
123
+ if labels is not None:
124
+ loss_fct = nn.CrossEntropyLoss()
125
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
126
+
127
+ return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
128
+
129
+
130
+ def disable_dropout(model):
131
+ for name, module in model.named_modules():
132
+ if isinstance(module, nn.Dropout):
133
+ setattr(model, name, nn.Identity())
134
+ return model
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.38.0
3
+ datasets>=2.16.0
4
+ peft>=0.8.2
5
+ accelerate>=0.24.1
6
+ gradio>=4.10.0
7
+ numpy
token_probabilities.json ADDED
The diff for this file is too large to render. See raw diff