Spaces:
Running
on
Zero
Running
on
Zero
Thanush
Refactor MEDITRON_PROMPT in app.py to enhance medical assessment and recommendations, ensuring evidence-based practices and clear communication. Update name and age extraction logic for improved accuracy in user responses.
5522bf8
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from langchain.memory import ConversationBufferMemory | |
import re | |
# Model configuration | |
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf" | |
MEDITRON_MODEL = "epfl-llm/meditron-7b" | |
SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's name, age, health condition, symptoms, medical history, medications, lifestyle, and other relevant data. | |
Always begin by asking for the user's name and age if not already provided. | |
**IMPORTANT** Ask 1-2 follow-up questions at a time to gather more details about: | |
- Detailed description of symptoms | |
- Duration (when did it start?) | |
- Severity (scale of 1-10) | |
- Aggravating or alleviating factors | |
- Related symptoms | |
- Medical history | |
- Current medications and allergies | |
After collecting sufficient information (at least 4-5 exchanges, but continue up to 10 if the user keeps responding), summarize findings, provide a likely diagnosis (if possible), and suggest when they should seek professional care. | |
If enough information is collected, provide a concise, general diagnosis and a practical over-the-counter medicine and home remedy suggestion. | |
Do NOT make specific prescriptions for prescription-only drugs. | |
Respond empathetically and clearly. Always be professional and thorough.""" | |
MEDITRON_PROMPT = """<|im_start|>system | |
You are a board-certified physician with extensive clinical experience. Your role is to provide evidence-based medical assessment and recommendations following standard medical practice. | |
For each patient case: | |
1. Analyze presented symptoms systematically using medical terminology | |
2. Create a structured differential diagnosis with most likely conditions first | |
3. Recommend appropriate next steps (testing, monitoring, or treatment) | |
4. Provide specific medication recommendations with precise dosing regimens | |
5. Include clear red flags that would necessitate urgent medical attention | |
6. Base all recommendations on current clinical guidelines and evidence-based medicine | |
7. Maintain professional, clear, and compassionate communication | |
Follow standard clinical documentation format when appropriate and prioritize patient safety at all times. Remember to include appropriate medical disclaimers. | |
<|im_start|>user | |
Patient information: {patient_info} | |
<|im_end|> | |
<|im_start|>assistant | |
""" | |
print("Loading Llama-2 model...") | |
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL) | |
model = AutoModelForCausalLM.from_pretrained( | |
LLAMA_MODEL, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
print("Llama-2 model loaded successfully!") | |
print("Loading Meditron model...") | |
meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL) | |
meditron_model = AutoModelForCausalLM.from_pretrained( | |
MEDITRON_MODEL, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
print("Meditron model loaded successfully!") | |
# Initialize LangChain memory | |
memory = ConversationBufferMemory(return_messages=True) | |
def build_llama2_prompt(system_prompt, messages, user_input, followup_stage=None): | |
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n" | |
for msg in messages: | |
if msg.type == "human": | |
prompt += f"{msg.content} [/INST] " | |
elif msg.type == "ai": | |
prompt += f"{msg.content} </s><s>[INST] " | |
# Add a specific follow-up question if in followup stage | |
if followup_stage is not None: | |
followup_questions = [ | |
"Can you describe your main symptoms in detail?", | |
"How long have you been experiencing these symptoms?", | |
"On a scale of 1-10, how severe are your symptoms?", | |
"Have you noticed anything that makes your symptoms better or worse?", | |
"Do you have any other related symptoms, such as fever, fatigue, or shortness of breath?" | |
] | |
if followup_stage < len(followup_questions): | |
prompt += f"{followup_questions[followup_stage]} [/INST] " | |
else: | |
prompt += f"{user_input} [/INST] " | |
else: | |
prompt += f"{user_input} [/INST] " | |
return prompt | |
def get_meditron_suggestions(patient_info): | |
"""Use Meditron model to generate medicine and remedy suggestions.""" | |
prompt = MEDITRON_PROMPT.format(patient_info=patient_info) | |
inputs = meditron_tokenizer(prompt, return_tensors="pt").to(meditron_model.device) | |
with torch.no_grad(): | |
outputs = meditron_model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=256, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
return suggestion | |
def extract_name_age(messages): | |
name, age = None, None | |
for msg in messages: | |
if msg.type == "human": | |
# Try to extract age | |
age_match = re.search(r"(?:I am|I'm|age is|aged|My age is|im|i'm)\s*(\d{1,3})", msg.content, re.IGNORECASE) | |
if age_match and not age: | |
age = age_match.group(1) | |
# Try to extract name (avoid matching 'I'm' as name if age is present) | |
name_match = re.search(r"my name is\s*([A-Za-z]+)", msg.content, re.IGNORECASE) | |
if name_match and not name: | |
name = name_match.group(1) | |
# Fallback: if user says 'I'm <name> and <age>' | |
fallback_match = re.search(r"i['β`]?m\s*([A-Za-z]+)\s*(?:and|,)?\s*(\d{1,3})", msg.content, re.IGNORECASE) | |
if fallback_match: | |
if not name: | |
name = fallback_match.group(1) | |
if not age: | |
age = fallback_match.group(2) | |
return name, age | |
def generate_response(message, history): | |
"""Generate a response using both models, with full context.""" | |
# Save the latest user message and last assistant response to memory | |
if history and len(history[-1]) == 2: | |
memory.save_context({"input": history[-1][0]}, {"output": history[-1][1]}) | |
memory.save_context({"input": message}, {"output": ""}) | |
messages = memory.chat_memory.messages | |
name, age = extract_name_age(messages) | |
missing_info = [] | |
if not name: | |
missing_info.append("your name") | |
if not age: | |
missing_info.append("your age") | |
if missing_info: | |
ask = "Before we continue, could you please tell me " + " and ".join(missing_info) + "?" | |
return ask | |
# Count how many user turns have actually provided new info (not just name/age) | |
info_turns = 0 | |
for msg in messages: | |
if msg.type == "human": | |
# Ignore turns that only provide name/age | |
if not re.fullmatch(r".*(name|age|years? old|I'm|I am|my name is).*", msg.content, re.IGNORECASE): | |
info_turns += 1 | |
# Ask up to 5 intelligent follow-up questions, then summarize/diagnose | |
if info_turns < 5: | |
prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message, followup_stage=info_turns) | |
else: | |
prompt = build_llama2_prompt(SYSTEM_PROMPT, messages, message) | |
prompt = prompt.replace("[/INST] ", "[/INST] Now, based on all the information, provide a likely diagnosis (if possible), and suggest when professional care may be needed. ") | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip() | |
# After 5 info turns, add medicine suggestions from Meditron, but only once | |
if info_turns == 5: | |
full_patient_info = "\n".join([ | |
m.content for m in messages if m.type == "human" and not re.fullmatch(r".*(name|age|years? old|I'm|I am|my name is).*", m.content, re.IGNORECASE) | |
] + [message]) + "\n\nSummary: " + llama_response | |
medicine_suggestions = get_meditron_suggestions(full_patient_info) | |
final_response = ( | |
f"{llama_response}\n\n" | |
f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n" | |
f"{medicine_suggestions}" | |
) | |
return final_response | |
return llama_response | |
# Create the Gradio interface | |
demo = gr.ChatInterface( | |
fn=generate_response, | |
title="Medical Assistant with Medicine Suggestions", | |
description="Tell me about your symptoms, and after gathering enough information, I'll suggest potential remedies.", | |
examples=[ | |
"I have a cough and my throat hurts", | |
"I've been having headaches for a week", | |
"My stomach has been hurting since yesterday" | |
], | |
theme="soft" | |
) | |
if __name__ == "__main__": | |
demo.launch() |