|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(cfg.lm_hidden_dim)) |
|
self.eps = cfg.lm_rms_eps |
|
|
|
def forward(self, x): |
|
irms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps) |
|
x = x * irms * self.weight |
|
|
|
return x |
|
|
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
assert cfg.lm_hidden_dim % cfg.lm_n_heads == 0, "Hidden dimension must be divisible by number of heads" |
|
|
|
self.dim = cfg.lm_hidden_dim // cfg.lm_n_heads |
|
self.base = cfg.lm_re_base |
|
self.max_seq_len = cfg.lm_max_position_embeddings |
|
|
|
|
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
self.original_max_seq_len = cfg.lm_max_position_embeddings |
|
self.attention_scaling = cfg.lm_attn_scaling |
|
|
|
@torch.no_grad() |
|
def forward(self, position_ids): |
|
batch_size, seq_len = position_ids.shape |
|
|
|
max_seq = position_ids.max() + 1 |
|
if max_seq > self.original_max_seq_len: |
|
scale = max_seq / self.original_max_seq_len |
|
inv_freq = self.inv_freq / scale |
|
else: |
|
inv_freq = self.inv_freq |
|
|
|
|
|
|
|
flat_position_ids = position_ids.reshape(-1).float() |
|
|
|
|
|
freqs = flat_position_ids.unsqueeze(-1) * inv_freq.unsqueeze(0) |
|
|
|
|
|
freqs = freqs.reshape(batch_size, seq_len, -1) |
|
|
|
|
|
emb = torch.cat([freqs, freqs], dim=-1) |
|
|
|
|
|
cos = torch.cos(emb) * self.attention_scaling |
|
sin = torch.sin(emb) * self.attention_scaling |
|
|
|
return cos, sin |
|
|
|
|
|
def rotate_half(x): |
|
x1, x2 = x.chunk(2, dim=-1) |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_embd(q, k, cos, sin, unsqueeze_dim=1): |
|
|
|
|
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
|
|
|
|
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
|
return q_embed, k_embed |
|
|
|
|
|
|
|
class LanguageModelGroupedQueryAttention(nn.Module): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
|
|
self.n_heads = cfg.lm_n_heads |
|
self.n_kv_heads = cfg.lm_n_kv_heads |
|
self.embd_dim = cfg.lm_hidden_dim |
|
self.dropout = cfg.lm_dropout |
|
|
|
assert self.n_heads % self.n_kv_heads == 0, "n_heads must be divisible by n_kv_heads" |
|
assert self.embd_dim % self.n_heads == 0, "embd_dim must be divisible by num_heads" |
|
|
|
self.n_kv_groups = self.n_heads // self.n_kv_heads |
|
self.head_dim = self.embd_dim // self.n_heads |
|
|
|
self.q_proj = nn.Linear(self.embd_dim, self.embd_dim, bias=False) |
|
self.k_proj = nn.Linear(self.embd_dim, self.head_dim * self.n_kv_heads, bias=False) |
|
self.v_proj = nn.Linear(self.embd_dim, self.head_dim * self.n_kv_heads, bias=False) |
|
self.out_proj = nn.Linear(self.embd_dim, self.embd_dim, bias=False) |
|
|
|
self.attn_dropout = nn.Dropout(self.dropout) |
|
self.resid_dropout = nn.Dropout(self.dropout) |
|
|
|
|
|
self.sdpa = hasattr(torch.nn.functional, 'scaled_dot_product_attention') |
|
if not self.sdpa: |
|
print("Warning: scaled dot product attention not available, using standard attention in LM.") |
|
|
|
def forward(self, x, cos, sin, attention_mask=None): |
|
B, T, C = x.size() |
|
|
|
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) |
|
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
q, k = apply_rotary_pos_embd(q, k, cos, sin) |
|
|
|
k = k.repeat_interleave(self.n_kv_groups, dim=1) |
|
v = v.repeat_interleave(self.n_kv_groups, dim=1) |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
padding_mask = (attention_mask == 0).transpose(-1, -2) |
|
|
|
attention_mask = (1.0 - attention_mask) * torch.finfo(q.dtype).min |
|
|
|
if self.sdpa: |
|
y = torch.nn.functional.scaled_dot_product_attention( |
|
q, k, v, |
|
attn_mask=attention_mask, |
|
dropout_p=self.dropout if self.training else 0.0, |
|
is_causal=True |
|
) |
|
else: |
|
attn = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
causal_mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T) |
|
attn = attn.masked_fill(causal_mask == 0, float('-inf')) |
|
if attention_mask is not None: |
|
attn = attn + attention_mask |
|
|
|
attn = F.softmax(attn, dim=-1) |
|
attn = self.attn_dropout(attn) |
|
y = attn @ v |
|
|
|
if attention_mask is not None: |
|
y = y.masked_fill(padding_mask, 0.0) |
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
y = self.out_proj(y) |
|
y = self.resid_dropout(y) |
|
|
|
return y |
|
|
|
|
|
class LanguageModelMLP(nn.Module): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.embd_dim = cfg.lm_hidden_dim |
|
self.inter_dim = cfg.lm_inter_dim |
|
|
|
self.activation_fn = F.silu |
|
self.gate_proj = nn.Linear(self.embd_dim, self.inter_dim, bias=False) |
|
self.up_proj = nn.Linear(self.embd_dim, self.inter_dim, bias=False) |
|
self.down_proj = nn.Linear(self.inter_dim, self.embd_dim, bias=False) |
|
|
|
def forward(self, x): |
|
gate = self.activation_fn(self.gate_proj(x)) |
|
x = self.up_proj(x) |
|
x = self.down_proj(gate * x) |
|
|
|
return x |
|
|
|
|
|
class LanguageModelBlock(nn.Module): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.mlp = LanguageModelMLP(cfg) |
|
self.attn = LanguageModelGroupedQueryAttention(cfg) |
|
self.norm1 = RMSNorm(cfg) |
|
self.norm2 = RMSNorm(cfg) |
|
|
|
def forward(self, x, cos, sin, attention_mask=None): |
|
res = x |
|
x = self.norm1(x) |
|
x = self.attn(x, cos, sin, attention_mask) |
|
x = res + x |
|
|
|
res = x |
|
x = self.norm2(x) |
|
x = self.mlp(x) |
|
x = res + x |
|
|
|
return x |
|
|
|
|
|
class LanguageModel(nn.Module): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.lm_use_tokens = cfg.lm_use_tokens |
|
self.lm_tie_weights = cfg.lm_tie_weights |
|
|
|
self.token_embedding = nn.Embedding(cfg.lm_vocab_size, cfg.lm_hidden_dim) |
|
self.rotary_embd = RotaryEmbedding(cfg) |
|
self.blocks = nn.ModuleList([ |
|
LanguageModelBlock(cfg) for _ in range(cfg.lm_n_blocks) |
|
]) |
|
self.norm = RMSNorm(cfg) |
|
self.head = nn.Linear(cfg.lm_hidden_dim, cfg.lm_vocab_size, bias=False) |
|
if self.lm_tie_weights: |
|
self.head.weight = self.token_embedding.weight |
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, nn.Linear): |
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
if module.bias is not None: |
|
torch.nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.Embedding): |
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
elif isinstance(module, RMSNorm): |
|
module.weight.data.fill_(1.0) |
|
|
|
def forward(self, x, attention_mask=None): |
|
if self.lm_use_tokens: |
|
x = self.token_embedding(x) |
|
|
|
B , T, _ = x.size() |
|
|
|
|
|
position_ids = torch.arange(T, device=x.device).unsqueeze(0).expand(B, -1) |
|
cos, sin = self.rotary_embd(position_ids) |
|
|
|
for block in self.blocks: |
|
x = block(x, cos, sin, attention_mask) |
|
x = self.norm(x) |
|
|
|
if self.lm_use_tokens: |
|
x = self.head(x) |
|
|
|
return x |
|
|
|
@torch.no_grad() |
|
def generate(self, inputs, max_new_tokens=20): |
|
|
|
if inputs.dim() == 1: |
|
inputs = inputs.unsqueeze(0) |
|
|
|
generated = inputs.clone() |
|
|
|
for _ in range(max_new_tokens): |
|
|
|
outputs = self.forward(generated) |
|
last_output = outputs[:, -1, :] |
|
|
|
if self.lm_use_tokens: |
|
|
|
next_token = torch.argmax(last_output, dim=-1, keepdim=True) |
|
generated = torch.cat((generated, next_token), dim=-1) |
|
else: |
|
|
|
next_token_embedding = last_output.unsqueeze(1) |
|
generated = torch.cat((generated, next_token_embedding), dim=1) |
|
|
|
|
|
|
|
return generated |
|
|
|
|
|
@classmethod |
|
def from_pretrained(cls, cfg): |
|
from transformers import AutoConfig |
|
from huggingface_hub import hf_hub_download |
|
import safetensors |
|
import torch.nn.init as init |
|
|
|
|
|
hf_config = AutoConfig.from_pretrained(cfg.lm_model_type) |
|
|
|
|
|
original_vocab_size = hf_config.vocab_size |
|
|
|
|
|
|
|
cfg.lm_hidden_dim = hf_config.hidden_size |
|
cfg.lm_inter_dim = hf_config.intermediate_size |
|
cfg.lm_rms_eps = hf_config.rms_norm_eps |
|
cfg.lm_re_base = hf_config.rope_theta |
|
cfg.lm_max_position_embeddings = hf_config.max_position_embeddings |
|
|
|
if hasattr(cfg, 'lm_vocab_size'): |
|
if cfg.lm_vocab_size < original_vocab_size: |
|
raise ValueError(f"Config vocab size ({cfg.lm_vocab_size}) is smaller than pretrained model vocab size ({original_vocab_size})") |
|
|
|
else: |
|
|
|
cfg.lm_vocab_size = original_vocab_size |
|
|
|
|
|
cfg.lm_n_heads = hf_config.num_attention_heads |
|
cfg.lm_n_kv_heads = hf_config.num_key_value_heads |
|
cfg.lm_dropout = hf_config.attention_dropout |
|
cfg.lm_n_blocks = hf_config.num_hidden_layers |
|
|
|
|
|
model = cls(cfg) |
|
safetensors_file = hf_hub_download(repo_id=cfg.lm_model_type, filename="model.safetensors") |
|
|
|
sd = model.state_dict() |
|
|
|
mapping = { |
|
'model.embed_tokens.weight': 'token_embedding.weight', |
|
'model.norm.weight': 'norm.weight' |
|
} |
|
|
|
for i in range(cfg.lm_n_blocks): |
|
layer_prefix = f'model.layers.{i}.' |
|
block_prefix = f'blocks.{i}.' |
|
|
|
mapping.update({ |
|
f"{layer_prefix}self_attn.q_proj.weight": f"{block_prefix}attn.q_proj.weight", |
|
f"{layer_prefix}self_attn.k_proj.weight": f"{block_prefix}attn.k_proj.weight", |
|
f"{layer_prefix}self_attn.v_proj.weight": f"{block_prefix}attn.v_proj.weight", |
|
f"{layer_prefix}self_attn.o_proj.weight": f"{block_prefix}attn.out_proj.weight", |
|
f"{layer_prefix}mlp.gate_proj.weight": f"{block_prefix}mlp.gate_proj.weight", |
|
f"{layer_prefix}mlp.up_proj.weight": f"{block_prefix}mlp.up_proj.weight", |
|
f"{layer_prefix}mlp.down_proj.weight": f"{block_prefix}mlp.down_proj.weight", |
|
f"{layer_prefix}input_layernorm.weight": f"{block_prefix}norm1.weight", |
|
f"{layer_prefix}post_attention_layernorm.weight": f"{block_prefix}norm2.weight" |
|
}) |
|
|
|
|
|
has_extended_embeddings = False |
|
with safetensors.safe_open(filename=safetensors_file, framework="pt", device="cpu") as f: |
|
for hf_key, our_key in mapping.items(): |
|
if hf_key in f.keys() and our_key in sd: |
|
tensor = f.get_tensor(hf_key) |
|
|
|
|
|
if hf_key == 'model.embed_tokens.weight' and tensor.shape[0] != sd[our_key].shape[0]: |
|
has_extended_embeddings = True |
|
print(f"Extending token embeddings from {tensor.shape} to {sd[our_key].shape}") |
|
|
|
|
|
sd[our_key][:tensor.shape[0]].copy_(tensor) |
|
|
|
|
|
std = 0.02 |
|
init.normal_(sd[our_key][tensor.shape[0]:], mean=0.0, std=std) |
|
|
|
print(f"Initialized {sd[our_key].shape[0] - tensor.shape[0]} new token embeddings") |
|
sd['head.weight'].copy_(sd[our_key]) |
|
elif tensor.shape == sd[our_key].shape: |
|
sd[our_key].copy_(tensor) |
|
else: |
|
print(f"Shape mismatch for {hf_key} -> {our_key}: {tensor.shape} vs {sd[our_key].shape}") |
|
else: |
|
if hf_key not in f.keys(): |
|
print(f"Warning: Key {hf_key} not found in safetensors file") |
|
if our_key not in sd: |
|
print(f"Warning: Key {our_key} not found in model state dict") |
|
|
|
|
|
model.load_state_dict(sd) |
|
|
|
|
|
if has_extended_embeddings and hasattr(model, 'head') and 'head.weight' in sd: |
|
|
|
|
|
with safetensors.safe_open(filename=safetensors_file, framework="pt", device="cpu") as f: |
|
if 'lm_head.weight' in f.keys(): |
|
lm_head = f.get_tensor('lm_head.weight') |
|
if lm_head.shape[0] != sd['head.weight'].shape[0]: |
|
print(f"Extending LM head from {lm_head.shape} to {sd['head.weight'].shape}") |
|
|
|
sd['head.weight'][:lm_head.shape[0]].copy_(lm_head) |
|
|
|
std = 0.02 |
|
init.normal_(sd['head.weight'][lm_head.shape[0]:], mean=0.0, std=std) |
|
|
|
model.load_state_dict(sd) |
|
|
|
|
|
if cfg.lm_tie_weights and hasattr(model, 'head') and hasattr(model, 'token_embedding'): |
|
model.head.weight = model.token_embedding.weight |
|
|
|
|
|
print(f"Successfully loaded {cfg.lm_model_type} weights from safetensors. Model has {sum(p.numel() for p in model.parameters()):,} parameters.") |
|
return model |
|
|