Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,569 Bytes
b80af5b 71bcd31 9f6ac99 c4447f4 000ab02 71bcd31 01a984c 6e237a4 01a984c 6e237a4 01a984c 6e237a4 01a984c 6e237a4 01a984c 71bcd31 5522bf8 71bcd31 f3b4260 43e5827 71bcd31 bdce857 71bcd31 f3b4260 43e5827 a7f6391 f3b4260 a7f6391 43e5827 a7f6391 43e5827 a7f6391 aa89cd7 f3b4260 43e5827 f3b4260 43e5827 f3b4260 43e5827 f3b4260 43e5827 f3b4260 43e5827 f3b4260 43e5827 f3b4260 43e5827 01a984c f3b4260 43e5827 01a984c 43e5827 01a984c 43e5827 f3b4260 01a984c a7f6391 d6da22c 43e5827 f3b4260 43e5827 f3b4260 a7f6391 f3b4260 a7f6391 f3b4260 a7f6391 43e5827 f3b4260 43e5827 a7f6391 f3b4260 aa89cd7 c4447f4 71bcd31 6d5190c 71bcd31 43e5827 8b29c0d 43e5827 8b29c0d 71bcd31 6d5190c b80af5b 71bcd31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain.memory import ConversationBufferMemory
import re
# Model configuration
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
MEDITRON_MODEL = "epfl-llm/meditron-7b"
SYSTEM_PROMPT = """You are a professional virtual doctor conducting a medical consultation. Your role is to gather comprehensive information about the patient's condition through intelligent questioning.
**CONSULTATION APPROACH:**
- Ask thoughtful, relevant follow-up questions based on the patient's responses
- Prioritize gathering information about: symptom details, duration, severity, triggers, related symptoms, medical history, medications, and lifestyle factors
- Ask 1-2 specific questions at a time that build naturally on their previous answers
- Be empathetic, professional, and thorough in your questioning
- Adapt your questions based on the symptoms they describe
**IMPORTANT GUIDELINES:**
- Generate intelligent follow-up questions that are contextually relevant to their specific symptoms
- Don't ask generic questions - tailor each question to their particular situation
- If they mention pain, ask about location, type, and triggers
- If they mention duration, ask about progression or changes
- Build each question logically from their previous responses
After 4-5 meaningful exchanges, provide assessment and recommendations.
Do NOT make specific prescriptions for prescription-only drugs.
Always maintain a professional, caring tone throughout the consultation."""
MEDITRON_PROMPT = """<|im_start|>system
You are a board-certified physician with extensive clinical experience. Your role is to provide evidence-based medical assessment and recommendations following standard medical practice.
For each patient case:
1. Analyze presented symptoms systematically using medical terminology
2. Create a structured differential diagnosis with most likely conditions first
3. Recommend appropriate next steps (testing, monitoring, or treatment)
4. Provide specific medication recommendations with precise dosing regimens
5. Include clear red flags that would necessitate urgent medical attention
6. Base all recommendations on current clinical guidelines and evidence-based medicine
7. Maintain professional, clear, and compassionate communication
Follow standard clinical documentation format when appropriate and prioritize patient safety at all times. Remember to include appropriate medical disclaimers.
<|im_start|>user
Patient information: {patient_info}
<|im_end|>
<|im_start|>assistant
"""
print("Loading Llama-2 model...")
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
model = AutoModelForCausalLM.from_pretrained(
LLAMA_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
print("Llama-2 model loaded successfully!")
print("Loading Meditron model...")
meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
meditron_model = AutoModelForCausalLM.from_pretrained(
MEDITRON_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
print("Meditron model loaded successfully!")
# Initialize LangChain memory for conversation tracking
memory = ConversationBufferMemory(return_messages=True)
# Simple state for basic info tracking
conversation_state = {
'name': None,
'age': None,
'medical_turns': 0,
'has_name': False,
'has_age': False
}
def get_meditron_suggestions(patient_info):
"""Use Meditron model to generate medicine and remedy suggestions."""
prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
inputs = meditron_tokenizer(prompt, return_tensors="pt").to(meditron_model.device)
with torch.no_grad():
outputs = meditron_model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True
)
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return suggestion
def build_prompt_with_memory(system_prompt, current_input):
"""Build prompt using LangChain memory for full conversation context"""
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
# Get conversation history from memory
messages = memory.chat_memory.messages
# Add conversation history to prompt
for msg in messages:
if msg.type == "human":
prompt += f"{msg.content} [/INST] "
elif msg.type == "ai":
prompt += f"{msg.content} </s><s>[INST] "
# Add current input
prompt += f"{current_input} [/INST] "
return prompt
@spaces.GPU
def generate_response(message, history):
"""Generate a response using LangChain ConversationBufferMemory."""
global conversation_state
# Reset state if this is a new conversation
if not history:
conversation_state = {
'name': None,
'age': None,
'medical_turns': 0,
'has_name': False,
'has_age': False
}
# Clear memory for new conversation
memory.clear()
# Save current user message to memory (we'll save bot response later)
memory.save_context({"input": message}, {"output": ""})
# Step 1: Ask for name if not provided
if not conversation_state['has_name']:
conversation_state['has_name'] = True
bot_response = "Hello! Before we discuss your health concerns, could you please tell me your name?"
# Update memory with bot response
memory.save_context({"input": message}, {"output": bot_response})
return bot_response
# Step 2: Store name and ask for age
if conversation_state['name'] is None:
conversation_state['name'] = message.strip()
bot_response = f"Nice to meet you, {conversation_state['name']}! Could you please tell me your age?"
# Update memory with bot response
memory.save_context({"input": message}, {"output": bot_response})
return bot_response
# Step 3: Store age and start medical questions
if not conversation_state['has_age']:
conversation_state['age'] = message.strip()
conversation_state['has_age'] = True
bot_response = f"Thank you, {conversation_state['name']}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
# Update memory with bot response
memory.save_context({"input": message}, {"output": bot_response})
return bot_response
# Step 4: Medical consultation phase using ConversationBufferMemory
conversation_state['medical_turns'] += 1
# Build the prompt using memory for full conversation context
if conversation_state['medical_turns'] <= 5:
# Still gathering information - let LLM ask intelligent follow-up questions
prompt = build_prompt_with_memory(SYSTEM_PROMPT, message)
# Generate response with intelligent follow-up questions
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=384,
temperature=0.8,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
llama_response = full_response.split('[/INST]')[-1].strip()
# Save bot response to memory
memory.save_context({"input": message}, {"output": llama_response})
return llama_response
else:
# Time for diagnosis and treatment (after 5+ turns)
# Get all conversation messages from memory
all_messages = memory.chat_memory.messages
# Compile patient information from memory
patient_info = f"Patient: {conversation_state['name']}, Age: {conversation_state['age']}\n\n"
patient_info += "Complete Conversation History:\n"
# Add all messages from memory
for msg in all_messages:
if msg.type == "human":
patient_info += f"Patient: {msg.content}\n"
elif msg.type == "ai":
patient_info += f"Doctor: {msg.content}\n"
patient_info += f"Current: {message}\n"
# Generate diagnosis with full conversation context
diagnosis_prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\nBased on the complete conversation history, please provide a comprehensive medical assessment including likely diagnosis and recommendations for {conversation_state['name']}.\n\nComplete Patient Information:\n{patient_info} [/INST] "
inputs = tokenizer(diagnosis_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=384,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
diagnosis = full_response.split('[/INST]')[-1].strip()
# Get treatment suggestions from Meditron using memory context
treatment_suggestions = get_meditron_suggestions(patient_info)
# Combine responses
final_response = f"{diagnosis}\n\n--- TREATMENT RECOMMENDATIONS ---\n\n{treatment_suggestions}\n\n**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice."
# Save final response to memory
memory.save_context({"input": message}, {"output": final_response})
return final_response
# Create the Gradio interface
demo = gr.ChatInterface(
fn=generate_response,
title="🩺 AI Medical Assistant",
description="I'll ask for your basic information first, then gather details about your symptoms to provide medical insights.",
examples=[
"I have a persistent cough",
"I've been having headaches",
"My stomach hurts"
],
theme="soft"
)
if __name__ == "__main__":
demo.launch() |