Noha90 commited on
Commit
fb7d78b
·
1 Parent(s): 506e2ea

update deit

Browse files
Files changed (1) hide show
  1. predict.py +2 -2
predict.py CHANGED
@@ -33,8 +33,8 @@ class DeiT(nn.Module):
33
  model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
34
  print("Model path:", model_path)
35
  model = DeiT(num_classes=len(class_names))
36
- checkpoint = torch.load(model_path, map_location="cpu")
37
- model.load_state_dict(checkpoint["model_state_dict"])
38
  model.eval()
39
 
40
  #deit transform
 
33
  model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
34
  print("Model path:", model_path)
35
  model = DeiT(num_classes=len(class_names))
36
+ state_dict = torch.load(model_path, map_location="cpu")
37
+ model.load_state_dict(state_dict)
38
  model.eval()
39
 
40
  #deit transform