Spaces:
Running
Running
""" | |
Dataset analyzer component for causal inference. | |
This module provides functionality to analyze datasets to detect characteristics | |
relevant for causal inference methods, including temporal structure, potential | |
instrumental variables, discontinuities, and variable relationships. | |
""" | |
import os | |
import pandas as pd | |
import numpy as np | |
from typing import Dict, List, Any, Optional, Tuple | |
from scipy import stats | |
import logging | |
import json | |
from langchain_core.language_models import BaseChatModel | |
from auto_causal.utils.llm_helpers import llm_identify_temporal_and_unit_vars | |
logger = logging.getLogger(__name__) | |
def _calculate_per_group_stats(df: pd.DataFrame, potential_treatments: List[str]) -> Dict[str, Dict]: | |
"""Calculates summary stats for numeric covariates grouped by potential binary treatments.""" | |
stats_dict = {} | |
numeric_cols = df.select_dtypes(include=np.number).columns.tolist() | |
for treat_var in potential_treatments: | |
if treat_var not in df.columns: | |
logger.warning(f"Potential treatment '{treat_var}' not found in DataFrame columns.") | |
continue | |
# Ensure treatment is binary (0/1 or similar) | |
unique_vals = df[treat_var].dropna().unique() | |
if len(unique_vals) != 2: | |
logger.info(f"Skipping stats for potential treatment '{treat_var}' as it is not binary ({len(unique_vals)} unique values).") | |
continue | |
# Attempt to map values to 0 and 1 if possible | |
try: | |
# Ensure boolean is converted to int | |
if df[treat_var].dtype == 'bool': | |
df[treat_var] = df[treat_var].astype(int) | |
unique_vals = df[treat_var].dropna().unique() | |
# Basic check if values are interpretable as 0/1 | |
if not set(unique_vals).issubset({0, 1}): | |
# Attempt conversion if possible (e.g., True/False strings?) | |
logger.warning(f"Potential treatment '{treat_var}' has values {unique_vals}, not {0, 1}. Cannot calculate group stats reliably.") | |
continue | |
except Exception as e: | |
logger.warning(f"Could not process potential treatment '{treat_var}' values ({unique_vals}): {e}") | |
continue | |
logger.info(f"Calculating group stats for treatment: '{treat_var}'") | |
treat_stats = {'group_sizes': {}, 'covariate_stats': {}} | |
try: | |
grouped = df.groupby(treat_var) | |
sizes = grouped.size() | |
treat_stats['group_sizes']['treated'] = int(sizes.get(1, 0)) | |
treat_stats['group_sizes']['control'] = int(sizes.get(0, 0)) | |
if treat_stats['group_sizes']['treated'] == 0 or treat_stats['group_sizes']['control'] == 0: | |
logger.warning(f"Treatment '{treat_var}' has zero samples in one group. Skipping covariate stats.") | |
stats_dict[treat_var] = treat_stats | |
continue | |
# Calculate mean and std for numeric covariates | |
cov_stats = grouped[numeric_cols].agg(['mean', 'std']).unstack() | |
for cov in numeric_cols: | |
if cov == treat_var: continue # Skip treatment variable itself | |
mean_control = cov_stats.get(('mean', 0, cov), np.nan) | |
std_control = cov_stats.get(('std', 0, cov), np.nan) | |
mean_treated = cov_stats.get(('mean', 1, cov), np.nan) | |
std_treated = cov_stats.get(('std', 1, cov), np.nan) | |
treat_stats['covariate_stats'][cov] = { | |
'mean_control': float(mean_control) if pd.notna(mean_control) else None, | |
'std_control': float(std_control) if pd.notna(std_control) else None, | |
'mean_treat': float(mean_treated) if pd.notna(mean_treated) else None, | |
'std_treat': float(std_treated) if pd.notna(std_treated) else None, | |
} | |
stats_dict[treat_var] = treat_stats | |
except Exception as e: | |
logger.error(f"Error calculating stats for treatment '{treat_var}': {e}", exc_info=True) | |
# Store partial info if possible | |
if treat_var not in stats_dict: | |
stats_dict[treat_var] = {'error': str(e)} | |
elif 'error' not in stats_dict[treat_var]: | |
stats_dict[treat_var]['error'] = str(e) | |
return stats_dict | |
def analyze_dataset( | |
dataset_path: str, | |
llm_client: Optional[BaseChatModel] = None, | |
dataset_description: Optional[str] = None, | |
original_query: Optional[str] = None | |
) -> Dict[str, Any]: | |
""" | |
Analyze a dataset to identify important characteristics for causal inference. | |
Args: | |
dataset_path: Path to the dataset file | |
llm_client: Optional LLM client for enhanced analysis | |
dataset_description: Optional description of the dataset for context | |
Returns: | |
Dict containing dataset analysis results: | |
- dataset_info: Basic information about the dataset | |
- columns: List of column names | |
- potential_treatments: List of potential treatment variables (possibly LLM augmented) | |
- potential_outcomes: List of potential outcome variables (possibly LLM augmented) | |
- temporal_structure_detected: Whether temporal structure was detected | |
- panel_data_detected: Whether panel data structure was detected | |
- potential_instruments_detected: Whether potential instruments were detected | |
- discontinuities_detected: Whether discontinuities were detected | |
- llm_augmentation: Status of LLM augmentation if used | |
""" | |
llm_augmentation = "Not used" if not llm_client else "Initialized" | |
# Check if file exists | |
if not os.path.exists(dataset_path): | |
logger.error(f"Dataset file not found at {dataset_path}") | |
return {"error": f"Dataset file not found at {dataset_path}"} | |
try: | |
# Load the dataset | |
df = pd.read_csv(dataset_path) | |
# Basic dataset information | |
sample_size = len(df) | |
columns_list = df.columns.tolist() | |
num_covariates = len(columns_list) - 2 # Rough estimate (total - T - Y) | |
dataset_info = { | |
"num_rows": sample_size, | |
"num_columns": len(columns_list), | |
"file_path": dataset_path, | |
"file_name": os.path.basename(dataset_path) | |
} | |
# --- Detailed Analysis (Keep internal) --- | |
column_types_detailed = {col: str(df[col].dtype) for col in df.columns} | |
missing_values_detailed = df.isnull().sum().to_dict() | |
column_categories_detailed = _categorize_columns(df) | |
column_nunique_counts_detailed = {col: df[col].nunique() for col in df.columns} # Calculate nunique | |
numeric_cols = df.select_dtypes(include=['number']).columns.tolist() | |
correlations_detailed = df[numeric_cols].corr() if numeric_cols else pd.DataFrame() | |
temporal_structure_detailed = detect_temporal_structure(df, llm_client, dataset_description, original_query) | |
# First, identify potential treatment and outcome variables | |
potential_variables = _identify_potential_variables( | |
df, | |
column_categories_detailed, | |
llm_client=llm_client, | |
dataset_description=dataset_description | |
) | |
if llm_client: | |
llm_augmentation = "Used for variable identification" | |
# Then use that info to help find potential instrumental variables | |
potential_instruments_detailed = find_potential_instruments( | |
df, | |
llm_client=llm_client, | |
potential_treatments=potential_variables.get("potential_treatments", []), | |
potential_outcomes=potential_variables.get("potential_outcomes", []), | |
dataset_description=dataset_description | |
) | |
# Other analyses | |
discontinuities_detailed = detect_discontinuities(df) | |
variable_relationships_detailed = assess_variable_relationships(df, correlations_detailed) | |
# Calculate per-group stats for potential binary treatments | |
potential_binary_treatments = [ | |
t for t in potential_variables["potential_treatments"] | |
if column_categories_detailed.get(t) == 'binary' | |
or column_categories_detailed.get(t) == 'binary_categorical' | |
] | |
per_group_stats = _calculate_per_group_stats(df.copy(), potential_binary_treatments) | |
# --- Summarized Analysis (For Output) --- | |
# Get boolean flags and essential lists | |
has_temporal = temporal_structure_detailed.get("has_temporal_structure", False) | |
is_panel = temporal_structure_detailed.get("is_panel_data", False) | |
logger.info(f"iv is {potential_instruments_detailed}") | |
has_instruments = len(potential_instruments_detailed) > 0 | |
has_discontinuities = discontinuities_detailed.get("has_discontinuities", False) | |
# --- Extract only instrument names for the final output --- | |
potential_instrument_names = [ | |
inst_dict.get('variable') | |
for inst_dict in potential_instruments_detailed | |
if isinstance(inst_dict, dict) and 'variable' in inst_dict | |
] | |
logger.info(f"iv is {potential_instrument_names}") | |
# --- Final Output Dictionary (Highly Summarized) --- | |
return { | |
"dataset_info": dataset_info, # Keep basic info | |
"columns": columns_list, | |
"potential_treatments": potential_variables["potential_treatments"], | |
"potential_outcomes": potential_variables["potential_outcomes"], | |
# Return concise flags instead of detailed dicts/lists | |
"temporal_structure_detected": has_temporal, | |
"panel_data_detected": is_panel, | |
"potential_instruments_detected": has_instruments, | |
"discontinuities_detected": has_discontinuities, | |
# Use the extracted list of names here | |
"potential_instruments": potential_instrument_names, | |
"discontinuities": discontinuities_detailed, | |
"temporal_structure": temporal_structure_detailed, | |
"column_categories": column_categories_detailed, | |
"column_nunique_counts": column_nunique_counts_detailed, # Add nunique counts to output | |
"sample_size": sample_size, | |
"num_covariates_estimate": num_covariates, | |
"llm_augmentation": llm_augmentation | |
} | |
except Exception as e: | |
logger.error(f"Error analyzing dataset '{dataset_path}': {e}", exc_info=True) | |
return { | |
"error": f"Error analyzing dataset: {str(e)}", | |
"llm_augmentation": llm_augmentation | |
} | |
def _categorize_columns(df: pd.DataFrame) -> Dict[str, str]: | |
""" | |
Categorize columns into types relevant for causal inference. | |
Args: | |
df: DataFrame to analyze | |
Returns: | |
Dict mapping column names to their types | |
""" | |
result = {} | |
for col in df.columns: | |
# Check if column is numeric | |
if pd.api.types.is_numeric_dtype(df[col]): | |
# Count number of unique values | |
n_unique = df[col].nunique() | |
# Binary numeric variable | |
if n_unique == 2: | |
result[col] = "binary" | |
# Likely categorical represented as numeric | |
elif n_unique < 10: | |
result[col] = "categorical_numeric" | |
# Discrete numeric (integers) | |
elif pd.api.types.is_integer_dtype(df[col]): | |
result[col] = "discrete_numeric" | |
# Continuous numeric | |
else: | |
result[col] = "continuous_numeric" | |
# Check for datetime | |
elif pd.api.types.is_datetime64_any_dtype(df[col]) or _is_date_string(df, col): | |
result[col] = "datetime" | |
# Check for categorical | |
elif pd.api.types.is_categorical_dtype(df[col]) or df[col].nunique() < 20: | |
if df[col].nunique() == 2: | |
result[col] = "binary_categorical" | |
else: | |
result[col] = "categorical" | |
# Must be text or other | |
else: | |
result[col] = "text_or_other" | |
return result | |
def _is_date_string(df: pd.DataFrame, col: str) -> bool: | |
""" | |
Check if a column contains date strings. | |
Args: | |
df: DataFrame to check | |
col: Column name to check | |
Returns: | |
True if the column appears to contain date strings | |
""" | |
# Try to convert to datetime | |
if not pd.api.types.is_string_dtype(df[col]): | |
return False | |
# Check sample of values | |
sample = df[col].dropna().sample(min(10, len(df[col].dropna()))).tolist() | |
try: | |
for val in sample: | |
pd.to_datetime(val) | |
return True | |
except: | |
return False | |
def _identify_potential_variables( | |
df: pd.DataFrame, | |
column_categories: Dict[str, str], | |
llm_client: Optional[BaseChatModel] = None, | |
dataset_description: Optional[str] = None | |
) -> Dict[str, List[str]]: | |
""" | |
Identify potential treatment and outcome variables in the dataset, using LLM if available. | |
Falls back to heuristic method if LLM fails or is not available. | |
Args: | |
df: DataFrame to analyze | |
column_categories: Dictionary mapping column names to their types | |
llm_client: Optional LLM client for enhanced identification | |
dataset_description: Optional description of the dataset for context | |
Returns: | |
Dict with potential treatment and outcome variables | |
""" | |
# Try LLM approach if client is provided | |
if llm_client: | |
try: | |
logger.info("Using LLM to identify potential treatment and outcome variables") | |
# Create a concise prompt with just column information | |
columns_list = df.columns.tolist() | |
column_types = {col: str(df[col].dtype) for col in columns_list} | |
# Get binary columns for extra context | |
binary_cols = [col for col in columns_list | |
if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() == 2] | |
# Add dataset description if available | |
description_text = f"\nDataset Description: {dataset_description}" if dataset_description else "" | |
prompt = f""" | |
You are an expert causal inference data scientist. Identify potential treatment and outcome variables from this dataset.{description_text} | |
Dataset columns: | |
{columns_list} | |
Column types: | |
{column_types} | |
Binary columns (good treatment candidates): | |
{binary_cols} | |
Instructions: | |
1. Identify TREATMENT variables: interventions, treatments, programs, policies, or binary state changes. | |
Look for binary variables or names with 'treatment', 'intervention', 'program', 'policy', etc. | |
2. Identify OUTCOME variables: results, effects, or responses to treatments. | |
Look for numeric variables (especially non-binary) or names with 'outcome', 'result', 'effect', 'score', etc. | |
Return ONLY a valid JSON object with two lists: "potential_treatments" and "potential_outcomes". | |
Example: {{"potential_treatments": ["treatment_a", "program_b"], "potential_outcomes": ["result_score", "outcome_measure"]}} | |
""" | |
# Call the LLM and parse the response | |
response = llm_client.invoke(prompt) | |
response_text = response.content if hasattr(response, 'content') else str(response) | |
# Extract JSON from the response text | |
import re | |
json_match = re.search(r'{.*}', response_text, re.DOTALL) | |
if json_match: | |
result = json.loads(json_match.group(0)) | |
# Validate the response | |
if (isinstance(result, dict) and | |
"potential_treatments" in result and | |
"potential_outcomes" in result and | |
isinstance(result["potential_treatments"], list) and | |
isinstance(result["potential_outcomes"], list)): | |
# Ensure all suggestions are valid columns | |
valid_treatments = [col for col in result["potential_treatments"] if col in df.columns] | |
valid_outcomes = [col for col in result["potential_outcomes"] if col in df.columns] | |
if valid_treatments and valid_outcomes: | |
logger.info(f"LLM identified {len(valid_treatments)} treatments and {len(valid_outcomes)} outcomes") | |
return { | |
"potential_treatments": valid_treatments, | |
"potential_outcomes": valid_outcomes | |
} | |
else: | |
logger.warning("LLM suggested invalid columns, falling back to heuristic method") | |
else: | |
logger.warning("Invalid LLM response format, falling back to heuristic method") | |
else: | |
logger.warning("Could not extract JSON from LLM response, falling back to heuristic method") | |
except Exception as e: | |
logger.error(f"Error in LLM identification: {e}", exc_info=True) | |
logger.info("Falling back to heuristic method") | |
# Fallback to heuristic method | |
logger.info("Using heuristic method to identify potential treatment and outcome variables") | |
# Identify potential treatment variables | |
potential_treatments = [] | |
# Look for binary variables (good treatment candidates) | |
binary_cols = [col for col in df.columns | |
if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() == 2] | |
# Look for variables with names suggesting treatment | |
treatment_keywords = ['treatment', 'treat', 'intervention', 'program', 'policy', | |
'exposed', 'assigned', 'received', 'participated'] | |
for col in df.columns: | |
col_lower = col.lower() | |
if any(keyword in col_lower for keyword in treatment_keywords): | |
potential_treatments.append(col) | |
# Add binary variables if we don't have enough candidates | |
if len(potential_treatments) < 3: | |
for col in binary_cols: | |
if col not in potential_treatments: | |
potential_treatments.append(col) | |
if len(potential_treatments) >= 3: | |
break | |
# Identify potential outcome variables | |
potential_outcomes = [] | |
# Look for numeric variables that aren't binary | |
numeric_cols = df.select_dtypes(include=['number']).columns.tolist() | |
non_binary_numeric = [col for col in numeric_cols if col not in binary_cols] | |
# Look for variables with names suggesting outcomes | |
outcome_keywords = ['outcome', 'result', 'effect', 'response', 'score', 'performance', | |
'achievement', 'success', 'failure', 'improvement'] | |
for col in df.columns: | |
col_lower = col.lower() | |
if any(keyword in col_lower for keyword in outcome_keywords): | |
potential_outcomes.append(col) | |
# Add numeric non-binary variables if we don't have enough candidates | |
if len(potential_outcomes) < 3: | |
for col in non_binary_numeric: | |
if col not in potential_outcomes and col not in potential_treatments: | |
potential_outcomes.append(col) | |
if len(potential_outcomes) >= 3: | |
break | |
return { | |
"potential_treatments": potential_treatments, | |
"potential_outcomes": potential_outcomes | |
} | |
def detect_temporal_structure( | |
df: pd.DataFrame, | |
llm_client: Optional[BaseChatModel] = None, | |
dataset_description: Optional[str] = None, | |
original_query: Optional[str] = None | |
) -> Dict[str, Any]: | |
""" | |
Detect temporal structure in the dataset, using LLM for enhanced identification. | |
Args: | |
df: DataFrame to analyze | |
llm_client: Optional LLM client for enhanced identification | |
dataset_description: Optional description of the dataset for context | |
Returns: | |
Dict with information about temporal structure: | |
- has_temporal_structure: Whether temporal structure exists | |
- temporal_columns: Primary time column identified (or list if multiple from heuristic) | |
- is_panel_data: Whether data is in panel format | |
- time_column: Primary time column identified for panel data | |
- id_column: Primary unit ID column identified for panel data | |
- time_periods: Number of time periods (if panel data) | |
- units: Number of unique units (if panel data) | |
- identification_method: How time/unit vars were identified ('LLM', 'Heuristic', 'None') | |
""" | |
result = { | |
"has_temporal_structure": False, | |
"temporal_columns": [], # Will store primary time column or heuristic list | |
"is_panel_data": False, | |
"time_column": None, | |
"id_column": None, | |
"time_periods": None, | |
"units": None, | |
"identification_method": "None" | |
} | |
# --- Step 1: Heuristic identification (as before) --- | |
#heuristic_datetime_cols = [] | |
#for col in df.columns: | |
# if pd.api.types.is_datetime64_any_dtype(df[col]): | |
# heuristic_datetime_cols.append(col) | |
# elif pd.api.types.is_string_dtype(df[col]): | |
# try: | |
# if pd.to_datetime(df[col], errors='coerce').notna().any(): | |
# heuristic_datetime_cols.append(col) | |
# except: | |
# pass # Ignore conversion errors | |
#time_keywords = ['year', 'month', 'day', 'date', 'time', 'period', 'quarter', 'week'] | |
#for col in df.columns: | |
# col_lower = col.lower() | |
# if any(keyword in col_lower for keyword in time_keywords) and col not in heuristic_datetime_cols: | |
# heuristic_datetime_cols.append(col) | |
#id_keywords = ['id', 'individual', 'person', 'unit', 'entity', 'firm', 'company', 'state', 'country'] | |
#heuristic_potential_id_cols = [] | |
#for col in df.columns: | |
# col_lower = col.lower() | |
# # Exclude columns already identified as time-related by heuristics | |
# if any(keyword in col_lower for keyword in id_keywords) and col not in heuristic_datetime_cols: | |
# heuristic_potential_id_cols.append(col) | |
# --- Step 2: LLM-assisted identification --- | |
llm_identified_time_var = None | |
llm_identified_unit_var = None | |
heuristic_datetime_cols = [] | |
heuristic_potential_id_cols = [] | |
dataset_summary = df.describe(include='all') | |
if llm_client: | |
logger.info("Attempting LLM-assisted identification of temporal/unit variables.") | |
column_names = df.columns.tolist() | |
column_dtypes_dict = {col: str(df[col].dtype) for col in column_names} | |
try: | |
llm_suggestions = llm_identify_temporal_and_unit_vars( | |
column_names=column_names, | |
column_dtypes=column_dtypes_dict, | |
dataset_description=dataset_description if dataset_description else "No dataset description provided.", | |
dataset_summary=dataset_summary, | |
heuristic_time_candidates=heuristic_datetime_cols, | |
heuristic_id_candidates=heuristic_potential_id_cols, | |
query=original_query if original_query else "No query provided.", | |
llm=llm_client | |
) | |
llm_identified_time_var = llm_suggestions.get("time_variable") | |
llm_identified_unit_var = llm_suggestions.get("unit_variable") | |
result["identification_method"] = "LLM" | |
if not llm_identified_time_var and not llm_identified_unit_var: | |
result["identification_method"] = "LLM_NoIdentification" | |
except Exception as e: | |
logger.warning(f"LLM call for temporal/unit vars failed: {e}. Falling back to heuristics.") | |
result["identification_method"] = "Heuristic_LLM_Error" | |
else: | |
result["identification_method"] = "Heuristic_NoLLM" | |
# --- Step 3: Combine LLM and Heuristic Results --- | |
final_time_var = None | |
final_unit_var = None | |
if llm_identified_time_var: | |
final_time_var = llm_identified_time_var | |
logger.info(f"Prioritizing LLM identified time variable: {final_time_var}") | |
elif heuristic_datetime_cols: | |
final_time_var = heuristic_datetime_cols[0] # Fallback to first heuristic time col | |
logger.info(f"Using heuristic time variable: {final_time_var}") | |
if llm_identified_unit_var: | |
final_unit_var = llm_identified_unit_var | |
logger.info(f"Prioritizing LLM identified unit variable: {final_unit_var}") | |
elif heuristic_potential_id_cols: | |
final_unit_var = heuristic_potential_id_cols[0] # Fallback to first heuristic ID col | |
logger.info(f"Using heuristic unit variable: {final_unit_var}") | |
# Update results based on final selections | |
if final_time_var: | |
result["has_temporal_structure"] = True | |
result["temporal_columns"] = [final_time_var] # Store as a list with the primary time var | |
result["time_column"] = final_time_var | |
else: # If no time var found by LLM or heuristic, use original heuristic list for temporal_columns | |
if heuristic_datetime_cols: | |
result["has_temporal_structure"] = True | |
result["temporal_columns"] = heuristic_datetime_cols | |
# time_column remains None | |
if final_unit_var: | |
result["id_column"] = final_unit_var | |
# --- Step 4: Update Panel Data Logic (based on final_time_var and final_unit_var) --- | |
if final_time_var and final_unit_var: | |
# Check if there are multiple time periods per unit using the identified variables | |
try: | |
# Ensure columns exist before groupby | |
if final_time_var in df.columns and final_unit_var in df.columns: | |
if df.groupby(final_unit_var)[final_time_var].nunique().mean() > 1.0: | |
result["is_panel_data"] = True | |
result["time_periods"] = df[final_time_var].nunique() | |
result["units"] = df[final_unit_var].nunique() | |
logger.info(f"Panel data detected: Time='{final_time_var}', Unit='{final_unit_var}', Periods={result['time_periods']}, Units={result['units']}") | |
else: | |
logger.info("Not panel data: Each unit does not have multiple time periods.") | |
else: | |
logger.warning(f"Final time ('{final_time_var}') or unit ('{final_unit_var}') var not in DataFrame. Cannot confirm panel structure.") | |
except Exception as e: | |
logger.error(f"Error checking panel data structure with time='{final_time_var}', unit='{final_unit_var}': {e}") | |
result["is_panel_data"] = False # Default to false on error | |
else: | |
logger.info("Not panel data: Missing either time or unit variable for panel structure.") | |
logger.debug(f"Final temporal structure detection result: {result}") | |
return result | |
def find_potential_instruments( | |
df: pd.DataFrame, | |
llm_client: Optional[BaseChatModel] = None, | |
potential_treatments: List[str] = None, | |
potential_outcomes: List[str] = None, | |
dataset_description: Optional[str] = None | |
) -> List[Dict[str, Any]]: | |
""" | |
Find potential instrumental variables in the dataset, using LLM if available. | |
Falls back to heuristic method if LLM fails or is not available. | |
Args: | |
df: DataFrame to analyze | |
llm_client: Optional LLM client for enhanced identification | |
potential_treatments: Optional list of potential treatment variables | |
potential_outcomes: Optional list of potential outcome variables | |
dataset_description: Optional description of the dataset for context | |
Returns: | |
List of potential instrumental variables with their properties | |
""" | |
# Try LLM approach if client is provided | |
if llm_client: | |
try: | |
logger.info("Using LLM to identify potential instrumental variables") | |
# Create a concise prompt with just column information | |
columns_list = df.columns.tolist() | |
# Exclude known treatment and outcome variables from consideration | |
excluded_columns = [] | |
if potential_treatments: | |
excluded_columns.extend(potential_treatments) | |
if potential_outcomes: | |
excluded_columns.extend(potential_outcomes) | |
# Filter columns to exclude treatments and outcomes | |
candidate_columns = [col for col in columns_list if col not in excluded_columns] | |
if not candidate_columns: | |
logger.warning("No eligible columns for instrumental variables after filtering treatments and outcomes") | |
return [] | |
# Get column types for context | |
column_types = {col: str(df[col].dtype) for col in candidate_columns} | |
# Add dataset description if available | |
description_text = f"\nDataset Description: {dataset_description}" if dataset_description else "" | |
prompt = f""" | |
You are an expert causal inference data scientist. Identify potential instrumental variables from this dataset.{description_text} | |
DEFINITION: Instrumental variables must: | |
1. Be correlated with the treatment variable (relevance) | |
2. Only affect the outcome through the treatment (exclusion restriction) | |
3. Not be correlated with unmeasured confounders (exogeneity) | |
Treatment variables: {potential_treatments if potential_treatments else "Unknown"} | |
Outcome variables: {potential_outcomes if potential_outcomes else "Unknown"} | |
Available columns (excluding treatments and outcomes): | |
{candidate_columns} | |
Column types: | |
{column_types} | |
Look for variables likely to be: | |
- Random assignments | |
- Policy changes | |
- Geographic or temporal variations | |
- Variables with names containing: 'instrument', 'iv', 'assigned', 'random', 'lottery', 'exogenous' | |
Return ONLY a JSON array of objects, each with "variable", "reason", and "data_type" fields. | |
Example: | |
[ | |
{{"variable": "random_assignment", "reason": "Random assignment variable", "data_type": "int64"}}, | |
{{"variable": "distance_to_facility", "reason": "Geographic variation", "data_type": "float64"}} | |
] | |
""" | |
# Call the LLM and parse the response | |
response = llm_client.invoke(prompt) | |
response_text = response.content if hasattr(response, 'content') else str(response) | |
# Extract JSON from the response text | |
import re | |
json_match = re.search(r'\[\s*{.*}\s*\]', response_text, re.DOTALL) | |
if json_match: | |
result = json.loads(json_match.group(0)) | |
# Validate the response | |
if isinstance(result, list) and len(result) > 0: | |
# Filter for valid entries | |
valid_instruments = [] | |
for item in result: | |
if not isinstance(item, dict) or "variable" not in item: | |
continue | |
if item["variable"] not in df.columns: | |
continue | |
# Ensure all required fields are present | |
if "reason" not in item: | |
item["reason"] = "Identified by LLM" | |
if "data_type" not in item: | |
item["data_type"] = str(df[item["variable"]].dtype) | |
valid_instruments.append(item) | |
if valid_instruments: | |
logger.info(f"LLM identified {len(valid_instruments)} potential instrumental variables {valid_instruments}") | |
return valid_instruments | |
else: | |
logger.warning("No valid instruments found by LLM, falling back to heuristic method") | |
else: | |
logger.warning("Invalid LLM response format, falling back to heuristic method") | |
else: | |
logger.warning("Could not extract JSON from LLM response, falling back to heuristic method") | |
except Exception as e: | |
logger.error(f"Error in LLM identification of instruments: {e}", exc_info=True) | |
logger.info("Falling back to heuristic method") | |
# Fallback to heuristic method | |
logger.info("Using heuristic method to identify potential instrumental variables") | |
potential_instruments = [] | |
# Look for variables with instrumental-related names | |
instrument_keywords = ['instrument', 'iv', 'assigned', 'random', 'lottery', 'exogenous'] | |
for col in df.columns: | |
# Skip treatment and outcome variables | |
if potential_treatments and col in potential_treatments: | |
continue | |
if potential_outcomes and col in potential_outcomes: | |
continue | |
col_lower = col.lower() | |
if any(keyword in col_lower for keyword in instrument_keywords): | |
instrument_info = { | |
"variable": col, | |
"reason": f"Name contains instrument-related keyword", | |
"data_type": str(df[col].dtype) | |
} | |
potential_instruments.append(instrument_info) | |
return potential_instruments | |
def detect_discontinuities(df: pd.DataFrame) -> Dict[str, Any]: | |
""" | |
Identify discontinuities in continuous variables (for RDD). | |
Args: | |
df: DataFrame to analyze | |
Returns: | |
Dict with information about detected discontinuities | |
""" | |
discontinuities = [] | |
# For each numeric column, check for potential discontinuities | |
numeric_cols = df.select_dtypes(include=['number']).columns.tolist() | |
for col in numeric_cols: | |
# Skip columns with too many unique values | |
if df[col].nunique() > 100: | |
continue | |
values = df[col].dropna().sort_values().values | |
# Calculate gaps between consecutive values | |
if len(values) > 10: | |
gaps = np.diff(values) | |
mean_gap = np.mean(gaps) | |
std_gap = np.std(gaps) | |
# Look for unusually large gaps (potential discontinuities) | |
large_gaps = np.where(gaps > mean_gap + 2*std_gap)[0] | |
if len(large_gaps) > 0: | |
for idx in large_gaps: | |
cutpoint = (values[idx] + values[idx+1]) / 2 | |
discontinuities.append({ | |
"variable": col, | |
"cutpoint": float(cutpoint), | |
"gap_size": float(gaps[idx]), | |
"mean_gap": float(mean_gap) | |
}) | |
return { | |
"has_discontinuities": len(discontinuities) > 0, | |
"discontinuities": discontinuities | |
} | |
def assess_variable_relationships(df: pd.DataFrame, corr_matrix: pd.DataFrame) -> Dict[str, Any]: | |
""" | |
Assess relationships between variables in the dataset. | |
Args: | |
df: DataFrame to analyze | |
corr_matrix: Precomputed correlation matrix for numeric columns | |
Returns: | |
Dict with information about variable relationships: | |
- strongly_correlated_pairs: Pairs of strongly correlated variables | |
- potential_confounders: Variables that might be confounders | |
""" | |
result = {"strongly_correlated_pairs": [], "potential_confounders": []} | |
numeric_cols = corr_matrix.columns.tolist() | |
if len(numeric_cols) < 2: | |
return result | |
# Use the precomputed correlation matrix | |
corr_matrix_abs = corr_matrix.abs() | |
# Find strongly correlated variable pairs | |
for i in range(len(numeric_cols)): | |
for j in range(i+1, len(numeric_cols)): | |
if abs(corr_matrix_abs.iloc[i, j]) > 0.7: # Correlation threshold | |
result["strongly_correlated_pairs"].append({ | |
"variables": [numeric_cols[i], numeric_cols[j]], | |
"correlation": float(corr_matrix.iloc[i, j]) | |
}) | |
# Identify potential confounders (variables correlated with multiple others) | |
confounder_counts = {col: 0 for col in numeric_cols} | |
for pair in result["strongly_correlated_pairs"]: | |
confounder_counts[pair["variables"][0]] += 1 | |
confounder_counts[pair["variables"][1]] += 1 | |
# Variables correlated with multiple others are potential confounders | |
for col, count in confounder_counts.items(): | |
if count >= 2: | |
result["potential_confounders"].append({"variable": col, "num_correlations": count}) | |
return result |