Thanush commited on
Commit
c4447f4
·
1 Parent(s): 6196bed

Refactor app.py to integrate LangChain memory for conversation tracking and update requirements.txt for LangChain dependency

Browse files
Files changed (2) hide show
  1. app.py +33 -28
  2. requirements.txt +3 -0
app.py CHANGED
@@ -2,12 +2,13 @@ import gradio as gr
2
  import spaces
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
5
 
6
  # Model configuration
7
  LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
8
  MEDITRON_MODEL = "epfl-llm/meditron-7b"
9
 
10
- SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
11
  Ask 1-2 follow-up questions at a time to gather more details about:
12
  - Detailed description of symptoms
13
  - Duration (when did it start?)
@@ -51,9 +52,8 @@ meditron_model = AutoModelForCausalLM.from_pretrained(
51
  )
52
  print("Meditron model loaded successfully!")
53
 
54
- # Conversation state tracking
55
- conversation_turns = {}
56
- patient_data = {}
57
 
58
  def build_llama2_prompt(system_prompt, history, user_input):
59
  """Format the conversation history and user input for Llama-2 chat models."""
@@ -89,26 +89,31 @@ def get_meditron_suggestions(patient_info):
89
  @spaces.GPU
90
  def generate_response(message, history):
91
  """Generate a response using both models."""
92
- # Track conversation turns
93
- session_id = "default-session"
94
- if session_id not in conversation_turns:
95
- conversation_turns[session_id] = 0
96
- conversation_turns[session_id] += 1
97
-
98
- # Store the entire conversation for reference
99
- if session_id not in patient_data:
100
- patient_data[session_id] = []
101
- patient_data[session_id].append(message)
102
-
103
- # Build the prompt with proper Llama-2 formatting
104
- prompt = build_llama2_prompt(SYSTEM_PROMPT, history, message)
105
-
 
 
 
 
 
106
  # Add summarization instruction after 4 turns
107
- if conversation_turns[session_id] >= 4:
108
  prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
109
-
110
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
111
-
112
  # Generate the Llama-2 response
113
  with torch.no_grad():
114
  outputs = model.generate(
@@ -120,19 +125,19 @@ def generate_response(message, history):
120
  do_sample=True,
121
  pad_token_id=tokenizer.eos_token_id
122
  )
123
-
124
  # Decode and extract Llama-2's response
125
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
126
  llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
127
-
128
  # After 4 turns, add medicine suggestions from Meditron
129
- if conversation_turns[session_id] >= 4:
130
  # Collect full patient conversation
131
- full_patient_info = "\n".join(patient_data[session_id]) + "\n\nSummary: " + llama_response
132
-
133
  # Get medicine suggestions
134
  medicine_suggestions = get_meditron_suggestions(full_patient_info)
135
-
136
  # Format final response
137
  final_response = (
138
  f"{llama_response}\n\n"
@@ -140,7 +145,7 @@ def generate_response(message, history):
140
  f"{medicine_suggestions}"
141
  )
142
  return final_response
143
-
144
  return llama_response
145
 
146
  # Create the Gradio interface
 
2
  import spaces
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from langchain.memory import ConversationBufferMemory
6
 
7
  # Model configuration
8
  LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
9
  MEDITRON_MODEL = "epfl-llm/meditron-7b"
10
 
11
+ SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's Name,age,health condition, symptoms, medical history, medications, lifestyle, and other relevant data.
12
  Ask 1-2 follow-up questions at a time to gather more details about:
13
  - Detailed description of symptoms
14
  - Duration (when did it start?)
 
52
  )
53
  print("Meditron model loaded successfully!")
54
 
55
+ # Initialize LangChain memory
56
+ memory = ConversationBufferMemory(return_messages=True)
 
57
 
58
  def build_llama2_prompt(system_prompt, history, user_input):
59
  """Format the conversation history and user input for Llama-2 chat models."""
 
89
  @spaces.GPU
90
  def generate_response(message, history):
91
  """Generate a response using both models."""
92
+ # Save the latest user message and last assistant response to memory
93
+ if history and len(history[-1]) == 2:
94
+ memory.save_context({"input": history[-1][0]}, {"output": history[-1][1]})
95
+ memory.save_context({"input": message}, {"output": ""})
96
+
97
+ # Build conversation history from memory
98
+ lc_history = []
99
+ user_msg = None
100
+ for msg in memory.chat_memory.messages:
101
+ if msg.type == "human":
102
+ user_msg = msg.content
103
+ elif msg.type == "ai" and user_msg is not None:
104
+ assistant_msg = msg.content
105
+ lc_history.append((user_msg, assistant_msg))
106
+ user_msg = None
107
+
108
+ # Build the prompt with LangChain memory history
109
+ prompt = build_llama2_prompt(SYSTEM_PROMPT, lc_history, message)
110
+
111
  # Add summarization instruction after 4 turns
112
+ if len(lc_history) >= 4:
113
  prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
114
+
115
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
116
+
117
  # Generate the Llama-2 response
118
  with torch.no_grad():
119
  outputs = model.generate(
 
125
  do_sample=True,
126
  pad_token_id=tokenizer.eos_token_id
127
  )
128
+
129
  # Decode and extract Llama-2's response
130
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
131
  llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
132
+
133
  # After 4 turns, add medicine suggestions from Meditron
134
+ if len(lc_history) >= 4:
135
  # Collect full patient conversation
136
+ full_patient_info = "\n".join([h[0] for h in lc_history] + [message]) + "\n\nSummary: " + llama_response
137
+
138
  # Get medicine suggestions
139
  medicine_suggestions = get_meditron_suggestions(full_patient_info)
140
+
141
  # Format final response
142
  final_response = (
143
  f"{llama_response}\n\n"
 
145
  f"{medicine_suggestions}"
146
  )
147
  return final_response
148
+
149
  return llama_response
150
 
151
  # Create the Gradio interface
requirements.txt CHANGED
@@ -21,3 +21,6 @@ aiofiles>=23.1.0
21
 
22
  # For better tensor operations
23
  numpy>=1.24.0
 
 
 
 
21
 
22
  # For better tensor operations
23
  numpy>=1.24.0
24
+
25
+ # For LangChain memory
26
+ langchain>=0.1.0