Noha90 commited on
Commit
33fe5de
·
1 Parent(s): f03c157

try to reload the model

Browse files
Files changed (1) hide show
  1. predict.py +1 -5
predict.py CHANGED
@@ -25,12 +25,8 @@ model.classifier = torch.nn.Linear(model.classifier.in_features, len(class_names
25
  # Download the model file from the Hub if not present
26
  model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
27
  print("Model path:", model_path)
28
- state_dict = torch.load(model_path, map_location="cpu")
29
-
30
- # Remove incompatible keys (classifier weights)
31
- filtered_state_dict = {k: v for k, v in state_dict.items() if "classifier" not in k}
32
- model.load_state_dict(filtered_state_dict, strict=False)
33
 
 
34
  model.eval()
35
 
36
 
 
25
  # Download the model file from the Hub if not present
26
  model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
27
  print("Model path:", model_path)
 
 
 
 
 
28
 
29
+ model.load_state_dict(torch.load((model_path)))
30
  model.eval()
31
 
32