try to reload the model
Browse files- predict.py +1 -4
predict.py
CHANGED
@@ -17,10 +17,7 @@ with open("labels.json", "r") as f:
|
|
17 |
print("class_names:", class_names)
|
18 |
|
19 |
# Load model
|
20 |
-
|
21 |
-
model = SwinForImageClassification.from_pretrained("microsoft/swin-base-patch4-window7-224")
|
22 |
-
|
23 |
-
model.classifier = torch.nn.Linear(model.classifier.in_features, len(class_names))
|
24 |
|
25 |
# Download the model file from the Hub if not present
|
26 |
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
|
|
|
17 |
print("class_names:", class_names)
|
18 |
|
19 |
# Load model
|
20 |
+
model = SwinForImageClassification.from_pretrained("microsoft/swin-base-patch4-window7-224", num_labels=len(class_names))
|
|
|
|
|
|
|
21 |
|
22 |
# Download the model file from the Hub if not present
|
23 |
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")
|