techindia2025 commited on
Commit
bdce857
·
verified ·
1 Parent(s): fffe187

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +412 -75
app.py CHANGED
@@ -2,96 +2,433 @@ import gradio as gr
2
  import spaces
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
5
 
6
- MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
 
 
 
 
 
 
 
 
 
9
 
10
- Ask 1-2 follow-up questions at a time to gather more details about:
11
- - Detailed description of symptoms
12
- - Duration (when did it start?)
13
- - Severity (scale of 1-10)
14
- - Aggravating or alleviating factors
15
- - Related symptoms
16
- - Medical history
17
- - Current medications and allergies
18
 
19
- After collecting sufficient information (4-5 exchanges), summarize findings and suggest when they should seek professional care. Do NOT make specific diagnoses or recommend specific treatments.
 
 
 
 
 
 
 
 
20
 
21
- Respond empathetically and clearly. Always be professional and thorough."""
22
-
23
- print("Loading model...")
24
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
25
- model = AutoModelForCausalLM.from_pretrained(
26
- MODEL_NAME,
27
- torch_dtype=torch.float16,
28
- device_map="auto"
29
- )
30
- print("Model loaded successfully!")
31
-
32
- # Conversation state tracking
33
- conversation_turns = {}
34
-
35
- def build_llama2_prompt(system_prompt, history, user_input):
36
- """Format the conversation history and user input for Llama-2 chat models."""
37
- prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # Add conversation history
40
- for user_msg, assistant_msg in history:
41
- prompt += f"{user_msg} [/INST] {assistant_msg} </s><s>[INST] "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Add the current user input
44
- prompt += f"{user_input} [/INST] "
 
 
 
 
 
 
 
 
 
45
 
46
- return prompt
47
-
48
- @spaces.GPU
49
- def generate_response(message, history):
50
- """Generate a response using the Llama-2 model with proper formatting."""
51
- # Track conversation turns
52
- session_id = "default-session"
53
- if session_id not in conversation_turns:
54
- conversation_turns[session_id] = 0
55
- conversation_turns[session_id] += 1
56
-
57
- # Build the prompt with proper Llama-2 formatting
58
- prompt = build_llama2_prompt(SYSTEM_PROMPT, history, message)
59
-
60
- # Add summarization instruction after 4 turns
61
- if conversation_turns[session_id] >= 4:
62
- prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
63
-
64
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
65
-
66
- # Generate the response
67
- with torch.no_grad():
68
- outputs = model.generate(
69
- inputs.input_ids,
70
- max_new_tokens=512,
71
- temperature=0.7,
72
- top_p=0.9,
73
- do_sample=True,
74
- pad_token_id=tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- # Decode and extract only the assistant's response
78
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
79
- assistant_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- return assistant_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- # Create the Gradio interface
 
 
 
 
 
 
 
 
84
  demo = gr.ChatInterface(
85
- fn=generate_response,
86
- title="Medical Assistant Chatbot",
87
- description="Ask about your symptoms and I'll help gather relevant information.",
 
 
 
 
 
88
  examples=[
89
- "I have a cough and my throat hurts",
90
- "I've been having headaches for a week",
91
- "My stomach has been hurting since yesterday"
 
92
  ],
93
- theme="soft"
 
 
 
 
94
  )
95
 
96
  if __name__ == "__main__":
97
- demo.launch()
 
2
  import spaces
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from langgraph.graph import StateGraph, END
6
+ from typing import TypedDict, List, Dict, Optional
7
+ from datetime import datetime
8
+ import json
9
 
10
+ # Enhanced State Management
11
+ class MedicalState(TypedDict):
12
+ patient_id: str
13
+ conversation_history: List[Dict]
14
+ symptoms: Dict[str, any]
15
+ vital_questions_asked: List[str]
16
+ medical_history: Dict
17
+ current_medications: List[str]
18
+ allergies: List[str]
19
+ severity_scores: Dict[str, int]
20
+ red_flags: List[str]
21
+ assessment_complete: bool
22
+ suggested_actions: List[str]
23
+ consultation_stage: str # intake, assessment, summary, recommendations
24
 
25
+ # Medical Knowledge Base
26
+ MEDICAL_CATEGORIES = {
27
+ "respiratory": ["cough", "shortness of breath", "chest pain", "wheezing"],
28
+ "gastrointestinal": ["nausea", "vomiting", "diarrhea", "stomach pain", "heartburn"],
29
+ "neurological": ["headache", "dizziness", "numbness", "tingling"],
30
+ "musculoskeletal": ["joint pain", "muscle pain", "back pain", "stiffness"],
31
+ "cardiovascular": ["chest pain", "palpitations", "swelling", "fatigue"],
32
+ "dermatological": ["rash", "itching", "skin changes", "wounds"],
33
+ "mental_health": ["anxiety", "depression", "sleep issues", "stress"]
34
+ }
35
 
36
+ RED_FLAGS = [
37
+ "chest pain", "difficulty breathing", "severe headache", "high fever",
38
+ "blood in stool", "blood in urine", "severe abdominal pain",
39
+ "sudden vision changes", "loss of consciousness", "severe allergic reaction"
40
+ ]
 
 
 
41
 
42
+ VITAL_QUESTIONS = {
43
+ "symptom_onset": "When did your symptoms first start?",
44
+ "severity": "On a scale of 1-10, how severe would you rate your symptoms?",
45
+ "triggers": "What makes your symptoms better or worse?",
46
+ "associated_symptoms": "Are you experiencing any other symptoms?",
47
+ "medical_history": "Do you have any chronic medical conditions?",
48
+ "medications": "Are you currently taking any medications?",
49
+ "allergies": "Do you have any known allergies?"
50
+ }
51
 
52
+ class EnhancedMedicalAssistant:
53
+ def __init__(self):
54
+ self.load_models()
55
+ self.setup_langgraph()
56
+
57
+ def load_models(self):
58
+ """Load the AI models"""
59
+ print("Loading models...")
60
+ # Llama-2 for conversation
61
+ self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
62
+ self.model = AutoModelForCausalLM.from_pretrained(
63
+ "meta-llama/Llama-2-7b-chat-hf",
64
+ torch_dtype=torch.float16,
65
+ device_map="auto"
66
+ )
67
+
68
+ # Meditron for medical suggestions
69
+ self.meditron_tokenizer = AutoTokenizer.from_pretrained("epfl-llm/meditron-7b")
70
+ self.meditron_model = AutoModelForCausalLM.from_pretrained(
71
+ "epfl-llm/meditron-7b",
72
+ torch_dtype=torch.float16,
73
+ device_map="auto"
74
+ )
75
+ print("Models loaded successfully!")
76
+
77
+ def setup_langgraph(self):
78
+ """Setup LangGraph workflow"""
79
+ workflow = StateGraph(MedicalState)
80
+
81
+ # Add nodes
82
+ workflow.add_node("intake", self.patient_intake)
83
+ workflow.add_node("symptom_assessment", self.assess_symptoms)
84
+ workflow.add_node("risk_evaluation", self.evaluate_risks)
85
+ workflow.add_node("generate_recommendations", self.generate_recommendations)
86
+ workflow.add_node("emergency_triage", self.emergency_triage)
87
+
88
+ # Define edges
89
+ workflow.set_entry_point("intake")
90
+ workflow.add_conditional_edges(
91
+ "intake",
92
+ self.route_after_intake,
93
+ {
94
+ "continue_assessment": "symptom_assessment",
95
+ "emergency": "emergency_triage",
96
+ "complete": "generate_recommendations"
97
+ }
98
+ )
99
+ workflow.add_edge("symptom_assessment", "risk_evaluation")
100
+ workflow.add_conditional_edges(
101
+ "risk_evaluation",
102
+ self.route_after_risk_eval,
103
+ {
104
+ "emergency": "emergency_triage",
105
+ "continue": "generate_recommendations",
106
+ "need_more_info": "symptom_assessment"
107
+ }
108
+ )
109
+ workflow.add_edge("generate_recommendations", END)
110
+ workflow.add_edge("emergency_triage", END)
111
+
112
+ self.workflow = workflow.compile()
113
 
114
+ def patient_intake(self, state: MedicalState) -> MedicalState:
115
+ """Initial patient intake and basic information gathering"""
116
+ last_message = state["conversation_history"][-1]["content"] if state["conversation_history"] else ""
117
+
118
+ # Extract symptoms and categorize them
119
+ detected_symptoms = self.extract_symptoms(last_message)
120
+ state["symptoms"].update(detected_symptoms)
121
+
122
+ # Check for red flags
123
+ red_flags = self.check_red_flags(last_message)
124
+ if red_flags:
125
+ state["red_flags"].extend(red_flags)
126
+
127
+ # Determine what vital questions still need to be asked
128
+ missing_questions = self.get_missing_vital_questions(state)
129
+
130
+ if missing_questions and len(state["conversation_history"]) < 6:
131
+ state["consultation_stage"] = "intake"
132
+ return state
133
+ else:
134
+ state["consultation_stage"] = "assessment"
135
+ return state
136
 
137
+ def assess_symptoms(self, state: MedicalState) -> MedicalState:
138
+ """Detailed symptom assessment"""
139
+ # Analyze symptom patterns and severity
140
+ for symptom, details in state["symptoms"].items():
141
+ if "severity" not in details:
142
+ # Need to ask about severity
143
+ state["consultation_stage"] = "assessment"
144
+ return state
145
+
146
+ state["assessment_complete"] = True
147
+ return state
148
 
149
+ def evaluate_risks(self, state: MedicalState) -> MedicalState:
150
+ """Evaluate patient risks and urgency"""
151
+ risk_score = 0
152
+
153
+ # Check red flags
154
+ if state["red_flags"]:
155
+ risk_score += len(state["red_flags"]) * 3
156
+
157
+ # Check severity scores
158
+ for severity in state["severity_scores"].values():
159
+ if severity >= 8:
160
+ risk_score += 2
161
+ elif severity >= 6:
162
+ risk_score += 1
163
+
164
+ # Check symptom duration and progression
165
+ # (Implementation would analyze timeline)
166
+
167
+ if risk_score >= 5:
168
+ state["consultation_stage"] = "emergency"
169
+ else:
170
+ state["consultation_stage"] = "recommendations"
171
+
172
+ return state
173
+
174
+ def generate_recommendations(self, state: MedicalState) -> MedicalState:
175
+ """Generate treatment recommendations and care suggestions"""
176
+ patient_summary = self.create_patient_summary(state)
177
+
178
+ # Use Meditron for medical recommendations
179
+ recommendations = self.get_meditron_recommendations(patient_summary)
180
+ state["suggested_actions"] = recommendations
181
+
182
+ return state
183
+
184
+ def emergency_triage(self, state: MedicalState) -> MedicalState:
185
+ """Handle emergency situations"""
186
+ emergency_response = {
187
+ "urgent_care_needed": True,
188
+ "recommended_action": "Seek immediate medical attention",
189
+ "reasons": state["red_flags"],
190
+ "instructions": "Go to the nearest emergency room or call emergency services"
191
+ }
192
+ state["suggested_actions"] = [emergency_response]
193
+ return state
194
+
195
+ def route_after_intake(self, state: MedicalState):
196
+ """Route decision after intake"""
197
+ if state["red_flags"]:
198
+ return "emergency"
199
+ elif len(state["vital_questions_asked"]) < 5:
200
+ return "continue_assessment"
201
+ else:
202
+ return "complete"
203
+
204
+ def route_after_risk_eval(self, state: MedicalState):
205
+ """Route decision after risk evaluation"""
206
+ if state["consultation_stage"] == "emergency":
207
+ return "emergency"
208
+ elif state["assessment_complete"]:
209
+ return "continue"
210
+ else:
211
+ return "need_more_info"
212
+
213
+ def extract_symptoms(self, text: str) -> Dict:
214
+ """Extract and categorize symptoms from patient text"""
215
+ symptoms = {}
216
+ text_lower = text.lower()
217
+
218
+ for category, symptom_list in MEDICAL_CATEGORIES.items():
219
+ for symptom in symptom_list:
220
+ if symptom in text_lower:
221
+ symptoms[symptom] = {
222
+ "category": category,
223
+ "mentioned_at": datetime.now().isoformat(),
224
+ "context": text
225
+ }
226
+
227
+ return symptoms
228
+
229
+ def check_red_flags(self, text: str) -> List[str]:
230
+ """Check for emergency red flags"""
231
+ found_flags = []
232
+ text_lower = text.lower()
233
+
234
+ for flag in RED_FLAGS:
235
+ if flag in text_lower:
236
+ found_flags.append(flag)
237
+
238
+ return found_flags
239
+
240
+ def get_missing_vital_questions(self, state: MedicalState) -> List[str]:
241
+ """Determine which vital questions haven't been asked"""
242
+ asked = state["vital_questions_asked"]
243
+ return [q for q in VITAL_QUESTIONS.keys() if q not in asked]
244
+
245
+ def create_patient_summary(self, state: MedicalState) -> str:
246
+ """Create a comprehensive patient summary"""
247
+ summary = f"""
248
+ Patient Summary:
249
+ Symptoms: {json.dumps(state['symptoms'], indent=2)}
250
+ Medical History: {state['medical_history']}
251
+ Current Medications: {state['current_medications']}
252
+ Allergies: {state['allergies']}
253
+ Severity Scores: {state['severity_scores']}
254
+ Conversation History: {[msg['content'] for msg in state['conversation_history'][-3:]]}
255
+ """
256
+ return summary
257
+
258
+ def get_meditron_recommendations(self, patient_summary: str) -> List[str]:
259
+ """Get medical recommendations using Meditron model"""
260
+ prompt = f"""
261
+ Based on the following patient information, provide:
262
+ 1. Specific over-the-counter medications with dosing
263
+ 2. Home remedies and self-care measures
264
+ 3. When to seek professional medical care
265
+ 4. Follow-up recommendations
266
+
267
+ Patient Information:
268
+ {patient_summary}
269
+
270
+ Response:"""
271
+
272
+ inputs = self.meditron_tokenizer(prompt, return_tensors="pt").to(self.meditron_model.device)
273
+
274
+ with torch.no_grad():
275
+ outputs = self.meditron_model.generate(
276
+ inputs.input_ids,
277
+ attention_mask=inputs.attention_mask,
278
+ max_new_tokens=400,
279
+ temperature=0.7,
280
+ top_p=0.9,
281
+ do_sample=True
282
+ )
283
+
284
+ recommendation = self.meditron_tokenizer.decode(
285
+ outputs[0][inputs.input_ids.shape[1]:],
286
+ skip_special_tokens=True
287
+ )
288
+
289
+ return [recommendation]
290
+
291
+ def generate_response(self, message: str, history: List) -> str:
292
+ """Main response generation function"""
293
+ # Initialize or update state
294
+ state = MedicalState(
295
+ patient_id="session_001",
296
+ conversation_history=history + [{"role": "user", "content": message}],
297
+ symptoms={},
298
+ vital_questions_asked=[],
299
+ medical_history={},
300
+ current_medications=[],
301
+ allergies=[],
302
+ severity_scores={},
303
+ red_flags=[],
304
+ assessment_complete=False,
305
+ suggested_actions=[],
306
+ consultation_stage="intake"
307
  )
308
+
309
+ # Run through LangGraph workflow
310
+ result = self.workflow.invoke(state)
311
+
312
+ # Generate contextual response
313
+ response = self.generate_contextual_response(result, message)
314
+
315
+ return response
316
+
317
+ def generate_contextual_response(self, state: MedicalState, user_message: str) -> str:
318
+ """Generate a contextual response based on the current state"""
319
+ if state["consultation_stage"] == "emergency":
320
+ return self.format_emergency_response(state)
321
+ elif state["consultation_stage"] == "intake":
322
+ return self.format_intake_response(state, user_message)
323
+ elif state["consultation_stage"] == "assessment":
324
+ return self.format_assessment_response(state)
325
+ elif state["consultation_stage"] == "recommendations":
326
+ return self.format_recommendations_response(state)
327
+ else:
328
+ return self.format_default_response(user_message)
329
 
330
+ def format_emergency_response(self, state: MedicalState) -> str:
331
+ """Format emergency response"""
332
+ return f"""
333
+ 🚨 URGENT MEDICAL ATTENTION NEEDED 🚨
334
+
335
+ Based on your symptoms, I recommend seeking immediate medical care because:
336
+ {', '.join(state['red_flags'])}
337
+
338
+ Please:
339
+ - Go to the nearest emergency room, OR
340
+ - Call emergency services (911), OR
341
+ - Contact your doctor immediately
342
+
343
+ This is not a diagnosis, but these symptoms warrant immediate professional evaluation.
344
+ """
345
 
346
+ def format_intake_response(self, state: MedicalState, user_message: str) -> str:
347
+ """Format intake response with follow-up questions"""
348
+ # Use Llama-2 to generate empathetic response
349
+ prompt = f"""
350
+ You are a caring virtual doctor. The patient said: "{user_message}"
351
+
352
+ Respond empathetically and ask 1-2 specific follow-up questions about:
353
+ - Symptom details (duration, severity, triggers)
354
+ - Associated symptoms
355
+ - Medical history if relevant
356
+
357
+ Be professional, caring, and thorough.
358
+ """
359
+
360
+ return self.generate_llama_response(prompt)
361
+
362
+ def format_assessment_response(self, state: MedicalState) -> str:
363
+ """Format detailed assessment response"""
364
+ return "Let me gather a bit more information to better understand your condition..."
365
+
366
+ def format_recommendations_response(self, state: MedicalState) -> str:
367
+ """Format final recommendations"""
368
+ recommendations = "\n".join(state["suggested_actions"])
369
+ return f"""
370
+ Based on our consultation, here's my assessment and recommendations:
371
+
372
+ {recommendations}
373
+
374
+ **Important Disclaimer:** I am an AI assistant, not a licensed medical professional.
375
+ These suggestions are for informational purposes only. Please consult with a
376
+ healthcare provider for proper diagnosis and treatment.
377
+ """
378
+
379
+ def format_default_response(self, user_message: str) -> str:
380
+ """Format default response"""
381
+ return self.generate_llama_response(f"Respond professionally to: {user_message}")
382
+
383
+ def generate_llama_response(self, prompt: str) -> str:
384
+ """Generate response using Llama-2"""
385
+ formatted_prompt = f"<s>[INST] {prompt} [/INST] "
386
+ inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
387
+
388
+ with torch.no_grad():
389
+ outputs = self.model.generate(
390
+ inputs.input_ids,
391
+ attention_mask=inputs.attention_mask,
392
+ max_new_tokens=300,
393
+ temperature=0.7,
394
+ top_p=0.9,
395
+ do_sample=True,
396
+ pad_token_id=self.tokenizer.eos_token_id
397
+ )
398
+
399
+ response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
400
+ return response.split('</s>')[0].strip()
401
 
402
+ # Initialize the enhanced medical assistant
403
+ medical_assistant = EnhancedMedicalAssistant()
404
+
405
+ @spaces.GPU
406
+ def chat_interface(message, history):
407
+ """Gradio chat interface"""
408
+ return medical_assistant.generate_response(message, history)
409
+
410
+ # Create Gradio interface
411
  demo = gr.ChatInterface(
412
+ fn=chat_interface,
413
+ title="🏥 Advanced Medical AI Assistant",
414
+ description="""
415
+ I'm an AI medical assistant that can help assess your symptoms and provide guidance.
416
+ I'll ask relevant questions to better understand your condition and provide appropriate recommendations.
417
+
418
+ ⚠️ **Important**: I'm not a replacement for professional medical care. Always consult healthcare providers for serious concerns.
419
+ """,
420
  examples=[
421
+ "I've been having severe chest pain for the last hour",
422
+ "I have a persistent cough that's been going on for 2 weeks",
423
+ "I'm experiencing nausea and stomach pain after eating",
424
+ "I have a headache and feel dizzy"
425
  ],
426
+ theme="soft",
427
+ css="""
428
+ .message.user { background-color: #e3f2fd; }
429
+ .message.bot { background-color: #f1f8e9; }
430
+ """
431
  )
432
 
433
  if __name__ == "__main__":
434
+ demo.launch(share=True)