try to reload the model
Browse files- predict.py +1 -5
predict.py
CHANGED
@@ -25,12 +25,8 @@ model.classifier = torch.nn.Linear(model.classifier.in_features, len(class_names
|
|
25 |
# Download the model file from the Hub if not present
|
26 |
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
|
27 |
print("Model path:", model_path)
|
28 |
-
state_dict = torch.load(model_path, map_location="cpu")
|
29 |
-
|
30 |
-
# Remove incompatible keys (classifier weights)
|
31 |
-
filtered_state_dict = {k: v for k, v in state_dict.items() if "classifier" not in k}
|
32 |
-
model.load_state_dict(filtered_state_dict, strict=False)
|
33 |
|
|
|
34 |
model.eval()
|
35 |
|
36 |
|
|
|
25 |
# Download the model file from the Hub if not present
|
26 |
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
|
27 |
print("Model path:", model_path)
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
model.load_state_dict(torch.load((model_path)))
|
30 |
model.eval()
|
31 |
|
32 |
|