nico-x commited on
Commit
b54146b
·
0 Parent(s):

codebase withouth model

Browse files
Files changed (16) hide show
  1. .gitignore +19 -0
  2. README.md +54 -0
  3. app.py +8 -0
  4. app/gradio_app.py +85 -0
  5. app/preprocess.py +26 -0
  6. dataset.py +72 -0
  7. eval.py +73 -0
  8. launch.py +8 -0
  9. model/decoder.py +125 -0
  10. model/encoder.py +93 -0
  11. model/feature_extractor.py +31 -0
  12. model/model.py +46 -0
  13. requirements.txt +6 -0
  14. task.txt +48 -0
  15. train.py +65 -0
  16. utils/tokenizer.py +33 -0
.gitignore ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / cache
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.so
5
+ *.ipynb_checkpoints
6
+
7
+ # Virtual environments
8
+ .venv/
9
+ env/
10
+ venv/
11
+
12
+ # System files
13
+ .DS_Store
14
+
15
+ # PyTorch checkpoints
16
+ *.pt
17
+
18
+ # Gradio session files
19
+ gradio_cached_examples/
README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Transformer MNIST 2×2 — Image-to-Sequence Prediction
2
+
3
+ This project implements a minimal Transformer-based model that takes a 2×2 grid of MNIST digits as input and autoregressively predicts the corresponding 4-digit sequence. It serves as a practical deep dive into the inner workings of the Transformer architecture and basic multimodality concepts, combining vision (image patches) with language modeling (digit sequences).
4
+
5
+ ## 1. Project Overview
6
+
7
+ The goal is to understand how a vanilla Transformer encoder-decoder can be applied to a simple multimodal task: mapping an image input to a discrete token sequence. This project focuses on building each architectural component from scratch and wiring them together cleanly.
8
+
9
+ ## 2. Task Definition
10
+
11
+ - **Input:** a 2×2 grid composed of 4 random MNIST digits, forming a 56×56 grayscale image.
12
+ - **Output:** the 4-digit sequence corresponding to the digits in the grid (top-left → bottom-right).
13
+ - **Modeling approach:** sequence-to-sequence using an autoregressive decoder with special `<start>` and `<finish>` tokens.
14
+
15
+ ## 3. Model Architecture
16
+
17
+ The model follows a clean encoder-decoder Transformer architecture:
18
+
19
+ - **Feature Extractor:** splits the 56×56 image into 16 non-overlapping patches of 14×14 pixels and projects each to a 64-dimensional embedding.
20
+ - **Transformer Encoder:** processes the 16 patch embeddings using standard multi-head self-attention, positional embeddings, and MLP blocks.
21
+ - **Transformer Decoder:** autoregressively predicts the digit sequence:
22
+ - Uses masked self-attention over token embeddings
23
+ - Attends to encoder output via cross-attention
24
+ - Outputs a sequence of logits over a vocabulary of 13 tokens (digits 0–9, `<start>`, `<finish>`)
25
+ - **Tokenizer:** handles token ↔ digit conversions and input preparation.
26
+
27
+ ## 4. Training Setup
28
+
29
+ - **Dataset:** MNIST, wrapped into a custom `MNIST_2x2` PyTorch dataset that returns the stitched image and 4-digit target.
30
+ - **Batch size:** 64
31
+ - **Epochs:** 10
32
+ - **Loss:** `CrossEntropyLoss` over vocabulary tokens
33
+ - **Optimizer:** Adam
34
+ - **Hardware:** Apple M4 with `mps` acceleration
35
+ - **Logging:** `tqdm` per-batch loss tracking for clear training progress
36
+
37
+ ## 5. Evaluation
38
+
39
+ Evaluation is done on the held-out MNIST test set using greedy decoding:
40
+
41
+ - Starts with <start> token
42
+ - Predicts one token at a time (no teacher forcing)
43
+ - Stops after 4 tokens or if <finish> is predicted
44
+
45
+ ### Evaluation Metrics
46
+
47
+ - **Sequence accuracy:** % of samples where all 4 digits are predicted correctly
48
+ - **Per-digit accuracy:** % of individual digits predicted correctly across all positions
49
+
50
+ ### final results after 10 epochs of training
51
+
52
+ - **training loss at epoch 10:** 0.0101
53
+ - **Sequence accuracy:** 93.77% on held-out test set
54
+ - **Per digit accuracy:** 98.43% on held-out test set
app.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # entrpoint for HuggingFace Space
2
+
3
+ import sys
4
+ sys.path.append('.')
5
+
6
+ from app.gradio_app import demo
7
+
8
+ demo.launch()
app/gradio_app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('.')
3
+
4
+ import gradio as gr
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image, ImageDraw
8
+
9
+ from model.model import ImageToDigitTransformer
10
+ from utils.tokenizer import START, FINISH, decode
11
+ from app.preprocess import preprocess_canvases
12
+
13
+ # Load model
14
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
15
+ model = ImageToDigitTransformer(vocab_size=13).to(device)
16
+ model.load_state_dict(torch.load("checkpoints/transformer_mnist.pt", map_location=device))
17
+ model.eval()
18
+
19
+ def split_into_quadrants(image):
20
+ """Split a PIL Image or numpy array into 4 quadrants (TL, TR, BL, BR)."""
21
+ if isinstance(image, np.ndarray):
22
+ image = Image.fromarray(image)
23
+ w, h = image.size
24
+ return [
25
+ np.array(image.crop((0, 0, w // 2, h // 2))),
26
+ np.array(image.crop((w // 2, 0, w, h // 2))),
27
+ np.array(image.crop((0, h // 2, w // 2, h))),
28
+ np.array(image.crop((w // 2, h // 2, w, h))),
29
+ ]
30
+
31
+ def predict_digit_sequence(editor_data):
32
+ """Predicts 4-digit sequence from 2×2 canvas image."""
33
+ if editor_data is None or "composite" not in editor_data:
34
+ return "No image provided."
35
+ img = editor_data["composite"]
36
+ quadrants = split_into_quadrants(img)
37
+ image_tensor = preprocess_canvases(quadrants).to(device)
38
+
39
+ decoded = [START]
40
+ for _ in range(4):
41
+ input_ids = torch.tensor(decoded, dtype=torch.long).unsqueeze(0).to(device)
42
+ with torch.no_grad():
43
+ logits = model(image_tensor, input_ids)
44
+ next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
45
+ decoded.append(next_token)
46
+ if next_token == FINISH:
47
+ break
48
+
49
+ pred = decoded[1:]
50
+ return "".join(decode(pred[:4]))
51
+
52
+ def create_black_canvas(size=(800, 800)):
53
+ """Create a black canvas with a 2×2 light gray grid overlay."""
54
+ img = Image.new("L", size, color=0)
55
+ draw = ImageDraw.Draw(img)
56
+ w, h = size
57
+ draw.line([(w // 2, 0), (w // 2, h)], fill=128, width=2)
58
+ draw.line([(0, h // 2), (w, h // 2)], fill=128, width=2)
59
+ return img
60
+
61
+ # === UI ===
62
+ canvas_size = 800
63
+
64
+ with gr.Blocks() as demo:
65
+ gr.Markdown("## Draw 4 digits in a 2×2 grid using a white brush")
66
+
67
+ canvas = gr.ImageEditor(
68
+ label="White brush only on black canvas (no uploads)",
69
+ value=create_black_canvas(),
70
+ image_mode="L",
71
+ height=canvas_size,
72
+ width=canvas_size,
73
+ sources=[], # disables uploads
74
+ type="pil",
75
+ brush=gr.Brush(colors=["#FFFFFF"], default_color="#FFFFFF", default_size=15, color_mode="fixed")
76
+ )
77
+
78
+ predict_btn = gr.Button("Predict")
79
+ clear_btn = gr.Button("Erase")
80
+ output = gr.Textbox(label="Predicted 4-digit sequence", interactive=True)
81
+
82
+ predict_btn.click(fn=predict_digit_sequence, inputs=[canvas], outputs=[output])
83
+ clear_btn.click(fn=lambda: create_black_canvas(), inputs=[], outputs=[canvas])
84
+
85
+ demo.launch()
app/preprocess.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+
5
+ def preprocess_canvases(images):
6
+ """
7
+ Takes a list of 4 RGBA images (top-left, top-right, bottom-left, bottom-right),
8
+ resizes to 28x28, converts to grayscale, stitches to (1, 56, 56) tensor.
9
+ """
10
+ assert len(images) == 4, "Expected 4 images"
11
+
12
+ digits = []
13
+ for img in images:
14
+ img = Image.fromarray(img).convert("L") # convert to grayscale
15
+ img = img.resize((28, 28))
16
+ img = np.array(img).astype(np.float32) / 255.0 # scale to [0, 1]
17
+ digits.append(img)
18
+
19
+ top = np.hstack([digits[0], digits[1]])
20
+ bottom = np.hstack([digits[2], digits[3]])
21
+ grid = np.vstack([top, bottom]) # shape (56, 56)
22
+
23
+ # Normalize like MNIST
24
+ grid = (grid - 0.1307) / 0.3081
25
+ grid = torch.tensor(grid).unsqueeze(0).unsqueeze(0) # shape (1, 1, 56, 56)
26
+ return grid
dataset.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from torchvision import datasets, transforms
4
+
5
+ from utils.tokenizer import prepare_decoder_labels, encode, decode
6
+
7
+ class MNIST_2x2(Dataset):
8
+ def __init__(self, base_dataset, transform=None, seed=42):
9
+ self.base_dataset = base_dataset
10
+ self.transform = transform
11
+ self.length = len(base_dataset)
12
+
13
+ torch.manual_seed(seed)
14
+ self.index_map = [
15
+ torch.randint(0, self.length, (4,))
16
+ for _ in range(self.length)
17
+ ]
18
+
19
+ def __len__(self):
20
+ return self.length
21
+
22
+ def __getitem__(self, idx):
23
+ indices = self.index_map[idx]
24
+
25
+ images = [self.base_dataset[i][0] for i in indices]
26
+ top_row = torch.cat([images[0], images[1]], dim=2)
27
+ bottom_row = torch.cat([images[2], images[3]], dim=2)
28
+ grid_image = torch.cat([top_row, bottom_row], dim=1)
29
+
30
+ labels = [self.base_dataset[i][1] for i in indices]
31
+ decoder_input_ids, decoder_target_ids = prepare_decoder_labels(labels)
32
+ decoder_input = torch.tensor(decoder_input_ids, dtype=torch.long)
33
+ decoder_target = torch.tensor(decoder_target_ids, dtype=torch.long)
34
+
35
+ return grid_image, decoder_input, decoder_target
36
+
37
+ # test the dataset and visualize a few samples
38
+ if __name__ == "__main__":
39
+
40
+ import matplotlib.pyplot as plt
41
+
42
+ transform = transforms.Compose([
43
+ transforms.ToTensor(),
44
+ transforms.Normalize((0.1307,), (0.3081,))
45
+ ])
46
+
47
+ mnist_train = datasets.MNIST('./data', train=True, download=True, transform=transform)
48
+ mnist_test = datasets.MNIST('./data', train=False, download=True, transform=transform)
49
+
50
+ train_dataset = MNIST_2x2(mnist_train, seed=42)
51
+ test_dataset = MNIST_2x2(mnist_test, seed=42)
52
+
53
+
54
+ def show_grid_image(grid_tensor, decoder_target):
55
+ # Undo normalization for visualization
56
+ img = grid_tensor.clone()
57
+ img = img * 0.3081 + 0.1307
58
+ img = img.squeeze().numpy()
59
+
60
+ # Decode token IDs into digit strings
61
+ digits = decode(decoder_target.tolist()[:-1]) # Remove <finish> for display
62
+ label_str = " ".join(digits)
63
+
64
+ plt.imshow(img, cmap="gray")
65
+ plt.title(f"Digits: {label_str}")
66
+ plt.axis("off")
67
+ plt.show()
68
+
69
+ # Visualize a few samples
70
+ for i in range(3):
71
+ grid_image, decoder_input, decoder_target = train_dataset[i]
72
+ show_grid_image(grid_image, decoder_target)
eval.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from torchvision import datasets, transforms
4
+ from tqdm import tqdm
5
+
6
+ from dataset import MNIST_2x2
7
+ from model.model import ImageToDigitTransformer
8
+ from utils.tokenizer import START, FINISH, decode
9
+
10
+ # device
11
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
12
+ print(f"Using device: {device}")
13
+
14
+ # config
15
+ VOCAB_SIZE = 13
16
+ MAX_LEN = 5 # length of decoder input: [<start>, d1, d2, d3, d4]
17
+ SEQ_LEN = 4 # number of predicted digits
18
+
19
+ transform = transforms.Compose([
20
+ transforms.ToTensor(),
21
+ transforms.Normalize((0.1307,), (0.3081,))
22
+ ])
23
+
24
+ mnist_test = datasets.MNIST('./data', train=False, download=True, transform=transform)
25
+ test_dataset = MNIST_2x2(mnist_test, seed=42)
26
+ test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
27
+
28
+ model = ImageToDigitTransformer(vocab_size=VOCAB_SIZE).to(device)
29
+ model.load_state_dict(torch.load("checkpoints/transformer_mnist.pt", map_location=device))
30
+ model.eval()
31
+
32
+ # Evaluation Loop
33
+ correct_sequences = 0
34
+ digit_correct = 0
35
+ digit_total = 0
36
+
37
+ with torch.no_grad():
38
+ loop = tqdm(test_loader, desc="Evaluating", leave=False)
39
+
40
+ for image, _, target_ids in loop:
41
+ image = image.to(device)
42
+ target_ids = target_ids.squeeze(0).tolist()[:-1] # remove <finish>
43
+
44
+ decoded = [START]
45
+ for _ in range(SEQ_LEN):
46
+ input_ids = torch.tensor(decoded, dtype=torch.long).unsqueeze(0).to(device)
47
+ logits = model(image, input_ids)
48
+ next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
49
+ decoded.append(next_token)
50
+ if next_token == FINISH:
51
+ break
52
+
53
+ pred = decoded[1:][:SEQ_LEN]
54
+ target = target_ids
55
+
56
+ if pred == target:
57
+ correct_sequences += 1
58
+ digit_correct += sum(p == t for p, t in zip(pred, target))
59
+ digit_total += len(target)
60
+
61
+ seq_acc = 100.0 * correct_sequences / (digit_total // SEQ_LEN)
62
+ digit_acc = 100.0 * digit_correct / digit_total
63
+ loop.set_postfix(seq_acc=f"{seq_acc:.2f}%", digit_acc=f"{digit_acc:.2f}%")
64
+
65
+
66
+ # final results
67
+ total_samples = len(test_loader)
68
+ seq_acc = 100.0 * correct_sequences / total_samples
69
+ digit_acc = 100.0 * digit_correct / digit_total
70
+
71
+ print(f"\nFinal Evaluation Results:")
72
+ print(f" Sequence accuracy: {seq_acc:.2f}%")
73
+ print(f" Per-digit accuracy: {digit_acc:.2f}%")
launch.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # entrpoint for HuggingFace Space
2
+
3
+ import sys
4
+ sys.path.append('.')
5
+
6
+ from app.gradio_app import demo
7
+
8
+ demo.launch()
model/decoder.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ class DecoderLayer(nn.Module):
7
+ def __init__(self, d_model=64, n_heads=4, ff_dim=128):
8
+ super().__init__()
9
+ self.d_model = d_model
10
+ self.n_heads = n_heads
11
+ self.head_dim = d_model // n_heads
12
+
13
+ assert d_model % n_heads == 0, "d_model must be divisible by number of heads"
14
+
15
+ # Self-attention: Q, K, V from decoder input
16
+ self.self_attn_proj = nn.Linear(d_model, 3 * d_model)
17
+
18
+ # Cross-attention: Q from decoder input, K/V from encoder output
19
+ self.cross_attn_q = nn.Linear(d_model, d_model)
20
+ self.cross_attn_kv = nn.Linear(d_model, 2 * d_model)
21
+
22
+ # Output projections
23
+ self.self_out = nn.Linear(d_model, d_model)
24
+ self.cross_out = nn.Linear(d_model, d_model)
25
+
26
+ # Feedforward MLP
27
+ self.mlp = nn.Sequential(
28
+ nn.Linear(d_model, ff_dim),
29
+ nn.GELU(),
30
+ nn.Linear(ff_dim, d_model)
31
+ )
32
+
33
+ # LayerNorms
34
+ self.norm1 = nn.LayerNorm(d_model)
35
+ self.norm2 = nn.LayerNorm(d_model)
36
+ self.norm3 = nn.LayerNorm(d_model)
37
+
38
+ def forward(self, x, enc_out):
39
+ """
40
+ x: (B, T, D) - decoder input embeddings
41
+ enc_out: (B, N, D) - encoder outputs (image patch representations)
42
+ Returns: (B, T, D)
43
+ """
44
+ B, T, D = x.shape
45
+ _, N, _ = enc_out.shape
46
+
47
+ # Masked Self-Attention
48
+ x_norm = self.norm1(x)
49
+ qkv = self.self_attn_proj(x_norm).reshape(B, T, 3, self.n_heads, self.head_dim)
50
+ qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, n_heads, T, head_dim)
51
+ q, k, v = qkv[0], qkv[1], qkv[2]
52
+
53
+ attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, T, T)
54
+
55
+ # Causal mask: prevent attention to future positions
56
+ mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0) # (1, 1, T, T)
57
+ attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))
58
+
59
+ attn_weights = F.softmax(attn_scores, dim=-1)
60
+ attn_out = attn_weights @ v # (B, n_heads, T, head_dim)
61
+ attn_out = attn_out.transpose(1, 2).reshape(B, T, D)
62
+ attn_out = self.self_out(attn_out)
63
+ x = x + attn_out # Residual
64
+
65
+ # Cross-Attention
66
+ x_norm = self.norm2(x)
67
+ q = self.cross_attn_q(x_norm).reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T, head_dim)
68
+ kv = self.cross_attn_kv(enc_out).reshape(B, N, 2, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
69
+ k, v = kv[0], kv[1] # (B, n_heads, N, head_dim)
70
+
71
+ cross_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, T, N)
72
+ cross_weights = F.softmax(cross_scores, dim=-1)
73
+ cross_out = cross_weights @ v # (B, n_heads, T, head_dim)
74
+ cross_out = cross_out.transpose(1, 2).reshape(B, T, D)
75
+ cross_out = self.cross_out(cross_out)
76
+ x = x + cross_out # Residual
77
+
78
+ # Feedforward
79
+ x_norm = self.norm3(x)
80
+ x = x + self.mlp(x_norm) # Residual
81
+
82
+ return x
83
+
84
+
85
+ # implement the entire decoder
86
+
87
+ class TransformerDecoder(nn.Module):
88
+ def __init__(self, vocab_size=13, max_len=5, d_model=64, n_heads=4, ff_dim=128, depth=2):
89
+ super().__init__()
90
+
91
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
92
+ self.pos_embedding = nn.Parameter(torch.randn(1, max_len, d_model)) # (1, 5, 64)
93
+
94
+ self.layers = nn.ModuleList([
95
+ DecoderLayer(d_model=d_model, n_heads=n_heads, ff_dim=ff_dim)
96
+ for _ in range(depth)
97
+ ])
98
+
99
+ self.output_proj = nn.Linear(d_model, vocab_size) # Final projection to logits
100
+
101
+ def forward(self, decoder_input_ids, encoder_output):
102
+ """
103
+ decoder_input_ids: (B, T) token IDs
104
+ encoder_output: (B, N, d_model) from image encoder
105
+ returns: logits over vocab, shape (B, T, vocab_size)
106
+ """
107
+ x = self.token_embedding(decoder_input_ids) # (B, T, d_model)
108
+ x = x + self.pos_embedding[:, :x.size(1), :] # Add positional embedding
109
+
110
+ for layer in self.layers:
111
+ x = layer(x, encoder_output) # (B, T, d_model)
112
+
113
+ logits = self.output_proj(x) # (B, T, vocab_size)
114
+ return logits
115
+
116
+
117
+ # quick test
118
+
119
+ if __name__ == "__main__":
120
+ decoder = TransformerDecoder()
121
+ decoder_input = torch.randint(0, 13, (4, 5)) # (B=4, T=5)
122
+ encoder_out = torch.randn(4, 16, 64) # (B=4, N=16, D=64)
123
+
124
+ logits = decoder(decoder_input, encoder_out)
125
+ print("Logits shape:", logits.shape) # (4, 5, 13)
model/encoder.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ class EncoderLayer(nn.Module):
7
+ def __init__(self, d_model=64, n_heads=4, ff_dim=128):
8
+ super().__init__()
9
+ self.d_model = d_model
10
+ self.n_heads = n_heads
11
+ self.head_dim = d_model // n_heads
12
+
13
+ #attention projections
14
+ self.qkv_proj = nn.Linear(d_model, d_model * 3) #efficient way of projecting to q, k, v
15
+ self.out_proj = nn.Linear(d_model, d_model)
16
+
17
+ #FF MLP
18
+ self.mlp = nn.Sequential(
19
+ nn.Linear(d_model, ff_dim),
20
+ nn.GELU(),
21
+ nn.Linear(ff_dim, d_model)
22
+ )
23
+
24
+ #layernorms
25
+ self.norm1 = nn.LayerNorm(d_model)
26
+ self.norm2 = nn.LayerNorm(d_model)
27
+
28
+ def forward(self, x):
29
+ B, N, D = x.shape
30
+
31
+ #multi-head attention
32
+ x_norm = self.norm1(x)
33
+ qkv = self.qkv_proj(x_norm)
34
+ qkv = qkv.reshape(B, N, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4) # qkv: (3, B, n_heads, N, head_dim)
35
+ q, k, v = qkv[0], qkv[1], qkv[2]
36
+
37
+ attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, N, N)
38
+ attn_weights = F.softmax(attn_scores, dim=-1)
39
+ attn_output = attn_weights @ v # (B, n_heads, N, head_dim)
40
+
41
+ attn_output = attn_output.transpose(1, 2).reshape(B, N, D) # (B, N, D)
42
+ attn_output = self.out_proj(attn_output)
43
+ x = x + attn_output # Residual connection
44
+
45
+ # === Feedforward ===
46
+ x_norm = self.norm2(x)
47
+ x = x + self.mlp(x_norm) # Residual
48
+
49
+ return x
50
+
51
+
52
+ class TransformerEncoder(nn.Module):
53
+ def __init__(self, depth=4, d_model=64, n_heads=4, ff_dim=128, num_patches=16):
54
+ super().__init__()
55
+
56
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, d_model)) # (1, 16, 64)
57
+ self.layers = nn.ModuleList([
58
+ EncoderLayer(d_model=d_model, n_heads=n_heads, ff_dim=ff_dim)
59
+ for _ in range(depth)
60
+ ])
61
+
62
+ def forward(self, x):
63
+ """
64
+ x: Tensor of shape (B, num_patches, d_model)
65
+ returns: Tensor of same shape (B, num_patches, d_model)
66
+ """
67
+ x = x + self.pos_embedding
68
+
69
+ for layer in self.layers:
70
+ x = layer(x)
71
+
72
+ return x
73
+
74
+
75
+ # simple testing of dimensions
76
+ if __name__ == "__main__":
77
+ import torch
78
+
79
+ B = 4 # batch size
80
+ N = 16 # number of patches
81
+ D = 64 # embedding dim
82
+
83
+ dummy_input = torch.randn(B, N, D)
84
+
85
+ print("Testing EncoderLayer...")
86
+ layer = EncoderLayer(d_model=D, n_heads=4, ff_dim=128)
87
+ out = layer(dummy_input)
88
+ print("EncoderLayer output shape:", out.shape) # (B, N, D) torch.Size([4, 16, 64])
89
+
90
+ print("Testing TransformerEncoder...")
91
+ encoder = TransformerEncoder(depth=3, d_model=D, n_heads=4, ff_dim=128, num_patches=N)
92
+ out = encoder(dummy_input)
93
+ print("TransformerEncoder output shape:", out.shape) # (B, N, D) torch.Size([4, 16, 64])
model/feature_extractor.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ class FeatureExtractor(nn.Module):
5
+ def __init__(self, patch_size=14, emb_dim=64):
6
+ super().__init__()
7
+ self.patch_size = patch_size
8
+ self.emb_dim = emb_dim
9
+ self.proj = nn.Linear(patch_size * patch_size, emb_dim)
10
+
11
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
12
+ """
13
+ x: Tensor of shape (B, 1, 56, 56)
14
+ returns patch_embeddings of shape (B, 16, emb_dim)"""
15
+
16
+ B, C, H, W = x.shape
17
+ patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
18
+ patches = patches.contiguous().view(B, -1, self.patch_size * self.patch_size)
19
+ patch_embeddings = self.proj(patches)
20
+
21
+ return patch_embeddings
22
+
23
+
24
+
25
+ if __name__ == "__main__":
26
+
27
+ feature_extractor = FeatureExtractor()
28
+ dummy_input = torch.randn(8, 1, 56, 56)
29
+ out = feature_extractor(dummy_input)
30
+
31
+ print(out.shape) # should expect (8, 16, 64)
model/model.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .feature_extractor import FeatureExtractor
4
+ from .encoder import TransformerEncoder
5
+ from .decoder import TransformerDecoder
6
+
7
+ class ImageToDigitTransformer(nn.Module):
8
+ def __init__(self, vocab_size=13, d_model=64, n_heads=4, ff_dim=128,
9
+ encoder_depth=4, decoder_depth=2, num_patches=16, max_seq_len=5):
10
+ super().__init__()
11
+
12
+ self.feature_extractor = FeatureExtractor(patch_size=14, emb_dim=d_model)
13
+ self.encoder = TransformerEncoder(
14
+ depth=encoder_depth,
15
+ d_model=d_model,
16
+ n_heads=n_heads,
17
+ ff_dim=ff_dim,
18
+ num_patches=num_patches
19
+ )
20
+ self.decoder = TransformerDecoder(
21
+ vocab_size=vocab_size,
22
+ max_len=max_seq_len,
23
+ d_model=d_model,
24
+ n_heads=n_heads,
25
+ ff_dim=ff_dim,
26
+ depth=decoder_depth
27
+ )
28
+
29
+ def forward(self, image_tensor, decoder_input_ids):
30
+ """
31
+ image_tensor: (B, 1, 56, 56)
32
+ decoder_input_ids: (B, 5)
33
+ Returns:
34
+ logits: (B, 5, vocab_size)
35
+ """
36
+ patch_embeddings = self.feature_extractor(image_tensor) # (B, 16, 64)
37
+ encoder_output = self.encoder(patch_embeddings) # (B, 16, 64)
38
+ logits = self.decoder(decoder_input_ids, encoder_output) # (B, 5, 13)
39
+ return logits
40
+
41
+ if __name__ == '__main__':
42
+ model = ImageToDigitTransformer()
43
+ img = torch.randn(4, 1, 56, 56)
44
+ tokens = torch.randint(0, 13, (4, 5))
45
+ logits = model(img, tokens)
46
+ print(logits.shape) # Expected: (4, 5, 13)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ numpy
5
+ Pillow
6
+ tqdm
task.txt ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MNIST multimodal transformer - Overview
2
+
3
+ ## Task 1
4
+ Our goal is to build and train a multimodal transformer from scratch.
5
+ The task is to predict the sequence of digits from a image composed of 4x4 MNIST images tiled together.
6
+ the transformer should be able to predict the labels from tope left, to top right, bottom left, bottom right.
7
+
8
+ ## Outcome
9
+
10
+ clean minimal and well organized project folder structure
11
+ Clean and minimal pytorch code, well organized across dataset, dataloader and model classes
12
+ clear evaluation metrics
13
+
14
+ # Execution
15
+
16
+ ## Dataset
17
+
18
+ create a dataset class that returns a single example of:
19
+ - an image made of 2x2 MNIST images picked at random (from training split) and stitched together
20
+ - the 4 labels organized in top-left top-right, bottom-left, bottom-right
21
+
22
+ ## Model
23
+
24
+ create a transformer architecture, encoder decoder for this task. The architecture is made of three main elements:
25
+ - Feature extractor
26
+ - Encoder
27
+ - Decoder
28
+
29
+ ### feature extractor
30
+
31
+ each image is cut into 16 patches of dim 14x14px (given my stitched 2x2 image is now 56x56 pixels)
32
+ and linearly projected to a dimension of 64, which is the constant latent vector size D for the encoder.
33
+ these represent the image embeddings that are fed as input to the encoder block
34
+
35
+ ### Encoder
36
+
37
+ should follow closely the "attention is all you need" vanilla implementation, similarly to the ViT Vision Transformer paper
38
+
39
+ - positional embeddings are added to the patch embeddings to retain positional information
40
+ using standard learnable 1D position embeddings
41
+ - encoder consists then of alternating layers of multi-headed self-attention and MLP blocks
42
+ - layernorm is applied before every attention block and MLP block
43
+ - residual connections are applied after every block
44
+
45
+ the output of the encoder is going to be a set of encoded representations of the image patches (16x64)
46
+
47
+ ### Decoder
48
+
train.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader
4
+ from torchvision import datasets, transforms
5
+ from tqdm import tqdm
6
+ import os
7
+
8
+ from dataset import MNIST_2x2
9
+ from model.model import ImageToDigitTransformer
10
+
11
+ # Use MPS if available (Apple Silicon)
12
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
13
+ print(f"Using device: {device}")
14
+
15
+ # Config
16
+ BATCH_SIZE = 64
17
+ EPOCHS = 10
18
+ LR = 1e-3
19
+ VOCAB_SIZE = 13
20
+
21
+ # Transforms
22
+ transform = transforms.Compose([
23
+ transforms.ToTensor(),
24
+ transforms.Normalize((0.1307,), (0.3081,))
25
+ ])
26
+
27
+ # Dataset & DataLoader
28
+ train_base = datasets.MNIST('./data', train=True, download=True, transform=transform)
29
+ train_dataset = MNIST_2x2(train_base, seed=42)
30
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
31
+
32
+ # Model, Loss, Optimizer
33
+ model = ImageToDigitTransformer(vocab_size=VOCAB_SIZE).to(device)
34
+ loss_fn = nn.CrossEntropyLoss()
35
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR)
36
+
37
+ # Training Loop
38
+ model.train()
39
+ for epoch in range(EPOCHS):
40
+ total_loss = 0.0
41
+ loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)
42
+
43
+ for images, dec_input, dec_target in loop:
44
+ images = images.to(device)
45
+ dec_input = dec_input.to(device)
46
+ dec_target = dec_target.to(device)
47
+
48
+ logits = model(images, dec_input)
49
+ loss = loss_fn(logits.view(-1, VOCAB_SIZE), dec_target.view(-1))
50
+
51
+ loss.backward()
52
+ optimizer.step()
53
+ optimizer.zero_grad()
54
+ total_loss += loss.item()
55
+
56
+ # Update tqdm every batch
57
+ loop.set_postfix(batch_loss=loss.item())
58
+
59
+ avg_loss = total_loss / len(train_loader)
60
+ print(f"Epoch {epoch + 1}/{EPOCHS} - Loss: {avg_loss:.4f}")
61
+
62
+ # save weights
63
+ os.makedirs("checkpoints", exist_ok=True)
64
+ torch.save(model.state_dict(), "checkpoints/transformer_mnist.pt")
65
+ print("Model saved to checkpoints/transformer_mnist.pt")
utils/tokenizer.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ token_to_id = {
3
+ str(i): i for i in range(10)
4
+ }
5
+ token_to_id["<start>"] = 10
6
+ token_to_id["<finish>"] = 11
7
+
8
+ id_to_token = {v: k for k, v in token_to_id.items()}
9
+
10
+ START = token_to_id["<start>"]
11
+ FINISH = token_to_id["<finish>"]
12
+ # i don't need padding or a pad token because the input is a fixed length sequence of 5
13
+
14
+ def encode(label_list):
15
+ return [token_to_id[str(d)] for d in label_list]
16
+
17
+ def decode(token_ids):
18
+ return [id_to_token[t] for t in token_ids]
19
+
20
+ def prepare_decoder_labels(labels):
21
+ """
22
+ Prepare decoder input and target sequences for training.
23
+ Input labels: [7, 7, 6, 9]
24
+ Output:
25
+ decoder_input = [<start>, 7, 7, 6, 9]
26
+ decoder_target = [7, 7, 6, 9, <finish>]
27
+ """
28
+ token_ids = encode(labels)
29
+ decoder_input = [START] + token_ids
30
+ decoder_target = token_ids + [FINISH]
31
+ return decoder_input, decoder_target
32
+
33
+