Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
"""Test script to verify cost tracking is working properly.""" | |
import logging | |
import json | |
from unittest.mock import Mock, patch | |
from src.services.cost_tracker import CostTracker | |
from src.agents.unique_indices_combinator import UniqueIndicesCombinator | |
from src.agents.unique_indices_loop_agent import UniqueIndicesLoopAgent | |
from src.config.settings import settings | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def test_cost_tracking(): | |
"""Test that cost tracking works properly with the new agents.""" | |
# Create a cost tracker | |
cost_tracker = CostTracker() | |
# Create mock context | |
ctx = { | |
"text": "This is a test document with some content.", | |
"unique_indices": ["Protein Lot", "Peptide", "Timepoint", "Modification"], | |
"unique_indices_descriptions": { | |
"Protein Lot": { | |
"description": "Protein lot identifier", | |
"format": "String", | |
"examples": "P066_L14_H31_0-hulgG-LALAPG-FJB", | |
"possible_values": "" | |
}, | |
"Peptide": { | |
"description": "Peptide sequence", | |
"format": "String", | |
"examples": "QVQLQQSGPGLVQPSQSLSITCTVSDFSLAR", | |
"possible_values": "" | |
} | |
}, | |
"fields": ["Chain", "Percentage", "Seq Loc"], | |
"field_descriptions": { | |
"Chain": { | |
"description": "Heavy or Light chain", | |
"format": "String", | |
"examples": "Heavy", | |
"possible_values": "Heavy, Light" | |
} | |
}, | |
"document_context": "Biotech document", | |
"cost_tracker": cost_tracker | |
} | |
# Mock LLM responses | |
mock_combinations = [ | |
{ | |
"Protein Lot": "P066_L14_H31_0-hulgG-LALAPG-FJB", | |
"Peptide": "PLTFGAGTK", | |
"Timepoint": "0w", | |
"Modification": "Clipping" | |
}, | |
{ | |
"Protein Lot": "P066_L14_H31_0-hulgG-LALAPG-FJB", | |
"Peptide": "PLTFGAGTK", | |
"Timepoint": "4w", | |
"Modification": "Clipping" | |
} | |
] | |
mock_additional_fields = { | |
"Chain": "Heavy", | |
"Percentage": "90.0", | |
"Seq Loc": "HC(1-31)" | |
} | |
# Test UniqueIndicesCombinator | |
logger.info("Testing UniqueIndicesCombinator cost tracking...") | |
with patch('openai.responses.create') as mock_create: | |
# Mock the LLM response for combinations | |
mock_create.return_value = Mock( | |
output=[Mock(content=[Mock(text=json.dumps(mock_combinations))])], | |
usage=Mock(input_tokens=1500, output_tokens=300) | |
) | |
combinator = UniqueIndicesCombinator() | |
result = combinator.execute(ctx) | |
logger.info(f"Combinator result: {result}") | |
logger.info(f"Cost tracker after combinator:") | |
logger.info(f" Input tokens: {cost_tracker.llm_input_tokens}") | |
logger.info(f" Output tokens: {cost_tracker.llm_output_tokens}") | |
logger.info(f" LLM calls: {len(cost_tracker.llm_calls)}") | |
# Verify cost tracking worked | |
assert cost_tracker.llm_input_tokens == 1500 | |
assert cost_tracker.llm_output_tokens == 300 | |
assert len(cost_tracker.llm_calls) == 1 | |
assert cost_tracker.llm_calls[0].description == "Unique Indices Combination Extraction" | |
# Test UniqueIndicesLoopAgent | |
logger.info("Testing UniqueIndicesLoopAgent cost tracking...") | |
# Set the results from combinator | |
ctx["results"] = mock_combinations | |
with patch('openai.responses.create') as mock_create: | |
# Mock the LLM response for additional fields (will be called twice, once for each combination) | |
mock_create.return_value = Mock( | |
output=[Mock(content=[Mock(text=json.dumps(mock_additional_fields))])], | |
usage=Mock(input_tokens=800, output_tokens=150) | |
) | |
loop_agent = UniqueIndicesLoopAgent() | |
result = loop_agent.execute(ctx) | |
logger.info(f"Loop agent result: {result}") | |
logger.info(f"Cost tracker after loop agent:") | |
logger.info(f" Input tokens: {cost_tracker.llm_input_tokens}") | |
logger.info(f" Output tokens: {cost_tracker.llm_output_tokens}") | |
logger.info(f" LLM calls: {len(cost_tracker.llm_calls)}") | |
# Verify cost tracking worked for both calls | |
assert cost_tracker.llm_input_tokens == 1500 + (800 * 2) # Combinator + 2 loop iterations | |
assert cost_tracker.llm_output_tokens == 300 + (150 * 2) # Combinator + 2 loop iterations | |
assert len(cost_tracker.llm_calls) == 3 # 1 combinator + 2 loop iterations | |
# Test detailed costs table | |
logger.info("Testing detailed costs table...") | |
costs_df = cost_tracker.get_detailed_costs_table() | |
logger.info(f"Costs table:\n{costs_df}") | |
# Verify the table has the expected structure | |
assert len(costs_df) == 4 # 3 calls + 1 total row | |
assert "Description" in costs_df.columns | |
assert "Input Tokens" in costs_df.columns | |
assert "Output Tokens" in costs_df.columns | |
assert "Total Cost" in costs_df.columns | |
logger.info("All cost tracking tests passed!") | |
if __name__ == "__main__": | |
test_cost_tracking() |