switch to deit
Browse files- predict.py +26 -16
predict.py
CHANGED
@@ -4,7 +4,7 @@ from PIL import Image
|
|
4 |
import json
|
5 |
import numpy as np
|
6 |
# from model import load_model
|
7 |
-
from transformers import AutoImageProcessor, SwinForImageClassification
|
8 |
import torch.nn as nn
|
9 |
import os
|
10 |
import pandas as pd
|
@@ -16,30 +16,40 @@ with open("labels.json", "r") as f:
|
|
16 |
class_names = json.load(f)
|
17 |
print("class_names:", class_names)
|
18 |
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
#
|
23 |
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
|
24 |
print("Model path:", model_path)
|
25 |
-
|
26 |
-
|
|
|
27 |
model.eval()
|
28 |
|
29 |
-
|
30 |
-
# Image transform
|
31 |
-
# transform = transforms.Compose([
|
32 |
-
# transforms.Resize((224, 224)),
|
33 |
-
# transforms.ToTensor(),
|
34 |
-
# transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
35 |
-
# ])
|
36 |
-
#Swin
|
37 |
transform = transforms.Compose([
|
38 |
transforms.Resize((224, 224)),
|
39 |
transforms.ToTensor(),
|
40 |
-
transforms.Normalize(
|
41 |
-
std=[0.229, 0.224, 0.225])
|
42 |
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
|
45 |
def predict(image_path):
|
|
|
4 |
import json
|
5 |
import numpy as np
|
6 |
# from model import load_model
|
7 |
+
from transformers import AutoImageProcessor, SwinForImageClassification, ViTForImageClassification
|
8 |
import torch.nn as nn
|
9 |
import os
|
10 |
import pandas as pd
|
|
|
16 |
class_names = json.load(f)
|
17 |
print("class_names:", class_names)
|
18 |
|
19 |
+
class DeiT(nn.Module):
|
20 |
+
def __init__(self, model_name="facebook/deit-small-patch16-224", num_classes=None):
|
21 |
+
super(DeiT, self).__init__()
|
22 |
+
self.model = ViTForImageClassification.from_pretrained(model_name)
|
23 |
+
in_features = self.model.classifier.in_features
|
24 |
+
self.model.classifier = nn.Sequential(
|
25 |
+
nn.Linear(in_features, num_classes)
|
26 |
+
)
|
27 |
+
|
28 |
+
def forward(self, images):
|
29 |
+
outputs = self.model(images)
|
30 |
+
return outputs.logits
|
31 |
|
32 |
+
# Load model
|
33 |
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
|
34 |
print("Model path:", model_path)
|
35 |
+
model = DeiT(num_classes=len(class_names))
|
36 |
+
checkpoint = torch.load(model_path, map_location="cpu")
|
37 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
38 |
model.eval()
|
39 |
|
40 |
+
#deit transform
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
transform = transforms.Compose([
|
42 |
transforms.Resize((224, 224)),
|
43 |
transforms.ToTensor(),
|
44 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
|
|
45 |
])
|
46 |
+
#Swin
|
47 |
+
# transform = transforms.Compose([
|
48 |
+
# transforms.Resize((224, 224)),
|
49 |
+
# transforms.ToTensor(),
|
50 |
+
# transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
51 |
+
# std=[0.229, 0.224, 0.225])
|
52 |
+
# ])
|
53 |
|
54 |
|
55 |
def predict(image_path):
|