Thanush commited on
Commit
a7f6391
·
1 Parent(s): 5522bf8

Enhance user interaction in app.py by refining follow-up questions for symptom collection and implementing intelligent extraction of user name and age from messages. Improve response generation logic to ensure comprehensive medical assessments and treatment recommendations.

Browse files
Files changed (1) hide show
  1. app.py +132 -47
app.py CHANGED
@@ -77,14 +77,15 @@ def build_llama2_prompt(system_prompt, messages, user_input, followup_stage=None
77
  prompt += f"{msg.content} [/INST] "
78
  elif msg.type == "ai":
79
  prompt += f"{msg.content} </s><s>[INST] "
 
80
  # Add a specific follow-up question if in followup stage
81
  if followup_stage is not None:
82
  followup_questions = [
83
- "Can you describe your main symptoms in detail?",
84
- "How long have you been experiencing these symptoms?",
85
- "On a scale of 1-10, how severe are your symptoms?",
86
- "Have you noticed anything that makes your symptoms better or worse?",
87
- "Do you have any other related symptoms, such as fever, fatigue, or shortness of breath?"
88
  ]
89
  if followup_stage < len(followup_questions):
90
  prompt += f"{followup_questions[followup_stage]} [/INST] "
@@ -112,27 +113,87 @@ def get_meditron_suggestions(patient_info):
112
  suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
113
  return suggestion
114
 
115
- def extract_name_age(messages):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  name, age = None, None
 
117
  for msg in messages:
118
  if msg.type == "human":
119
- # Try to extract age
120
- age_match = re.search(r"(?:I am|I'm|age is|aged|My age is|im|i'm)\s*(\d{1,3})", msg.content, re.IGNORECASE)
121
- if age_match and not age:
122
- age = age_match.group(1)
123
- # Try to extract name (avoid matching 'I'm' as name if age is present)
124
- name_match = re.search(r"my name is\s*([A-Za-z]+)", msg.content, re.IGNORECASE)
125
- if name_match and not name:
126
- name = name_match.group(1)
127
- # Fallback: if user says 'I'm <name> and <age>'
128
- fallback_match = re.search(r"i['’`]?m\s*([A-Za-z]+)\s*(?:and|,)?\s*(\d{1,3})", msg.content, re.IGNORECASE)
129
- if fallback_match:
130
- if not name:
131
- name = fallback_match.group(1)
132
- if not age:
133
- age = fallback_match.group(2)
134
  return name, age
135
 
 
 
 
 
 
 
 
 
 
 
 
136
  @spaces.GPU
137
  def generate_response(message, history):
138
  """Generate a response using both models, with full context."""
@@ -142,31 +203,43 @@ def generate_response(message, history):
142
  memory.save_context({"input": message}, {"output": ""})
143
 
144
  messages = memory.chat_memory.messages
145
- name, age = extract_name_age(messages)
 
 
 
 
146
  missing_info = []
147
  if not name:
148
  missing_info.append("your name")
149
  if not age:
150
  missing_info.append("your age")
 
 
151
  if missing_info:
152
- ask = "Before we continue, could you please tell me " + " and ".join(missing_info) + "?"
153
  return ask
154
-
155
- # Count how many user turns have actually provided new info (not just name/age)
156
- info_turns = 0
157
  for msg in messages:
158
  if msg.type == "human":
159
- # Ignore turns that only provide name/age
160
- if not re.fullmatch(r".*(name|age|years? old|I'm|I am|my name is).*", msg.content, re.IGNORECASE):
161
- info_turns += 1
162
-
163
- # Ask up to 5 intelligent follow-up questions, then summarize/diagnose
164
- if info_turns < 5:
165
- prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message, followup_stage=info_turns)
 
 
 
 
166
  else:
 
167
  prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message)
168
- prompt = prompt.replace("[/INST] ", "[/INST] Now, based on all the information, provide a likely diagnosis (if possible), and suggest when professional care may be needed. ")
169
 
 
170
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
171
  with torch.no_grad():
172
  outputs = model.generate(
@@ -178,19 +251,31 @@ def generate_response(message, history):
178
  do_sample=True,
179
  pad_token_id=tokenizer.eos_token_id
180
  )
 
181
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
182
  llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
183
 
184
- # After 5 info turns, add medicine suggestions from Meditron, but only once
185
- if info_turns == 5:
186
- full_patient_info = "\n".join([
187
- m.content for m in messages if m.type == "human" and not re.fullmatch(r".*(name|age|years? old|I'm|I am|my name is).*", m.content, re.IGNORECASE)
188
- ] + [message]) + "\n\nSummary: " + llama_response
189
- medicine_suggestions = get_meditron_suggestions(full_patient_info)
 
 
 
 
 
 
 
 
 
 
190
  final_response = (
191
  f"{llama_response}\n\n"
192
- f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n"
193
- f"{medicine_suggestions}"
 
194
  )
195
  return final_response
196
 
@@ -199,12 +284,12 @@ def generate_response(message, history):
199
  # Create the Gradio interface
200
  demo = gr.ChatInterface(
201
  fn=generate_response,
202
- title="Medical Assistant with Medicine Suggestions",
203
- description="Tell me about your symptoms, and after gathering enough information, I'll suggest potential remedies.",
204
  examples=[
205
- "I have a cough and my throat hurts",
206
- "I've been having headaches for a week",
207
- "My stomach has been hurting since yesterday"
208
  ],
209
  theme="soft"
210
  )
 
77
  prompt += f"{msg.content} [/INST] "
78
  elif msg.type == "ai":
79
  prompt += f"{msg.content} </s><s>[INST] "
80
+
81
  # Add a specific follow-up question if in followup stage
82
  if followup_stage is not None:
83
  followup_questions = [
84
+ "Can you describe your main symptoms in more detail? What exactly are you experiencing?",
85
+ "How long have you been experiencing these symptoms? When did they first start?",
86
+ "On a scale of 1-10, how would you rate the severity of your symptoms?",
87
+ "Have you noticed anything that makes your symptoms better or worse? Any triggers or relief factors?",
88
+ "Do you have any other related symptoms, such as fever, fatigue, nausea, or changes in appetite?"
89
  ]
90
  if followup_stage < len(followup_questions):
91
  prompt += f"{followup_questions[followup_stage]} [/INST] "
 
113
  suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
114
  return suggestion
115
 
116
+ def extract_name_age_intelligent(text):
117
+ """Intelligently extract name and age from user input using multiple patterns."""
118
+ name, age = None, None
119
+ text_lower = text.lower().strip()
120
+
121
+ # Age extraction patterns (more comprehensive)
122
+ age_patterns = [
123
+ r'(?:i am|i\'m|im|age is|aged|my age is|years old|year old)\s*(\d{1,3})',
124
+ r'(\d{1,3})\s*(?:years old|year old|yrs old|yr old)',
125
+ r'\b(\d{1,3})\s*(?:and|,)?\s*(?:years|yrs|y\.o\.)',
126
+ r'(?:^|\s)(\d{1,3})(?:\s|$)', # standalone numbers
127
+ ]
128
+
129
+ for pattern in age_patterns:
130
+ match = re.search(pattern, text_lower)
131
+ if match:
132
+ potential_age = int(match.group(1))
133
+ if 1 <= potential_age <= 120: # reasonable age range
134
+ age = str(potential_age)
135
+ break
136
+
137
+ # Name extraction patterns (more comprehensive)
138
+ name_patterns = [
139
+ r'(?:my name is|name is|i am|i\'m|im|call me|this is)\s+([a-zA-Z][a-zA-Z\s]{1,20}?)(?:\s+and|\s+\d|\s*$)',
140
+ r'^([a-zA-Z][a-zA-Z\s]{1,20}?)\s+(?:and|,)?\s*\d', # name followed by number
141
+ r'(?:^|\s)([a-zA-Z]{2,15})(?:\s+and|\s+\d)', # simple name pattern
142
+ ]
143
+
144
+ for pattern in name_patterns:
145
+ match = re.search(pattern, text_lower)
146
+ if match:
147
+ potential_name = match.group(1).strip().title()
148
+ # Filter out common non-name words
149
+ non_names = ['it', 'is', 'am', 'my', 'me', 'the', 'and', 'or', 'but', 'yes', 'no']
150
+ if potential_name.lower() not in non_names and len(potential_name) >= 2:
151
+ name = potential_name
152
+ break
153
+
154
+ # Special case: handle "thanush and 23" or "it thanush and im 23" patterns
155
+ special_patterns = [
156
+ r'(?:it\s+)?([a-zA-Z]{2,15})\s+and\s+(?:im\s+|i\'m\s+)?(\d{1,3})',
157
+ r'([a-zA-Z]{2,15})\s+and\s+(\d{1,3})',
158
+ ]
159
+
160
+ for pattern in special_patterns:
161
+ match = re.search(pattern, text_lower)
162
+ if match:
163
+ potential_name = match.group(1).strip().title()
164
+ potential_age = int(match.group(2))
165
+ if potential_name.lower() not in ['it', 'is', 'am'] and 1 <= potential_age <= 120:
166
+ name = potential_name
167
+ age = str(potential_age)
168
+ break
169
+
170
+ return name, age
171
+
172
+ def extract_name_age_from_all_messages(messages):
173
+ """Extract name and age from all conversation messages."""
174
  name, age = None, None
175
+
176
  for msg in messages:
177
  if msg.type == "human":
178
+ extracted_name, extracted_age = extract_name_age_intelligent(msg.content)
179
+ if extracted_name and not name:
180
+ name = extracted_name
181
+ if extracted_age and not age:
182
+ age = extracted_age
183
+
 
 
 
 
 
 
 
 
 
184
  return name, age
185
 
186
+ def is_medical_symptom_message(text):
187
+ """Check if the message contains medical symptoms rather than just name/age."""
188
+ medical_keywords = [
189
+ 'hurt', 'pain', 'ache', 'sick', 'fever', 'cough', 'headache', 'stomach', 'throat',
190
+ 'nausea', 'dizzy', 'tired', 'fatigue', 'breathe', 'chest', 'back', 'leg', 'arm',
191
+ 'symptom', 'feel', 'suffering', 'problem', 'issue', 'uncomfortable', 'sore'
192
+ ]
193
+
194
+ text_lower = text.lower()
195
+ return any(keyword in text_lower for keyword in medical_keywords)
196
+
197
  @spaces.GPU
198
  def generate_response(message, history):
199
  """Generate a response using both models, with full context."""
 
203
  memory.save_context({"input": message}, {"output": ""})
204
 
205
  messages = memory.chat_memory.messages
206
+
207
+ # Extract name and age from all messages
208
+ name, age = extract_name_age_from_all_messages(messages)
209
+
210
+ # Check what information is missing
211
  missing_info = []
212
  if not name:
213
  missing_info.append("your name")
214
  if not age:
215
  missing_info.append("your age")
216
+
217
+ # If missing basic info, ask for it
218
  if missing_info:
219
+ ask = "Hello! Before we discuss your health concerns, could you please tell me " + " and ".join(missing_info) + "?"
220
  return ask
221
+
222
+ # Count meaningful medical information exchanges (exclude name/age only messages)
223
+ medical_info_turns = 0
224
  for msg in messages:
225
  if msg.type == "human":
226
+ # Count only if it's not just name/age info and contains medical content
227
+ if is_medical_symptom_message(msg.content) or not any(keyword in msg.content.lower() for keyword in ['name', 'age', 'years', 'old', 'im', 'i am']):
228
+ medical_info_turns += 1
229
+
230
+ # Ensure we have at least one medical symptom mentioned
231
+ if medical_info_turns == 0 and not is_medical_symptom_message(message):
232
+ return f"Thank you, {name}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
233
+
234
+ # Ask up to 5 intelligent follow-up questions, then provide diagnosis and treatment
235
+ if medical_info_turns < 5:
236
+ prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message, followup_stage=medical_info_turns)
237
  else:
238
+ # Time for final diagnosis and treatment recommendations
239
  prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message)
240
+ prompt = prompt.replace("[/INST] ", "[/INST] Based on all the information provided, please provide a comprehensive assessment including: 1) Most likely diagnosis, 2) Recommended next steps, and 3) When to seek immediate medical attention. ")
241
 
242
+ # Generate response using Llama-2
243
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
244
  with torch.no_grad():
245
  outputs = model.generate(
 
251
  do_sample=True,
252
  pad_token_id=tokenizer.eos_token_id
253
  )
254
+
255
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
256
  llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
257
 
258
+ # After 5 medical info turns, add Meditron suggestions
259
+ if medical_info_turns >= 4: # Start suggesting after 4+ turns
260
+ # Compile patient information for Meditron
261
+ patient_summary = f"Patient: {name}, Age: {age}\n\n"
262
+ patient_summary += "Medical Information:\n"
263
+
264
+ for msg in messages:
265
+ if msg.type == "human" and is_medical_symptom_message(msg.content):
266
+ patient_summary += f"- {msg.content}\n"
267
+
268
+ patient_summary += f"\nLatest input: {message}\n"
269
+ patient_summary += f"\nInitial Assessment: {llama_response}"
270
+
271
+ # Get Meditron suggestions
272
+ medicine_suggestions = get_meditron_suggestions(patient_summary)
273
+
274
  final_response = (
275
  f"{llama_response}\n\n"
276
+ f"--- MEDICATION AND HOME CARE RECOMMENDATIONS ---\n\n"
277
+ f"{medicine_suggestions}\n\n"
278
+ f"**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice, especially if symptoms persist or worsen."
279
  )
280
  return final_response
281
 
 
284
  # Create the Gradio interface
285
  demo = gr.ChatInterface(
286
  fn=generate_response,
287
+ title="🩺 AI Medical Assistant with Treatment Suggestions",
288
+ description="Describe your symptoms and I'll gather information to provide medical insights and treatment recommendations.",
289
  examples=[
290
+ "Hi, I'm Sarah and I'm 25. I have a persistent cough and sore throat.",
291
+ "My name is John, I'm 35, and I've been having severe headaches.",
292
+ "I'm Lisa, 28 years old, and my stomach has been hurting since yesterday."
293
  ],
294
  theme="soft"
295
  )