File size: 1,620 Bytes
b54146b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torch.nn as nn
from .feature_extractor import FeatureExtractor
from .encoder import TransformerEncoder
from .decoder import TransformerDecoder

class ImageToDigitTransformer(nn.Module):
    def __init__(self, vocab_size=13, d_model=64, n_heads=4, ff_dim=128,
                 encoder_depth=4, decoder_depth=2, num_patches=16, max_seq_len=5):
        super().__init__()

        self.feature_extractor = FeatureExtractor(patch_size=14, emb_dim=d_model)
        self.encoder = TransformerEncoder(
            depth=encoder_depth,
            d_model=d_model,
            n_heads=n_heads,
            ff_dim=ff_dim,
            num_patches=num_patches
        )
        self.decoder = TransformerDecoder(
            vocab_size=vocab_size,
            max_len=max_seq_len,
            d_model=d_model,
            n_heads=n_heads,
            ff_dim=ff_dim,
            depth=decoder_depth
        )

    def forward(self, image_tensor, decoder_input_ids):
        """
        image_tensor: (B, 1, 56, 56)
        decoder_input_ids: (B, 5)
        Returns:
            logits: (B, 5, vocab_size)
        """
        patch_embeddings = self.feature_extractor(image_tensor)       # (B, 16, 64)
        encoder_output = self.encoder(patch_embeddings)               # (B, 16, 64)
        logits = self.decoder(decoder_input_ids, encoder_output)     # (B, 5, 13)
        return logits

if __name__ == '__main__':
    model = ImageToDigitTransformer()
    img = torch.randn(4, 1, 56, 56)
    tokens = torch.randint(0, 13, (4, 5))
    logits = model(img, tokens)
    print(logits.shape)  # Expected: (4, 5, 13)