Thanush commited on
Commit
d6da22c
·
1 Parent(s): a985489

Enhance prompt building in app.py to include intelligent follow-up questions and adjust response generation logic based on user information turns.

Browse files
Files changed (1) hide show
  1. app.py +24 -13
app.py CHANGED
@@ -65,15 +65,28 @@ print("Meditron model loaded successfully!")
65
  # Initialize LangChain memory
66
  memory = ConversationBufferMemory(return_messages=True)
67
 
68
- def build_llama2_prompt(system_prompt, messages, user_input):
69
- """Format the conversation history and user input for Llama-2 chat models, using the full message sequence."""
70
  prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
71
  for msg in messages:
72
  if msg.type == "human":
73
  prompt += f"{msg.content} [/INST] "
74
  elif msg.type == "ai":
75
  prompt += f"{msg.content} </s><s>[INST] "
76
- prompt += f"{user_input} [/INST] "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  return prompt
78
 
79
  def get_meditron_suggestions(patient_info):
@@ -133,14 +146,14 @@ def generate_response(message, history):
133
  if not re.fullmatch(r".*(name|age|years? old|I'm|I am|my name is).*", msg.content, re.IGNORECASE):
134
  info_turns += 1
135
 
136
- prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message)
137
- # Only add summarization ONCE, not on every turn after 4 info turns
138
- if info_turns == 4:
139
- prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
 
 
140
 
141
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
142
-
143
- # Generate the Llama-2 response
144
  with torch.no_grad():
145
  outputs = model.generate(
146
  inputs.input_ids,
@@ -151,13 +164,11 @@ def generate_response(message, history):
151
  do_sample=True,
152
  pad_token_id=tokenizer.eos_token_id
153
  )
154
-
155
- # Decode and extract Llama-2's response
156
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
157
  llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
158
 
159
- # After 4 info turns, add medicine suggestions from Meditron, but only once
160
- if info_turns == 4:
161
  full_patient_info = "\n".join([
162
  m.content for m in messages if m.type == "human" and not re.fullmatch(r".*(name|age|years? old|I'm|I am|my name is).*", m.content, re.IGNORECASE)
163
  ] + [message]) + "\n\nSummary: " + llama_response
 
65
  # Initialize LangChain memory
66
  memory = ConversationBufferMemory(return_messages=True)
67
 
68
+ def build_llama2_prompt(system_prompt, messages, user_input, followup_stage=None):
 
69
  prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
70
  for msg in messages:
71
  if msg.type == "human":
72
  prompt += f"{msg.content} [/INST] "
73
  elif msg.type == "ai":
74
  prompt += f"{msg.content} </s><s>[INST] "
75
+ # Add a specific follow-up question if in followup stage
76
+ if followup_stage is not None:
77
+ followup_questions = [
78
+ "Can you describe your main symptoms in detail?",
79
+ "How long have you been experiencing these symptoms?",
80
+ "On a scale of 1-10, how severe are your symptoms?",
81
+ "Have you noticed anything that makes your symptoms better or worse?",
82
+ "Do you have any other related symptoms, such as fever, fatigue, or shortness of breath?"
83
+ ]
84
+ if followup_stage < len(followup_questions):
85
+ prompt += f"{followup_questions[followup_stage]} [/INST] "
86
+ else:
87
+ prompt += f"{user_input} [/INST] "
88
+ else:
89
+ prompt += f"{user_input} [/INST] "
90
  return prompt
91
 
92
  def get_meditron_suggestions(patient_info):
 
146
  if not re.fullmatch(r".*(name|age|years? old|I'm|I am|my name is).*", msg.content, re.IGNORECASE):
147
  info_turns += 1
148
 
149
+ # Ask up to 5 intelligent follow-up questions, then summarize/diagnose
150
+ if info_turns < 5:
151
+ prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message, followup_stage=info_turns)
152
+ else:
153
+ prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message)
154
+ prompt = prompt.replace("[/INST] ", "[/INST] Now, based on all the information, provide a likely diagnosis (if possible), and suggest when professional care may be needed. ")
155
 
156
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
157
  with torch.no_grad():
158
  outputs = model.generate(
159
  inputs.input_ids,
 
164
  do_sample=True,
165
  pad_token_id=tokenizer.eos_token_id
166
  )
 
 
167
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
168
  llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
169
 
170
+ # After 5 info turns, add medicine suggestions from Meditron, but only once
171
+ if info_turns == 5:
172
  full_patient_info = "\n".join([
173
  m.content for m in messages if m.type == "human" and not re.fullmatch(r".*(name|age|years? old|I'm|I am|my name is).*", m.content, re.IGNORECASE)
174
  ] + [message]) + "\n\nSummary: " + llama_response