Noha90 commited on
Commit
02a7128
·
1 Parent(s): 092bd1b

change to deit base

Browse files
Files changed (1) hide show
  1. predict.py +1 -1
predict.py CHANGED
@@ -32,7 +32,7 @@ class DeiT(nn.Module):
32
  # Load model
33
  model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="deit_base_best_model.pth")
34
  print("Model path:", model_path)
35
- model = DeiT(model_name="facebook/deit-tiny-patch16-224", num_classes=40)
36
  state_dict = torch.load(model_path, map_location="cpu")
37
  if "model_state_dict" in state_dict:
38
  state_dict = state_dict["model_state_dict"]
 
32
  # Load model
33
  model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="deit_base_best_model.pth")
34
  print("Model path:", model_path)
35
+ model = DeiT(model_name="facebook/deit-base-patch16-224", num_classes=40)
36
  state_dict = torch.load(model_path, map_location="cpu")
37
  if "model_state_dict" in state_dict:
38
  state_dict = state_dict["model_state_dict"]