Noha90 commited on
Commit
bb484df
·
1 Parent(s): ea03b5e

switch to deit

Browse files
Files changed (1) hide show
  1. predict.py +26 -16
predict.py CHANGED
@@ -4,7 +4,7 @@ from PIL import Image
4
  import json
5
  import numpy as np
6
  # from model import load_model
7
- from transformers import AutoImageProcessor, SwinForImageClassification
8
  import torch.nn as nn
9
  import os
10
  import pandas as pd
@@ -16,30 +16,40 @@ with open("labels.json", "r") as f:
16
  class_names = json.load(f)
17
  print("class_names:", class_names)
18
 
19
- # Load model
20
- model = SwinForImageClassification.from_pretrained("microsoft/swin-base-patch4-window7-224", num_labels=1000)
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Download the model file from the Hub if not present
23
  model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
24
  print("Model path:", model_path)
25
- state_dict = torch.load(model_path, map_location="cpu")
26
- model.load_state_dict(state_dict)
 
27
  model.eval()
28
 
29
-
30
- # Image transform
31
- # transform = transforms.Compose([
32
- # transforms.Resize((224, 224)),
33
- # transforms.ToTensor(),
34
- # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
35
- # ])
36
- #Swin
37
  transform = transforms.Compose([
38
  transforms.Resize((224, 224)),
39
  transforms.ToTensor(),
40
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
41
- std=[0.229, 0.224, 0.225])
42
  ])
 
 
 
 
 
 
 
43
 
44
 
45
  def predict(image_path):
 
4
  import json
5
  import numpy as np
6
  # from model import load_model
7
+ from transformers import AutoImageProcessor, SwinForImageClassification, ViTForImageClassification
8
  import torch.nn as nn
9
  import os
10
  import pandas as pd
 
16
  class_names = json.load(f)
17
  print("class_names:", class_names)
18
 
19
+ class DeiT(nn.Module):
20
+ def __init__(self, model_name="facebook/deit-small-patch16-224", num_classes=None):
21
+ super(DeiT, self).__init__()
22
+ self.model = ViTForImageClassification.from_pretrained(model_name)
23
+ in_features = self.model.classifier.in_features
24
+ self.model.classifier = nn.Sequential(
25
+ nn.Linear(in_features, num_classes)
26
+ )
27
+
28
+ def forward(self, images):
29
+ outputs = self.model(images)
30
+ return outputs.logits
31
 
32
+ # Load model
33
  model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
34
  print("Model path:", model_path)
35
+ model = DeiT(num_classes=len(class_names))
36
+ checkpoint = torch.load(model_path, map_location="cpu")
37
+ model.load_state_dict(checkpoint["model_state_dict"])
38
  model.eval()
39
 
40
+ #deit transform
 
 
 
 
 
 
 
41
  transform = transforms.Compose([
42
  transforms.Resize((224, 224)),
43
  transforms.ToTensor(),
44
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
 
45
  ])
46
+ #Swin
47
+ # transform = transforms.Compose([
48
+ # transforms.Resize((224, 224)),
49
+ # transforms.ToTensor(),
50
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406],
51
+ # std=[0.229, 0.224, 0.225])
52
+ # ])
53
 
54
 
55
  def predict(image_path):