santhoshraghu commited on
Commit
7a5b0e6
·
verified ·
1 Parent(s): 545cc0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -151,13 +151,16 @@ class SkinViT(nn.Module):
151
  #multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
152
  #multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
153
 
154
- # Download the full model from your Hugging Face repo
155
- multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10.pth")
156
- multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit.pth")
157
 
158
- # Load the models from local path
159
- multilabel_model = torch.load(multilabel_model_path, map_location="cpu")
160
- multiclass_model = torch.load(multiclass_model_path, map_location="cpu")
 
 
 
161
 
162
  multilabel_model.eval()
163
  multiclass_model.eval()
 
151
  #multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
152
  #multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
153
 
154
+ # Download state_dicts from Hugging Face Hub
155
+ multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth")
156
+ multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth")
157
 
158
+ # Load models
159
+ multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
160
+ multiclass_model = SkinViT(num_classes=len(multiclass_class_names))
161
+
162
+ multilabel_model.load_state_dict(torch.load(multilabel_model_path, map_location="cpu"))
163
+ multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location="cpu"))
164
 
165
  multilabel_model.eval()
166
  multiclass_model.eval()