Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Test Phase 1 Improvements - Tool Execution and Answer Formatting | |
This script tests the critical fixes implemented in Phase 1: | |
1. Tool execution debugging and validation | |
2. Enhanced answer formatting with multiple patterns | |
3. GAIA format compliance validation | |
4. Comprehensive error handling and fallback systems | |
Usage: | |
python test_phase1_improvements.py | |
""" | |
import os | |
import sys | |
import logging | |
from pathlib import Path | |
# Add the deployment-ready directory to the path | |
sys.path.insert(0, str(Path(__file__).parent)) | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
def test_tool_execution_debugger(): | |
"""Test the ToolExecutionDebugger functionality.""" | |
logger.info("π§ Testing ToolExecutionDebugger...") | |
try: | |
from utils.tool_execution_debugger import ToolExecutionDebugger | |
debugger = ToolExecutionDebugger() | |
# Test JSON syntax detection | |
test_responses = [ | |
"The answer is 42", # Normal response | |
'{"function": "calculator", "parameters": {"expression": "2+2"}}', # JSON syntax issue | |
"FINAL ANSWER: 42", # Proper format | |
"I need to use the calculator tool: {\"tool\": \"calc\"}", # Mixed content | |
] | |
for i, response in enumerate(test_responses): | |
issues = debugger.detect_json_syntax_in_response(response) | |
logger.info(f" Test {i+1}: {'β Issues detected' if issues else 'β Clean'} - {issues}") | |
# Test tool validation | |
class MockTool: | |
def __init__(self, name): | |
self.name = name | |
def __class__(self): | |
return type(self.name, (), {}) | |
mock_tool = MockTool("TestTool") | |
validation = debugger.validate_tool_registration("TestTool", mock_tool) | |
logger.info(f" Tool validation: {validation}") | |
# Get debug stats | |
stats = debugger.get_debug_stats() | |
logger.info(f" Debug stats: {stats}") | |
logger.info("β ToolExecutionDebugger tests passed") | |
return True | |
except Exception as e: | |
logger.error(f"β ToolExecutionDebugger test failed: {e}") | |
return False | |
def test_enhanced_answer_formatter(): | |
"""Test the EnhancedGAIAAnswerFormatter functionality.""" | |
logger.info("π― Testing EnhancedGAIAAnswerFormatter...") | |
try: | |
from utils.enhanced_gaia_answer_formatter import EnhancedGAIAAnswerFormatter | |
formatter = EnhancedGAIAAnswerFormatter() | |
# Test cases covering different answer types and formats | |
test_cases = [ | |
# Number formatting | |
{ | |
'input': "The calculation gives us 1,234.50 as the result.", | |
'question': "What is 1000 + 234.5?", | |
'expected_type': 'number', | |
'description': 'Number with comma removal' | |
}, | |
{ | |
'input': "FINAL ANSWER: 42", | |
'question': "How many items are there?", | |
'expected_type': 'number', | |
'description': 'Simple FINAL ANSWER format' | |
}, | |
# String formatting | |
{ | |
'input': "The capital of France is Paris.", | |
'question': "What is the capital of France?", | |
'expected_type': 'string', | |
'description': 'String extraction from sentence' | |
}, | |
{ | |
'input': 'FINAL ANSWER: "The Eiffel Tower"', | |
'question': "What is the famous tower in Paris?", | |
'expected_type': 'string', | |
'description': 'String with quotes removal' | |
}, | |
# List formatting | |
{ | |
'input': "The colors are red, blue, and green.", | |
'question': "List three primary colors", | |
'expected_type': 'list', | |
'description': 'List with "and" removal' | |
}, | |
{ | |
'input': "FINAL ANSWER: apple; banana; orange", | |
'question': "Name three fruits", | |
'expected_type': 'list', | |
'description': 'List with semicolon separation' | |
}, | |
# Boolean formatting | |
{ | |
'input': "Yes, Paris is in France.", | |
'question': "Is Paris in France?", | |
'expected_type': 'boolean', | |
'description': 'Boolean yes answer' | |
}, | |
{ | |
'input': "No, that is incorrect.", | |
'question': "Is London in Germany?", | |
'expected_type': 'boolean', | |
'description': 'Boolean no answer' | |
}, | |
# Complex cases | |
{ | |
'input': "After analyzing the data, I can conclude that the answer is 3.14159.", | |
'question': "What is the value of pi to 5 decimal places?", | |
'expected_type': 'number', | |
'description': 'Number extraction from complex text' | |
}, | |
{ | |
'input': "Let me search for this information... The result shows that Einstein was born in 1879.", | |
'question': "When was Einstein born?", | |
'expected_type': 'number', | |
'description': 'Year extraction from narrative' | |
} | |
] | |
results = [] | |
for i, test_case in enumerate(test_cases): | |
try: | |
formatted = formatter.format_answer(test_case['input'], test_case['question']) | |
results.append({ | |
'test': i + 1, | |
'description': test_case['description'], | |
'input': test_case['input'][:50] + "..." if len(test_case['input']) > 50 else test_case['input'], | |
'output': formatted, | |
'status': 'β Success' | |
}) | |
logger.info(f" Test {i+1}: β {test_case['description']} β '{formatted}'") | |
except Exception as e: | |
results.append({ | |
'test': i + 1, | |
'description': test_case['description'], | |
'input': test_case['input'][:50] + "..." if len(test_case['input']) > 50 else test_case['input'], | |
'output': f"Error: {e}", | |
'status': 'β Failed' | |
}) | |
logger.error(f" Test {i+1}: β {test_case['description']} failed: {e}") | |
# Get formatting statistics | |
stats = formatter.get_formatting_stats() | |
logger.info(f" Formatting stats: {stats}") | |
# Summary | |
successful_tests = sum(1 for r in results if r['status'] == 'β Success') | |
logger.info(f"β Enhanced formatter tests: {successful_tests}/{len(test_cases)} passed") | |
return successful_tests == len(test_cases) | |
except Exception as e: | |
logger.error(f"β EnhancedGAIAAnswerFormatter test failed: {e}") | |
return False | |
def test_agent_integration(): | |
"""Test the integration of improvements in the main agent.""" | |
logger.info("π€ Testing agent integration...") | |
try: | |
# Check if MISTRAL_API_KEY is available | |
if not os.getenv("MISTRAL_API_KEY"): | |
logger.warning("β οΈ MISTRAL_API_KEY not found - skipping agent integration test") | |
return True | |
from agents.enhanced_unified_agno_agent import GAIAAgent | |
# Initialize agent | |
agent = GAIAAgent() | |
if not agent.available: | |
logger.warning("β οΈ Agent not available - check API key and dependencies") | |
return False | |
# Test tool status | |
tool_status = agent.get_tool_status() | |
logger.info(f" Tool status: {tool_status}") | |
# Test simple question (if agent is available) | |
test_question = "What is 2 + 2?" | |
logger.info(f" Testing question: {test_question}") | |
try: | |
response = agent(test_question) | |
logger.info(f" Response: {response}") | |
# Check if response is properly formatted | |
if response and response != "Agent not available" and response != "Unable to process this question": | |
logger.info("β Agent integration test passed") | |
return True | |
else: | |
logger.warning("β οΈ Agent returned error response") | |
return False | |
except Exception as e: | |
logger.error(f"β Agent execution failed: {e}") | |
return False | |
except Exception as e: | |
logger.error(f"β Agent integration test failed: {e}") | |
return False | |
def run_phase1_tests(): | |
"""Run all Phase 1 improvement tests.""" | |
logger.info("π Starting Phase 1 Improvement Tests") | |
logger.info("=" * 60) | |
test_results = {} | |
# Test 1: Tool Execution Debugger | |
test_results['tool_debugger'] = test_tool_execution_debugger() | |
# Test 2: Enhanced Answer Formatter | |
test_results['answer_formatter'] = test_enhanced_answer_formatter() | |
# Test 3: Agent Integration | |
test_results['agent_integration'] = test_agent_integration() | |
# Summary | |
logger.info("=" * 60) | |
logger.info("π Phase 1 Test Results Summary:") | |
total_tests = len(test_results) | |
passed_tests = sum(1 for result in test_results.values() if result) | |
for test_name, result in test_results.items(): | |
status = "β PASSED" if result else "β FAILED" | |
logger.info(f" {test_name}: {status}") | |
logger.info(f"\nOverall: {passed_tests}/{total_tests} tests passed") | |
if passed_tests == total_tests: | |
logger.info("π All Phase 1 improvements are working correctly!") | |
logger.info("π Ready to proceed with Phase 2 (Answer Formatting Enhancement)") | |
else: | |
logger.warning("β οΈ Some tests failed - review logs and fix issues before proceeding") | |
return passed_tests == total_tests | |
if __name__ == "__main__": | |
success = run_phase1_tests() | |
sys.exit(0 if success else 1) |