Spaces:
Running
on
Zero
Running
on
Zero
Removed custom bidirectional layer as it is not needed when using the Llama attention_masks
Browse files- llama_diffusion_model.py +0 -70
llama_diffusion_model.py
CHANGED
@@ -9,72 +9,6 @@ from typing import Optional, Tuple
|
|
9 |
|
10 |
hf_token = os.getenv("HF_TOKEN")
|
11 |
|
12 |
-
class BidirectionalLlamaAttention(LlamaAttention):
|
13 |
-
def __init__(self, original_layer, masking='unidirectional'):
|
14 |
-
super().__init__(original_layer.config, layer_idx=original_layer.layer_idx)
|
15 |
-
self.masking = masking
|
16 |
-
self.q_proj.weight = original_layer.q_proj.weight
|
17 |
-
self.k_proj.weight = original_layer.k_proj.weight
|
18 |
-
self.v_proj.weight = original_layer.v_proj.weight
|
19 |
-
self.o_proj.weight = original_layer.o_proj.weight
|
20 |
-
|
21 |
-
def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
22 |
-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
23 |
-
if n_rep == 1:
|
24 |
-
return hidden_states
|
25 |
-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
26 |
-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
27 |
-
|
28 |
-
def eager_attention_forward(self, module: nn.Module, query, key, value, attention_mask, scaling, dropout=0.0, **kwargs):
|
29 |
-
key_states = self.repeat_kv(key, module.num_key_value_groups)
|
30 |
-
value_states = self.repeat_kv(value, module.num_key_value_groups)
|
31 |
-
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
32 |
-
|
33 |
-
if attention_mask is not None:
|
34 |
-
attn_mask = (1.0 - attention_mask) * float('-inf')
|
35 |
-
attn_mask = attn_mask.to(dtype=query.dtype)
|
36 |
-
attn_weights = attn_weights + attn_mask
|
37 |
-
|
38 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)
|
39 |
-
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
40 |
-
attn_output = torch.matmul(attn_weights, value_states).transpose(1, 2).contiguous()
|
41 |
-
return attn_output, attn_weights
|
42 |
-
|
43 |
-
def rotate_half(self, x):
|
44 |
-
x1 = x[..., : x.shape[-1] // 2]
|
45 |
-
x2 = x[..., x.shape[-1] // 2:]
|
46 |
-
return torch.cat((-x2, x1), dim=-1)
|
47 |
-
|
48 |
-
def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1):
|
49 |
-
cos = cos.unsqueeze(unsqueeze_dim)
|
50 |
-
sin = sin.unsqueeze(unsqueeze_dim)
|
51 |
-
q_embed = (q * cos) + (self.rotate_half(q) * sin)
|
52 |
-
k_embed = (k * cos) + (self.rotate_half(k) * sin)
|
53 |
-
return q_embed, k_embed
|
54 |
-
|
55 |
-
def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_value=None, cache_position=None, **kwargs):
|
56 |
-
input_shape = hidden_states.shape[:-1]
|
57 |
-
hidden_shape = (*input_shape, -1, self.head_dim)
|
58 |
-
|
59 |
-
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
60 |
-
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
61 |
-
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
62 |
-
|
63 |
-
cos, sin = position_embeddings
|
64 |
-
query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
65 |
-
|
66 |
-
if past_key_value is not None:
|
67 |
-
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
68 |
-
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
69 |
-
|
70 |
-
attn_output, attn_weights = self.eager_attention_forward(
|
71 |
-
self, query_states, key_states, value_states, attention_mask,
|
72 |
-
dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs
|
73 |
-
)
|
74 |
-
|
75 |
-
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
76 |
-
return self.o_proj(attn_output), attn_weights
|
77 |
-
|
78 |
class CustomTransformerConfig(PretrainedConfig):
|
79 |
def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0,
|
80 |
max_position_embeddings=4096, masking_type="bidirectional", **kwargs):
|
@@ -97,9 +31,6 @@ class CustomTransformerModel(PreTrainedModel):
|
|
97 |
self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token=hf_token)
|
98 |
self.llama.resize_token_embeddings(config.vocab_size)
|
99 |
|
100 |
-
# for i, layer in enumerate(self.llama.model.layers):
|
101 |
-
# layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking=config.masking_type)
|
102 |
-
|
103 |
for param in self.llama.parameters():
|
104 |
param.requires_grad = False
|
105 |
for param in self.llama.lm_head.parameters():
|
@@ -113,7 +44,6 @@ class CustomTransformerModel(PreTrainedModel):
|
|
113 |
|
114 |
self.llama = get_peft_model(self.llama, lora_config)
|
115 |
self.llama.print_trainable_parameters()
|
116 |
-
# self.llama = self.llama.to(torch.float16)
|
117 |
|
118 |
def forward(self, input_ids, labels=None, **kwargs):
|
119 |
batch_size, seq_len = input_ids.shape
|
|
|
9 |
|
10 |
hf_token = os.getenv("HF_TOKEN")
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
class CustomTransformerConfig(PretrainedConfig):
|
13 |
def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0,
|
14 |
max_position_embeddings=4096, masking_type="bidirectional", **kwargs):
|
|
|
31 |
self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token=hf_token)
|
32 |
self.llama.resize_token_embeddings(config.vocab_size)
|
33 |
|
|
|
|
|
|
|
34 |
for param in self.llama.parameters():
|
35 |
param.requires_grad = False
|
36 |
for param in self.llama.lm_head.parameters():
|
|
|
44 |
|
45 |
self.llama = get_peft_model(self.llama, lora_config)
|
46 |
self.llama.print_trainable_parameters()
|
|
|
47 |
|
48 |
def forward(self, input_ids, labels=None, **kwargs):
|
49 |
batch_size, seq_len = input_ids.shape
|