File size: 7,791 Bytes
9a6a4dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""
Test integration of calculator prompt enhancer with Fixed GAIA Agent.
Verifies that exponentiation operations are properly enhanced.
"""

import pytest
import logging
from unittest.mock import Mock, patch, MagicMock

# Configure test environment
from utils.environment_setup import setup_test_environment
setup_test_environment()

from agents.fixed_enhanced_unified_agno_agent import FixedGAIAAgent
from utils.calculator_prompt_enhancer import CalculatorPromptEnhancer

logger = logging.getLogger(__name__)


class TestAgentPromptEnhancerIntegration:
    """Test integration of prompt enhancer with GAIA agent."""
    
    def setup_method(self):
        """Set up test fixtures."""
        self.enhancer = CalculatorPromptEnhancer()
    
    def test_agent_has_prompt_enhancer(self):
        """Test that agent initializes with prompt enhancer."""
        with patch.dict('os.environ', {'MISTRAL_API_KEY': 'test-key'}):
            with patch('agents.fixed_enhanced_unified_agno_agent.MistralChat'):
                with patch('agents.fixed_enhanced_unified_agno_agent.Agent'):
                    agent = FixedGAIAAgent()
                    
                    # Verify prompt enhancer is initialized
                    assert hasattr(agent, 'prompt_enhancer')
                    assert isinstance(agent.prompt_enhancer, CalculatorPromptEnhancer)
                    logger.info("βœ… Agent has prompt enhancer initialized")
    
    def test_exponentiation_question_enhancement(self):
        """Test that exponentiation questions are enhanced."""
        # Test questions with exponentiation
        test_cases = [
            {
                'question': 'Calculate 2^8',
                'should_enhance': True,
                'description': 'caret notation'
            },
            {
                'question': 'What is 3**4?',
                'should_enhance': True,
                'description': 'double asterisk notation'
            },
            {
                'question': 'Compute 2 to the power of 8',
                'should_enhance': True,
                'description': 'power of notation'
            },
            {
                'question': 'What is 5 squared?',
                'should_enhance': True,
                'description': 'squared notation'
            },
            {
                'question': 'Calculate 25 * 17',
                'should_enhance': False,
                'description': 'regular multiplication'
            },
            {
                'question': 'What is 144 / 12?',
                'should_enhance': False,
                'description': 'division operation'
            }
        ]
        
        for case in test_cases:
            question = case['question']
            should_enhance = case['should_enhance']
            description = case['description']
            
            # Test enhancement
            enhanced = self.enhancer.enhance_prompt_for_exponentiation(question)
            is_enhanced = enhanced != question
            
            assert is_enhanced == should_enhance, f"Enhancement mismatch for {description}: '{question}'"
            
            if should_enhance:
                # Verify enhancement contains Python guidance
                assert 'python' in enhanced.lower() or 'pow(' in enhanced or '**' in enhanced
                logger.info(f"βœ… Enhanced {description}: {question}")
            else:
                logger.info(f"βœ… No enhancement needed for {description}: {question}")
    
    @patch('agents.fixed_enhanced_unified_agno_agent.MistralChat')
    @patch('agents.fixed_enhanced_unified_agno_agent.Agent')
    def test_agent_uses_enhanced_prompt(self, mock_agent_class, mock_mistral_class):
        """Test that agent uses enhanced prompts for exponentiation."""
        # Mock the agent run method
        mock_agent_instance = Mock()
        mock_agent_instance.run = Mock(return_value=Mock(content="FINAL ANSWER: 256"))
        mock_agent_class.return_value = mock_agent_instance
        
        # Mock Mistral
        mock_mistral_class.return_value = Mock()
        
        with patch.dict('os.environ', {'MISTRAL_API_KEY': 'test-key'}):
            agent = FixedGAIAAgent()
            
            # Test with exponentiation question
            question = "Calculate 2^8"
            result = agent(question)
            
            # Verify agent.run was called
            assert mock_agent_instance.run.called
            
            # Get the actual prompt passed to agent.run
            call_args = mock_agent_instance.run.call_args
            actual_prompt = call_args[0][0]  # First positional argument
            
            # Verify the prompt was enhanced (should be longer and contain guidance)
            assert len(actual_prompt) > len(question)
            assert 'python' in actual_prompt.lower() or 'pow(' in actual_prompt or '**' in actual_prompt
            
            logger.info(f"βœ… Agent used enhanced prompt for exponentiation")
            logger.info(f"   Original: {question}")
            logger.info(f"   Enhanced length: {len(actual_prompt)} vs {len(question)}")
    
    @patch('agents.fixed_enhanced_unified_agno_agent.MistralChat')
    @patch('agents.fixed_enhanced_unified_agno_agent.Agent')
    def test_agent_no_enhancement_for_regular_math(self, mock_agent_class, mock_mistral_class):
        """Test that agent doesn't enhance regular math questions."""
        # Mock the agent run method
        mock_agent_instance = Mock()
        mock_agent_instance.run = Mock(return_value=Mock(content="FINAL ANSWER: 425"))
        mock_agent_class.return_value = mock_agent_instance
        
        # Mock Mistral
        mock_mistral_class.return_value = Mock()
        
        with patch.dict('os.environ', {'MISTRAL_API_KEY': 'test-key'}):
            agent = FixedGAIAAgent()
            
            # Test with regular math question
            question = "Calculate 25 * 17"
            result = agent(question)
            
            # Verify agent.run was called
            assert mock_agent_instance.run.called
            
            # Get the actual prompt passed to agent.run
            call_args = mock_agent_instance.run.call_args
            actual_prompt = call_args[0][0]  # First positional argument
            
            # Verify the prompt was NOT enhanced (should be the same)
            assert actual_prompt == question
            
            logger.info(f"βœ… Agent did not enhance regular math question")
            logger.info(f"   Question: {question}")
    
    def test_enhancement_preserves_file_context(self):
        """Test that enhancement works with file context."""
        # Simulate a question with file context
        base_question = "Calculate 2^8"
        file_context = "File 1: data.csv (CSV format), 1024 bytes\nCSV Data: numbers,values\n1,2\n3,4"
        question_with_files = f"{base_question}\n\nFile Context:\n{file_context}"
        
        # Test enhancement
        enhanced = self.enhancer.enhance_prompt_for_exponentiation(question_with_files)
        
        # Verify enhancement occurred
        assert enhanced != question_with_files
        assert len(enhanced) > len(question_with_files)
        
        # Verify file context is preserved
        assert file_context in enhanced
        
        # Verify exponentiation guidance is added
        assert 'python' in enhanced.lower() or 'pow(' in enhanced or '**' in enhanced
        
        logger.info("βœ… Enhancement preserves file context")
        logger.info(f"   Original length: {len(question_with_files)}")
        logger.info(f"   Enhanced length: {len(enhanced)}")


if __name__ == "__main__":
    # Run tests with verbose output
    pytest.main([__file__, "-v", "-s"])