Thanush commited on
Commit
031a3f5
·
1 Parent(s): 1bcbb86

Implement medical consultation app with LangChain memory management and model integration

Browse files
app.py CHANGED
@@ -1,380 +1,5 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- import spaces
5
- from langchain.memory import ConversationBufferWindowMemory
6
- from langchain.schema import HumanMessage, AIMessage
7
- import json
8
- from datetime import datetime
9
-
10
- # Model configuration - Using correct Me-LLaMA model identifier
11
- ME_LLAMA_MODEL = "clinicalnlplab/me-llama-13b"
12
-
13
- # System prompts for different phases
14
- CONSULTATION_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
15
- Ask 1-2 follow-up questions at a time to gather more details about:
16
- - Detailed description of symptoms
17
- - Duration (when did it start?)
18
- - Severity (scale of 1-10)
19
- - Aggravating or alleviating factors
20
- - Related symptoms
21
- - Medical history
22
- - Current medications and allergies
23
- After collecting sufficient information (4-5 exchanges), summarize findings and suggest when they should seek professional care. Do NOT make specific diagnoses or recommend specific treatments.
24
- Respond empathetically and clearly. Always be professional and thorough."""
25
-
26
- MEDICINE_PROMPT = """You are a specialized medical assistant. Based on the patient information gathered, provide:
27
- 1. One specific over-the-counter medicine with proper adult dosing instructions
28
- 2. One practical home remedy that might help
29
- 3. Clear guidance on when to seek professional medical care
30
-
31
- Be concise, practical, and focus only on general symptom relief. Do not diagnose. Include a disclaimer that you are not a licensed medical professional.
32
-
33
- Patient information: {patient_info}
34
- Previous conversation context: {memory_context}"""
35
-
36
- # Global variables
37
- me_llama_model = None
38
- me_llama_tokenizer = None
39
- conversation_turns = 0
40
- patient_data = []
41
-
42
- # LangChain Memory Configuration
43
- class MedicalMemoryManager:
44
- def __init__(self, k=10): # Keep last 10 conversation turns
45
- self.conversation_memory = ConversationBufferWindowMemory(k=k, return_messages=True)
46
- self.patient_context = {
47
- "symptoms": [],
48
- "medical_history": [],
49
- "medications": [],
50
- "allergies": [],
51
- "lifestyle_factors": [],
52
- "timeline": [],
53
- "severity_scores": {},
54
- "session_start": datetime.now().isoformat()
55
- }
56
-
57
- def add_interaction(self, human_input, ai_response):
58
- """Add human-AI interaction to memory"""
59
- self.conversation_memory.chat_memory.add_user_message(human_input)
60
- self.conversation_memory.chat_memory.add_ai_message(ai_response)
61
-
62
- # Extract and categorize medical information
63
- self._extract_medical_info(human_input)
64
-
65
- def _extract_medical_info(self, user_input):
66
- """Extract medical information from user input and categorize it"""
67
- user_lower = user_input.lower()
68
-
69
- # Extract symptoms (simple keyword matching)
70
- symptom_keywords = ["pain", "ache", "hurt", "sore", "cough", "fever", "nausea",
71
- "headache", "dizzy", "tired", "fatigue", "vomit", "swollen",
72
- "rash", "itch", "burn", "cramp", "bleed", "shortness of breath"]
73
-
74
- for keyword in symptom_keywords:
75
- if keyword in user_lower and keyword not in [s.lower() for s in self.patient_context["symptoms"]]:
76
- self.patient_context["symptoms"].append(user_input)
77
- break
78
-
79
- # Extract timeline information
80
- time_keywords = ["days", "weeks", "months", "hours", "yesterday", "today", "started", "began"]
81
- if any(keyword in user_lower for keyword in time_keywords):
82
- self.patient_context["timeline"].append(user_input)
83
-
84
- # Extract severity (look for numbers 1-10)
85
- import re
86
- severity_match = re.search(r'\b([1-9]|10)\b.*(?:pain|severity|scale)', user_lower)
87
- if severity_match:
88
- self.patient_context["severity_scores"][datetime.now().isoformat()] = severity_match.group(1)
89
-
90
- # Extract medications
91
- med_keywords = ["taking", "medication", "medicine", "pills", "prescribed", "drug"]
92
- if any(keyword in user_lower for keyword in med_keywords):
93
- self.patient_context["medications"].append(user_input)
94
-
95
- # Extract allergies
96
- allergy_keywords = ["allergic", "allergy", "allergies", "reaction"]
97
- if any(keyword in user_lower for keyword in allergy_keywords):
98
- self.patient_context["allergies"].append(user_input)
99
-
100
- def get_memory_context(self):
101
- """Get formatted memory context for the model"""
102
- messages = self.conversation_memory.chat_memory.messages
103
- context = []
104
-
105
- for msg in messages[-6:]: # Last 6 messages (3 exchanges)
106
- if isinstance(msg, HumanMessage):
107
- context.append(f"Patient: {msg.content}")
108
- elif isinstance(msg, AIMessage):
109
- context.append(f"Doctor: {msg.content}")
110
-
111
- return "\n".join(context)
112
-
113
- def get_patient_summary(self):
114
- """Get structured patient information summary"""
115
- summary = {
116
- "conversation_turns": len(self.conversation_memory.chat_memory.messages) // 2,
117
- "session_duration": datetime.now().isoformat(),
118
- "key_symptoms": self.patient_context["symptoms"][-3:], # Last 3 symptoms mentioned
119
- "timeline_info": self.patient_context["timeline"][-2:], # Last 2 timeline mentions
120
- "medications": self.patient_context["medications"],
121
- "allergies": self.patient_context["allergies"],
122
- "severity_scores": self.patient_context["severity_scores"]
123
- }
124
- return json.dumps(summary, indent=2)
125
-
126
- def reset_session(self):
127
- """Reset memory for new consultation"""
128
- self.conversation_memory.clear()
129
- self.patient_context = {
130
- "symptoms": [],
131
- "medical_history": [],
132
- "medications": [],
133
- "allergies": [],
134
- "lifestyle_factors": [],
135
- "timeline": [],
136
- "severity_scores": {},
137
- "session_start": datetime.now().isoformat()
138
- }
139
-
140
- # Initialize memory manager
141
- memory_manager = MedicalMemoryManager()
142
-
143
- def build_me_llama_prompt(system_prompt, history, user_input):
144
- """Format the conversation for Me-LLaMA chat model with memory context."""
145
- # Get memory context from LangChain
146
- memory_context = memory_manager.get_memory_context()
147
-
148
- # Enhance system prompt with memory context
149
- enhanced_system_prompt = f"{system_prompt}\n\nPrevious conversation context:\n{memory_context}"
150
-
151
- # Use standard Llama-2 chat format since Me-LLaMA is based on Llama-2
152
- prompt = f"<s>[INST] <<SYS>>\n{enhanced_system_prompt}\n<</SYS>>\n\n"
153
-
154
- # Add only recent history to avoid token limit issues
155
- recent_history = history[-3:] if len(history) > 3 else history
156
- for user_msg, assistant_msg in recent_history:
157
- prompt += f"{user_msg} [/INST] {assistant_msg} </s><s>[INST] "
158
-
159
- # Add the current user input
160
- prompt += f"{user_input} [/INST] "
161
-
162
- return prompt
163
-
164
- @spaces.GPU
165
- def load_model_if_needed():
166
- """Load Me-LLaMA model only when GPU is available."""
167
- global me_llama_model, me_llama_tokenizer
168
-
169
- if me_llama_model is None:
170
- print("Loading Me-LLaMA 13B model...")
171
- try:
172
- me_llama_tokenizer = AutoTokenizer.from_pretrained(
173
- ME_LLAMA_MODEL,
174
- trust_remote_code=True
175
- )
176
- me_llama_model = AutoModelForCausalLM.from_pretrained(
177
- ME_LLAMA_MODEL,
178
- torch_dtype=torch.float16,
179
- device_map="auto",
180
- trust_remote_code=True
181
- )
182
- print("Me-LLaMA 13B model loaded successfully!")
183
- except Exception as e:
184
- print(f"Error loading model: {e}")
185
- # Fallback to a working medical model
186
- print("Falling back to Llama-2-7b-chat-hf...")
187
- me_llama_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
188
- me_llama_model = AutoModelForCausalLM.from_pretrained(
189
- "meta-llama/Llama-2-7b-chat-hf",
190
- torch_dtype=torch.float16,
191
- device_map="auto"
192
- )
193
- print("Fallback model loaded successfully!")
194
-
195
- @spaces.GPU
196
- def generate_medicine_suggestions(patient_info, memory_context):
197
- """Use Me-LLaMA to generate medicine and remedy suggestions with memory context."""
198
- load_model_if_needed()
199
-
200
- # Create a prompt with both patient info and memory context
201
- prompt = f"<s>[INST] {MEDICINE_PROMPT.format(patient_info=patient_info, memory_context=memory_context)} [/INST] "
202
-
203
- inputs = me_llama_tokenizer(prompt, return_tensors="pt")
204
-
205
- # Move inputs to the same device as the model
206
- if torch.cuda.is_available():
207
- inputs = {k: v.to(me_llama_model.device) for k, v in inputs.items()}
208
-
209
- with torch.no_grad():
210
- outputs = me_llama_model.generate(
211
- inputs["input_ids"],
212
- attention_mask=inputs["attention_mask"],
213
- max_new_tokens=300,
214
- temperature=0.7,
215
- top_p=0.9,
216
- do_sample=True,
217
- pad_token_id=me_llama_tokenizer.eos_token_id
218
- )
219
-
220
- suggestion = me_llama_tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
221
- return suggestion
222
-
223
- @spaces.GPU
224
- def generate_response(message, history):
225
- """Generate response using Me-LLaMA with LangChain memory management."""
226
- global conversation_turns, patient_data
227
-
228
- try:
229
- # Load model if needed
230
- load_model_if_needed()
231
-
232
- # Track conversation turns
233
- conversation_turns += 1
234
-
235
- # Store patient data (legacy support)
236
- patient_data.append(message)
237
-
238
- # Phase 1-3: Information gathering with memory
239
- if conversation_turns < 4:
240
- # Build consultation prompt with memory context
241
- prompt = build_me_llama_prompt(CONSULTATION_PROMPT, history, message)
242
-
243
- inputs = me_llama_tokenizer(prompt, return_tensors="pt")
244
-
245
- # Move inputs to the same device as the model
246
- if torch.cuda.is_available():
247
- inputs = {k: v.to(me_llama_model.device) for k, v in inputs.items()}
248
-
249
- # Generate consultation response
250
- with torch.no_grad():
251
- outputs = me_llama_model.generate(
252
- inputs["input_ids"],
253
- attention_mask=inputs["attention_mask"],
254
- max_new_tokens=400,
255
- temperature=0.7,
256
- top_p=0.9,
257
- do_sample=True,
258
- pad_token_id=me_llama_tokenizer.eos_token_id
259
- )
260
-
261
- # Decode response
262
- full_response = me_llama_tokenizer.decode(outputs[0], skip_special_tokens=False)
263
- response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
264
-
265
- # Add interaction to memory
266
- memory_manager.add_interaction(message, response)
267
-
268
- return response
269
-
270
- # Phase 4+: Summary and medicine suggestions with full memory context
271
- else:
272
- # Get comprehensive patient summary from memory
273
- patient_summary = memory_manager.get_patient_summary()
274
- memory_context = memory_manager.get_memory_context()
275
-
276
- # First, get summary from consultation with memory context
277
- summary_prompt = build_me_llama_prompt(
278
- CONSULTATION_PROMPT + "\n\nNow provide a comprehensive summary based on all the information gathered. Include when professional care may be needed.",
279
- history,
280
- message
281
- )
282
-
283
- inputs = me_llama_tokenizer(summary_prompt, return_tensors="pt")
284
-
285
- if torch.cuda.is_available():
286
- inputs = {k: v.to(me_llama_model.device) for k, v in inputs.items()}
287
-
288
- # Generate summary
289
- with torch.no_grad():
290
- outputs = me_llama_model.generate(
291
- inputs["input_ids"],
292
- attention_mask=inputs["attention_mask"],
293
- max_new_tokens=400,
294
- temperature=0.7,
295
- top_p=0.9,
296
- do_sample=True,
297
- pad_token_id=me_llama_tokenizer.eos_token_id
298
- )
299
-
300
- summary_response = me_llama_tokenizer.decode(outputs[0], skip_special_tokens=False)
301
- summary = summary_response.split('[/INST]')[-1].split('</s>')[0].strip()
302
-
303
- # Get medicine suggestions using memory context
304
- full_patient_info = f"Patient Summary: {patient_summary}\n\nDetailed Summary: {summary}"
305
- medicine_suggestions = generate_medicine_suggestions(full_patient_info, memory_context)
306
-
307
- # Combine both responses
308
- final_response = (
309
- f"**COMPREHENSIVE MEDICAL SUMMARY:**\n{summary}\n\n"
310
- f"**MEDICATION AND HOME CARE SUGGESTIONS:**\n{medicine_suggestions}\n\n"
311
- f"**PATIENT CONTEXT SUMMARY:**\n{patient_summary}\n\n"
312
- f"**DISCLAIMER:** This is AI-generated advice for informational purposes only. Please consult a licensed healthcare provider for proper medical diagnosis and treatment."
313
- )
314
-
315
- # Add final interaction to memory
316
- memory_manager.add_interaction(message, final_response)
317
-
318
- return final_response
319
-
320
- except Exception as e:
321
- error_msg = f"I apologize, but I'm experiencing technical difficulties. Please try again. Error: {str(e)}"
322
- # Still try to add to memory even on error
323
- try:
324
- memory_manager.add_interaction(message, error_msg)
325
- except:
326
- pass
327
- return error_msg
328
-
329
- def reset_consultation():
330
- """Reset the consultation and memory for a new patient."""
331
- global conversation_turns, patient_data, memory_manager
332
-
333
- conversation_turns = 0
334
- patient_data = []
335
- memory_manager.reset_session()
336
-
337
- return "New consultation started. Please tell me about your symptoms or health concerns."
338
-
339
- # Create the Gradio interface with memory reset option
340
- with gr.Blocks(theme="soft") as demo:
341
- gr.Markdown("# 🏥 Complete Medical Assistant - Me-LLaMA 13B with Memory")
342
- gr.Markdown("Comprehensive medical consultation powered by Me-LLaMA 13B with LangChain memory management. One model handles both consultation and medicine suggestions with full context awareness.")
343
-
344
- with gr.Row():
345
- with gr.Column(scale=4):
346
- chatbot = gr.Chatbot(height=500)
347
- msg = gr.Textbox(
348
- placeholder="Tell me about your symptoms or health concerns...",
349
- label="Your Message"
350
- )
351
-
352
- with gr.Column(scale=1):
353
- reset_btn = gr.Button("🔄 Start New Consultation", variant="secondary")
354
- gr.Markdown("**Memory Features:**\n- Tracks symptoms & timeline\n- Remembers medications & allergies\n- Maintains conversation context\n- Provides comprehensive summaries")
355
-
356
- # Examples
357
- gr.Examples(
358
- examples=[
359
- "I have a persistent cough and sore throat for 3 days",
360
- "I've been having severe headaches and feel dizzy",
361
- "My stomach hurts and I feel nauseous after eating"
362
- ],
363
- inputs=msg
364
- )
365
-
366
- # Event handlers
367
- def respond(message, chat_history):
368
- bot_message = generate_response(message, chat_history)
369
- chat_history.append((message, bot_message))
370
- return "", chat_history
371
-
372
- def reset_chat():
373
- reset_msg = reset_consultation()
374
- return [(None, reset_msg)], ""
375
-
376
- msg.submit(respond, [msg, chatbot], [msg, chatbot])
377
- reset_btn.click(reset_chat, [], [chatbot, msg])
378
 
379
  if __name__ == "__main__":
 
380
  demo.launch()
 
1
+ from medbot.interface import build_interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  if __name__ == "__main__":
4
+ demo = build_interface()
5
  demo.launch()
medbot/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
medbot/config.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ME_LLAMA_MODEL = "clinicalnlplab/me-llama-13b"
2
+ FALLBACK_MODEL = "meta-llama/Llama-2-7b-chat-hf"
medbot/handlers.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model import ModelManager
2
+ from .memory import MedicalMemoryManager
3
+ from .prompts import CONSULTATION_PROMPT, MEDICINE_PROMPT
4
+
5
+ model_manager = ModelManager()
6
+ memory_manager = MedicalMemoryManager()
7
+ conversation_turns = 0
8
+
9
+
10
+ def build_me_llama_prompt(system_prompt, history, user_input):
11
+ memory_context = memory_manager.get_memory_context()
12
+ enhanced_system_prompt = f"{system_prompt}\n\nPrevious conversation context:\n{memory_context}"
13
+ prompt = f"<s>[INST] <<SYS>>\n{enhanced_system_prompt}\n<</SYS>>\n\n"
14
+ recent_history = history[-3:] if len(history) > 3 else history
15
+ for user_msg, assistant_msg in recent_history:
16
+ prompt += f"{user_msg} [/INST] {assistant_msg} </s><s>[INST] "
17
+ prompt += f"{user_input} [/INST] "
18
+ return prompt
19
+
20
+ def respond(message, chat_history):
21
+ global conversation_turns
22
+ conversation_turns += 1
23
+ if conversation_turns < 4:
24
+ prompt = build_me_llama_prompt(CONSULTATION_PROMPT, chat_history, message)
25
+ response = model_manager.generate(prompt)
26
+ memory_manager.add_interaction(message, response)
27
+ chat_history.append((message, response))
28
+ return "", chat_history
29
+ else:
30
+ patient_summary = memory_manager.get_patient_summary()
31
+ memory_context = memory_manager.get_memory_context()
32
+ summary_prompt = build_me_llama_prompt(
33
+ CONSULTATION_PROMPT + "\n\nNow provide a comprehensive summary based on all the information gathered. Include when professional care may be needed.",
34
+ chat_history,
35
+ message
36
+ )
37
+ summary = model_manager.generate(summary_prompt)
38
+ full_patient_info = f"Patient Summary: {patient_summary}\n\nDetailed Summary: {summary}"
39
+ med_prompt = f"<s>[INST] {MEDICINE_PROMPT.format(patient_info=full_patient_info, memory_context=memory_context)} [/INST] "
40
+ medicine_suggestions = model_manager.generate(med_prompt, max_new_tokens=300)
41
+ final_response = (
42
+ f"**COMPREHENSIVE MEDICAL SUMMARY:**\n{summary}\n\n"
43
+ f"**MEDICATION AND HOME CARE SUGGESTIONS:**\n{medicine_suggestions}\n\n"
44
+ f"**PATIENT CONTEXT SUMMARY:**\n{patient_summary}\n\n"
45
+ f"**DISCLAIMER:** This is AI-generated advice for informational purposes only. Please consult a licensed healthcare provider for proper medical diagnosis and treatment."
46
+ )
47
+ memory_manager.add_interaction(message, final_response)
48
+ chat_history.append((message, final_response))
49
+ return "", chat_history
50
+
51
+ def reset_chat():
52
+ global conversation_turns
53
+ conversation_turns = 0
54
+ memory_manager.reset_session()
55
+ reset_msg = "New consultation started. Please tell me about your symptoms or health concerns."
56
+ return [(None, reset_msg)], ""
medbot/interface.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from .handlers import respond, reset_chat
3
+
4
+ def build_interface():
5
+ with gr.Blocks(theme="soft") as demo:
6
+ gr.Markdown("# 🏥 Complete Medical Assistant - Me-LLaMA 13B with Memory")
7
+ gr.Markdown("Comprehensive medical consultation powered by Me-LLaMA 13B with LangChain memory management. One model handles both consultation and medicine suggestions with full context awareness.")
8
+ with gr.Row():
9
+ with gr.Column(scale=4):
10
+ chatbot = gr.Chatbot(height=500)
11
+ msg = gr.Textbox(
12
+ placeholder="Tell me about your symptoms or health concerns...",
13
+ label="Your Message"
14
+ )
15
+ with gr.Column(scale=1):
16
+ reset_btn = gr.Button("🔄 Start New Consultation", variant="secondary")
17
+ gr.Markdown("**Memory Features:**\n- Tracks symptoms & timeline\n- Remembers medications & allergies\n- Maintains conversation context\n- Provides comprehensive summaries")
18
+ gr.Examples(
19
+ examples=[
20
+ "I have a persistent cough and sore throat for 3 days",
21
+ "I've been having severe headaches and feel dizzy",
22
+ "My stomach hurts and I feel nauseous after eating"
23
+ ],
24
+ inputs=msg
25
+ )
26
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
27
+ reset_btn.click(reset_chat, [], [chatbot, msg])
28
+ return demo
medbot/memory.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.memory import ConversationBufferWindowMemory
2
+ from langchain.schema import HumanMessage, AIMessage
3
+ from datetime import datetime
4
+ import json
5
+ import re
6
+
7
+ class MedicalMemoryManager:
8
+ def __init__(self, k=10):
9
+ self.conversation_memory = ConversationBufferWindowMemory(k=k, return_messages=True)
10
+ self.patient_context = {
11
+ "symptoms": [],
12
+ "medical_history": [],
13
+ "medications": [],
14
+ "allergies": [],
15
+ "lifestyle_factors": [],
16
+ "timeline": [],
17
+ "severity_scores": {},
18
+ "session_start": datetime.now().isoformat()
19
+ }
20
+
21
+ def add_interaction(self, human_input, ai_response):
22
+ self.conversation_memory.chat_memory.add_user_message(human_input)
23
+ self.conversation_memory.chat_memory.add_ai_message(ai_response)
24
+ self._extract_medical_info(human_input)
25
+
26
+ def _extract_medical_info(self, user_input):
27
+ user_lower = user_input.lower()
28
+ symptom_keywords = ["pain", "ache", "hurt", "sore", "cough", "fever", "nausea", "headache", "dizzy", "tired", "fatigue", "vomit", "swollen", "rash", "itch", "burn", "cramp", "bleed", "shortness of breath"]
29
+ for keyword in symptom_keywords:
30
+ if keyword in user_lower and keyword not in [s.lower() for s in self.patient_context["symptoms"]]:
31
+ self.patient_context["symptoms"].append(user_input)
32
+ break
33
+ time_keywords = ["days", "weeks", "months", "hours", "yesterday", "today", "started", "began"]
34
+ if any(keyword in user_lower for keyword in time_keywords):
35
+ self.patient_context["timeline"].append(user_input)
36
+ severity_match = re.search(r'\b([1-9]|10)\b.*(?:pain|severity|scale)', user_lower)
37
+ if severity_match:
38
+ self.patient_context["severity_scores"][datetime.now().isoformat()] = severity_match.group(1)
39
+ med_keywords = ["taking", "medication", "medicine", "pills", "prescribed", "drug"]
40
+ if any(keyword in user_lower for keyword in med_keywords):
41
+ self.patient_context["medications"].append(user_input)
42
+ allergy_keywords = ["allergic", "allergy", "allergies", "reaction"]
43
+ if any(keyword in user_lower for keyword in allergy_keywords):
44
+ self.patient_context["allergies"].append(user_input)
45
+
46
+ def get_memory_context(self):
47
+ messages = self.conversation_memory.chat_memory.messages
48
+ context = []
49
+ for msg in messages[-6:]:
50
+ if isinstance(msg, HumanMessage):
51
+ context.append(f"Patient: {msg.content}")
52
+ elif isinstance(msg, AIMessage):
53
+ context.append(f"Doctor: {msg.content}")
54
+ return "\n".join(context)
55
+
56
+ def get_patient_summary(self):
57
+ summary = {
58
+ "conversation_turns": len(self.conversation_memory.chat_memory.messages) // 2,
59
+ "session_duration": datetime.now().isoformat(),
60
+ "key_symptoms": self.patient_context["symptoms"][-3:],
61
+ "timeline_info": self.patient_context["timeline"][-2:],
62
+ "medications": self.patient_context["medications"],
63
+ "allergies": self.patient_context["allergies"],
64
+ "severity_scores": self.patient_context["severity_scores"]
65
+ }
66
+ return json.dumps(summary, indent=2)
67
+
68
+ def reset_session(self):
69
+ self.conversation_memory.clear()
70
+ self.patient_context = {
71
+ "symptoms": [],
72
+ "medical_history": [],
73
+ "medications": [],
74
+ "allergies": [],
75
+ "lifestyle_factors": [],
76
+ "timeline": [],
77
+ "severity_scores": {},
78
+ "session_start": datetime.now().isoformat()
79
+ }
medbot/model.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from .config import ME_LLAMA_MODEL, FALLBACK_MODEL
4
+
5
+ class ModelManager:
6
+ def __init__(self):
7
+ self.model = None
8
+ self.tokenizer = None
9
+
10
+ def load(self):
11
+ if self.model is not None and self.tokenizer is not None:
12
+ return
13
+ try:
14
+ self.tokenizer = AutoTokenizer.from_pretrained(ME_LLAMA_MODEL, trust_remote_code=True)
15
+ self.model = AutoModelForCausalLM.from_pretrained(
16
+ ME_LLAMA_MODEL,
17
+ torch_dtype=torch.float16,
18
+ device_map="auto",
19
+ trust_remote_code=True
20
+ )
21
+ except Exception as e:
22
+ print(f"Error loading model: {e}")
23
+ print("Falling back to Llama-2-7b-chat-hf...")
24
+ self.tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
25
+ self.model = AutoModelForCausalLM.from_pretrained(
26
+ FALLBACK_MODEL,
27
+ torch_dtype=torch.float16,
28
+ device_map="auto"
29
+ )
30
+
31
+ def generate(self, prompt, max_new_tokens=400, temperature=0.7, top_p=0.9):
32
+ self.load()
33
+ inputs = self.tokenizer(prompt, return_tensors="pt")
34
+ if torch.cuda.is_available():
35
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
36
+ with torch.no_grad():
37
+ outputs = self.model.generate(
38
+ inputs["input_ids"],
39
+ attention_mask=inputs["attention_mask"],
40
+ max_new_tokens=max_new_tokens,
41
+ temperature=temperature,
42
+ top_p=top_p,
43
+ do_sample=True,
44
+ pad_token_id=self.tokenizer.eos_token_id
45
+ )
46
+ return self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
medbot/prompts.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONSULTATION_PROMPT = '''You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
2
+ Ask 1-2 follow-up questions at a time to gather more details about:
3
+ - Detailed description of symptoms
4
+ - Duration (when did it start?)
5
+ - Severity (scale of 1-10)
6
+ - Aggravating or alleviating factors
7
+ - Related symptoms
8
+ - Medical history
9
+ - Current medications and allergies
10
+ After collecting sufficient information (4-5 exchanges), summarize findings and suggest when they should seek professional care. Do NOT make specific diagnoses or recommend specific treatments.
11
+ Respond empathetically and clearly. Always be professional and thorough.'''
12
+
13
+ MEDICINE_PROMPT = '''You are a specialized medical assistant. Based on the patient information gathered, provide:
14
+ 1. One specific over-the-counter medicine with proper adult dosing instructions
15
+ 2. One practical home remedy that might help
16
+ 3. Clear guidance on when to seek professional medical care
17
+
18
+ Be concise, practical, and focus only on general symptom relief. Do not diagnose. Include a disclaimer that you are not a licensed medical professional.
19
+
20
+ Patient information: {patient_info}
21
+ Previous conversation context: {memory_context}'''
medbot/utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Utility functions for medbot
2
+
3
+ def extract_symptoms(text):
4
+ # Placeholder for advanced symptom extraction logic
5
+ return []