upload weights only
Browse files- predict.py +3 -1
predict.py
CHANGED
@@ -34,7 +34,9 @@ model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="deit_best_model.
|
|
34 |
print("Model path:", model_path)
|
35 |
model = DeiT(num_classes=len(class_names))
|
36 |
state_dict = torch.load(model_path, map_location="cpu")
|
37 |
-
|
|
|
|
|
38 |
model.eval()
|
39 |
|
40 |
#deit transform
|
|
|
34 |
print("Model path:", model_path)
|
35 |
model = DeiT(num_classes=len(class_names))
|
36 |
state_dict = torch.load(model_path, map_location="cpu")
|
37 |
+
if "model_state_dict" in state_dict:
|
38 |
+
state_dict = state_dict["model_state_dict"]
|
39 |
+
model.load_state_dict(state_dict, strict=False)
|
40 |
model.eval()
|
41 |
|
42 |
#deit transform
|