Thanush commited on
Commit
0b85ef5
Β·
1 Parent(s): 10736b1

Enhance medical consultation app with LangChain memory management and improved patient context tracking

Browse files
Files changed (1) hide show
  1. app.py +202 -35
app.py CHANGED
@@ -2,9 +2,13 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import spaces
 
 
 
 
5
 
6
  # Model configuration - Using correct Me-LLaMA model identifier
7
- ME_LLAMA_MODEL = "clinicalnlplab/me-llama-13b" # Corrected model name
8
 
9
  # System prompts for different phases
10
  CONSULTATION_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
@@ -26,7 +30,8 @@ MEDICINE_PROMPT = """You are a specialized medical assistant. Based on the patie
26
 
27
  Be concise, practical, and focus only on general symptom relief. Do not diagnose. Include a disclaimer that you are not a licensed medical professional.
28
 
29
- Patient information: {patient_info}"""
 
30
 
31
  # Global variables
32
  me_llama_model = None
@@ -34,13 +39,121 @@ me_llama_tokenizer = None
34
  conversation_turns = 0
35
  patient_data = []
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def build_me_llama_prompt(system_prompt, history, user_input):
38
- """Format the conversation for Me-LLaMA chat model."""
 
 
 
 
 
 
39
  # Use standard Llama-2 chat format since Me-LLaMA is based on Llama-2
40
- prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
41
 
42
- # Add conversation history
43
- for user_msg, assistant_msg in history:
 
44
  prompt += f"{user_msg} [/INST] {assistant_msg} </s><s>[INST] "
45
 
46
  # Add the current user input
@@ -80,12 +193,12 @@ def load_model_if_needed():
80
  print("Fallback model loaded successfully!")
81
 
82
  @spaces.GPU
83
- def generate_medicine_suggestions(patient_info):
84
- """Use Me-LLaMA to generate medicine and remedy suggestions."""
85
  load_model_if_needed()
86
 
87
- # Create a simple prompt for medicine suggestions
88
- prompt = f"<s>[INST] {MEDICINE_PROMPT.format(patient_info=patient_info)} [/INST] "
89
 
90
  inputs = me_llama_tokenizer(prompt, return_tensors="pt")
91
 
@@ -109,7 +222,7 @@ def generate_medicine_suggestions(patient_info):
109
 
110
  @spaces.GPU
111
  def generate_response(message, history):
112
- """Generate response using only Me-LLaMA for both consultation and medicine suggestions."""
113
  global conversation_turns, patient_data
114
 
115
  try:
@@ -119,12 +232,12 @@ def generate_response(message, history):
119
  # Track conversation turns
120
  conversation_turns += 1
121
 
122
- # Store patient data
123
  patient_data.append(message)
124
 
125
- # Phase 1-3: Information gathering
126
  if conversation_turns < 4:
127
- # Build consultation prompt
128
  prompt = build_me_llama_prompt(CONSULTATION_PROMPT, history, message)
129
 
130
  inputs = me_llama_tokenizer(prompt, return_tensors="pt")
@@ -149,13 +262,20 @@ def generate_response(message, history):
149
  full_response = me_llama_tokenizer.decode(outputs[0], skip_special_tokens=False)
150
  response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
151
 
 
 
 
152
  return response
153
 
154
- # Phase 4+: Summary and medicine suggestions
155
  else:
156
- # First, get summary from consultation
 
 
 
 
157
  summary_prompt = build_me_llama_prompt(
158
- CONSULTATION_PROMPT + "\n\nNow summarize what you've learned and suggest when professional care may be needed.",
159
  history,
160
  message
161
  )
@@ -180,34 +300,81 @@ def generate_response(message, history):
180
  summary_response = me_llama_tokenizer.decode(outputs[0], skip_special_tokens=False)
181
  summary = summary_response.split('[/INST]')[-1].split('</s>')[0].strip()
182
 
183
- # Then get medicine suggestions using the same model
184
- full_patient_info = "\n".join(patient_data) + f"\n\nMedical Summary: {summary}"
185
- medicine_suggestions = generate_medicine_suggestions(full_patient_info)
186
 
187
  # Combine both responses
188
  final_response = (
189
- f"**MEDICAL SUMMARY:**\n{summary}\n\n"
190
  f"**MEDICATION AND HOME CARE SUGGESTIONS:**\n{medicine_suggestions}\n\n"
 
191
  f"**DISCLAIMER:** This is AI-generated advice for informational purposes only. Please consult a licensed healthcare provider for proper medical diagnosis and treatment."
192
  )
193
 
 
 
 
194
  return final_response
195
 
196
  except Exception as e:
197
- return f"I apologize, but I'm experiencing technical difficulties. Please try again. Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- # Create the Gradio interface
200
- demo = gr.ChatInterface(
201
- fn=generate_response,
202
- title="πŸ₯ Complete Medical Assistant - Me-LLaMA 13B",
203
- description="Comprehensive medical consultation powered by Me-LLaMA 13B. One model handles both consultation and medicine suggestions. Tell me about your symptoms!",
204
- examples=[
205
- "I have a persistent cough and sore throat for 3 days",
206
- "I've been having severe headaches and feel dizzy",
207
- "My stomach hurts and I feel nauseous after eating"
208
- ],
209
- theme="soft"
210
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  if __name__ == "__main__":
213
- demo.launch()
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import spaces
5
+ from langchain.memory import ConversationBufferWindowMemory
6
+ from langchain.schema import HumanMessage, AIMessage
7
+ import json
8
+ from datetime import datetime
9
 
10
  # Model configuration - Using correct Me-LLaMA model identifier
11
+ ME_LLAMA_MODEL = "clinicalnlplab/me-llama-13b"
12
 
13
  # System prompts for different phases
14
  CONSULTATION_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
 
30
 
31
  Be concise, practical, and focus only on general symptom relief. Do not diagnose. Include a disclaimer that you are not a licensed medical professional.
32
 
33
+ Patient information: {patient_info}
34
+ Previous conversation context: {memory_context}"""
35
 
36
  # Global variables
37
  me_llama_model = None
 
39
  conversation_turns = 0
40
  patient_data = []
41
 
42
+ # LangChain Memory Configuration
43
+ class MedicalMemoryManager:
44
+ def __init__(self, k=10): # Keep last 10 conversation turns
45
+ self.conversation_memory = ConversationBufferWindowMemory(k=k, return_messages=True)
46
+ self.patient_context = {
47
+ "symptoms": [],
48
+ "medical_history": [],
49
+ "medications": [],
50
+ "allergies": [],
51
+ "lifestyle_factors": [],
52
+ "timeline": [],
53
+ "severity_scores": {},
54
+ "session_start": datetime.now().isoformat()
55
+ }
56
+
57
+ def add_interaction(self, human_input, ai_response):
58
+ """Add human-AI interaction to memory"""
59
+ self.conversation_memory.chat_memory.add_user_message(human_input)
60
+ self.conversation_memory.chat_memory.add_ai_message(ai_response)
61
+
62
+ # Extract and categorize medical information
63
+ self._extract_medical_info(human_input)
64
+
65
+ def _extract_medical_info(self, user_input):
66
+ """Extract medical information from user input and categorize it"""
67
+ user_lower = user_input.lower()
68
+
69
+ # Extract symptoms (simple keyword matching)
70
+ symptom_keywords = ["pain", "ache", "hurt", "sore", "cough", "fever", "nausea",
71
+ "headache", "dizzy", "tired", "fatigue", "vomit", "swollen",
72
+ "rash", "itch", "burn", "cramp", "bleed", "shortness of breath"]
73
+
74
+ for keyword in symptom_keywords:
75
+ if keyword in user_lower and keyword not in [s.lower() for s in self.patient_context["symptoms"]]:
76
+ self.patient_context["symptoms"].append(user_input)
77
+ break
78
+
79
+ # Extract timeline information
80
+ time_keywords = ["days", "weeks", "months", "hours", "yesterday", "today", "started", "began"]
81
+ if any(keyword in user_lower for keyword in time_keywords):
82
+ self.patient_context["timeline"].append(user_input)
83
+
84
+ # Extract severity (look for numbers 1-10)
85
+ import re
86
+ severity_match = re.search(r'\b([1-9]|10)\b.*(?:pain|severity|scale)', user_lower)
87
+ if severity_match:
88
+ self.patient_context["severity_scores"][datetime.now().isoformat()] = severity_match.group(1)
89
+
90
+ # Extract medications
91
+ med_keywords = ["taking", "medication", "medicine", "pills", "prescribed", "drug"]
92
+ if any(keyword in user_lower for keyword in med_keywords):
93
+ self.patient_context["medications"].append(user_input)
94
+
95
+ # Extract allergies
96
+ allergy_keywords = ["allergic", "allergy", "allergies", "reaction"]
97
+ if any(keyword in user_lower for keyword in allergy_keywords):
98
+ self.patient_context["allergies"].append(user_input)
99
+
100
+ def get_memory_context(self):
101
+ """Get formatted memory context for the model"""
102
+ messages = self.conversation_memory.chat_memory.messages
103
+ context = []
104
+
105
+ for msg in messages[-6:]: # Last 6 messages (3 exchanges)
106
+ if isinstance(msg, HumanMessage):
107
+ context.append(f"Patient: {msg.content}")
108
+ elif isinstance(msg, AIMessage):
109
+ context.append(f"Doctor: {msg.content}")
110
+
111
+ return "\n".join(context)
112
+
113
+ def get_patient_summary(self):
114
+ """Get structured patient information summary"""
115
+ summary = {
116
+ "conversation_turns": len(self.conversation_memory.chat_memory.messages) // 2,
117
+ "session_duration": datetime.now().isoformat(),
118
+ "key_symptoms": self.patient_context["symptoms"][-3:], # Last 3 symptoms mentioned
119
+ "timeline_info": self.patient_context["timeline"][-2:], # Last 2 timeline mentions
120
+ "medications": self.patient_context["medications"],
121
+ "allergies": self.patient_context["allergies"],
122
+ "severity_scores": self.patient_context["severity_scores"]
123
+ }
124
+ return json.dumps(summary, indent=2)
125
+
126
+ def reset_session(self):
127
+ """Reset memory for new consultation"""
128
+ self.conversation_memory.clear()
129
+ self.patient_context = {
130
+ "symptoms": [],
131
+ "medical_history": [],
132
+ "medications": [],
133
+ "allergies": [],
134
+ "lifestyle_factors": [],
135
+ "timeline": [],
136
+ "severity_scores": {},
137
+ "session_start": datetime.now().isoformat()
138
+ }
139
+
140
+ # Initialize memory manager
141
+ memory_manager = MedicalMemoryManager()
142
+
143
  def build_me_llama_prompt(system_prompt, history, user_input):
144
+ """Format the conversation for Me-LLaMA chat model with memory context."""
145
+ # Get memory context from LangChain
146
+ memory_context = memory_manager.get_memory_context()
147
+
148
+ # Enhance system prompt with memory context
149
+ enhanced_system_prompt = f"{system_prompt}\n\nPrevious conversation context:\n{memory_context}"
150
+
151
  # Use standard Llama-2 chat format since Me-LLaMA is based on Llama-2
152
+ prompt = f"<s>[INST] <<SYS>>\n{enhanced_system_prompt}\n<</SYS>>\n\n"
153
 
154
+ # Add only recent history to avoid token limit issues
155
+ recent_history = history[-3:] if len(history) > 3 else history
156
+ for user_msg, assistant_msg in recent_history:
157
  prompt += f"{user_msg} [/INST] {assistant_msg} </s><s>[INST] "
158
 
159
  # Add the current user input
 
193
  print("Fallback model loaded successfully!")
194
 
195
  @spaces.GPU
196
+ def generate_medicine_suggestions(patient_info, memory_context):
197
+ """Use Me-LLaMA to generate medicine and remedy suggestions with memory context."""
198
  load_model_if_needed()
199
 
200
+ # Create a prompt with both patient info and memory context
201
+ prompt = f"<s>[INST] {MEDICINE_PROMPT.format(patient_info=patient_info, memory_context=memory_context)} [/INST] "
202
 
203
  inputs = me_llama_tokenizer(prompt, return_tensors="pt")
204
 
 
222
 
223
  @spaces.GPU
224
  def generate_response(message, history):
225
+ """Generate response using Me-LLaMA with LangChain memory management."""
226
  global conversation_turns, patient_data
227
 
228
  try:
 
232
  # Track conversation turns
233
  conversation_turns += 1
234
 
235
+ # Store patient data (legacy support)
236
  patient_data.append(message)
237
 
238
+ # Phase 1-3: Information gathering with memory
239
  if conversation_turns < 4:
240
+ # Build consultation prompt with memory context
241
  prompt = build_me_llama_prompt(CONSULTATION_PROMPT, history, message)
242
 
243
  inputs = me_llama_tokenizer(prompt, return_tensors="pt")
 
262
  full_response = me_llama_tokenizer.decode(outputs[0], skip_special_tokens=False)
263
  response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
264
 
265
+ # Add interaction to memory
266
+ memory_manager.add_interaction(message, response)
267
+
268
  return response
269
 
270
+ # Phase 4+: Summary and medicine suggestions with full memory context
271
  else:
272
+ # Get comprehensive patient summary from memory
273
+ patient_summary = memory_manager.get_patient_summary()
274
+ memory_context = memory_manager.get_memory_context()
275
+
276
+ # First, get summary from consultation with memory context
277
  summary_prompt = build_me_llama_prompt(
278
+ CONSULTATION_PROMPT + "\n\nNow provide a comprehensive summary based on all the information gathered. Include when professional care may be needed.",
279
  history,
280
  message
281
  )
 
300
  summary_response = me_llama_tokenizer.decode(outputs[0], skip_special_tokens=False)
301
  summary = summary_response.split('[/INST]')[-1].split('</s>')[0].strip()
302
 
303
+ # Get medicine suggestions using memory context
304
+ full_patient_info = f"Patient Summary: {patient_summary}\n\nDetailed Summary: {summary}"
305
+ medicine_suggestions = generate_medicine_suggestions(full_patient_info, memory_context)
306
 
307
  # Combine both responses
308
  final_response = (
309
+ f"**COMPREHENSIVE MEDICAL SUMMARY:**\n{summary}\n\n"
310
  f"**MEDICATION AND HOME CARE SUGGESTIONS:**\n{medicine_suggestions}\n\n"
311
+ f"**PATIENT CONTEXT SUMMARY:**\n{patient_summary}\n\n"
312
  f"**DISCLAIMER:** This is AI-generated advice for informational purposes only. Please consult a licensed healthcare provider for proper medical diagnosis and treatment."
313
  )
314
 
315
+ # Add final interaction to memory
316
+ memory_manager.add_interaction(message, final_response)
317
+
318
  return final_response
319
 
320
  except Exception as e:
321
+ error_msg = f"I apologize, but I'm experiencing technical difficulties. Please try again. Error: {str(e)}"
322
+ # Still try to add to memory even on error
323
+ try:
324
+ memory_manager.add_interaction(message, error_msg)
325
+ except:
326
+ pass
327
+ return error_msg
328
+
329
+ def reset_consultation():
330
+ """Reset the consultation and memory for a new patient."""
331
+ global conversation_turns, patient_data, memory_manager
332
+
333
+ conversation_turns = 0
334
+ patient_data = []
335
+ memory_manager.reset_session()
336
+
337
+ return "New consultation started. Please tell me about your symptoms or health concerns."
338
 
339
+ # Create the Gradio interface with memory reset option
340
+ with gr.Blocks(theme="soft") as demo:
341
+ gr.Markdown("# πŸ₯ Complete Medical Assistant - Me-LLaMA 13B with Memory")
342
+ gr.Markdown("Comprehensive medical consultation powered by Me-LLaMA 13B with LangChain memory management. One model handles both consultation and medicine suggestions with full context awareness.")
343
+
344
+ with gr.Row():
345
+ with gr.Column(scale=4):
346
+ chatbot = gr.Chatbot(height=500)
347
+ msg = gr.Textbox(
348
+ placeholder="Tell me about your symptoms or health concerns...",
349
+ label="Your Message"
350
+ )
351
+
352
+ with gr.Column(scale=1):
353
+ reset_btn = gr.Button("πŸ”„ Start New Consultation", variant="secondary")
354
+ gr.Markdown("**Memory Features:**\n- Tracks symptoms & timeline\n- Remembers medications & allergies\n- Maintains conversation context\n- Provides comprehensive summaries")
355
+
356
+ # Examples
357
+ gr.Examples(
358
+ examples=[
359
+ "I have a persistent cough and sore throat for 3 days",
360
+ "I've been having severe headaches and feel dizzy",
361
+ "My stomach hurts and I feel nauseous after eating"
362
+ ],
363
+ inputs=msg
364
+ )
365
+
366
+ # Event handlers
367
+ def respond(message, chat_history):
368
+ bot_message = generate_response(message, chat_history)
369
+ chat_history.append((message, bot_message))
370
+ return "", chat_history
371
+
372
+ def reset_chat():
373
+ reset_msg = reset_consultation()
374
+ return [(None, reset_msg)], ""
375
+
376
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
377
+ reset_btn.click(reset_chat, [], [chatbot, msg])
378
 
379
  if __name__ == "__main__":
380
+ demo.launch()