check the model path
Browse files- predict.py +1 -0
predict.py
CHANGED
@@ -24,6 +24,7 @@ model.classifier = torch.nn.Linear(model.classifier.in_features, len(class_names
|
|
24 |
|
25 |
# Download the model file from the Hub if not present
|
26 |
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
|
|
|
27 |
state_dict = torch.load(model_path, map_location="cpu")
|
28 |
|
29 |
# Remove incompatible keys (classifier weights)
|
|
|
24 |
|
25 |
# Download the model file from the Hub if not present
|
26 |
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
|
27 |
+
print("Model path:", model_path)
|
28 |
state_dict = torch.load(model_path, map_location="cpu")
|
29 |
|
30 |
# Remove incompatible keys (classifier weights)
|