Legola commited on
Commit
9f6a1fd
·
1 Parent(s): 8bf2724

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +10 -18
model.py CHANGED
@@ -5,14 +5,12 @@ import torch.nn as nn
5
  import torchvision.models as models
6
  from PIL import Image
7
  from torchvision import transforms
8
- import time
9
  import torch.nn.functional as F
10
  import torch.optim as optim
11
  import matplotlib.pyplot as plt
12
  import torchvision.transforms as transforms
13
  import copy
14
  import torchvision.models as models
15
- import torchvision.transforms.functional as TF
16
  from PIL import Image
17
  import numpy as np
18
 
@@ -68,25 +66,19 @@ transform = transforms.Compose([
68
  transforms.ToTensor()]) # transform it into a torch tensor
69
 
70
  def image_transform(image):
71
- if isinstance(image, str):
72
- # If image is a path to a file, open it using PIL
73
- image = Image.open(image).convert('RGB')
74
- else:
75
- # If image is a NumPy array, convert it to a PIL image
76
- image = Image.fromarray(image.astype('uint8'), 'RGB')
77
- # Apply the same transformations as before
78
- image = transform(image).unsqueeze(0)
 
 
79
  return image
80
 
81
 
82
-
83
-
84
-
85
- #Normalization
86
-
87
- cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
88
- cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])
89
-
90
  # create a module to normalize input image so we can easily put it in a
91
  # ``nn.Sequential``
92
  class Normalization(nn.Module):
 
5
  import torchvision.models as models
6
  from PIL import Image
7
  from torchvision import transforms
 
8
  import torch.nn.functional as F
9
  import torch.optim as optim
10
  import matplotlib.pyplot as plt
11
  import torchvision.transforms as transforms
12
  import copy
13
  import torchvision.models as models
 
14
  from PIL import Image
15
  import numpy as np
16
 
 
66
  transforms.ToTensor()]) # transform it into a torch tensor
67
 
68
  def image_transform(image):
69
+
70
+ if image is not None:
71
+ if isinstance(image, str):
72
+ # If image is a path to a file, open it using PIL
73
+ image = Image.open(image).convert('RGB')
74
+ else:
75
+ # If image is a NumPy array, convert it to a PIL image
76
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
77
+ # Apply the same transformations as before
78
+ image = transform(image).unsqueeze(0)
79
  return image
80
 
81
 
 
 
 
 
 
 
 
 
82
  # create a module to normalize input image so we can easily put it in a
83
  # ``nn.Sequential``
84
  class Normalization(nn.Module):