Spaces:
Running
on
Zero
Running
on
Zero
Thanush
Refactor app.py to implement LangChain memory for enhanced conversation tracking. Update prompt building and response generation logic to utilize full conversation context, improving user interaction and medical assessment accuracy.
f3b4260
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from langchain.memory import ConversationBufferMemory | |
import re | |
# Model configuration | |
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf" | |
MEDITRON_MODEL = "epfl-llm/meditron-7b" | |
SYSTEM_PROMPT = """You are a professional virtual doctor conducting a medical consultation. Your role is to gather comprehensive information about the patient's condition through intelligent questioning. | |
**CONSULTATION APPROACH:** | |
- Ask thoughtful, relevant follow-up questions based on the patient's responses | |
- Prioritize gathering information about: symptom details, duration, severity, triggers, related symptoms, medical history, medications, and lifestyle factors | |
- Ask 1-2 specific questions at a time that build naturally on their previous answers | |
- Be empathetic, professional, and thorough in your questioning | |
- Adapt your questions based on the symptoms they describe | |
**IMPORTANT GUIDELINES:** | |
- Generate intelligent follow-up questions that are contextually relevant to their specific symptoms | |
- Don't ask generic questions - tailor each question to their particular situation | |
- If they mention pain, ask about location, type, and triggers | |
- If they mention duration, ask about progression or changes | |
- Build each question logically from their previous responses | |
After 4-5 meaningful exchanges, provide assessment and recommendations. | |
Do NOT make specific prescriptions for prescription-only drugs. | |
Always maintain a professional, caring tone throughout the consultation.""" | |
MEDITRON_PROMPT = """<|im_start|>system | |
You are a board-certified physician with extensive clinical experience. Your role is to provide evidence-based medical assessment and recommendations following standard medical practice. | |
For each patient case: | |
1. Analyze presented symptoms systematically using medical terminology | |
2. Create a structured differential diagnosis with most likely conditions first | |
3. Recommend appropriate next steps (testing, monitoring, or treatment) | |
4. Provide specific medication recommendations with precise dosing regimens | |
5. Include clear red flags that would necessitate urgent medical attention | |
6. Base all recommendations on current clinical guidelines and evidence-based medicine | |
7. Maintain professional, clear, and compassionate communication | |
Follow standard clinical documentation format when appropriate and prioritize patient safety at all times. Remember to include appropriate medical disclaimers. | |
<|im_start|>user | |
Patient information: {patient_info} | |
<|im_end|> | |
<|im_start|>assistant | |
""" | |
print("Loading Llama-2 model...") | |
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL) | |
model = AutoModelForCausalLM.from_pretrained( | |
LLAMA_MODEL, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
print("Llama-2 model loaded successfully!") | |
print("Loading Meditron model...") | |
meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL) | |
meditron_model = AutoModelForCausalLM.from_pretrained( | |
MEDITRON_MODEL, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
print("Meditron model loaded successfully!") | |
# Initialize LangChain memory for conversation tracking | |
memory = ConversationBufferMemory(return_messages=True) | |
# Simple state for basic info tracking | |
conversation_state = { | |
'name': None, | |
'age': None, | |
'medical_turns': 0, | |
'has_name': False, | |
'has_age': False | |
} | |
def get_meditron_suggestions(patient_info): | |
"""Use Meditron model to generate medicine and remedy suggestions.""" | |
prompt = MEDITRON_PROMPT.format(patient_info=patient_info) | |
inputs = meditron_tokenizer(prompt, return_tensors="pt").to(meditron_model.device) | |
with torch.no_grad(): | |
outputs = meditron_model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=256, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
return suggestion | |
def build_prompt_with_memory(system_prompt, current_input): | |
"""Build prompt using LangChain memory for full conversation context""" | |
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n" | |
# Get conversation history from memory | |
messages = memory.chat_memory.messages | |
# Add conversation history to prompt | |
for msg in messages: | |
if msg.type == "human": | |
prompt += f"{msg.content} [/INST] " | |
elif msg.type == "ai": | |
prompt += f"{msg.content} </s><s>[INST] " | |
# Add current input | |
prompt += f"{current_input} [/INST] " | |
return prompt | |
def generate_response(message, history): | |
"""Generate a response using LangChain ConversationBufferMemory.""" | |
global conversation_state | |
# Reset state if this is a new conversation | |
if not history: | |
conversation_state = { | |
'name': None, | |
'age': None, | |
'medical_turns': 0, | |
'has_name': False, | |
'has_age': False | |
} | |
# Clear memory for new conversation | |
memory.clear() | |
# Save current user message to memory (we'll save bot response later) | |
memory.save_context({"input": message}, {"output": ""}) | |
# Step 1: Ask for name if not provided | |
if not conversation_state['has_name']: | |
conversation_state['has_name'] = True | |
bot_response = "Hello! Before we discuss your health concerns, could you please tell me your name?" | |
# Update memory with bot response | |
memory.save_context({"input": message}, {"output": bot_response}) | |
return bot_response | |
# Step 2: Store name and ask for age | |
if conversation_state['name'] is None: | |
conversation_state['name'] = message.strip() | |
bot_response = f"Nice to meet you, {conversation_state['name']}! Could you please tell me your age?" | |
# Update memory with bot response | |
memory.save_context({"input": message}, {"output": bot_response}) | |
return bot_response | |
# Step 3: Store age and start medical questions | |
if not conversation_state['has_age']: | |
conversation_state['age'] = message.strip() | |
conversation_state['has_age'] = True | |
bot_response = f"Thank you, {conversation_state['name']}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing." | |
# Update memory with bot response | |
memory.save_context({"input": message}, {"output": bot_response}) | |
return bot_response | |
# Step 4: Medical consultation phase using ConversationBufferMemory | |
conversation_state['medical_turns'] += 1 | |
# Build the prompt using memory for full conversation context | |
if conversation_state['medical_turns'] <= 5: | |
# Still gathering information - let LLM ask intelligent follow-up questions | |
prompt = build_prompt_with_memory(SYSTEM_PROMPT, message) | |
# Generate response with intelligent follow-up questions | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=384, | |
temperature=0.8, | |
top_p=0.95, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
llama_response = full_response.split('[/INST]')[-1].strip() | |
# Save bot response to memory | |
memory.save_context({"input": message}, {"output": llama_response}) | |
return llama_response | |
else: | |
# Time for diagnosis and treatment (after 5+ turns) | |
# Get all conversation messages from memory | |
all_messages = memory.chat_memory.messages | |
# Compile patient information from memory | |
patient_info = f"Patient: {conversation_state['name']}, Age: {conversation_state['age']}\n\n" | |
patient_info += "Complete Conversation History:\n" | |
# Add all messages from memory | |
for msg in all_messages: | |
if msg.type == "human": | |
patient_info += f"Patient: {msg.content}\n" | |
elif msg.type == "ai": | |
patient_info += f"Doctor: {msg.content}\n" | |
patient_info += f"Current: {message}\n" | |
# Generate diagnosis with full conversation context | |
diagnosis_prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\nBased on the complete conversation history, please provide a comprehensive medical assessment including likely diagnosis and recommendations for {conversation_state['name']}.\n\nComplete Patient Information:\n{patient_info} [/INST] " | |
inputs = tokenizer(diagnosis_prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=384, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
diagnosis = full_response.split('[/INST]')[-1].strip() | |
# Get treatment suggestions from Meditron using memory context | |
treatment_suggestions = get_meditron_suggestions(patient_info) | |
# Combine responses | |
final_response = f"{diagnosis}\n\n--- TREATMENT RECOMMENDATIONS ---\n\n{treatment_suggestions}\n\n**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice." | |
# Save final response to memory | |
memory.save_context({"input": message}, {"output": final_response}) | |
return final_response | |
# Create the Gradio interface | |
demo = gr.ChatInterface( | |
fn=generate_response, | |
title="🩺 AI Medical Assistant", | |
description="I'll ask for your basic information first, then gather details about your symptoms to provide medical insights.", | |
examples=[ | |
"I have a persistent cough", | |
"I've been having headaches", | |
"My stomach hurts" | |
], | |
theme="soft" | |
) | |
if __name__ == "__main__": | |
demo.launch() |