waghib commited on
Commit
d0ba7fb
·
verified ·
1 Parent(s): c50df20

Updated app.py to improve the syntax compatible with spaces.

Browse files
Files changed (1) hide show
  1. app.py +129 -56
app.py CHANGED
@@ -6,12 +6,13 @@ import torch
6
  import streamlit as st
7
  from dotenv import load_dotenv
8
  from langchain_groq import ChatGroq
9
- from langchain.embeddings import HuggingFaceEmbeddings
10
- from langchain.text_splitter import RecursiveCharacterTextSplitter
11
  from langchain_community.vectorstores import FAISS
12
- from langchain.docstore.document import Document
13
- from langchain.prompts import ChatPromptTemplate
14
  from langchain.chains import create_retrieval_chain
 
15
  import numpy as np
16
  from sentence_transformers import util
17
  import time
@@ -19,15 +20,37 @@ import time
19
  # Set device for model (CUDA if available)
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
- # Load environment variables
23
  load_dotenv()
24
 
25
  # Set up the clinical assistant LLM
26
- groq_api_key = os.getenv('GROQ_API_KEY')
27
- if not groq_api_key:
28
- raise ValueError("API Key is not set in the secrets.")
 
 
 
29
 
30
- llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # Set up embeddings for clinical context (Bio_ClinicalBERT)
33
  embeddings = HuggingFaceEmbeddings(
@@ -38,38 +61,74 @@ embeddings = HuggingFaceEmbeddings(
38
  def load_clinical_data():
39
  """Load both flowcharts and patient cases"""
40
  docs = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Load diagnosis flowcharts
43
- for fpath in glob.glob("./Diagnosis_flowchart/*.json"):
44
- with open(fpath) as f:
45
- data = json.load(f)
46
- content = f"""
47
- DIAGNOSTIC FLOWCHART: {Path(fpath).stem}
48
- Diagnostic Path: {data['diagnostic']}
49
- Key Criteria: {data['knowledge']}
50
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  docs.append(Document(
52
- page_content=content,
53
- metadata={"source": fpath, "type": "flowchart"}
 
 
54
  ))
55
-
56
- # Load patient cases
57
- for category_dir in glob.glob("./Finished/*"):
58
- if os.path.isdir(category_dir):
59
- for case_file in glob.glob(f"{category_dir}/*.json"):
60
- with open(case_file) as f:
61
- case_data = json.load(f)
62
- notes = "\n".join(
63
- f"{k}: {v}" for k, v in case_data.items() if k.startswith("input")
64
- )
65
- docs.append(Document(
66
- page_content=f"""
67
- PATIENT CASE: {Path(case_file).stem}
68
- Category: {Path(category_dir).name}
69
- Notes: {notes}
70
- """,
71
- metadata={"source": case_file, "type": "patient_case"}
72
- ))
73
  return docs
74
 
75
  def build_vectorstore():
@@ -88,31 +147,45 @@ def get_vectorstore():
88
 
89
  def run_rag_chat(query, vectorstore):
90
  """Run the Retrieval-Augmented Generation (RAG) for clinical questions"""
91
- retriever = vectorstore.as_retriever()
92
-
93
- prompt_template = ChatPromptTemplate.from_template("""
94
- You are a clinical assistant AI. Based on the following clinical context, provide a reasoned and medically sound answer to the question.
95
 
96
- <context>
97
- {context}
98
- </context>
99
 
100
- Question: {input}
 
 
101
 
102
- Answer:
103
- """)
104
 
105
- retrieved_docs = retriever.invoke(query, k=3)
106
- retrieved_context = "\n".join([doc.page_content for doc in retrieved_docs])
107
 
108
- chain = create_retrieval_chain(
109
- retriever,
110
- create_stuff_documents_chain(llm, prompt_template)
111
- )
112
 
113
- response = chain.invoke({"input": query, "context": retrieved_context})
 
 
 
 
114
 
115
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  def calculate_hit_rate(retriever, query, expected_docs, k=3):
118
  """Calculate the hit rate for top-k retrieved documents"""
 
6
  import streamlit as st
7
  from dotenv import load_dotenv
8
  from langchain_groq import ChatGroq
9
+ from langchain_community.embeddings import HuggingFaceEmbeddings
10
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
11
  from langchain_community.vectorstores import FAISS
12
+ from langchain_core.documents import Document
13
+ from langchain_core.prompts import ChatPromptTemplate
14
  from langchain.chains import create_retrieval_chain
15
+ from langchain.chains.combine_documents import create_stuff_documents_chain
16
  import numpy as np
17
  from sentence_transformers import util
18
  import time
 
20
  # Set device for model (CUDA if available)
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
+ # Load environment variables - works for both local and Hugging Face Spaces
24
  load_dotenv()
25
 
26
  # Set up the clinical assistant LLM
27
+ # Try to get API key from Hugging Face Spaces secrets first, then fall back to .env file
28
+ try:
29
+ # For Hugging Face Spaces
30
+ from huggingface_hub.inference_api import InferenceApi
31
+ import os
32
+ groq_api_key = os.environ.get('GROQ_API_KEY')
33
 
34
+ # If not found in environment, try to get from st.secrets (Streamlit Cloud/Spaces)
35
+ if not groq_api_key and hasattr(st, 'secrets') and 'GROQ_API_KEY' in st.secrets:
36
+ groq_api_key = st.secrets['GROQ_API_KEY']
37
+
38
+ if not groq_api_key:
39
+ st.warning("API Key is not set in the secrets. Using a placeholder for UI demonstration.")
40
+ # For UI demonstration without API key
41
+ class MockLLM:
42
+ def invoke(self, prompt):
43
+ return {"answer": "This is a placeholder response. Please set up your GROQ_API_KEY to get real responses."}
44
+ llm = MockLLM()
45
+ else:
46
+ llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile")
47
+
48
+ except Exception as e:
49
+ st.error(f"Error setting up LLM: {str(e)}")
50
+ class MockLLM:
51
+ def invoke(self, prompt):
52
+ return {"answer": f"Error setting up LLM: {str(e)}. Please check your API key configuration."}
53
+ llm = MockLLM()
54
 
55
  # Set up embeddings for clinical context (Bio_ClinicalBERT)
56
  embeddings = HuggingFaceEmbeddings(
 
61
  def load_clinical_data():
62
  """Load both flowcharts and patient cases"""
63
  docs = []
64
+
65
+ # Get the absolute path to the current script
66
+ current_dir = os.path.dirname(os.path.abspath(__file__))
67
+
68
+ # Try to handle potential errors with file loading
69
+ try:
70
+ # Load diagnosis flowcharts
71
+ flowchart_dir = os.path.join(current_dir, "Diagnosis_flowchart")
72
+ if os.path.exists(flowchart_dir):
73
+ for fpath in glob.glob(os.path.join(flowchart_dir, "*.json")):
74
+ try:
75
+ with open(fpath, 'r', encoding='utf-8') as f:
76
+ data = json.load(f)
77
+ content = f"""
78
+ DIAGNOSTIC FLOWCHART: {Path(fpath).stem}
79
+ Diagnostic Path: {data.get('diagnostic', 'N/A')}
80
+ Key Criteria: {data.get('knowledge', 'N/A')}
81
+ """
82
+ docs.append(Document(
83
+ page_content=content,
84
+ metadata={"source": fpath, "type": "flowchart"}
85
+ ))
86
+ except Exception as e:
87
+ st.warning(f"Error loading flowchart file {fpath}: {str(e)}")
88
+ else:
89
+ st.warning(f"Flowchart directory not found at {flowchart_dir}")
90
 
91
+ # Load patient cases
92
+ finished_dir = os.path.join(current_dir, "Finished")
93
+ if os.path.exists(finished_dir):
94
+ for category_dir in glob.glob(os.path.join(finished_dir, "*")):
95
+ if os.path.isdir(category_dir):
96
+ for case_file in glob.glob(os.path.join(category_dir, "*.json")):
97
+ try:
98
+ with open(case_file, 'r', encoding='utf-8') as f:
99
+ case_data = json.load(f)
100
+ notes = "\n".join(
101
+ f"{k}: {v}" for k, v in case_data.items() if k.startswith("input")
102
+ )
103
+ docs.append(Document(
104
+ page_content=f"""
105
+ PATIENT CASE: {Path(case_file).stem}
106
+ Category: {Path(category_dir).name}
107
+ Notes: {notes}
108
+ """,
109
+ metadata={"source": case_file, "type": "patient_case"}
110
+ ))
111
+ except Exception as e:
112
+ st.warning(f"Error loading case file {case_file}: {str(e)}")
113
+ else:
114
+ st.warning(f"Finished directory not found at {finished_dir}")
115
+
116
+ # If no documents were loaded, add a sample document for testing
117
+ if not docs:
118
+ st.warning("No clinical data files found. Using sample data for demonstration.")
119
  docs.append(Document(
120
+ page_content="""SAMPLE CLINICAL DATA: This is sample data for demonstration purposes.
121
+ This application requires clinical data files to be present in the correct directories.
122
+ Please ensure the Diagnosis_flowchart and Finished directories exist with proper JSON files.""",
123
+ metadata={"source": "sample", "type": "sample"}
124
  ))
125
+ except Exception as e:
126
+ st.error(f"Error loading clinical data: {str(e)}")
127
+ # Add a fallback document
128
+ docs.append(Document(
129
+ page_content="Error loading clinical data. This is a fallback document for demonstration purposes.",
130
+ metadata={"source": "error", "type": "error"}
131
+ ))
 
 
 
 
 
 
 
 
 
 
 
132
  return docs
133
 
134
  def build_vectorstore():
 
147
 
148
  def run_rag_chat(query, vectorstore):
149
  """Run the Retrieval-Augmented Generation (RAG) for clinical questions"""
150
+ try:
151
+ retriever = vectorstore.as_retriever()
 
 
152
 
153
+ prompt_template = ChatPromptTemplate.from_template("""
154
+ You are a clinical assistant AI. Based on the following clinical context, provide a reasoned and medically sound answer to the question.
 
155
 
156
+ <context>
157
+ {context}
158
+ </context>
159
 
160
+ Question: {input}
 
161
 
162
+ Answer:
163
+ """)
164
 
165
+ retrieved_docs = retriever.invoke(query, k=3)
166
+ retrieved_context = "\n".join([doc.page_content for doc in retrieved_docs])
 
 
167
 
168
+ # Create document chain first
169
+ document_chain = create_stuff_documents_chain(llm, prompt_template)
170
+
171
+ # Then create retrieval chain
172
+ chain = create_retrieval_chain(retriever, document_chain)
173
 
174
+ # Invoke the chain
175
+ response = chain.invoke({"input": query})
176
+
177
+ # Add retrieved documents to response for transparency
178
+ response["context"] = retrieved_docs
179
+
180
+ return response
181
+ except Exception as e:
182
+ st.error(f"Error in RAG processing: {str(e)}")
183
+ # Return a fallback response
184
+ return {
185
+ "answer": f"I encountered an error processing your query: {str(e)}",
186
+ "context": [],
187
+ "input": query
188
+ }
189
 
190
  def calculate_hit_rate(retriever, query, expected_docs, k=3):
191
  """Calculate the hit rate for top-k retrieved documents"""