Coool2 commited on
Commit
6395478
·
verified ·
1 Parent(s): a5e7aff

Rename tools.py to agent.py

Browse files
Files changed (2) hide show
  1. agent.py +492 -0
  2. tools.py +0 -0
agent.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llama_index.core.agent.workflow import FunctionAgent
2
+ from llama_index.core.tools import FunctionTool
3
+ from llama_index.core import VectorStoreIndex, Document
4
+ from llama_index.core.node_parser import SentenceWindowNodeParser, HierarchicalNodeParser
5
+ from llama_index.core.postprocessor import SentenceTransformerRerank
6
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
7
+ from llama_index.llms.huggingface import HuggingFaceInferenceAPI
8
+ from llama_index.core.retrievers import VectorIndexRetriever
9
+ from llama_index.core.query_engine import RetrieverQueryEngine
10
+ from llama_index.readers.file import PDFReader, DocxReader, CSVReader, ImageReader
11
+ from llama_index.llms.huggingface import HuggingFaceInferenceAPI
12
+ import os
13
+ from typing import List, Dict, Any
14
+
15
+ # LLM definitions
16
+ multimodal_llm = HuggingFaceInferenceAPI(
17
+ model_name="microsoft/Phi-3.5-vision-instruct",
18
+ token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
19
+ )
20
+
21
+ # Replace your current text_llm with:
22
+ text_llm = HuggingFaceInferenceAPI(
23
+ model_name="Qwen/Qwen2.5-72B-Instruct",
24
+ token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
25
+ )
26
+
27
+ class EnhancedRAGQueryEngine:
28
+ def __init__(self, task_context: str = ""):
29
+ self.task_context = task_context
30
+ self.embed_model = HuggingFaceEmbedding("BAAI/bge-small-en-v1.5")
31
+ self.reranker = SentenceTransformerRerank(model="cross-encoder/ms-marco-MiniLM-L-2-v2", top_n=5)
32
+
33
+ self.readers = {
34
+ '.pdf': PDFReader(),
35
+ '.docx': DocxReader(),
36
+ '.doc': DocxReader(),
37
+ '.csv': CSVReader(),
38
+ '.txt': lambda file_path: [Document(text=open(file_path, 'r').read())],
39
+ '.jpg': ImageReader(),
40
+ '.jpeg': ImageReader(),
41
+ '.png': ImageReader()
42
+ }
43
+
44
+ self.sentence_window_parser = SentenceWindowNodeParser.from_defaults(
45
+ window_size=3,
46
+ window_metadata_key="window",
47
+ original_text_metadata_key="original_text"
48
+ )
49
+
50
+ self.hierarchical_parser = HierarchicalNodeParser.from_defaults(
51
+ chunk_sizes=[2048, 512, 128]
52
+ )
53
+
54
+ def load_and_process_documents(self, file_paths: List[str]) -> List[Document]:
55
+ documents = []
56
+
57
+ for file_path in file_paths:
58
+ file_ext = os.path.splitext(file_path)[1].lower()
59
+
60
+ try:
61
+ if file_ext in self.readers:
62
+ reader = self.readers[file_ext]
63
+ if callable(reader):
64
+ docs = reader(file_path)
65
+ else:
66
+ docs = reader.load_data(file=file_path)
67
+
68
+ # Add metadata to all documents
69
+ for doc in docs:
70
+ doc.metadata.update({
71
+ "file_path": file_path,
72
+ "file_type": file_ext[1:],
73
+ "task_context": self.task_context
74
+ })
75
+ documents.extend(docs)
76
+
77
+ except Exception as e:
78
+ # Fallback to text reading
79
+ try:
80
+ with open(file_path, 'r', encoding='utf-8') as f:
81
+ content = f.read()
82
+ documents.append(Document(
83
+ text=content,
84
+ metadata={"file_path": file_path, "file_type": "text", "error": str(e)}
85
+ ))
86
+ except:
87
+ print(f"Failed to process {file_path}: {e}")
88
+
89
+ return documents
90
+
91
+ def create_advanced_index(self, documents: List[Document], use_hierarchical: bool = False) -> VectorStoreIndex:
92
+ if use_hierarchical or len(documents) > 10:
93
+ nodes = self.hierarchical_parser.get_nodes_from_documents(documents)
94
+ else:
95
+ nodes = self.sentence_window_parser.get_nodes_from_documents(documents)
96
+
97
+ index = VectorStoreIndex(
98
+ nodes,
99
+ embed_model=self.embed_model
100
+ )
101
+
102
+ return index
103
+
104
+ def create_context_aware_query_engine(self, index: VectorStoreIndex):
105
+ retriever = VectorIndexRetriever(
106
+ index=index,
107
+ similarity_top_k=10,
108
+ embed_model=self.embed_model
109
+ )
110
+
111
+ query_engine = RetrieverQueryEngine(
112
+ retriever=retriever,
113
+ node_postprocessors=[self.reranker],
114
+ llm=multimodal_llm
115
+ )
116
+
117
+ return query_engine
118
+
119
+ def comprehensive_rag_analysis(file_paths: List[str], query: str, task_context: str = "") -> str:
120
+ try:
121
+ rag_engine = EnhancedRAGQueryEngine(task_context)
122
+ documents = rag_engine.load_and_process_documents(file_paths)
123
+
124
+ if not documents:
125
+ return "No documents could be processed successfully."
126
+
127
+ total_text_length = sum(len(doc.text) for doc in documents)
128
+ use_hierarchical = total_text_length > 50000 or len(documents) > 5
129
+
130
+ index = rag_engine.create_advanced_index(documents, use_hierarchical)
131
+ query_engine = rag_engine.create_context_aware_query_engine(index)
132
+
133
+ enhanced_query = f"""
134
+ Task Context: {task_context}
135
+ Original Query: {query}
136
+
137
+ Please analyze the provided documents and answer the query with precise, factual information.
138
+ """
139
+
140
+ response = query_engine.query(enhanced_query)
141
+
142
+ result = f"**RAG Analysis Results:**\n\n"
143
+ result += f"**Documents Processed:** {len(documents)}\n"
144
+ result += f"**Answer:**\n{response.response}\n\n"
145
+
146
+ return result
147
+
148
+ except Exception as e:
149
+ return f"RAG analysis failed: {str(e)}"
150
+
151
+ def cross_document_analysis(file_paths: List[str], query: str, task_context: str = "") -> str:
152
+ try:
153
+ rag_engine = EnhancedRAGQueryEngine(task_context)
154
+ all_documents = []
155
+ document_groups = {}
156
+
157
+ for file_path in file_paths:
158
+ docs = rag_engine.load_and_process_documents([file_path])
159
+ doc_key = os.path.basename(file_path)
160
+ document_groups[doc_key] = docs
161
+
162
+ for doc in docs:
163
+ doc.metadata.update({
164
+ "document_group": doc_key,
165
+ "total_documents": len(file_paths)
166
+ })
167
+ all_documents.extend(docs)
168
+
169
+ index = rag_engine.create_advanced_index(all_documents, use_hierarchical=True)
170
+ query_engine = rag_engine.create_context_aware_query_engine(index)
171
+
172
+ response = query_engine.query(f"Task: {task_context}\nQuery: {query}")
173
+
174
+ result = f"**Cross-Document Analysis:**\n"
175
+ result += f"**Documents:** {list(document_groups.keys())}\n"
176
+ result += f"**Answer:**\n{response.response}\n"
177
+
178
+ return result
179
+
180
+ except Exception as e:
181
+ return f"Cross-document analysis failed: {str(e)}"
182
+
183
+ # Create tools
184
+ enhanced_rag_tool = FunctionTool.from_defaults(
185
+ fn=comprehensive_rag_analysis,
186
+ name="Enhanced RAG Analysis",
187
+ description="Comprehensive document analysis using advanced RAG with hybrid search and context-aware processing"
188
+ )
189
+
190
+ cross_document_tool = FunctionTool.from_defaults(
191
+ fn=cross_document_analysis,
192
+ name="Cross-Document Analysis",
193
+ description="Advanced analysis across multiple documents with cross-referencing capabilities"
194
+ )
195
+
196
+ # Analysis Agent
197
+ analysis_agent = FunctionAgent(
198
+ name="AnalysisAgent",
199
+ description="Advanced multimodal analysis using enhanced RAG with hybrid search and cross-document capabilities",
200
+ system_prompt="""
201
+ You are an advanced analysis specialist with access to:
202
+ - Enhanced RAG with hybrid search and reranking
203
+ - Multi-format document processing (PDF, Word, CSV, images, text)
204
+ - Cross-document analysis and synthesis
205
+ - Context-aware query processing
206
+
207
+ Your capabilities:
208
+ 1. Process multiple file types simultaneously
209
+ 2. Perform semantic search across document collections
210
+ 3. Cross-reference information between documents
211
+ 4. Extract precise information with source attribution
212
+ 5. Handle both text and visual content analysis
213
+
214
+ Always consider the GAIA task context and provide precise, well-sourced answers.
215
+ """,
216
+ llm=multimodal_llm,
217
+ tools=[enhanced_rag_tool, cross_document_tool],
218
+ can_handoff_to=["CodeAgent", "ResearchAgent"]
219
+ )
220
+
221
+ from llama_index.readers.web import SimpleWebPageReader
222
+ from llama_index.core.tools.ondemand_loader_tool import OnDemandLoaderTool
223
+ from llama_index.tools.arxiv import ArxivToolSpec
224
+ import duckduckgo_search as ddg
225
+ import re
226
+ from typing import List
227
+
228
+ class IntelligentSourceRouter:
229
+ def __init__(self):
230
+ # Initialize tools - only ArXiv and web search
231
+ self.arxiv_spec = ArxivToolSpec()
232
+
233
+ # Add web content loader
234
+ self.web_reader = SimpleWebPageReader()
235
+
236
+ # Create OnDemandLoaderTool for web content
237
+ self.web_loader_tool = OnDemandLoaderTool.from_defaults(
238
+ self.web_reader,
239
+ name="Web Content Loader",
240
+ description="Load and analyze web page content with intelligent chunking and search"
241
+ )
242
+
243
+ def web_search_fallback(self, query: str, max_results: int = 5) -> str:
244
+ try:
245
+ results = ddg.DDGS().text(query, max_results=max_results)
246
+ return "\n".join([f"{i}. **{r['title']}**\n URL: {r['href']}\n {r['body']}" for i, r in enumerate(results, 1)])
247
+ except Exception as e:
248
+ return f"Search failed: {str(e)}"
249
+
250
+ def extract_web_content(self, urls: List[str], query: str) -> str:
251
+ """Extract and analyze content from web URLs"""
252
+ try:
253
+ content_results = []
254
+ for url in urls[:3]: # Limit to top 3 URLs
255
+ try:
256
+ result = self.web_loader_tool.call(
257
+ urls=[url],
258
+ query=f"Extract information relevant to: {query}"
259
+ )
260
+ content_results.append(f"**Content from {url}:**\n{result}")
261
+ except Exception as e:
262
+ content_results.append(f"**Failed to load {url}**: {str(e)}")
263
+
264
+ return "\n\n".join(content_results)
265
+ except Exception as e:
266
+ return f"Content extraction failed: {str(e)}"
267
+
268
+ def detect_intent_and_route(self, query: str) -> str:
269
+ # Simple LLM-based discrimination: scientific vs non-scientific
270
+ intent_prompt = f"""
271
+ Analyze this query and determine if it's scientific research or general information:
272
+ Query: "{query}"
273
+
274
+ Choose ONE source:
275
+ - arxiv: For scientific research, academic papers, technical studies, algorithms, experiments
276
+ - web_search: For all other information (current events, general facts, weather, how-to guides, etc.)
277
+
278
+ Respond with ONLY "arxiv" or "web_search".
279
+ """
280
+
281
+ response = text_llm.complete(intent_prompt)
282
+ selected_source = response.text.strip().lower()
283
+
284
+ # Execute search and extract content
285
+ results = [f"**Query**: {query}", f"**Selected Source**: {selected_source}", "="*50]
286
+
287
+ try:
288
+ if selected_source == 'arxiv':
289
+ result = self.arxiv_spec.to_tool_list()[0].call(query=query, max_results=3)
290
+ results.append(f"**ArXiv Research:**\n{result}")
291
+
292
+ else: # Default to web_search for everything else
293
+ # Get search results
294
+ search_results = self.web_search_fallback(query, 5)
295
+ results.append(f"**Web Search Results:**\n{search_results}")
296
+
297
+ # Extract URLs and load content
298
+ urls = re.findall(r'URL: (https?://[^\s]+)', search_results)
299
+ if urls:
300
+ web_content = self.extract_web_content(urls, query)
301
+ results.append(f"**Extracted Web Content:**\n{web_content}")
302
+
303
+ except Exception as e:
304
+ results.append(f"**Search failed**: {str(e)}")
305
+
306
+ return "\n\n".join(results)
307
+
308
+ # Initialize router
309
+ intelligent_router = IntelligentSourceRouter()
310
+
311
+ # Create enhanced research tool
312
+ def enhanced_smart_research_tool(query: str, task_context: str = "", max_results: int = 5) -> str:
313
+ full_query = f"{query} {task_context}".strip()
314
+ return intelligent_router.detect_intent_and_route(full_query)
315
+
316
+ enhanced_research_tool_func = FunctionTool.from_defaults(
317
+ fn=enhanced_smart_research_tool,
318
+ name="Enhanced Research Tool",
319
+ description="Intelligent research tool that discriminates between scientific (ArXiv) and general (web) research with deep content extraction"
320
+ )
321
+
322
+ # Updated research agent
323
+ research_agent = FunctionAgent(
324
+ name="ResearchAgent",
325
+ description="Advanced research agent that automatically routes between scientific and general research sources",
326
+ system_prompt="""
327
+ You are an advanced research specialist that automatically discriminates between:
328
+
329
+ **Scientific Research** → ArXiv
330
+ - Academic papers, research studies
331
+ - Technical algorithms and methods
332
+ - Scientific experiments and theories
333
+
334
+ **General Research** → Web Search with Content Extraction
335
+ - Current events and news
336
+ - General factual information
337
+ - How-to guides and technical documentation
338
+ - Weather, locations, biographical info
339
+
340
+ You automatically:
341
+ 1. **Route queries** to the most appropriate source
342
+ 2. **Extract deep content** from web pages (not just snippets)
343
+ 3. **Analyze and synthesize** information comprehensively
344
+ 4. **Provide detailed answers** with source attribution
345
+
346
+ Always focus on extracting the most relevant information for the GAIA task.
347
+ """,
348
+ llm=text_llm,
349
+ tools=[enhanced_research_tool_func],
350
+ can_handoff_to=["AnalysisAgent", "CodeAgent"]
351
+ )
352
+
353
+
354
+ from llama_index.core.agent.workflow import ReActAgent
355
+
356
+ def execute_python_code(code: str) -> str:
357
+ try:
358
+ safe_globals = {
359
+ "__builtins__": {
360
+ "len": len, "str": str, "int": int, "float": float,
361
+ "list": list, "dict": dict, "sum": sum, "max": max, "min": min,
362
+ "round": round, "abs": abs, "sorted": sorted
363
+ },
364
+ "math": __import__("math"),
365
+ "datetime": __import__("datetime"),
366
+ "re": __import__("re")
367
+ }
368
+
369
+ exec_locals = {}
370
+ exec(code, safe_globals, exec_locals)
371
+
372
+ if 'result' in exec_locals:
373
+ return str(exec_locals['result'])
374
+ else:
375
+ return "Code executed successfully"
376
+
377
+ except Exception as e:
378
+ return f"Code execution failed: {str(e)}"
379
+
380
+ code_execution_tool = FunctionTool.from_defaults(
381
+ fn=execute_python_code,
382
+ name="Python Code Execution",
383
+ description="Execute Python code safely for calculations and data processing"
384
+ )
385
+
386
+ # Code Agent as ReActAgent
387
+ code_agent = ReActAgent(
388
+ name="CodeAgent",
389
+ description="Advanced calculations, data processing, and final answer synthesis using ReAct reasoning",
390
+ system_prompt="""
391
+ You are a coding and reasoning specialist using ReAct methodology.
392
+
393
+ For each task:
394
+ 1. THINK: Analyze what needs to be calculated or processed
395
+ 2. ACT: Execute appropriate code or calculations
396
+ 3. OBSERVE: Review results and determine if more work is needed
397
+ 4. REPEAT: Continue until you have the final answer
398
+
399
+ Always show your reasoning process clearly and provide exact answers as required by GAIA.
400
+ """,
401
+ llm=text_llm,
402
+ tools=[code_execution_tool],
403
+ can_handoff_to=["ResearchAgent", "AnalysisAgent"]
404
+ )
405
+
406
+ class TaskRouter:
407
+ def __init__(self):
408
+ self.agents = {
409
+ "AnalysisAgent": analysis_agent,
410
+ "ResearchAgent": research_agent,
411
+ "CodeAgent": code_agent
412
+ }
413
+
414
+ def route_task(self, question_data: Dict[str, Any]) -> str:
415
+ question = question_data.get("Question", "").lower()
416
+ has_files = "file_name" in question_data
417
+
418
+ # Routing logic
419
+ if has_files:
420
+ if any(keyword in question for keyword in ["image", "chart", "graph", "picture", "pdf", "document", "csv"]):
421
+ return "AnalysisAgent"
422
+
423
+ if any(keyword in question for keyword in ["calculate", "compute", "math", "number", "formula"]):
424
+ return "CodeAgent"
425
+
426
+ if any(keyword in question for keyword in ["search", "find", "who", "what", "when", "where", "research"]):
427
+ return "ResearchAgent"
428
+
429
+ return "AnalysisAgent" # Default
430
+
431
+ def get_agent(self, agent_name: str):
432
+ return self.agents.get(agent_name, self.agents["AnalysisAgent"])
433
+
434
+ class EnhancedGAIAAgent:
435
+ def __init__(self):
436
+ self.router = TaskRouter()
437
+
438
+ # Main ReActAgent that coordinates everything
439
+ self.main_agent = ReActAgent(
440
+ name="MainGAIAAgent",
441
+ description="Main GAIA agent that coordinates research, analysis, and computation to solve complex questions",
442
+ system_prompt="""
443
+ You are the main GAIA agent coordinator using ReAct reasoning methodology.
444
+
445
+ Your process:
446
+ 1. THINK: Analyze the GAIA question and determine what information/analysis is needed
447
+ 2. ACT: Delegate to appropriate specialist agents (Research, Analysis, Code)
448
+ 3. OBSERVE: Review the results from specialist agents
449
+ 4. THINK: Determine if you have enough information for a final answer
450
+ 5. ACT: Either request more information or provide the final answer
451
+
452
+ Available specialist agents:
453
+ - ResearchAgent: For web search, ArXiv research
454
+ - AnalysisAgent: For document/image analysis using RAG
455
+ - CodeAgent: For calculations and data processing
456
+
457
+ Always provide precise, exact answers as required by GAIA format.
458
+ """,
459
+ llm=text_llm,
460
+ tools=[
461
+ enhanced_research_tool_func,
462
+ enhanced_rag_tool,
463
+ cross_document_tool,
464
+ code_execution_tool
465
+ ]
466
+ )
467
+
468
+ async def solve_gaia_question(self, question_data: Dict[str, Any]) -> str:
469
+ question = question_data.get("Question", "")
470
+ task_id = question_data.get("task_id", "")
471
+
472
+ # Prepare comprehensive context
473
+ context_prompt = f"""
474
+ GAIA Task ID: {task_id}
475
+ Question: {question}
476
+
477
+ {'Associated files: ' + question_data.get('file_name', '') if 'file_name' in question_data else 'No files provided'}
478
+
479
+ Instructions:
480
+ 1. Analyze this GAIA question carefully using ReAct reasoning
481
+ 2. Determine what information, analysis, or calculations are needed
482
+ 3. Use appropriate tools to gather information and perform analysis
483
+ 4. Synthesize findings into a precise, exact answer
484
+ 5. Ensure your answer format matches GAIA requirements (exact, concise)
485
+
486
+ Begin your ReAct reasoning process now.
487
+ """
488
+
489
+ # Execute main agent
490
+ response = self.main_agent.chat(context_prompt)
491
+
492
+ return str(response)
tools.py DELETED
File without changes