Legola commited on
Commit
b1dd78e
·
1 Parent(s): 464d185

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +15 -3
model.py CHANGED
@@ -53,13 +53,25 @@ class StyleLoss(nn.Module):
53
  cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
54
  cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])
55
  #image transformation
 
 
 
 
 
 
 
 
 
 
 
56
  def image_transform(image):
57
  if isinstance(image, str):
58
  # If image is a path to a file, open it using PIL
59
- image = Image.open(image).convert('RGB')
 
60
  else:
61
- # If image is a NumPy array, convert it to a PIL image
62
- image = Image.fromarray(image.astype('uint8'), 'RGB')
63
  # Apply the same transformations as before
64
  image = transform(image).unsqueeze(0)
65
  return image
 
53
  cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
54
  cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])
55
  #image transformation
56
+ # def image_transform(image):
57
+ # if isinstance(image, str):
58
+ # # If image is a path to a file, open it using PIL
59
+ # image = Image.open(image).convert('RGB')
60
+ # else:
61
+ # # If image is a NumPy array, convert it to a PIL image
62
+ # image = Image.fromarray(image.astype('uint8'), 'RGB')
63
+ # # Apply the same transformations as before
64
+ # image = transform(image).unsqueeze(0)
65
+ # return image
66
+
67
  def image_transform(image):
68
  if isinstance(image, str):
69
  # If image is a path to a file, open it using PIL
70
+ with open(image, "rb") as f:
71
+ image = Image.open(f).convert('RGB')
72
  else:
73
+ # If image is already a PIL image, just convert it to RGB mode
74
+ image = image.convert('RGB')
75
  # Apply the same transformations as before
76
  image = transform(image).unsqueeze(0)
77
  return image