Thanush commited on
Commit
43e5827
·
1 Parent(s): a7f6391

Refactor app.py to streamline user information collection by removing redundant prompts for name and age. Implement a simple state tracking mechanism for improved conversation flow and enhance medical consultation process with structured follow-up questions.

Browse files
Files changed (1) hide show
  1. app.py +126 -189
app.py CHANGED
@@ -9,9 +9,7 @@ import re
9
  LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
10
  MEDITRON_MODEL = "epfl-llm/meditron-7b"
11
 
12
- SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's name, age, health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
13
-
14
- Always begin by asking for the user's name and age if not already provided.
15
 
16
  **IMPORTANT** Ask 1-2 follow-up questions at a time to gather more details about:
17
  - Detailed description of symptoms
@@ -22,7 +20,7 @@ Always begin by asking for the user's name and age if not already provided.
22
  - Medical history
23
  - Current medications and allergies
24
 
25
- After collecting sufficient information (at least 4-5 exchanges, but continue up to 10 if the user keeps responding), summarize findings, provide a likely diagnosis (if possible), and suggest when they should seek professional care.
26
 
27
  If enough information is collected, provide a concise, general diagnosis and a practical over-the-counter medicine and home remedy suggestion.
28
 
@@ -67,33 +65,14 @@ meditron_model = AutoModelForCausalLM.from_pretrained(
67
  )
68
  print("Meditron model loaded successfully!")
69
 
70
- # Initialize LangChain memory
71
- memory = ConversationBufferMemory(return_messages=True)
72
-
73
- def build_llama2_prompt(system_prompt, messages, user_input, followup_stage=None):
74
- prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
75
- for msg in messages:
76
- if msg.type == "human":
77
- prompt += f"{msg.content} [/INST] "
78
- elif msg.type == "ai":
79
- prompt += f"{msg.content} </s><s>[INST] "
80
-
81
- # Add a specific follow-up question if in followup stage
82
- if followup_stage is not None:
83
- followup_questions = [
84
- "Can you describe your main symptoms in more detail? What exactly are you experiencing?",
85
- "How long have you been experiencing these symptoms? When did they first start?",
86
- "On a scale of 1-10, how would you rate the severity of your symptoms?",
87
- "Have you noticed anything that makes your symptoms better or worse? Any triggers or relief factors?",
88
- "Do you have any other related symptoms, such as fever, fatigue, nausea, or changes in appetite?"
89
- ]
90
- if followup_stage < len(followup_questions):
91
- prompt += f"{followup_questions[followup_stage]} [/INST] "
92
- else:
93
- prompt += f"{user_input} [/INST] "
94
- else:
95
- prompt += f"{user_input} [/INST] "
96
- return prompt
97
 
98
  def get_meditron_suggestions(patient_info):
99
  """Use Meditron model to generate medicine and remedy suggestions."""
@@ -113,183 +92,141 @@ def get_meditron_suggestions(patient_info):
113
  suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
114
  return suggestion
115
 
116
- def extract_name_age_intelligent(text):
117
- """Intelligently extract name and age from user input using multiple patterns."""
118
- name, age = None, None
119
- text_lower = text.lower().strip()
120
-
121
- # Age extraction patterns (more comprehensive)
122
- age_patterns = [
123
- r'(?:i am|i\'m|im|age is|aged|my age is|years old|year old)\s*(\d{1,3})',
124
- r'(\d{1,3})\s*(?:years old|year old|yrs old|yr old)',
125
- r'\b(\d{1,3})\s*(?:and|,)?\s*(?:years|yrs|y\.o\.)',
126
- r'(?:^|\s)(\d{1,3})(?:\s|$)', # standalone numbers
127
- ]
128
-
129
- for pattern in age_patterns:
130
- match = re.search(pattern, text_lower)
131
- if match:
132
- potential_age = int(match.group(1))
133
- if 1 <= potential_age <= 120: # reasonable age range
134
- age = str(potential_age)
135
- break
136
-
137
- # Name extraction patterns (more comprehensive)
138
- name_patterns = [
139
- r'(?:my name is|name is|i am|i\'m|im|call me|this is)\s+([a-zA-Z][a-zA-Z\s]{1,20}?)(?:\s+and|\s+\d|\s*$)',
140
- r'^([a-zA-Z][a-zA-Z\s]{1,20}?)\s+(?:and|,)?\s*\d', # name followed by number
141
- r'(?:^|\s)([a-zA-Z]{2,15})(?:\s+and|\s+\d)', # simple name pattern
142
- ]
143
-
144
- for pattern in name_patterns:
145
- match = re.search(pattern, text_lower)
146
- if match:
147
- potential_name = match.group(1).strip().title()
148
- # Filter out common non-name words
149
- non_names = ['it', 'is', 'am', 'my', 'me', 'the', 'and', 'or', 'but', 'yes', 'no']
150
- if potential_name.lower() not in non_names and len(potential_name) >= 2:
151
- name = potential_name
152
- break
153
-
154
- # Special case: handle "thanush and 23" or "it thanush and im 23" patterns
155
- special_patterns = [
156
- r'(?:it\s+)?([a-zA-Z]{2,15})\s+and\s+(?:im\s+|i\'m\s+)?(\d{1,3})',
157
- r'([a-zA-Z]{2,15})\s+and\s+(\d{1,3})',
158
- ]
159
-
160
- for pattern in special_patterns:
161
- match = re.search(pattern, text_lower)
162
- if match:
163
- potential_name = match.group(1).strip().title()
164
- potential_age = int(match.group(2))
165
- if potential_name.lower() not in ['it', 'is', 'am'] and 1 <= potential_age <= 120:
166
- name = potential_name
167
- age = str(potential_age)
168
- break
169
-
170
- return name, age
171
-
172
- def extract_name_age_from_all_messages(messages):
173
- """Extract name and age from all conversation messages."""
174
- name, age = None, None
175
 
176
- for msg in messages:
177
- if msg.type == "human":
178
- extracted_name, extracted_age = extract_name_age_intelligent(msg.content)
179
- if extracted_name and not name:
180
- name = extracted_name
181
- if extracted_age and not age:
182
- age = extracted_age
183
 
184
- return name, age
185
-
186
- def is_medical_symptom_message(text):
187
- """Check if the message contains medical symptoms rather than just name/age."""
188
- medical_keywords = [
189
- 'hurt', 'pain', 'ache', 'sick', 'fever', 'cough', 'headache', 'stomach', 'throat',
190
- 'nausea', 'dizzy', 'tired', 'fatigue', 'breathe', 'chest', 'back', 'leg', 'arm',
191
- 'symptom', 'feel', 'suffering', 'problem', 'issue', 'uncomfortable', 'sore'
192
- ]
193
 
194
- text_lower = text.lower()
195
- return any(keyword in text_lower for keyword in medical_keywords)
196
 
197
  @spaces.GPU
198
  def generate_response(message, history):
199
- """Generate a response using both models, with full context."""
200
- # Save the latest user message and last assistant response to memory
201
- if history and len(history[-1]) == 2:
202
- memory.save_context({"input": history[-1][0]}, {"output": history[-1][1]})
203
- memory.save_context({"input": message}, {"output": ""})
204
-
205
- messages = memory.chat_memory.messages
206
-
207
- # Extract name and age from all messages
208
- name, age = extract_name_age_from_all_messages(messages)
209
-
210
- # Check what information is missing
211
- missing_info = []
212
- if not name:
213
- missing_info.append("your name")
214
- if not age:
215
- missing_info.append("your age")
216
-
217
- # If missing basic info, ask for it
218
- if missing_info:
219
- ask = "Hello! Before we discuss your health concerns, could you please tell me " + " and ".join(missing_info) + "?"
220
- return ask
221
-
222
- # Count meaningful medical information exchanges (exclude name/age only messages)
223
- medical_info_turns = 0
224
- for msg in messages:
225
- if msg.type == "human":
226
- # Count only if it's not just name/age info and contains medical content
227
- if is_medical_symptom_message(msg.content) or not any(keyword in msg.content.lower() for keyword in ['name', 'age', 'years', 'old', 'im', 'i am']):
228
- medical_info_turns += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- # Ensure we have at least one medical symptom mentioned
231
- if medical_info_turns == 0 and not is_medical_symptom_message(message):
232
- return f"Thank you, {name}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
- # Ask up to 5 intelligent follow-up questions, then provide diagnosis and treatment
235
- if medical_info_turns < 5:
236
- prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message, followup_stage=medical_info_turns)
237
  else:
238
- # Time for final diagnosis and treatment recommendations
239
- prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message)
240
- prompt = prompt.replace("[/INST] ", "[/INST] Based on all the information provided, please provide a comprehensive assessment including: 1) Most likely diagnosis, 2) Recommended next steps, and 3) When to seek immediate medical attention. ")
241
-
242
- # Generate response using Llama-2
243
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
244
- with torch.no_grad():
245
- outputs = model.generate(
246
- inputs.input_ids,
247
- attention_mask=inputs.attention_mask,
248
- max_new_tokens=512,
249
- temperature=0.7,
250
- top_p=0.9,
251
- do_sample=True,
252
- pad_token_id=tokenizer.eos_token_id
253
- )
254
-
255
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
256
- llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
257
-
258
- # After 5 medical info turns, add Meditron suggestions
259
- if medical_info_turns >= 4: # Start suggesting after 4+ turns
260
- # Compile patient information for Meditron
261
- patient_summary = f"Patient: {name}, Age: {age}\n\n"
262
- patient_summary += "Medical Information:\n"
263
 
264
- for msg in messages:
265
- if msg.type == "human" and is_medical_symptom_message(msg.content):
266
- patient_summary += f"- {msg.content}\n"
 
267
 
268
- patient_summary += f"\nLatest input: {message}\n"
269
- patient_summary += f"\nInitial Assessment: {llama_response}"
270
 
271
- # Get Meditron suggestions
272
- medicine_suggestions = get_meditron_suggestions(patient_summary)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
- final_response = (
275
- f"{llama_response}\n\n"
276
- f"--- MEDICATION AND HOME CARE RECOMMENDATIONS ---\n\n"
277
- f"{medicine_suggestions}\n\n"
278
- f"**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice, especially if symptoms persist or worsen."
279
- )
280
  return final_response
281
 
282
- return llama_response
283
-
284
  # Create the Gradio interface
285
  demo = gr.ChatInterface(
286
  fn=generate_response,
287
- title="🩺 AI Medical Assistant with Treatment Suggestions",
288
- description="Describe your symptoms and I'll gather information to provide medical insights and treatment recommendations.",
289
  examples=[
290
- "Hi, I'm Sarah and I'm 25. I have a persistent cough and sore throat.",
291
- "My name is John, I'm 35, and I've been having severe headaches.",
292
- "I'm Lisa, 28 years old, and my stomach has been hurting since yesterday."
293
  ],
294
  theme="soft"
295
  )
 
9
  LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
10
  MEDITRON_MODEL = "epfl-llm/meditron-7b"
11
 
12
+ SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
 
 
13
 
14
  **IMPORTANT** Ask 1-2 follow-up questions at a time to gather more details about:
15
  - Detailed description of symptoms
 
20
  - Medical history
21
  - Current medications and allergies
22
 
23
+ After collecting sufficient information, summarize findings, provide a likely diagnosis (if possible), and suggest when they should seek professional care.
24
 
25
  If enough information is collected, provide a concise, general diagnosis and a practical over-the-counter medicine and home remedy suggestion.
26
 
 
65
  )
66
  print("Meditron model loaded successfully!")
67
 
68
+ # Simple conversation state tracking
69
+ conversation_state = {
70
+ 'name': None,
71
+ 'age': None,
72
+ 'medical_turns': 0,
73
+ 'has_name': False,
74
+ 'has_age': False
75
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def get_meditron_suggestions(patient_info):
78
  """Use Meditron model to generate medicine and remedy suggestions."""
 
92
  suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
93
  return suggestion
94
 
95
+ def build_simple_prompt(system_prompt, conversation_history, current_input):
96
+ """Build a simple prompt for Llama-2"""
97
+ prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ # Add conversation history
100
+ for i, (user_msg, bot_msg) in enumerate(conversation_history):
101
+ prompt += f"{user_msg} [/INST] {bot_msg} </s><s>[INST] "
 
 
 
 
102
 
103
+ # Add current input
104
+ prompt += f"{current_input} [/INST] "
 
 
 
 
 
 
 
105
 
106
+ return prompt
 
107
 
108
  @spaces.GPU
109
  def generate_response(message, history):
110
+ """Generate a response using simple state tracking."""
111
+ global conversation_state
112
+
113
+ # Reset state if this is a new conversation
114
+ if not history:
115
+ conversation_state = {
116
+ 'name': None,
117
+ 'age': None,
118
+ 'medical_turns': 0,
119
+ 'has_name': False,
120
+ 'has_age': False
121
+ }
122
+
123
+ # Step 1: Ask for name if not provided
124
+ if not conversation_state['has_name']:
125
+ conversation_state['has_name'] = True
126
+ return "Hello! Before we discuss your health concerns, could you please tell me your name?"
127
+
128
+ # Step 2: Store name and ask for age
129
+ if conversation_state['name'] is None:
130
+ conversation_state['name'] = message.strip()
131
+ return f"Nice to meet you, {conversation_state['name']}! Could you please tell me your age?"
132
+
133
+ # Step 3: Store age and start medical questions
134
+ if not conversation_state['has_age']:
135
+ conversation_state['age'] = message.strip()
136
+ conversation_state['has_age'] = True
137
+ return f"Thank you, {conversation_state['name']}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
138
+
139
+ # Step 4: Medical consultation phase
140
+ conversation_state['medical_turns'] += 1
141
+
142
+ # Prepare conversation history for the model
143
+ medical_history = []
144
+ if len(history) >= 3: # Skip name/age exchanges
145
+ medical_history = history[3:]
146
+
147
+ # Define follow-up questions based on turn number
148
+ followup_questions = [
149
+ "Can you describe your symptoms in more detail? What exactly are you experiencing?",
150
+ "How long have you been experiencing these symptoms? When did they first start?",
151
+ "On a scale of 1-10, how would you rate the severity of your symptoms?",
152
+ "Have you noticed anything that makes your symptoms better or worse?",
153
+ "Do you have any other symptoms, medical history, or are you taking any medications?"
154
+ ]
155
 
156
+ # Build the prompt for medical consultation
157
+ if conversation_state['medical_turns'] <= 5:
158
+ # Still gathering information
159
+ prompt = build_simple_prompt(SYSTEM_PROMPT, medical_history, message)
160
+
161
+ # Generate response
162
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
163
+ with torch.no_grad():
164
+ outputs = model.generate(
165
+ inputs.input_ids,
166
+ attention_mask=inputs.attention_mask,
167
+ max_new_tokens=256,
168
+ temperature=0.7,
169
+ top_p=0.9,
170
+ do_sample=True,
171
+ pad_token_id=tokenizer.eos_token_id
172
+ )
173
+
174
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
175
+ llama_response = full_response.split('[/INST]')[-1].strip()
176
+
177
+ # Add a specific follow-up question
178
+ if conversation_state['medical_turns'] < len(followup_questions):
179
+ next_question = followup_questions[conversation_state['medical_turns']]
180
+ return f"{llama_response}\n\n{next_question}"
181
+ else:
182
+ return llama_response
183
 
 
 
 
184
  else:
185
+ # Time for diagnosis and treatment (after 5+ turns)
186
+ # Compile patient information
187
+ patient_info = f"Patient: {conversation_state['name']}, Age: {conversation_state['age']}\n\n"
188
+ patient_info += "Symptoms and Information:\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ # Add all medical conversation history
191
+ for user_msg, bot_msg in medical_history:
192
+ patient_info += f"Patient: {user_msg}\n"
193
+ patient_info += f"Patient: {message}\n"
194
 
195
+ # Generate diagnosis with Llama-2
196
+ diagnosis_prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\nBased on all the information provided, please provide a comprehensive medical assessment including likely diagnosis and recommendations for {conversation_state['name']}.\n\nPatient Information:\n{patient_info} [/INST] "
197
 
198
+ inputs = tokenizer(diagnosis_prompt, return_tensors="pt").to(model.device)
199
+ with torch.no_grad():
200
+ outputs = model.generate(
201
+ inputs.input_ids,
202
+ attention_mask=inputs.attention_mask,
203
+ max_new_tokens=384,
204
+ temperature=0.7,
205
+ top_p=0.9,
206
+ do_sample=True,
207
+ pad_token_id=tokenizer.eos_token_id
208
+ )
209
+
210
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
211
+ diagnosis = full_response.split('[/INST]')[-1].strip()
212
+
213
+ # Get treatment suggestions from Meditron
214
+ treatment_suggestions = get_meditron_suggestions(patient_info)
215
+
216
+ # Combine responses
217
+ final_response = f"{diagnosis}\n\n--- TREATMENT RECOMMENDATIONS ---\n\n{treatment_suggestions}\n\n**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice."
218
 
 
 
 
 
 
 
219
  return final_response
220
 
 
 
221
  # Create the Gradio interface
222
  demo = gr.ChatInterface(
223
  fn=generate_response,
224
+ title="🩺 AI Medical Assistant",
225
+ description="I'll ask for your basic information first, then gather details about your symptoms to provide medical insights.",
226
  examples=[
227
+ "I have a persistent cough",
228
+ "I've been having headaches",
229
+ "My stomach hurts"
230
  ],
231
  theme="soft"
232
  )