Spaces:
Sleeping
Sleeping
File size: 5,368 Bytes
665cc97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
#!/usr/bin/env python3
"""Test script to verify cost tracking is working properly."""
import logging
import json
from unittest.mock import Mock, patch
from src.services.cost_tracker import CostTracker
from src.agents.unique_indices_combinator import UniqueIndicesCombinator
from src.agents.unique_indices_loop_agent import UniqueIndicesLoopAgent
from src.config.settings import settings
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_cost_tracking():
"""Test that cost tracking works properly with the new agents."""
# Create a cost tracker
cost_tracker = CostTracker()
# Create mock context
ctx = {
"text": "This is a test document with some content.",
"unique_indices": ["Protein Lot", "Peptide", "Timepoint", "Modification"],
"unique_indices_descriptions": {
"Protein Lot": {
"description": "Protein lot identifier",
"format": "String",
"examples": "P066_L14_H31_0-hulgG-LALAPG-FJB",
"possible_values": ""
},
"Peptide": {
"description": "Peptide sequence",
"format": "String",
"examples": "QVQLQQSGPGLVQPSQSLSITCTVSDFSLAR",
"possible_values": ""
}
},
"fields": ["Chain", "Percentage", "Seq Loc"],
"field_descriptions": {
"Chain": {
"description": "Heavy or Light chain",
"format": "String",
"examples": "Heavy",
"possible_values": "Heavy, Light"
}
},
"document_context": "Biotech document",
"cost_tracker": cost_tracker
}
# Mock LLM responses
mock_combinations = [
{
"Protein Lot": "P066_L14_H31_0-hulgG-LALAPG-FJB",
"Peptide": "PLTFGAGTK",
"Timepoint": "0w",
"Modification": "Clipping"
},
{
"Protein Lot": "P066_L14_H31_0-hulgG-LALAPG-FJB",
"Peptide": "PLTFGAGTK",
"Timepoint": "4w",
"Modification": "Clipping"
}
]
mock_additional_fields = {
"Chain": "Heavy",
"Percentage": "90.0",
"Seq Loc": "HC(1-31)"
}
# Test UniqueIndicesCombinator
logger.info("Testing UniqueIndicesCombinator cost tracking...")
with patch('openai.responses.create') as mock_create:
# Mock the LLM response for combinations
mock_create.return_value = Mock(
output=[Mock(content=[Mock(text=json.dumps(mock_combinations))])],
usage=Mock(input_tokens=1500, output_tokens=300)
)
combinator = UniqueIndicesCombinator()
result = combinator.execute(ctx)
logger.info(f"Combinator result: {result}")
logger.info(f"Cost tracker after combinator:")
logger.info(f" Input tokens: {cost_tracker.llm_input_tokens}")
logger.info(f" Output tokens: {cost_tracker.llm_output_tokens}")
logger.info(f" LLM calls: {len(cost_tracker.llm_calls)}")
# Verify cost tracking worked
assert cost_tracker.llm_input_tokens == 1500
assert cost_tracker.llm_output_tokens == 300
assert len(cost_tracker.llm_calls) == 1
assert cost_tracker.llm_calls[0].description == "Unique Indices Combination Extraction"
# Test UniqueIndicesLoopAgent
logger.info("Testing UniqueIndicesLoopAgent cost tracking...")
# Set the results from combinator
ctx["results"] = mock_combinations
with patch('openai.responses.create') as mock_create:
# Mock the LLM response for additional fields (will be called twice, once for each combination)
mock_create.return_value = Mock(
output=[Mock(content=[Mock(text=json.dumps(mock_additional_fields))])],
usage=Mock(input_tokens=800, output_tokens=150)
)
loop_agent = UniqueIndicesLoopAgent()
result = loop_agent.execute(ctx)
logger.info(f"Loop agent result: {result}")
logger.info(f"Cost tracker after loop agent:")
logger.info(f" Input tokens: {cost_tracker.llm_input_tokens}")
logger.info(f" Output tokens: {cost_tracker.llm_output_tokens}")
logger.info(f" LLM calls: {len(cost_tracker.llm_calls)}")
# Verify cost tracking worked for both calls
assert cost_tracker.llm_input_tokens == 1500 + (800 * 2) # Combinator + 2 loop iterations
assert cost_tracker.llm_output_tokens == 300 + (150 * 2) # Combinator + 2 loop iterations
assert len(cost_tracker.llm_calls) == 3 # 1 combinator + 2 loop iterations
# Test detailed costs table
logger.info("Testing detailed costs table...")
costs_df = cost_tracker.get_detailed_costs_table()
logger.info(f"Costs table:\n{costs_df}")
# Verify the table has the expected structure
assert len(costs_df) == 4 # 3 calls + 1 total row
assert "Description" in costs_df.columns
assert "Input Tokens" in costs_df.columns
assert "Output Tokens" in costs_df.columns
assert "Total Cost" in costs_df.columns
logger.info("All cost tracking tests passed!")
if __name__ == "__main__":
test_cost_tracking() |