techindia2025 commited on
Commit
1728da9
·
verified ·
1 Parent(s): f3b4260

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +239 -188
app.py CHANGED
@@ -2,7 +2,6 @@ import gradio as gr
2
  import spaces
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
- from langchain.memory import ConversationBufferMemory
6
  import re
7
 
8
  # Model configuration
@@ -26,229 +25,281 @@ SYSTEM_PROMPT = """You are a professional virtual doctor conducting a medical co
26
  - Build each question logically from their previous responses
27
 
28
  After 4-5 meaningful exchanges, provide assessment and recommendations.
29
-
30
  Do NOT make specific prescriptions for prescription-only drugs.
31
-
32
  Always maintain a professional, caring tone throughout the consultation."""
33
 
34
- MEDITRON_PROMPT = """<|im_start|>system
35
- You are a board-certified physician with extensive clinical experience. Your role is to provide evidence-based medical assessment and recommendations following standard medical practice.
36
 
37
- For each patient case:
38
- 1. Analyze presented symptoms systematically using medical terminology
39
- 2. Create a structured differential diagnosis with most likely conditions first
40
  3. Recommend appropriate next steps (testing, monitoring, or treatment)
41
- 4. Provide specific medication recommendations with precise dosing regimens
42
- 5. Include clear red flags that would necessitate urgent medical attention
43
- 6. Base all recommendations on current clinical guidelines and evidence-based medicine
44
- 7. Maintain professional, clear, and compassionate communication
45
-
46
- Follow standard clinical documentation format when appropriate and prioritize patient safety at all times. Remember to include appropriate medical disclaimers.
47
- <|im_start|>user
48
- Patient information: {patient_info}
49
- <|im_end|>
50
- <|im_start|>assistant
51
- """
52
 
53
- print("Loading Llama-2 model...")
54
- tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
55
- model = AutoModelForCausalLM.from_pretrained(
56
- LLAMA_MODEL,
57
- torch_dtype=torch.float16,
58
- device_map="auto"
59
- )
60
- print("Llama-2 model loaded successfully!")
61
-
62
- print("Loading Meditron model...")
63
- meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
64
- meditron_model = AutoModelForCausalLM.from_pretrained(
65
- MEDITRON_MODEL,
66
- torch_dtype=torch.float16,
67
- device_map="auto"
68
- )
69
- print("Meditron model loaded successfully!")
70
-
71
- # Initialize LangChain memory for conversation tracking
72
- memory = ConversationBufferMemory(return_messages=True)
73
-
74
- # Simple state for basic info tracking
75
- conversation_state = {
76
- 'name': None,
77
- 'age': None,
78
- 'medical_turns': 0,
79
- 'has_name': False,
80
- 'has_age': False
81
- }
82
-
83
- def get_meditron_suggestions(patient_info):
84
- """Use Meditron model to generate medicine and remedy suggestions."""
85
- prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
86
- inputs = meditron_tokenizer(prompt, return_tensors="pt").to(meditron_model.device)
87
-
88
- with torch.no_grad():
89
- outputs = meditron_model.generate(
90
- inputs.input_ids,
91
- attention_mask=inputs.attention_mask,
92
- max_new_tokens=256,
93
- temperature=0.7,
94
- top_p=0.9,
95
- do_sample=True
96
- )
97
-
98
- suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
99
- return suggestion
100
-
101
- def build_prompt_with_memory(system_prompt, current_input):
102
- """Build prompt using LangChain memory for full conversation context"""
103
- prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
104
-
105
- # Get conversation history from memory
106
- messages = memory.chat_memory.messages
107
-
108
- # Add conversation history to prompt
109
- for msg in messages:
110
- if msg.type == "human":
111
- prompt += f"{msg.content} [/INST] "
112
- elif msg.type == "ai":
113
- prompt += f"{msg.content} </s><s>[INST] "
114
-
115
- # Add current input
116
- prompt += f"{current_input} [/INST] "
117
-
118
- return prompt
119
 
120
- @spaces.GPU
121
- def generate_response(message, history):
122
- """Generate a response using LangChain ConversationBufferMemory."""
123
- global conversation_state
124
-
125
- # Reset state if this is a new conversation
126
- if not history:
127
- conversation_state = {
128
- 'name': None,
129
- 'age': None,
130
- 'medical_turns': 0,
131
- 'has_name': False,
132
- 'has_age': False
133
- }
134
- # Clear memory for new conversation
135
- memory.clear()
136
-
137
- # Save current user message to memory (we'll save bot response later)
138
- memory.save_context({"input": message}, {"output": ""})
139
-
140
- # Step 1: Ask for name if not provided
141
- if not conversation_state['has_name']:
142
- conversation_state['has_name'] = True
143
- bot_response = "Hello! Before we discuss your health concerns, could you please tell me your name?"
144
- # Update memory with bot response
145
- memory.save_context({"input": message}, {"output": bot_response})
146
- return bot_response
147
-
148
- # Step 2: Store name and ask for age
149
- if conversation_state['name'] is None:
150
- conversation_state['name'] = message.strip()
151
- bot_response = f"Nice to meet you, {conversation_state['name']}! Could you please tell me your age?"
152
- # Update memory with bot response
153
- memory.save_context({"input": message}, {"output": bot_response})
154
- return bot_response
155
-
156
- # Step 3: Store age and start medical questions
157
- if not conversation_state['has_age']:
158
- conversation_state['age'] = message.strip()
159
- conversation_state['has_age'] = True
160
- bot_response = f"Thank you, {conversation_state['name']}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
161
- # Update memory with bot response
162
- memory.save_context({"input": message}, {"output": bot_response})
163
- return bot_response
164
 
165
- # Step 4: Medical consultation phase using ConversationBufferMemory
166
- conversation_state['medical_turns'] += 1
 
 
 
 
167
 
168
- # Build the prompt using memory for full conversation context
169
- if conversation_state['medical_turns'] <= 5:
170
- # Still gathering information - let LLM ask intelligent follow-up questions
171
- prompt = build_prompt_with_memory(SYSTEM_PROMPT, message)
172
 
173
- # Generate response with intelligent follow-up questions
174
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
175
- with torch.no_grad():
176
- outputs = model.generate(
177
- inputs.input_ids,
178
- attention_mask=inputs.attention_mask,
179
- max_new_tokens=384,
180
- temperature=0.8,
181
- top_p=0.95,
182
- do_sample=True,
183
- pad_token_id=tokenizer.eos_token_id
184
- )
 
 
 
 
 
 
 
 
185
 
186
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
187
- llama_response = full_response.split('[/INST]')[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
- # Save bot response to memory
190
- memory.save_context({"input": message}, {"output": llama_response})
 
 
191
 
192
- return llama_response
193
 
194
- else:
195
- # Time for diagnosis and treatment (after 5+ turns)
196
- # Get all conversation messages from memory
197
- all_messages = memory.chat_memory.messages
198
-
199
- # Compile patient information from memory
200
- patient_info = f"Patient: {conversation_state['name']}, Age: {conversation_state['age']}\n\n"
201
- patient_info += "Complete Conversation History:\n"
202
 
203
- # Add all messages from memory
204
- for msg in all_messages:
205
- if msg.type == "human":
206
- patient_info += f"Patient: {msg.content}\n"
207
- elif msg.type == "ai":
208
- patient_info += f"Doctor: {msg.content}\n"
209
 
210
- patient_info += f"Current: {message}\n"
211
 
212
- # Generate diagnosis with full conversation context
213
- diagnosis_prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\nBased on the complete conversation history, please provide a comprehensive medical assessment including likely diagnosis and recommendations for {conversation_state['name']}.\n\nComplete Patient Information:\n{patient_info} [/INST] "
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
- inputs = tokenizer(diagnosis_prompt, return_tensors="pt").to(model.device)
216
  with torch.no_grad():
217
- outputs = model.generate(
218
  inputs.input_ids,
219
  attention_mask=inputs.attention_mask,
220
- max_new_tokens=384,
221
  temperature=0.7,
222
  top_p=0.9,
223
  do_sample=True,
224
- pad_token_id=tokenizer.eos_token_id
225
  )
226
 
227
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
228
- diagnosis = full_response.split('[/INST]')[-1].strip()
 
 
229
 
230
- # Get treatment suggestions from Meditron using memory context
231
- treatment_suggestions = get_meditron_suggestions(patient_info)
232
-
233
- # Combine responses
234
- final_response = f"{diagnosis}\n\n--- TREATMENT RECOMMENDATIONS ---\n\n{treatment_suggestions}\n\n**Important:** These are general recommendations. Please consult with a healthcare professional for personalized medical advice."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
- # Save final response to memory
237
- memory.save_context({"input": message}, {"output": final_response})
 
 
238
 
239
- return final_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
- # Create the Gradio interface
242
  demo = gr.ChatInterface(
243
- fn=generate_response,
244
- title="🩺 AI Medical Assistant",
245
- description="I'll ask for your basic information first, then gather details about your symptoms to provide medical insights.",
246
  examples=[
 
247
  "I have a persistent cough",
248
  "I've been having headaches",
249
  "My stomach hurts"
250
  ],
251
- theme="soft"
 
 
 
252
  )
253
 
254
  if __name__ == "__main__":
 
2
  import spaces
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
5
  import re
6
 
7
  # Model configuration
 
25
  - Build each question logically from their previous responses
26
 
27
  After 4-5 meaningful exchanges, provide assessment and recommendations.
 
28
  Do NOT make specific prescriptions for prescription-only drugs.
 
29
  Always maintain a professional, caring tone throughout the consultation."""
30
 
31
+ MEDITRON_PROMPT = """You are a board-certified physician providing evidence-based medical assessment.
 
32
 
33
+ Based on the patient information provided, please:
34
+ 1. Analyze the symptoms systematically
35
+ 2. Provide a differential diagnosis with most likely conditions
36
  3. Recommend appropriate next steps (testing, monitoring, or treatment)
37
+ 4. Suggest appropriate medications or remedies with dosing if applicable
38
+ 5. Include red flags that would require urgent medical attention
39
+ 6. Base recommendations on clinical guidelines
 
 
 
 
 
 
 
 
40
 
41
+ Patient Information: {patient_info}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ Please provide a structured medical assessment:"""
44
+
45
+ # Load models
46
+ print("Loading models...")
47
+ try:
48
+ tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
49
+ if tokenizer.pad_token is None:
50
+ tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ model = AutoModelForCausalLM.from_pretrained(
53
+ LLAMA_MODEL,
54
+ torch_dtype=torch.float16,
55
+ device_map="auto"
56
+ )
57
+ print("Llama-2 model loaded successfully!")
58
 
59
+ meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
60
+ if meditron_tokenizer.pad_token is None:
61
+ meditron_tokenizer.pad_token = meditron_tokenizer.eos_token
 
62
 
63
+ meditron_model = AutoModelForCausalLM.from_pretrained(
64
+ MEDITRON_MODEL,
65
+ torch_dtype=torch.float16,
66
+ device_map="auto"
67
+ )
68
+ print("Meditron model loaded successfully!")
69
+ except Exception as e:
70
+ print(f"Error loading models: {e}")
71
+
72
+ class MedicalConsultationBot:
73
+ def __init__(self):
74
+ self.reset_conversation()
75
+
76
+ def reset_conversation(self):
77
+ """Reset all conversation state"""
78
+ self.conversation_history = []
79
+ self.patient_name = None
80
+ self.patient_age = None
81
+ self.medical_turns = 0
82
+ self.stage = "greeting" # greeting -> name -> age -> symptoms -> diagnosis
83
 
84
+ def add_to_history(self, user_message, bot_response):
85
+ """Add exchange to conversation history"""
86
+ self.conversation_history.append({
87
+ "user": user_message,
88
+ "bot": bot_response
89
+ })
90
+
91
+ def get_conversation_context(self):
92
+ """Get full conversation context as string"""
93
+ context = ""
94
+ if self.patient_name:
95
+ context += f"Patient Name: {self.patient_name}\n"
96
+ if self.patient_age:
97
+ context += f"Patient Age: {self.patient_age}\n"
98
 
99
+ context += "\nConversation History:\n"
100
+ for exchange in self.conversation_history:
101
+ context += f"Patient: {exchange['user']}\n"
102
+ context += f"Doctor: {exchange['bot']}\n"
103
 
104
+ return context
105
 
106
+ def build_llama_prompt(self, current_message):
107
+ """Build prompt for Llama model with conversation context"""
108
+ prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n"
 
 
 
 
 
109
 
110
+ # Add conversation context
111
+ context = self.get_conversation_context()
112
+ if context.strip():
113
+ prompt += f"Previous conversation context:\n{context}\n\n"
 
 
114
 
115
+ prompt += f"Current patient message: {current_message}\n\nProvide a professional medical response with appropriate follow-up questions. [/INST]"
116
 
117
+ return prompt
118
+
119
+ # Global bot instance
120
+ medical_bot = MedicalConsultationBot()
121
+
122
+ def get_meditron_diagnosis(patient_info):
123
+ """Use Meditron model to generate medical assessment"""
124
+ try:
125
+ prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
126
+ inputs = meditron_tokenizer(
127
+ prompt,
128
+ return_tensors="pt",
129
+ max_length=512,
130
+ truncation=True
131
+ ).to(meditron_model.device)
132
 
 
133
  with torch.no_grad():
134
+ outputs = meditron_model.generate(
135
  inputs.input_ids,
136
  attention_mask=inputs.attention_mask,
137
+ max_new_tokens=300,
138
  temperature=0.7,
139
  top_p=0.9,
140
  do_sample=True,
141
+ pad_token_id=meditron_tokenizer.pad_token_id
142
  )
143
 
144
+ response = meditron_tokenizer.decode(
145
+ outputs[0][inputs.input_ids.shape[1]:],
146
+ skip_special_tokens=True
147
+ ).strip()
148
 
149
+ return response
150
+ except Exception as e:
151
+ return f"Error generating medical assessment: {str(e)}"
152
+
153
+ @spaces.GPU
154
+ def medical_chat_response(message, history):
155
+ """Main chat response function with proper state management"""
156
+ global medical_bot
157
+
158
+ # If this is a new conversation (empty history), reset the bot
159
+ if not history:
160
+ medical_bot.reset_conversation()
161
+
162
+ user_message = message.strip()
163
+
164
+ # Stage 1: Initial greeting and ask for name
165
+ if medical_bot.stage == "greeting":
166
+ bot_response = "Hello! I'm your AI medical assistant. Before we discuss your health concerns, could you please tell me your name?"
167
+ medical_bot.stage = "name"
168
+ medical_bot.add_to_history(user_message, bot_response)
169
+ return bot_response
170
+
171
+ # Stage 2: Collect name and ask for age
172
+ elif medical_bot.stage == "name":
173
+ medical_bot.patient_name = user_message
174
+ bot_response = f"Nice to meet you, {medical_bot.patient_name}! Could you please tell me your age?"
175
+ medical_bot.stage = "age"
176
+ medical_bot.add_to_history(user_message, bot_response)
177
+ return bot_response
178
+
179
+ # Stage 3: Collect age and start medical consultation
180
+ elif medical_bot.stage == "age":
181
+ medical_bot.patient_age = user_message
182
+ bot_response = f"Thank you, {medical_bot.patient_name}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
183
+ medical_bot.stage = "symptoms"
184
+ medical_bot.add_to_history(user_message, bot_response)
185
+ return bot_response
186
+
187
+ # Stage 4: Medical consultation - gather symptoms with intelligent follow-ups
188
+ elif medical_bot.stage == "symptoms":
189
+ medical_bot.medical_turns += 1
190
 
191
+ # If we've had enough turns, move to diagnosis
192
+ if medical_bot.medical_turns >= 4:
193
+ medical_bot.stage = "diagnosis"
194
+ return generate_final_diagnosis(user_message)
195
 
196
+ # Generate intelligent follow-up questions
197
+ try:
198
+ prompt = medical_bot.build_llama_prompt(user_message)
199
+ inputs = tokenizer(
200
+ prompt,
201
+ return_tensors="pt",
202
+ max_length=1024,
203
+ truncation=True
204
+ ).to(model.device)
205
+
206
+ with torch.no_grad():
207
+ outputs = model.generate(
208
+ inputs.input_ids,
209
+ attention_mask=inputs.attention_mask,
210
+ max_new_tokens=200,
211
+ temperature=0.8,
212
+ top_p=0.95,
213
+ do_sample=True,
214
+ pad_token_id=tokenizer.pad_token_id
215
+ )
216
+
217
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
218
+ bot_response = full_response.split('[/INST]')[-1].strip()
219
+
220
+ # Clean up the response
221
+ bot_response = bot_response.replace('<s>', '').replace('</s>', '').strip()
222
+
223
+ medical_bot.add_to_history(user_message, bot_response)
224
+ return bot_response
225
+
226
+ except Exception as e:
227
+ bot_response = f"I understand. Could you tell me more about how long you've been experiencing this and if there are any specific triggers or patterns you've noticed?"
228
+ medical_bot.add_to_history(user_message, bot_response)
229
+ return bot_response
230
+
231
+ # Stage 5: Final diagnosis and treatment recommendations
232
+ elif medical_bot.stage == "diagnosis":
233
+ return generate_final_diagnosis(user_message)
234
+
235
+ # Handle any questions after diagnosis
236
+ else:
237
+ # Check if they're asking about their name or previous information
238
+ if "name" in user_message.lower() and medical_bot.patient_name:
239
+ return f"Your name is {medical_bot.patient_name}."
240
+ elif "age" in user_message.lower() and medical_bot.patient_age:
241
+ return f"You told me you are {medical_bot.patient_age} years old."
242
+ else:
243
+ return "Is there anything else about your health concerns I can help you with today?"
244
+
245
+ def generate_final_diagnosis(current_message):
246
+ """Generate final diagnosis using both models"""
247
+ global medical_bot
248
+
249
+ # Add current message to history
250
+ medical_bot.add_to_history(current_message, "")
251
+
252
+ # Compile complete patient information
253
+ patient_info = f"""
254
+ Patient Name: {medical_bot.patient_name}
255
+ Patient Age: {medical_bot.patient_age}
256
+
257
+ Complete Consultation History:
258
+ """
259
+
260
+ for exchange in medical_bot.conversation_history[:-1]: # Exclude the empty last entry
261
+ patient_info += f"Doctor: {exchange['bot']}\n"
262
+ patient_info += f"Patient: {exchange['user']}\n"
263
+
264
+ patient_info += f"Patient: {current_message}\n"
265
+
266
+ # Get diagnosis from Meditron
267
+ meditron_assessment = get_meditron_diagnosis(patient_info)
268
+
269
+ # Generate comprehensive response
270
+ final_response = f"""Thank you for providing all this information, {medical_bot.patient_name}. Based on our consultation, here is my assessment:
271
+
272
+ **MEDICAL ASSESSMENT & RECOMMENDATIONS:**
273
+
274
+ {meditron_assessment}
275
+
276
+ **IMPORTANT DISCLAIMER:** This assessment is for informational purposes only and should not replace professional medical advice. Please consult with a healthcare provider for proper diagnosis and treatment.
277
+
278
+ **NEXT STEPS:** I recommend scheduling an appointment with your primary care physician or appropriate specialist for further evaluation and personalized treatment.
279
+
280
+ Is there anything specific about this assessment you'd like me to clarify?"""
281
+
282
+ # Update conversation history with final response
283
+ medical_bot.conversation_history[-1]["bot"] = final_response
284
+ medical_bot.stage = "complete"
285
+
286
+ return final_response
287
 
288
+ # Create Gradio interface
289
  demo = gr.ChatInterface(
290
+ fn=medical_chat_response,
291
+ title="🩺 AI Medical Assistant with Memory",
292
+ description="I'm an AI medical assistant that will remember our conversation. I'll first ask for your basic information, then gather details about your symptoms through intelligent follow-up questions, and finally provide a medical assessment.",
293
  examples=[
294
+ "Hello, I need medical help",
295
  "I have a persistent cough",
296
  "I've been having headaches",
297
  "My stomach hurts"
298
  ],
299
+ theme="soft",
300
+ retry_btn=None,
301
+ undo_btn=None,
302
+ clear_btn="🔄 Start New Consultation"
303
  )
304
 
305
  if __name__ == "__main__":