File size: 867 Bytes
b54146b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import numpy as np
import torch
from PIL import Image

def preprocess_canvases(images):
    """
    Takes a list of 4 RGBA images (top-left, top-right, bottom-left, bottom-right),
    resizes to 28x28, converts to grayscale, stitches to (1, 56, 56) tensor.
    """
    assert len(images) == 4, "Expected 4 images"

    digits = []
    for img in images:
        img = Image.fromarray(img).convert("L")  # convert to grayscale
        img = img.resize((28, 28))
        img = np.array(img).astype(np.float32) / 255.0  # scale to [0, 1]
        digits.append(img)

    top = np.hstack([digits[0], digits[1]])
    bottom = np.hstack([digits[2], digits[3]])
    grid = np.vstack([top, bottom])  # shape (56, 56)

    # Normalize like MNIST
    grid = (grid - 0.1307) / 0.3081
    grid = torch.tensor(grid).unsqueeze(0).unsqueeze(0)  # shape (1, 1, 56, 56)
    return grid