Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,8 @@ import torch.nn as nn
|
|
5 |
from torchvision import transforms
|
6 |
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
7 |
import pandas as pd
|
|
|
|
|
8 |
import io
|
9 |
import os
|
10 |
import base64
|
@@ -149,18 +151,13 @@ class SkinViT(nn.Module):
|
|
149 |
#multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
|
150 |
#multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
|
151 |
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
multilabel_model.load_state_dict(torch.hub.load_state_dict_from_url(
|
156 |
-
"https://huggingface.co/santhoshraghu/DermBOT/resolve/main/skin_vit_fold10.pth",
|
157 |
-
map_location="cpu"
|
158 |
-
))
|
159 |
-
multiclass_model.load_state_dict(torch.hub.load_state_dict_from_url(
|
160 |
-
"https://huggingface.co/santhoshraghu/DermBOT/resolve/main/best_dermnet_vit.pth",
|
161 |
-
map_location="cpu"
|
162 |
-
))
|
163 |
|
|
|
|
|
|
|
164 |
|
165 |
multilabel_model.eval()
|
166 |
multiclass_model.eval()
|
|
|
5 |
from torchvision import transforms
|
6 |
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
7 |
import pandas as pd
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
10 |
import io
|
11 |
import os
|
12 |
import base64
|
|
|
151 |
#multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
|
152 |
#multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
|
153 |
|
154 |
+
# Download the full model from your Hugging Face repo
|
155 |
+
multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10.pth")
|
156 |
+
multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit.pth")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
+
# Load the models from local path
|
159 |
+
multilabel_model = torch.load(multilabel_model_path, map_location="cpu")
|
160 |
+
multiclass_model = torch.load(multiclass_model_path, map_location="cpu")
|
161 |
|
162 |
multilabel_model.eval()
|
163 |
multiclass_model.eval()
|