File size: 5,003 Bytes
1721aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import pytest
import os
import pandas as pd

# Import the refactored parse_input function
from auto_causal.components import input_parser

# Check if OpenAI API key is available, skip if not
api_key_present = bool(os.environ.get("OPENAI_API_KEY"))
skip_if_no_key = pytest.mark.skipif(not api_key_present, reason="OPENAI_API_KEY environment variable not set")

@skip_if_no_key
def test_parse_input_with_real_llm():
    """Tests the parse_input function invoking the actual LLM.
    
    Note: This test requires the OPENAI_API_KEY environment variable to be set 
    and will make a real API call.
    """
    # --- Test Case 1: Effect query with dataset and constraint ---
    query1 = "analyze the effect of 'Minimum Wage Increase' on 'Unemployment Rate' using data/county_data.csv where year > 2010"
    
    # Provide some dummy dataset context
    dataset_info1 = {
        'columns': ['County', 'Year', 'Minimum Wage Increase', 'Unemployment Rate', 'Population'],
        'column_types': {'County': 'object', 'Year': 'int64', 'Minimum Wage Increase': 'int64', 'Unemployment Rate': 'float64', 'Population': 'int64'},
        'sample_rows': [
            {'County': 'A', 'Year': 2009, 'Minimum Wage Increase': 0, 'Unemployment Rate': 5.5, 'Population': 10000},
            {'County': 'A', 'Year': 2011, 'Minimum Wage Increase': 1, 'Unemployment Rate': 6.0, 'Population': 10200}
        ]
    }
    
    # Create a dummy data file for path checking (relative to workspace root)
    dummy_file_path = "data/county_data.csv"
    os.makedirs(os.path.dirname(dummy_file_path), exist_ok=True)
    with open(dummy_file_path, 'w') as f:
        f.write("County,Year,Minimum Wage Increase,Unemployment Rate,Population\n")
        f.write("A,2009,0,5.5,10000\n")
        f.write("A,2011,1,6.0,10200\n")
        
    result1 = input_parser.parse_input(query=query1, dataset_info=dataset_info1)
    
    # Clean up dummy file
    if os.path.exists(dummy_file_path):
        os.remove(dummy_file_path)
        # Try removing the directory if empty
        try:
            os.rmdir(os.path.dirname(dummy_file_path))
        except OSError:
            pass # Ignore if directory is not empty or other error

    # Assertions for Test Case 1
    assert result1 is not None
    assert result1['original_query'] == query1
    assert result1['query_type'] == "EFFECT_ESTIMATION"
    assert result1['dataset_path'] == dummy_file_path # Check if path extraction worked
    
    # Check variables (allowing for some LLM interpretation flexibility)
    assert 'treatment' in result1['extracted_variables']
    assert 'outcome' in result1['extracted_variables']
    # Check if the core variable names are present in the extracted lists
    assert any('Minimum Wage Increase' in t for t in result1['extracted_variables'].get('treatment', []))
    assert any('Unemployment Rate' in o for o in result1['extracted_variables'].get('outcome', []))
    
    # Check constraints
    assert isinstance(result1['constraints'], list)
    # Check if a constraint related to 'year > 2010' was captured (LLM might phrase it differently)
    assert any('year' in c.lower() and '2010' in c for c in result1.get('constraints', [])), "Constraint 'year > 2010' not found or not parsed correctly."

    # --- Test Case 2: Counterfactual without dataset path ---
    query2 = "What would sales have been if we hadn't run the 'Summer Sale' campaign?"
    dataset_info2 = {
        'columns': ['Date', 'Sales', 'Summer Sale', 'Competitor Activity'],
        'column_types': { 'Date': 'datetime64[ns]', 'Sales': 'float64', 'Summer Sale': 'int64', 'Competitor Activity': 'float64'}
    }
    
    result2 = input_parser.parse_input(query=query2, dataset_info=dataset_info2)
    
    # Assertions for Test Case 2
    assert result2 is not None
    assert result2['query_type'] == "COUNTERFACTUAL"
    assert result2['dataset_path'] is None # No path mentioned or inferrable here
    assert any('Summer Sale' in t for t in result2['extracted_variables'].get('treatment', []))
    assert any('Sales' in o for o in result2['extracted_variables'].get('outcome', []))
    assert not result2['constraints'] # No constraints expected

    # --- Test Case 3: Simple query, LLM might fail validation? ---
    # This tests if the retry/failure mechanism logs warnings but doesn't crash
    # (Assuming LLM might struggle to extract treatment/outcome from just "sales vs ads")
    query3 = "sales vs ads"
    dataset_info3 = {
        'columns': ['sales', 'ads'],
        'column_types': {'sales': 'float', 'ads': 'float'}
    }
    result3 = input_parser.parse_input(query=query3, dataset_info=dataset_info3)
    assert result3 is not None
    # LLM might fail extraction; check default/fallback values
    # Query type might default to OTHER or CORRELATION/DESCRIPTIVE
    # Variables might be empty or partially filled
    # This mainly checks that the function completes without error even if LLM fails
    print(f"Result for ambiguous query: {result3}")