santhoshraghu commited on
Commit
a182687
·
verified ·
1 Parent(s): 7a5b0e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -13
app.py CHANGED
@@ -3,7 +3,7 @@ from PIL import Image
3
  import torch
4
  import torch.nn as nn
5
  from torchvision import transforms
6
- from torchvision.models import vit_b_16, ViT_B_16_Weights
7
  import pandas as pd
8
  from huggingface_hub import hf_hub_download
9
  from langchain_huggingface import HuggingFaceEmbeddings
@@ -31,7 +31,6 @@ nest_asyncio.apply()
31
 
32
  st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
33
 
34
- #os.environ["PGVECTOR_CONNECTION_STRING"] = "postgresql+psycopg2://postgres:postgres@localhost:5432/VectorDB"
35
 
36
  # === Model Selection ===
37
  available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"]
@@ -49,10 +48,6 @@ collection_name = "ks_collection_1.5BE"
49
  #local_embedding = SentenceTransformerEmbeddings(model=embedding_model)
50
 
51
 
52
- #local_embedding = HuggingFaceEmbeddings(
53
- # model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
54
- # model_kwargs={"trust_remote_code": True, "device":"cpu"}
55
- #)
56
 
57
  local_embedding = HuggingFaceEmbeddings(
58
  model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
@@ -143,25 +138,33 @@ class SkinViT(nn.Module):
143
  def __init__(self, num_classes):
144
  super(SkinViT, self).__init__()
145
  self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
146
- in_features = self.model.heads[0].in_features
147
- self.model.heads[0] = nn.Linear(in_features, num_classes)
 
 
 
 
 
 
 
 
 
 
 
148
  def forward(self, x):
149
  return self.model(x)
150
 
151
  #multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
152
  #multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
153
 
154
- # Download state_dicts from Hugging Face Hub
155
  multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth")
156
  multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth")
157
 
158
- # Load models
159
  multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
160
- multiclass_model = SkinViT(num_classes=len(multiclass_class_names))
161
-
162
  multilabel_model.load_state_dict(torch.load(multilabel_model_path, map_location="cpu"))
163
  multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location="cpu"))
164
-
165
  multilabel_model.eval()
166
  multiclass_model.eval()
167
 
 
3
  import torch
4
  import torch.nn as nn
5
  from torchvision import transforms
6
+ from torchvision.models import vit_b_16, vit_l_16, ViT_B_16_Weights, ViT_L_16_Weights
7
  import pandas as pd
8
  from huggingface_hub import hf_hub_download
9
  from langchain_huggingface import HuggingFaceEmbeddings
 
31
 
32
  st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
33
 
 
34
 
35
  # === Model Selection ===
36
  available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"]
 
48
  #local_embedding = SentenceTransformerEmbeddings(model=embedding_model)
49
 
50
 
 
 
 
 
51
 
52
  local_embedding = HuggingFaceEmbeddings(
53
  model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
 
138
  def __init__(self, num_classes):
139
  super(SkinViT, self).__init__()
140
  self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
141
+ in_features = self.model.heads.head.in_features
142
+ self.model.heads.head = nn.Linear(in_features, num_classes)
143
+
144
+ def forward(self, x):
145
+ return self.model(x)
146
+
147
+ class DermNetViT(nn.Module):
148
+ def __init__(self, num_classes):
149
+ super(DermNetViT, self).__init__()
150
+ self.model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT)
151
+ in_features = self.model.heads.head.in_features
152
+ self.model.heads.head = nn.Linear(in_features, num_classes)
153
+
154
  def forward(self, x):
155
  return self.model(x)
156
 
157
  #multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
158
  #multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
159
 
160
+ # === Load Model State Dicts ===
161
  multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth")
162
  multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth")
163
 
 
164
  multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
165
+ multiclass_model = DermNetViT(num_classes=len(multiclass_class_names))
 
166
  multilabel_model.load_state_dict(torch.load(multilabel_model_path, map_location="cpu"))
167
  multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location="cpu"))
 
168
  multilabel_model.eval()
169
  multiclass_model.eval()
170