Spaces:
Running
Running
import unittest | |
from unittest.mock import patch, MagicMock | |
import json | |
from langchain_core.language_models import BaseChatModel | |
from langchain_core.messages import HumanMessage, AIMessage | |
from auto_causal.components.decision_tree_llm import DecisionTreeLLMEngine | |
from auto_causal.components.decision_tree import ( | |
METHOD_ASSUMPTIONS, | |
CORRELATION_ANALYSIS, | |
DIFF_IN_DIFF, | |
INSTRUMENTAL_VARIABLE, | |
LINEAR_REGRESSION, | |
PROPENSITY_SCORE_MATCHING, | |
REGRESSION_DISCONTINUITY, | |
DIFF_IN_MEANS | |
) | |
class TestDecisionTreeLLMEngine(unittest.TestCase): | |
def setUp(self): | |
self.engine = DecisionTreeLLMEngine(verbose=False) | |
self.mock_dataset_analysis = { | |
"temporal_structure": {"has_temporal_structure": True, "time_variables": ["year"]}, | |
"potential_instruments": ["Z1"], | |
"running_variable_analysis": {"is_candidate": False} | |
} | |
self.mock_variables = { | |
"treatment_variable": "T", | |
"outcome_variable": "Y", | |
"covariates": ["X1", "X2"], | |
"time_variable": "year", | |
"instrument_variable": "Z1", | |
"treatment_variable_type": "binary" | |
} | |
self.mock_llm = MagicMock(spec=BaseChatModel) | |
def _create_mock_llm_response(self, response_dict): | |
ai_message = AIMessage(content=json.dumps(response_dict)) | |
self.mock_llm.invoke = MagicMock(return_value=ai_message) | |
def _create_mock_llm_raw_response(self, raw_content_str): | |
ai_message = AIMessage(content=raw_content_str) | |
self.mock_llm.invoke = MagicMock(return_value=ai_message) | |
def test_select_method_rct_no_covariates_llm_selects_diff_in_means(self): | |
self._create_mock_llm_response({ | |
"selected_method": DIFF_IN_MEANS, | |
"method_justification": "LLM: RCT with no covariates, DiM is appropriate.", | |
"alternative_methods": [] | |
}) | |
rct_variables = self.mock_variables.copy() | |
rct_variables["covariates"] = [] | |
result = self.engine.select_method( | |
self.mock_dataset_analysis, rct_variables, is_rct=True, llm=self.mock_llm | |
) | |
self.assertEqual(result["selected_method"], DIFF_IN_MEANS) | |
self.assertEqual(result["method_justification"], "LLM: RCT with no covariates, DiM is appropriate.") | |
self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[DIFF_IN_MEANS]) | |
self.mock_llm.invoke.assert_called_once() | |
def test_select_method_rct_with_covariates_llm_selects_linear_regression(self): | |
self._create_mock_llm_response({ | |
"selected_method": LINEAR_REGRESSION, | |
"method_justification": "LLM: RCT with covariates, Linear Regression for precision.", | |
"alternative_methods": [] | |
}) | |
result = self.engine.select_method( | |
self.mock_dataset_analysis, self.mock_variables, is_rct=True, llm=self.mock_llm | |
) | |
self.assertEqual(result["selected_method"], LINEAR_REGRESSION) | |
self.assertEqual(result["method_justification"], "LLM: RCT with covariates, Linear Regression for precision.") | |
self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[LINEAR_REGRESSION]) | |
def test_select_method_observational_temporal_llm_selects_did(self): | |
self._create_mock_llm_response({ | |
"selected_method": DIFF_IN_DIFF, | |
"method_justification": "LLM: Observational with temporal data, DiD selected.", | |
"alternative_methods": [INSTRUMENTAL_VARIABLE] | |
}) | |
result = self.engine.select_method( | |
self.mock_dataset_analysis, self.mock_variables, is_rct=False, llm=self.mock_llm | |
) | |
self.assertEqual(result["selected_method"], DIFF_IN_DIFF) | |
self.assertEqual(result["method_justification"], "LLM: Observational with temporal data, DiD selected.") | |
self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[DIFF_IN_DIFF]) | |
self.assertEqual(result["alternative_methods"], [INSTRUMENTAL_VARIABLE]) | |
def test_select_method_observational_instrument_llm_selects_iv(self): | |
# Modify dataset analysis to not strongly suggest DiD | |
no_temporal_analysis = self.mock_dataset_analysis.copy() | |
no_temporal_analysis["temporal_structure"] = {"has_temporal_structure": False} | |
self._create_mock_llm_response({ | |
"selected_method": INSTRUMENTAL_VARIABLE, | |
"method_justification": "LLM: Observational with instrument, IV selected.", | |
"alternative_methods": [] | |
}) | |
result = self.engine.select_method( | |
no_temporal_analysis, self.mock_variables, is_rct=False, llm=self.mock_llm | |
) | |
self.assertEqual(result["selected_method"], INSTRUMENTAL_VARIABLE) | |
self.assertEqual(result["method_justification"], "LLM: Observational with instrument, IV selected.") | |
self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[INSTRUMENTAL_VARIABLE]) | |
def test_select_method_observational_running_var_llm_selects_rdd(self): | |
rdd_analysis = self.mock_dataset_analysis.copy() | |
rdd_analysis["temporal_structure"] = {"has_temporal_structure": False} # Make DiD less likely | |
rdd_variables = self.mock_variables.copy() | |
rdd_variables["instrument_variable"] = None # Make IV less likely | |
rdd_variables["running_variable"] = "age" | |
rdd_variables["cutoff_value"] = 65 | |
self._create_mock_llm_response({ | |
"selected_method": REGRESSION_DISCONTINUITY, | |
"method_justification": "LLM: Running var and cutoff, RDD selected.", | |
"alternative_methods": [] | |
}) | |
result = self.engine.select_method( | |
rdd_analysis, rdd_variables, is_rct=False, llm=self.mock_llm | |
) | |
self.assertEqual(result["selected_method"], REGRESSION_DISCONTINUITY) | |
self.assertEqual(result["method_justification"], "LLM: Running var and cutoff, RDD selected.") | |
self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[REGRESSION_DISCONTINUITY]) | |
def test_select_method_observational_covariates_llm_selects_psm(self): | |
psm_analysis = {"temporal_structure": {"has_temporal_structure": False}} | |
psm_variables = { | |
"treatment_variable": "T", "outcome_variable": "Y", "covariates": ["X1", "X2"], | |
"treatment_variable_type": "binary" | |
} | |
self._create_mock_llm_response({ | |
"selected_method": PROPENSITY_SCORE_MATCHING, | |
"method_justification": "LLM: Observational with covariates, PSM.", | |
"alternative_methods": [] | |
}) | |
result = self.engine.select_method( | |
psm_analysis, psm_variables, is_rct=False, llm=self.mock_llm | |
) | |
self.assertEqual(result["selected_method"], PROPENSITY_SCORE_MATCHING) | |
self.assertEqual(result["method_justification"], "LLM: Observational with covariates, PSM.") | |
self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[PROPENSITY_SCORE_MATCHING]) | |
def test_select_method_no_llm_provided_defaults_to_correlation(self): | |
result = self.engine.select_method( | |
self.mock_dataset_analysis, self.mock_variables, is_rct=False, llm=None | |
) | |
self.assertEqual(result["selected_method"], CORRELATION_ANALYSIS) | |
self.assertIn("LLM client not provided", result["method_justification"]) | |
self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[CORRELATION_ANALYSIS]) | |
def test_select_method_llm_returns_malformed_json_defaults_to_correlation(self): | |
self._create_mock_llm_raw_response("This is not a valid JSON") | |
result = self.engine.select_method( | |
self.mock_dataset_analysis, self.mock_variables, is_rct=False, llm=self.mock_llm | |
) | |
self.assertEqual(result["selected_method"], CORRELATION_ANALYSIS) | |
self.assertIn("LLM response was not valid JSON", result["method_justification"]) | |
self.assertIn("This is not a valid JSON", result["method_justification"]) | |
self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[CORRELATION_ANALYSIS]) | |
def test_select_method_llm_returns_unknown_method_defaults_to_correlation(self): | |
self._create_mock_llm_response({ | |
"selected_method": "SUPER_NOVEL_METHOD_X", | |
"method_justification": "LLM thinks this is best.", | |
"alternative_methods": [] | |
}) | |
result = self.engine.select_method( | |
self.mock_dataset_analysis, self.mock_variables, is_rct=False, llm=self.mock_llm | |
) | |
self.assertEqual(result["selected_method"], CORRELATION_ANALYSIS) | |
self.assertIn("LLM output was problematic (selected: SUPER_NOVEL_METHOD_X)", result["method_justification"]) | |
self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[CORRELATION_ANALYSIS]) | |
def test_select_method_llm_call_raises_exception_defaults_to_correlation(self): | |
self.mock_llm.invoke = MagicMock(side_effect=Exception("LLM API Error")) | |
result = self.engine.select_method( | |
self.mock_dataset_analysis, self.mock_variables, is_rct=False, llm=self.mock_llm | |
) | |
self.assertEqual(result["selected_method"], CORRELATION_ANALYSIS) | |
self.assertIn("An unexpected error occurred during LLM method selection.", result["method_justification"]) | |
self.assertIn("LLM API Error", result["method_justification"]) | |
self.assertEqual(result["method_assumptions"], METHOD_ASSUMPTIONS[CORRELATION_ANALYSIS]) | |
def test_prompt_construction_content(self): | |
actual_prompt_generated = [] # List to capture the prompt | |
# Store the original method before patching | |
original_construct_prompt = self.engine._construct_prompt | |
def side_effect_for_construct_prompt(dataset_analysis, variables, is_rct): | |
# Call the original _construct_prompt method using the stored original | |
# self.engine is the instance, so it's implicitly passed if original_construct_prompt is bound | |
# However, to be explicit and safe, if we treat original_construct_prompt as potentially unbound: | |
prompt = original_construct_prompt(dataset_analysis, variables, is_rct) | |
actual_prompt_generated.append(prompt) | |
return prompt | |
with patch.object(self.engine, '_construct_prompt', side_effect=side_effect_for_construct_prompt) as mock_construct_prompt: | |
self._create_mock_llm_response({ # Need a mock response for the select_method to run | |
"selected_method": DIFF_IN_DIFF, "method_justification": "Test", "alternative_methods": [] | |
}) | |
self.engine.select_method(self.mock_dataset_analysis, self.mock_variables, False, self.mock_llm) | |
mock_construct_prompt.assert_called_once_with(self.mock_dataset_analysis, self.mock_variables, False) | |
self.assertTrue(actual_prompt_generated, "Prompt was not generated or captured by side_effect") | |
prompt_string = actual_prompt_generated[0] | |
self.assertIn("You are an expert in causal inference.", prompt_string) | |
self.assertIn(json.dumps(self.mock_dataset_analysis, indent=2), prompt_string) | |
self.assertIn(json.dumps(self.mock_variables, indent=2), prompt_string) | |
self.assertIn("Is the data from a Randomized Controlled Trial (RCT)? No", prompt_string) | |
self.assertIn(f"- {DIFF_IN_DIFF}", prompt_string) # Check if method descriptions are there | |
self.assertIn(f"- {INSTRUMENTAL_VARIABLE}", prompt_string) | |
self.assertIn("Output your final decision as a JSON object", prompt_string) | |
def test_llm_response_with_triple_backticks_json(self): | |
raw_response = """ | |
Some conversational text before the JSON. | |
```json | |
{ | |
"selected_method": "difference_in_differences", | |
"method_justification": "LLM reasoned and selected DiD.", | |
"alternative_methods": ["instrumental_variable"] | |
} | |
``` | |
And some text after. | |
""" | |
self._create_mock_llm_raw_response(raw_response) | |
result = self.engine.select_method(self.mock_dataset_analysis, self.mock_variables, False, self.mock_llm) | |
self.assertEqual(result["selected_method"], DIFF_IN_DIFF) | |
self.assertEqual(result["method_justification"], "LLM reasoned and selected DiD.") | |
def test_llm_response_with_triple_backticks_only(self): | |
raw_response = """ | |
``` | |
{ | |
"selected_method": "difference_in_differences", | |
"method_justification": "LLM reasoned and selected DiD with only triple backticks.", | |
"alternative_methods": ["instrumental_variable"] | |
} | |
``` | |
""" | |
self._create_mock_llm_raw_response(raw_response) | |
result = self.engine.select_method(self.mock_dataset_analysis, self.mock_variables, False, self.mock_llm) | |
self.assertEqual(result["selected_method"], DIFF_IN_DIFF) | |
self.assertEqual(result["method_justification"], "LLM reasoned and selected DiD with only triple backticks.") | |
def test_llm_response_plain_json(self): | |
raw_response = """ | |
{ | |
"selected_method": "difference_in_differences", | |
"method_justification": "LLM reasoned and selected DiD plain JSON.", | |
"alternative_methods": ["instrumental_variable"] | |
} | |
""" | |
self._create_mock_llm_raw_response(raw_response) | |
result = self.engine.select_method(self.mock_dataset_analysis, self.mock_variables, False, self.mock_llm) | |
self.assertEqual(result["selected_method"], DIFF_IN_DIFF) | |
self.assertEqual(result["method_justification"], "LLM reasoned and selected DiD plain JSON.") | |
if __name__ == '__main__': | |
unittest.main() |