techindia2025 commited on
Commit
71bcd31
·
1 Parent(s): 9ae5fda

update meditron

Browse files
Files changed (2) hide show
  1. app.py +237 -512
  2. requirements.txt +14 -41
app.py CHANGED
@@ -1,530 +1,255 @@
1
  import gradio as gr
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from langgraph.graph import StateGraph, END
5
- from typing import TypedDict, List, Dict
6
- from datetime import datetime
7
- import re
8
-
9
- # Enhanced State Management with RAG
10
- class MedicalState(TypedDict):
11
- patient_id: str
12
- conversation_history: List[Dict]
13
- symptoms: Dict[str, any]
14
- vital_questions_asked: List[str]
15
- medical_history: Dict
16
- current_medications: List[str]
17
- allergies: List[str]
18
- severity_scores: Dict[str, int]
19
- red_flags: List[str]
20
- assessment_complete: bool
21
- suggested_actions: List[str]
22
- consultation_stage: str
23
- retrieved_knowledge: List[Dict]
24
- confidence_scores: Dict[str, float]
25
-
26
- # Medical Knowledge Base for RAG
27
- MEDICAL_KNOWLEDGE_BASE = {
28
- "conditions": {
29
- "common_cold": {
30
- "symptoms": ["runny nose", "cough", "sneezing", "sore throat", "mild fever"],
31
- "treatment": "Rest, fluids, OTC pain relievers",
32
- "otc_medications": [
33
- {"name": "Acetaminophen", "dose": "500-1000mg every 4-6 hours", "max_daily": "3000mg"},
34
- {"name": "Ibuprofen", "dose": "200-400mg every 4-6 hours", "max_daily": "1200mg"}
35
- ],
36
- "home_remedies": ["Warm salt water gargle", "Honey and lemon tea", "Steam inhalation"],
37
- "when_to_seek_care": "If symptoms worsen after 7-10 days or fever above 101.3°F"
38
- },
39
- "headache": {
40
- "symptoms": ["head pain", "pressure", "throbbing"],
41
- "treatment": "Pain relief, rest, hydration",
42
- "otc_medications": [
43
- {"name": "Acetaminophen", "dose": "500-1000mg every 4-6 hours", "max_daily": "3000mg"},
44
- {"name": "Ibuprofen", "dose": "400-600mg every 6-8 hours", "max_daily": "1200mg"}
45
- ],
46
- "home_remedies": ["Cold or warm compress", "Rest in dark room", "Stay hydrated"],
47
- "when_to_seek_care": "Sudden severe headache, fever, neck stiffness, vision changes"
48
- },
49
- "stomach_pain": {
50
- "symptoms": ["abdominal pain", "nausea", "bloating", "cramps"],
51
- "treatment": "Bland diet, rest, hydration",
52
- "otc_medications": [
53
- {"name": "Pepto-Bismol", "dose": "525mg every 30 minutes as needed", "max_daily": "8 doses"},
54
- {"name": "TUMS", "dose": "2-4 tablets as needed", "max_daily": "15 tablets"}
55
- ],
56
- "home_remedies": ["BRAT diet", "Ginger tea", "Warm compress on stomach"],
57
- "when_to_seek_care": "Severe pain, fever, vomiting, blood in stool"
58
- }
59
- }
60
- }
61
-
62
- MEDICAL_CATEGORIES = {
63
- "respiratory": ["cough", "shortness of breath", "chest pain", "wheezing", "runny nose", "sore throat"],
64
- "gastrointestinal": ["nausea", "vomiting", "diarrhea", "stomach pain", "heartburn", "bloating"],
65
- "neurological": ["headache", "dizziness", "numbness", "tingling"],
66
- "musculoskeletal": ["joint pain", "muscle pain", "back pain", "stiffness"],
67
- "cardiovascular": ["chest pain", "palpitations", "swelling", "fatigue"],
68
- "dermatological": ["rash", "itching", "skin changes", "wounds"],
69
- "mental_health": ["anxiety", "depression", "sleep issues", "stress"]
70
- }
71
-
72
- RED_FLAGS = [
73
- "chest pain", "difficulty breathing", "severe headache", "high fever",
74
- "blood in stool", "blood in urine", "severe abdominal pain",
75
- "sudden vision changes", "loss of consciousness", "severe allergic reaction"
76
- ]
77
-
78
- class SimpleRAGSystem:
79
- def __init__(self):
80
- self.knowledge_base = MEDICAL_KNOWLEDGE_BASE
81
- self.setup_simple_retrieval()
82
-
83
- def setup_simple_retrieval(self):
84
- """Setup simple keyword-based retrieval system"""
85
- self.symptom_to_condition = {}
86
-
87
- for condition, data in self.knowledge_base["conditions"].items():
88
- for symptom in data["symptoms"]:
89
- if symptom not in self.symptom_to_condition:
90
- self.symptom_to_condition[symptom] = []
91
- self.symptom_to_condition[symptom].append(condition)
92
-
93
- def retrieve_relevant_conditions(self, symptoms: List[str]) -> List[Dict]:
94
- """Retrieve relevant medical conditions based on symptoms"""
95
- relevant_conditions = {}
96
-
97
- for symptom in symptoms:
98
- symptom_lower = symptom.lower()
99
-
100
- # Direct match
101
- if symptom_lower in self.symptom_to_condition:
102
- for condition in self.symptom_to_condition[symptom_lower]:
103
- if condition not in relevant_conditions:
104
- relevant_conditions[condition] = self.knowledge_base["conditions"][condition]
105
-
106
- # Partial match
107
- for kb_symptom, conditions in self.symptom_to_condition.items():
108
- if symptom_lower in kb_symptom or kb_symptom in symptom_lower:
109
- for condition in conditions:
110
- if condition not in relevant_conditions:
111
- relevant_conditions[condition] = self.knowledge_base["conditions"][condition]
112
-
113
- return [{"condition": k, "data": v} for k, v in relevant_conditions.items()]
114
-
115
- class EnhancedMedicalAssistant:
116
- def __init__(self):
117
- self.load_models()
118
- self.rag_system = SimpleRAGSystem()
119
- self.setup_langgraph()
120
- self.conversation_count = {}
121
-
122
- def load_models(self):
123
- """Load the AI models with fallback options"""
124
- print("Loading models...")
125
- try:
126
- # Llama-2 for conversation
127
- self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
128
- if self.tokenizer.pad_token is None:
129
- self.tokenizer.pad_token = self.tokenizer.eos_token
130
-
131
- self.model = AutoModelForCausalLM.from_pretrained(
132
- "meta-llama/Llama-2-7b-chat-hf",
133
- torch_dtype=torch.float16,
134
- device_map="auto"
135
- )
136
-
137
- # Meditron for medical suggestions
138
- self.meditron_tokenizer = AutoTokenizer.from_pretrained("epfl-llm/meditron-7b")
139
- if self.meditron_tokenizer.pad_token is None:
140
- self.meditron_tokenizer.pad_token = self.meditron_tokenizer.eos_token
141
-
142
- self.meditron_model = AutoModelForCausalLM.from_pretrained(
143
- "epfl-llm/meditron-7b",
144
- torch_dtype=torch.float16,
145
- device_map="auto"
146
- )
147
- print("Models loaded successfully!")
148
-
149
- except Exception as e:
150
- print(f"Error loading models: {e}")
151
- # Fallback - use smaller models
152
- self.tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
153
- self.model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
154
- self.meditron_tokenizer = self.tokenizer
155
- self.meditron_model = self.model
156
-
157
- def setup_langgraph(self):
158
- """Setup LangGraph workflow"""
159
- workflow = StateGraph(MedicalState)
160
-
161
- workflow.add_node("intake", self.patient_intake)
162
- workflow.add_node("generate_recommendations", self.generate_recommendations)
163
- workflow.add_node("emergency_triage", self.emergency_triage)
164
-
165
- workflow.set_entry_point("intake")
166
- workflow.add_conditional_edges(
167
- "intake",
168
- self.route_after_intake,
169
- {
170
- "emergency": "emergency_triage",
171
- "recommendations": "generate_recommendations"
172
- }
173
  )
174
- workflow.add_edge("generate_recommendations", END)
175
- workflow.add_edge("emergency_triage", END)
176
-
177
- self.workflow = workflow.compile()
178
-
179
- def patient_intake(self, state: MedicalState) -> MedicalState:
180
- """Enhanced patient intake with RAG"""
181
- last_message = state["conversation_history"][-1]["content"] if state["conversation_history"] else ""
182
-
183
- # Extract symptoms
184
- detected_symptoms = self.extract_symptoms(last_message)
185
- state["symptoms"].update(detected_symptoms)
186
-
187
- # Use RAG to get relevant medical knowledge
188
- if detected_symptoms:
189
- symptom_names = list(detected_symptoms.keys())
190
- relevant_conditions = self.rag_system.retrieve_relevant_conditions(symptom_names)
191
- state["retrieved_knowledge"] = relevant_conditions
192
-
193
- # Check for red flags
194
- red_flags = self.check_red_flags(last_message)
195
- state["red_flags"].extend(red_flags)
196
-
197
- # Determine consultation stage
198
- if red_flags:
199
- state["consultation_stage"] = "emergency"
200
- else:
201
- state["consultation_stage"] = "recommendations"
202
-
203
- return state
204
-
205
- def generate_recommendations(self, state: MedicalState) -> MedicalState:
206
- """Generate RAG-enhanced recommendations"""
207
- # Create structured recommendations from RAG knowledge
208
- recommendations = self.create_structured_recommendations(state)
209
- state["suggested_actions"] = recommendations
210
- return state
211
-
212
- def create_structured_recommendations(self, state: MedicalState) -> List[str]:
213
- """Create structured recommendations using RAG knowledge"""
214
- recommendations = []
215
-
216
- if not state["retrieved_knowledge"]:
217
- recommendations.append("I need more specific information about your symptoms to provide targeted recommendations.")
218
- return recommendations
219
-
220
- # Process each relevant condition
221
- for knowledge_item in state["retrieved_knowledge"][:2]: # Limit to top 2 conditions
222
- condition = knowledge_item["condition"]
223
- data = knowledge_item["data"]
224
-
225
- # Format condition information
226
- condition_info = f"\n**Possible Condition: {condition.replace('_', ' ').title()}**\n"
227
-
228
- # Add medications
229
- if "otc_medications" in data:
230
- condition_info += "\n**💊 Over-the-Counter Medications:**\n"
231
- for med in data["otc_medications"]:
232
- condition_info += f"• **{med['name']}**: {med['dose']} (Max daily: {med['max_daily']})\n"
233
-
234
- # Add home remedies
235
- if "home_remedies" in data:
236
- condition_info += "\n**🏠 Home Remedies:**\n"
237
- for remedy in data["home_remedies"]:
238
- condition_info += f"• {remedy}\n"
239
-
240
- # Add when to seek care
241
- if "when_to_seek_care" in data:
242
- condition_info += f"\n**⚠️ Seek Medical Care If:** {data['when_to_seek_care']}\n"
243
-
244
- recommendations.append(condition_info)
245
-
246
- # Add general advice
247
- recommendations.append("""
248
- **📋 General Recommendations:**
249
- • Monitor your symptoms for any changes
250
- • Stay hydrated and get adequate rest
251
- • Follow medication instructions carefully
252
- • Don't exceed recommended dosages
253
-
254
- **🚨 Emergency Warning Signs:**
255
- • Severe worsening of symptoms
256
- • High fever (>101.3°F/38.5°C)
257
- • Difficulty breathing
258
- • Severe pain
259
- • Signs of dehydration
260
- """)
261
-
262
- return recommendations
263
-
264
- def emergency_triage(self, state: MedicalState) -> MedicalState:
265
- """Handle emergency situations"""
266
- emergency_response = f"""
267
- 🚨 **URGENT MEDICAL ATTENTION NEEDED** 🚨
268
-
269
- Based on your symptoms, I strongly recommend seeking immediate medical care because you mentioned: {', '.join(state['red_flags'])}
270
-
271
- **Immediate Actions:**
272
- • Go to the nearest emergency room, OR
273
- • Call emergency services (911), OR
274
- • Contact your doctor immediately
275
-
276
- **Why This is Urgent:**
277
- These symptoms can indicate serious conditions that require professional medical evaluation and treatment.
278
-
279
- ⚠️ **Disclaimer:** This is not a medical diagnosis, but these symptoms warrant immediate professional assessment.
280
- """
281
-
282
- state["suggested_actions"] = [emergency_response]
283
- return state
284
 
285
- def route_after_intake(self, state: MedicalState):
286
- """Route decision after intake"""
287
- if state["red_flags"]:
288
- return "emergency"
289
- else:
290
- return "recommendations"
291
-
292
- def extract_symptoms(self, text: str) -> Dict:
293
- """Extract and categorize symptoms from patient text"""
294
- symptoms = {}
295
- text_lower = text.lower()
296
-
297
- for category, symptom_list in MEDICAL_CATEGORIES.items():
298
- for symptom in symptom_list:
299
- if symptom in text_lower:
300
- symptoms[symptom] = {
301
- "category": category,
302
- "mentioned_at": datetime.now().isoformat(),
303
- "context": text
304
- }
305
-
306
- return symptoms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
- def check_red_flags(self, text: str) -> List[str]:
309
- """Check for emergency red flags"""
310
- found_flags = []
311
- text_lower = text.lower()
312
-
313
- for flag in RED_FLAGS:
314
- if flag in text_lower:
315
- found_flags.append(flag)
316
-
317
- return found_flags
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
- def generate_response(self, message: str, history: List) -> str:
320
- """Main response generation function"""
321
- session_id = "default_session"
322
-
323
- # Track conversation count
324
- if session_id not in self.conversation_count:
325
- self.conversation_count[session_id] = 0
326
- self.conversation_count[session_id] += 1
327
-
328
- # Initialize state
329
- state = MedicalState(
330
- patient_id=session_id,
331
- conversation_history=history + [{"role": "user", "content": message}],
332
- symptoms={},
333
- vital_questions_asked=[],
334
- medical_history={},
335
- current_medications=[],
336
- allergies=[],
337
- severity_scores={},
338
- red_flags=[],
339
- assessment_complete=False,
340
- suggested_actions=[],
341
- consultation_stage="intake",
342
- retrieved_knowledge=[],
343
- confidence_scores={}
344
- )
345
-
346
- # For first few messages, do conversational intake
347
- if self.conversation_count[session_id] <= 3:
348
- return self.generate_conversational_response(message, history)
349
-
350
- # After gathering info, run workflow for recommendations
351
- try:
352
- result = self.workflow.invoke(state)
353
- return self.format_final_response(result)
354
- except Exception as e:
355
- print(f"Workflow error: {e}")
356
- return self.generate_conversational_response(message, history)
357
 
358
- def generate_conversational_response(self, message: str, history: List) -> str:
359
- """Generate conversational response for intake phase"""
360
- # Extract symptoms for context
361
- symptoms = self.extract_symptoms(message)
362
- red_flags = self.check_red_flags(message)
363
-
364
- # Handle emergencies immediately
365
- if red_flags:
366
- return f"""
367
- 🚨 **URGENT MEDICAL ATTENTION NEEDED** 🚨
368
-
369
- I notice you mentioned: {', '.join(red_flags)}
370
-
371
- Please seek immediate medical care:
372
- • Go to the nearest emergency room
373
- • Call emergency services (911)
374
- • Contact your doctor immediately
375
-
376
- These symptoms require professional medical evaluation right away.
377
- """
378
-
379
- # Generate contextual questions based on symptoms
380
- if symptoms:
381
- symptom_names = list(symptoms.keys())
382
- prompt = f"""
383
- You are a caring medical assistant. The patient mentioned these symptoms: {', '.join(symptom_names)}.
384
-
385
- Respond empathetically and ask 1-2 relevant follow-up questions about:
386
- - How long they've had these symptoms
387
- - Severity (mild, moderate, severe)
388
- - What makes it better or worse
389
- - Any other symptoms they're experiencing
390
-
391
- Be professional, caring, and concise. Don't provide treatment advice yet.
392
- """
393
- else:
394
- prompt = f"""
395
- You are a caring medical assistant. The patient said: "{message}"
396
-
397
- Respond empathetically and ask relevant questions to understand their health concern better.
398
- Be professional and caring.
399
- """
400
-
401
- return self.generate_llama_response(prompt)
402
 
403
- def generate_llama_response(self, prompt: str) -> str:
404
- """Generate response using Llama-2 with better formatting"""
405
- try:
406
- formatted_prompt = f"<s>[INST] {prompt} [/INST]"
407
- inputs = self.tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=512)
408
-
409
- if torch.cuda.is_available():
410
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
411
-
412
- with torch.no_grad():
413
- outputs = self.model.generate(
414
- **inputs,
415
- max_new_tokens=200,
416
- temperature=0.7,
417
- top_p=0.9,
418
- do_sample=True,
419
- pad_token_id=self.tokenizer.eos_token_id
420
- )
421
-
422
- # Decode response
423
- response = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
424
-
425
- # Clean up the response
426
- response = response.split('</s>')[0].strip()
427
- response = response.replace('<s>', '').replace('[INST]', '').replace('[/INST]', '').strip()
428
-
429
- # Remove any XML-like tags
430
- response = re.sub(r'<[^>]+>', '', response)
431
-
432
- return response if response else "I understand your concern. Can you tell me more about what you're experiencing?"
433
-
434
- except Exception as e:
435
- print(f"Error generating response: {e}")
436
- return "I understand your concern. Can you tell me more about your symptoms?"
437
 
438
- def format_final_response(self, state: MedicalState) -> str:
439
- """Format the final response with recommendations"""
440
- if state["consultation_stage"] == "emergency":
441
- return state["suggested_actions"][0] if state["suggested_actions"] else "Please seek immediate medical attention."
442
-
443
- # Format recommendations nicely
444
- if state["suggested_actions"]:
445
- response = "## 🏥 Medical Assessment & Recommendations\n\n"
446
- response += "Based on our conversation, here's what I recommend:\n"
447
-
448
- for action in state["suggested_actions"]:
449
- response += f"{action}\n"
450
-
451
- response += "\n---\n"
452
- response += "**Important Disclaimer:** I'm an AI assistant providing general health information. "
453
- response += "This is not a substitute for professional medical advice, diagnosis, or treatment. "
454
- response += "Always consult with qualified healthcare providers for medical concerns."
455
-
456
- return response
457
- else:
458
- return "Please provide more details about your symptoms so I can offer better guidance."
459
 
460
- # Initialize the medical assistant
461
- medical_assistant = EnhancedMedicalAssistant()
462
-
463
- # Gradio chat interface function
464
- def chat_interface(message, history):
465
- """Gradio chat interface"""
466
- try:
467
- return medical_assistant.generate_response(message, history)
468
- except Exception as e:
469
- print(f"Chat interface error: {e}")
470
- return f"I apologize, but I encountered an error. Please try rephrasing your question. Error: {str(e)}"
471
-
472
- # Create Gradio interface with enhanced styling
473
  demo = gr.ChatInterface(
474
- fn=chat_interface,
475
- title="🏥 Medical AI Assistant with medRAG",
476
- description="""
477
- I'm an AI medical assistant powered by medical knowledge retrieval (medRAG).
478
- I can help assess your symptoms and provide evidence-based recommendations.
479
-
480
- **How it works:**
481
- 1. Tell me about your symptoms
482
- 2. I'll ask follow-up questions
483
- 3. I'll provide personalized recommendations based on medical knowledge
484
-
485
- ⚠️ **Important**: I'm not a replacement for professional medical care. Always consult healthcare providers for serious concerns.
486
- """,
487
  examples=[
488
- "I have a bad cough and sore throat",
489
- "I've been having headaches for the past few days",
490
- "My stomach has been hurting after meals",
491
- "I have chest pain and trouble breathing"
492
  ],
493
- theme="soft",
494
- css="""
495
- /* Main container styling */
496
- .gradio-container {
497
- background: linear-gradient(to right, #f0f8ff, #f5f5f5);
498
- font-family: 'Arial', sans-serif;
499
- }
500
-
501
- /* Chat message styling */
502
- .message.user {
503
- background-color: #e3f2fd;
504
- border-radius: 12px;
505
- padding: 12px;
506
- margin: 8px;
507
- box-shadow: 0 2px 5px rgba(0,0,0,0.1);
508
- border-left: 4px solid #2196F3;
509
- }
510
-
511
- .message.bot {
512
- background-color: #f1f8e9;
513
- border-radius: 12px;
514
- padding: 12px;
515
- margin: 8px;
516
- box-shadow: 0 2px 5px rgba(0,0,0,0.1);
517
- border-left: 4px solid #4CAF50;
518
- }
519
-
520
- /* Enhanced medical styling */
521
- .bot h2 {
522
- color: #1976D2 !important;
523
- border-bottom: 2px solid #E0E0E0 !important;
524
- padding-bottom: 8px !important;
525
- }
526
- """
527
  )
528
 
529
  if __name__ == "__main__":
530
- demo.launch(share=True)
 
1
  import gradio as gr
2
+ import spaces
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from langgraph.graph import StateGraph, END
6
+ from typing import TypedDict, List, Tuple
7
+ import json
8
+
9
+ # Model configuration
10
+ LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
11
+ MEDITRON_MODEL = "epfl-llm/meditron-7b"
12
+
13
+ SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
14
+ Ask 1-2 follow-up questions at a time to gather more details about:
15
+ - Detailed description of symptoms
16
+ - Duration (when did it start?)
17
+ - Severity (scale of 1-10)
18
+ - Aggravating or alleviating factors
19
+ - Related symptoms
20
+ - Medical history
21
+ - Current medications and allergies
22
+ After collecting sufficient information (4-5 exchanges), summarize findings and suggest when they should seek professional care. Do NOT make specific diagnoses or recommend specific treatments.
23
+ Respond empathetically and clearly. Always be professional and thorough."""
24
+
25
+ MEDITRON_PROMPT = """<|im_start|>system
26
+ You are a specialized medical assistant focusing ONLY on suggesting over-the-counter medicines and home remedies based on patient information.
27
+ Based on the following patient information, provide ONLY:
28
+ 1. One specific over-the-counter medicine with proper adult dosing instructions
29
+ 2. One practical home remedy that might help
30
+ 3. Clear guidance on when to seek professional medical care
31
+ Be concise, practical, and focus only on general symptom relief. Do not diagnose. Include a disclaimer that you are not a licensed medical professional.
32
+ <|im_end|>
33
+ <|im_start|>user
34
+ Patient information: {patient_info}
35
+ <|im_end|>
36
+ <|im_start|>assistant
37
+ """
38
+
39
+ # Load models
40
+ print("Loading Llama-2 model...")
41
+ tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ LLAMA_MODEL,
44
+ torch_dtype=torch.float16,
45
+ device_map="auto"
46
+ )
47
+ print("Llama-2 model loaded successfully!")
48
+
49
+ print("Loading Meditron model...")
50
+ meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
51
+ meditron_model = AutoModelForCausalLM.from_pretrained(
52
+ MEDITRON_MODEL,
53
+ torch_dtype=torch.float16,
54
+ device_map="auto"
55
+ )
56
+ print("Meditron model loaded successfully!")
57
+
58
+ # Define the state for LangGraph
59
+ class ConversationState(TypedDict):
60
+ messages: List[str]
61
+ history: List[Tuple[str, str]]
62
+ current_message: str
63
+ conversation_turns: int
64
+ patient_data: List[str]
65
+ llama_response: str
66
+ final_response: str
67
+ should_get_suggestions: bool
68
+
69
+ def build_llama2_prompt(system_prompt, history, user_input):
70
+ """Format the conversation history and user input for Llama-2 chat models."""
71
+ prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
72
+
73
+ # Add conversation history
74
+ for user_msg, assistant_msg in history:
75
+ prompt += f"{user_msg} [/INST] {assistant_msg} </s><s>[INST] "
76
+
77
+ # Add the current user input
78
+ prompt += f"{user_input} [/INST] "
79
+
80
+ return prompt
81
+
82
+ def get_meditron_suggestions(patient_info):
83
+ """Use Meditron model to generate medicine and remedy suggestions."""
84
+ prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
85
+ inputs = meditron_tokenizer(prompt, return_tensors="pt").to(meditron_model.device)
86
+
87
+ with torch.no_grad():
88
+ outputs = meditron_model.generate(
89
+ inputs.input_ids,
90
+ attention_mask=inputs.attention_mask,
91
+ max_new_tokens=256,
92
+ temperature=0.7,
93
+ top_p=0.9,
94
+ do_sample=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
98
+ return suggestion
99
+
100
+ # LangGraph Node Functions
101
+ def initialize_conversation(state: ConversationState) -> ConversationState:
102
+ """Initialize or update conversation state."""
103
+ # Update conversation turns
104
+ state["conversation_turns"] = state.get("conversation_turns", 0) + 1
105
+
106
+ # Add current message to patient data
107
+ if "patient_data" not in state:
108
+ state["patient_data"] = []
109
+ state["patient_data"].append(state["current_message"])
110
+
111
+ # Determine if we should get suggestions (after 4 turns)
112
+ state["should_get_suggestions"] = state["conversation_turns"] >= 4
113
+
114
+ return state
115
+
116
+ def generate_llama_response(state: ConversationState) -> ConversationState:
117
+ """Generate response using Llama-2 model."""
118
+ # Build the prompt with proper Llama-2 formatting
119
+ prompt = build_llama2_prompt(SYSTEM_PROMPT, state["history"], state["current_message"])
120
+
121
+ # Add summarization instruction after 4 turns
122
+ if state["conversation_turns"] >= 4:
123
+ prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
124
+
125
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
126
+
127
+ # Generate the Llama-2 response
128
+ with torch.no_grad():
129
+ outputs = model.generate(
130
+ inputs.input_ids,
131
+ attention_mask=inputs.attention_mask,
132
+ max_new_tokens=512,
133
+ temperature=0.7,
134
+ top_p=0.9,
135
+ do_sample=True,
136
+ pad_token_id=tokenizer.eos_token_id
137
+ )
138
 
139
+ # Decode and extract Llama-2's response
140
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
141
+ llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
142
+
143
+ state["llama_response"] = llama_response
144
+ return state
145
+
146
+ def generate_medicine_suggestions(state: ConversationState) -> ConversationState:
147
+ """Generate medicine suggestions using Meditron model."""
148
+ # Collect full patient conversation
149
+ full_patient_info = "\n".join(state["patient_data"]) + "\n\nSummary: " + state["llama_response"]
150
+
151
+ # Get medicine suggestions
152
+ medicine_suggestions = get_meditron_suggestions(full_patient_info)
153
+
154
+ # Format final response
155
+ final_response = (
156
+ f"{state['llama_response']}\n\n"
157
+ f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n"
158
+ f"{medicine_suggestions}"
159
+ )
160
+
161
+ state["final_response"] = final_response
162
+ return state
163
+
164
+ def finalize_response(state: ConversationState) -> ConversationState:
165
+ """Finalize the response without medicine suggestions."""
166
+ state["final_response"] = state["llama_response"]
167
+ return state
168
+
169
+ def should_get_suggestions(state: ConversationState) -> str:
170
+ """Conditional edge to determine next step."""
171
+ if state["should_get_suggestions"]:
172
+ return "get_suggestions"
173
+ else:
174
+ return "finalize"
175
+
176
+ # Create the LangGraph workflow
177
+ def create_medical_workflow():
178
+ """Create the LangGraph workflow for medical assistant."""
179
+ workflow = StateGraph(ConversationState)
180
+
181
+ # Add nodes
182
+ workflow.add_node("initialize", initialize_conversation)
183
+ workflow.add_node("generate_llama", generate_llama_response)
184
+ workflow.add_node("get_suggestions", generate_medicine_suggestions)
185
+ workflow.add_node("finalize", finalize_response)
186
+
187
+ # Define the flow
188
+ workflow.set_entry_point("initialize")
189
+ workflow.add_edge("initialize", "generate_llama")
190
+ workflow.add_conditional_edges(
191
+ "generate_llama",
192
+ should_get_suggestions,
193
+ {
194
+ "get_suggestions": "get_suggestions",
195
+ "finalize": "finalize"
196
+ }
197
+ )
198
+ workflow.add_edge("get_suggestions", END)
199
+ workflow.add_edge("finalize", END)
200
+
201
+ return workflow.compile()
202
+
203
+ # Initialize the workflow
204
+ medical_workflow = create_medical_workflow()
205
+
206
+ # Conversation state tracking (for Gradio session management)
207
+ conversation_states = {}
208
+
209
+ @spaces.GPU
210
+ def generate_response(message, history):
211
+ """Generate a response using the LangGraph workflow."""
212
+ session_id = "default-session"
213
+
214
+ # Initialize or get existing conversation state
215
+ if session_id not in conversation_states:
216
+ conversation_states[session_id] = {
217
+ "messages": [],
218
+ "history": [],
219
+ "conversation_turns": 0,
220
+ "patient_data": []
221
+ }
222
 
223
+ # Update state with current message and history
224
+ state = conversation_states[session_id].copy()
225
+ state["current_message"] = message
226
+ state["history"] = history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
+ # Run the workflow
229
+ result = medical_workflow.invoke(state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ # Update the stored conversation state
232
+ conversation_states[session_id] = {
233
+ "messages": result["messages"] if "messages" in result else [],
234
+ "history": history,
235
+ "conversation_turns": result["conversation_turns"],
236
+ "patient_data": result["patient_data"]
237
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
+ return result["final_response"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
+ # Create the Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
242
  demo = gr.ChatInterface(
243
+ fn=generate_response,
244
+ title="Medical Assistant with LangGraph & Medicine Suggestions",
245
+ description="Tell me about your symptoms, and after gathering enough information, I'll suggest potential remedies using an AI workflow.",
 
 
 
 
 
 
 
 
 
 
246
  examples=[
247
+ "I have a cough and my throat hurts",
248
+ "I've been having headaches for a week",
249
+ "My stomach has been hurting since yesterday"
 
250
  ],
251
+ theme="soft"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  )
253
 
254
  if __name__ == "__main__":
255
+ demo.launch()
requirements.txt CHANGED
@@ -1,45 +1,18 @@
1
- # Core dependencies
2
- gradio>=4.0.0
3
- torch>=2.0.0
4
- transformers>=4.30.0
5
- langgraph>=0.3.27
6
- langchain-core>=0.2.38
7
 
8
- # Vector search and embeddings
9
- sentence-transformers>=2.2.2
10
- faiss-cpu>=1.7.4
11
 
12
- # Data processing
13
- numpy>=1.24.0
14
- typing-extensions>=4.5.0
15
-
16
- # LangGraph ecosystem components
17
- langsmith>=0.1.63
18
- langgraph-sdk>=0.1.66
19
- langgraph-checkpoint>=2.0.23
20
-
21
- # Web serving utilities
22
- httpx>=0.25.0
23
- uvicorn>=0.26.0
24
- sse-starlette>=2.1.0,<2.2.0
25
- uvloop>=0.18.0
26
- httptools>=0.5.0
27
 
28
- # Serialization and utilities
29
- orjson>=3.9.7,<3.10.17
30
- jsonschema-rs>=0.20.0
31
- structlog>=24.1.0
32
- cloudpickle>=3.0.0
33
- tenacity>=8.0.0
34
 
35
- # Model acceleration (optional but recommended)
36
- accelerate>=0.20.0
37
- safetensors>=0.3.1
38
- bitsandbytes>=0.40.0
39
-
40
- # For Hugging Face model access
41
- huggingface_hub>=0.16.0
42
-
43
- # Optional - specific model support
44
- langchain_anthropic>=0.0.5
45
- langchain_openai>=0.0.2
 
1
+ # Core packages
2
+ gradio==4.24.0
3
+ spaces==0.21.1
 
 
 
4
 
5
+ # Transformers & tokenization
6
+ transformers==4.40.1
7
+ torch>=2.1.0
8
 
9
+ # LangGraph
10
+ langgraph==0.0.41
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Optional but often required for transformers
13
+ accelerate==0.30.1
14
+ sentencepiece==0.1.99
15
+ protobuf==4.25.3
 
 
16
 
17
+ # Utility
18
+ typing-extensions>=4.5.0