Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -176,8 +176,12 @@ multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filenam
|
|
176 |
multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
|
177 |
multiclass_model = DermNetViT(num_classes=len(multiclass_class_names))
|
178 |
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
181 |
|
182 |
multilabel_model.eval()
|
183 |
multiclass_model.eval()
|
@@ -193,7 +197,7 @@ def run_inference(image):
|
|
193 |
transforms.ToTensor(),
|
194 |
transforms.Normalize([0.5], [0.5])
|
195 |
])
|
196 |
-
input_tensor = transform(image).unsqueeze(0)
|
197 |
with torch.no_grad():
|
198 |
probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().numpy()
|
199 |
predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5]
|
|
|
176 |
multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
|
177 |
multiclass_model = DermNetViT(num_classes=len(multiclass_class_names))
|
178 |
|
179 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
180 |
+
multilabel_model.load_state_dict(torch.load(multilabel_model_path, map_location=device))
|
181 |
+
multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location=device))
|
182 |
+
multilabel_model.to(device)
|
183 |
+
multiclass_model.to(device)
|
184 |
+
|
185 |
|
186 |
multilabel_model.eval()
|
187 |
multiclass_model.eval()
|
|
|
197 |
transforms.ToTensor(),
|
198 |
transforms.Normalize([0.5], [0.5])
|
199 |
])
|
200 |
+
input_tensor = transform(image).unsqueeze(0).to(device)
|
201 |
with torch.no_grad():
|
202 |
probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().numpy()
|
203 |
predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5]
|