Update app.py
Browse files
app.py
CHANGED
@@ -147,13 +147,20 @@ class SkinViT(nn.Module):
|
|
147 |
class DermNetViT(nn.Module):
|
148 |
def __init__(self, num_classes):
|
149 |
super(DermNetViT, self).__init__()
|
|
|
150 |
self.model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT)
|
151 |
-
in_features = self.model.heads.
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
def forward(self, x):
|
155 |
return self.model(x)
|
156 |
|
|
|
157 |
#multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
|
158 |
#multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
|
159 |
|
@@ -163,8 +170,10 @@ multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filenam
|
|
163 |
|
164 |
multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
|
165 |
multiclass_model = DermNetViT(num_classes=len(multiclass_class_names))
|
|
|
166 |
multilabel_model.load_state_dict(torch.load(multilabel_model_path, map_location="cpu"))
|
167 |
multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location="cpu"))
|
|
|
168 |
multilabel_model.eval()
|
169 |
multiclass_model.eval()
|
170 |
|
|
|
147 |
class DermNetViT(nn.Module):
|
148 |
def __init__(self, num_classes):
|
149 |
super(DermNetViT, self).__init__()
|
150 |
+
from torchvision.models import vit_l_16, ViT_L_16_Weights
|
151 |
self.model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT)
|
152 |
+
in_features = self.model.heads[0].in_features
|
153 |
+
# Match the structure used during training (with head.1)
|
154 |
+
self.model.heads = nn.Sequential(
|
155 |
+
nn.Linear(in_features, 1024),
|
156 |
+
nn.ReLU(),
|
157 |
+
nn.Linear(1024, num_classes)
|
158 |
+
)
|
159 |
|
160 |
def forward(self, x):
|
161 |
return self.model(x)
|
162 |
|
163 |
+
|
164 |
#multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
|
165 |
#multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
|
166 |
|
|
|
170 |
|
171 |
multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
|
172 |
multiclass_model = DermNetViT(num_classes=len(multiclass_class_names))
|
173 |
+
|
174 |
multilabel_model.load_state_dict(torch.load(multilabel_model_path, map_location="cpu"))
|
175 |
multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location="cpu"))
|
176 |
+
|
177 |
multilabel_model.eval()
|
178 |
multiclass_model.eval()
|
179 |
|