try to reload the model
Browse files- predict.py +2 -2
predict.py
CHANGED
@@ -25,8 +25,8 @@ model.classifier = torch.nn.Linear(model.classifier.in_features, len(class_names
|
|
25 |
# Download the model file from the Hub if not present
|
26 |
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
|
27 |
print("Model path:", model_path)
|
28 |
-
|
29 |
-
model.load_state_dict(
|
30 |
model.eval()
|
31 |
|
32 |
|
|
|
25 |
# Download the model file from the Hub if not present
|
26 |
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
|
27 |
print("Model path:", model_path)
|
28 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
29 |
+
model.load_state_dict(state_dict)
|
30 |
model.eval()
|
31 |
|
32 |
|