causal-agent / tests /auto_causal /test_components /test_query_interpreter.py
FireShadow's picture
Initial clean commit
1721aea
import pytest
from unittest.mock import patch, MagicMock
from auto_causal.components.query_interpreter import interpret_query
from auto_causal.models import LLMTreatmentReferenceLevel
# Basic mock data setup
MOCK_QUERY_INFO_REF_LEVEL = {
"query_text": "What is the effect of different fertilizers (Nitro, Phos, Control) on crop_yield, using Control as the baseline?",
"potential_treatments": ["fertilizer_type"],
"outcome_hints": ["crop_yield"],
"covariates_hints": ["soil_ph", "rainfall"]
}
MOCK_DATASET_ANALYSIS_REF_LEVEL = {
"columns": ["fertilizer_type", "crop_yield", "soil_ph", "rainfall"],
"column_categories": {
"fertilizer_type": "categorical_multi", # Assuming a category type for multi-level
"crop_yield": "continuous_numeric",
"soil_ph": "continuous_numeric",
"rainfall": "continuous_numeric"
},
"potential_treatments": ["fertilizer_type"],
"potential_outcomes": ["crop_yield"],
"value_counts": { # Added for providing unique values to the prompt
"fertilizer_type": {
"values": ["Nitro", "Phos", "Control"]
}
},
"columns_data_preview": { # Fallback if value_counts isn't structured as expected
"fertilizer_type": ["Nitro", "Phos", "Control", "Nitro", "Control"]
}
# Add other necessary fields from DatasetAnalysis model if interpret_query uses them
}
MOCK_DATASET_DESCRIPTION_REF_LEVEL = "A dataset from an agricultural experiment."
def test_interpret_query_identifies_treatment_reference_level():
"""
Test that interpret_query correctly identifies and returns the treatment_reference_level
when the LLM simulation provides one.
"""
# Mock the LLM client and its structured output
mock_llm_instance = MagicMock()
mock_structured_llm = MagicMock()
# This will be the mock for the call related to TREATMENT_REFERENCE_IDENTIFICATION_PROMPT_TEMPLATE
# The other LLM calls (for T, O, C, IV, RDD, RCT) also need to be considered
# or made to return benign defaults for this specific test.
# Simulate LLM responses for different calls within interpret_query
def mock_llm_call_router(*args, **kwargs):
# The first argument to _call_llm_for_var is the llm instance,
# The second is the prompt string
# The third is the Pydantic model for structured output
# args[0] is llm, args[1] is prompt, args[2] is pydantic_model
pydantic_model_passed = args[2]
if pydantic_model_passed == LLMTreatmentReferenceLevel:
return LLMTreatmentReferenceLevel(reference_level="Control", reasoning="Identified from query text.")
# Add mocks for other LLM calls if interpret_query strictly needs them to proceed
# For example, for identifying treatment, outcome, covariates, IV, RDD, RCT:
elif "most likely treatment variable" in args[1]: # Simplified check for treatment prompt
return MagicMock(variable_name="fertilizer_type")
elif "most likely outcome variable" in args[1]: # Simplified check for outcome prompt
return MagicMock(variable_name="crop_yield")
elif "valid covariates" in args[1]: # Simplified check for covariates prompt
return MagicMock(covariates=["soil_ph", "rainfall"])
elif "Instrumental Variables" in args[1]: # Check for IV prompt
return MagicMock(instrument_variable=None)
elif "Regression Discontinuity Design" in args[1]: # Check for RDD prompt
return MagicMock(running_variable=None, cutoff_value=None)
elif "Randomized Controlled Trial" in args[1]: # Check for RCT prompt
return MagicMock(is_rct=False, reasoning="No indication of RCT.")
return MagicMock() # Default mock for other calls
# Patch _call_llm_for_var which is used internally by interpret_query's helpers
with patch('auto_causal.components.query_interpreter._call_llm_for_var', side_effect=mock_llm_call_router) as mock_llm_call:
# Patch get_llm_client to return our mock_llm_instance
# This ensures that _call_llm_for_var uses the intended LLM mock when called from within interpret_query
with patch('auto_causal.components.query_interpreter.get_llm_client', return_value=mock_llm_instance) as mock_get_llm:
result = interpret_query(
query_info=MOCK_QUERY_INFO_REF_LEVEL,
dataset_analysis=MOCK_DATASET_ANALYSIS_REF_LEVEL,
dataset_description=MOCK_DATASET_DESCRIPTION_REF_LEVEL
)
assert "treatment_reference_level" in result, "treatment_reference_level should be in the result"
assert result["treatment_reference_level"] == "Control", "Incorrect treatment_reference_level identified"
# Verify that the LLM was called to get the reference level
# This requires checking the calls made to the mock_llm_call
found_ref_level_call = False
for call_args in mock_llm_call.call_args_list:
# call_args is a tuple; call_args[0] contains positional args, call_args[1] has kwargs
# The third positional argument to _call_llm_for_var is the pydantic_model
if len(call_args[0]) >= 3 and call_args[0][2] == LLMTreatmentReferenceLevel:
found_ref_level_call = True
# Optionally, check the prompt content here too if needed
# prompt_content = call_args[0][1]
# assert "using Control as the baseline" in prompt_content
break
assert found_ref_level_call, "LLM call for treatment reference level was not made."
# Basic checks for other essential variables (assuming they are mocked simply)
assert result["treatment_variable"] == "fertilizer_type"
assert result["outcome_variable"] == "crop_yield"
assert result["is_rct"] is False # Based on mock