techindia2025 commited on
Commit
afe76d4
·
verified ·
1 Parent(s): 1728da9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -263
app.py CHANGED
@@ -1,306 +1,257 @@
1
  import gradio as gr
2
- import spaces
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
- import re
 
 
 
6
 
7
  # Model configuration
8
  LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
9
  MEDITRON_MODEL = "epfl-llm/meditron-7b"
10
 
11
- SYSTEM_PROMPT = """You are a professional virtual doctor conducting a medical consultation. Your role is to gather comprehensive information about the patient's condition through intelligent questioning.
 
 
 
 
 
 
 
 
 
 
12
 
13
- **CONSULTATION APPROACH:**
14
- - Ask thoughtful, relevant follow-up questions based on the patient's responses
15
- - Prioritize gathering information about: symptom details, duration, severity, triggers, related symptoms, medical history, medications, and lifestyle factors
16
- - Ask 1-2 specific questions at a time that build naturally on their previous answers
17
- - Be empathetic, professional, and thorough in your questioning
18
- - Adapt your questions based on the symptoms they describe
19
-
20
- **IMPORTANT GUIDELINES:**
21
- - Generate intelligent follow-up questions that are contextually relevant to their specific symptoms
22
- - Don't ask generic questions - tailor each question to their particular situation
23
- - If they mention pain, ask about location, type, and triggers
24
- - If they mention duration, ask about progression or changes
25
- - Build each question logically from their previous responses
26
-
27
- After 4-5 meaningful exchanges, provide assessment and recommendations.
28
- Do NOT make specific prescriptions for prescription-only drugs.
29
- Always maintain a professional, caring tone throughout the consultation."""
30
-
31
- MEDITRON_PROMPT = """You are a board-certified physician providing evidence-based medical assessment.
32
-
33
- Based on the patient information provided, please:
34
- 1. Analyze the symptoms systematically
35
- 2. Provide a differential diagnosis with most likely conditions
36
- 3. Recommend appropriate next steps (testing, monitoring, or treatment)
37
- 4. Suggest appropriate medications or remedies with dosing if applicable
38
- 5. Include red flags that would require urgent medical attention
39
- 6. Base recommendations on clinical guidelines
40
-
41
- Patient Information: {patient_info}
42
-
43
- Please provide a structured medical assessment:"""
44
 
45
- # Load models
46
- print("Loading models...")
47
- try:
48
- tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
49
- if tokenizer.pad_token is None:
50
- tokenizer.pad_token = tokenizer.eos_token
51
-
52
- model = AutoModelForCausalLM.from_pretrained(
53
- LLAMA_MODEL,
54
- torch_dtype=torch.float16,
55
- device_map="auto"
56
- )
57
- print("Llama-2 model loaded successfully!")
58
 
59
- meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
60
- if meditron_tokenizer.pad_token is None:
61
- meditron_tokenizer.pad_token = meditron_tokenizer.eos_token
62
-
63
- meditron_model = AutoModelForCausalLM.from_pretrained(
64
- MEDITRON_MODEL,
65
- torch_dtype=torch.float16,
66
- device_map="auto"
67
- )
68
- print("Meditron model loaded successfully!")
69
- except Exception as e:
70
- print(f"Error loading models: {e}")
71
 
72
- class MedicalConsultationBot:
73
- def __init__(self):
74
- self.reset_conversation()
 
75
 
76
- def reset_conversation(self):
77
- """Reset all conversation state"""
78
- self.conversation_history = []
79
- self.patient_name = None
80
- self.patient_age = None
81
- self.medical_turns = 0
82
- self.stage = "greeting" # greeting -> name -> age -> symptoms -> diagnosis
83
-
84
- def add_to_history(self, user_message, bot_response):
85
- """Add exchange to conversation history"""
86
- self.conversation_history.append({
87
- "user": user_message,
88
- "bot": bot_response
89
- })
90
-
91
- def get_conversation_context(self):
92
- """Get full conversation context as string"""
93
- context = ""
94
- if self.patient_name:
95
- context += f"Patient Name: {self.patient_name}\n"
96
- if self.patient_age:
97
- context += f"Patient Age: {self.patient_age}\n"
98
-
99
- context += "\nConversation History:\n"
100
- for exchange in self.conversation_history:
101
- context += f"Patient: {exchange['user']}\n"
102
- context += f"Doctor: {exchange['bot']}\n"
103
-
104
- return context
105
-
106
- def build_llama_prompt(self, current_message):
107
- """Build prompt for Llama model with conversation context"""
108
- prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n"
109
-
110
- # Add conversation context
111
- context = self.get_conversation_context()
112
- if context.strip():
113
- prompt += f"Previous conversation context:\n{context}\n\n"
114
-
115
- prompt += f"Current patient message: {current_message}\n\nProvide a professional medical response with appropriate follow-up questions. [/INST]"
116
-
117
- return prompt
118
-
119
- # Global bot instance
120
- medical_bot = MedicalConsultationBot()
121
 
122
- def get_meditron_diagnosis(patient_info):
123
- """Use Meditron model to generate medical assessment"""
124
- try:
125
- prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
126
- inputs = meditron_tokenizer(
127
- prompt,
128
- return_tensors="pt",
129
- max_length=512,
130
- truncation=True
131
- ).to(meditron_model.device)
132
-
133
- with torch.no_grad():
134
- outputs = meditron_model.generate(
135
- inputs.input_ids,
136
- attention_mask=inputs.attention_mask,
137
- max_new_tokens=300,
138
- temperature=0.7,
139
- top_p=0.9,
140
- do_sample=True,
141
- pad_token_id=meditron_tokenizer.pad_token_id
142
- )
143
-
144
- response = meditron_tokenizer.decode(
145
- outputs[0][inputs.input_ids.shape[1]:],
146
- skip_special_tokens=True
147
- ).strip()
148
-
149
- return response
150
- except Exception as e:
151
- return f"Error generating medical assessment: {str(e)}"
152
 
153
- @spaces.GPU
154
- def medical_chat_response(message, history):
155
- """Main chat response function with proper state management"""
156
- global medical_bot
157
 
158
- # If this is a new conversation (empty history), reset the bot
159
- if not history:
160
- medical_bot.reset_conversation()
 
 
 
 
 
 
161
 
162
- user_message = message.strip()
 
 
 
 
 
 
 
163
 
164
- # Stage 1: Initial greeting and ask for name
165
- if medical_bot.stage == "greeting":
166
- bot_response = "Hello! I'm your AI medical assistant. Before we discuss your health concerns, could you please tell me your name?"
167
- medical_bot.stage = "name"
168
- medical_bot.add_to_history(user_message, bot_response)
169
- return bot_response
170
 
171
- # Stage 2: Collect name and ask for age
172
- elif medical_bot.stage == "name":
173
- medical_bot.patient_name = user_message
174
- bot_response = f"Nice to meet you, {medical_bot.patient_name}! Could you please tell me your age?"
175
- medical_bot.stage = "age"
176
- medical_bot.add_to_history(user_message, bot_response)
177
- return bot_response
 
 
 
178
 
179
- # Stage 3: Collect age and start medical consultation
180
- elif medical_bot.stage == "age":
181
- medical_bot.patient_age = user_message
182
- bot_response = f"Thank you, {medical_bot.patient_name}! Now, what brings you here today? Please tell me about any symptoms or health concerns you're experiencing."
183
- medical_bot.stage = "symptoms"
184
- medical_bot.add_to_history(user_message, bot_response)
185
- return bot_response
186
 
187
- # Stage 4: Medical consultation - gather symptoms with intelligent follow-ups
188
- elif medical_bot.stage == "symptoms":
189
- medical_bot.medical_turns += 1
190
-
191
- # If we've had enough turns, move to diagnosis
192
- if medical_bot.medical_turns >= 4:
193
- medical_bot.stage = "diagnosis"
194
- return generate_final_diagnosis(user_message)
195
-
196
- # Generate intelligent follow-up questions
197
- try:
198
- prompt = medical_bot.build_llama_prompt(user_message)
199
- inputs = tokenizer(
200
- prompt,
201
- return_tensors="pt",
202
- max_length=1024,
203
- truncation=True
204
- ).to(model.device)
205
-
206
- with torch.no_grad():
207
- outputs = model.generate(
208
- inputs.input_ids,
209
- attention_mask=inputs.attention_mask,
210
- max_new_tokens=200,
211
- temperature=0.8,
212
- top_p=0.95,
213
- do_sample=True,
214
- pad_token_id=tokenizer.pad_token_id
215
- )
216
-
217
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
218
- bot_response = full_response.split('[/INST]')[-1].strip()
219
-
220
- # Clean up the response
221
- bot_response = bot_response.replace('<s>', '').replace('</s>', '').strip()
222
-
223
- medical_bot.add_to_history(user_message, bot_response)
224
- return bot_response
225
-
226
- except Exception as e:
227
- bot_response = f"I understand. Could you tell me more about how long you've been experiencing this and if there are any specific triggers or patterns you've noticed?"
228
- medical_bot.add_to_history(user_message, bot_response)
229
- return bot_response
230
 
231
- # Stage 5: Final diagnosis and treatment recommendations
232
- elif medical_bot.stage == "diagnosis":
233
- return generate_final_diagnosis(user_message)
 
 
 
 
 
 
 
234
 
235
- # Handle any questions after diagnosis
236
- else:
237
- # Check if they're asking about their name or previous information
238
- if "name" in user_message.lower() and medical_bot.patient_name:
239
- return f"Your name is {medical_bot.patient_name}."
240
- elif "age" in user_message.lower() and medical_bot.patient_age:
241
- return f"You told me you are {medical_bot.patient_age} years old."
242
- else:
243
- return "Is there anything else about your health concerns I can help you with today?"
244
 
245
- def generate_final_diagnosis(current_message):
246
- """Generate final diagnosis using both models"""
247
- global medical_bot
 
 
248
 
249
- # Add current message to history
250
- medical_bot.add_to_history(current_message, "")
 
 
 
 
 
 
 
 
251
 
252
- # Compile complete patient information
253
- patient_info = f"""
254
- Patient Name: {medical_bot.patient_name}
255
- Patient Age: {medical_bot.patient_age}
256
 
257
- Complete Consultation History:
258
- """
 
 
 
 
 
 
 
 
259
 
260
- for exchange in medical_bot.conversation_history[:-1]: # Exclude the empty last entry
261
- patient_info += f"Doctor: {exchange['bot']}\n"
262
- patient_info += f"Patient: {exchange['user']}\n"
263
 
264
- patient_info += f"Patient: {current_message}\n"
 
265
 
266
- # Get diagnosis from Meditron
267
- meditron_assessment = get_meditron_diagnosis(patient_info)
 
 
 
 
268
 
269
- # Generate comprehensive response
270
- final_response = f"""Thank you for providing all this information, {medical_bot.patient_name}. Based on our consultation, here is my assessment:
271
 
272
- **MEDICAL ASSESSMENT & RECOMMENDATIONS:**
273
-
274
- {meditron_assessment}
275
-
276
- **IMPORTANT DISCLAIMER:** This assessment is for informational purposes only and should not replace professional medical advice. Please consult with a healthcare provider for proper diagnosis and treatment.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- **NEXT STEPS:** I recommend scheduling an appointment with your primary care physician or appropriate specialist for further evaluation and personalized treatment.
 
279
 
280
- Is there anything specific about this assessment you'd like me to clarify?"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
- # Update conversation history with final response
283
- medical_bot.conversation_history[-1]["bot"] = final_response
284
- medical_bot.stage = "complete"
285
 
286
- return final_response
 
287
 
288
- # Create Gradio interface
289
  demo = gr.ChatInterface(
290
- fn=medical_chat_response,
291
- title="🩺 AI Medical Assistant with Memory",
292
- description="I'm an AI medical assistant that will remember our conversation. I'll first ask for your basic information, then gather details about your symptoms through intelligent follow-up questions, and finally provide a medical assessment.",
293
  examples=[
294
- "Hello, I need medical help",
295
- "I have a persistent cough",
296
- "I've been having headaches",
297
- "My stomach hurts"
298
  ],
299
- theme="soft",
300
- retry_btn=None,
301
- undo_btn=None,
302
- clear_btn="🔄 Start New Consultation"
303
  )
304
 
305
  if __name__ == "__main__":
306
- demo.launch()
 
1
  import gradio as gr
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from typing import Annotated, List, Dict, Any
5
+ from typing_extensions import TypedDict
6
+ from langgraph.graph import StateGraph, START
7
+ from langgraph.graph.message import add_messages
8
 
9
  # Model configuration
10
  LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
11
  MEDITRON_MODEL = "epfl-llm/meditron-7b"
12
 
13
+ SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
14
+ Ask 1-2 follow-up questions at a time to gather more details about:
15
+ - Detailed description of symptoms
16
+ - Duration (when did it start?)
17
+ - Severity (scale of 1-10)
18
+ - Aggravating or alleviating factors
19
+ - Related symptoms
20
+ - Medical history
21
+ - Current medications and allergies
22
+ After collecting sufficient information (4-5 exchanges), summarize findings and suggest when they should seek professional care. Do NOT make specific diagnoses or recommend specific treatments.
23
+ Respond empathetically and clearly. Always be professional and thorough."""
24
 
25
+ MEDITRON_PROMPT = """<|im_start|>system
26
+ You are a specialized medical assistant focusing ONLY on suggesting over-the-counter medicines and home remedies based on patient information.
27
+ Based on the following patient information, provide ONLY:
28
+ 1. One specific over-the-counter medicine with proper adult dosing instructions
29
+ 2. One practical home remedy that might help
30
+ 3. Clear guidance on when to seek professional medical care
31
+ Be concise, practical, and focus only on general symptom relief. Do not diagnose. Include a disclaimer that you are not a licensed medical professional.
32
+ <|im_end|>
33
+ <|im_start|>user
34
+ Patient information: {patient_info}
35
+ <|im_end|>
36
+ <|im_start|>assistant
37
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ print("Loading Llama-2 model...")
40
+ tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
41
+ if tokenizer.pad_token is None:
42
+ tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
 
 
 
 
43
 
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ LLAMA_MODEL,
46
+ torch_dtype=torch.float16,
47
+ device_map="auto"
48
+ )
49
+ print("Llama-2 model loaded successfully!")
 
 
 
 
 
 
50
 
51
+ print("Loading Meditron model...")
52
+ meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
53
+ if meditron_tokenizer.pad_token is None:
54
+ meditron_tokenizer.pad_token = meditron_tokenizer.eos_token
55
 
56
+ meditron_model = AutoModelForCausalLM.from_pretrained(
57
+ MEDITRON_MODEL,
58
+ torch_dtype=torch.float16,
59
+ device_map="auto"
60
+ )
61
+ print("Meditron model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # Define the state for our LangGraph
64
+ class ChatbotState(TypedDict):
65
+ messages: Annotated[List, add_messages]
66
+ turn_count: int
67
+ patient_info: List[str]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # Function to build Llama-2 prompt
70
+ def build_llama2_prompt(messages):
71
+ """Format the conversation history for Llama-2 chat models."""
72
+ prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n"
73
 
74
+ # Add conversation history
75
+ for i, msg in enumerate(messages[:-1]):
76
+ if i % 2 == 0: # User message
77
+ prompt += f"{msg.content} [/INST] "
78
+ else: # Assistant message
79
+ prompt += f"{msg.content} </s><s>[INST] "
80
+
81
+ # Add the current user input
82
+ prompt += f"{messages[-1].content} [/INST] "
83
 
84
+ return prompt
85
+
86
+ # Function to get Llama-2 response
87
+ def get_llama2_response(prompt, turn_count):
88
+ """Generate response from Llama-2 model."""
89
+ # Add summarization instruction after 4 turns
90
+ if turn_count >= 4:
91
+ prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
92
 
93
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
94
 
95
+ with torch.no_grad():
96
+ outputs = model.generate(
97
+ inputs.input_ids,
98
+ attention_mask=inputs.attention_mask,
99
+ max_new_tokens=512,
100
+ temperature=0.7,
101
+ top_p=0.9,
102
+ do_sample=True,
103
+ pad_token_id=tokenizer.pad_token_id
104
+ )
105
 
106
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
107
+ response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
 
 
 
 
 
108
 
109
+ return response
110
+
111
+ # Function to get Meditron suggestions
112
+ def get_meditron_suggestions(patient_info):
113
+ """Generate medicine and remedy suggestions from Meditron model."""
114
+ prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
115
+ inputs = meditron_tokenizer(prompt, return_tensors="pt").to(meditron_model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ with torch.no_grad():
118
+ outputs = meditron_model.generate(
119
+ inputs.input_ids,
120
+ attention_mask=inputs.attention_mask,
121
+ max_new_tokens=256,
122
+ temperature=0.7,
123
+ top_p=0.9,
124
+ do_sample=True,
125
+ pad_token_id=meditron_tokenizer.pad_token_id
126
+ )
127
 
128
+ suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
129
+ return suggestion
 
 
 
 
 
 
 
130
 
131
+ # Define LangGraph nodes
132
+ def process_user_input(state: ChatbotState) -> ChatbotState:
133
+ """Process user input and update state."""
134
+ # Extract the latest user message
135
+ user_message = state["messages"][-1].content
136
 
137
+ # Update patient info
138
+ return {
139
+ "patient_info": state["patient_info"] + [user_message],
140
+ "turn_count": state["turn_count"] + 1
141
+ }
142
+
143
+ def generate_llama_response(state: ChatbotState) -> ChatbotState:
144
+ """Generate response using Llama-2 model."""
145
+ prompt = build_llama2_prompt(state["messages"])
146
+ response = get_llama2_response(prompt, state["turn_count"])
147
 
148
+ return {"messages": [{"role": "assistant", "content": response}]}
 
 
 
149
 
150
+ def check_turn_count(state: ChatbotState) -> str:
151
+ """Check if we need to add medicine suggestions."""
152
+ if state["turn_count"] >= 4:
153
+ return "add_suggestions"
154
+ return "continue"
155
+
156
+ def add_medicine_suggestions(state: ChatbotState) -> ChatbotState:
157
+ """Add medicine suggestions from Meditron model."""
158
+ # Get the last assistant response
159
+ last_response = state["messages"][-1].content
160
 
161
+ # Collect full patient conversation
162
+ full_patient_info = "\n".join(state["patient_info"]) + "\n\nSummary: " + last_response
 
163
 
164
+ # Get medicine suggestions
165
+ medicine_suggestions = get_meditron_suggestions(full_patient_info)
166
 
167
+ # Format final response
168
+ final_response = (
169
+ f"{last_response}\n\n"
170
+ f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n"
171
+ f"{medicine_suggestions}"
172
+ )
173
 
174
+ # Return updated message
175
+ return {"messages": [{"role": "assistant", "content": final_response}]}
176
 
177
+ # Build the LangGraph
178
+ def build_graph():
179
+ """Build and return the LangGraph for our chatbot."""
180
+ graph = StateGraph(ChatbotState)
181
+
182
+ # Add nodes
183
+ graph.add_node("process_input", process_user_input)
184
+ graph.add_node("generate_response", generate_llama_response)
185
+ graph.add_node("add_suggestions", add_medicine_suggestions)
186
+
187
+ # Add edges
188
+ graph.add_edge(START, "process_input")
189
+ graph.add_edge("process_input", "generate_response")
190
+ graph.add_conditional_edges(
191
+ "generate_response",
192
+ check_turn_count,
193
+ {
194
+ "add_suggestions": "add_suggestions",
195
+ "continue": END
196
+ }
197
+ )
198
+ graph.add_edge("add_suggestions", END)
199
+
200
+ return graph.compile()
201
 
202
+ # Initialize the graph
203
+ chatbot_graph = build_graph()
204
 
205
+ # Function for Gradio interface
206
+ def chat_response(message, history):
207
+ """Generate chatbot response using LangGraph."""
208
+ # Initialize state if this is the first message
209
+ if not history:
210
+ state = {
211
+ "messages": [{"role": "user", "content": message}],
212
+ "turn_count": 0,
213
+ "patient_info": []
214
+ }
215
+ else:
216
+ # Convert history to messages format
217
+ messages = []
218
+ for user_msg, bot_msg in history:
219
+ messages.append({"role": "user", "content": user_msg})
220
+ messages.append({"role": "assistant", "content": bot_msg})
221
+
222
+ # Add current message
223
+ messages.append({"role": "user", "content": message})
224
+
225
+ # Get turn count from history
226
+ turn_count = len(history)
227
+
228
+ # Build patient info from history
229
+ patient_info = [user_msg for user_msg, _ in history]
230
+
231
+ state = {
232
+ "messages": messages,
233
+ "turn_count": turn_count,
234
+ "patient_info": patient_info
235
+ }
236
 
237
+ # Process through LangGraph
238
+ result = chatbot_graph.invoke(state)
 
239
 
240
+ # Return the latest assistant message
241
+ return result["messages"][-1].content
242
 
243
+ # Create the Gradio interface
244
  demo = gr.ChatInterface(
245
+ fn=chat_response,
246
+ title="Medical Assistant with LangGraph",
247
+ description="Tell me about your symptoms, and after gathering enough information, I'll suggest potential remedies.",
248
  examples=[
249
+ "I have a cough and my throat hurts",
250
+ "I've been having headaches for a week",
251
+ "My stomach has been hurting since yesterday"
 
252
  ],
253
+ theme="soft"
 
 
 
254
  )
255
 
256
  if __name__ == "__main__":
257
+ demo.launch()