Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
from PIL import Image | |
from torchvision.transforms import Compose, Resize, Grayscale, ToTensor, ToPILImage | |
# global variable for the grayscale transform | |
transform_gs = Compose( | |
[Resize((400, 400)), Grayscale(num_output_channels=1), ToTensor()] | |
) | |
def process_gs_image(image): | |
""" | |
Function to process the grayscale image. | |
""" | |
# Save original size for later use | |
original_size = image.size # (width, height) | |
# Convert the image to grayscale and resize | |
image = transform_gs(image) | |
# Add the batch dimension | |
image = image.unsqueeze(0) | |
# Return both the processed image and original size | |
return image, original_size | |
def inverse_transform_cs(tensor, original_size): | |
""" | |
Function to convert the tensor back to the color image and resize it to its original size. | |
""" | |
# Convert the tensor back to a PIL image | |
to_pil = ToPILImage() | |
pil_image = to_pil(tensor.squeeze(0)) # Remove the batch dimension | |
# Resize the image back to the original size | |
pil_image = pil_image.resize(original_size) | |
return pil_image |