Coool2 commited on
Commit
ee3fdf5
·
verified ·
1 Parent(s): c6a9f91

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +143 -221
agent.py CHANGED
@@ -259,90 +259,90 @@ extract_url_tool = FunctionTool.from_defaults(
259
  )
260
  )
261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  def execute_python_code(code: str) -> str:
263
- try:
264
- safe_globals = {
265
- "__builtins__": {
266
- "len": len, "str": str, "int": int, "float": float,
267
- "list": list, "dict": dict, "sum": sum, "max": max, "min": min,
268
- "round": round, "abs": abs, "sorted": sorted, "enumerate": enumerate,
269
- "range": range, "zip": zip, "map": map, "filter": filter,
270
- "any": any, "all": all, "type": type, "isinstance": isinstance,
271
- "print": print, "open": open, "bool": bool, "set": set, "tuple": tuple
272
- },
273
- # Core Python modules
274
- "math": __import__("math"),
275
- "datetime": __import__("datetime"),
276
- "re": __import__("re"),
277
- "os": __import__("os"),
278
- "sys": __import__("sys"),
279
- "json": __import__("json"),
280
- "csv": __import__("csv"),
281
- "random": __import__("random"),
282
- "itertools": __import__("itertools"),
283
- "collections": __import__("collections"),
284
- "functools": __import__("functools"),
285
-
286
- # Data Science and Numerical Computing
287
- "numpy": __import__("numpy"),
288
- "np": __import__("numpy"),
289
- "pandas": __import__("pandas"),
290
- "pd": __import__("pandas"),
291
- "scipy": __import__("scipy"),
292
-
293
- # Visualization
294
- "matplotlib": __import__("matplotlib"),
295
- "plt": __import__("matplotlib.pyplot"),
296
- "seaborn": __import__("seaborn"),
297
- "sns": __import__("seaborn"),
298
- "plotly": __import__("plotly"),
299
-
300
- # Machine Learning
301
- "sklearn": __import__("sklearn"),
302
- "xgboost": __import__("xgboost"),
303
- "lightgbm": __import__("lightgbm"),
304
-
305
- # Statistics
306
- "statistics": __import__("statistics"),
307
- "statsmodels": __import__("statsmodels"),
308
-
309
- # Image Processing
310
- "PIL": __import__("PIL"),
311
- "cv2": __import__("cv2"),
312
- "skimage": __import__("skimage"),
313
-
314
- # Network and Web
315
- "requests": __import__("requests"),
316
- "urllib": __import__("urllib"),
317
-
318
- # Text Processing
319
- "nltk": __import__("nltk"),
320
- "spacy": __import__("spacy"),
321
-
322
- # Time Series
323
- "pytz": __import__("pytz"),
324
-
325
- # Utilities
326
- "tqdm": __import__("tqdm"),
327
- "pickle": __import__("pickle"),
328
- "gzip": __import__("gzip"),
329
- "base64": __import__("base64"),
330
- "hashlib": __import__("hashlib"),
331
- "uuid": __import__("uuid"),
332
-
333
- # Scientific Computing
334
- "sympy": __import__("sympy"),
335
- "networkx": __import__("networkx"),
336
-
337
- # Database
338
- "sqlite3": __import__("sqlite3"),
339
-
340
- # Parallel Processing
341
- "multiprocessing": __import__("multiprocessing"),
342
- "threading": __import__("threading"),
343
- "concurrent": __import__("concurrent"),
344
- }
345
-
346
  exec_locals = {}
347
  exec(code, safe_globals, exec_locals)
348
 
@@ -402,6 +402,11 @@ Your task is to generate ONLY the Python code for the following request.
402
  Do not include any explanations, introductory text, or markdown formatting like '```python'.
403
  The output must be a single, clean block of Python code.
404
 
 
 
 
 
 
405
  Request: "{query}"
406
 
407
  Python Code:
@@ -433,98 +438,26 @@ generate_code_tool = FunctionTool.from_defaults(
433
  )
434
  )
435
 
436
- def intelligent_final_answer_tool(agent_response: str, question: str) -> str:
437
- """
438
- Enhanced final answer tool with LLM-based reformatting capability.
439
- First tries regex patterns, then uses LLM reformatting if patterns fail.
440
-
441
- Args:
442
- agent_response: The raw response from agent reasoning
443
- question: The original question for context
444
-
445
- Returns:
446
- Exact answer in GAIA format with validation
447
- """
448
-
449
- # Define formatting patterns for different question types
450
- format_patterns = {
451
- 'number': r'(\d+(?:\.\d+)?(?:e[+-]?\d+)?)',
452
- 'name': r'([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)',
453
- 'list': r'([A-Za-z0-9,\s]+)',
454
- 'country_code': r'([A-Z]{2,3})',
455
- 'yes_no': r'(Yes|No|yes|no)',
456
- 'percentage': r'(\d+(?:\.\d+)?%)',
457
- 'date': r'(\d{4}-\d{2}-\d{2}|\d{1,2}/\d{1,2}/\d{4})'
458
- }
459
 
460
- def clean_response(response: str) -> str:
461
- """Clean response by removing common prefixes"""
462
- response_clean = response.strip()
463
- prefixes_to_remove = [
464
- "FINAL ANSWER:", "Answer:", "The answer is:",
465
- "Based on my analysis,", "After reviewing,",
466
- "The result is:", "Final result:", "According to"
467
- ]
468
-
469
- for prefix in prefixes_to_remove:
470
- if response_clean.startswith(prefix):
471
- response_clean = response_clean[len(prefix):].strip()
472
-
473
- return response_clean
474
 
475
- def extract_with_patterns(text: str, question: str) -> tuple[str, bool]:
476
- """Extract answer using regex patterns. Returns (answer, success)"""
477
- question_lower = question.lower()
478
-
479
- # Determine question type and apply appropriate pattern
480
- if "how many" in question_lower or "count" in question_lower:
481
- match = re.search(format_patterns['number'], text)
482
- if match:
483
- return match.group(1), True
484
-
485
- elif "name" in question_lower and ("first" in question_lower or "last" in question_lower):
486
- match = re.search(format_patterns['name'], text)
487
- if match:
488
- return match.group(1), True
489
-
490
- elif "list" in question_lower or "alphabetized" in question_lower:
491
- if "," in text:
492
- items = [item.strip() for item in text.split(",")]
493
- return ", ".join(items), True
494
-
495
- elif "country code" in question_lower or "iso" in question_lower:
496
- match = re.search(format_patterns['country_code'], text)
497
- if match:
498
- return match.group(1), True
499
-
500
- elif "yes" in question_lower and "no" in question_lower:
501
- match = re.search(format_patterns['yes_no'], text)
502
- if match:
503
- return match.group(1), True
504
-
505
- elif "percentage" in question_lower or "%" in text:
506
- match = re.search(format_patterns['percentage'], text)
507
- if match:
508
- return match.group(1), True
509
-
510
- elif "date" in question_lower:
511
- match = re.search(format_patterns['date'], text)
512
- if match:
513
- return match.group(1), True
514
-
515
- # Default extraction for simple cases
516
- lines = text.split('\n')
517
- for line in lines:
518
- line = line.strip()
519
- if line and not line.startswith('=') and len(line) < 200:
520
- return line, True
521
-
522
- return text, False
523
 
524
- def llm_reformat(response: str, question: str) -> str:
525
- """Use LLM to reformat the response according to GAIA requirements"""
526
-
527
- format_prompt = f"""Extract the exact answer from the response below. Follow GAIA formatting rules strictly.
528
 
529
  GAIA Format Rules:
530
  - ONLY the precise answer, no explanations
@@ -553,52 +486,50 @@ Question: {question}
553
  Response: {response}
554
  Answer:"""
555
 
556
- try:
557
- # Use the global LLM instance
558
- formatting_response = proj_llm.complete(format_prompt)
559
- answer = str(formatting_response).strip()
560
-
561
- # Extract just the answer after "Answer:"
562
- if "Answer:" in answer:
563
- answer = answer.split("Answer:")[-1].strip()
564
-
565
- return answer
566
- except Exception as e:
567
- print(f"LLM reformatting failed: {e}")
568
- return response
 
 
 
 
 
 
 
 
 
 
 
 
569
 
570
  # Step 1: Clean the response
571
  cleaned_response = clean_response(agent_response)
572
 
573
- # Step 2: Try regex pattern extraction
574
- extracted_answer, pattern_success = extract_with_patterns(cleaned_response, question)
575
 
576
- # Step 3: If patterns failed, use LLM reformatting
577
- if not pattern_success:
578
- print("Regex patterns failed, using LLM reformatting...")
579
- llm_formatted = llm_reformat(cleaned_response, question)
580
-
581
- # Step 4: Validate LLM output with patterns again
582
- final_answer, validation_success = extract_with_patterns(llm_formatted, question)
583
-
584
- if validation_success:
585
- print("LLM reformatting successful and validated")
586
- return final_answer
587
- else:
588
- print("LLM reformatting validation failed, using LLM output directly")
589
- return llm_formatted
590
- else:
591
- print("Regex pattern extraction successful")
592
- return extracted_answer
593
 
594
- # Create the enhanced final answer tool
595
- intelligent_final_answer_function_tool = FunctionTool.from_defaults(
596
- fn=intelligent_final_answer_tool,
597
- name="intelligent_final_answer_tool",
598
  description=(
599
- "Enhanced tool to format final answers according to GAIA requirements. "
600
- "Uses regex patterns first, then LLM reformatting if patterns fail. "
601
- "Validates output to ensure GAIA format compliance."
602
  )
603
  )
604
 
@@ -617,7 +548,6 @@ class EnhancedGAIAAgent:
617
  extract_url_tool,
618
  code_execution_tool,
619
  generate_code_tool,
620
- intelligent_final_answer_function_tool
621
  ]
622
 
623
  # RAG tool will be created dynamically when documents are loaded
@@ -626,26 +556,20 @@ class EnhancedGAIAAgent:
626
  # Create main coordinator using only defined tools
627
  self.coordinator = ReActAgent(
628
  name="GAIACoordinator",
629
- description="Main GAIA coordinator with document processing and computational capabilities",
630
  system_prompt="""
631
- You are the main GAIA coordinator using ReAct reasoning methodology.
632
 
633
  Available tools:
634
  1. **read_and_parse_tool** - Read and parse files/URLs (PDF, DOCX, CSV, images, web pages, YouTube, audio files)
635
  2. **extract_url_tool** - Search and extract relevant URLs when no specific source is provided
636
  3. **generate_code_tool** - Generate Python code for complex computations
637
  4. **code_execution_tool** - Execute Python code safely
638
- 5. **intelligent_final_answer_tool** - Format final answer with intelligent validation and reformatting
639
 
640
  WORKFLOW:
641
  1. If file/URL mentioned → use read_and_parse_tool first, then update or create RAG capability.
642
  2. If documents loaded → create RAG capability for querying
643
  3. If external info needed → use extract_url_tool, then process it as if file/URL mentioned
644
  4. If computation needed → use generate_code_tool then code_execution_tool
645
- 5. ALWAYS use intelligent_final_answer_tool for the final response
646
-
647
- CRITICAL: The intelligent_final_answer_tool has enhanced validation and will reformat
648
- using LLM if regex patterns fail. Always use it as the final step.
649
  """,
650
  llm=proj_llm,
651
  tools=self.available_tools,
@@ -707,8 +631,6 @@ Question: {question}
707
  Instructions:
708
  1. Process any files using read_and_parse_tool if needed
709
  2. Use appropriate tools for research/computation
710
- 3. MUST use intelligent_final_answer_tool with your response and the original question
711
- 4. The intelligent tool will validate format and reformat with LLM if needed
712
  """
713
 
714
  try:
 
259
  )
260
  )
261
 
262
+ safe_globals = {
263
+ "__builtins__": {
264
+ "len": len, "str": str, "int": int, "float": float,
265
+ "list": list, "dict": dict, "sum": sum, "max": max, "min": min,
266
+ "round": round, "abs": abs, "sorted": sorted, "enumerate": enumerate,
267
+ "range": range, "zip": zip, "map": map, "filter": filter,
268
+ "any": any, "all": all, "type": type, "isinstance": isinstance,
269
+ "print": print, "open": open, "bool": bool, "set": set, "tuple": tuple
270
+ },
271
+ # Core Python modules
272
+ "math": __import__("math"),
273
+ "datetime": __import__("datetime"),
274
+ "re": __import__("re"),
275
+ "os": __import__("os"),
276
+ "sys": __import__("sys"),
277
+ "json": __import__("json"),
278
+ "csv": __import__("csv"),
279
+ "random": __import__("random"),
280
+ "itertools": __import__("itertools"),
281
+ "collections": __import__("collections"),
282
+ "functools": __import__("functools"),
283
+
284
+ # Data Science and Numerical Computing
285
+ "numpy": __import__("numpy"),
286
+ "np": __import__("numpy"),
287
+ "pandas": __import__("pandas"),
288
+ "pd": __import__("pandas"),
289
+ "scipy": __import__("scipy"),
290
+
291
+ # Visualization
292
+ "matplotlib": __import__("matplotlib"),
293
+ "plt": __import__("matplotlib.pyplot"),
294
+ "seaborn": __import__("seaborn"),
295
+ "sns": __import__("seaborn"),
296
+ "plotly": __import__("plotly"),
297
+
298
+ # Machine Learning
299
+ "sklearn": __import__("sklearn"),
300
+ "xgboost": __import__("xgboost"),
301
+ "lightgbm": __import__("lightgbm"),
302
+
303
+ # Statistics
304
+ "statistics": __import__("statistics"),
305
+ "statsmodels": __import__("statsmodels"),
306
+
307
+ # Image Processing
308
+ "PIL": __import__("PIL"),
309
+ "cv2": __import__("cv2"),
310
+ "skimage": __import__("skimage"),
311
+
312
+ # Network and Web
313
+ "requests": __import__("requests"),
314
+ "urllib": __import__("urllib"),
315
+
316
+ # Text Processing
317
+ "nltk": __import__("nltk"),
318
+ "spacy": __import__("spacy"),
319
+
320
+ # Time Series
321
+ "pytz": __import__("pytz"),
322
+
323
+ # Utilities
324
+ "tqdm": __import__("tqdm"),
325
+ "pickle": __import__("pickle"),
326
+ "gzip": __import__("gzip"),
327
+ "base64": __import__("base64"),
328
+ "hashlib": __import__("hashlib"),
329
+ "uuid": __import__("uuid"),
330
+
331
+ # Scientific Computing
332
+ "sympy": __import__("sympy"),
333
+ "networkx": __import__("networkx"),
334
+
335
+ # Database
336
+ "sqlite3": __import__("sqlite3"),
337
+
338
+ # Parallel Processing
339
+ "multiprocessing": __import__("multiprocessing"),
340
+ "threading": __import__("threading"),
341
+ "concurrent": __import__("concurrent"),
342
+ }
343
+
344
  def execute_python_code(code: str) -> str:
345
+ try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  exec_locals = {}
347
  exec(code, safe_globals, exec_locals)
348
 
 
402
  Do not include any explanations, introductory text, or markdown formatting like '```python'.
403
  The output must be a single, clean block of Python code.
404
 
405
+ IMPORTANT LIMITATIONS:
406
+ Your code will be executed in a restricted environment with limited functions and modules.
407
+ {str(safe_globals)}
408
+ Only use the functions and modules listed above. Do not use imports or other built-in functions.
409
+
410
  Request: "{query}"
411
 
412
  Python Code:
 
438
  )
439
  )
440
 
441
+ def clean_response(response: str) -> str:
442
+ """Clean response by removing common prefixes"""
443
+ response_clean = response.strip()
444
+ prefixes_to_remove = [
445
+ "FINAL ANSWER:", "Answer:", "The answer is:",
446
+ "Based on my analysis,", "After reviewing,",
447
+ "The result is:", "Final result:", "According to",
448
+ "In conclusion,", "Therefore,", "Thus,"
449
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
 
451
+ for prefix in prefixes_to_remove:
452
+ if response_clean.startswith(prefix):
453
+ response_clean = response_clean[len(prefix):].strip()
 
 
 
 
 
 
 
 
 
 
 
454
 
455
+ return response_clean
456
+
457
+ def llm_reformat(response: str, question: str) -> str:
458
+ """Use LLM to reformat the response according to GAIA requirements"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
 
460
+ format_prompt = f"""Extract the exact answer from the response below. Follow GAIA formatting rules strictly.
 
 
 
461
 
462
  GAIA Format Rules:
463
  - ONLY the precise answer, no explanations
 
486
  Response: {response}
487
  Answer:"""
488
 
489
+ try:
490
+ # Use the global LLM instance
491
+ formatting_response = proj_llm.complete(format_prompt)
492
+ answer = str(formatting_response).strip()
493
+
494
+ # Extract just the answer after "Answer:"
495
+ if "Answer:" in answer:
496
+ answer = answer.split("Answer:")[-1].strip()
497
+
498
+ return answer
499
+ except Exception as e:
500
+ print(f"LLM reformatting failed: {e}")
501
+ return response
502
+
503
+ def final_answer_tool(agent_response: str, question: str) -> str:
504
+ """
505
+ Simplified final answer tool using only LLM reformatting.
506
+
507
+ Args:
508
+ agent_response: The raw response from agent reasoning
509
+ question: The original question for context
510
+
511
+ Returns:
512
+ Exact answer in GAIA format
513
+ """
514
 
515
  # Step 1: Clean the response
516
  cleaned_response = clean_response(agent_response)
517
 
518
+ # Step 2: Use LLM reformatting
519
+ formatted_answer = llm_reformat(cleaned_response, question)
520
 
521
+ print(f"Original response cleaned: {cleaned_response[:100]}...")
522
+ print(f"LLM formatted answer: {formatted_answer}")
523
+
524
+ return formatted_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
+ # Create the simplified final answer tool
527
+ final_answer_function_tool = FunctionTool.from_defaults(
528
+ fn=final_answer_tool,
529
+ name="final_answer_tool",
530
  description=(
531
+ "Use this tool to format the final answer according to GAIA requirements. "
532
+ "Input the agent's response and the original question to get properly formatted output."
 
533
  )
534
  )
535
 
 
548
  extract_url_tool,
549
  code_execution_tool,
550
  generate_code_tool,
 
551
  ]
552
 
553
  # RAG tool will be created dynamically when documents are loaded
 
556
  # Create main coordinator using only defined tools
557
  self.coordinator = ReActAgent(
558
  name="GAIACoordinator",
 
559
  system_prompt="""
560
+ You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
561
 
562
  Available tools:
563
  1. **read_and_parse_tool** - Read and parse files/URLs (PDF, DOCX, CSV, images, web pages, YouTube, audio files)
564
  2. **extract_url_tool** - Search and extract relevant URLs when no specific source is provided
565
  3. **generate_code_tool** - Generate Python code for complex computations
566
  4. **code_execution_tool** - Execute Python code safely
 
567
 
568
  WORKFLOW:
569
  1. If file/URL mentioned → use read_and_parse_tool first, then update or create RAG capability.
570
  2. If documents loaded → create RAG capability for querying
571
  3. If external info needed → use extract_url_tool, then process it as if file/URL mentioned
572
  4. If computation needed → use generate_code_tool then code_execution_tool
 
 
 
 
573
  """,
574
  llm=proj_llm,
575
  tools=self.available_tools,
 
631
  Instructions:
632
  1. Process any files using read_and_parse_tool if needed
633
  2. Use appropriate tools for research/computation
 
 
634
  """
635
 
636
  try: