Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
from PIL import Image
|
3 |
import torch
|
|
|
4 |
import torch.nn as nn
|
5 |
from torchvision import transforms
|
6 |
from torchvision.models import vit_b_16, vit_l_16, ViT_B_16_Weights, ViT_L_16_Weights
|
@@ -32,6 +33,8 @@ torch.cuda.empty_cache()
|
|
32 |
|
33 |
import nest_asyncio
|
34 |
nest_asyncio.apply()
|
|
|
|
|
35 |
|
36 |
st.set_page_config(page_title="DermBOT", page_icon="π§¬", layout="centered")
|
37 |
|
@@ -97,6 +100,8 @@ vector_store = Qdrant(
|
|
97 |
retriever = vector_store.as_retriever()
|
98 |
|
99 |
|
|
|
|
|
100 |
# Dynamically initialize LLM based on selection
|
101 |
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
|
102 |
selected_model = st.session_state["selected_model"]
|
@@ -145,12 +150,12 @@ Your Response:
|
|
145 |
|
146 |
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
|
147 |
|
148 |
-
rag_chain = RetrievalQA.from_chain_type(
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
)
|
154 |
|
155 |
# === Class Names ===
|
156 |
multilabel_class_names = [
|
@@ -268,6 +273,25 @@ def export_chat_to_pdf(messages):
|
|
268 |
buf.seek(0)
|
269 |
return buf
|
270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
# === App UI ===
|
272 |
|
273 |
st.title("𧬠DermBOT β Skin AI Assistant")
|
@@ -278,22 +302,21 @@ if uploaded_file:
|
|
278 |
st.image(uploaded_file, caption="Uploaded image", use_container_width=True)
|
279 |
image = Image.open(uploaded_file).convert("RGB")
|
280 |
|
281 |
-
|
282 |
predicted_multi, predicted_single = run_inference(image)
|
283 |
|
284 |
# Show predictions clearly to the user
|
285 |
-
st.markdown(f" Skin Issues
|
286 |
-
st.markdown(f" Most Likely Diagnosis
|
287 |
|
288 |
query = f"What are my treatment options for {predicted_multi} and {predicted_single}?"
|
289 |
st.session_state.messages.append({"role": "user", "content": query})
|
290 |
|
291 |
-
with st.spinner("Analyzing
|
292 |
-
response =
|
293 |
-
st.session_state.messages.append({"role": "assistant", "content": response
|
294 |
|
295 |
with st.chat_message("assistant"):
|
296 |
-
st.markdown(response
|
297 |
|
298 |
# === Chat Interface ===
|
299 |
if prompt := st.chat_input("Ask a follow-up..."):
|
@@ -301,7 +324,7 @@ if prompt := st.chat_input("Ask a follow-up..."):
|
|
301 |
with st.chat_message("user"):
|
302 |
st.markdown(prompt)
|
303 |
|
304 |
-
response =
|
305 |
st.session_state.messages.append({"role": "assistant", "content": response.content})
|
306 |
with st.chat_message("assistant"):
|
307 |
st.markdown(response.content)
|
|
|
1 |
import streamlit as st
|
2 |
from PIL import Image
|
3 |
import torch
|
4 |
+
import cohere
|
5 |
import torch.nn as nn
|
6 |
from torchvision import transforms
|
7 |
from torchvision.models import vit_b_16, vit_l_16, ViT_B_16_Weights, ViT_L_16_Weights
|
|
|
33 |
|
34 |
import nest_asyncio
|
35 |
nest_asyncio.apply()
|
36 |
+
co = cohere.Client(st.secrets["COHERE_API_KEY"])
|
37 |
+
|
38 |
|
39 |
st.set_page_config(page_title="DermBOT", page_icon="π§¬", layout="centered")
|
40 |
|
|
|
100 |
retriever = vector_store.as_retriever()
|
101 |
|
102 |
|
103 |
+
|
104 |
+
|
105 |
# Dynamically initialize LLM based on selection
|
106 |
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
|
107 |
selected_model = st.session_state["selected_model"]
|
|
|
150 |
|
151 |
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
|
152 |
|
153 |
+
#rag_chain = RetrievalQA.from_chain_type(
|
154 |
+
# llm=llm,
|
155 |
+
# retriever=retriever,
|
156 |
+
# chain_type="stuff",
|
157 |
+
# chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
|
158 |
+
#)
|
159 |
|
160 |
# === Class Names ===
|
161 |
multilabel_class_names = [
|
|
|
273 |
buf.seek(0)
|
274 |
return buf
|
275 |
|
276 |
+
|
277 |
+
#Reranker utility
|
278 |
+
def rerank_with_cohere(query: str, documents: list, top_n: int = 5) -> list:
|
279 |
+
if not documents:
|
280 |
+
return []
|
281 |
+
|
282 |
+
raw_texts = [doc.page_content if hasattr(doc, "page_content") else str(doc) for doc in documents]
|
283 |
+
results = co.rerank(query=query, documents=raw_texts, top_n=min(top_n, len(raw_texts)), model="rerank-v3.5")
|
284 |
+
reranked_docs = [documents[result.index] for result in results]
|
285 |
+
return reranked_docs
|
286 |
+
|
287 |
+
# Final answer generation using reranked context
|
288 |
+
def get_reranked_response(query: str):
|
289 |
+
docs = retriever.get_relevant_documents(query)
|
290 |
+
reranked_docs = rerank_with_cohere(query, docs)
|
291 |
+
context = "\n\n".join([doc.page_content for doc in reranked_docs])
|
292 |
+
prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context)
|
293 |
+
return llm.invoke([{"role": "system", "content": prompt}])
|
294 |
+
|
295 |
# === App UI ===
|
296 |
|
297 |
st.title("𧬠DermBOT β Skin AI Assistant")
|
|
|
302 |
st.image(uploaded_file, caption="Uploaded image", use_container_width=True)
|
303 |
image = Image.open(uploaded_file).convert("RGB")
|
304 |
|
|
|
305 |
predicted_multi, predicted_single = run_inference(image)
|
306 |
|
307 |
# Show predictions clearly to the user
|
308 |
+
st.markdown(f"π§Ύ **Skin Issues**: {', '.join(predicted_multi)}")
|
309 |
+
st.markdown(f"π **Most Likely Diagnosis**: {predicted_single}")
|
310 |
|
311 |
query = f"What are my treatment options for {predicted_multi} and {predicted_single}?"
|
312 |
st.session_state.messages.append({"role": "user", "content": query})
|
313 |
|
314 |
+
with st.spinner("π Analyzing and retrieving context..."):
|
315 |
+
response = get_reranked_response(query)
|
316 |
+
st.session_state.messages.append({"role": "assistant", "content": response.content})
|
317 |
|
318 |
with st.chat_message("assistant"):
|
319 |
+
st.markdown(response.content)
|
320 |
|
321 |
# === Chat Interface ===
|
322 |
if prompt := st.chat_input("Ask a follow-up..."):
|
|
|
324 |
with st.chat_message("user"):
|
325 |
st.markdown(prompt)
|
326 |
|
327 |
+
response = get_reranked_response(prompt)
|
328 |
st.session_state.messages.append({"role": "assistant", "content": response.content})
|
329 |
with st.chat_message("assistant"):
|
330 |
st.markdown(response.content)
|