techindia2025 commited on
Commit
1cb1a8e
·
verified ·
1 Parent(s): c66e1bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -100
app.py CHANGED
@@ -3,11 +3,11 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import spaces
5
 
6
- # Model configuration
7
- LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
8
- MEDITRON_MODEL = "epfl-llm/meditron-7b"
9
 
10
- SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
 
11
  Ask 1-2 follow-up questions at a time to gather more details about:
12
  - Detailed description of symptoms
13
  - Duration (when did it start?)
@@ -19,30 +19,23 @@ Ask 1-2 follow-up questions at a time to gather more details about:
19
  After collecting sufficient information (4-5 exchanges), summarize findings and suggest when they should seek professional care. Do NOT make specific diagnoses or recommend specific treatments.
20
  Respond empathetically and clearly. Always be professional and thorough."""
21
 
22
- MEDITRON_PROMPT = """<|im_start|>system
23
- You are a specialized medical assistant focusing ONLY on suggesting over-the-counter medicines and home remedies based on patient information.
24
- Based on the following patient information, provide ONLY:
25
  1. One specific over-the-counter medicine with proper adult dosing instructions
26
  2. One practical home remedy that might help
27
  3. Clear guidance on when to seek professional medical care
 
28
  Be concise, practical, and focus only on general symptom relief. Do not diagnose. Include a disclaimer that you are not a licensed medical professional.
29
- <|im_end|>
30
- <|im_start|>user
31
- Patient information: {patient_info}
32
- <|im_end|>
33
- <|im_start|>assistant
34
- """
35
 
36
- # Global variables to store models (will be loaded lazily)
37
- llama_model = None
38
- llama_tokenizer = None
39
- meditron_model = None
40
- meditron_tokenizer = None
41
  conversation_turns = 0
42
  patient_data = []
43
 
44
- def build_llama2_prompt(system_prompt, history, user_input):
45
- """Format the conversation history and user input for Llama-2 chat models."""
46
  prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
47
 
48
  # Add conversation history
@@ -55,126 +48,143 @@ def build_llama2_prompt(system_prompt, history, user_input):
55
  return prompt
56
 
57
  @spaces.GPU
58
- def load_models_if_needed():
59
- """Load models only when GPU is available and only if not already loaded."""
60
- global llama_model, llama_tokenizer, meditron_model, meditron_tokenizer
61
-
62
- if llama_model is None:
63
- print("Loading Llama-2 model...")
64
- llama_tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
65
- llama_model = AutoModelForCausalLM.from_pretrained(
66
- LLAMA_MODEL,
67
- torch_dtype=torch.float16,
68
- device_map="auto"
69
- )
70
- print("Llama-2 model loaded successfully!")
71
-
72
- if meditron_model is None:
73
- print("Loading Meditron model...")
74
- meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
75
- meditron_model = AutoModelForCausalLM.from_pretrained(
76
- MEDITRON_MODEL,
77
  torch_dtype=torch.float16,
78
- device_map="auto"
 
79
  )
80
- print("Meditron model loaded successfully!")
81
 
82
  @spaces.GPU
83
- def get_meditron_suggestions(patient_info):
84
- """Use Meditron model to generate medicine and remedy suggestions."""
85
- load_models_if_needed() # Ensure models are loaded
86
 
87
- prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
88
- inputs = meditron_tokenizer(prompt, return_tensors="pt")
 
 
89
 
90
  # Move inputs to the same device as the model
91
  if torch.cuda.is_available():
92
- inputs = {k: v.to(meditron_model.device) for k, v in inputs.items()}
93
 
94
  with torch.no_grad():
95
- outputs = meditron_model.generate(
96
  inputs["input_ids"],
97
  attention_mask=inputs["attention_mask"],
98
- max_new_tokens=256,
99
  temperature=0.7,
100
  top_p=0.9,
101
  do_sample=True,
102
- pad_token_id=meditron_tokenizer.eos_token_id
103
  )
104
 
105
- suggestion = meditron_tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
106
  return suggestion
107
 
108
  @spaces.GPU
109
  def generate_response(message, history):
110
- """Generate a response using both models."""
111
  global conversation_turns, patient_data
112
 
113
- # Load models if needed
114
- load_models_if_needed()
115
 
116
  # Track conversation turns
117
  conversation_turns += 1
118
 
119
- # Store the entire conversation for reference
120
  patient_data.append(message)
121
 
122
- # Build the prompt with proper Llama-2 formatting
123
- prompt = build_llama2_prompt(SYSTEM_PROMPT, history, message)
124
-
125
- # Add summarization instruction after 4 turns
126
- if conversation_turns >= 4:
127
- prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
128
-
129
- inputs = llama_tokenizer(prompt, return_tensors="pt")
130
-
131
- # Move inputs to the same device as the model
132
- if torch.cuda.is_available():
133
- inputs = {k: v.to(llama_model.device) for k, v in inputs.items()}
134
-
135
- # Generate the Llama-2 response
136
- with torch.no_grad():
137
- outputs = llama_model.generate(
138
- inputs["input_ids"],
139
- attention_mask=inputs["attention_mask"],
140
- max_new_tokens=512,
141
- temperature=0.7,
142
- top_p=0.9,
143
- do_sample=True,
144
- pad_token_id=llama_tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  )
146
-
147
- # Decode and extract Llama-2's response
148
- full_response = llama_tokenizer.decode(outputs[0], skip_special_tokens=False)
149
- llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
150
-
151
- # After 4 turns, add medicine suggestions from Meditron
152
- if conversation_turns >= 4:
153
- # Collect full patient conversation
154
- full_patient_info = "\n".join(patient_data) + "\n\nSummary: " + llama_response
155
 
156
- # Get medicine suggestions
157
- medicine_suggestions = get_meditron_suggestions(full_patient_info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- # Format final response
 
 
 
 
 
 
 
160
  final_response = (
161
- f"{llama_response}\n\n"
162
- f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n"
163
- f"{medicine_suggestions}"
164
  )
 
165
  return final_response
166
-
167
- return llama_response
168
 
169
  # Create the Gradio interface
170
  demo = gr.ChatInterface(
171
  fn=generate_response,
172
- title="Medical Assistant with Medicine Suggestions",
173
- description="Tell me about your symptoms, and after gathering enough information, I'll suggest potential remedies.",
174
  examples=[
175
- "I have a cough and my throat hurts",
176
- "I've been having headaches for a week",
177
- "My stomach has been hurting since yesterday"
178
  ],
179
  theme="soft"
180
  )
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import spaces
5
 
6
+ # Model configuration - Using only Me-LLaMA 13B-chat
7
+ ME_LLAMA_MODEL = "clinicalnlplab/me-llama-13b-chat"
 
8
 
9
+ # System prompts for different phases
10
+ CONSULTATION_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
11
  Ask 1-2 follow-up questions at a time to gather more details about:
12
  - Detailed description of symptoms
13
  - Duration (when did it start?)
 
19
  After collecting sufficient information (4-5 exchanges), summarize findings and suggest when they should seek professional care. Do NOT make specific diagnoses or recommend specific treatments.
20
  Respond empathetically and clearly. Always be professional and thorough."""
21
 
22
+ MEDICINE_PROMPT = """You are a specialized medical assistant. Based on the patient information gathered, provide:
 
 
23
  1. One specific over-the-counter medicine with proper adult dosing instructions
24
  2. One practical home remedy that might help
25
  3. Clear guidance on when to seek professional medical care
26
+
27
  Be concise, practical, and focus only on general symptom relief. Do not diagnose. Include a disclaimer that you are not a licensed medical professional.
 
 
 
 
 
 
28
 
29
+ Patient information: {patient_info}"""
30
+
31
+ # Global variables
32
+ me_llama_model = None
33
+ me_llama_tokenizer = None
34
  conversation_turns = 0
35
  patient_data = []
36
 
37
+ def build_me_llama_prompt(system_prompt, history, user_input):
38
+ """Format the conversation for Me-LLaMA chat model."""
39
  prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
40
 
41
  # Add conversation history
 
48
  return prompt
49
 
50
  @spaces.GPU
51
+ def load_model_if_needed():
52
+ """Load Me-LLaMA model only when GPU is available."""
53
+ global me_llama_model, me_llama_tokenizer
54
+
55
+ if me_llama_model is None:
56
+ print("Loading Me-LLaMA 13B-chat model...")
57
+ me_llama_tokenizer = AutoTokenizer.from_pretrained(ME_LLAMA_MODEL)
58
+ me_llama_model = AutoModelForCausalLM.from_pretrained(
59
+ ME_LLAMA_MODEL,
 
 
 
 
 
 
 
 
 
 
60
  torch_dtype=torch.float16,
61
+ device_map="auto",
62
+ trust_remote_code=True
63
  )
64
+ print("Me-LLaMA 13B-chat model loaded successfully!")
65
 
66
  @spaces.GPU
67
+ def generate_medicine_suggestions(patient_info):
68
+ """Use Me-LLaMA to generate medicine and remedy suggestions."""
69
+ load_model_if_needed()
70
 
71
+ # Create a simple prompt for medicine suggestions
72
+ prompt = f"<s>[INST] {MEDICINE_PROMPT.format(patient_info=patient_info)} [/INST] "
73
+
74
+ inputs = me_llama_tokenizer(prompt, return_tensors="pt")
75
 
76
  # Move inputs to the same device as the model
77
  if torch.cuda.is_available():
78
+ inputs = {k: v.to(me_llama_model.device) for k, v in inputs.items()}
79
 
80
  with torch.no_grad():
81
+ outputs = me_llama_model.generate(
82
  inputs["input_ids"],
83
  attention_mask=inputs["attention_mask"],
84
+ max_new_tokens=300,
85
  temperature=0.7,
86
  top_p=0.9,
87
  do_sample=True,
88
+ pad_token_id=me_llama_tokenizer.eos_token_id
89
  )
90
 
91
+ suggestion = me_llama_tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
92
  return suggestion
93
 
94
  @spaces.GPU
95
  def generate_response(message, history):
96
+ """Generate response using only Me-LLaMA for both consultation and medicine suggestions."""
97
  global conversation_turns, patient_data
98
 
99
+ # Load model if needed
100
+ load_model_if_needed()
101
 
102
  # Track conversation turns
103
  conversation_turns += 1
104
 
105
+ # Store patient data
106
  patient_data.append(message)
107
 
108
+ # Phase 1-3: Information gathering
109
+ if conversation_turns < 4:
110
+ # Build consultation prompt
111
+ prompt = build_me_llama_prompt(CONSULTATION_PROMPT, history, message)
112
+
113
+ inputs = me_llama_tokenizer(prompt, return_tensors="pt")
114
+
115
+ # Move inputs to the same device as the model
116
+ if torch.cuda.is_available():
117
+ inputs = {k: v.to(me_llama_model.device) for k, v in inputs.items()}
118
+
119
+ # Generate consultation response
120
+ with torch.no_grad():
121
+ outputs = me_llama_model.generate(
122
+ inputs["input_ids"],
123
+ attention_mask=inputs["attention_mask"],
124
+ max_new_tokens=400,
125
+ temperature=0.7,
126
+ top_p=0.9,
127
+ do_sample=True,
128
+ pad_token_id=me_llama_tokenizer.eos_token_id
129
+ )
130
+
131
+ # Decode response
132
+ full_response = me_llama_tokenizer.decode(outputs[0], skip_special_tokens=False)
133
+ response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
134
+
135
+ return response
136
+
137
+ # Phase 4+: Summary and medicine suggestions
138
+ else:
139
+ # First, get summary from consultation
140
+ summary_prompt = build_me_llama_prompt(
141
+ CONSULTATION_PROMPT + "\n\nNow summarize what you've learned and suggest when professional care may be needed.",
142
+ history,
143
+ message
144
  )
 
 
 
 
 
 
 
 
 
145
 
146
+ inputs = me_llama_tokenizer(summary_prompt, return_tensors="pt")
147
+
148
+ if torch.cuda.is_available():
149
+ inputs = {k: v.to(me_llama_model.device) for k, v in inputs.items()}
150
+
151
+ # Generate summary
152
+ with torch.no_grad():
153
+ outputs = me_llama_model.generate(
154
+ inputs["input_ids"],
155
+ attention_mask=inputs["attention_mask"],
156
+ max_new_tokens=400,
157
+ temperature=0.7,
158
+ top_p=0.9,
159
+ do_sample=True,
160
+ pad_token_id=me_llama_tokenizer.eos_token_id
161
+ )
162
 
163
+ summary_response = me_llama_tokenizer.decode(outputs[0], skip_special_tokens=False)
164
+ summary = summary_response.split('[/INST]')[-1].split('</s>')[0].strip()
165
+
166
+ # Then get medicine suggestions using the same model
167
+ full_patient_info = "\n".join(patient_data) + f"\n\nMedical Summary: {summary}"
168
+ medicine_suggestions = generate_medicine_suggestions(full_patient_info)
169
+
170
+ # Combine both responses
171
  final_response = (
172
+ f"**MEDICAL SUMMARY:**\n{summary}\n\n"
173
+ f"**MEDICATION AND HOME CARE SUGGESTIONS:**\n{medicine_suggestions}\n\n"
174
+ f"**DISCLAIMER:** This is AI-generated advice for informational purposes only. Please consult a licensed healthcare provider for proper medical diagnosis and treatment."
175
  )
176
+
177
  return final_response
 
 
178
 
179
  # Create the Gradio interface
180
  demo = gr.ChatInterface(
181
  fn=generate_response,
182
+ title="🏥 Complete Medical Assistant - Me-LLaMA 13B",
183
+ description="Comprehensive medical consultation powered by Me-LLaMA 13B-chat. One model handles both consultation and medicine suggestions. Tell me about your symptoms!",
184
  examples=[
185
+ "I have a persistent cough and sore throat for 3 days",
186
+ "I've been having severe headaches and feel dizzy",
187
+ "My stomach hurts and I feel nauseous after eating"
188
  ],
189
  theme="soft"
190
  )