update deit
Browse files- predict.py +2 -2
predict.py
CHANGED
@@ -33,8 +33,8 @@ class DeiT(nn.Module):
|
|
33 |
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
|
34 |
print("Model path:", model_path)
|
35 |
model = DeiT(num_classes=len(class_names))
|
36 |
-
|
37 |
-
model.load_state_dict(
|
38 |
model.eval()
|
39 |
|
40 |
#deit transform
|
|
|
33 |
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
|
34 |
print("Model path:", model_path)
|
35 |
model = DeiT(num_classes=len(class_names))
|
36 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
37 |
+
model.load_state_dict(state_dict)
|
38 |
model.eval()
|
39 |
|
40 |
#deit transform
|