doctorecord / test_cost_tracking.py
levalencia's picture
feat: enhance architecture and developer documentation for clarity and detail
665cc97
raw
history blame
5.37 kB
#!/usr/bin/env python3
"""Test script to verify cost tracking is working properly."""
import logging
import json
from unittest.mock import Mock, patch
from src.services.cost_tracker import CostTracker
from src.agents.unique_indices_combinator import UniqueIndicesCombinator
from src.agents.unique_indices_loop_agent import UniqueIndicesLoopAgent
from src.config.settings import settings
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_cost_tracking():
"""Test that cost tracking works properly with the new agents."""
# Create a cost tracker
cost_tracker = CostTracker()
# Create mock context
ctx = {
"text": "This is a test document with some content.",
"unique_indices": ["Protein Lot", "Peptide", "Timepoint", "Modification"],
"unique_indices_descriptions": {
"Protein Lot": {
"description": "Protein lot identifier",
"format": "String",
"examples": "P066_L14_H31_0-hulgG-LALAPG-FJB",
"possible_values": ""
},
"Peptide": {
"description": "Peptide sequence",
"format": "String",
"examples": "QVQLQQSGPGLVQPSQSLSITCTVSDFSLAR",
"possible_values": ""
}
},
"fields": ["Chain", "Percentage", "Seq Loc"],
"field_descriptions": {
"Chain": {
"description": "Heavy or Light chain",
"format": "String",
"examples": "Heavy",
"possible_values": "Heavy, Light"
}
},
"document_context": "Biotech document",
"cost_tracker": cost_tracker
}
# Mock LLM responses
mock_combinations = [
{
"Protein Lot": "P066_L14_H31_0-hulgG-LALAPG-FJB",
"Peptide": "PLTFGAGTK",
"Timepoint": "0w",
"Modification": "Clipping"
},
{
"Protein Lot": "P066_L14_H31_0-hulgG-LALAPG-FJB",
"Peptide": "PLTFGAGTK",
"Timepoint": "4w",
"Modification": "Clipping"
}
]
mock_additional_fields = {
"Chain": "Heavy",
"Percentage": "90.0",
"Seq Loc": "HC(1-31)"
}
# Test UniqueIndicesCombinator
logger.info("Testing UniqueIndicesCombinator cost tracking...")
with patch('openai.responses.create') as mock_create:
# Mock the LLM response for combinations
mock_create.return_value = Mock(
output=[Mock(content=[Mock(text=json.dumps(mock_combinations))])],
usage=Mock(input_tokens=1500, output_tokens=300)
)
combinator = UniqueIndicesCombinator()
result = combinator.execute(ctx)
logger.info(f"Combinator result: {result}")
logger.info(f"Cost tracker after combinator:")
logger.info(f" Input tokens: {cost_tracker.llm_input_tokens}")
logger.info(f" Output tokens: {cost_tracker.llm_output_tokens}")
logger.info(f" LLM calls: {len(cost_tracker.llm_calls)}")
# Verify cost tracking worked
assert cost_tracker.llm_input_tokens == 1500
assert cost_tracker.llm_output_tokens == 300
assert len(cost_tracker.llm_calls) == 1
assert cost_tracker.llm_calls[0].description == "Unique Indices Combination Extraction"
# Test UniqueIndicesLoopAgent
logger.info("Testing UniqueIndicesLoopAgent cost tracking...")
# Set the results from combinator
ctx["results"] = mock_combinations
with patch('openai.responses.create') as mock_create:
# Mock the LLM response for additional fields (will be called twice, once for each combination)
mock_create.return_value = Mock(
output=[Mock(content=[Mock(text=json.dumps(mock_additional_fields))])],
usage=Mock(input_tokens=800, output_tokens=150)
)
loop_agent = UniqueIndicesLoopAgent()
result = loop_agent.execute(ctx)
logger.info(f"Loop agent result: {result}")
logger.info(f"Cost tracker after loop agent:")
logger.info(f" Input tokens: {cost_tracker.llm_input_tokens}")
logger.info(f" Output tokens: {cost_tracker.llm_output_tokens}")
logger.info(f" LLM calls: {len(cost_tracker.llm_calls)}")
# Verify cost tracking worked for both calls
assert cost_tracker.llm_input_tokens == 1500 + (800 * 2) # Combinator + 2 loop iterations
assert cost_tracker.llm_output_tokens == 300 + (150 * 2) # Combinator + 2 loop iterations
assert len(cost_tracker.llm_calls) == 3 # 1 combinator + 2 loop iterations
# Test detailed costs table
logger.info("Testing detailed costs table...")
costs_df = cost_tracker.get_detailed_costs_table()
logger.info(f"Costs table:\n{costs_df}")
# Verify the table has the expected structure
assert len(costs_df) == 4 # 3 calls + 1 total row
assert "Description" in costs_df.columns
assert "Input Tokens" in costs_df.columns
assert "Output Tokens" in costs_df.columns
assert "Total Cost" in costs_df.columns
logger.info("All cost tracking tests passed!")
if __name__ == "__main__":
test_cost_tracking()