Legola commited on
Commit
529009f
·
1 Parent(s): 19e21cf

Update model.py

Browse files

no gpu so updated the device agnoistic code

Files changed (1) hide show
  1. model.py +5 -5
model.py CHANGED
@@ -46,8 +46,8 @@ class StyleLoss(nn.Module):
46
 
47
  #Normalization
48
 
49
- cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
50
- cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
51
  #image transformation
52
  def image_transform(image):
53
  if isinstance(image, str):
@@ -58,12 +58,12 @@ def image_transform(image):
58
  image = Image.fromarray(image.astype('uint8'), 'RGB')
59
  # Apply the same transformations as before
60
  image = transform(image).unsqueeze(0)
61
- return image.to(device)
62
 
63
 
64
 
65
  #Defining a model
66
- cnn = models.vgg19(pretrained=True).features.to(device).eval()
67
 
68
  #getting the input optimizer
69
  def get_input_optimizer(input_img):
@@ -82,7 +82,7 @@ def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
82
  content_layers=content_layers_default,
83
  style_layers=style_layers_default):
84
  # normalization module
85
- normalization = Normalization(normalization_mean, normalization_std).to(device)
86
 
87
  # just in order to have an iterable access to or list of content/style
88
  # losses
 
46
 
47
  #Normalization
48
 
49
+ cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
50
+ cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])
51
  #image transformation
52
  def image_transform(image):
53
  if isinstance(image, str):
 
58
  image = Image.fromarray(image.astype('uint8'), 'RGB')
59
  # Apply the same transformations as before
60
  image = transform(image).unsqueeze(0)
61
+ return image
62
 
63
 
64
 
65
  #Defining a model
66
+ cnn = models.vgg19(pretrained=True).features.eval()
67
 
68
  #getting the input optimizer
69
  def get_input_optimizer(input_img):
 
82
  content_layers=content_layers_default,
83
  style_layers=style_layers_default):
84
  # normalization module
85
+ normalization = Normalization(normalization_mean, normalization_std)
86
 
87
  # just in order to have an iterable access to or list of content/style
88
  # losses