Spaces:
Running
on
Zero
Running
on
Zero
Thanush
commited on
Commit
·
d6da22c
1
Parent(s):
a985489
Enhance prompt building in app.py to include intelligent follow-up questions and adjust response generation logic based on user information turns.
Browse files
app.py
CHANGED
@@ -65,15 +65,28 @@ print("Meditron model loaded successfully!")
|
|
65 |
# Initialize LangChain memory
|
66 |
memory = ConversationBufferMemory(return_messages=True)
|
67 |
|
68 |
-
def build_llama2_prompt(system_prompt, messages, user_input):
|
69 |
-
"""Format the conversation history and user input for Llama-2 chat models, using the full message sequence."""
|
70 |
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
|
71 |
for msg in messages:
|
72 |
if msg.type == "human":
|
73 |
prompt += f"{msg.content} [/INST] "
|
74 |
elif msg.type == "ai":
|
75 |
prompt += f"{msg.content} </s><s>[INST] "
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
return prompt
|
78 |
|
79 |
def get_meditron_suggestions(patient_info):
|
@@ -133,14 +146,14 @@ def generate_response(message, history):
|
|
133 |
if not re.fullmatch(r".*(name|age|years? old|I'm|I am|my name is).*", msg.content, re.IGNORECASE):
|
134 |
info_turns += 1
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
140 |
|
141 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
142 |
-
|
143 |
-
# Generate the Llama-2 response
|
144 |
with torch.no_grad():
|
145 |
outputs = model.generate(
|
146 |
inputs.input_ids,
|
@@ -151,13 +164,11 @@ def generate_response(message, history):
|
|
151 |
do_sample=True,
|
152 |
pad_token_id=tokenizer.eos_token_id
|
153 |
)
|
154 |
-
|
155 |
-
# Decode and extract Llama-2's response
|
156 |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
157 |
llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
|
158 |
|
159 |
-
# After
|
160 |
-
if info_turns ==
|
161 |
full_patient_info = "\n".join([
|
162 |
m.content for m in messages if m.type == "human" and not re.fullmatch(r".*(name|age|years? old|I'm|I am|my name is).*", m.content, re.IGNORECASE)
|
163 |
] + [message]) + "\n\nSummary: " + llama_response
|
|
|
65 |
# Initialize LangChain memory
|
66 |
memory = ConversationBufferMemory(return_messages=True)
|
67 |
|
68 |
+
def build_llama2_prompt(system_prompt, messages, user_input, followup_stage=None):
|
|
|
69 |
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
|
70 |
for msg in messages:
|
71 |
if msg.type == "human":
|
72 |
prompt += f"{msg.content} [/INST] "
|
73 |
elif msg.type == "ai":
|
74 |
prompt += f"{msg.content} </s><s>[INST] "
|
75 |
+
# Add a specific follow-up question if in followup stage
|
76 |
+
if followup_stage is not None:
|
77 |
+
followup_questions = [
|
78 |
+
"Can you describe your main symptoms in detail?",
|
79 |
+
"How long have you been experiencing these symptoms?",
|
80 |
+
"On a scale of 1-10, how severe are your symptoms?",
|
81 |
+
"Have you noticed anything that makes your symptoms better or worse?",
|
82 |
+
"Do you have any other related symptoms, such as fever, fatigue, or shortness of breath?"
|
83 |
+
]
|
84 |
+
if followup_stage < len(followup_questions):
|
85 |
+
prompt += f"{followup_questions[followup_stage]} [/INST] "
|
86 |
+
else:
|
87 |
+
prompt += f"{user_input} [/INST] "
|
88 |
+
else:
|
89 |
+
prompt += f"{user_input} [/INST] "
|
90 |
return prompt
|
91 |
|
92 |
def get_meditron_suggestions(patient_info):
|
|
|
146 |
if not re.fullmatch(r".*(name|age|years? old|I'm|I am|my name is).*", msg.content, re.IGNORECASE):
|
147 |
info_turns += 1
|
148 |
|
149 |
+
# Ask up to 5 intelligent follow-up questions, then summarize/diagnose
|
150 |
+
if info_turns < 5:
|
151 |
+
prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message, followup_stage=info_turns)
|
152 |
+
else:
|
153 |
+
prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message)
|
154 |
+
prompt = prompt.replace("[/INST] ", "[/INST] Now, based on all the information, provide a likely diagnosis (if possible), and suggest when professional care may be needed. ")
|
155 |
|
156 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
|
157 |
with torch.no_grad():
|
158 |
outputs = model.generate(
|
159 |
inputs.input_ids,
|
|
|
164 |
do_sample=True,
|
165 |
pad_token_id=tokenizer.eos_token_id
|
166 |
)
|
|
|
|
|
167 |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
168 |
llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
|
169 |
|
170 |
+
# After 5 info turns, add medicine suggestions from Meditron, but only once
|
171 |
+
if info_turns == 5:
|
172 |
full_patient_info = "\n".join([
|
173 |
m.content for m in messages if m.type == "human" and not re.fullmatch(r".*(name|age|years? old|I'm|I am|my name is).*", m.content, re.IGNORECASE)
|
174 |
] + [message]) + "\n\nSummary: " + llama_response
|