santhoshraghu commited on
Commit
e09e1b2
·
verified ·
1 Parent(s): ef749ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -11
app.py CHANGED
@@ -5,6 +5,8 @@ import torch.nn as nn
5
  from torchvision import transforms
6
  from torchvision.models import vit_b_16, ViT_B_16_Weights
7
  import pandas as pd
 
 
8
  import io
9
  import os
10
  import base64
@@ -149,18 +151,13 @@ class SkinViT(nn.Module):
149
  #multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
150
  #multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
151
 
152
- multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
153
- multiclass_model = SkinViT(num_classes=len(multiclass_class_names))
154
-
155
- multilabel_model.load_state_dict(torch.hub.load_state_dict_from_url(
156
- "https://huggingface.co/santhoshraghu/DermBOT/resolve/main/skin_vit_fold10.pth",
157
- map_location="cpu"
158
- ))
159
- multiclass_model.load_state_dict(torch.hub.load_state_dict_from_url(
160
- "https://huggingface.co/santhoshraghu/DermBOT/resolve/main/best_dermnet_vit.pth",
161
- map_location="cpu"
162
- ))
163
 
 
 
 
164
 
165
  multilabel_model.eval()
166
  multiclass_model.eval()
 
5
  from torchvision import transforms
6
  from torchvision.models import vit_b_16, ViT_B_16_Weights
7
  import pandas as pd
8
+ from huggingface_hub import hf_hub_download
9
+ from langchain_huggingface import HuggingFaceEmbeddings
10
  import io
11
  import os
12
  import base64
 
151
  #multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
152
  #multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
153
 
154
+ # Download the full model from your Hugging Face repo
155
+ multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10.pth")
156
+ multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit.pth")
 
 
 
 
 
 
 
 
157
 
158
+ # Load the models from local path
159
+ multilabel_model = torch.load(multilabel_model_path, map_location="cpu")
160
+ multiclass_model = torch.load(multiclass_model_path, map_location="cpu")
161
 
162
  multilabel_model.eval()
163
  multiclass_model.eval()