techindia2025 commited on
Commit
5067011
·
verified ·
1 Parent(s): d73b8dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -186
app.py CHANGED
@@ -1,15 +1,16 @@
1
- import gradio as gr
 
 
 
 
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from typing import Annotated, List, Dict, Any
5
- from typing_extensions import TypedDict
6
- from langgraph.graph import StateGraph, START
7
- from langgraph.graph.message import add_messages
8
 
9
  # Model configuration
10
  LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
11
  MEDITRON_MODEL = "epfl-llm/meditron-7b"
12
 
 
13
  SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
14
  Ask 1-2 follow-up questions at a time to gather more details about:
15
  - Detailed description of symptoms
@@ -37,213 +38,109 @@ Patient information: {patient_info}
37
  """
38
 
39
  print("Loading Llama-2 model...")
40
- tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
41
- if tokenizer.pad_token is None:
42
- tokenizer.pad_token = tokenizer.eos_token
43
-
44
- model = AutoModelForCausalLM.from_pretrained(
45
  LLAMA_MODEL,
46
  torch_dtype=torch.float16,
47
  device_map="auto"
48
  )
 
 
 
 
 
 
 
 
 
 
 
 
49
  print("Llama-2 model loaded successfully!")
50
 
51
  print("Loading Meditron model...")
52
  meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
53
- if meditron_tokenizer.pad_token is None:
54
- meditron_tokenizer.pad_token = meditron_tokenizer.eos_token
55
-
56
  meditron_model = AutoModelForCausalLM.from_pretrained(
57
  MEDITRON_MODEL,
58
  torch_dtype=torch.float16,
59
  device_map="auto"
60
  )
 
 
 
 
 
 
 
 
 
 
 
61
  print("Meditron model loaded successfully!")
62
 
63
- # Define the state for our LangGraph
64
- class ChatbotState(TypedDict):
65
- messages: Annotated[List, add_messages]
66
- turn_count: int
67
- patient_info: List[str]
68
-
69
- # Function to build Llama-2 prompt
70
- def build_llama2_prompt(messages):
71
- """Format the conversation history for Llama-2 chat models."""
72
- prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n"
73
-
74
- # Add conversation history
75
- for i, msg in enumerate(messages[:-1]):
76
- if i % 2 == 0: # User message
77
- prompt += f"{msg.content} [/INST] "
78
- else: # Assistant message
79
- prompt += f"{msg.content} </s><s>[INST] "
80
-
81
- # Add the current user input
82
- prompt += f"{messages[-1].content} [/INST] "
83
-
84
- return prompt
85
-
86
- # Function to get Llama-2 response
87
- def get_llama2_response(prompt, turn_count):
88
- """Generate response from Llama-2 model."""
89
- # Add summarization instruction after 4 turns
90
- if turn_count >= 4:
91
- prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
92
-
93
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
94
-
95
- with torch.no_grad():
96
- outputs = model.generate(
97
- inputs.input_ids,
98
- attention_mask=inputs.attention_mask,
99
- max_new_tokens=512,
100
- temperature=0.7,
101
- top_p=0.9,
102
- do_sample=True,
103
- pad_token_id=tokenizer.pad_token_id
104
- )
105
-
106
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
107
- response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
108
-
109
- return response
110
-
111
- # Function to get Meditron suggestions
112
- def get_meditron_suggestions(patient_info):
113
- """Generate medicine and remedy suggestions from Meditron model."""
114
- prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
115
- inputs = meditron_tokenizer(prompt, return_tensors="pt").to(meditron_model.device)
116
-
117
- with torch.no_grad():
118
- outputs = meditron_model.generate(
119
- inputs.input_ids,
120
- attention_mask=inputs.attention_mask,
121
- max_new_tokens=256,
122
- temperature=0.7,
123
- top_p=0.9,
124
- do_sample=True,
125
- pad_token_id=meditron_tokenizer.pad_token_id
126
- )
127
-
128
- suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
129
- return suggestion
130
-
131
- # Define LangGraph nodes
132
- def process_user_input(state: ChatbotState) -> ChatbotState:
133
- """Process user input and update state."""
134
- # Extract the latest user message
135
- user_message = state["messages"][-1].content
136
-
137
- # Update patient info
138
- return {
139
- "patient_info": state["patient_info"] + [user_message],
140
- "turn_count": state["turn_count"] + 1
141
- }
142
 
143
- def generate_llama_response(state: ChatbotState) -> ChatbotState:
144
- """Generate response using Llama-2 model."""
145
- prompt = build_llama2_prompt(state["messages"])
146
- response = get_llama2_response(prompt, state["turn_count"])
147
-
148
- return {"messages": [{"role": "assistant", "content": response}]}
 
 
 
 
149
 
150
- def check_turn_count(state: ChatbotState) -> str:
151
- """Check if we need to add medicine suggestions."""
152
- if state["turn_count"] >= 4:
153
- return "add_suggestions"
154
- return "continue"
155
 
156
- def add_medicine_suggestions(state: ChatbotState) -> ChatbotState:
157
- """Add medicine suggestions from Meditron model."""
158
- # Get the last assistant response
159
- last_response = state["messages"][-1].content
160
-
161
- # Collect full patient conversation
162
- full_patient_info = "\n".join(state["patient_info"]) + "\n\nSummary: " + last_response
163
 
164
- # Get medicine suggestions
165
- medicine_suggestions = get_meditron_suggestions(full_patient_info)
166
 
167
- # Format final response
168
- final_response = (
169
- f"{last_response}\n\n"
170
- f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n"
171
- f"{medicine_suggestions}"
172
- )
173
-
174
- # Return updated message
175
- return {"messages": [{"role": "assistant", "content": final_response}]}
176
-
177
- # Build the LangGraph
178
- def build_graph():
179
- """Build and return the LangGraph for our chatbot."""
180
- graph = StateGraph(ChatbotState)
181
-
182
- # Add nodes
183
- graph.add_node("process_input", process_user_input)
184
- graph.add_node("generate_response", generate_llama_response)
185
- graph.add_node("add_suggestions", add_medicine_suggestions)
186
 
187
- # Add edges
188
- graph.add_edge(START, "process_input")
189
- graph.add_edge("process_input", "generate_response")
190
- graph.add_conditional_edges(
191
- "generate_response",
192
- check_turn_count,
193
- {
194
- "add_suggestions": "add_suggestions",
195
- "continue": END
196
- }
197
- )
198
- graph.add_edge("add_suggestions", END)
199
 
200
- return graph.compile()
201
-
202
- # Initialize the graph
203
- chatbot_graph = build_graph()
204
-
205
- # Function for Gradio interface
206
- def chat_response(message, history):
207
- """Generate chatbot response using LangGraph."""
208
- # Initialize state if this is the first message
209
- if not history:
210
- state = {
211
- "messages": [{"role": "user", "content": message}],
212
- "turn_count": 0,
213
- "patient_info": []
214
- }
215
- else:
216
- # Convert history to messages format
217
- messages = []
218
- for user_msg, bot_msg in history:
219
- messages.append({"role": "user", "content": user_msg})
220
- messages.append({"role": "assistant", "content": bot_msg})
221
 
222
- # Add current message
223
- messages.append({"role": "user", "content": message})
224
 
225
- # Get turn count from history
226
- turn_count = len(history)
227
-
228
- # Build patient info from history
229
- patient_info = [user_msg for user_msg, _ in history]
230
-
231
- state = {
232
- "messages": messages,
233
- "turn_count": turn_count,
234
- "patient_info": patient_info
235
- }
236
-
237
- # Process through LangGraph
238
- result = chatbot_graph.invoke(state)
239
 
240
- # Return the latest assistant message
241
- return result["messages"][-1].content
242
 
243
  # Create the Gradio interface
244
  demo = gr.ChatInterface(
245
- fn=chat_response,
246
- title="Medical Assistant with LangGraph",
247
  description="Tell me about your symptoms, and after gathering enough information, I'll suggest potential remedies.",
248
  examples=[
249
  "I have a cough and my throat hurts",
@@ -254,4 +151,4 @@ demo = gr.ChatInterface(
254
  )
255
 
256
  if __name__ == "__main__":
257
- demo.launch()
 
1
+ from langchain.chains import ConversationChain, LLMChain
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain.llms import HuggingFacePipeline
4
+ from langchain.memory import ConversationBufferMemory
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
  import torch
7
+ import gradio as gr
 
 
 
 
8
 
9
  # Model configuration
10
  LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
11
  MEDITRON_MODEL = "epfl-llm/meditron-7b"
12
 
13
+ # System prompts
14
  SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
15
  Ask 1-2 follow-up questions at a time to gather more details about:
16
  - Detailed description of symptoms
 
38
  """
39
 
40
  print("Loading Llama-2 model...")
41
+ # Create LangChain wrapper for Llama-2
42
+ llama_tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
43
+ llama_model = AutoModelForCausalLM.from_pretrained(
 
 
44
  LLAMA_MODEL,
45
  torch_dtype=torch.float16,
46
  device_map="auto"
47
  )
48
+
49
+ # Create a pipeline for LangChain
50
+ llama_pipeline = pipeline(
51
+ "text-generation",
52
+ model=llama_model,
53
+ tokenizer=llama_tokenizer,
54
+ max_new_tokens=512,
55
+ temperature=0.7,
56
+ top_p=0.9,
57
+ do_sample=True
58
+ )
59
+ llama_llm = HuggingFacePipeline(pipeline=llama_pipeline)
60
  print("Llama-2 model loaded successfully!")
61
 
62
  print("Loading Meditron model...")
63
  meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
 
 
 
64
  meditron_model = AutoModelForCausalLM.from_pretrained(
65
  MEDITRON_MODEL,
66
  torch_dtype=torch.float16,
67
  device_map="auto"
68
  )
69
+ # Create a pipeline for Meditron
70
+ meditron_pipeline = pipeline(
71
+ "text-generation",
72
+ model=meditron_model,
73
+ tokenizer=meditron_tokenizer,
74
+ max_new_tokens=256,
75
+ temperature=0.7,
76
+ top_p=0.9,
77
+ do_sample=True
78
+ )
79
+ meditron_llm = HuggingFacePipeline(pipeline=meditron_pipeline)
80
  print("Meditron model loaded successfully!")
81
 
82
+ # Create LangChain conversation with memory
83
+ memory = ConversationBufferMemory(return_messages=True)
84
+ conversation = ConversationChain(
85
+ llm=llama_llm,
86
+ memory=memory,
87
+ verbose=True
88
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ # Create a template for the Meditron model
91
+ meditron_template = PromptTemplate(
92
+ input_variables=["patient_info"],
93
+ template=MEDITRON_PROMPT
94
+ )
95
+ meditron_chain = LLMChain(
96
+ llm=meditron_llm,
97
+ prompt=meditron_template,
98
+ verbose=True
99
+ )
100
 
101
+ # Track conversation turns
102
+ conversation_turns = 0
103
+ patient_data = []
 
 
104
 
105
+ def generate_response(message, history):
106
+ global conversation_turns, patient_data
107
+ conversation_turns += 1
 
 
 
 
108
 
109
+ # Store patient message
110
+ patient_data.append(message)
111
 
112
+ # Format the prompt with system instructions
113
+ if conversation_turns >= 4:
114
+ # Add summarization instruction after 4 turns
115
+ prompt = f"{SYSTEM_PROMPT}\n\nNow summarize what you've learned and suggest when professional care may be needed.\n\n{message}"
116
+ else:
117
+ prompt = f"{SYSTEM_PROMPT}\n\n{message}"
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ # Generate response using LangChain conversation
120
+ llama_response = conversation.predict(input=prompt)
 
 
 
 
 
 
 
 
 
 
121
 
122
+ # After 4 turns, add medicine suggestions from Meditron
123
+ if conversation_turns >= 4:
124
+ # Collect full patient conversation
125
+ full_patient_info = "\n".join(patient_data) + "\n\nSummary: " + llama_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ # Get medicine suggestions using LangChain
128
+ medicine_suggestions = meditron_chain.run(patient_info=full_patient_info)
129
 
130
+ # Format final response
131
+ final_response = (
132
+ f"{llama_response}\n\n"
133
+ f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n"
134
+ f"{medicine_suggestions}"
135
+ )
136
+ return final_response
 
 
 
 
 
 
 
137
 
138
+ return llama_response
 
139
 
140
  # Create the Gradio interface
141
  demo = gr.ChatInterface(
142
+ fn=generate_response,
143
+ title="Medical Assistant with Medicine Suggestions",
144
  description="Tell me about your symptoms, and after gathering enough information, I'll suggest potential remedies.",
145
  examples=[
146
  "I have a cough and my throat hurts",
 
151
  )
152
 
153
  if __name__ == "__main__":
154
+ demo.launch()