Legola commited on
Commit
6c21671
·
1 Parent(s): e3abfd6

Update model.py

Browse files

the code structure modified

Files changed (1) hide show
  1. model.py +37 -36
model.py CHANGED
@@ -48,53 +48,49 @@ class StyleLoss(nn.Module):
48
  self.loss = F.mse_loss(G, self.target)
49
  return input
50
 
51
-
52
- #Normalization
53
-
54
- cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
55
- cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])
56
- #image transformation
57
- # def image_transform(image):
58
- # if isinstance(image, str):
59
- # # If image is a path to a file, open it using PIL
60
- # image = Image.open(image).convert('RGB')
61
- # else:
62
- # # If image is a NumPy array, convert it to a PIL image
63
- # image = Image.fromarray(image.astype('uint8'), 'RGB')
64
- # # Apply the same transformations as before
65
- # image = transform(image).unsqueeze(0)
66
- # return image
67
 
68
  def image_transform(image):
69
- if image is None:
70
- return None
71
-
72
  if isinstance(image, str):
73
  # If image is a path to a file, open it using PIL
74
- with open(image, "rb") as f:
75
- image = Image.open(f).convert('RGB')
76
  else:
77
- # If image is already a PIL image, just convert it to RGB mode
78
- image = image.convert('RGB')
79
-
80
  # Apply the same transformations as before
81
- image =image_transform(image).unsqueeze(0)
82
  return image
83
 
84
 
85
- #Defining a model
86
- # weights=weights=torchvision.models.VGG19_Weights.IMAGENET1K_V1
87
- # cnn = models.vgg19(weights=weights).features.eval()
88
- weights = models.vgg19(pretrained='imagenet')
89
- cnn = weights.features.eval()
90
 
91
-
92
- #getting the input optimizer
93
- def get_input_optimizer(input_img):
94
- # this line to show that input is a parameter that requires a gradient
95
- optimizer = optim.LBFGS([input_img])
96
- return optimizer
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
98
  # desired depth layers to compute style/content losses :
99
 
100
 
@@ -159,3 +155,8 @@ def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
159
  model = model[:(i + 1)]
160
 
161
  return model, style_losses, content_losses
 
 
 
 
 
 
48
  self.loss = F.mse_loss(G, self.target)
49
  return input
50
 
51
+ #Image Transform
52
+
53
+ transform = transforms.Compose([
54
+ transforms.Resize((128,128)), # scale imported image
55
+ transforms.ToTensor()]) # transform it into a torch tensor
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  def image_transform(image):
 
 
 
58
  if isinstance(image, str):
59
  # If image is a path to a file, open it using PIL
60
+ image = Image.open(image).convert('RGB')
 
61
  else:
62
+ # If image is a NumPy array, convert it to a PIL image
63
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
 
64
  # Apply the same transformations as before
65
+ image = transform(image).unsqueeze(0)
66
  return image
67
 
68
 
69
+
70
+ #Defining the Model
71
+ cnn = models.vgg19(pretrained=True).features.eval()
 
 
72
 
73
+ #Normalization
74
+
75
+ cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
76
+ cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
77
+
78
+ # create a module to normalize input image so we can easily put it in a
79
+ # ``nn.Sequential``
80
+ class Normalization(nn.Module):
81
+ def __init__(self, mean, std):
82
+ super(Normalization, self).__init__()
83
+ # .view the mean and std to make them [C x 1 x 1] so that they can
84
+ # directly work with image Tensor of shape [B x C x H x W].
85
+ # B is batch size. C is number of channels. H is height and W is width.
86
+ self.mean = torch.tensor(mean).view(-1, 1, 1)
87
+ self.std = torch.tensor(std).view(-1, 1, 1)
88
+
89
+ def forward(self, img):
90
+ # normalize ``img``
91
+ return (img - self.mean) / self.std
92
 
93
+
94
  # desired depth layers to compute style/content losses :
95
 
96
 
 
155
  model = model[:(i + 1)]
156
 
157
  return model, style_losses, content_losses
158
+
159
+ def get_input_optimizer(input_img):
160
+ # this line to show that input is a parameter that requires a gradient
161
+ optimizer = optim.LBFGS([input_img])
162
+ return optimizer