Spaces:
Running on Zero

Ruurd commited on
Commit
b57b92e
·
verified ·
1 Parent(s): 093a557

Removed custom bidirectional layer as it is not needed when using the Llama attention_masks

Browse files
Files changed (1) hide show
  1. llama_diffusion_model.py +0 -70
llama_diffusion_model.py CHANGED
@@ -9,72 +9,6 @@ from typing import Optional, Tuple
9
 
10
  hf_token = os.getenv("HF_TOKEN")
11
 
12
- class BidirectionalLlamaAttention(LlamaAttention):
13
- def __init__(self, original_layer, masking='unidirectional'):
14
- super().__init__(original_layer.config, layer_idx=original_layer.layer_idx)
15
- self.masking = masking
16
- self.q_proj.weight = original_layer.q_proj.weight
17
- self.k_proj.weight = original_layer.k_proj.weight
18
- self.v_proj.weight = original_layer.v_proj.weight
19
- self.o_proj.weight = original_layer.o_proj.weight
20
-
21
- def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
22
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
23
- if n_rep == 1:
24
- return hidden_states
25
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
26
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
27
-
28
- def eager_attention_forward(self, module: nn.Module, query, key, value, attention_mask, scaling, dropout=0.0, **kwargs):
29
- key_states = self.repeat_kv(key, module.num_key_value_groups)
30
- value_states = self.repeat_kv(value, module.num_key_value_groups)
31
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
32
-
33
- if attention_mask is not None:
34
- attn_mask = (1.0 - attention_mask) * float('-inf')
35
- attn_mask = attn_mask.to(dtype=query.dtype)
36
- attn_weights = attn_weights + attn_mask
37
-
38
- attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)
39
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
40
- attn_output = torch.matmul(attn_weights, value_states).transpose(1, 2).contiguous()
41
- return attn_output, attn_weights
42
-
43
- def rotate_half(self, x):
44
- x1 = x[..., : x.shape[-1] // 2]
45
- x2 = x[..., x.shape[-1] // 2:]
46
- return torch.cat((-x2, x1), dim=-1)
47
-
48
- def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1):
49
- cos = cos.unsqueeze(unsqueeze_dim)
50
- sin = sin.unsqueeze(unsqueeze_dim)
51
- q_embed = (q * cos) + (self.rotate_half(q) * sin)
52
- k_embed = (k * cos) + (self.rotate_half(k) * sin)
53
- return q_embed, k_embed
54
-
55
- def forward(self, hidden_states, position_embeddings, attention_mask=None, past_key_value=None, cache_position=None, **kwargs):
56
- input_shape = hidden_states.shape[:-1]
57
- hidden_shape = (*input_shape, -1, self.head_dim)
58
-
59
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
60
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
61
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
62
-
63
- cos, sin = position_embeddings
64
- query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin)
65
-
66
- if past_key_value is not None:
67
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
68
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
69
-
70
- attn_output, attn_weights = self.eager_attention_forward(
71
- self, query_states, key_states, value_states, attention_mask,
72
- dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs
73
- )
74
-
75
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
76
- return self.o_proj(attn_output), attn_weights
77
-
78
  class CustomTransformerConfig(PretrainedConfig):
79
  def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0,
80
  max_position_embeddings=4096, masking_type="bidirectional", **kwargs):
@@ -97,9 +31,6 @@ class CustomTransformerModel(PreTrainedModel):
97
  self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token=hf_token)
98
  self.llama.resize_token_embeddings(config.vocab_size)
99
 
100
- # for i, layer in enumerate(self.llama.model.layers):
101
- # layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking=config.masking_type)
102
-
103
  for param in self.llama.parameters():
104
  param.requires_grad = False
105
  for param in self.llama.lm_head.parameters():
@@ -113,7 +44,6 @@ class CustomTransformerModel(PreTrainedModel):
113
 
114
  self.llama = get_peft_model(self.llama, lora_config)
115
  self.llama.print_trainable_parameters()
116
- # self.llama = self.llama.to(torch.float16)
117
 
118
  def forward(self, input_ids, labels=None, **kwargs):
119
  batch_size, seq_len = input_ids.shape
 
9
 
10
  hf_token = os.getenv("HF_TOKEN")
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class CustomTransformerConfig(PretrainedConfig):
13
  def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0,
14
  max_position_embeddings=4096, masking_type="bidirectional", **kwargs):
 
31
  self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token=hf_token)
32
  self.llama.resize_token_embeddings(config.vocab_size)
33
 
 
 
 
34
  for param in self.llama.parameters():
35
  param.requires_grad = False
36
  for param in self.llama.lm_head.parameters():
 
44
 
45
  self.llama = get_peft_model(self.llama, lora_config)
46
  self.llama.print_trainable_parameters()
 
47
 
48
  def forward(self, input_ids, labels=None, **kwargs):
49
  batch_size, seq_len = input_ids.shape