Update app.py
Browse files
app.py
CHANGED
@@ -151,13 +151,16 @@ class SkinViT(nn.Module):
|
|
151 |
#multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
|
152 |
#multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
|
153 |
|
154 |
-
# Download
|
155 |
-
multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="
|
156 |
-
multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="
|
157 |
|
158 |
-
# Load
|
159 |
-
multilabel_model =
|
160 |
-
multiclass_model =
|
|
|
|
|
|
|
161 |
|
162 |
multilabel_model.eval()
|
163 |
multiclass_model.eval()
|
|
|
151 |
#multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
|
152 |
#multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
|
153 |
|
154 |
+
# Download state_dicts from Hugging Face Hub
|
155 |
+
multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth")
|
156 |
+
multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth")
|
157 |
|
158 |
+
# Load models
|
159 |
+
multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
|
160 |
+
multiclass_model = SkinViT(num_classes=len(multiclass_class_names))
|
161 |
+
|
162 |
+
multilabel_model.load_state_dict(torch.load(multilabel_model_path, map_location="cpu"))
|
163 |
+
multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location="cpu"))
|
164 |
|
165 |
multilabel_model.eval()
|
166 |
multiclass_model.eval()
|