causal-agent / tests /auto_causal /components /test_dataset_analyzer.py
FireShadow's picture
Initial clean commit
1721aea
import unittest
import os
import pandas as pd
import numpy as np
# Import the function to test
from auto_causal.components.dataset_analyzer import analyze_dataset
# Helper to create dummy dataset files
def create_dummy_csv_for_analysis(path, data_dict):
df = pd.DataFrame(data_dict)
df.to_csv(path, index=False)
return path
class TestDatasetAnalyzer(unittest.TestCase):
def setUp(self):
'''Set up dummy data paths and create files.'''
self.test_files = []
# Basic data
self.basic_data_path = "analyzer_test_basic.csv"
create_dummy_csv_for_analysis(self.basic_data_path, {
'treatment': [0, 1, 0, 1, 0, 1],
'outcome': [10, 12, 11, 13, 9, 14],
'cov1': ['A', 'B', 'A', 'B', 'A', 'B'],
'numeric_cov': [1.1, 2.2, 1.3, 2.5, 1.0, 2.9]
})
self.test_files.append(self.basic_data_path)
# Panel data
self.panel_data_path = "analyzer_test_panel.csv"
create_dummy_csv_for_analysis(self.panel_data_path, {
'unit': [1, 1, 2, 2],
'year': [2000, 2001, 2000, 2001],
'treat': [0, 1, 0, 0],
'value': [5, 6, 7, 7.5]
})
self.test_files.append(self.panel_data_path)
# Data with potential instrument
self.iv_data_path = "analyzer_test_iv.csv"
create_dummy_csv_for_analysis(self.iv_data_path, {
'Z_assigned': [0, 1, 0, 1],
'D_actual': [0, 0, 0, 1],
'Y_outcome': [10, 11, 12, 15]
})
self.test_files.append(self.iv_data_path)
# Data with discontinuity
self.rdd_data_path = "analyzer_test_rdd.csv"
create_dummy_csv_for_analysis(self.rdd_data_path, {
'running_var': [-1.5, -0.5, 0.5, 1.5, -1.1, 0.8],
'outcome_rdd': [4, 5, 10, 11, 4.5, 10.5]
})
self.test_files.append(self.rdd_data_path)
def tearDown(self):
'''Clean up dummy files.'''
for f in self.test_files:
if os.path.exists(f):
os.remove(f)
def test_analyze_basic_structure(self):
'''Test the basic structure and keys of the summarized output.'''
result = analyze_dataset(self.basic_data_path)
self.assertIsInstance(result, dict)
self.assertNotIn("error", result, f"Analysis failed: {result.get('error')}")
expected_keys = [
"dataset_info", "columns", "potential_treatments", "potential_outcomes",
"temporal_structure_detected", "panel_data_detected",
"potential_instruments_detected", "discontinuities_detected"
]
# Check old detailed keys are NOT present
unexpected_keys = [
"column_types", "column_categories", "missing_values", "correlations",
"discontinuities", "variable_relationships", "column_type_summary",
"missing_value_summary", "discontinuity_summary", "relationship_summary"
]
for key in expected_keys:
self.assertIn(key, result, f"Expected key '{key}' missing.")
for key in unexpected_keys:
self.assertNotIn(key, result, f"Unexpected key '{key}' present.")
# Check some types
self.assertIsInstance(result["columns"], list)
self.assertIsInstance(result["potential_treatments"], list)
self.assertIsInstance(result["potential_outcomes"], list)
self.assertIsInstance(result["temporal_structure_detected"], bool)
self.assertIsInstance(result["panel_data_detected"], bool)
self.assertIsInstance(result["potential_instruments_detected"], bool)
self.assertIsInstance(result["discontinuities_detected"], bool)
def test_analyze_panel_data(self):
'''Test detection of panel data structure.'''
result = analyze_dataset(self.panel_data_path)
self.assertTrue(result["temporal_structure_detected"])
self.assertTrue(result["panel_data_detected"])
self.assertIn('year', result["columns"]) # Check columns list is correct
self.assertIn('unit', result["columns"])
def test_analyze_iv_data(self):
'''Test detection of potential IV.'''
result = analyze_dataset(self.iv_data_path)
self.assertTrue(result["potential_instruments_detected"])
def test_analyze_rdd_data(self):
'''Test detection of potential discontinuity.'''
# Note: Our summarized output only has a boolean flag.
# The internal detection logic might be complex, but output is simple.
result = analyze_dataset(self.rdd_data_path)
# This depends heavily on the thresholds in detect_discontinuities
# It might be False if the dummy data doesn't trigger it reliably
# self.assertTrue(result["discontinuities_detected"])
# For now, just check the key exists
self.assertIn("discontinuities_detected", result)
def test_analyze_file_not_found(self):
'''Test handling of non-existent file.'''
result = analyze_dataset("non_existent_file.csv")
self.assertIn("error", result)
self.assertIn("not found", result["error"])
if __name__ == '__main__':
unittest.main()