Update model.py
Browse files
model.py
CHANGED
@@ -53,13 +53,25 @@ class StyleLoss(nn.Module):
|
|
53 |
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
|
54 |
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])
|
55 |
#image transformation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
def image_transform(image):
|
57 |
if isinstance(image, str):
|
58 |
# If image is a path to a file, open it using PIL
|
59 |
-
|
|
|
60 |
else:
|
61 |
-
# If image is a
|
62 |
-
image =
|
63 |
# Apply the same transformations as before
|
64 |
image = transform(image).unsqueeze(0)
|
65 |
return image
|
|
|
53 |
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
|
54 |
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])
|
55 |
#image transformation
|
56 |
+
# def image_transform(image):
|
57 |
+
# if isinstance(image, str):
|
58 |
+
# # If image is a path to a file, open it using PIL
|
59 |
+
# image = Image.open(image).convert('RGB')
|
60 |
+
# else:
|
61 |
+
# # If image is a NumPy array, convert it to a PIL image
|
62 |
+
# image = Image.fromarray(image.astype('uint8'), 'RGB')
|
63 |
+
# # Apply the same transformations as before
|
64 |
+
# image = transform(image).unsqueeze(0)
|
65 |
+
# return image
|
66 |
+
|
67 |
def image_transform(image):
|
68 |
if isinstance(image, str):
|
69 |
# If image is a path to a file, open it using PIL
|
70 |
+
with open(image, "rb") as f:
|
71 |
+
image = Image.open(f).convert('RGB')
|
72 |
else:
|
73 |
+
# If image is already a PIL image, just convert it to RGB mode
|
74 |
+
image = image.convert('RGB')
|
75 |
# Apply the same transformations as before
|
76 |
image = transform(image).unsqueeze(0)
|
77 |
return image
|