Thanush commited on
Commit
aa89cd7
·
1 Parent(s): 71bcd31

Refactor app.py to streamline conversation state management and update requirements.txt for package versions

Browse files
Files changed (2) hide show
  1. app.py +35 -130
  2. requirements.txt +6 -18
app.py CHANGED
@@ -2,9 +2,6 @@ import gradio as gr
2
  import spaces
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
- from langgraph.graph import StateGraph, END
6
- from typing import TypedDict, List, Tuple
7
- import json
8
 
9
  # Model configuration
10
  LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
@@ -36,7 +33,6 @@ Patient information: {patient_info}
36
  <|im_start|>assistant
37
  """
38
 
39
- # Load models
40
  print("Loading Llama-2 model...")
41
  tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
42
  model = AutoModelForCausalLM.from_pretrained(
@@ -55,16 +51,9 @@ meditron_model = AutoModelForCausalLM.from_pretrained(
55
  )
56
  print("Meditron model loaded successfully!")
57
 
58
- # Define the state for LangGraph
59
- class ConversationState(TypedDict):
60
- messages: List[str]
61
- history: List[Tuple[str, str]]
62
- current_message: str
63
- conversation_turns: int
64
- patient_data: List[str]
65
- llama_response: str
66
- final_response: str
67
- should_get_suggestions: bool
68
 
69
  def build_llama2_prompt(system_prompt, history, user_input):
70
  """Format the conversation history and user input for Llama-2 chat models."""
@@ -97,29 +86,25 @@ def get_meditron_suggestions(patient_info):
97
  suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
98
  return suggestion
99
 
100
- # LangGraph Node Functions
101
- def initialize_conversation(state: ConversationState) -> ConversationState:
102
- """Initialize or update conversation state."""
103
- # Update conversation turns
104
- state["conversation_turns"] = state.get("conversation_turns", 0) + 1
105
-
106
- # Add current message to patient data
107
- if "patient_data" not in state:
108
- state["patient_data"] = []
109
- state["patient_data"].append(state["current_message"])
110
 
111
- # Determine if we should get suggestions (after 4 turns)
112
- state["should_get_suggestions"] = state["conversation_turns"] >= 4
 
 
113
 
114
- return state
115
-
116
- def generate_llama_response(state: ConversationState) -> ConversationState:
117
- """Generate response using Llama-2 model."""
118
  # Build the prompt with proper Llama-2 formatting
119
- prompt = build_llama2_prompt(SYSTEM_PROMPT, state["history"], state["current_message"])
120
 
121
  # Add summarization instruction after 4 turns
122
- if state["conversation_turns"] >= 4:
123
  prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
124
 
125
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
@@ -140,109 +125,29 @@ def generate_llama_response(state: ConversationState) -> ConversationState:
140
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
141
  llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
142
 
143
- state["llama_response"] = llama_response
144
- return state
145
-
146
- def generate_medicine_suggestions(state: ConversationState) -> ConversationState:
147
- """Generate medicine suggestions using Meditron model."""
148
- # Collect full patient conversation
149
- full_patient_info = "\n".join(state["patient_data"]) + "\n\nSummary: " + state["llama_response"]
150
-
151
- # Get medicine suggestions
152
- medicine_suggestions = get_meditron_suggestions(full_patient_info)
153
-
154
- # Format final response
155
- final_response = (
156
- f"{state['llama_response']}\n\n"
157
- f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n"
158
- f"{medicine_suggestions}"
159
- )
160
-
161
- state["final_response"] = final_response
162
- return state
163
-
164
- def finalize_response(state: ConversationState) -> ConversationState:
165
- """Finalize the response without medicine suggestions."""
166
- state["final_response"] = state["llama_response"]
167
- return state
168
-
169
- def should_get_suggestions(state: ConversationState) -> str:
170
- """Conditional edge to determine next step."""
171
- if state["should_get_suggestions"]:
172
- return "get_suggestions"
173
- else:
174
- return "finalize"
175
-
176
- # Create the LangGraph workflow
177
- def create_medical_workflow():
178
- """Create the LangGraph workflow for medical assistant."""
179
- workflow = StateGraph(ConversationState)
180
-
181
- # Add nodes
182
- workflow.add_node("initialize", initialize_conversation)
183
- workflow.add_node("generate_llama", generate_llama_response)
184
- workflow.add_node("get_suggestions", generate_medicine_suggestions)
185
- workflow.add_node("finalize", finalize_response)
186
-
187
- # Define the flow
188
- workflow.set_entry_point("initialize")
189
- workflow.add_edge("initialize", "generate_llama")
190
- workflow.add_conditional_edges(
191
- "generate_llama",
192
- should_get_suggestions,
193
- {
194
- "get_suggestions": "get_suggestions",
195
- "finalize": "finalize"
196
- }
197
- )
198
- workflow.add_edge("get_suggestions", END)
199
- workflow.add_edge("finalize", END)
200
-
201
- return workflow.compile()
202
-
203
- # Initialize the workflow
204
- medical_workflow = create_medical_workflow()
205
-
206
- # Conversation state tracking (for Gradio session management)
207
- conversation_states = {}
208
-
209
- @spaces.GPU
210
- def generate_response(message, history):
211
- """Generate a response using the LangGraph workflow."""
212
- session_id = "default-session"
213
-
214
- # Initialize or get existing conversation state
215
- if session_id not in conversation_states:
216
- conversation_states[session_id] = {
217
- "messages": [],
218
- "history": [],
219
- "conversation_turns": 0,
220
- "patient_data": []
221
- }
222
-
223
- # Update state with current message and history
224
- state = conversation_states[session_id].copy()
225
- state["current_message"] = message
226
- state["history"] = history
227
-
228
- # Run the workflow
229
- result = medical_workflow.invoke(state)
230
-
231
- # Update the stored conversation state
232
- conversation_states[session_id] = {
233
- "messages": result["messages"] if "messages" in result else [],
234
- "history": history,
235
- "conversation_turns": result["conversation_turns"],
236
- "patient_data": result["patient_data"]
237
- }
238
 
239
- return result["final_response"]
240
 
241
  # Create the Gradio interface
242
  demo = gr.ChatInterface(
243
  fn=generate_response,
244
- title="Medical Assistant with LangGraph & Medicine Suggestions",
245
- description="Tell me about your symptoms, and after gathering enough information, I'll suggest potential remedies using an AI workflow.",
246
  examples=[
247
  "I have a cough and my throat hurts",
248
  "I've been having headaches for a week",
 
2
  import spaces
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
5
 
6
  # Model configuration
7
  LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
 
33
  <|im_start|>assistant
34
  """
35
 
 
36
  print("Loading Llama-2 model...")
37
  tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
38
  model = AutoModelForCausalLM.from_pretrained(
 
51
  )
52
  print("Meditron model loaded successfully!")
53
 
54
+ # Conversation state tracking
55
+ conversation_turns = {}
56
+ patient_data = {}
 
 
 
 
 
 
 
57
 
58
  def build_llama2_prompt(system_prompt, history, user_input):
59
  """Format the conversation history and user input for Llama-2 chat models."""
 
86
  suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
87
  return suggestion
88
 
89
+ @spaces.GPU
90
+ def generate_response(message, history):
91
+ """Generate a response using both models."""
92
+ # Track conversation turns
93
+ session_id = "default-session"
94
+ if session_id not in conversation_turns:
95
+ conversation_turns[session_id] = 0
96
+ conversation_turns[session_id] += 1
 
 
97
 
98
+ # Store the entire conversation for reference
99
+ if session_id not in patient_data:
100
+ patient_data[session_id] = []
101
+ patient_data[session_id].append(message)
102
 
 
 
 
 
103
  # Build the prompt with proper Llama-2 formatting
104
+ prompt = build_llama2_prompt(SYSTEM_PROMPT, history, message)
105
 
106
  # Add summarization instruction after 4 turns
107
+ if conversation_turns[session_id] >= 4:
108
  prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
109
 
110
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
125
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
126
  llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
127
 
128
+ # After 4 turns, add medicine suggestions from Meditron
129
+ if conversation_turns[session_id] >= 4:
130
+ # Collect full patient conversation
131
+ full_patient_info = "\n".join(patient_data[session_id]) + "\n\nSummary: " + llama_response
132
+
133
+ # Get medicine suggestions
134
+ medicine_suggestions = get_meditron_suggestions(full_patient_info)
135
+
136
+ # Format final response
137
+ final_response = (
138
+ f"{llama_response}\n\n"
139
+ f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n"
140
+ f"{medicine_suggestions}"
141
+ )
142
+ return final_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ return llama_response
145
 
146
  # Create the Gradio interface
147
  demo = gr.ChatInterface(
148
  fn=generate_response,
149
+ title="Medical Assistant with Medicine Suggestions",
150
+ description="Tell me about your symptoms, and after gathering enough information, I'll suggest potential remedies.",
151
  examples=[
152
  "I have a cough and my throat hurts",
153
  "I've been having headaches for a week",
requirements.txt CHANGED
@@ -1,18 +1,6 @@
1
- # Core packages
2
- gradio==4.24.0
3
- spaces==0.21.1
4
-
5
- # Transformers & tokenization
6
- transformers==4.40.1
7
- torch>=2.1.0
8
-
9
- # LangGraph
10
- langgraph==0.0.41
11
-
12
- # Optional but often required for transformers
13
- accelerate==0.30.1
14
- sentencepiece==0.1.99
15
- protobuf==4.25.3
16
-
17
- # Utility
18
- typing-extensions>=4.5.0
 
1
+ gradio>=4.0
2
+ torch>=2.1
3
+ transformers>=4.37
4
+ spaces
5
+ sentencepiece
6
+ accelerate>=0.21.0