Spaces:
Sleeping
Sleeping
Commit
·
b54146b
0
Parent(s):
codebase withouth model
Browse files- .gitignore +19 -0
- README.md +54 -0
- app.py +8 -0
- app/gradio_app.py +85 -0
- app/preprocess.py +26 -0
- dataset.py +72 -0
- eval.py +73 -0
- launch.py +8 -0
- model/decoder.py +125 -0
- model/encoder.py +93 -0
- model/feature_extractor.py +31 -0
- model/model.py +46 -0
- requirements.txt +6 -0
- task.txt +48 -0
- train.py +65 -0
- utils/tokenizer.py +33 -0
.gitignore
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / cache
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*.so
|
5 |
+
*.ipynb_checkpoints
|
6 |
+
|
7 |
+
# Virtual environments
|
8 |
+
.venv/
|
9 |
+
env/
|
10 |
+
venv/
|
11 |
+
|
12 |
+
# System files
|
13 |
+
.DS_Store
|
14 |
+
|
15 |
+
# PyTorch checkpoints
|
16 |
+
*.pt
|
17 |
+
|
18 |
+
# Gradio session files
|
19 |
+
gradio_cached_examples/
|
README.md
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Transformer MNIST 2×2 — Image-to-Sequence Prediction
|
2 |
+
|
3 |
+
This project implements a minimal Transformer-based model that takes a 2×2 grid of MNIST digits as input and autoregressively predicts the corresponding 4-digit sequence. It serves as a practical deep dive into the inner workings of the Transformer architecture and basic multimodality concepts, combining vision (image patches) with language modeling (digit sequences).
|
4 |
+
|
5 |
+
## 1. Project Overview
|
6 |
+
|
7 |
+
The goal is to understand how a vanilla Transformer encoder-decoder can be applied to a simple multimodal task: mapping an image input to a discrete token sequence. This project focuses on building each architectural component from scratch and wiring them together cleanly.
|
8 |
+
|
9 |
+
## 2. Task Definition
|
10 |
+
|
11 |
+
- **Input:** a 2×2 grid composed of 4 random MNIST digits, forming a 56×56 grayscale image.
|
12 |
+
- **Output:** the 4-digit sequence corresponding to the digits in the grid (top-left → bottom-right).
|
13 |
+
- **Modeling approach:** sequence-to-sequence using an autoregressive decoder with special `<start>` and `<finish>` tokens.
|
14 |
+
|
15 |
+
## 3. Model Architecture
|
16 |
+
|
17 |
+
The model follows a clean encoder-decoder Transformer architecture:
|
18 |
+
|
19 |
+
- **Feature Extractor:** splits the 56×56 image into 16 non-overlapping patches of 14×14 pixels and projects each to a 64-dimensional embedding.
|
20 |
+
- **Transformer Encoder:** processes the 16 patch embeddings using standard multi-head self-attention, positional embeddings, and MLP blocks.
|
21 |
+
- **Transformer Decoder:** autoregressively predicts the digit sequence:
|
22 |
+
- Uses masked self-attention over token embeddings
|
23 |
+
- Attends to encoder output via cross-attention
|
24 |
+
- Outputs a sequence of logits over a vocabulary of 13 tokens (digits 0–9, `<start>`, `<finish>`)
|
25 |
+
- **Tokenizer:** handles token ↔ digit conversions and input preparation.
|
26 |
+
|
27 |
+
## 4. Training Setup
|
28 |
+
|
29 |
+
- **Dataset:** MNIST, wrapped into a custom `MNIST_2x2` PyTorch dataset that returns the stitched image and 4-digit target.
|
30 |
+
- **Batch size:** 64
|
31 |
+
- **Epochs:** 10
|
32 |
+
- **Loss:** `CrossEntropyLoss` over vocabulary tokens
|
33 |
+
- **Optimizer:** Adam
|
34 |
+
- **Hardware:** Apple M4 with `mps` acceleration
|
35 |
+
- **Logging:** `tqdm` per-batch loss tracking for clear training progress
|
36 |
+
|
37 |
+
## 5. Evaluation
|
38 |
+
|
39 |
+
Evaluation is done on the held-out MNIST test set using greedy decoding:
|
40 |
+
|
41 |
+
- Starts with <start> token
|
42 |
+
- Predicts one token at a time (no teacher forcing)
|
43 |
+
- Stops after 4 tokens or if <finish> is predicted
|
44 |
+
|
45 |
+
### Evaluation Metrics
|
46 |
+
|
47 |
+
- **Sequence accuracy:** % of samples where all 4 digits are predicted correctly
|
48 |
+
- **Per-digit accuracy:** % of individual digits predicted correctly across all positions
|
49 |
+
|
50 |
+
### final results after 10 epochs of training
|
51 |
+
|
52 |
+
- **training loss at epoch 10:** 0.0101
|
53 |
+
- **Sequence accuracy:** 93.77% on held-out test set
|
54 |
+
- **Per digit accuracy:** 98.43% on held-out test set
|
app.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# entrpoint for HuggingFace Space
|
2 |
+
|
3 |
+
import sys
|
4 |
+
sys.path.append('.')
|
5 |
+
|
6 |
+
from app.gradio_app import demo
|
7 |
+
|
8 |
+
demo.launch()
|
app/gradio_app.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('.')
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image, ImageDraw
|
8 |
+
|
9 |
+
from model.model import ImageToDigitTransformer
|
10 |
+
from utils.tokenizer import START, FINISH, decode
|
11 |
+
from app.preprocess import preprocess_canvases
|
12 |
+
|
13 |
+
# Load model
|
14 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
15 |
+
model = ImageToDigitTransformer(vocab_size=13).to(device)
|
16 |
+
model.load_state_dict(torch.load("checkpoints/transformer_mnist.pt", map_location=device))
|
17 |
+
model.eval()
|
18 |
+
|
19 |
+
def split_into_quadrants(image):
|
20 |
+
"""Split a PIL Image or numpy array into 4 quadrants (TL, TR, BL, BR)."""
|
21 |
+
if isinstance(image, np.ndarray):
|
22 |
+
image = Image.fromarray(image)
|
23 |
+
w, h = image.size
|
24 |
+
return [
|
25 |
+
np.array(image.crop((0, 0, w // 2, h // 2))),
|
26 |
+
np.array(image.crop((w // 2, 0, w, h // 2))),
|
27 |
+
np.array(image.crop((0, h // 2, w // 2, h))),
|
28 |
+
np.array(image.crop((w // 2, h // 2, w, h))),
|
29 |
+
]
|
30 |
+
|
31 |
+
def predict_digit_sequence(editor_data):
|
32 |
+
"""Predicts 4-digit sequence from 2×2 canvas image."""
|
33 |
+
if editor_data is None or "composite" not in editor_data:
|
34 |
+
return "No image provided."
|
35 |
+
img = editor_data["composite"]
|
36 |
+
quadrants = split_into_quadrants(img)
|
37 |
+
image_tensor = preprocess_canvases(quadrants).to(device)
|
38 |
+
|
39 |
+
decoded = [START]
|
40 |
+
for _ in range(4):
|
41 |
+
input_ids = torch.tensor(decoded, dtype=torch.long).unsqueeze(0).to(device)
|
42 |
+
with torch.no_grad():
|
43 |
+
logits = model(image_tensor, input_ids)
|
44 |
+
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
|
45 |
+
decoded.append(next_token)
|
46 |
+
if next_token == FINISH:
|
47 |
+
break
|
48 |
+
|
49 |
+
pred = decoded[1:]
|
50 |
+
return "".join(decode(pred[:4]))
|
51 |
+
|
52 |
+
def create_black_canvas(size=(800, 800)):
|
53 |
+
"""Create a black canvas with a 2×2 light gray grid overlay."""
|
54 |
+
img = Image.new("L", size, color=0)
|
55 |
+
draw = ImageDraw.Draw(img)
|
56 |
+
w, h = size
|
57 |
+
draw.line([(w // 2, 0), (w // 2, h)], fill=128, width=2)
|
58 |
+
draw.line([(0, h // 2), (w, h // 2)], fill=128, width=2)
|
59 |
+
return img
|
60 |
+
|
61 |
+
# === UI ===
|
62 |
+
canvas_size = 800
|
63 |
+
|
64 |
+
with gr.Blocks() as demo:
|
65 |
+
gr.Markdown("## Draw 4 digits in a 2×2 grid using a white brush")
|
66 |
+
|
67 |
+
canvas = gr.ImageEditor(
|
68 |
+
label="White brush only on black canvas (no uploads)",
|
69 |
+
value=create_black_canvas(),
|
70 |
+
image_mode="L",
|
71 |
+
height=canvas_size,
|
72 |
+
width=canvas_size,
|
73 |
+
sources=[], # disables uploads
|
74 |
+
type="pil",
|
75 |
+
brush=gr.Brush(colors=["#FFFFFF"], default_color="#FFFFFF", default_size=15, color_mode="fixed")
|
76 |
+
)
|
77 |
+
|
78 |
+
predict_btn = gr.Button("Predict")
|
79 |
+
clear_btn = gr.Button("Erase")
|
80 |
+
output = gr.Textbox(label="Predicted 4-digit sequence", interactive=True)
|
81 |
+
|
82 |
+
predict_btn.click(fn=predict_digit_sequence, inputs=[canvas], outputs=[output])
|
83 |
+
clear_btn.click(fn=lambda: create_black_canvas(), inputs=[], outputs=[canvas])
|
84 |
+
|
85 |
+
demo.launch()
|
app/preprocess.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
def preprocess_canvases(images):
|
6 |
+
"""
|
7 |
+
Takes a list of 4 RGBA images (top-left, top-right, bottom-left, bottom-right),
|
8 |
+
resizes to 28x28, converts to grayscale, stitches to (1, 56, 56) tensor.
|
9 |
+
"""
|
10 |
+
assert len(images) == 4, "Expected 4 images"
|
11 |
+
|
12 |
+
digits = []
|
13 |
+
for img in images:
|
14 |
+
img = Image.fromarray(img).convert("L") # convert to grayscale
|
15 |
+
img = img.resize((28, 28))
|
16 |
+
img = np.array(img).astype(np.float32) / 255.0 # scale to [0, 1]
|
17 |
+
digits.append(img)
|
18 |
+
|
19 |
+
top = np.hstack([digits[0], digits[1]])
|
20 |
+
bottom = np.hstack([digits[2], digits[3]])
|
21 |
+
grid = np.vstack([top, bottom]) # shape (56, 56)
|
22 |
+
|
23 |
+
# Normalize like MNIST
|
24 |
+
grid = (grid - 0.1307) / 0.3081
|
25 |
+
grid = torch.tensor(grid).unsqueeze(0).unsqueeze(0) # shape (1, 1, 56, 56)
|
26 |
+
return grid
|
dataset.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from torchvision import datasets, transforms
|
4 |
+
|
5 |
+
from utils.tokenizer import prepare_decoder_labels, encode, decode
|
6 |
+
|
7 |
+
class MNIST_2x2(Dataset):
|
8 |
+
def __init__(self, base_dataset, transform=None, seed=42):
|
9 |
+
self.base_dataset = base_dataset
|
10 |
+
self.transform = transform
|
11 |
+
self.length = len(base_dataset)
|
12 |
+
|
13 |
+
torch.manual_seed(seed)
|
14 |
+
self.index_map = [
|
15 |
+
torch.randint(0, self.length, (4,))
|
16 |
+
for _ in range(self.length)
|
17 |
+
]
|
18 |
+
|
19 |
+
def __len__(self):
|
20 |
+
return self.length
|
21 |
+
|
22 |
+
def __getitem__(self, idx):
|
23 |
+
indices = self.index_map[idx]
|
24 |
+
|
25 |
+
images = [self.base_dataset[i][0] for i in indices]
|
26 |
+
top_row = torch.cat([images[0], images[1]], dim=2)
|
27 |
+
bottom_row = torch.cat([images[2], images[3]], dim=2)
|
28 |
+
grid_image = torch.cat([top_row, bottom_row], dim=1)
|
29 |
+
|
30 |
+
labels = [self.base_dataset[i][1] for i in indices]
|
31 |
+
decoder_input_ids, decoder_target_ids = prepare_decoder_labels(labels)
|
32 |
+
decoder_input = torch.tensor(decoder_input_ids, dtype=torch.long)
|
33 |
+
decoder_target = torch.tensor(decoder_target_ids, dtype=torch.long)
|
34 |
+
|
35 |
+
return grid_image, decoder_input, decoder_target
|
36 |
+
|
37 |
+
# test the dataset and visualize a few samples
|
38 |
+
if __name__ == "__main__":
|
39 |
+
|
40 |
+
import matplotlib.pyplot as plt
|
41 |
+
|
42 |
+
transform = transforms.Compose([
|
43 |
+
transforms.ToTensor(),
|
44 |
+
transforms.Normalize((0.1307,), (0.3081,))
|
45 |
+
])
|
46 |
+
|
47 |
+
mnist_train = datasets.MNIST('./data', train=True, download=True, transform=transform)
|
48 |
+
mnist_test = datasets.MNIST('./data', train=False, download=True, transform=transform)
|
49 |
+
|
50 |
+
train_dataset = MNIST_2x2(mnist_train, seed=42)
|
51 |
+
test_dataset = MNIST_2x2(mnist_test, seed=42)
|
52 |
+
|
53 |
+
|
54 |
+
def show_grid_image(grid_tensor, decoder_target):
|
55 |
+
# Undo normalization for visualization
|
56 |
+
img = grid_tensor.clone()
|
57 |
+
img = img * 0.3081 + 0.1307
|
58 |
+
img = img.squeeze().numpy()
|
59 |
+
|
60 |
+
# Decode token IDs into digit strings
|
61 |
+
digits = decode(decoder_target.tolist()[:-1]) # Remove <finish> for display
|
62 |
+
label_str = " ".join(digits)
|
63 |
+
|
64 |
+
plt.imshow(img, cmap="gray")
|
65 |
+
plt.title(f"Digits: {label_str}")
|
66 |
+
plt.axis("off")
|
67 |
+
plt.show()
|
68 |
+
|
69 |
+
# Visualize a few samples
|
70 |
+
for i in range(3):
|
71 |
+
grid_image, decoder_input, decoder_target = train_dataset[i]
|
72 |
+
show_grid_image(grid_image, decoder_target)
|
eval.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from torchvision import datasets, transforms
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
from dataset import MNIST_2x2
|
7 |
+
from model.model import ImageToDigitTransformer
|
8 |
+
from utils.tokenizer import START, FINISH, decode
|
9 |
+
|
10 |
+
# device
|
11 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
12 |
+
print(f"Using device: {device}")
|
13 |
+
|
14 |
+
# config
|
15 |
+
VOCAB_SIZE = 13
|
16 |
+
MAX_LEN = 5 # length of decoder input: [<start>, d1, d2, d3, d4]
|
17 |
+
SEQ_LEN = 4 # number of predicted digits
|
18 |
+
|
19 |
+
transform = transforms.Compose([
|
20 |
+
transforms.ToTensor(),
|
21 |
+
transforms.Normalize((0.1307,), (0.3081,))
|
22 |
+
])
|
23 |
+
|
24 |
+
mnist_test = datasets.MNIST('./data', train=False, download=True, transform=transform)
|
25 |
+
test_dataset = MNIST_2x2(mnist_test, seed=42)
|
26 |
+
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
27 |
+
|
28 |
+
model = ImageToDigitTransformer(vocab_size=VOCAB_SIZE).to(device)
|
29 |
+
model.load_state_dict(torch.load("checkpoints/transformer_mnist.pt", map_location=device))
|
30 |
+
model.eval()
|
31 |
+
|
32 |
+
# Evaluation Loop
|
33 |
+
correct_sequences = 0
|
34 |
+
digit_correct = 0
|
35 |
+
digit_total = 0
|
36 |
+
|
37 |
+
with torch.no_grad():
|
38 |
+
loop = tqdm(test_loader, desc="Evaluating", leave=False)
|
39 |
+
|
40 |
+
for image, _, target_ids in loop:
|
41 |
+
image = image.to(device)
|
42 |
+
target_ids = target_ids.squeeze(0).tolist()[:-1] # remove <finish>
|
43 |
+
|
44 |
+
decoded = [START]
|
45 |
+
for _ in range(SEQ_LEN):
|
46 |
+
input_ids = torch.tensor(decoded, dtype=torch.long).unsqueeze(0).to(device)
|
47 |
+
logits = model(image, input_ids)
|
48 |
+
next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
|
49 |
+
decoded.append(next_token)
|
50 |
+
if next_token == FINISH:
|
51 |
+
break
|
52 |
+
|
53 |
+
pred = decoded[1:][:SEQ_LEN]
|
54 |
+
target = target_ids
|
55 |
+
|
56 |
+
if pred == target:
|
57 |
+
correct_sequences += 1
|
58 |
+
digit_correct += sum(p == t for p, t in zip(pred, target))
|
59 |
+
digit_total += len(target)
|
60 |
+
|
61 |
+
seq_acc = 100.0 * correct_sequences / (digit_total // SEQ_LEN)
|
62 |
+
digit_acc = 100.0 * digit_correct / digit_total
|
63 |
+
loop.set_postfix(seq_acc=f"{seq_acc:.2f}%", digit_acc=f"{digit_acc:.2f}%")
|
64 |
+
|
65 |
+
|
66 |
+
# final results
|
67 |
+
total_samples = len(test_loader)
|
68 |
+
seq_acc = 100.0 * correct_sequences / total_samples
|
69 |
+
digit_acc = 100.0 * digit_correct / digit_total
|
70 |
+
|
71 |
+
print(f"\nFinal Evaluation Results:")
|
72 |
+
print(f" Sequence accuracy: {seq_acc:.2f}%")
|
73 |
+
print(f" Per-digit accuracy: {digit_acc:.2f}%")
|
launch.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# entrpoint for HuggingFace Space
|
2 |
+
|
3 |
+
import sys
|
4 |
+
sys.path.append('.')
|
5 |
+
|
6 |
+
from app.gradio_app import demo
|
7 |
+
|
8 |
+
demo.launch()
|
model/decoder.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
class DecoderLayer(nn.Module):
|
7 |
+
def __init__(self, d_model=64, n_heads=4, ff_dim=128):
|
8 |
+
super().__init__()
|
9 |
+
self.d_model = d_model
|
10 |
+
self.n_heads = n_heads
|
11 |
+
self.head_dim = d_model // n_heads
|
12 |
+
|
13 |
+
assert d_model % n_heads == 0, "d_model must be divisible by number of heads"
|
14 |
+
|
15 |
+
# Self-attention: Q, K, V from decoder input
|
16 |
+
self.self_attn_proj = nn.Linear(d_model, 3 * d_model)
|
17 |
+
|
18 |
+
# Cross-attention: Q from decoder input, K/V from encoder output
|
19 |
+
self.cross_attn_q = nn.Linear(d_model, d_model)
|
20 |
+
self.cross_attn_kv = nn.Linear(d_model, 2 * d_model)
|
21 |
+
|
22 |
+
# Output projections
|
23 |
+
self.self_out = nn.Linear(d_model, d_model)
|
24 |
+
self.cross_out = nn.Linear(d_model, d_model)
|
25 |
+
|
26 |
+
# Feedforward MLP
|
27 |
+
self.mlp = nn.Sequential(
|
28 |
+
nn.Linear(d_model, ff_dim),
|
29 |
+
nn.GELU(),
|
30 |
+
nn.Linear(ff_dim, d_model)
|
31 |
+
)
|
32 |
+
|
33 |
+
# LayerNorms
|
34 |
+
self.norm1 = nn.LayerNorm(d_model)
|
35 |
+
self.norm2 = nn.LayerNorm(d_model)
|
36 |
+
self.norm3 = nn.LayerNorm(d_model)
|
37 |
+
|
38 |
+
def forward(self, x, enc_out):
|
39 |
+
"""
|
40 |
+
x: (B, T, D) - decoder input embeddings
|
41 |
+
enc_out: (B, N, D) - encoder outputs (image patch representations)
|
42 |
+
Returns: (B, T, D)
|
43 |
+
"""
|
44 |
+
B, T, D = x.shape
|
45 |
+
_, N, _ = enc_out.shape
|
46 |
+
|
47 |
+
# Masked Self-Attention
|
48 |
+
x_norm = self.norm1(x)
|
49 |
+
qkv = self.self_attn_proj(x_norm).reshape(B, T, 3, self.n_heads, self.head_dim)
|
50 |
+
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, n_heads, T, head_dim)
|
51 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
52 |
+
|
53 |
+
attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, T, T)
|
54 |
+
|
55 |
+
# Causal mask: prevent attention to future positions
|
56 |
+
mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0) # (1, 1, T, T)
|
57 |
+
attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))
|
58 |
+
|
59 |
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
60 |
+
attn_out = attn_weights @ v # (B, n_heads, T, head_dim)
|
61 |
+
attn_out = attn_out.transpose(1, 2).reshape(B, T, D)
|
62 |
+
attn_out = self.self_out(attn_out)
|
63 |
+
x = x + attn_out # Residual
|
64 |
+
|
65 |
+
# Cross-Attention
|
66 |
+
x_norm = self.norm2(x)
|
67 |
+
q = self.cross_attn_q(x_norm).reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, n_heads, T, head_dim)
|
68 |
+
kv = self.cross_attn_kv(enc_out).reshape(B, N, 2, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
69 |
+
k, v = kv[0], kv[1] # (B, n_heads, N, head_dim)
|
70 |
+
|
71 |
+
cross_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, T, N)
|
72 |
+
cross_weights = F.softmax(cross_scores, dim=-1)
|
73 |
+
cross_out = cross_weights @ v # (B, n_heads, T, head_dim)
|
74 |
+
cross_out = cross_out.transpose(1, 2).reshape(B, T, D)
|
75 |
+
cross_out = self.cross_out(cross_out)
|
76 |
+
x = x + cross_out # Residual
|
77 |
+
|
78 |
+
# Feedforward
|
79 |
+
x_norm = self.norm3(x)
|
80 |
+
x = x + self.mlp(x_norm) # Residual
|
81 |
+
|
82 |
+
return x
|
83 |
+
|
84 |
+
|
85 |
+
# implement the entire decoder
|
86 |
+
|
87 |
+
class TransformerDecoder(nn.Module):
|
88 |
+
def __init__(self, vocab_size=13, max_len=5, d_model=64, n_heads=4, ff_dim=128, depth=2):
|
89 |
+
super().__init__()
|
90 |
+
|
91 |
+
self.token_embedding = nn.Embedding(vocab_size, d_model)
|
92 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, max_len, d_model)) # (1, 5, 64)
|
93 |
+
|
94 |
+
self.layers = nn.ModuleList([
|
95 |
+
DecoderLayer(d_model=d_model, n_heads=n_heads, ff_dim=ff_dim)
|
96 |
+
for _ in range(depth)
|
97 |
+
])
|
98 |
+
|
99 |
+
self.output_proj = nn.Linear(d_model, vocab_size) # Final projection to logits
|
100 |
+
|
101 |
+
def forward(self, decoder_input_ids, encoder_output):
|
102 |
+
"""
|
103 |
+
decoder_input_ids: (B, T) token IDs
|
104 |
+
encoder_output: (B, N, d_model) from image encoder
|
105 |
+
returns: logits over vocab, shape (B, T, vocab_size)
|
106 |
+
"""
|
107 |
+
x = self.token_embedding(decoder_input_ids) # (B, T, d_model)
|
108 |
+
x = x + self.pos_embedding[:, :x.size(1), :] # Add positional embedding
|
109 |
+
|
110 |
+
for layer in self.layers:
|
111 |
+
x = layer(x, encoder_output) # (B, T, d_model)
|
112 |
+
|
113 |
+
logits = self.output_proj(x) # (B, T, vocab_size)
|
114 |
+
return logits
|
115 |
+
|
116 |
+
|
117 |
+
# quick test
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
decoder = TransformerDecoder()
|
121 |
+
decoder_input = torch.randint(0, 13, (4, 5)) # (B=4, T=5)
|
122 |
+
encoder_out = torch.randn(4, 16, 64) # (B=4, N=16, D=64)
|
123 |
+
|
124 |
+
logits = decoder(decoder_input, encoder_out)
|
125 |
+
print("Logits shape:", logits.shape) # (4, 5, 13)
|
model/encoder.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
class EncoderLayer(nn.Module):
|
7 |
+
def __init__(self, d_model=64, n_heads=4, ff_dim=128):
|
8 |
+
super().__init__()
|
9 |
+
self.d_model = d_model
|
10 |
+
self.n_heads = n_heads
|
11 |
+
self.head_dim = d_model // n_heads
|
12 |
+
|
13 |
+
#attention projections
|
14 |
+
self.qkv_proj = nn.Linear(d_model, d_model * 3) #efficient way of projecting to q, k, v
|
15 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
16 |
+
|
17 |
+
#FF MLP
|
18 |
+
self.mlp = nn.Sequential(
|
19 |
+
nn.Linear(d_model, ff_dim),
|
20 |
+
nn.GELU(),
|
21 |
+
nn.Linear(ff_dim, d_model)
|
22 |
+
)
|
23 |
+
|
24 |
+
#layernorms
|
25 |
+
self.norm1 = nn.LayerNorm(d_model)
|
26 |
+
self.norm2 = nn.LayerNorm(d_model)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
B, N, D = x.shape
|
30 |
+
|
31 |
+
#multi-head attention
|
32 |
+
x_norm = self.norm1(x)
|
33 |
+
qkv = self.qkv_proj(x_norm)
|
34 |
+
qkv = qkv.reshape(B, N, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4) # qkv: (3, B, n_heads, N, head_dim)
|
35 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
36 |
+
|
37 |
+
attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, N, N)
|
38 |
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
39 |
+
attn_output = attn_weights @ v # (B, n_heads, N, head_dim)
|
40 |
+
|
41 |
+
attn_output = attn_output.transpose(1, 2).reshape(B, N, D) # (B, N, D)
|
42 |
+
attn_output = self.out_proj(attn_output)
|
43 |
+
x = x + attn_output # Residual connection
|
44 |
+
|
45 |
+
# === Feedforward ===
|
46 |
+
x_norm = self.norm2(x)
|
47 |
+
x = x + self.mlp(x_norm) # Residual
|
48 |
+
|
49 |
+
return x
|
50 |
+
|
51 |
+
|
52 |
+
class TransformerEncoder(nn.Module):
|
53 |
+
def __init__(self, depth=4, d_model=64, n_heads=4, ff_dim=128, num_patches=16):
|
54 |
+
super().__init__()
|
55 |
+
|
56 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, d_model)) # (1, 16, 64)
|
57 |
+
self.layers = nn.ModuleList([
|
58 |
+
EncoderLayer(d_model=d_model, n_heads=n_heads, ff_dim=ff_dim)
|
59 |
+
for _ in range(depth)
|
60 |
+
])
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
"""
|
64 |
+
x: Tensor of shape (B, num_patches, d_model)
|
65 |
+
returns: Tensor of same shape (B, num_patches, d_model)
|
66 |
+
"""
|
67 |
+
x = x + self.pos_embedding
|
68 |
+
|
69 |
+
for layer in self.layers:
|
70 |
+
x = layer(x)
|
71 |
+
|
72 |
+
return x
|
73 |
+
|
74 |
+
|
75 |
+
# simple testing of dimensions
|
76 |
+
if __name__ == "__main__":
|
77 |
+
import torch
|
78 |
+
|
79 |
+
B = 4 # batch size
|
80 |
+
N = 16 # number of patches
|
81 |
+
D = 64 # embedding dim
|
82 |
+
|
83 |
+
dummy_input = torch.randn(B, N, D)
|
84 |
+
|
85 |
+
print("Testing EncoderLayer...")
|
86 |
+
layer = EncoderLayer(d_model=D, n_heads=4, ff_dim=128)
|
87 |
+
out = layer(dummy_input)
|
88 |
+
print("EncoderLayer output shape:", out.shape) # (B, N, D) torch.Size([4, 16, 64])
|
89 |
+
|
90 |
+
print("Testing TransformerEncoder...")
|
91 |
+
encoder = TransformerEncoder(depth=3, d_model=D, n_heads=4, ff_dim=128, num_patches=N)
|
92 |
+
out = encoder(dummy_input)
|
93 |
+
print("TransformerEncoder output shape:", out.shape) # (B, N, D) torch.Size([4, 16, 64])
|
model/feature_extractor.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class FeatureExtractor(nn.Module):
|
5 |
+
def __init__(self, patch_size=14, emb_dim=64):
|
6 |
+
super().__init__()
|
7 |
+
self.patch_size = patch_size
|
8 |
+
self.emb_dim = emb_dim
|
9 |
+
self.proj = nn.Linear(patch_size * patch_size, emb_dim)
|
10 |
+
|
11 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
12 |
+
"""
|
13 |
+
x: Tensor of shape (B, 1, 56, 56)
|
14 |
+
returns patch_embeddings of shape (B, 16, emb_dim)"""
|
15 |
+
|
16 |
+
B, C, H, W = x.shape
|
17 |
+
patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
|
18 |
+
patches = patches.contiguous().view(B, -1, self.patch_size * self.patch_size)
|
19 |
+
patch_embeddings = self.proj(patches)
|
20 |
+
|
21 |
+
return patch_embeddings
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
|
27 |
+
feature_extractor = FeatureExtractor()
|
28 |
+
dummy_input = torch.randn(8, 1, 56, 56)
|
29 |
+
out = feature_extractor(dummy_input)
|
30 |
+
|
31 |
+
print(out.shape) # should expect (8, 16, 64)
|
model/model.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from .feature_extractor import FeatureExtractor
|
4 |
+
from .encoder import TransformerEncoder
|
5 |
+
from .decoder import TransformerDecoder
|
6 |
+
|
7 |
+
class ImageToDigitTransformer(nn.Module):
|
8 |
+
def __init__(self, vocab_size=13, d_model=64, n_heads=4, ff_dim=128,
|
9 |
+
encoder_depth=4, decoder_depth=2, num_patches=16, max_seq_len=5):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
self.feature_extractor = FeatureExtractor(patch_size=14, emb_dim=d_model)
|
13 |
+
self.encoder = TransformerEncoder(
|
14 |
+
depth=encoder_depth,
|
15 |
+
d_model=d_model,
|
16 |
+
n_heads=n_heads,
|
17 |
+
ff_dim=ff_dim,
|
18 |
+
num_patches=num_patches
|
19 |
+
)
|
20 |
+
self.decoder = TransformerDecoder(
|
21 |
+
vocab_size=vocab_size,
|
22 |
+
max_len=max_seq_len,
|
23 |
+
d_model=d_model,
|
24 |
+
n_heads=n_heads,
|
25 |
+
ff_dim=ff_dim,
|
26 |
+
depth=decoder_depth
|
27 |
+
)
|
28 |
+
|
29 |
+
def forward(self, image_tensor, decoder_input_ids):
|
30 |
+
"""
|
31 |
+
image_tensor: (B, 1, 56, 56)
|
32 |
+
decoder_input_ids: (B, 5)
|
33 |
+
Returns:
|
34 |
+
logits: (B, 5, vocab_size)
|
35 |
+
"""
|
36 |
+
patch_embeddings = self.feature_extractor(image_tensor) # (B, 16, 64)
|
37 |
+
encoder_output = self.encoder(patch_embeddings) # (B, 16, 64)
|
38 |
+
logits = self.decoder(decoder_input_ids, encoder_output) # (B, 5, 13)
|
39 |
+
return logits
|
40 |
+
|
41 |
+
if __name__ == '__main__':
|
42 |
+
model = ImageToDigitTransformer()
|
43 |
+
img = torch.randn(4, 1, 56, 56)
|
44 |
+
tokens = torch.randint(0, 13, (4, 5))
|
45 |
+
logits = model(img, tokens)
|
46 |
+
print(logits.shape) # Expected: (4, 5, 13)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
gradio
|
4 |
+
numpy
|
5 |
+
Pillow
|
6 |
+
tqdm
|
task.txt
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MNIST multimodal transformer - Overview
|
2 |
+
|
3 |
+
## Task 1
|
4 |
+
Our goal is to build and train a multimodal transformer from scratch.
|
5 |
+
The task is to predict the sequence of digits from a image composed of 4x4 MNIST images tiled together.
|
6 |
+
the transformer should be able to predict the labels from tope left, to top right, bottom left, bottom right.
|
7 |
+
|
8 |
+
## Outcome
|
9 |
+
|
10 |
+
clean minimal and well organized project folder structure
|
11 |
+
Clean and minimal pytorch code, well organized across dataset, dataloader and model classes
|
12 |
+
clear evaluation metrics
|
13 |
+
|
14 |
+
# Execution
|
15 |
+
|
16 |
+
## Dataset
|
17 |
+
|
18 |
+
create a dataset class that returns a single example of:
|
19 |
+
- an image made of 2x2 MNIST images picked at random (from training split) and stitched together
|
20 |
+
- the 4 labels organized in top-left top-right, bottom-left, bottom-right
|
21 |
+
|
22 |
+
## Model
|
23 |
+
|
24 |
+
create a transformer architecture, encoder decoder for this task. The architecture is made of three main elements:
|
25 |
+
- Feature extractor
|
26 |
+
- Encoder
|
27 |
+
- Decoder
|
28 |
+
|
29 |
+
### feature extractor
|
30 |
+
|
31 |
+
each image is cut into 16 patches of dim 14x14px (given my stitched 2x2 image is now 56x56 pixels)
|
32 |
+
and linearly projected to a dimension of 64, which is the constant latent vector size D for the encoder.
|
33 |
+
these represent the image embeddings that are fed as input to the encoder block
|
34 |
+
|
35 |
+
### Encoder
|
36 |
+
|
37 |
+
should follow closely the "attention is all you need" vanilla implementation, similarly to the ViT Vision Transformer paper
|
38 |
+
|
39 |
+
- positional embeddings are added to the patch embeddings to retain positional information
|
40 |
+
using standard learnable 1D position embeddings
|
41 |
+
- encoder consists then of alternating layers of multi-headed self-attention and MLP blocks
|
42 |
+
- layernorm is applied before every attention block and MLP block
|
43 |
+
- residual connections are applied after every block
|
44 |
+
|
45 |
+
the output of the encoder is going to be a set of encoded representations of the image patches (16x64)
|
46 |
+
|
47 |
+
### Decoder
|
48 |
+
|
train.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from torchvision import datasets, transforms
|
5 |
+
from tqdm import tqdm
|
6 |
+
import os
|
7 |
+
|
8 |
+
from dataset import MNIST_2x2
|
9 |
+
from model.model import ImageToDigitTransformer
|
10 |
+
|
11 |
+
# Use MPS if available (Apple Silicon)
|
12 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
13 |
+
print(f"Using device: {device}")
|
14 |
+
|
15 |
+
# Config
|
16 |
+
BATCH_SIZE = 64
|
17 |
+
EPOCHS = 10
|
18 |
+
LR = 1e-3
|
19 |
+
VOCAB_SIZE = 13
|
20 |
+
|
21 |
+
# Transforms
|
22 |
+
transform = transforms.Compose([
|
23 |
+
transforms.ToTensor(),
|
24 |
+
transforms.Normalize((0.1307,), (0.3081,))
|
25 |
+
])
|
26 |
+
|
27 |
+
# Dataset & DataLoader
|
28 |
+
train_base = datasets.MNIST('./data', train=True, download=True, transform=transform)
|
29 |
+
train_dataset = MNIST_2x2(train_base, seed=42)
|
30 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
31 |
+
|
32 |
+
# Model, Loss, Optimizer
|
33 |
+
model = ImageToDigitTransformer(vocab_size=VOCAB_SIZE).to(device)
|
34 |
+
loss_fn = nn.CrossEntropyLoss()
|
35 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
|
36 |
+
|
37 |
+
# Training Loop
|
38 |
+
model.train()
|
39 |
+
for epoch in range(EPOCHS):
|
40 |
+
total_loss = 0.0
|
41 |
+
loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)
|
42 |
+
|
43 |
+
for images, dec_input, dec_target in loop:
|
44 |
+
images = images.to(device)
|
45 |
+
dec_input = dec_input.to(device)
|
46 |
+
dec_target = dec_target.to(device)
|
47 |
+
|
48 |
+
logits = model(images, dec_input)
|
49 |
+
loss = loss_fn(logits.view(-1, VOCAB_SIZE), dec_target.view(-1))
|
50 |
+
|
51 |
+
loss.backward()
|
52 |
+
optimizer.step()
|
53 |
+
optimizer.zero_grad()
|
54 |
+
total_loss += loss.item()
|
55 |
+
|
56 |
+
# Update tqdm every batch
|
57 |
+
loop.set_postfix(batch_loss=loss.item())
|
58 |
+
|
59 |
+
avg_loss = total_loss / len(train_loader)
|
60 |
+
print(f"Epoch {epoch + 1}/{EPOCHS} - Loss: {avg_loss:.4f}")
|
61 |
+
|
62 |
+
# save weights
|
63 |
+
os.makedirs("checkpoints", exist_ok=True)
|
64 |
+
torch.save(model.state_dict(), "checkpoints/transformer_mnist.pt")
|
65 |
+
print("Model saved to checkpoints/transformer_mnist.pt")
|
utils/tokenizer.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
token_to_id = {
|
3 |
+
str(i): i for i in range(10)
|
4 |
+
}
|
5 |
+
token_to_id["<start>"] = 10
|
6 |
+
token_to_id["<finish>"] = 11
|
7 |
+
|
8 |
+
id_to_token = {v: k for k, v in token_to_id.items()}
|
9 |
+
|
10 |
+
START = token_to_id["<start>"]
|
11 |
+
FINISH = token_to_id["<finish>"]
|
12 |
+
# i don't need padding or a pad token because the input is a fixed length sequence of 5
|
13 |
+
|
14 |
+
def encode(label_list):
|
15 |
+
return [token_to_id[str(d)] for d in label_list]
|
16 |
+
|
17 |
+
def decode(token_ids):
|
18 |
+
return [id_to_token[t] for t in token_ids]
|
19 |
+
|
20 |
+
def prepare_decoder_labels(labels):
|
21 |
+
"""
|
22 |
+
Prepare decoder input and target sequences for training.
|
23 |
+
Input labels: [7, 7, 6, 9]
|
24 |
+
Output:
|
25 |
+
decoder_input = [<start>, 7, 7, 6, 9]
|
26 |
+
decoder_target = [7, 7, 6, 9, <finish>]
|
27 |
+
"""
|
28 |
+
token_ids = encode(labels)
|
29 |
+
decoder_input = [START] + token_ids
|
30 |
+
decoder_target = token_ids + [FINISH]
|
31 |
+
return decoder_input, decoder_target
|
32 |
+
|
33 |
+
|