change to base deit
Browse files- app.py +1 -1
- predict.py +2 -2
app.py
CHANGED
@@ -9,7 +9,7 @@ demo = gr.Interface(
|
|
9 |
gr.Image(label="Top-1 Class Example"),
|
10 |
gr.Label(label="Top-5 Probabilities")
|
11 |
],
|
12 |
-
title="Scene Classification with Reference Image
|
13 |
description="Upload an image to get the predicted class with a sample image and top-5 prediction chart."
|
14 |
)
|
15 |
|
|
|
9 |
gr.Image(label="Top-1 Class Example"),
|
10 |
gr.Label(label="Top-5 Probabilities")
|
11 |
],
|
12 |
+
title="Scene Classification with Reference Image Testing using DeiT version 1.0",
|
13 |
description="Upload an image to get the predicted class with a sample image and top-5 prediction chart."
|
14 |
)
|
15 |
|
predict.py
CHANGED
@@ -17,7 +17,7 @@ with open("labels.json", "r") as f:
|
|
17 |
print("class_names:", class_names)
|
18 |
|
19 |
class DeiT(nn.Module):
|
20 |
-
def __init__(self, model_name="facebook/deit-
|
21 |
super(DeiT, self).__init__()
|
22 |
self.model = ViTForImageClassification.from_pretrained(model_name)
|
23 |
in_features = self.model.classifier.in_features
|
@@ -30,7 +30,7 @@ class DeiT(nn.Module):
|
|
30 |
return outputs.logits
|
31 |
|
32 |
# Load model
|
33 |
-
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="
|
34 |
print("Model path:", model_path)
|
35 |
model = DeiT(model_name="facebook/deit-tiny-patch16-224", num_classes=40)
|
36 |
state_dict = torch.load(model_path, map_location="cpu")
|
|
|
17 |
print("class_names:", class_names)
|
18 |
|
19 |
class DeiT(nn.Module):
|
20 |
+
def __init__(self, model_name="facebook/deit-base-patch16-224", num_classes=40):
|
21 |
super(DeiT, self).__init__()
|
22 |
self.model = ViTForImageClassification.from_pretrained(model_name)
|
23 |
in_features = self.model.classifier.in_features
|
|
|
30 |
return outputs.logits
|
31 |
|
32 |
# Load model
|
33 |
+
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="deit_base_best_model.pth")
|
34 |
print("Model path:", model_path)
|
35 |
model = DeiT(model_name="facebook/deit-tiny-patch16-224", num_classes=40)
|
36 |
state_dict = torch.load(model_path, map_location="cpu")
|