santhoshraghu commited on
Commit
a670d2e
Β·
verified Β·
1 Parent(s): 754b611

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -14
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  from PIL import Image
3
  import torch
 
4
  import torch.nn as nn
5
  from torchvision import transforms
6
  from torchvision.models import vit_b_16, vit_l_16, ViT_B_16_Weights, ViT_L_16_Weights
@@ -32,6 +33,8 @@ torch.cuda.empty_cache()
32
 
33
  import nest_asyncio
34
  nest_asyncio.apply()
 
 
35
 
36
  st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
37
 
@@ -97,6 +100,8 @@ vector_store = Qdrant(
97
  retriever = vector_store.as_retriever()
98
 
99
 
 
 
100
  # Dynamically initialize LLM based on selection
101
  OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
102
  selected_model = st.session_state["selected_model"]
@@ -145,12 +150,12 @@ Your Response:
145
 
146
  prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
147
 
148
- rag_chain = RetrievalQA.from_chain_type(
149
- llm=llm,
150
- retriever=retriever,
151
- chain_type="stuff",
152
- chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
153
- )
154
 
155
  # === Class Names ===
156
  multilabel_class_names = [
@@ -268,6 +273,25 @@ def export_chat_to_pdf(messages):
268
  buf.seek(0)
269
  return buf
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  # === App UI ===
272
 
273
  st.title("🧬 DermBOT β€” Skin AI Assistant")
@@ -278,22 +302,21 @@ if uploaded_file:
278
  st.image(uploaded_file, caption="Uploaded image", use_container_width=True)
279
  image = Image.open(uploaded_file).convert("RGB")
280
 
281
-
282
  predicted_multi, predicted_single = run_inference(image)
283
 
284
  # Show predictions clearly to the user
285
- st.markdown(f" Skin Issues : {', '.join(predicted_multi)}")
286
- st.markdown(f" Most Likely Diagnosis : {predicted_single}")
287
 
288
  query = f"What are my treatment options for {predicted_multi} and {predicted_single}?"
289
  st.session_state.messages.append({"role": "user", "content": query})
290
 
291
- with st.spinner("Analyzing the image and retrieving response..."):
292
- response = rag_chain.invoke(query)
293
- st.session_state.messages.append({"role": "assistant", "content": response['result']})
294
 
295
  with st.chat_message("assistant"):
296
- st.markdown(response['result'])
297
 
298
  # === Chat Interface ===
299
  if prompt := st.chat_input("Ask a follow-up..."):
@@ -301,7 +324,7 @@ if prompt := st.chat_input("Ask a follow-up..."):
301
  with st.chat_message("user"):
302
  st.markdown(prompt)
303
 
304
- response = llm.invoke([{"role": m["role"], "content": m["content"]} for m in st.session_state.messages])
305
  st.session_state.messages.append({"role": "assistant", "content": response.content})
306
  with st.chat_message("assistant"):
307
  st.markdown(response.content)
 
1
  import streamlit as st
2
  from PIL import Image
3
  import torch
4
+ import cohere
5
  import torch.nn as nn
6
  from torchvision import transforms
7
  from torchvision.models import vit_b_16, vit_l_16, ViT_B_16_Weights, ViT_L_16_Weights
 
33
 
34
  import nest_asyncio
35
  nest_asyncio.apply()
36
+ co = cohere.Client(st.secrets["COHERE_API_KEY"])
37
+
38
 
39
  st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
40
 
 
100
  retriever = vector_store.as_retriever()
101
 
102
 
103
+
104
+
105
  # Dynamically initialize LLM based on selection
106
  OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
107
  selected_model = st.session_state["selected_model"]
 
150
 
151
  prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
152
 
153
+ #rag_chain = RetrievalQA.from_chain_type(
154
+ # llm=llm,
155
+ # retriever=retriever,
156
+ # chain_type="stuff",
157
+ # chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
158
+ #)
159
 
160
  # === Class Names ===
161
  multilabel_class_names = [
 
273
  buf.seek(0)
274
  return buf
275
 
276
+
277
+ #Reranker utility
278
+ def rerank_with_cohere(query: str, documents: list, top_n: int = 5) -> list:
279
+ if not documents:
280
+ return []
281
+
282
+ raw_texts = [doc.page_content if hasattr(doc, "page_content") else str(doc) for doc in documents]
283
+ results = co.rerank(query=query, documents=raw_texts, top_n=min(top_n, len(raw_texts)), model="rerank-v3.5")
284
+ reranked_docs = [documents[result.index] for result in results]
285
+ return reranked_docs
286
+
287
+ # Final answer generation using reranked context
288
+ def get_reranked_response(query: str):
289
+ docs = retriever.get_relevant_documents(query)
290
+ reranked_docs = rerank_with_cohere(query, docs)
291
+ context = "\n\n".join([doc.page_content for doc in reranked_docs])
292
+ prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context)
293
+ return llm.invoke([{"role": "system", "content": prompt}])
294
+
295
  # === App UI ===
296
 
297
  st.title("🧬 DermBOT β€” Skin AI Assistant")
 
302
  st.image(uploaded_file, caption="Uploaded image", use_container_width=True)
303
  image = Image.open(uploaded_file).convert("RGB")
304
 
 
305
  predicted_multi, predicted_single = run_inference(image)
306
 
307
  # Show predictions clearly to the user
308
+ st.markdown(f"🧾 **Skin Issues**: {', '.join(predicted_multi)}")
309
+ st.markdown(f"πŸ“Œ **Most Likely Diagnosis**: {predicted_single}")
310
 
311
  query = f"What are my treatment options for {predicted_multi} and {predicted_single}?"
312
  st.session_state.messages.append({"role": "user", "content": query})
313
 
314
+ with st.spinner("πŸ”Ž Analyzing and retrieving context..."):
315
+ response = get_reranked_response(query)
316
+ st.session_state.messages.append({"role": "assistant", "content": response.content})
317
 
318
  with st.chat_message("assistant"):
319
+ st.markdown(response.content)
320
 
321
  # === Chat Interface ===
322
  if prompt := st.chat_input("Ask a follow-up..."):
 
324
  with st.chat_message("user"):
325
  st.markdown(prompt)
326
 
327
+ response = get_reranked_response(prompt)
328
  st.session_state.messages.append({"role": "assistant", "content": response.content})
329
  with st.chat_message("assistant"):
330
  st.markdown(response.content)