File size: 4,757 Bytes
b54146b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch.nn as nn
import torch.nn.functional as F 
import math 

class DecoderLayer(nn.Module):
    def __init__(self, d_model=64, n_heads=4, ff_dim=128):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads 
        self.head_dim = d_model // n_heads 

        assert d_model % n_heads == 0, "d_model must be divisible by number of heads"

        # Self-attention: Q, K, V from decoder input
        self.self_attn_proj = nn.Linear(d_model, 3 * d_model)

        # Cross-attention: Q from decoder input, K/V from encoder output
        self.cross_attn_q = nn.Linear(d_model, d_model)
        self.cross_attn_kv = nn.Linear(d_model, 2 * d_model)

        # Output projections
        self.self_out = nn.Linear(d_model, d_model)
        self.cross_out = nn.Linear(d_model, d_model)

        # Feedforward MLP
        self.mlp = nn.Sequential(
            nn.Linear(d_model, ff_dim),
            nn.GELU(),
            nn.Linear(ff_dim, d_model)
        )

        # LayerNorms
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, enc_out):
            """
            x: (B, T, D) - decoder input embeddings
            enc_out: (B, N, D) - encoder outputs (image patch representations)
            Returns: (B, T, D)
            """
            B, T, D = x.shape
            _, N, _ = enc_out.shape

            # Masked Self-Attention 
            x_norm = self.norm1(x)
            qkv = self.self_attn_proj(x_norm).reshape(B, T, 3, self.n_heads, self.head_dim)
            qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, n_heads, T, head_dim)
            q, k, v = qkv[0], qkv[1], qkv[2]

            attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)  # (B, n_heads, T, T)

            # Causal mask: prevent attention to future positions
            mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0)  # (1, 1, T, T)
            attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))

            attn_weights = F.softmax(attn_scores, dim=-1)
            attn_out = attn_weights @ v  # (B, n_heads, T, head_dim)
            attn_out = attn_out.transpose(1, 2).reshape(B, T, D)
            attn_out = self.self_out(attn_out)
            x = x + attn_out  # Residual

            # Cross-Attention 
            x_norm = self.norm2(x)
            q = self.cross_attn_q(x_norm).reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # (B, n_heads, T, head_dim)
            kv = self.cross_attn_kv(enc_out).reshape(B, N, 2, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
            k, v = kv[0], kv[1]  # (B, n_heads, N, head_dim)

            cross_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)  # (B, n_heads, T, N)
            cross_weights = F.softmax(cross_scores, dim=-1)
            cross_out = cross_weights @ v  # (B, n_heads, T, head_dim)
            cross_out = cross_out.transpose(1, 2).reshape(B, T, D)
            cross_out = self.cross_out(cross_out)
            x = x + cross_out  # Residual

            # Feedforward 
            x_norm = self.norm3(x)
            x = x + self.mlp(x_norm)  # Residual

            return x


# implement the entire decoder 

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size=13, max_len=5, d_model=64, n_heads=4, ff_dim=128, depth=2):
        super().__init__()

        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, d_model))  # (1, 5, 64)

        self.layers = nn.ModuleList([
            DecoderLayer(d_model=d_model, n_heads=n_heads, ff_dim=ff_dim)
            for _ in range(depth)
        ])

        self.output_proj = nn.Linear(d_model, vocab_size)  # Final projection to logits

    def forward(self, decoder_input_ids, encoder_output):
        """
        decoder_input_ids: (B, T) token IDs
        encoder_output: (B, N, d_model) from image encoder
        returns: logits over vocab, shape (B, T, vocab_size)
        """
        x = self.token_embedding(decoder_input_ids)  # (B, T, d_model)
        x = x + self.pos_embedding[:, :x.size(1), :]  # Add positional embedding

        for layer in self.layers:
            x = layer(x, encoder_output)  # (B, T, d_model)

        logits = self.output_proj(x)  # (B, T, vocab_size)
        return logits
     

# quick test

if __name__ == "__main__":
    decoder = TransformerDecoder()
    decoder_input = torch.randint(0, 13, (4, 5))  # (B=4, T=5)
    encoder_out = torch.randn(4, 16, 64)  # (B=4, N=16, D=64)

    logits = decoder(decoder_input, encoder_out)
    print("Logits shape:", logits.shape)  # (4, 5, 13)