santhoshraghu commited on
Commit
e0d40bf
·
verified ·
1 Parent(s): d389cd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -176,8 +176,12 @@ multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filenam
176
  multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
177
  multiclass_model = DermNetViT(num_classes=len(multiclass_class_names))
178
 
179
- multilabel_model.load_state_dict(torch.load(multilabel_model_path, map_location="cpu"))
180
- multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location="cpu"))
 
 
 
 
181
 
182
  multilabel_model.eval()
183
  multiclass_model.eval()
@@ -193,7 +197,7 @@ def run_inference(image):
193
  transforms.ToTensor(),
194
  transforms.Normalize([0.5], [0.5])
195
  ])
196
- input_tensor = transform(image).unsqueeze(0)
197
  with torch.no_grad():
198
  probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().numpy()
199
  predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5]
 
176
  multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
177
  multiclass_model = DermNetViT(num_classes=len(multiclass_class_names))
178
 
179
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
180
+ multilabel_model.load_state_dict(torch.load(multilabel_model_path, map_location=device))
181
+ multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location=device))
182
+ multilabel_model.to(device)
183
+ multiclass_model.to(device)
184
+
185
 
186
  multilabel_model.eval()
187
  multiclass_model.eval()
 
197
  transforms.ToTensor(),
198
  transforms.Normalize([0.5], [0.5])
199
  ])
200
+ input_tensor = transform(image).unsqueeze(0).to(device)
201
  with torch.no_grad():
202
  probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().numpy()
203
  predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5]