Spaces:
Running
Running
import pytest | |
import os | |
import pandas as pd | |
# Import the refactored parse_input function | |
from auto_causal.components import input_parser | |
# Check if OpenAI API key is available, skip if not | |
api_key_present = bool(os.environ.get("OPENAI_API_KEY")) | |
skip_if_no_key = pytest.mark.skipif(not api_key_present, reason="OPENAI_API_KEY environment variable not set") | |
def test_parse_input_with_real_llm(): | |
"""Tests the parse_input function invoking the actual LLM. | |
Note: This test requires the OPENAI_API_KEY environment variable to be set | |
and will make a real API call. | |
""" | |
# --- Test Case 1: Effect query with dataset and constraint --- | |
query1 = "analyze the effect of 'Minimum Wage Increase' on 'Unemployment Rate' using data/county_data.csv where year > 2010" | |
# Provide some dummy dataset context | |
dataset_info1 = { | |
'columns': ['County', 'Year', 'Minimum Wage Increase', 'Unemployment Rate', 'Population'], | |
'column_types': {'County': 'object', 'Year': 'int64', 'Minimum Wage Increase': 'int64', 'Unemployment Rate': 'float64', 'Population': 'int64'}, | |
'sample_rows': [ | |
{'County': 'A', 'Year': 2009, 'Minimum Wage Increase': 0, 'Unemployment Rate': 5.5, 'Population': 10000}, | |
{'County': 'A', 'Year': 2011, 'Minimum Wage Increase': 1, 'Unemployment Rate': 6.0, 'Population': 10200} | |
] | |
} | |
# Create a dummy data file for path checking (relative to workspace root) | |
dummy_file_path = "data/county_data.csv" | |
os.makedirs(os.path.dirname(dummy_file_path), exist_ok=True) | |
with open(dummy_file_path, 'w') as f: | |
f.write("County,Year,Minimum Wage Increase,Unemployment Rate,Population\n") | |
f.write("A,2009,0,5.5,10000\n") | |
f.write("A,2011,1,6.0,10200\n") | |
result1 = input_parser.parse_input(query=query1, dataset_info=dataset_info1) | |
# Clean up dummy file | |
if os.path.exists(dummy_file_path): | |
os.remove(dummy_file_path) | |
# Try removing the directory if empty | |
try: | |
os.rmdir(os.path.dirname(dummy_file_path)) | |
except OSError: | |
pass # Ignore if directory is not empty or other error | |
# Assertions for Test Case 1 | |
assert result1 is not None | |
assert result1['original_query'] == query1 | |
assert result1['query_type'] == "EFFECT_ESTIMATION" | |
assert result1['dataset_path'] == dummy_file_path # Check if path extraction worked | |
# Check variables (allowing for some LLM interpretation flexibility) | |
assert 'treatment' in result1['extracted_variables'] | |
assert 'outcome' in result1['extracted_variables'] | |
# Check if the core variable names are present in the extracted lists | |
assert any('Minimum Wage Increase' in t for t in result1['extracted_variables'].get('treatment', [])) | |
assert any('Unemployment Rate' in o for o in result1['extracted_variables'].get('outcome', [])) | |
# Check constraints | |
assert isinstance(result1['constraints'], list) | |
# Check if a constraint related to 'year > 2010' was captured (LLM might phrase it differently) | |
assert any('year' in c.lower() and '2010' in c for c in result1.get('constraints', [])), "Constraint 'year > 2010' not found or not parsed correctly." | |
# --- Test Case 2: Counterfactual without dataset path --- | |
query2 = "What would sales have been if we hadn't run the 'Summer Sale' campaign?" | |
dataset_info2 = { | |
'columns': ['Date', 'Sales', 'Summer Sale', 'Competitor Activity'], | |
'column_types': { 'Date': 'datetime64[ns]', 'Sales': 'float64', 'Summer Sale': 'int64', 'Competitor Activity': 'float64'} | |
} | |
result2 = input_parser.parse_input(query=query2, dataset_info=dataset_info2) | |
# Assertions for Test Case 2 | |
assert result2 is not None | |
assert result2['query_type'] == "COUNTERFACTUAL" | |
assert result2['dataset_path'] is None # No path mentioned or inferrable here | |
assert any('Summer Sale' in t for t in result2['extracted_variables'].get('treatment', [])) | |
assert any('Sales' in o for o in result2['extracted_variables'].get('outcome', [])) | |
assert not result2['constraints'] # No constraints expected | |
# --- Test Case 3: Simple query, LLM might fail validation? --- | |
# This tests if the retry/failure mechanism logs warnings but doesn't crash | |
# (Assuming LLM might struggle to extract treatment/outcome from just "sales vs ads") | |
query3 = "sales vs ads" | |
dataset_info3 = { | |
'columns': ['sales', 'ads'], | |
'column_types': {'sales': 'float', 'ads': 'float'} | |
} | |
result3 = input_parser.parse_input(query=query3, dataset_info=dataset_info3) | |
assert result3 is not None | |
# LLM might fail extraction; check default/fallback values | |
# Query type might default to OTHER or CORRELATION/DESCRIPTIVE | |
# Variables might be empty or partially filled | |
# This mainly checks that the function completes without error even if LLM fails | |
print(f"Result for ambiguous query: {result3}") |