Thanush commited on
Commit
f3b4260
·
1 Parent(s): 01a984c

Refactor app.py to implement LangChain memory for enhanced conversation tracking. Update prompt building and response generation logic to utilize full conversation context, improving user interaction and medical assessment accuracy.

Browse files
Files changed (1) hide show
  1. app.py +58 -27
app.py CHANGED
@@ -68,7 +68,10 @@ meditron_model = AutoModelForCausalLM.from_pretrained(
68
  )
69
  print("Meditron model loaded successfully!")
70
 
71
- # Simple conversation state tracking
 
 
 
72
  conversation_state = {
73
  'name': None,
74
  'age': None,
@@ -95,13 +98,19 @@ def get_meditron_suggestions(patient_info):
95
  suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
96
  return suggestion
97
 
98
- def build_simple_prompt(system_prompt, conversation_history, current_input):
99
- """Build a simple prompt for Llama-2"""
100
  prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
101
 
102
- # Add conversation history
103
- for i, (user_msg, bot_msg) in enumerate(conversation_history):
104
- prompt += f"{user_msg} [/INST] {bot_msg} </s><s>[INST] "
 
 
 
 
 
 
105
 
106
  # Add current input
107
  prompt += f"{current_input} [/INST] "
@@ -110,7 +119,7 @@ def build_simple_prompt(system_prompt, conversation_history, current_input):
110
 
111
  @spaces.GPU
112
  def generate_response(message, history):
113
- """Generate a response using simple state tracking."""
114
  global conversation_state
115
 
116
  # Reset state if this is a new conversation
@@ -122,35 +131,44 @@ def generate_response(message, history):
122
  'has_name': False,
123
  'has_age': False
124
  }
 
 
 
 
 
125
 
126
  # Step 1: Ask for name if not provided
127
  if not conversation_state['has_name']:
128
  conversation_state['has_name'] = True
129
- return "Hello! Before we discuss your health concerns, could you please tell me your name?"
 
 
 
130
 
131
  # Step 2: Store name and ask for age
132
  if conversation_state['name'] is None:
133
  conversation_state['name'] = message.strip()
134
- return f"Nice to meet you, {conversation_state['name']}! Could you please tell me your age?"
 
 
 
135
 
136
  # Step 3: Store age and start medical questions
137
  if not conversation_state['has_age']:
138
  conversation_state['age'] = message.strip()
139
  conversation_state['has_age'] = True
140
- return f"Thank you, {conversation_state['name']}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
 
 
 
141
 
142
- # Step 4: Medical consultation phase
143
  conversation_state['medical_turns'] += 1
144
 
145
- # Prepare conversation history for the model
146
- medical_history = []
147
- if len(history) >= 3: # Skip name/age exchanges
148
- medical_history = history[3:]
149
-
150
- # Build the prompt for medical consultation
151
  if conversation_state['medical_turns'] <= 5:
152
  # Still gathering information - let LLM ask intelligent follow-up questions
153
- prompt = build_simple_prompt(SYSTEM_PROMPT, medical_history, message)
154
 
155
  # Generate response with intelligent follow-up questions
156
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
@@ -168,21 +186,31 @@ def generate_response(message, history):
168
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
169
  llama_response = full_response.split('[/INST]')[-1].strip()
170
 
 
 
 
171
  return llama_response
172
 
173
  else:
174
  # Time for diagnosis and treatment (after 5+ turns)
175
- # Compile patient information
 
 
 
176
  patient_info = f"Patient: {conversation_state['name']}, Age: {conversation_state['age']}\n\n"
177
- patient_info += "Symptoms and Information:\n"
178
 
179
- # Add all medical conversation history
180
- for user_msg, bot_msg in medical_history:
181
- patient_info += f"Patient: {user_msg}\n"
182
- patient_info += f"Patient: {message}\n"
 
 
183
 
184
- # Generate diagnosis with Llama-2
185
- diagnosis_prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\nBased on all the information provided, please provide a comprehensive medical assessment including likely diagnosis and recommendations for {conversation_state['name']}.\n\nPatient Information:\n{patient_info} [/INST] "
 
 
186
 
187
  inputs = tokenizer(diagnosis_prompt, return_tensors="pt").to(model.device)
188
  with torch.no_grad():
@@ -199,12 +227,15 @@ def generate_response(message, history):
199
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
200
  diagnosis = full_response.split('[/INST]')[-1].strip()
201
 
202
- # Get treatment suggestions from Meditron
203
  treatment_suggestions = get_meditron_suggestions(patient_info)
204
 
205
  # Combine responses
206
  final_response = f"{diagnosis}\n\n--- TREATMENT RECOMMENDATIONS ---\n\n{treatment_suggestions}\n\n**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice."
207
 
 
 
 
208
  return final_response
209
 
210
  # Create the Gradio interface
 
68
  )
69
  print("Meditron model loaded successfully!")
70
 
71
+ # Initialize LangChain memory for conversation tracking
72
+ memory = ConversationBufferMemory(return_messages=True)
73
+
74
+ # Simple state for basic info tracking
75
  conversation_state = {
76
  'name': None,
77
  'age': None,
 
98
  suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
99
  return suggestion
100
 
101
+ def build_prompt_with_memory(system_prompt, current_input):
102
+ """Build prompt using LangChain memory for full conversation context"""
103
  prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
104
 
105
+ # Get conversation history from memory
106
+ messages = memory.chat_memory.messages
107
+
108
+ # Add conversation history to prompt
109
+ for msg in messages:
110
+ if msg.type == "human":
111
+ prompt += f"{msg.content} [/INST] "
112
+ elif msg.type == "ai":
113
+ prompt += f"{msg.content} </s><s>[INST] "
114
 
115
  # Add current input
116
  prompt += f"{current_input} [/INST] "
 
119
 
120
  @spaces.GPU
121
  def generate_response(message, history):
122
+ """Generate a response using LangChain ConversationBufferMemory."""
123
  global conversation_state
124
 
125
  # Reset state if this is a new conversation
 
131
  'has_name': False,
132
  'has_age': False
133
  }
134
+ # Clear memory for new conversation
135
+ memory.clear()
136
+
137
+ # Save current user message to memory (we'll save bot response later)
138
+ memory.save_context({"input": message}, {"output": ""})
139
 
140
  # Step 1: Ask for name if not provided
141
  if not conversation_state['has_name']:
142
  conversation_state['has_name'] = True
143
+ bot_response = "Hello! Before we discuss your health concerns, could you please tell me your name?"
144
+ # Update memory with bot response
145
+ memory.save_context({"input": message}, {"output": bot_response})
146
+ return bot_response
147
 
148
  # Step 2: Store name and ask for age
149
  if conversation_state['name'] is None:
150
  conversation_state['name'] = message.strip()
151
+ bot_response = f"Nice to meet you, {conversation_state['name']}! Could you please tell me your age?"
152
+ # Update memory with bot response
153
+ memory.save_context({"input": message}, {"output": bot_response})
154
+ return bot_response
155
 
156
  # Step 3: Store age and start medical questions
157
  if not conversation_state['has_age']:
158
  conversation_state['age'] = message.strip()
159
  conversation_state['has_age'] = True
160
+ bot_response = f"Thank you, {conversation_state['name']}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
161
+ # Update memory with bot response
162
+ memory.save_context({"input": message}, {"output": bot_response})
163
+ return bot_response
164
 
165
+ # Step 4: Medical consultation phase using ConversationBufferMemory
166
  conversation_state['medical_turns'] += 1
167
 
168
+ # Build the prompt using memory for full conversation context
 
 
 
 
 
169
  if conversation_state['medical_turns'] <= 5:
170
  # Still gathering information - let LLM ask intelligent follow-up questions
171
+ prompt = build_prompt_with_memory(SYSTEM_PROMPT, message)
172
 
173
  # Generate response with intelligent follow-up questions
174
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
186
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
187
  llama_response = full_response.split('[/INST]')[-1].strip()
188
 
189
+ # Save bot response to memory
190
+ memory.save_context({"input": message}, {"output": llama_response})
191
+
192
  return llama_response
193
 
194
  else:
195
  # Time for diagnosis and treatment (after 5+ turns)
196
+ # Get all conversation messages from memory
197
+ all_messages = memory.chat_memory.messages
198
+
199
+ # Compile patient information from memory
200
  patient_info = f"Patient: {conversation_state['name']}, Age: {conversation_state['age']}\n\n"
201
+ patient_info += "Complete Conversation History:\n"
202
 
203
+ # Add all messages from memory
204
+ for msg in all_messages:
205
+ if msg.type == "human":
206
+ patient_info += f"Patient: {msg.content}\n"
207
+ elif msg.type == "ai":
208
+ patient_info += f"Doctor: {msg.content}\n"
209
 
210
+ patient_info += f"Current: {message}\n"
211
+
212
+ # Generate diagnosis with full conversation context
213
+ diagnosis_prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\nBased on the complete conversation history, please provide a comprehensive medical assessment including likely diagnosis and recommendations for {conversation_state['name']}.\n\nComplete Patient Information:\n{patient_info} [/INST] "
214
 
215
  inputs = tokenizer(diagnosis_prompt, return_tensors="pt").to(model.device)
216
  with torch.no_grad():
 
227
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
228
  diagnosis = full_response.split('[/INST]')[-1].strip()
229
 
230
+ # Get treatment suggestions from Meditron using memory context
231
  treatment_suggestions = get_meditron_suggestions(patient_info)
232
 
233
  # Combine responses
234
  final_response = f"{diagnosis}\n\n--- TREATMENT RECOMMENDATIONS ---\n\n{treatment_suggestions}\n\n**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice."
235
 
236
+ # Save final response to memory
237
+ memory.save_context({"input": message}, {"output": final_response})
238
+
239
  return final_response
240
 
241
  # Create the Gradio interface