helloparthshah commited on
Commit
9d9d55a
·
1 Parent(s): b3d741d

Adding memory management using sentenceTransformer

Browse files
Files changed (1) hide show
  1. src/manager/manager.py +42 -1
src/manager/manager.py CHANGED
@@ -8,6 +8,9 @@ from src.manager.tool_manager import ToolManager
8
  from src.manager.utils.suppress_outputs import suppress_output
9
  import logging
10
  import gradio as gr
 
 
 
11
 
12
  logger = logging.getLogger(__name__)
13
  handler = logging.StreamHandler(sys.stdout)
@@ -35,6 +38,7 @@ class GeminiManager:
35
  self.client = genai.Client(api_key=self.API_KEY)
36
  self.toolsLoader.load_tools()
37
  self.model_name = gemini_model
 
38
  with open(system_prompt_file, 'r', encoding="utf8") as f:
39
  self.system_prompt = f.read()
40
  self.messages = []
@@ -131,6 +135,9 @@ class GeminiManager:
131
  match message.get("role"):
132
  case "user":
133
  role = "user"
 
 
 
134
  case "tool":
135
  role = "tool"
136
  formatted_history.append(
@@ -149,7 +156,41 @@ class GeminiManager:
149
  ))
150
  return formatted_history
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def run(self, messages):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  chat_history = self.format_chat_history(messages)
154
  logger.debug(f"Chat history: {chat_history}")
155
  try:
@@ -195,6 +236,6 @@ class GeminiManager:
195
  if (call.get("role") == "tool"
196
  or (call.get("role") == "assistant" and call.get("metadata", {}).get("status") == "done")):
197
  messages.append(call)
198
- yield from self.run(messages)
199
  return
200
  yield messages
 
8
  from src.manager.utils.suppress_outputs import suppress_output
9
  import logging
10
  import gradio as gr
11
+ from sentence_transformers import SentenceTransformer
12
+ import torch
13
+ from src.tools.default_tools.memory_manager import MemoryManager
14
 
15
  logger = logging.getLogger(__name__)
16
  handler = logging.StreamHandler(sys.stdout)
 
38
  self.client = genai.Client(api_key=self.API_KEY)
39
  self.toolsLoader.load_tools()
40
  self.model_name = gemini_model
41
+ self.memory_manager = MemoryManager()
42
  with open(system_prompt_file, 'r', encoding="utf8") as f:
43
  self.system_prompt = f.read()
44
  self.messages = []
 
135
  match message.get("role"):
136
  case "user":
137
  role = "user"
138
+ case "memories":
139
+ role = "user"
140
+ parts = [types.Part.from_text(text="User memories: "+message.get("content", ""))]
141
  case "tool":
142
  role = "tool"
143
  formatted_history.append(
 
156
  ))
157
  return formatted_history
158
 
159
+ def get_k_memories(self, query, k=5, threshold=0.0):
160
+ memories = MemoryManager().get_memories()
161
+ if len(memories) == 0:
162
+ return []
163
+ top_k = min(k, len(memories))
164
+ # Semantic Retrieval with GPU
165
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
166
+ model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
167
+ doc_embeddings = model.encode(memories, convert_to_tensor=True, device=device)
168
+ query_embedding = model.encode(query, convert_to_tensor=True, device=device)
169
+ similarity_scores = model.similarity(query_embedding, doc_embeddings)[0]
170
+ scores, indices = torch.topk(similarity_scores, k=top_k)
171
+ results = []
172
+ for score, idx in zip(scores, indices):
173
+ print(memories[idx], f"(Score: {score:.4f})")
174
+ if score >= threshold:
175
+ results.append(memories[idx])
176
+ return results
177
+
178
  def run(self, messages):
179
+ memories = self.get_k_memories(messages[-1]['content'], k=5, threshold=0.0)
180
+ if len(memories) > 0:
181
+ messages.append({
182
+ "role": "memories",
183
+ "content": f"{memories}",
184
+ })
185
+ messages.append({
186
+ "role": "assistant",
187
+ "content": f"Memories: {memories}",
188
+ "metadata": {"title": "Memories"}
189
+ })
190
+ yield messages
191
+ yield from self.invoke_manager(messages)
192
+
193
+ def invoke_manager(self, messages):
194
  chat_history = self.format_chat_history(messages)
195
  logger.debug(f"Chat history: {chat_history}")
196
  try:
 
236
  if (call.get("role") == "tool"
237
  or (call.get("role") == "assistant" and call.get("metadata", {}).get("status") == "done")):
238
  messages.append(call)
239
+ yield from self.invoke_manager(messages)
240
  return
241
  yield messages