Spaces:
Running
on
Zero
Running
on
Zero
Thanush
commited on
Commit
·
a7f6391
1
Parent(s):
5522bf8
Enhance user interaction in app.py by refining follow-up questions for symptom collection and implementing intelligent extraction of user name and age from messages. Improve response generation logic to ensure comprehensive medical assessments and treatment recommendations.
Browse files
app.py
CHANGED
@@ -77,14 +77,15 @@ def build_llama2_prompt(system_prompt, messages, user_input, followup_stage=None
|
|
77 |
prompt += f"{msg.content} [/INST] "
|
78 |
elif msg.type == "ai":
|
79 |
prompt += f"{msg.content} </s><s>[INST] "
|
|
|
80 |
# Add a specific follow-up question if in followup stage
|
81 |
if followup_stage is not None:
|
82 |
followup_questions = [
|
83 |
-
"Can you describe your main symptoms in detail?",
|
84 |
-
"How long have you been experiencing these symptoms?",
|
85 |
-
"On a scale of 1-10, how
|
86 |
-
"Have you noticed anything that makes your symptoms better or worse?",
|
87 |
-
"Do you have any other related symptoms, such as fever, fatigue, or
|
88 |
]
|
89 |
if followup_stage < len(followup_questions):
|
90 |
prompt += f"{followup_questions[followup_stage]} [/INST] "
|
@@ -112,27 +113,87 @@ def get_meditron_suggestions(patient_info):
|
|
112 |
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
113 |
return suggestion
|
114 |
|
115 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
name, age = None, None
|
|
|
117 |
for msg in messages:
|
118 |
if msg.type == "human":
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
if name_match and not name:
|
126 |
-
name = name_match.group(1)
|
127 |
-
# Fallback: if user says 'I'm <name> and <age>'
|
128 |
-
fallback_match = re.search(r"i['’`]?m\s*([A-Za-z]+)\s*(?:and|,)?\s*(\d{1,3})", msg.content, re.IGNORECASE)
|
129 |
-
if fallback_match:
|
130 |
-
if not name:
|
131 |
-
name = fallback_match.group(1)
|
132 |
-
if not age:
|
133 |
-
age = fallback_match.group(2)
|
134 |
return name, age
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
@spaces.GPU
|
137 |
def generate_response(message, history):
|
138 |
"""Generate a response using both models, with full context."""
|
@@ -142,31 +203,43 @@ def generate_response(message, history):
|
|
142 |
memory.save_context({"input": message}, {"output": ""})
|
143 |
|
144 |
messages = memory.chat_memory.messages
|
145 |
-
|
|
|
|
|
|
|
|
|
146 |
missing_info = []
|
147 |
if not name:
|
148 |
missing_info.append("your name")
|
149 |
if not age:
|
150 |
missing_info.append("your age")
|
|
|
|
|
151 |
if missing_info:
|
152 |
-
ask = "Before we
|
153 |
return ask
|
154 |
-
|
155 |
-
# Count
|
156 |
-
|
157 |
for msg in messages:
|
158 |
if msg.type == "human":
|
159 |
-
#
|
160 |
-
if not
|
161 |
-
|
162 |
-
|
163 |
-
#
|
164 |
-
if
|
165 |
-
|
|
|
|
|
|
|
|
|
166 |
else:
|
|
|
167 |
prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message)
|
168 |
-
prompt = prompt.replace("[/INST] ", "[/INST]
|
169 |
|
|
|
170 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
171 |
with torch.no_grad():
|
172 |
outputs = model.generate(
|
@@ -178,19 +251,31 @@ def generate_response(message, history):
|
|
178 |
do_sample=True,
|
179 |
pad_token_id=tokenizer.eos_token_id
|
180 |
)
|
|
|
181 |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
182 |
llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
|
183 |
|
184 |
-
# After 5 info turns, add
|
185 |
-
if
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
final_response = (
|
191 |
f"{llama_response}\n\n"
|
192 |
-
f"--- MEDICATION AND HOME CARE
|
193 |
-
f"{medicine_suggestions}"
|
|
|
194 |
)
|
195 |
return final_response
|
196 |
|
@@ -199,12 +284,12 @@ def generate_response(message, history):
|
|
199 |
# Create the Gradio interface
|
200 |
demo = gr.ChatInterface(
|
201 |
fn=generate_response,
|
202 |
-
title="Medical Assistant with
|
203 |
-
description="
|
204 |
examples=[
|
205 |
-
"I have a cough and
|
206 |
-
"I've been having headaches
|
207 |
-
"
|
208 |
],
|
209 |
theme="soft"
|
210 |
)
|
|
|
77 |
prompt += f"{msg.content} [/INST] "
|
78 |
elif msg.type == "ai":
|
79 |
prompt += f"{msg.content} </s><s>[INST] "
|
80 |
+
|
81 |
# Add a specific follow-up question if in followup stage
|
82 |
if followup_stage is not None:
|
83 |
followup_questions = [
|
84 |
+
"Can you describe your main symptoms in more detail? What exactly are you experiencing?",
|
85 |
+
"How long have you been experiencing these symptoms? When did they first start?",
|
86 |
+
"On a scale of 1-10, how would you rate the severity of your symptoms?",
|
87 |
+
"Have you noticed anything that makes your symptoms better or worse? Any triggers or relief factors?",
|
88 |
+
"Do you have any other related symptoms, such as fever, fatigue, nausea, or changes in appetite?"
|
89 |
]
|
90 |
if followup_stage < len(followup_questions):
|
91 |
prompt += f"{followup_questions[followup_stage]} [/INST] "
|
|
|
113 |
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
114 |
return suggestion
|
115 |
|
116 |
+
def extract_name_age_intelligent(text):
|
117 |
+
"""Intelligently extract name and age from user input using multiple patterns."""
|
118 |
+
name, age = None, None
|
119 |
+
text_lower = text.lower().strip()
|
120 |
+
|
121 |
+
# Age extraction patterns (more comprehensive)
|
122 |
+
age_patterns = [
|
123 |
+
r'(?:i am|i\'m|im|age is|aged|my age is|years old|year old)\s*(\d{1,3})',
|
124 |
+
r'(\d{1,3})\s*(?:years old|year old|yrs old|yr old)',
|
125 |
+
r'\b(\d{1,3})\s*(?:and|,)?\s*(?:years|yrs|y\.o\.)',
|
126 |
+
r'(?:^|\s)(\d{1,3})(?:\s|$)', # standalone numbers
|
127 |
+
]
|
128 |
+
|
129 |
+
for pattern in age_patterns:
|
130 |
+
match = re.search(pattern, text_lower)
|
131 |
+
if match:
|
132 |
+
potential_age = int(match.group(1))
|
133 |
+
if 1 <= potential_age <= 120: # reasonable age range
|
134 |
+
age = str(potential_age)
|
135 |
+
break
|
136 |
+
|
137 |
+
# Name extraction patterns (more comprehensive)
|
138 |
+
name_patterns = [
|
139 |
+
r'(?:my name is|name is|i am|i\'m|im|call me|this is)\s+([a-zA-Z][a-zA-Z\s]{1,20}?)(?:\s+and|\s+\d|\s*$)',
|
140 |
+
r'^([a-zA-Z][a-zA-Z\s]{1,20}?)\s+(?:and|,)?\s*\d', # name followed by number
|
141 |
+
r'(?:^|\s)([a-zA-Z]{2,15})(?:\s+and|\s+\d)', # simple name pattern
|
142 |
+
]
|
143 |
+
|
144 |
+
for pattern in name_patterns:
|
145 |
+
match = re.search(pattern, text_lower)
|
146 |
+
if match:
|
147 |
+
potential_name = match.group(1).strip().title()
|
148 |
+
# Filter out common non-name words
|
149 |
+
non_names = ['it', 'is', 'am', 'my', 'me', 'the', 'and', 'or', 'but', 'yes', 'no']
|
150 |
+
if potential_name.lower() not in non_names and len(potential_name) >= 2:
|
151 |
+
name = potential_name
|
152 |
+
break
|
153 |
+
|
154 |
+
# Special case: handle "thanush and 23" or "it thanush and im 23" patterns
|
155 |
+
special_patterns = [
|
156 |
+
r'(?:it\s+)?([a-zA-Z]{2,15})\s+and\s+(?:im\s+|i\'m\s+)?(\d{1,3})',
|
157 |
+
r'([a-zA-Z]{2,15})\s+and\s+(\d{1,3})',
|
158 |
+
]
|
159 |
+
|
160 |
+
for pattern in special_patterns:
|
161 |
+
match = re.search(pattern, text_lower)
|
162 |
+
if match:
|
163 |
+
potential_name = match.group(1).strip().title()
|
164 |
+
potential_age = int(match.group(2))
|
165 |
+
if potential_name.lower() not in ['it', 'is', 'am'] and 1 <= potential_age <= 120:
|
166 |
+
name = potential_name
|
167 |
+
age = str(potential_age)
|
168 |
+
break
|
169 |
+
|
170 |
+
return name, age
|
171 |
+
|
172 |
+
def extract_name_age_from_all_messages(messages):
|
173 |
+
"""Extract name and age from all conversation messages."""
|
174 |
name, age = None, None
|
175 |
+
|
176 |
for msg in messages:
|
177 |
if msg.type == "human":
|
178 |
+
extracted_name, extracted_age = extract_name_age_intelligent(msg.content)
|
179 |
+
if extracted_name and not name:
|
180 |
+
name = extracted_name
|
181 |
+
if extracted_age and not age:
|
182 |
+
age = extracted_age
|
183 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
return name, age
|
185 |
|
186 |
+
def is_medical_symptom_message(text):
|
187 |
+
"""Check if the message contains medical symptoms rather than just name/age."""
|
188 |
+
medical_keywords = [
|
189 |
+
'hurt', 'pain', 'ache', 'sick', 'fever', 'cough', 'headache', 'stomach', 'throat',
|
190 |
+
'nausea', 'dizzy', 'tired', 'fatigue', 'breathe', 'chest', 'back', 'leg', 'arm',
|
191 |
+
'symptom', 'feel', 'suffering', 'problem', 'issue', 'uncomfortable', 'sore'
|
192 |
+
]
|
193 |
+
|
194 |
+
text_lower = text.lower()
|
195 |
+
return any(keyword in text_lower for keyword in medical_keywords)
|
196 |
+
|
197 |
@spaces.GPU
|
198 |
def generate_response(message, history):
|
199 |
"""Generate a response using both models, with full context."""
|
|
|
203 |
memory.save_context({"input": message}, {"output": ""})
|
204 |
|
205 |
messages = memory.chat_memory.messages
|
206 |
+
|
207 |
+
# Extract name and age from all messages
|
208 |
+
name, age = extract_name_age_from_all_messages(messages)
|
209 |
+
|
210 |
+
# Check what information is missing
|
211 |
missing_info = []
|
212 |
if not name:
|
213 |
missing_info.append("your name")
|
214 |
if not age:
|
215 |
missing_info.append("your age")
|
216 |
+
|
217 |
+
# If missing basic info, ask for it
|
218 |
if missing_info:
|
219 |
+
ask = "Hello! Before we discuss your health concerns, could you please tell me " + " and ".join(missing_info) + "?"
|
220 |
return ask
|
221 |
+
|
222 |
+
# Count meaningful medical information exchanges (exclude name/age only messages)
|
223 |
+
medical_info_turns = 0
|
224 |
for msg in messages:
|
225 |
if msg.type == "human":
|
226 |
+
# Count only if it's not just name/age info and contains medical content
|
227 |
+
if is_medical_symptom_message(msg.content) or not any(keyword in msg.content.lower() for keyword in ['name', 'age', 'years', 'old', 'im', 'i am']):
|
228 |
+
medical_info_turns += 1
|
229 |
+
|
230 |
+
# Ensure we have at least one medical symptom mentioned
|
231 |
+
if medical_info_turns == 0 and not is_medical_symptom_message(message):
|
232 |
+
return f"Thank you, {name}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
|
233 |
+
|
234 |
+
# Ask up to 5 intelligent follow-up questions, then provide diagnosis and treatment
|
235 |
+
if medical_info_turns < 5:
|
236 |
+
prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message, followup_stage=medical_info_turns)
|
237 |
else:
|
238 |
+
# Time for final diagnosis and treatment recommendations
|
239 |
prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message)
|
240 |
+
prompt = prompt.replace("[/INST] ", "[/INST] Based on all the information provided, please provide a comprehensive assessment including: 1) Most likely diagnosis, 2) Recommended next steps, and 3) When to seek immediate medical attention. ")
|
241 |
|
242 |
+
# Generate response using Llama-2
|
243 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
244 |
with torch.no_grad():
|
245 |
outputs = model.generate(
|
|
|
251 |
do_sample=True,
|
252 |
pad_token_id=tokenizer.eos_token_id
|
253 |
)
|
254 |
+
|
255 |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
256 |
llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
|
257 |
|
258 |
+
# After 5 medical info turns, add Meditron suggestions
|
259 |
+
if medical_info_turns >= 4: # Start suggesting after 4+ turns
|
260 |
+
# Compile patient information for Meditron
|
261 |
+
patient_summary = f"Patient: {name}, Age: {age}\n\n"
|
262 |
+
patient_summary += "Medical Information:\n"
|
263 |
+
|
264 |
+
for msg in messages:
|
265 |
+
if msg.type == "human" and is_medical_symptom_message(msg.content):
|
266 |
+
patient_summary += f"- {msg.content}\n"
|
267 |
+
|
268 |
+
patient_summary += f"\nLatest input: {message}\n"
|
269 |
+
patient_summary += f"\nInitial Assessment: {llama_response}"
|
270 |
+
|
271 |
+
# Get Meditron suggestions
|
272 |
+
medicine_suggestions = get_meditron_suggestions(patient_summary)
|
273 |
+
|
274 |
final_response = (
|
275 |
f"{llama_response}\n\n"
|
276 |
+
f"--- MEDICATION AND HOME CARE RECOMMENDATIONS ---\n\n"
|
277 |
+
f"{medicine_suggestions}\n\n"
|
278 |
+
f"**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice, especially if symptoms persist or worsen."
|
279 |
)
|
280 |
return final_response
|
281 |
|
|
|
284 |
# Create the Gradio interface
|
285 |
demo = gr.ChatInterface(
|
286 |
fn=generate_response,
|
287 |
+
title="🩺 AI Medical Assistant with Treatment Suggestions",
|
288 |
+
description="Describe your symptoms and I'll gather information to provide medical insights and treatment recommendations.",
|
289 |
examples=[
|
290 |
+
"Hi, I'm Sarah and I'm 25. I have a persistent cough and sore throat.",
|
291 |
+
"My name is John, I'm 35, and I've been having severe headaches.",
|
292 |
+
"I'm Lisa, 28 years old, and my stomach has been hurting since yesterday."
|
293 |
],
|
294 |
theme="soft"
|
295 |
)
|