santhoshraghu commited on
Commit
039752f
·
verified ·
1 Parent(s): a182687

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -147,13 +147,20 @@ class SkinViT(nn.Module):
147
  class DermNetViT(nn.Module):
148
  def __init__(self, num_classes):
149
  super(DermNetViT, self).__init__()
 
150
  self.model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT)
151
- in_features = self.model.heads.head.in_features
152
- self.model.heads.head = nn.Linear(in_features, num_classes)
 
 
 
 
 
153
 
154
  def forward(self, x):
155
  return self.model(x)
156
 
 
157
  #multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
158
  #multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
159
 
@@ -163,8 +170,10 @@ multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filenam
163
 
164
  multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
165
  multiclass_model = DermNetViT(num_classes=len(multiclass_class_names))
 
166
  multilabel_model.load_state_dict(torch.load(multilabel_model_path, map_location="cpu"))
167
  multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location="cpu"))
 
168
  multilabel_model.eval()
169
  multiclass_model.eval()
170
 
 
147
  class DermNetViT(nn.Module):
148
  def __init__(self, num_classes):
149
  super(DermNetViT, self).__init__()
150
+ from torchvision.models import vit_l_16, ViT_L_16_Weights
151
  self.model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT)
152
+ in_features = self.model.heads[0].in_features
153
+ # Match the structure used during training (with head.1)
154
+ self.model.heads = nn.Sequential(
155
+ nn.Linear(in_features, 1024),
156
+ nn.ReLU(),
157
+ nn.Linear(1024, num_classes)
158
+ )
159
 
160
  def forward(self, x):
161
  return self.model(x)
162
 
163
+
164
  #multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
165
  #multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
166
 
 
170
 
171
  multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
172
  multiclass_model = DermNetViT(num_classes=len(multiclass_class_names))
173
+
174
  multilabel_model.load_state_dict(torch.load(multilabel_model_path, map_location="cpu"))
175
  multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location="cpu"))
176
+
177
  multilabel_model.eval()
178
  multiclass_model.eval()
179