|
""" |
|
An adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch. |
|
Original source: https://github.com/karpathy/nanoGPT |
|
|
|
Original License: |
|
MIT License |
|
|
|
Copyright (c) 2022 Andrej Karpathy |
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy |
|
of this software and associated documentation files (the "Software"), to deal |
|
in the Software without restriction, including without limitation the rights |
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
copies of the Software, and to permit persons to whom the Software is |
|
furnished to do so, subject to the following conditions: |
|
|
|
The above copyright notice and this permission notice shall be included in all |
|
copies or substantial portions of the Software. |
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
SOFTWARE. |
|
|
|
Original comments: |
|
Full definition of a GPT Language Model, all of it in this single file. |
|
References: |
|
1) the official GPT-2 TensorFlow implementation released by OpenAI: |
|
https://github.com/openai/gpt-2/blob/master/src/model.py |
|
2) huggingface/transformers PyTorch implementation: |
|
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py |
|
""" |
|
|
|
import math |
|
from dataclasses import dataclass |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
def new_gelu(x): |
|
""" |
|
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). |
|
Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 |
|
""" |
|
return ( |
|
0.5 |
|
* x |
|
* ( |
|
1.0 |
|
+ torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))) |
|
) |
|
) |
|
|
|
|
|
class CausalSelfAttention(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
assert config.n_embd % config.n_head == 0 |
|
|
|
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) |
|
|
|
self.c_proj = nn.Linear(config.n_embd, config.n_embd) |
|
|
|
self.attn_dropout = nn.Dropout(config.dropout) |
|
self.resid_dropout = nn.Dropout(config.dropout) |
|
|
|
self.register_buffer( |
|
"bias", |
|
torch.tril(torch.ones(config.block_size, config.block_size)).view( |
|
1, 1, config.block_size, config.block_size |
|
), |
|
) |
|
self.n_head = config.n_head |
|
self.n_embd = config.n_embd |
|
|
|
def forward(self, x): |
|
( |
|
B, |
|
T, |
|
C, |
|
) = x.size() |
|
|
|
|
|
q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
|
k = k.view(B, T, self.n_head, C // self.n_head).transpose( |
|
1, 2 |
|
) |
|
q = q.view(B, T, self.n_head, C // self.n_head).transpose( |
|
1, 2 |
|
) |
|
v = v.view(B, T, self.n_head, C // self.n_head).transpose( |
|
1, 2 |
|
) |
|
|
|
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
|
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) |
|
att = F.softmax(att, dim=-1) |
|
att = self.attn_dropout(att) |
|
y = att @ v |
|
y = ( |
|
y.transpose(1, 2).contiguous().view(B, T, C) |
|
) |
|
|
|
|
|
y = self.resid_dropout(self.c_proj(y)) |
|
return y |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) |
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) |
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
def forward(self, x): |
|
x = self.c_fc(x) |
|
x = new_gelu(x) |
|
x = self.c_proj(x) |
|
x = self.dropout(x) |
|
return x |
|
|
|
|
|
class Block(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.ln_1 = nn.LayerNorm(config.n_embd) |
|
self.attn = CausalSelfAttention(config) |
|
self.ln_2 = nn.LayerNorm(config.n_embd) |
|
self.mlp = MLP(config) |
|
|
|
def forward(self, x): |
|
x = x + self.attn(self.ln_1(x)) |
|
x = x + self.mlp(self.ln_2(x)) |
|
return x |
|
|
|
|
|
@dataclass |
|
class GPTConfig: |
|
block_size: int = 1024 |
|
input_dim: int = 256 |
|
output_dim: int = 256 |
|
n_layer: int = 12 |
|
n_head: int = 12 |
|
n_embd: int = 768 |
|
dropout: float = 0.1 |
|
|
|
|
|
class GPT(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
assert config.input_dim is not None |
|
assert config.output_dim is not None |
|
assert config.block_size is not None |
|
self.config = config |
|
|
|
self.transformer = nn.ModuleDict( |
|
dict( |
|
wte=nn.Linear(config.input_dim, config.n_embd), |
|
wpe=nn.Embedding(config.block_size, config.n_embd), |
|
drop=nn.Dropout(config.dropout), |
|
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
|
ln_f=nn.LayerNorm(config.n_embd), |
|
) |
|
) |
|
self.lm_head = nn.Linear(config.n_embd, config.output_dim, bias=False) |
|
|
|
self.apply(self._init_weights) |
|
for pn, p in self.named_parameters(): |
|
if pn.endswith("c_proj.weight"): |
|
torch.nn.init.normal_( |
|
p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) |
|
) |
|
|
|
|
|
n_params = sum(p.numel() for p in self.parameters()) |
|
print("number of parameters: %.2fM" % (n_params / 1e6,)) |
|
|
|
def forward(self, input, targets=None): |
|
device = input.device |
|
b, t, d = input.size() |
|
assert ( |
|
t <= self.config.block_size |
|
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze( |
|
0 |
|
) |
|
|
|
|
|
tok_emb = self.transformer.wte( |
|
input |
|
) |
|
pos_emb = self.transformer.wpe( |
|
pos |
|
) |
|
x = self.transformer.drop(tok_emb + pos_emb) |
|
for block in self.transformer.h: |
|
x = block(x) |
|
x = self.transformer.ln_f(x) |
|
logits = self.lm_head(x) |
|
return logits |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, nn.Linear): |
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
if module.bias is not None: |
|
torch.nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.Embedding): |
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
elif isinstance(module, nn.LayerNorm): |
|
torch.nn.init.zeros_(module.bias) |
|
torch.nn.init.ones_(module.weight) |
|
|
|
def crop_block_size(self, block_size): |
|
assert block_size <= self.config.block_size |
|
self.config.block_size = block_size |
|
self.transformer.wpe.weight = nn.Parameter( |
|
self.transformer.wpe.weight[:block_size] |
|
) |
|
for block in self.transformer.h: |
|
block.attn.bias = block.attn.bias[:, :, :block_size, :block_size] |
|
|
|
def configure_optimizers(self, weight_decay, learning_rate, betas): |
|
""" |
|
This long function is unfortunately doing something very simple and is being very defensive: |
|
We are separating out all parameters of the model into two buckets: those that will experience |
|
weight decay for regularization and those that won't (biases, and layernorm/embedding weights). |
|
We are then returning the PyTorch optimizer object. |
|
""" |
|
|
|
|
|
decay = set() |
|
no_decay = set() |
|
whitelist_weight_modules = (torch.nn.Linear,) |
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) |
|
for mn, m in self.named_modules(): |
|
for pn, p in m.named_parameters(): |
|
fpn = "%s.%s" % (mn, pn) if mn else pn |
|
if pn.endswith("bias"): |
|
|
|
no_decay.add(fpn) |
|
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): |
|
|
|
decay.add(fpn) |
|
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): |
|
|
|
no_decay.add(fpn) |
|
|
|
|
|
param_dict = {pn: p for pn, p in self.named_parameters()} |
|
inter_params = decay & no_decay |
|
union_params = decay | no_decay |
|
assert len(inter_params) == 0, ( |
|
"parameters %s made it into both decay/no_decay sets!" |
|
% (str(inter_params),) |
|
) |
|
assert len(param_dict.keys() - union_params) == 0, ( |
|
"parameters %s were not separated into either decay/no_decay set!" |
|
% (str(param_dict.keys() - union_params),) |
|
) |
|
|
|
|
|
optim_groups = [ |
|
{ |
|
"params": [param_dict[pn] for pn in sorted(list(decay))], |
|
"weight_decay": weight_decay, |
|
}, |
|
{ |
|
"params": [param_dict[pn] for pn in sorted(list(no_decay))], |
|
"weight_decay": 0.0, |
|
}, |
|
] |
|
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) |
|
return optimizer |
|
|