Spaces:
Running
Running
File size: 5,264 Bytes
1721aea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import unittest
import os
import pandas as pd
import numpy as np
# Import the function to test
from auto_causal.components.dataset_analyzer import analyze_dataset
# Helper to create dummy dataset files
def create_dummy_csv_for_analysis(path, data_dict):
df = pd.DataFrame(data_dict)
df.to_csv(path, index=False)
return path
class TestDatasetAnalyzer(unittest.TestCase):
def setUp(self):
'''Set up dummy data paths and create files.'''
self.test_files = []
# Basic data
self.basic_data_path = "analyzer_test_basic.csv"
create_dummy_csv_for_analysis(self.basic_data_path, {
'treatment': [0, 1, 0, 1, 0, 1],
'outcome': [10, 12, 11, 13, 9, 14],
'cov1': ['A', 'B', 'A', 'B', 'A', 'B'],
'numeric_cov': [1.1, 2.2, 1.3, 2.5, 1.0, 2.9]
})
self.test_files.append(self.basic_data_path)
# Panel data
self.panel_data_path = "analyzer_test_panel.csv"
create_dummy_csv_for_analysis(self.panel_data_path, {
'unit': [1, 1, 2, 2],
'year': [2000, 2001, 2000, 2001],
'treat': [0, 1, 0, 0],
'value': [5, 6, 7, 7.5]
})
self.test_files.append(self.panel_data_path)
# Data with potential instrument
self.iv_data_path = "analyzer_test_iv.csv"
create_dummy_csv_for_analysis(self.iv_data_path, {
'Z_assigned': [0, 1, 0, 1],
'D_actual': [0, 0, 0, 1],
'Y_outcome': [10, 11, 12, 15]
})
self.test_files.append(self.iv_data_path)
# Data with discontinuity
self.rdd_data_path = "analyzer_test_rdd.csv"
create_dummy_csv_for_analysis(self.rdd_data_path, {
'running_var': [-1.5, -0.5, 0.5, 1.5, -1.1, 0.8],
'outcome_rdd': [4, 5, 10, 11, 4.5, 10.5]
})
self.test_files.append(self.rdd_data_path)
def tearDown(self):
'''Clean up dummy files.'''
for f in self.test_files:
if os.path.exists(f):
os.remove(f)
def test_analyze_basic_structure(self):
'''Test the basic structure and keys of the summarized output.'''
result = analyze_dataset(self.basic_data_path)
self.assertIsInstance(result, dict)
self.assertNotIn("error", result, f"Analysis failed: {result.get('error')}")
expected_keys = [
"dataset_info", "columns", "potential_treatments", "potential_outcomes",
"temporal_structure_detected", "panel_data_detected",
"potential_instruments_detected", "discontinuities_detected"
]
# Check old detailed keys are NOT present
unexpected_keys = [
"column_types", "column_categories", "missing_values", "correlations",
"discontinuities", "variable_relationships", "column_type_summary",
"missing_value_summary", "discontinuity_summary", "relationship_summary"
]
for key in expected_keys:
self.assertIn(key, result, f"Expected key '{key}' missing.")
for key in unexpected_keys:
self.assertNotIn(key, result, f"Unexpected key '{key}' present.")
# Check some types
self.assertIsInstance(result["columns"], list)
self.assertIsInstance(result["potential_treatments"], list)
self.assertIsInstance(result["potential_outcomes"], list)
self.assertIsInstance(result["temporal_structure_detected"], bool)
self.assertIsInstance(result["panel_data_detected"], bool)
self.assertIsInstance(result["potential_instruments_detected"], bool)
self.assertIsInstance(result["discontinuities_detected"], bool)
def test_analyze_panel_data(self):
'''Test detection of panel data structure.'''
result = analyze_dataset(self.panel_data_path)
self.assertTrue(result["temporal_structure_detected"])
self.assertTrue(result["panel_data_detected"])
self.assertIn('year', result["columns"]) # Check columns list is correct
self.assertIn('unit', result["columns"])
def test_analyze_iv_data(self):
'''Test detection of potential IV.'''
result = analyze_dataset(self.iv_data_path)
self.assertTrue(result["potential_instruments_detected"])
def test_analyze_rdd_data(self):
'''Test detection of potential discontinuity.'''
# Note: Our summarized output only has a boolean flag.
# The internal detection logic might be complex, but output is simple.
result = analyze_dataset(self.rdd_data_path)
# This depends heavily on the thresholds in detect_discontinuities
# It might be False if the dummy data doesn't trigger it reliably
# self.assertTrue(result["discontinuities_detected"])
# For now, just check the key exists
self.assertIn("discontinuities_detected", result)
def test_analyze_file_not_found(self):
'''Test handling of non-existent file.'''
result = analyze_dataset("non_existent_file.csv")
self.assertIn("error", result)
self.assertIn("not found", result["error"])
if __name__ == '__main__':
unittest.main() |