import pytest from unittest.mock import patch, MagicMock from auto_causal.components.query_interpreter import interpret_query from auto_causal.models import LLMTreatmentReferenceLevel # Basic mock data setup MOCK_QUERY_INFO_REF_LEVEL = { "query_text": "What is the effect of different fertilizers (Nitro, Phos, Control) on crop_yield, using Control as the baseline?", "potential_treatments": ["fertilizer_type"], "outcome_hints": ["crop_yield"], "covariates_hints": ["soil_ph", "rainfall"] } MOCK_DATASET_ANALYSIS_REF_LEVEL = { "columns": ["fertilizer_type", "crop_yield", "soil_ph", "rainfall"], "column_categories": { "fertilizer_type": "categorical_multi", # Assuming a category type for multi-level "crop_yield": "continuous_numeric", "soil_ph": "continuous_numeric", "rainfall": "continuous_numeric" }, "potential_treatments": ["fertilizer_type"], "potential_outcomes": ["crop_yield"], "value_counts": { # Added for providing unique values to the prompt "fertilizer_type": { "values": ["Nitro", "Phos", "Control"] } }, "columns_data_preview": { # Fallback if value_counts isn't structured as expected "fertilizer_type": ["Nitro", "Phos", "Control", "Nitro", "Control"] } # Add other necessary fields from DatasetAnalysis model if interpret_query uses them } MOCK_DATASET_DESCRIPTION_REF_LEVEL = "A dataset from an agricultural experiment." def test_interpret_query_identifies_treatment_reference_level(): """ Test that interpret_query correctly identifies and returns the treatment_reference_level when the LLM simulation provides one. """ # Mock the LLM client and its structured output mock_llm_instance = MagicMock() mock_structured_llm = MagicMock() # This will be the mock for the call related to TREATMENT_REFERENCE_IDENTIFICATION_PROMPT_TEMPLATE # The other LLM calls (for T, O, C, IV, RDD, RCT) also need to be considered # or made to return benign defaults for this specific test. # Simulate LLM responses for different calls within interpret_query def mock_llm_call_router(*args, **kwargs): # The first argument to _call_llm_for_var is the llm instance, # The second is the prompt string # The third is the Pydantic model for structured output # args[0] is llm, args[1] is prompt, args[2] is pydantic_model pydantic_model_passed = args[2] if pydantic_model_passed == LLMTreatmentReferenceLevel: return LLMTreatmentReferenceLevel(reference_level="Control", reasoning="Identified from query text.") # Add mocks for other LLM calls if interpret_query strictly needs them to proceed # For example, for identifying treatment, outcome, covariates, IV, RDD, RCT: elif "most likely treatment variable" in args[1]: # Simplified check for treatment prompt return MagicMock(variable_name="fertilizer_type") elif "most likely outcome variable" in args[1]: # Simplified check for outcome prompt return MagicMock(variable_name="crop_yield") elif "valid covariates" in args[1]: # Simplified check for covariates prompt return MagicMock(covariates=["soil_ph", "rainfall"]) elif "Instrumental Variables" in args[1]: # Check for IV prompt return MagicMock(instrument_variable=None) elif "Regression Discontinuity Design" in args[1]: # Check for RDD prompt return MagicMock(running_variable=None, cutoff_value=None) elif "Randomized Controlled Trial" in args[1]: # Check for RCT prompt return MagicMock(is_rct=False, reasoning="No indication of RCT.") return MagicMock() # Default mock for other calls # Patch _call_llm_for_var which is used internally by interpret_query's helpers with patch('auto_causal.components.query_interpreter._call_llm_for_var', side_effect=mock_llm_call_router) as mock_llm_call: # Patch get_llm_client to return our mock_llm_instance # This ensures that _call_llm_for_var uses the intended LLM mock when called from within interpret_query with patch('auto_causal.components.query_interpreter.get_llm_client', return_value=mock_llm_instance) as mock_get_llm: result = interpret_query( query_info=MOCK_QUERY_INFO_REF_LEVEL, dataset_analysis=MOCK_DATASET_ANALYSIS_REF_LEVEL, dataset_description=MOCK_DATASET_DESCRIPTION_REF_LEVEL ) assert "treatment_reference_level" in result, "treatment_reference_level should be in the result" assert result["treatment_reference_level"] == "Control", "Incorrect treatment_reference_level identified" # Verify that the LLM was called to get the reference level # This requires checking the calls made to the mock_llm_call found_ref_level_call = False for call_args in mock_llm_call.call_args_list: # call_args is a tuple; call_args[0] contains positional args, call_args[1] has kwargs # The third positional argument to _call_llm_for_var is the pydantic_model if len(call_args[0]) >= 3 and call_args[0][2] == LLMTreatmentReferenceLevel: found_ref_level_call = True # Optionally, check the prompt content here too if needed # prompt_content = call_args[0][1] # assert "using Control as the baseline" in prompt_content break assert found_ref_level_call, "LLM call for treatment reference level was not made." # Basic checks for other essential variables (assuming they are mocked simply) assert result["treatment_variable"] == "fertilizer_type" assert result["outcome_variable"] == "crop_yield" assert result["is_rct"] is False # Based on mock