Spaces:
Sleeping
Sleeping
File size: 4,757 Bytes
b54146b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class DecoderLayer(nn.Module):
def __init__(self, d_model=64, n_heads=4, ff_dim=128):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
assert d_model % n_heads == 0, "d_model must be divisible by number of heads"
# Self-attention: Q, K, V from decoder input
self.self_attn_proj = nn.Linear(d_model, 3 * d_model)
# Cross-attention: Q from decoder input, K/V from encoder output
self.cross_attn_q = nn.Linear(d_model, d_model)
self.cross_attn_kv = nn.Linear(d_model, 2 * d_model)
# Output projections
self.self_out = nn.Linear(d_model, d_model)
self.cross_out = nn.Linear(d_model, d_model)
# Feedforward MLP
self.mlp = nn.Sequential(
nn.Linear(d_model, ff_dim),
nn.GELU(),
nn.Linear(ff_dim, d_model)
)
# LayerNorms
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x, enc_out):
"""
x: (B, T, D) - decoder input embeddings
enc_out: (B, N, D) - encoder outputs (image patch representations)
Returns: (B, T, D)
"""
B, T, D = x.shape
_, N, _ = enc_out.shape
# Masked Self-Attention
x_norm = self.norm1(x)
qkv = self.self_attn_proj(x_norm).reshape(B, T, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, n_heads, T, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, T, T)
# Causal mask: prevent attention to future positions
mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0) # (1, 1, T, T)
attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))
attn_weights = F.softmax(attn_scores, dim=-1)
attn_out = attn_weights @ v # (B, n_heads, T, head_dim)
attn_out = attn_out.transpose(1, 2).reshape(B, T, D)
attn_out = self.self_out(attn_out)
x = x + attn_out # Residual
# Cross-Attention
x_norm = self.norm2(x)
q = self.cross_attn_q(x_norm).reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T, head_dim)
kv = self.cross_attn_kv(enc_out).reshape(B, N, 2, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1] # (B, n_heads, N, head_dim)
cross_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, T, N)
cross_weights = F.softmax(cross_scores, dim=-1)
cross_out = cross_weights @ v # (B, n_heads, T, head_dim)
cross_out = cross_out.transpose(1, 2).reshape(B, T, D)
cross_out = self.cross_out(cross_out)
x = x + cross_out # Residual
# Feedforward
x_norm = self.norm3(x)
x = x + self.mlp(x_norm) # Residual
return x
# implement the entire decoder
class TransformerDecoder(nn.Module):
def __init__(self, vocab_size=13, max_len=5, d_model=64, n_heads=4, ff_dim=128, depth=2):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.pos_embedding = nn.Parameter(torch.randn(1, max_len, d_model)) # (1, 5, 64)
self.layers = nn.ModuleList([
DecoderLayer(d_model=d_model, n_heads=n_heads, ff_dim=ff_dim)
for _ in range(depth)
])
self.output_proj = nn.Linear(d_model, vocab_size) # Final projection to logits
def forward(self, decoder_input_ids, encoder_output):
"""
decoder_input_ids: (B, T) token IDs
encoder_output: (B, N, d_model) from image encoder
returns: logits over vocab, shape (B, T, vocab_size)
"""
x = self.token_embedding(decoder_input_ids) # (B, T, d_model)
x = x + self.pos_embedding[:, :x.size(1), :] # Add positional embedding
for layer in self.layers:
x = layer(x, encoder_output) # (B, T, d_model)
logits = self.output_proj(x) # (B, T, vocab_size)
return logits
# quick test
if __name__ == "__main__":
decoder = TransformerDecoder()
decoder_input = torch.randint(0, 13, (4, 5)) # (B=4, T=5)
encoder_out = torch.randn(4, 16, 64) # (B=4, N=16, D=64)
logits = decoder(decoder_input, encoder_out)
print("Logits shape:", logits.shape) # (4, 5, 13) |