gaia-enhanced-agent / tests /test_calculator_prompt_enhancer.py
GAIA Agent Deployment
Deploy Complete Enhanced GAIA Agent with Phase 1-6 Improvements
9a6a4dc
"""
Test Calculator Prompt Enhancer - TDD Implementation
Tests the prompt enhancement functionality for exponentiation operations.
"""
import pytest
import sys
import os
import logging
# Add the deployment-ready directory to the path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from utils.calculator_prompt_enhancer import CalculatorPromptEnhancer
logger = logging.getLogger(__name__)
class TestCalculatorPromptEnhancer:
"""Test suite for calculator prompt enhancer."""
@pytest.fixture(autouse=True)
def setup_method(self):
"""Set up test fixtures."""
self.enhancer = CalculatorPromptEnhancer()
def test_detect_exponentiation_patterns(self):
"""Test detection of various exponentiation patterns."""
test_cases = [
# Should detect exponentiation
("Calculate 2^8", True),
("What is 2**8?", True),
("2 to the power of 8", True),
("Compute 3 to the power of 4", True),
("What is 2 raised to 8?", True),
("Calculate power(2, 8)", True),
("Use pow(3, 4)", True),
("What is 5 squared?", True),
("Calculate 2 cubed", True),
# Should NOT detect exponentiation
("Calculate 25 * 17", False),
("What is 144 divided by 12?", False),
("Add 100 and 50", False),
("Subtract 75 from 200", False),
("What is the square root of 16?", False),
]
for question, expected in test_cases:
result = self.enhancer.detect_exponentiation(question)
assert result == expected, f"Failed for '{question}': expected {expected}, got {result}"
logger.info(f"βœ… Detection test passed: '{question}' β†’ {result}")
def test_extract_exponentiation_components(self):
"""Test extraction of base and exponent from questions."""
test_cases = [
("Calculate 2^8", {'base': 2, 'exponent': 8, 'expected_result': 256}),
("What is 3**4?", {'base': 3, 'exponent': 4, 'expected_result': 81}),
("2 to the power of 8", {'base': 2, 'exponent': 8, 'expected_result': 256}),
("5 raised to 3", {'base': 5, 'exponent': 3, 'expected_result': 125}),
]
for question, expected in test_cases:
result = self.enhancer.extract_exponentiation_components(question)
assert result is not None, f"Failed to extract components from '{question}'"
assert result['base'] == expected['base'], f"Base mismatch for '{question}'"
assert result['exponent'] == expected['exponent'], f"Exponent mismatch for '{question}'"
assert result['expected_result'] == expected['expected_result'], f"Expected result mismatch for '{question}'"
logger.info(f"βœ… Extraction test passed: '{question}' β†’ {result['base']}^{result['exponent']} = {result['expected_result']}")
def test_enhance_prompt_for_exponentiation(self):
"""Test prompt enhancement for exponentiation questions."""
test_cases = [
"Calculate 2^8",
"What is 3 to the power of 4?",
"Compute 5**2",
]
for question in test_cases:
enhanced = self.enhancer.enhance_prompt_for_exponentiation(question)
# Check that enhancement occurred
assert len(enhanced) > len(question), f"Prompt not enhanced for '{question}'"
assert "Python" in enhanced, f"Enhanced prompt should mention Python for '{question}'"
assert "**" in enhanced, f"Enhanced prompt should mention ** operator for '{question}'"
assert question in enhanced, f"Original question should be preserved in '{question}'"
logger.info(f"βœ… Enhancement test passed: '{question}'")
logger.info(f" Enhanced length: {len(enhanced)} vs original: {len(question)}")
def test_non_exponentiation_questions_unchanged(self):
"""Test that non-exponentiation questions are not enhanced."""
test_cases = [
"Calculate 25 * 17",
"What is 144 divided by 12?",
"Add 100 and 50",
]
for question in test_cases:
enhanced = self.enhancer.enhance_prompt_for_exponentiation(question)
assert enhanced == question, f"Non-exponentiation question should not be enhanced: '{question}'"
logger.info(f"βœ… Non-enhancement test passed: '{question}'")
def test_validate_exponentiation_result(self):
"""Test validation of exponentiation results."""
test_cases = [
# Correct results
("Calculate 2^8", "256", True),
("What is 3**4?", "The answer is 81", True),
("2 to the power of 8", "Result: 256", True),
# Incorrect results
("Calculate 2^8", "16", False), # This is 2*8, not 2^8
("What is 3**4?", "12", False), # This is 3*4, not 3^4
("2 to the power of 8", "128", False), # Wrong result
]
for question, result, expected_valid in test_cases:
validation = self.enhancer.validate_exponentiation_result(question, result)
assert 'valid' in validation, f"Validation should include 'valid' key for '{question}'"
assert validation['valid'] == expected_valid, f"Validation failed for '{question}' with result '{result}'"
if expected_valid:
logger.info(f"βœ… Validation test passed (correct): '{question}' β†’ '{result}'")
else:
logger.info(f"βœ… Validation test passed (incorrect detected): '{question}' β†’ '{result}'")
assert 'expected' in validation, f"Should include expected result for incorrect answer"
assert 'actual' in validation, f"Should include actual result for incorrect answer"
def test_create_python_calculation_prompt(self):
"""Test creation of Python calculation prompts."""
test_cases = [
(2, 8, 256),
(3, 4, 81),
(5, 3, 125),
]
for base, exponent, expected_result in test_cases:
prompt = self.enhancer.create_python_calculation_prompt(base, exponent)
# Check that prompt contains necessary elements
assert str(base) in prompt, f"Prompt should contain base {base}"
assert str(exponent) in prompt, f"Prompt should contain exponent {exponent}"
assert str(expected_result) in prompt, f"Prompt should contain expected result {expected_result}"
assert "**" in prompt, f"Prompt should contain ** operator"
assert "Python" in prompt, f"Prompt should mention Python"
logger.info(f"βœ… Python prompt test passed: {base}^{exponent} = {expected_result}")
if __name__ == "__main__":
# Run the prompt enhancer tests
pytest.main([__file__, "-v", "-s"])