gaia-enhanced-agent / tests /test_agent_prompt_enhancer_integration.py
GAIA Agent Deployment
Deploy Complete Enhanced GAIA Agent with Phase 1-6 Improvements
9a6a4dc
"""
Test integration of calculator prompt enhancer with Fixed GAIA Agent.
Verifies that exponentiation operations are properly enhanced.
"""
import pytest
import logging
from unittest.mock import Mock, patch, MagicMock
# Configure test environment
from utils.environment_setup import setup_test_environment
setup_test_environment()
from agents.fixed_enhanced_unified_agno_agent import FixedGAIAAgent
from utils.calculator_prompt_enhancer import CalculatorPromptEnhancer
logger = logging.getLogger(__name__)
class TestAgentPromptEnhancerIntegration:
"""Test integration of prompt enhancer with GAIA agent."""
def setup_method(self):
"""Set up test fixtures."""
self.enhancer = CalculatorPromptEnhancer()
def test_agent_has_prompt_enhancer(self):
"""Test that agent initializes with prompt enhancer."""
with patch.dict('os.environ', {'MISTRAL_API_KEY': 'test-key'}):
with patch('agents.fixed_enhanced_unified_agno_agent.MistralChat'):
with patch('agents.fixed_enhanced_unified_agno_agent.Agent'):
agent = FixedGAIAAgent()
# Verify prompt enhancer is initialized
assert hasattr(agent, 'prompt_enhancer')
assert isinstance(agent.prompt_enhancer, CalculatorPromptEnhancer)
logger.info("βœ… Agent has prompt enhancer initialized")
def test_exponentiation_question_enhancement(self):
"""Test that exponentiation questions are enhanced."""
# Test questions with exponentiation
test_cases = [
{
'question': 'Calculate 2^8',
'should_enhance': True,
'description': 'caret notation'
},
{
'question': 'What is 3**4?',
'should_enhance': True,
'description': 'double asterisk notation'
},
{
'question': 'Compute 2 to the power of 8',
'should_enhance': True,
'description': 'power of notation'
},
{
'question': 'What is 5 squared?',
'should_enhance': True,
'description': 'squared notation'
},
{
'question': 'Calculate 25 * 17',
'should_enhance': False,
'description': 'regular multiplication'
},
{
'question': 'What is 144 / 12?',
'should_enhance': False,
'description': 'division operation'
}
]
for case in test_cases:
question = case['question']
should_enhance = case['should_enhance']
description = case['description']
# Test enhancement
enhanced = self.enhancer.enhance_prompt_for_exponentiation(question)
is_enhanced = enhanced != question
assert is_enhanced == should_enhance, f"Enhancement mismatch for {description}: '{question}'"
if should_enhance:
# Verify enhancement contains Python guidance
assert 'python' in enhanced.lower() or 'pow(' in enhanced or '**' in enhanced
logger.info(f"βœ… Enhanced {description}: {question}")
else:
logger.info(f"βœ… No enhancement needed for {description}: {question}")
@patch('agents.fixed_enhanced_unified_agno_agent.MistralChat')
@patch('agents.fixed_enhanced_unified_agno_agent.Agent')
def test_agent_uses_enhanced_prompt(self, mock_agent_class, mock_mistral_class):
"""Test that agent uses enhanced prompts for exponentiation."""
# Mock the agent run method
mock_agent_instance = Mock()
mock_agent_instance.run = Mock(return_value=Mock(content="FINAL ANSWER: 256"))
mock_agent_class.return_value = mock_agent_instance
# Mock Mistral
mock_mistral_class.return_value = Mock()
with patch.dict('os.environ', {'MISTRAL_API_KEY': 'test-key'}):
agent = FixedGAIAAgent()
# Test with exponentiation question
question = "Calculate 2^8"
result = agent(question)
# Verify agent.run was called
assert mock_agent_instance.run.called
# Get the actual prompt passed to agent.run
call_args = mock_agent_instance.run.call_args
actual_prompt = call_args[0][0] # First positional argument
# Verify the prompt was enhanced (should be longer and contain guidance)
assert len(actual_prompt) > len(question)
assert 'python' in actual_prompt.lower() or 'pow(' in actual_prompt or '**' in actual_prompt
logger.info(f"βœ… Agent used enhanced prompt for exponentiation")
logger.info(f" Original: {question}")
logger.info(f" Enhanced length: {len(actual_prompt)} vs {len(question)}")
@patch('agents.fixed_enhanced_unified_agno_agent.MistralChat')
@patch('agents.fixed_enhanced_unified_agno_agent.Agent')
def test_agent_no_enhancement_for_regular_math(self, mock_agent_class, mock_mistral_class):
"""Test that agent doesn't enhance regular math questions."""
# Mock the agent run method
mock_agent_instance = Mock()
mock_agent_instance.run = Mock(return_value=Mock(content="FINAL ANSWER: 425"))
mock_agent_class.return_value = mock_agent_instance
# Mock Mistral
mock_mistral_class.return_value = Mock()
with patch.dict('os.environ', {'MISTRAL_API_KEY': 'test-key'}):
agent = FixedGAIAAgent()
# Test with regular math question
question = "Calculate 25 * 17"
result = agent(question)
# Verify agent.run was called
assert mock_agent_instance.run.called
# Get the actual prompt passed to agent.run
call_args = mock_agent_instance.run.call_args
actual_prompt = call_args[0][0] # First positional argument
# Verify the prompt was NOT enhanced (should be the same)
assert actual_prompt == question
logger.info(f"βœ… Agent did not enhance regular math question")
logger.info(f" Question: {question}")
def test_enhancement_preserves_file_context(self):
"""Test that enhancement works with file context."""
# Simulate a question with file context
base_question = "Calculate 2^8"
file_context = "File 1: data.csv (CSV format), 1024 bytes\nCSV Data: numbers,values\n1,2\n3,4"
question_with_files = f"{base_question}\n\nFile Context:\n{file_context}"
# Test enhancement
enhanced = self.enhancer.enhance_prompt_for_exponentiation(question_with_files)
# Verify enhancement occurred
assert enhanced != question_with_files
assert len(enhanced) > len(question_with_files)
# Verify file context is preserved
assert file_context in enhanced
# Verify exponentiation guidance is added
assert 'python' in enhanced.lower() or 'pow(' in enhanced or '**' in enhanced
logger.info("βœ… Enhancement preserves file context")
logger.info(f" Original length: {len(question_with_files)}")
logger.info(f" Enhanced length: {len(enhanced)}")
if __name__ == "__main__":
# Run tests with verbose output
pytest.main([__file__, "-v", "-s"])