Spaces:
Running on Zero

lad / llama_diffusion_model.py
Ruurd's picture
Revert back to large LoRA layers
ef52a31 verified
raw
history blame
4.01 kB
import torch
import torch.nn as nn
from torch.amp import autocast
from transformers import AutoModelForCausalLM, PreTrainedModel, PretrainedConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from peft import LoraConfig, get_peft_model
import os
from typing import Optional, Tuple
hf_token = os.getenv("HF_TOKEN")
class CustomTransformerConfig(PretrainedConfig):
def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0,
max_position_embeddings=4096, masking_type="bidirectional", **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.prediction_chunk = prediction_chunk
self.max_position_embeddings = max_position_embeddings
self.input_size = prediction_chunk
self.masking_type = masking_type
class CustomTransformerModel(PreTrainedModel):
config_class = CustomTransformerConfig
def __init__(self, config):
super().__init__(config)
self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", torch_dtype=torch.float16, device_map="auto", token=hf_token)
self.llama.resize_token_embeddings(config.vocab_size)
for param in self.llama.parameters():
param.requires_grad = False
for param in self.llama.lm_head.parameters():
param.requires_grad = True
lora_config = LoraConfig(
r=512, lora_alpha=512, lora_dropout=0.0,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
bias="none", task_type=None
)
self.llama = get_peft_model(self.llama, lora_config)
self.llama.print_trainable_parameters()
def forward(self, input_ids, labels=None, **kwargs):
batch_size, seq_len = input_ids.shape
assert seq_len == self.config.prediction_chunk, f"Expected input length {self.config.prediction_chunk}, got {seq_len}"
# Build attention mask
device = input_ids.device
masking_type = getattr(self.config, "masking_type", "bidirectional")
if masking_type == 'bidirectional':
base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
elif masking_type == 'bidirectional_masked':
base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
base_mask.fill_diagonal_(False)
elif masking_type == 'unidirectional':
base_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
else:
raise ValueError(f"Unknown masking type: {self.config.masking_type}")
attention_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone()
attention_mask = attention_mask.to(dtype=torch.float32) # required for SDPA and Flash attention
with autocast("cuda", dtype=torch.float16):
outputs = self.llama(
input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
**kwargs
)
logits = outputs.logits[:, :, :self.config.vocab_size].view(batch_size, seq_len, self.config.vocab_size)
loss = None
if labels is not None:
assert labels.shape == (batch_size, seq_len), f"Labels shape mismatch: expected ({batch_size}, {seq_len}), got {labels.shape}"
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
def disable_dropout(model):
for name, module in model.named_modules():
if isinstance(module, nn.Dropout):
setattr(model, name, nn.Identity())
return model