Update agent.py
Browse files
agent.py
CHANGED
@@ -17,10 +17,28 @@ import duckduckgo_search as ddg
|
|
17 |
import re
|
18 |
from llama_index.core.agent.workflow import ReActAgent
|
19 |
from llama_index.llms.openrouter import OpenRouter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
text_llm = OpenRouter(
|
22 |
-
model="mistralai/mistral-small-3.1-24b-instruct:free",
|
23 |
-
api_key=os.getenv("OPENROUTER_API_KEY"),
|
24 |
)
|
25 |
multimodal_llm = text_llm
|
26 |
|
@@ -97,7 +115,8 @@ class EnhancedRAGQueryEngine:
|
|
97 |
|
98 |
index = VectorStoreIndex(
|
99 |
nodes,
|
100 |
-
embed_model=self.embed_model
|
|
|
101 |
)
|
102 |
|
103 |
return index
|
@@ -112,7 +131,8 @@ class EnhancedRAGQueryEngine:
|
|
112 |
query_engine = RetrieverQueryEngine(
|
113 |
retriever=retriever,
|
114 |
node_postprocessors=[self.reranker],
|
115 |
-
llm=multimodal_llm
|
|
|
116 |
)
|
117 |
|
118 |
return query_engine
|
@@ -217,7 +237,7 @@ analysis_agent = FunctionAgent(
|
|
217 |
llm=multimodal_llm,
|
218 |
tools=[enhanced_rag_tool, cross_document_tool],
|
219 |
max_steps=5,
|
220 |
-
|
221 |
)
|
222 |
|
223 |
|
@@ -362,7 +382,8 @@ code_agent = ReActAgent(
|
|
362 |
""",
|
363 |
llm=text_llm,
|
364 |
tools=[code_execution_tool],
|
365 |
-
max_steps = 5
|
|
|
366 |
)
|
367 |
|
368 |
# Créer des outils à partir des agents
|
@@ -421,7 +442,8 @@ class EnhancedGAIAAgent:
|
|
421 |
""",
|
422 |
llm=text_llm,
|
423 |
tools=[analysis_tool, research_tool, code_tool],
|
424 |
-
max_steps = 10
|
|
|
425 |
)
|
426 |
|
427 |
async def solve_gaia_question(self, question_data: Dict[str, Any]) -> str:
|
|
|
17 |
import re
|
18 |
from llama_index.core.agent.workflow import ReActAgent
|
19 |
from llama_index.llms.openrouter import OpenRouter
|
20 |
+
import wandb
|
21 |
+
from llama_index.callbacks.wandb import WandbCallbackHandler
|
22 |
+
from llama_index.callbacks.base import CallbackManager
|
23 |
+
from llama_index.callbacks.llama_debug import LlamaDebugHandler
|
24 |
+
from llama_index import ServiceContext
|
25 |
+
|
26 |
+
wandb.init(project="gaia-llamaindex-agents") # Choisis ton nom de projet
|
27 |
+
wandb_callback = WandbCallbackHandler(run_args={"project": "gaia-llamaindex-agents"})
|
28 |
+
llama_debug = LlamaDebugHandler(print_trace_on_end=True)
|
29 |
+
callback_manager = CallbackManager([wandb_callback, llama_debug])
|
30 |
+
|
31 |
+
service_context = ServiceContext.from_defaults(
|
32 |
+
llm=text_llm,
|
33 |
+
embed_model=HuggingFaceEmbedding("BAAI/bge-small-en-v1.5"),
|
34 |
+
callback_manager=callback_manager
|
35 |
+
)
|
36 |
+
# Puis passe service_context=service_context à tes agents ou query engines
|
37 |
+
|
38 |
|
39 |
text_llm = OpenRouter(
|
40 |
+
model="mistralai/mistral-small-3.1-24b-instruct:free",
|
41 |
+
api_key=os.getenv("OPENROUTER_API_KEY"),
|
42 |
)
|
43 |
multimodal_llm = text_llm
|
44 |
|
|
|
115 |
|
116 |
index = VectorStoreIndex(
|
117 |
nodes,
|
118 |
+
embed_model=self.embed_model,
|
119 |
+
service_context=service_context
|
120 |
)
|
121 |
|
122 |
return index
|
|
|
131 |
query_engine = RetrieverQueryEngine(
|
132 |
retriever=retriever,
|
133 |
node_postprocessors=[self.reranker],
|
134 |
+
llm=multimodal_llm,
|
135 |
+
service_context=service_context
|
136 |
)
|
137 |
|
138 |
return query_engine
|
|
|
237 |
llm=multimodal_llm,
|
238 |
tools=[enhanced_rag_tool, cross_document_tool],
|
239 |
max_steps=5,
|
240 |
+
service_context=service_context
|
241 |
)
|
242 |
|
243 |
|
|
|
382 |
""",
|
383 |
llm=text_llm,
|
384 |
tools=[code_execution_tool],
|
385 |
+
max_steps = 5,
|
386 |
+
service_context=service_context
|
387 |
)
|
388 |
|
389 |
# Créer des outils à partir des agents
|
|
|
442 |
""",
|
443 |
llm=text_llm,
|
444 |
tools=[analysis_tool, research_tool, code_tool],
|
445 |
+
max_steps = 10,
|
446 |
+
service_context=service_context
|
447 |
)
|
448 |
|
449 |
async def solve_gaia_question(self, question_data: Dict[str, Any]) -> str:
|