Spaces:
Running
on
Zero
Running
on
Zero
Thanush
Refactor app.py to streamline user information collection by removing redundant prompts for name and age. Implement a simple state tracking mechanism for improved conversation flow and enhance medical consultation process with structured follow-up questions.
43e5827
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from langchain.memory import ConversationBufferMemory | |
import re | |
# Model configuration | |
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf" | |
MEDITRON_MODEL = "epfl-llm/meditron-7b" | |
SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data. | |
**IMPORTANT** Ask 1-2 follow-up questions at a time to gather more details about: | |
- Detailed description of symptoms | |
- Duration (when did it start?) | |
- Severity (scale of 1-10) | |
- Aggravating or alleviating factors | |
- Related symptoms | |
- Medical history | |
- Current medications and allergies | |
After collecting sufficient information, summarize findings, provide a likely diagnosis (if possible), and suggest when they should seek professional care. | |
If enough information is collected, provide a concise, general diagnosis and a practical over-the-counter medicine and home remedy suggestion. | |
Do NOT make specific prescriptions for prescription-only drugs. | |
Respond empathetically and clearly. Always be professional and thorough.""" | |
MEDITRON_PROMPT = """<|im_start|>system | |
You are a board-certified physician with extensive clinical experience. Your role is to provide evidence-based medical assessment and recommendations following standard medical practice. | |
For each patient case: | |
1. Analyze presented symptoms systematically using medical terminology | |
2. Create a structured differential diagnosis with most likely conditions first | |
3. Recommend appropriate next steps (testing, monitoring, or treatment) | |
4. Provide specific medication recommendations with precise dosing regimens | |
5. Include clear red flags that would necessitate urgent medical attention | |
6. Base all recommendations on current clinical guidelines and evidence-based medicine | |
7. Maintain professional, clear, and compassionate communication | |
Follow standard clinical documentation format when appropriate and prioritize patient safety at all times. Remember to include appropriate medical disclaimers. | |
<|im_start|>user | |
Patient information: {patient_info} | |
<|im_end|> | |
<|im_start|>assistant | |
""" | |
print("Loading Llama-2 model...") | |
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL) | |
model = AutoModelForCausalLM.from_pretrained( | |
LLAMA_MODEL, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
print("Llama-2 model loaded successfully!") | |
print("Loading Meditron model...") | |
meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL) | |
meditron_model = AutoModelForCausalLM.from_pretrained( | |
MEDITRON_MODEL, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
print("Meditron model loaded successfully!") | |
# Simple conversation state tracking | |
conversation_state = { | |
'name': None, | |
'age': None, | |
'medical_turns': 0, | |
'has_name': False, | |
'has_age': False | |
} | |
def get_meditron_suggestions(patient_info): | |
"""Use Meditron model to generate medicine and remedy suggestions.""" | |
prompt = MEDITRON_PROMPT.format(patient_info=patient_info) | |
inputs = meditron_tokenizer(prompt, return_tensors="pt").to(meditron_model.device) | |
with torch.no_grad(): | |
outputs = meditron_model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=256, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
return suggestion | |
def build_simple_prompt(system_prompt, conversation_history, current_input): | |
"""Build a simple prompt for Llama-2""" | |
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n" | |
# Add conversation history | |
for i, (user_msg, bot_msg) in enumerate(conversation_history): | |
prompt += f"{user_msg} [/INST] {bot_msg} </s><s>[INST] " | |
# Add current input | |
prompt += f"{current_input} [/INST] " | |
return prompt | |
def generate_response(message, history): | |
"""Generate a response using simple state tracking.""" | |
global conversation_state | |
# Reset state if this is a new conversation | |
if not history: | |
conversation_state = { | |
'name': None, | |
'age': None, | |
'medical_turns': 0, | |
'has_name': False, | |
'has_age': False | |
} | |
# Step 1: Ask for name if not provided | |
if not conversation_state['has_name']: | |
conversation_state['has_name'] = True | |
return "Hello! Before we discuss your health concerns, could you please tell me your name?" | |
# Step 2: Store name and ask for age | |
if conversation_state['name'] is None: | |
conversation_state['name'] = message.strip() | |
return f"Nice to meet you, {conversation_state['name']}! Could you please tell me your age?" | |
# Step 3: Store age and start medical questions | |
if not conversation_state['has_age']: | |
conversation_state['age'] = message.strip() | |
conversation_state['has_age'] = True | |
return f"Thank you, {conversation_state['name']}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing." | |
# Step 4: Medical consultation phase | |
conversation_state['medical_turns'] += 1 | |
# Prepare conversation history for the model | |
medical_history = [] | |
if len(history) >= 3: # Skip name/age exchanges | |
medical_history = history[3:] | |
# Define follow-up questions based on turn number | |
followup_questions = [ | |
"Can you describe your symptoms in more detail? What exactly are you experiencing?", | |
"How long have you been experiencing these symptoms? When did they first start?", | |
"On a scale of 1-10, how would you rate the severity of your symptoms?", | |
"Have you noticed anything that makes your symptoms better or worse?", | |
"Do you have any other symptoms, medical history, or are you taking any medications?" | |
] | |
# Build the prompt for medical consultation | |
if conversation_state['medical_turns'] <= 5: | |
# Still gathering information | |
prompt = build_simple_prompt(SYSTEM_PROMPT, medical_history, message) | |
# Generate response | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=256, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
llama_response = full_response.split('[/INST]')[-1].strip() | |
# Add a specific follow-up question | |
if conversation_state['medical_turns'] < len(followup_questions): | |
next_question = followup_questions[conversation_state['medical_turns']] | |
return f"{llama_response}\n\n{next_question}" | |
else: | |
return llama_response | |
else: | |
# Time for diagnosis and treatment (after 5+ turns) | |
# Compile patient information | |
patient_info = f"Patient: {conversation_state['name']}, Age: {conversation_state['age']}\n\n" | |
patient_info += "Symptoms and Information:\n" | |
# Add all medical conversation history | |
for user_msg, bot_msg in medical_history: | |
patient_info += f"Patient: {user_msg}\n" | |
patient_info += f"Patient: {message}\n" | |
# Generate diagnosis with Llama-2 | |
diagnosis_prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\nBased on all the information provided, please provide a comprehensive medical assessment including likely diagnosis and recommendations for {conversation_state['name']}.\n\nPatient Information:\n{patient_info} [/INST] " | |
inputs = tokenizer(diagnosis_prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=384, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
diagnosis = full_response.split('[/INST]')[-1].strip() | |
# Get treatment suggestions from Meditron | |
treatment_suggestions = get_meditron_suggestions(patient_info) | |
# Combine responses | |
final_response = f"{diagnosis}\n\n--- TREATMENT RECOMMENDATIONS ---\n\n{treatment_suggestions}\n\n**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice." | |
return final_response | |
# Create the Gradio interface | |
demo = gr.ChatInterface( | |
fn=generate_response, | |
title="🩺 AI Medical Assistant", | |
description="I'll ask for your basic information first, then gather details about your symptoms to provide medical insights.", | |
examples=[ | |
"I have a persistent cough", | |
"I've been having headaches", | |
"My stomach hurts" | |
], | |
theme="soft" | |
) | |
if __name__ == "__main__": | |
demo.launch() |