santhoshraghu commited on
Commit
c1f11b3
·
verified ·
1 Parent(s): 13f6011

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -14
app.py CHANGED
@@ -55,13 +55,37 @@ collection_name = "ks_collection_1.5BE"
55
 
56
  device = "cuda" if torch.cuda.is_available() else "cpu"
57
 
58
- local_embedding = HuggingFaceEmbeddings(
59
- model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
60
- model_kwargs={
61
- "trust_remote_code": True,
62
- "device": device
63
- }
64
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  print(" Qwen2-1.5B local embedding model loaded.")
66
 
67
 
@@ -178,17 +202,33 @@ class DermNetViT(nn.Module):
178
  multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth")
179
  multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth")
180
 
181
- multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
182
- multiclass_model = DermNetViT(num_classes=len(multiclass_class_names))
183
-
184
- #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
185
- multilabel_model.load_state_dict(torch.load(multilabel_model_path, map_location="cpu"))
186
- multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location="cpu"))
187
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  multilabel_model.eval()
190
  multiclass_model.eval()
191
 
 
192
  # === Session Init ===
193
  if "messages" not in st.session_state:
194
  st.session_state.messages = []
 
55
 
56
  device = "cuda" if torch.cuda.is_available() else "cpu"
57
 
58
+ def get_safe_embedding_model():
59
+ model_name = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
60
+
61
+ try:
62
+ print("Trying to load embedding model on CUDA...")
63
+ embedding = HuggingFaceEmbeddings(
64
+ model_name=model_name,
65
+ model_kwargs={
66
+ "trust_remote_code": True,
67
+ "device": "cuda"
68
+ }
69
+ )
70
+ print("Loaded embedding model on GPU.")
71
+ return embedding
72
+ except RuntimeError as e:
73
+ if "CUDA out of memory" in str(e):
74
+ print("CUDA OOM. Falling back to CPU.")
75
+ else:
76
+ print(" Error loading model on CUDA:", str(e))
77
+ print("Loading embedding model on CPU...")
78
+ return HuggingFaceEmbeddings(
79
+ model_name=model_name,
80
+ model_kwargs={
81
+ "trust_remote_code": True,
82
+ "device": "cpu"
83
+ }
84
+ )
85
+
86
+ # Replace your old local_embedding line with this
87
+ local_embedding = get_safe_embedding_model()
88
+
89
  print(" Qwen2-1.5B local embedding model loaded.")
90
 
91
 
 
202
  multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth")
203
  multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth")
204
 
205
+ def load_model_with_fallback(model_class, weight_path, num_classes, model_name):
206
+ try:
207
+ print(f"🔍 Loading {model_name} on GPU...")
208
+ model = model_class(num_classes)
209
+ model.load_state_dict(torch.load(weight_path, map_location="cuda"))
210
+ model.to("cuda")
211
+ print(f"✅ {model_name} loaded on GPU.")
212
+ return model
213
+ except RuntimeError as e:
214
+ if "CUDA out of memory" in str(e):
215
+ print(f"⚠️ {model_name} OOM. Falling back to CPU.")
216
+ else:
217
+ print(f"❌ Error loading {model_name} on CUDA: {e}")
218
+ print(f"🔄 Loading {model_name} on CPU...")
219
+ model = model_class(num_classes)
220
+ model.load_state_dict(torch.load(weight_path, map_location="cpu"))
221
+ model.to("cpu")
222
+ return model
223
+
224
+ # Load both models with fallback
225
+ multilabel_model = load_model_with_fallback(SkinViT, multilabel_model_path, len(multilabel_class_names), "SkinViT")
226
+ multiclass_model = load_model_with_fallback(DermNetViT, multiclass_model_path, len(multiclass_class_names), "DermNetViT")
227
 
228
  multilabel_model.eval()
229
  multiclass_model.eval()
230
 
231
+
232
  # === Session Init ===
233
  if "messages" not in st.session_state:
234
  st.session_state.messages = []