Coool2 commited on
Commit
e6723aa
·
verified ·
1 Parent(s): fa9bd8d

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +3 -3
agent.py CHANGED
@@ -27,23 +27,23 @@ proj_llm = OpenRouter(
27
  api_key=os.getenv("OPENROUTER_API_KEY"),
28
  )
29
 
 
 
30
  wandb.init(project="gaia-llamaindex-agents") # Choisis ton nom de projet
31
  wandb_callback = WandbCallbackHandler(run_args={"project": "gaia-llamaindex-agents"})
32
  llama_debug = LlamaDebugHandler(print_trace_on_end=True)
33
  callback_manager = CallbackManager([wandb_callback, llama_debug])
34
 
35
- from llama_index.core import Settings
36
 
37
  Settings.llm = proj_llm
38
  Settings.embed_model = embed_model
39
  Settings.callback_manager = callback_manager
40
 
41
 
42
-
43
  class EnhancedRAGQueryEngine:
44
  def __init__(self, task_context: str = ""):
45
  self.task_context = task_context
46
- self.embed_model = HuggingFaceEmbedding("BAAI/bge-small-en-v1.5")
47
  self.reranker = SentenceTransformerRerank(model="cross-encoder/ms-marco-MiniLM-L-2-v2", top_n=5)
48
 
49
  self.readers = {
 
27
  api_key=os.getenv("OPENROUTER_API_KEY"),
28
  )
29
 
30
+ embed_model = HuggingFaceEmbedding("BAAI/bge-small-en-v1.5")
31
+
32
  wandb.init(project="gaia-llamaindex-agents") # Choisis ton nom de projet
33
  wandb_callback = WandbCallbackHandler(run_args={"project": "gaia-llamaindex-agents"})
34
  llama_debug = LlamaDebugHandler(print_trace_on_end=True)
35
  callback_manager = CallbackManager([wandb_callback, llama_debug])
36
 
 
37
 
38
  Settings.llm = proj_llm
39
  Settings.embed_model = embed_model
40
  Settings.callback_manager = callback_manager
41
 
42
 
 
43
  class EnhancedRAGQueryEngine:
44
  def __init__(self, task_context: str = ""):
45
  self.task_context = task_context
46
+ self.embed_model = embed_model
47
  self.reranker = SentenceTransformerRerank(model="cross-encoder/ms-marco-MiniLM-L-2-v2", top_n=5)
48
 
49
  self.readers = {