Spaces:
Running
on
Zero
Running
on
Zero
Thanush
commited on
Commit
·
43e5827
1
Parent(s):
a7f6391
Refactor app.py to streamline user information collection by removing redundant prompts for name and age. Implement a simple state tracking mechanism for improved conversation flow and enhance medical consultation process with structured follow-up questions.
Browse files
app.py
CHANGED
@@ -9,9 +9,7 @@ import re
|
|
9 |
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
|
10 |
MEDITRON_MODEL = "epfl-llm/meditron-7b"
|
11 |
|
12 |
-
SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's
|
13 |
-
|
14 |
-
Always begin by asking for the user's name and age if not already provided.
|
15 |
|
16 |
**IMPORTANT** Ask 1-2 follow-up questions at a time to gather more details about:
|
17 |
- Detailed description of symptoms
|
@@ -22,7 +20,7 @@ Always begin by asking for the user's name and age if not already provided.
|
|
22 |
- Medical history
|
23 |
- Current medications and allergies
|
24 |
|
25 |
-
After collecting sufficient information
|
26 |
|
27 |
If enough information is collected, provide a concise, general diagnosis and a practical over-the-counter medicine and home remedy suggestion.
|
28 |
|
@@ -67,33 +65,14 @@ meditron_model = AutoModelForCausalLM.from_pretrained(
|
|
67 |
)
|
68 |
print("Meditron model loaded successfully!")
|
69 |
|
70 |
-
#
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
elif msg.type == "ai":
|
79 |
-
prompt += f"{msg.content} </s><s>[INST] "
|
80 |
-
|
81 |
-
# Add a specific follow-up question if in followup stage
|
82 |
-
if followup_stage is not None:
|
83 |
-
followup_questions = [
|
84 |
-
"Can you describe your main symptoms in more detail? What exactly are you experiencing?",
|
85 |
-
"How long have you been experiencing these symptoms? When did they first start?",
|
86 |
-
"On a scale of 1-10, how would you rate the severity of your symptoms?",
|
87 |
-
"Have you noticed anything that makes your symptoms better or worse? Any triggers or relief factors?",
|
88 |
-
"Do you have any other related symptoms, such as fever, fatigue, nausea, or changes in appetite?"
|
89 |
-
]
|
90 |
-
if followup_stage < len(followup_questions):
|
91 |
-
prompt += f"{followup_questions[followup_stage]} [/INST] "
|
92 |
-
else:
|
93 |
-
prompt += f"{user_input} [/INST] "
|
94 |
-
else:
|
95 |
-
prompt += f"{user_input} [/INST] "
|
96 |
-
return prompt
|
97 |
|
98 |
def get_meditron_suggestions(patient_info):
|
99 |
"""Use Meditron model to generate medicine and remedy suggestions."""
|
@@ -113,183 +92,141 @@ def get_meditron_suggestions(patient_info):
|
|
113 |
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
114 |
return suggestion
|
115 |
|
116 |
-
def
|
117 |
-
"""
|
118 |
-
|
119 |
-
text_lower = text.lower().strip()
|
120 |
-
|
121 |
-
# Age extraction patterns (more comprehensive)
|
122 |
-
age_patterns = [
|
123 |
-
r'(?:i am|i\'m|im|age is|aged|my age is|years old|year old)\s*(\d{1,3})',
|
124 |
-
r'(\d{1,3})\s*(?:years old|year old|yrs old|yr old)',
|
125 |
-
r'\b(\d{1,3})\s*(?:and|,)?\s*(?:years|yrs|y\.o\.)',
|
126 |
-
r'(?:^|\s)(\d{1,3})(?:\s|$)', # standalone numbers
|
127 |
-
]
|
128 |
-
|
129 |
-
for pattern in age_patterns:
|
130 |
-
match = re.search(pattern, text_lower)
|
131 |
-
if match:
|
132 |
-
potential_age = int(match.group(1))
|
133 |
-
if 1 <= potential_age <= 120: # reasonable age range
|
134 |
-
age = str(potential_age)
|
135 |
-
break
|
136 |
-
|
137 |
-
# Name extraction patterns (more comprehensive)
|
138 |
-
name_patterns = [
|
139 |
-
r'(?:my name is|name is|i am|i\'m|im|call me|this is)\s+([a-zA-Z][a-zA-Z\s]{1,20}?)(?:\s+and|\s+\d|\s*$)',
|
140 |
-
r'^([a-zA-Z][a-zA-Z\s]{1,20}?)\s+(?:and|,)?\s*\d', # name followed by number
|
141 |
-
r'(?:^|\s)([a-zA-Z]{2,15})(?:\s+and|\s+\d)', # simple name pattern
|
142 |
-
]
|
143 |
-
|
144 |
-
for pattern in name_patterns:
|
145 |
-
match = re.search(pattern, text_lower)
|
146 |
-
if match:
|
147 |
-
potential_name = match.group(1).strip().title()
|
148 |
-
# Filter out common non-name words
|
149 |
-
non_names = ['it', 'is', 'am', 'my', 'me', 'the', 'and', 'or', 'but', 'yes', 'no']
|
150 |
-
if potential_name.lower() not in non_names and len(potential_name) >= 2:
|
151 |
-
name = potential_name
|
152 |
-
break
|
153 |
-
|
154 |
-
# Special case: handle "thanush and 23" or "it thanush and im 23" patterns
|
155 |
-
special_patterns = [
|
156 |
-
r'(?:it\s+)?([a-zA-Z]{2,15})\s+and\s+(?:im\s+|i\'m\s+)?(\d{1,3})',
|
157 |
-
r'([a-zA-Z]{2,15})\s+and\s+(\d{1,3})',
|
158 |
-
]
|
159 |
-
|
160 |
-
for pattern in special_patterns:
|
161 |
-
match = re.search(pattern, text_lower)
|
162 |
-
if match:
|
163 |
-
potential_name = match.group(1).strip().title()
|
164 |
-
potential_age = int(match.group(2))
|
165 |
-
if potential_name.lower() not in ['it', 'is', 'am'] and 1 <= potential_age <= 120:
|
166 |
-
name = potential_name
|
167 |
-
age = str(potential_age)
|
168 |
-
break
|
169 |
-
|
170 |
-
return name, age
|
171 |
-
|
172 |
-
def extract_name_age_from_all_messages(messages):
|
173 |
-
"""Extract name and age from all conversation messages."""
|
174 |
-
name, age = None, None
|
175 |
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
if extracted_name and not name:
|
180 |
-
name = extracted_name
|
181 |
-
if extracted_age and not age:
|
182 |
-
age = extracted_age
|
183 |
|
184 |
-
|
185 |
-
|
186 |
-
def is_medical_symptom_message(text):
|
187 |
-
"""Check if the message contains medical symptoms rather than just name/age."""
|
188 |
-
medical_keywords = [
|
189 |
-
'hurt', 'pain', 'ache', 'sick', 'fever', 'cough', 'headache', 'stomach', 'throat',
|
190 |
-
'nausea', 'dizzy', 'tired', 'fatigue', 'breathe', 'chest', 'back', 'leg', 'arm',
|
191 |
-
'symptom', 'feel', 'suffering', 'problem', 'issue', 'uncomfortable', 'sore'
|
192 |
-
]
|
193 |
|
194 |
-
|
195 |
-
return any(keyword in text_lower for keyword in medical_keywords)
|
196 |
|
197 |
@spaces.GPU
|
198 |
def generate_response(message, history):
|
199 |
-
"""Generate a response using
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
if not
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
#
|
218 |
-
if
|
219 |
-
|
220 |
-
return
|
221 |
-
|
222 |
-
#
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
|
230 |
-
#
|
231 |
-
if
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
-
# Ask up to 5 intelligent follow-up questions, then provide diagnosis and treatment
|
235 |
-
if medical_info_turns < 5:
|
236 |
-
prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message, followup_stage=medical_info_turns)
|
237 |
else:
|
238 |
-
# Time for
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
# Generate response using Llama-2
|
243 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
244 |
-
with torch.no_grad():
|
245 |
-
outputs = model.generate(
|
246 |
-
inputs.input_ids,
|
247 |
-
attention_mask=inputs.attention_mask,
|
248 |
-
max_new_tokens=512,
|
249 |
-
temperature=0.7,
|
250 |
-
top_p=0.9,
|
251 |
-
do_sample=True,
|
252 |
-
pad_token_id=tokenizer.eos_token_id
|
253 |
-
)
|
254 |
-
|
255 |
-
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
256 |
-
llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
|
257 |
-
|
258 |
-
# After 5 medical info turns, add Meditron suggestions
|
259 |
-
if medical_info_turns >= 4: # Start suggesting after 4+ turns
|
260 |
-
# Compile patient information for Meditron
|
261 |
-
patient_summary = f"Patient: {name}, Age: {age}\n\n"
|
262 |
-
patient_summary += "Medical Information:\n"
|
263 |
|
264 |
-
|
265 |
-
|
266 |
-
|
|
|
267 |
|
268 |
-
|
269 |
-
|
270 |
|
271 |
-
|
272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
|
274 |
-
final_response = (
|
275 |
-
f"{llama_response}\n\n"
|
276 |
-
f"--- MEDICATION AND HOME CARE RECOMMENDATIONS ---\n\n"
|
277 |
-
f"{medicine_suggestions}\n\n"
|
278 |
-
f"**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice, especially if symptoms persist or worsen."
|
279 |
-
)
|
280 |
return final_response
|
281 |
|
282 |
-
return llama_response
|
283 |
-
|
284 |
# Create the Gradio interface
|
285 |
demo = gr.ChatInterface(
|
286 |
fn=generate_response,
|
287 |
-
title="🩺 AI Medical Assistant
|
288 |
-
description="
|
289 |
examples=[
|
290 |
-
"
|
291 |
-
"
|
292 |
-
"
|
293 |
],
|
294 |
theme="soft"
|
295 |
)
|
|
|
9 |
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
|
10 |
MEDITRON_MODEL = "epfl-llm/meditron-7b"
|
11 |
|
12 |
+
SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
|
|
|
|
|
13 |
|
14 |
**IMPORTANT** Ask 1-2 follow-up questions at a time to gather more details about:
|
15 |
- Detailed description of symptoms
|
|
|
20 |
- Medical history
|
21 |
- Current medications and allergies
|
22 |
|
23 |
+
After collecting sufficient information, summarize findings, provide a likely diagnosis (if possible), and suggest when they should seek professional care.
|
24 |
|
25 |
If enough information is collected, provide a concise, general diagnosis and a practical over-the-counter medicine and home remedy suggestion.
|
26 |
|
|
|
65 |
)
|
66 |
print("Meditron model loaded successfully!")
|
67 |
|
68 |
+
# Simple conversation state tracking
|
69 |
+
conversation_state = {
|
70 |
+
'name': None,
|
71 |
+
'age': None,
|
72 |
+
'medical_turns': 0,
|
73 |
+
'has_name': False,
|
74 |
+
'has_age': False
|
75 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
def get_meditron_suggestions(patient_info):
|
78 |
"""Use Meditron model to generate medicine and remedy suggestions."""
|
|
|
92 |
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
93 |
return suggestion
|
94 |
|
95 |
+
def build_simple_prompt(system_prompt, conversation_history, current_input):
|
96 |
+
"""Build a simple prompt for Llama-2"""
|
97 |
+
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
+
# Add conversation history
|
100 |
+
for i, (user_msg, bot_msg) in enumerate(conversation_history):
|
101 |
+
prompt += f"{user_msg} [/INST] {bot_msg} </s><s>[INST] "
|
|
|
|
|
|
|
|
|
102 |
|
103 |
+
# Add current input
|
104 |
+
prompt += f"{current_input} [/INST] "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
+
return prompt
|
|
|
107 |
|
108 |
@spaces.GPU
|
109 |
def generate_response(message, history):
|
110 |
+
"""Generate a response using simple state tracking."""
|
111 |
+
global conversation_state
|
112 |
+
|
113 |
+
# Reset state if this is a new conversation
|
114 |
+
if not history:
|
115 |
+
conversation_state = {
|
116 |
+
'name': None,
|
117 |
+
'age': None,
|
118 |
+
'medical_turns': 0,
|
119 |
+
'has_name': False,
|
120 |
+
'has_age': False
|
121 |
+
}
|
122 |
+
|
123 |
+
# Step 1: Ask for name if not provided
|
124 |
+
if not conversation_state['has_name']:
|
125 |
+
conversation_state['has_name'] = True
|
126 |
+
return "Hello! Before we discuss your health concerns, could you please tell me your name?"
|
127 |
+
|
128 |
+
# Step 2: Store name and ask for age
|
129 |
+
if conversation_state['name'] is None:
|
130 |
+
conversation_state['name'] = message.strip()
|
131 |
+
return f"Nice to meet you, {conversation_state['name']}! Could you please tell me your age?"
|
132 |
+
|
133 |
+
# Step 3: Store age and start medical questions
|
134 |
+
if not conversation_state['has_age']:
|
135 |
+
conversation_state['age'] = message.strip()
|
136 |
+
conversation_state['has_age'] = True
|
137 |
+
return f"Thank you, {conversation_state['name']}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
|
138 |
+
|
139 |
+
# Step 4: Medical consultation phase
|
140 |
+
conversation_state['medical_turns'] += 1
|
141 |
+
|
142 |
+
# Prepare conversation history for the model
|
143 |
+
medical_history = []
|
144 |
+
if len(history) >= 3: # Skip name/age exchanges
|
145 |
+
medical_history = history[3:]
|
146 |
+
|
147 |
+
# Define follow-up questions based on turn number
|
148 |
+
followup_questions = [
|
149 |
+
"Can you describe your symptoms in more detail? What exactly are you experiencing?",
|
150 |
+
"How long have you been experiencing these symptoms? When did they first start?",
|
151 |
+
"On a scale of 1-10, how would you rate the severity of your symptoms?",
|
152 |
+
"Have you noticed anything that makes your symptoms better or worse?",
|
153 |
+
"Do you have any other symptoms, medical history, or are you taking any medications?"
|
154 |
+
]
|
155 |
|
156 |
+
# Build the prompt for medical consultation
|
157 |
+
if conversation_state['medical_turns'] <= 5:
|
158 |
+
# Still gathering information
|
159 |
+
prompt = build_simple_prompt(SYSTEM_PROMPT, medical_history, message)
|
160 |
+
|
161 |
+
# Generate response
|
162 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
163 |
+
with torch.no_grad():
|
164 |
+
outputs = model.generate(
|
165 |
+
inputs.input_ids,
|
166 |
+
attention_mask=inputs.attention_mask,
|
167 |
+
max_new_tokens=256,
|
168 |
+
temperature=0.7,
|
169 |
+
top_p=0.9,
|
170 |
+
do_sample=True,
|
171 |
+
pad_token_id=tokenizer.eos_token_id
|
172 |
+
)
|
173 |
+
|
174 |
+
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
175 |
+
llama_response = full_response.split('[/INST]')[-1].strip()
|
176 |
+
|
177 |
+
# Add a specific follow-up question
|
178 |
+
if conversation_state['medical_turns'] < len(followup_questions):
|
179 |
+
next_question = followup_questions[conversation_state['medical_turns']]
|
180 |
+
return f"{llama_response}\n\n{next_question}"
|
181 |
+
else:
|
182 |
+
return llama_response
|
183 |
|
|
|
|
|
|
|
184 |
else:
|
185 |
+
# Time for diagnosis and treatment (after 5+ turns)
|
186 |
+
# Compile patient information
|
187 |
+
patient_info = f"Patient: {conversation_state['name']}, Age: {conversation_state['age']}\n\n"
|
188 |
+
patient_info += "Symptoms and Information:\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
+
# Add all medical conversation history
|
191 |
+
for user_msg, bot_msg in medical_history:
|
192 |
+
patient_info += f"Patient: {user_msg}\n"
|
193 |
+
patient_info += f"Patient: {message}\n"
|
194 |
|
195 |
+
# Generate diagnosis with Llama-2
|
196 |
+
diagnosis_prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\nBased on all the information provided, please provide a comprehensive medical assessment including likely diagnosis and recommendations for {conversation_state['name']}.\n\nPatient Information:\n{patient_info} [/INST] "
|
197 |
|
198 |
+
inputs = tokenizer(diagnosis_prompt, return_tensors="pt").to(model.device)
|
199 |
+
with torch.no_grad():
|
200 |
+
outputs = model.generate(
|
201 |
+
inputs.input_ids,
|
202 |
+
attention_mask=inputs.attention_mask,
|
203 |
+
max_new_tokens=384,
|
204 |
+
temperature=0.7,
|
205 |
+
top_p=0.9,
|
206 |
+
do_sample=True,
|
207 |
+
pad_token_id=tokenizer.eos_token_id
|
208 |
+
)
|
209 |
+
|
210 |
+
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
211 |
+
diagnosis = full_response.split('[/INST]')[-1].strip()
|
212 |
+
|
213 |
+
# Get treatment suggestions from Meditron
|
214 |
+
treatment_suggestions = get_meditron_suggestions(patient_info)
|
215 |
+
|
216 |
+
# Combine responses
|
217 |
+
final_response = f"{diagnosis}\n\n--- TREATMENT RECOMMENDATIONS ---\n\n{treatment_suggestions}\n\n**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice."
|
218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
return final_response
|
220 |
|
|
|
|
|
221 |
# Create the Gradio interface
|
222 |
demo = gr.ChatInterface(
|
223 |
fn=generate_response,
|
224 |
+
title="🩺 AI Medical Assistant",
|
225 |
+
description="I'll ask for your basic information first, then gather details about your symptoms to provide medical insights.",
|
226 |
examples=[
|
227 |
+
"I have a persistent cough",
|
228 |
+
"I've been having headaches",
|
229 |
+
"My stomach hurts"
|
230 |
],
|
231 |
theme="soft"
|
232 |
)
|