change to deit base
Browse files- predict.py +1 -1
predict.py
CHANGED
@@ -32,7 +32,7 @@ class DeiT(nn.Module):
|
|
32 |
# Load model
|
33 |
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="deit_base_best_model.pth")
|
34 |
print("Model path:", model_path)
|
35 |
-
model = DeiT(model_name="facebook/deit-
|
36 |
state_dict = torch.load(model_path, map_location="cpu")
|
37 |
if "model_state_dict" in state_dict:
|
38 |
state_dict = state_dict["model_state_dict"]
|
|
|
32 |
# Load model
|
33 |
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="deit_base_best_model.pth")
|
34 |
print("Model path:", model_path)
|
35 |
+
model = DeiT(model_name="facebook/deit-base-patch16-224", num_classes=40)
|
36 |
state_dict = torch.load(model_path, map_location="cpu")
|
37 |
if "model_state_dict" in state_dict:
|
38 |
state_dict = state_dict["model_state_dict"]
|