Noha90 commited on
Commit
d226319
·
1 Parent(s): f25a889

try to reload the model

Browse files
Files changed (1) hide show
  1. predict.py +1 -4
predict.py CHANGED
@@ -17,10 +17,7 @@ with open("labels.json", "r") as f:
17
  print("class_names:", class_names)
18
 
19
  # Load model
20
-
21
- model = SwinForImageClassification.from_pretrained("microsoft/swin-base-patch4-window7-224")
22
-
23
- model.classifier = torch.nn.Linear(model.classifier.in_features, len(class_names))
24
 
25
  # Download the model file from the Hub if not present
26
  model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
 
17
  print("class_names:", class_names)
18
 
19
  # Load model
20
+ model = SwinForImageClassification.from_pretrained("microsoft/swin-base-patch4-window7-224", num_labels=len(class_names))
 
 
 
21
 
22
  # Download the model file from the Hub if not present
23
  model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")