File size: 9,562 Bytes
b80af5b
71bcd31
9f6ac99
 
c4447f4
000ab02
71bcd31
 
 
 
 
43e5827
6e237a4
a985489
71bcd31
 
 
 
 
 
 
6e237a4
43e5827
6e237a4
 
 
 
 
71bcd31
 
 
5522bf8
 
 
 
 
 
 
 
 
 
 
 
71bcd31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43e5827
 
 
 
 
 
 
 
71bcd31
 
 
 
 
 
 
 
 
 
 
 
 
 
bdce857
 
71bcd31
 
 
43e5827
 
 
a7f6391
43e5827
 
 
a7f6391
43e5827
 
a7f6391
43e5827
a7f6391
aa89cd7
 
43e5827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7f6391
43e5827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7f6391
d6da22c
43e5827
 
 
 
a7f6391
43e5827
 
 
 
a7f6391
43e5827
 
a7f6391
43e5827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7f6391
aa89cd7
c4447f4
71bcd31
6d5190c
71bcd31
43e5827
 
8b29c0d
43e5827
 
 
8b29c0d
71bcd31
6d5190c
b80af5b
 
71bcd31
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain.memory import ConversationBufferMemory
import re

# Model configuration
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
MEDITRON_MODEL = "epfl-llm/meditron-7b"

SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.

**IMPORTANT** Ask 1-2 follow-up questions at a time to gather more details about:
- Detailed description of symptoms
- Duration (when did it start?)
- Severity (scale of 1-10)
- Aggravating or alleviating factors
- Related symptoms
- Medical history
- Current medications and allergies

After collecting sufficient information, summarize findings, provide a likely diagnosis (if possible), and suggest when they should seek professional care.

If enough information is collected, provide a concise, general diagnosis and a practical over-the-counter medicine and home remedy suggestion.

Do NOT make specific prescriptions for prescription-only drugs.

Respond empathetically and clearly. Always be professional and thorough."""

MEDITRON_PROMPT = """<|im_start|>system
You are a board-certified physician with extensive clinical experience. Your role is to provide evidence-based medical assessment and recommendations following standard medical practice.

For each patient case:
1. Analyze presented symptoms systematically using medical terminology
2. Create a structured differential diagnosis with most likely conditions first
3. Recommend appropriate next steps (testing, monitoring, or treatment)
4. Provide specific medication recommendations with precise dosing regimens
5. Include clear red flags that would necessitate urgent medical attention
6. Base all recommendations on current clinical guidelines and evidence-based medicine
7. Maintain professional, clear, and compassionate communication

Follow standard clinical documentation format when appropriate and prioritize patient safety at all times. Remember to include appropriate medical disclaimers.
<|im_start|>user
Patient information: {patient_info}
<|im_end|>
<|im_start|>assistant
"""

print("Loading Llama-2 model...")
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
model = AutoModelForCausalLM.from_pretrained(
    LLAMA_MODEL,
    torch_dtype=torch.float16,
    device_map="auto"
)
print("Llama-2 model loaded successfully!")

print("Loading Meditron model...")
meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
meditron_model = AutoModelForCausalLM.from_pretrained(
    MEDITRON_MODEL,
    torch_dtype=torch.float16,
    device_map="auto"
)
print("Meditron model loaded successfully!")

# Simple conversation state tracking
conversation_state = {
    'name': None,
    'age': None,
    'medical_turns': 0,
    'has_name': False,
    'has_age': False
}

def get_meditron_suggestions(patient_info):
    """Use Meditron model to generate medicine and remedy suggestions."""
    prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
    inputs = meditron_tokenizer(prompt, return_tensors="pt").to(meditron_model.device)
    
    with torch.no_grad():
        outputs = meditron_model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_new_tokens=256,
            temperature=0.7,
            top_p=0.9,
            do_sample=True
        )
    
    suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    return suggestion

def build_simple_prompt(system_prompt, conversation_history, current_input):
    """Build a simple prompt for Llama-2"""
    prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
    
    # Add conversation history
    for i, (user_msg, bot_msg) in enumerate(conversation_history):
        prompt += f"{user_msg} [/INST] {bot_msg} </s><s>[INST] "
    
    # Add current input
    prompt += f"{current_input} [/INST] "
    
    return prompt

@spaces.GPU
def generate_response(message, history):
    """Generate a response using simple state tracking."""
    global conversation_state
    
    # Reset state if this is a new conversation
    if not history:
        conversation_state = {
            'name': None,
            'age': None,
            'medical_turns': 0,
            'has_name': False,
            'has_age': False
        }
    
    # Step 1: Ask for name if not provided
    if not conversation_state['has_name']:
        conversation_state['has_name'] = True
        return "Hello! Before we discuss your health concerns, could you please tell me your name?"
    
    # Step 2: Store name and ask for age
    if conversation_state['name'] is None:
        conversation_state['name'] = message.strip()
        return f"Nice to meet you, {conversation_state['name']}! Could you please tell me your age?"
    
    # Step 3: Store age and start medical questions
    if not conversation_state['has_age']:
        conversation_state['age'] = message.strip()
        conversation_state['has_age'] = True
        return f"Thank you, {conversation_state['name']}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
    
    # Step 4: Medical consultation phase
    conversation_state['medical_turns'] += 1
    
    # Prepare conversation history for the model
    medical_history = []
    if len(history) >= 3:  # Skip name/age exchanges
        medical_history = history[3:]
    
    # Define follow-up questions based on turn number
    followup_questions = [
        "Can you describe your symptoms in more detail? What exactly are you experiencing?",
        "How long have you been experiencing these symptoms? When did they first start?",
        "On a scale of 1-10, how would you rate the severity of your symptoms?",
        "Have you noticed anything that makes your symptoms better or worse?",
        "Do you have any other symptoms, medical history, or are you taking any medications?"
    ]
    
    # Build the prompt for medical consultation
    if conversation_state['medical_turns'] <= 5:
        # Still gathering information
        prompt = build_simple_prompt(SYSTEM_PROMPT, medical_history, message)
        
        # Generate response
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=256,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )
        
        full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        llama_response = full_response.split('[/INST]')[-1].strip()
        
        # Add a specific follow-up question
        if conversation_state['medical_turns'] < len(followup_questions):
            next_question = followup_questions[conversation_state['medical_turns']]
            return f"{llama_response}\n\n{next_question}"
        else:
            return llama_response
    
    else:
        # Time for diagnosis and treatment (after 5+ turns)
        # Compile patient information
        patient_info = f"Patient: {conversation_state['name']}, Age: {conversation_state['age']}\n\n"
        patient_info += "Symptoms and Information:\n"
        
        # Add all medical conversation history
        for user_msg, bot_msg in medical_history:
            patient_info += f"Patient: {user_msg}\n"
        patient_info += f"Patient: {message}\n"
        
        # Generate diagnosis with Llama-2
        diagnosis_prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\nBased on all the information provided, please provide a comprehensive medical assessment including likely diagnosis and recommendations for {conversation_state['name']}.\n\nPatient Information:\n{patient_info} [/INST] "
        
        inputs = tokenizer(diagnosis_prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=384,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )
        
        full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        diagnosis = full_response.split('[/INST]')[-1].strip()
        
        # Get treatment suggestions from Meditron
        treatment_suggestions = get_meditron_suggestions(patient_info)
        
        # Combine responses
        final_response = f"{diagnosis}\n\n--- TREATMENT RECOMMENDATIONS ---\n\n{treatment_suggestions}\n\n**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice."
        
        return final_response

# Create the Gradio interface
demo = gr.ChatInterface(
    fn=generate_response,
    title="🩺 AI Medical Assistant",
    description="I'll ask for your basic information first, then gather details about your symptoms to provide medical insights.",
    examples=[
        "I have a persistent cough",
        "I've been having headaches",
        "My stomach hurts"
    ],
    theme="soft"
)

if __name__ == "__main__":
    demo.launch()