medbot_2 / app.py
Thanush
Refactor app.py to streamline user information collection by removing redundant prompts for name and age. Implement a simple state tracking mechanism for improved conversation flow and enhance medical consultation process with structured follow-up questions.
43e5827
raw
history blame
9.56 kB
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain.memory import ConversationBufferMemory
import re
# Model configuration
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
MEDITRON_MODEL = "epfl-llm/meditron-7b"
SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
**IMPORTANT** Ask 1-2 follow-up questions at a time to gather more details about:
- Detailed description of symptoms
- Duration (when did it start?)
- Severity (scale of 1-10)
- Aggravating or alleviating factors
- Related symptoms
- Medical history
- Current medications and allergies
After collecting sufficient information, summarize findings, provide a likely diagnosis (if possible), and suggest when they should seek professional care.
If enough information is collected, provide a concise, general diagnosis and a practical over-the-counter medicine and home remedy suggestion.
Do NOT make specific prescriptions for prescription-only drugs.
Respond empathetically and clearly. Always be professional and thorough."""
MEDITRON_PROMPT = """<|im_start|>system
You are a board-certified physician with extensive clinical experience. Your role is to provide evidence-based medical assessment and recommendations following standard medical practice.
For each patient case:
1. Analyze presented symptoms systematically using medical terminology
2. Create a structured differential diagnosis with most likely conditions first
3. Recommend appropriate next steps (testing, monitoring, or treatment)
4. Provide specific medication recommendations with precise dosing regimens
5. Include clear red flags that would necessitate urgent medical attention
6. Base all recommendations on current clinical guidelines and evidence-based medicine
7. Maintain professional, clear, and compassionate communication
Follow standard clinical documentation format when appropriate and prioritize patient safety at all times. Remember to include appropriate medical disclaimers.
<|im_start|>user
Patient information: {patient_info}
<|im_end|>
<|im_start|>assistant
"""
print("Loading Llama-2 model...")
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
model = AutoModelForCausalLM.from_pretrained(
LLAMA_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
print("Llama-2 model loaded successfully!")
print("Loading Meditron model...")
meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
meditron_model = AutoModelForCausalLM.from_pretrained(
MEDITRON_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
print("Meditron model loaded successfully!")
# Simple conversation state tracking
conversation_state = {
'name': None,
'age': None,
'medical_turns': 0,
'has_name': False,
'has_age': False
}
def get_meditron_suggestions(patient_info):
"""Use Meditron model to generate medicine and remedy suggestions."""
prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
inputs = meditron_tokenizer(prompt, return_tensors="pt").to(meditron_model.device)
with torch.no_grad():
outputs = meditron_model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True
)
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return suggestion
def build_simple_prompt(system_prompt, conversation_history, current_input):
"""Build a simple prompt for Llama-2"""
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
# Add conversation history
for i, (user_msg, bot_msg) in enumerate(conversation_history):
prompt += f"{user_msg} [/INST] {bot_msg} </s><s>[INST] "
# Add current input
prompt += f"{current_input} [/INST] "
return prompt
@spaces.GPU
def generate_response(message, history):
"""Generate a response using simple state tracking."""
global conversation_state
# Reset state if this is a new conversation
if not history:
conversation_state = {
'name': None,
'age': None,
'medical_turns': 0,
'has_name': False,
'has_age': False
}
# Step 1: Ask for name if not provided
if not conversation_state['has_name']:
conversation_state['has_name'] = True
return "Hello! Before we discuss your health concerns, could you please tell me your name?"
# Step 2: Store name and ask for age
if conversation_state['name'] is None:
conversation_state['name'] = message.strip()
return f"Nice to meet you, {conversation_state['name']}! Could you please tell me your age?"
# Step 3: Store age and start medical questions
if not conversation_state['has_age']:
conversation_state['age'] = message.strip()
conversation_state['has_age'] = True
return f"Thank you, {conversation_state['name']}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
# Step 4: Medical consultation phase
conversation_state['medical_turns'] += 1
# Prepare conversation history for the model
medical_history = []
if len(history) >= 3: # Skip name/age exchanges
medical_history = history[3:]
# Define follow-up questions based on turn number
followup_questions = [
"Can you describe your symptoms in more detail? What exactly are you experiencing?",
"How long have you been experiencing these symptoms? When did they first start?",
"On a scale of 1-10, how would you rate the severity of your symptoms?",
"Have you noticed anything that makes your symptoms better or worse?",
"Do you have any other symptoms, medical history, or are you taking any medications?"
]
# Build the prompt for medical consultation
if conversation_state['medical_turns'] <= 5:
# Still gathering information
prompt = build_simple_prompt(SYSTEM_PROMPT, medical_history, message)
# Generate response
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
llama_response = full_response.split('[/INST]')[-1].strip()
# Add a specific follow-up question
if conversation_state['medical_turns'] < len(followup_questions):
next_question = followup_questions[conversation_state['medical_turns']]
return f"{llama_response}\n\n{next_question}"
else:
return llama_response
else:
# Time for diagnosis and treatment (after 5+ turns)
# Compile patient information
patient_info = f"Patient: {conversation_state['name']}, Age: {conversation_state['age']}\n\n"
patient_info += "Symptoms and Information:\n"
# Add all medical conversation history
for user_msg, bot_msg in medical_history:
patient_info += f"Patient: {user_msg}\n"
patient_info += f"Patient: {message}\n"
# Generate diagnosis with Llama-2
diagnosis_prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\nBased on all the information provided, please provide a comprehensive medical assessment including likely diagnosis and recommendations for {conversation_state['name']}.\n\nPatient Information:\n{patient_info} [/INST] "
inputs = tokenizer(diagnosis_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=384,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
diagnosis = full_response.split('[/INST]')[-1].strip()
# Get treatment suggestions from Meditron
treatment_suggestions = get_meditron_suggestions(patient_info)
# Combine responses
final_response = f"{diagnosis}\n\n--- TREATMENT RECOMMENDATIONS ---\n\n{treatment_suggestions}\n\n**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice."
return final_response
# Create the Gradio interface
demo = gr.ChatInterface(
fn=generate_response,
title="🩺 AI Medical Assistant",
description="I'll ask for your basic information first, then gather details about your symptoms to provide medical insights.",
examples=[
"I have a persistent cough",
"I've been having headaches",
"My stomach hurts"
],
theme="soft"
)
if __name__ == "__main__":
demo.launch()