Update app.py
Browse files
app.py
CHANGED
@@ -55,13 +55,37 @@ collection_name = "ks_collection_1.5BE"
|
|
55 |
|
56 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
57 |
|
58 |
-
|
59 |
-
model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct"
|
60 |
-
|
61 |
-
|
62 |
-
"
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
print(" Qwen2-1.5B local embedding model loaded.")
|
66 |
|
67 |
|
@@ -178,17 +202,33 @@ class DermNetViT(nn.Module):
|
|
178 |
multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth")
|
179 |
multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth")
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
multilabel_model.eval()
|
190 |
multiclass_model.eval()
|
191 |
|
|
|
192 |
# === Session Init ===
|
193 |
if "messages" not in st.session_state:
|
194 |
st.session_state.messages = []
|
|
|
55 |
|
56 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
57 |
|
58 |
+
def get_safe_embedding_model():
|
59 |
+
model_name = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
|
60 |
+
|
61 |
+
try:
|
62 |
+
print("Trying to load embedding model on CUDA...")
|
63 |
+
embedding = HuggingFaceEmbeddings(
|
64 |
+
model_name=model_name,
|
65 |
+
model_kwargs={
|
66 |
+
"trust_remote_code": True,
|
67 |
+
"device": "cuda"
|
68 |
+
}
|
69 |
+
)
|
70 |
+
print("Loaded embedding model on GPU.")
|
71 |
+
return embedding
|
72 |
+
except RuntimeError as e:
|
73 |
+
if "CUDA out of memory" in str(e):
|
74 |
+
print("CUDA OOM. Falling back to CPU.")
|
75 |
+
else:
|
76 |
+
print(" Error loading model on CUDA:", str(e))
|
77 |
+
print("Loading embedding model on CPU...")
|
78 |
+
return HuggingFaceEmbeddings(
|
79 |
+
model_name=model_name,
|
80 |
+
model_kwargs={
|
81 |
+
"trust_remote_code": True,
|
82 |
+
"device": "cpu"
|
83 |
+
}
|
84 |
+
)
|
85 |
+
|
86 |
+
# Replace your old local_embedding line with this
|
87 |
+
local_embedding = get_safe_embedding_model()
|
88 |
+
|
89 |
print(" Qwen2-1.5B local embedding model loaded.")
|
90 |
|
91 |
|
|
|
202 |
multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth")
|
203 |
multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth")
|
204 |
|
205 |
+
def load_model_with_fallback(model_class, weight_path, num_classes, model_name):
|
206 |
+
try:
|
207 |
+
print(f"🔍 Loading {model_name} on GPU...")
|
208 |
+
model = model_class(num_classes)
|
209 |
+
model.load_state_dict(torch.load(weight_path, map_location="cuda"))
|
210 |
+
model.to("cuda")
|
211 |
+
print(f"✅ {model_name} loaded on GPU.")
|
212 |
+
return model
|
213 |
+
except RuntimeError as e:
|
214 |
+
if "CUDA out of memory" in str(e):
|
215 |
+
print(f"⚠️ {model_name} OOM. Falling back to CPU.")
|
216 |
+
else:
|
217 |
+
print(f"❌ Error loading {model_name} on CUDA: {e}")
|
218 |
+
print(f"🔄 Loading {model_name} on CPU...")
|
219 |
+
model = model_class(num_classes)
|
220 |
+
model.load_state_dict(torch.load(weight_path, map_location="cpu"))
|
221 |
+
model.to("cpu")
|
222 |
+
return model
|
223 |
+
|
224 |
+
# Load both models with fallback
|
225 |
+
multilabel_model = load_model_with_fallback(SkinViT, multilabel_model_path, len(multilabel_class_names), "SkinViT")
|
226 |
+
multiclass_model = load_model_with_fallback(DermNetViT, multiclass_model_path, len(multiclass_class_names), "DermNetViT")
|
227 |
|
228 |
multilabel_model.eval()
|
229 |
multiclass_model.eval()
|
230 |
|
231 |
+
|
232 |
# === Session Init ===
|
233 |
if "messages" not in st.session_state:
|
234 |
st.session_state.messages = []
|