Noha90 commited on
Commit
2061c9f
·
1 Parent(s): 08b0e25

upload weights only

Browse files
Files changed (1) hide show
  1. predict.py +3 -1
predict.py CHANGED
@@ -34,7 +34,9 @@ model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="deit_best_model.
34
  print("Model path:", model_path)
35
  model = DeiT(num_classes=len(class_names))
36
  state_dict = torch.load(model_path, map_location="cpu")
37
- model.load_state_dict(state_dict)
 
 
38
  model.eval()
39
 
40
  #deit transform
 
34
  print("Model path:", model_path)
35
  model = DeiT(num_classes=len(class_names))
36
  state_dict = torch.load(model_path, map_location="cpu")
37
+ if "model_state_dict" in state_dict:
38
+ state_dict = state_dict["model_state_dict"]
39
+ model.load_state_dict(state_dict, strict=False)
40
  model.eval()
41
 
42
  #deit transform