Commit
·
9d9d55a
1
Parent(s):
b3d741d
Adding memory management using sentenceTransformer
Browse files- src/manager/manager.py +42 -1
src/manager/manager.py
CHANGED
@@ -8,6 +8,9 @@ from src.manager.tool_manager import ToolManager
|
|
8 |
from src.manager.utils.suppress_outputs import suppress_output
|
9 |
import logging
|
10 |
import gradio as gr
|
|
|
|
|
|
|
11 |
|
12 |
logger = logging.getLogger(__name__)
|
13 |
handler = logging.StreamHandler(sys.stdout)
|
@@ -35,6 +38,7 @@ class GeminiManager:
|
|
35 |
self.client = genai.Client(api_key=self.API_KEY)
|
36 |
self.toolsLoader.load_tools()
|
37 |
self.model_name = gemini_model
|
|
|
38 |
with open(system_prompt_file, 'r', encoding="utf8") as f:
|
39 |
self.system_prompt = f.read()
|
40 |
self.messages = []
|
@@ -131,6 +135,9 @@ class GeminiManager:
|
|
131 |
match message.get("role"):
|
132 |
case "user":
|
133 |
role = "user"
|
|
|
|
|
|
|
134 |
case "tool":
|
135 |
role = "tool"
|
136 |
formatted_history.append(
|
@@ -149,7 +156,41 @@ class GeminiManager:
|
|
149 |
))
|
150 |
return formatted_history
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
def run(self, messages):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
chat_history = self.format_chat_history(messages)
|
154 |
logger.debug(f"Chat history: {chat_history}")
|
155 |
try:
|
@@ -195,6 +236,6 @@ class GeminiManager:
|
|
195 |
if (call.get("role") == "tool"
|
196 |
or (call.get("role") == "assistant" and call.get("metadata", {}).get("status") == "done")):
|
197 |
messages.append(call)
|
198 |
-
yield from self.
|
199 |
return
|
200 |
yield messages
|
|
|
8 |
from src.manager.utils.suppress_outputs import suppress_output
|
9 |
import logging
|
10 |
import gradio as gr
|
11 |
+
from sentence_transformers import SentenceTransformer
|
12 |
+
import torch
|
13 |
+
from src.tools.default_tools.memory_manager import MemoryManager
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
16 |
handler = logging.StreamHandler(sys.stdout)
|
|
|
38 |
self.client = genai.Client(api_key=self.API_KEY)
|
39 |
self.toolsLoader.load_tools()
|
40 |
self.model_name = gemini_model
|
41 |
+
self.memory_manager = MemoryManager()
|
42 |
with open(system_prompt_file, 'r', encoding="utf8") as f:
|
43 |
self.system_prompt = f.read()
|
44 |
self.messages = []
|
|
|
135 |
match message.get("role"):
|
136 |
case "user":
|
137 |
role = "user"
|
138 |
+
case "memories":
|
139 |
+
role = "user"
|
140 |
+
parts = [types.Part.from_text(text="User memories: "+message.get("content", ""))]
|
141 |
case "tool":
|
142 |
role = "tool"
|
143 |
formatted_history.append(
|
|
|
156 |
))
|
157 |
return formatted_history
|
158 |
|
159 |
+
def get_k_memories(self, query, k=5, threshold=0.0):
|
160 |
+
memories = MemoryManager().get_memories()
|
161 |
+
if len(memories) == 0:
|
162 |
+
return []
|
163 |
+
top_k = min(k, len(memories))
|
164 |
+
# Semantic Retrieval with GPU
|
165 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
166 |
+
model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
|
167 |
+
doc_embeddings = model.encode(memories, convert_to_tensor=True, device=device)
|
168 |
+
query_embedding = model.encode(query, convert_to_tensor=True, device=device)
|
169 |
+
similarity_scores = model.similarity(query_embedding, doc_embeddings)[0]
|
170 |
+
scores, indices = torch.topk(similarity_scores, k=top_k)
|
171 |
+
results = []
|
172 |
+
for score, idx in zip(scores, indices):
|
173 |
+
print(memories[idx], f"(Score: {score:.4f})")
|
174 |
+
if score >= threshold:
|
175 |
+
results.append(memories[idx])
|
176 |
+
return results
|
177 |
+
|
178 |
def run(self, messages):
|
179 |
+
memories = self.get_k_memories(messages[-1]['content'], k=5, threshold=0.0)
|
180 |
+
if len(memories) > 0:
|
181 |
+
messages.append({
|
182 |
+
"role": "memories",
|
183 |
+
"content": f"{memories}",
|
184 |
+
})
|
185 |
+
messages.append({
|
186 |
+
"role": "assistant",
|
187 |
+
"content": f"Memories: {memories}",
|
188 |
+
"metadata": {"title": "Memories"}
|
189 |
+
})
|
190 |
+
yield messages
|
191 |
+
yield from self.invoke_manager(messages)
|
192 |
+
|
193 |
+
def invoke_manager(self, messages):
|
194 |
chat_history = self.format_chat_history(messages)
|
195 |
logger.debug(f"Chat history: {chat_history}")
|
196 |
try:
|
|
|
236 |
if (call.get("role") == "tool"
|
237 |
or (call.get("role") == "assistant" and call.get("metadata", {}).get("status") == "done")):
|
238 |
messages.append(call)
|
239 |
+
yield from self.invoke_manager(messages)
|
240 |
return
|
241 |
yield messages
|