Update model.py
Browse filesno gpu so updated the device agnoistic code
model.py
CHANGED
@@ -46,8 +46,8 @@ class StyleLoss(nn.Module):
|
|
46 |
|
47 |
#Normalization
|
48 |
|
49 |
-
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
|
50 |
-
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])
|
51 |
#image transformation
|
52 |
def image_transform(image):
|
53 |
if isinstance(image, str):
|
@@ -58,12 +58,12 @@ def image_transform(image):
|
|
58 |
image = Image.fromarray(image.astype('uint8'), 'RGB')
|
59 |
# Apply the same transformations as before
|
60 |
image = transform(image).unsqueeze(0)
|
61 |
-
return image
|
62 |
|
63 |
|
64 |
|
65 |
#Defining a model
|
66 |
-
cnn = models.vgg19(pretrained=True).features.
|
67 |
|
68 |
#getting the input optimizer
|
69 |
def get_input_optimizer(input_img):
|
@@ -82,7 +82,7 @@ def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
|
|
82 |
content_layers=content_layers_default,
|
83 |
style_layers=style_layers_default):
|
84 |
# normalization module
|
85 |
-
normalization = Normalization(normalization_mean, normalization_std)
|
86 |
|
87 |
# just in order to have an iterable access to or list of content/style
|
88 |
# losses
|
|
|
46 |
|
47 |
#Normalization
|
48 |
|
49 |
+
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
|
50 |
+
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])
|
51 |
#image transformation
|
52 |
def image_transform(image):
|
53 |
if isinstance(image, str):
|
|
|
58 |
image = Image.fromarray(image.astype('uint8'), 'RGB')
|
59 |
# Apply the same transformations as before
|
60 |
image = transform(image).unsqueeze(0)
|
61 |
+
return image
|
62 |
|
63 |
|
64 |
|
65 |
#Defining a model
|
66 |
+
cnn = models.vgg19(pretrained=True).features.eval()
|
67 |
|
68 |
#getting the input optimizer
|
69 |
def get_input_optimizer(input_img):
|
|
|
82 |
content_layers=content_layers_default,
|
83 |
style_layers=style_layers_default):
|
84 |
# normalization module
|
85 |
+
normalization = Normalization(normalization_mean, normalization_std)
|
86 |
|
87 |
# just in order to have an iterable access to or list of content/style
|
88 |
# losses
|