Update agent.py
Browse files
agent.py
CHANGED
@@ -27,23 +27,23 @@ proj_llm = OpenRouter(
|
|
27 |
api_key=os.getenv("OPENROUTER_API_KEY"),
|
28 |
)
|
29 |
|
|
|
|
|
30 |
wandb.init(project="gaia-llamaindex-agents") # Choisis ton nom de projet
|
31 |
wandb_callback = WandbCallbackHandler(run_args={"project": "gaia-llamaindex-agents"})
|
32 |
llama_debug = LlamaDebugHandler(print_trace_on_end=True)
|
33 |
callback_manager = CallbackManager([wandb_callback, llama_debug])
|
34 |
|
35 |
-
from llama_index.core import Settings
|
36 |
|
37 |
Settings.llm = proj_llm
|
38 |
Settings.embed_model = embed_model
|
39 |
Settings.callback_manager = callback_manager
|
40 |
|
41 |
|
42 |
-
|
43 |
class EnhancedRAGQueryEngine:
|
44 |
def __init__(self, task_context: str = ""):
|
45 |
self.task_context = task_context
|
46 |
-
self.embed_model =
|
47 |
self.reranker = SentenceTransformerRerank(model="cross-encoder/ms-marco-MiniLM-L-2-v2", top_n=5)
|
48 |
|
49 |
self.readers = {
|
|
|
27 |
api_key=os.getenv("OPENROUTER_API_KEY"),
|
28 |
)
|
29 |
|
30 |
+
embed_model = HuggingFaceEmbedding("BAAI/bge-small-en-v1.5")
|
31 |
+
|
32 |
wandb.init(project="gaia-llamaindex-agents") # Choisis ton nom de projet
|
33 |
wandb_callback = WandbCallbackHandler(run_args={"project": "gaia-llamaindex-agents"})
|
34 |
llama_debug = LlamaDebugHandler(print_trace_on_end=True)
|
35 |
callback_manager = CallbackManager([wandb_callback, llama_debug])
|
36 |
|
|
|
37 |
|
38 |
Settings.llm = proj_llm
|
39 |
Settings.embed_model = embed_model
|
40 |
Settings.callback_manager = callback_manager
|
41 |
|
42 |
|
|
|
43 |
class EnhancedRAGQueryEngine:
|
44 |
def __init__(self, task_context: str = ""):
|
45 |
self.task_context = task_context
|
46 |
+
self.embed_model = embed_model
|
47 |
self.reranker = SentenceTransformerRerank(model="cross-encoder/ms-marco-MiniLM-L-2-v2", top_n=5)
|
48 |
|
49 |
self.readers = {
|