Spaces:
Running
on
Zero
Running
on
Zero
Thanush
commited on
Commit
·
f3b4260
1
Parent(s):
01a984c
Refactor app.py to implement LangChain memory for enhanced conversation tracking. Update prompt building and response generation logic to utilize full conversation context, improving user interaction and medical assessment accuracy.
Browse files
app.py
CHANGED
@@ -68,7 +68,10 @@ meditron_model = AutoModelForCausalLM.from_pretrained(
|
|
68 |
)
|
69 |
print("Meditron model loaded successfully!")
|
70 |
|
71 |
-
#
|
|
|
|
|
|
|
72 |
conversation_state = {
|
73 |
'name': None,
|
74 |
'age': None,
|
@@ -95,13 +98,19 @@ def get_meditron_suggestions(patient_info):
|
|
95 |
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
96 |
return suggestion
|
97 |
|
98 |
-
def
|
99 |
-
"""Build
|
100 |
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
|
101 |
|
102 |
-
#
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
# Add current input
|
107 |
prompt += f"{current_input} [/INST] "
|
@@ -110,7 +119,7 @@ def build_simple_prompt(system_prompt, conversation_history, current_input):
|
|
110 |
|
111 |
@spaces.GPU
|
112 |
def generate_response(message, history):
|
113 |
-
"""Generate a response using
|
114 |
global conversation_state
|
115 |
|
116 |
# Reset state if this is a new conversation
|
@@ -122,35 +131,44 @@ def generate_response(message, history):
|
|
122 |
'has_name': False,
|
123 |
'has_age': False
|
124 |
}
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
# Step 1: Ask for name if not provided
|
127 |
if not conversation_state['has_name']:
|
128 |
conversation_state['has_name'] = True
|
129 |
-
|
|
|
|
|
|
|
130 |
|
131 |
# Step 2: Store name and ask for age
|
132 |
if conversation_state['name'] is None:
|
133 |
conversation_state['name'] = message.strip()
|
134 |
-
|
|
|
|
|
|
|
135 |
|
136 |
# Step 3: Store age and start medical questions
|
137 |
if not conversation_state['has_age']:
|
138 |
conversation_state['age'] = message.strip()
|
139 |
conversation_state['has_age'] = True
|
140 |
-
|
|
|
|
|
|
|
141 |
|
142 |
-
# Step 4: Medical consultation phase
|
143 |
conversation_state['medical_turns'] += 1
|
144 |
|
145 |
-
#
|
146 |
-
medical_history = []
|
147 |
-
if len(history) >= 3: # Skip name/age exchanges
|
148 |
-
medical_history = history[3:]
|
149 |
-
|
150 |
-
# Build the prompt for medical consultation
|
151 |
if conversation_state['medical_turns'] <= 5:
|
152 |
# Still gathering information - let LLM ask intelligent follow-up questions
|
153 |
-
prompt =
|
154 |
|
155 |
# Generate response with intelligent follow-up questions
|
156 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
@@ -168,21 +186,31 @@ def generate_response(message, history):
|
|
168 |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
169 |
llama_response = full_response.split('[/INST]')[-1].strip()
|
170 |
|
|
|
|
|
|
|
171 |
return llama_response
|
172 |
|
173 |
else:
|
174 |
# Time for diagnosis and treatment (after 5+ turns)
|
175 |
-
#
|
|
|
|
|
|
|
176 |
patient_info = f"Patient: {conversation_state['name']}, Age: {conversation_state['age']}\n\n"
|
177 |
-
patient_info += "
|
178 |
|
179 |
-
# Add all
|
180 |
-
for
|
181 |
-
|
182 |
-
|
|
|
|
|
183 |
|
184 |
-
|
185 |
-
|
|
|
|
|
186 |
|
187 |
inputs = tokenizer(diagnosis_prompt, return_tensors="pt").to(model.device)
|
188 |
with torch.no_grad():
|
@@ -199,12 +227,15 @@ def generate_response(message, history):
|
|
199 |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
200 |
diagnosis = full_response.split('[/INST]')[-1].strip()
|
201 |
|
202 |
-
# Get treatment suggestions from Meditron
|
203 |
treatment_suggestions = get_meditron_suggestions(patient_info)
|
204 |
|
205 |
# Combine responses
|
206 |
final_response = f"{diagnosis}\n\n--- TREATMENT RECOMMENDATIONS ---\n\n{treatment_suggestions}\n\n**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice."
|
207 |
|
|
|
|
|
|
|
208 |
return final_response
|
209 |
|
210 |
# Create the Gradio interface
|
|
|
68 |
)
|
69 |
print("Meditron model loaded successfully!")
|
70 |
|
71 |
+
# Initialize LangChain memory for conversation tracking
|
72 |
+
memory = ConversationBufferMemory(return_messages=True)
|
73 |
+
|
74 |
+
# Simple state for basic info tracking
|
75 |
conversation_state = {
|
76 |
'name': None,
|
77 |
'age': None,
|
|
|
98 |
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
99 |
return suggestion
|
100 |
|
101 |
+
def build_prompt_with_memory(system_prompt, current_input):
|
102 |
+
"""Build prompt using LangChain memory for full conversation context"""
|
103 |
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
|
104 |
|
105 |
+
# Get conversation history from memory
|
106 |
+
messages = memory.chat_memory.messages
|
107 |
+
|
108 |
+
# Add conversation history to prompt
|
109 |
+
for msg in messages:
|
110 |
+
if msg.type == "human":
|
111 |
+
prompt += f"{msg.content} [/INST] "
|
112 |
+
elif msg.type == "ai":
|
113 |
+
prompt += f"{msg.content} </s><s>[INST] "
|
114 |
|
115 |
# Add current input
|
116 |
prompt += f"{current_input} [/INST] "
|
|
|
119 |
|
120 |
@spaces.GPU
|
121 |
def generate_response(message, history):
|
122 |
+
"""Generate a response using LangChain ConversationBufferMemory."""
|
123 |
global conversation_state
|
124 |
|
125 |
# Reset state if this is a new conversation
|
|
|
131 |
'has_name': False,
|
132 |
'has_age': False
|
133 |
}
|
134 |
+
# Clear memory for new conversation
|
135 |
+
memory.clear()
|
136 |
+
|
137 |
+
# Save current user message to memory (we'll save bot response later)
|
138 |
+
memory.save_context({"input": message}, {"output": ""})
|
139 |
|
140 |
# Step 1: Ask for name if not provided
|
141 |
if not conversation_state['has_name']:
|
142 |
conversation_state['has_name'] = True
|
143 |
+
bot_response = "Hello! Before we discuss your health concerns, could you please tell me your name?"
|
144 |
+
# Update memory with bot response
|
145 |
+
memory.save_context({"input": message}, {"output": bot_response})
|
146 |
+
return bot_response
|
147 |
|
148 |
# Step 2: Store name and ask for age
|
149 |
if conversation_state['name'] is None:
|
150 |
conversation_state['name'] = message.strip()
|
151 |
+
bot_response = f"Nice to meet you, {conversation_state['name']}! Could you please tell me your age?"
|
152 |
+
# Update memory with bot response
|
153 |
+
memory.save_context({"input": message}, {"output": bot_response})
|
154 |
+
return bot_response
|
155 |
|
156 |
# Step 3: Store age and start medical questions
|
157 |
if not conversation_state['has_age']:
|
158 |
conversation_state['age'] = message.strip()
|
159 |
conversation_state['has_age'] = True
|
160 |
+
bot_response = f"Thank you, {conversation_state['name']}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
|
161 |
+
# Update memory with bot response
|
162 |
+
memory.save_context({"input": message}, {"output": bot_response})
|
163 |
+
return bot_response
|
164 |
|
165 |
+
# Step 4: Medical consultation phase using ConversationBufferMemory
|
166 |
conversation_state['medical_turns'] += 1
|
167 |
|
168 |
+
# Build the prompt using memory for full conversation context
|
|
|
|
|
|
|
|
|
|
|
169 |
if conversation_state['medical_turns'] <= 5:
|
170 |
# Still gathering information - let LLM ask intelligent follow-up questions
|
171 |
+
prompt = build_prompt_with_memory(SYSTEM_PROMPT, message)
|
172 |
|
173 |
# Generate response with intelligent follow-up questions
|
174 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
186 |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
187 |
llama_response = full_response.split('[/INST]')[-1].strip()
|
188 |
|
189 |
+
# Save bot response to memory
|
190 |
+
memory.save_context({"input": message}, {"output": llama_response})
|
191 |
+
|
192 |
return llama_response
|
193 |
|
194 |
else:
|
195 |
# Time for diagnosis and treatment (after 5+ turns)
|
196 |
+
# Get all conversation messages from memory
|
197 |
+
all_messages = memory.chat_memory.messages
|
198 |
+
|
199 |
+
# Compile patient information from memory
|
200 |
patient_info = f"Patient: {conversation_state['name']}, Age: {conversation_state['age']}\n\n"
|
201 |
+
patient_info += "Complete Conversation History:\n"
|
202 |
|
203 |
+
# Add all messages from memory
|
204 |
+
for msg in all_messages:
|
205 |
+
if msg.type == "human":
|
206 |
+
patient_info += f"Patient: {msg.content}\n"
|
207 |
+
elif msg.type == "ai":
|
208 |
+
patient_info += f"Doctor: {msg.content}\n"
|
209 |
|
210 |
+
patient_info += f"Current: {message}\n"
|
211 |
+
|
212 |
+
# Generate diagnosis with full conversation context
|
213 |
+
diagnosis_prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\nBased on the complete conversation history, please provide a comprehensive medical assessment including likely diagnosis and recommendations for {conversation_state['name']}.\n\nComplete Patient Information:\n{patient_info} [/INST] "
|
214 |
|
215 |
inputs = tokenizer(diagnosis_prompt, return_tensors="pt").to(model.device)
|
216 |
with torch.no_grad():
|
|
|
227 |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
228 |
diagnosis = full_response.split('[/INST]')[-1].strip()
|
229 |
|
230 |
+
# Get treatment suggestions from Meditron using memory context
|
231 |
treatment_suggestions = get_meditron_suggestions(patient_info)
|
232 |
|
233 |
# Combine responses
|
234 |
final_response = f"{diagnosis}\n\n--- TREATMENT RECOMMENDATIONS ---\n\n{treatment_suggestions}\n\n**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice."
|
235 |
|
236 |
+
# Save final response to memory
|
237 |
+
memory.save_context({"input": message}, {"output": final_response})
|
238 |
+
|
239 |
return final_response
|
240 |
|
241 |
# Create the Gradio interface
|