Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from PIL import Image | |
def preprocess_canvases(images): | |
""" | |
Takes a list of 4 RGBA images (top-left, top-right, bottom-left, bottom-right), | |
resizes to 28x28, converts to grayscale, stitches to (1, 56, 56) tensor. | |
""" | |
assert len(images) == 4, "Expected 4 images" | |
digits = [] | |
for img in images: | |
img = Image.fromarray(img).convert("L") # convert to grayscale | |
img = img.resize((28, 28)) | |
img = np.array(img).astype(np.float32) / 255.0 # scale to [0, 1] | |
digits.append(img) | |
top = np.hstack([digits[0], digits[1]]) | |
bottom = np.hstack([digits[2], digits[3]]) | |
grid = np.vstack([top, bottom]) # shape (56, 56) | |
# Normalize like MNIST | |
grid = (grid - 0.1307) / 0.3081 | |
grid = torch.tensor(grid).unsqueeze(0).unsqueeze(0) # shape (1, 1, 56, 56) | |
return grid |