Spaces:
Running
Running
""" | |
Utility functions for causal inference methods. | |
This module provides common utility functions used across | |
different causal inference methods. | |
""" | |
from typing import Dict, List, Set, Optional, Union, Any, Tuple | |
import numpy as np | |
import pandas as pd | |
import scipy.stats as stats | |
from sklearn.preprocessing import StandardScaler | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from statsmodels.stats.outliers_influence import variance_inflation_factor | |
from sklearn.linear_model import LogisticRegression | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def check_binary_treatment(treatment_series: pd.Series) -> bool: | |
""" | |
Check if treatment variable is binary. | |
Args: | |
treatment_series: Series containing treatment variable | |
Returns: | |
Boolean indicating if treatment is binary | |
""" | |
unique_values = set(treatment_series.unique()) | |
# Remove NaN values if present | |
unique_values = {x for x in unique_values if pd.notna(x)} | |
# Check if there are exactly 2 unique values | |
if len(unique_values) != 2: | |
return False | |
# Check if values are 0/1 or similar binary encoding | |
sorted_vals = sorted(unique_values) | |
# Check common binary encodings: 0/1, False/True, etc. | |
binary_pairs = [ | |
(0, 1), | |
(False, True), | |
("0", "1"), | |
("no", "yes"), | |
("false", "true") | |
] | |
# Convert to strings for comparison if needed | |
if not all(isinstance(v, (int, float, bool)) for v in sorted_vals): | |
# Convert to lowercase strings for comparison | |
str_vals = [str(v).lower() for v in sorted_vals] | |
for pair in binary_pairs: | |
str_pair = [str(v).lower() for v in pair] | |
if str_vals == str_pair: | |
return True | |
return False | |
# For numeric values, check if they're 0/1 or can be easily mapped to 0/1 | |
if sorted_vals == [0, 1]: | |
return True | |
# Check if there are only two values that could be easily mapped | |
return len(unique_values) == 2 | |
def calculate_standardized_differences(df: pd.DataFrame, treatment: str, covariates: List[str]) -> Dict[str, float]: | |
""" | |
Calculate standardized differences between treated and control groups. | |
Args: | |
df: DataFrame containing the data | |
treatment: Name of treatment variable | |
covariates: List of covariate variable names | |
Returns: | |
Dictionary with standardized differences for each covariate | |
""" | |
treated = df[df[treatment] == 1] | |
control = df[df[treatment] == 0] | |
std_diffs = {} | |
for cov in covariates: | |
# Skip if covariate has missing values | |
if df[cov].isna().any(): | |
std_diffs[cov] = np.nan | |
continue | |
t_mean = treated[cov].mean() | |
c_mean = control[cov].mean() | |
t_var = treated[cov].var() | |
c_var = control[cov].var() | |
# Pooled standard deviation | |
pooled_std = np.sqrt((t_var + c_var) / 2) | |
# Avoid division by zero | |
if pooled_std == 0: | |
std_diffs[cov] = 0 | |
else: | |
std_diffs[cov] = (t_mean - c_mean) / pooled_std | |
return std_diffs | |
def check_overlap(df: pd.DataFrame, treatment: str, propensity_scores: np.ndarray, | |
threshold: float = 0.5) -> Dict[str, Any]: | |
""" | |
Check overlap in propensity scores between treated and control groups. | |
Args: | |
df: DataFrame containing the data | |
treatment: Name of treatment variable | |
propensity_scores: Array of propensity scores | |
threshold: Threshold for sufficient overlap (proportion of range) | |
Returns: | |
Dictionary with overlap statistics | |
""" | |
df_copy = df.copy() | |
df_copy['propensity_score'] = propensity_scores | |
treated = df_copy[df_copy[treatment] == 1]['propensity_score'] | |
control = df_copy[df_copy[treatment] == 0]['propensity_score'] | |
min_treated = treated.min() | |
max_treated = treated.max() | |
min_control = control.min() | |
max_control = control.max() | |
overall_min = min(min_treated, min_control) | |
overall_max = max(max_treated, max_control) | |
# Range of overlap | |
overlap_min = max(min_treated, min_control) | |
overlap_max = min(max_treated, max_control) | |
# Check if there is any overlap | |
if overlap_max < overlap_min: | |
overlap_proportion = 0 | |
sufficient_overlap = False | |
else: | |
# Calculate proportion of overall range that has overlap | |
overall_range = overall_max - overall_min | |
if overall_range == 0: | |
# All values are the same | |
overlap_proportion = 1.0 | |
sufficient_overlap = True | |
else: | |
overlap_proportion = (overlap_max - overlap_min) / overall_range | |
sufficient_overlap = overlap_proportion >= threshold | |
return { | |
"treated_range": (float(min_treated), float(max_treated)), | |
"control_range": (float(min_control), float(max_control)), | |
"overlap_range": (float(overlap_min), float(overlap_max)), | |
"overlap_proportion": float(overlap_proportion), | |
"sufficient_overlap": sufficient_overlap | |
} | |
def plot_propensity_overlap(df: pd.DataFrame, treatment: str, propensity_scores: np.ndarray, | |
save_path: Optional[str] = None) -> None: | |
""" | |
Plot overlap in propensity scores. | |
Args: | |
df: DataFrame containing the data | |
treatment: Name of treatment variable | |
propensity_scores: Array of propensity scores | |
save_path: Optional path to save the plot | |
""" | |
df_copy = df.copy() | |
df_copy['propensity_score'] = propensity_scores | |
plt.figure(figsize=(10, 6)) | |
# Plot histograms | |
sns.histplot(df_copy.loc[df_copy[treatment] == 1, 'propensity_score'], | |
bins=20, alpha=0.5, label='Treated', color='blue', kde=True) | |
sns.histplot(df_copy.loc[df_copy[treatment] == 0, 'propensity_score'], | |
bins=20, alpha=0.5, label='Control', color='red', kde=True) | |
plt.title('Propensity Score Distributions') | |
plt.xlabel('Propensity Score') | |
plt.ylabel('Count') | |
plt.legend() | |
if save_path: | |
plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
plt.show() | |
def plot_covariate_balance(standardized_diffs: Dict[str, float], threshold: float = 0.1, | |
save_path: Optional[str] = None) -> None: | |
""" | |
Plot standardized differences for covariates before and after matching. | |
Args: | |
standardized_diffs: Dictionary with standardized differences | |
threshold: Threshold for acceptable balance | |
save_path: Optional path to save the plot | |
""" | |
# Convert to DataFrame for plotting | |
df = pd.DataFrame({ | |
'Covariate': list(standardized_diffs.keys()), | |
'Standardized Difference': list(standardized_diffs.values()) | |
}) | |
# Sort by absolute standardized difference | |
df['Absolute Difference'] = np.abs(df['Standardized Difference']) | |
df = df.sort_values('Absolute Difference', ascending=False) | |
plt.figure(figsize=(12, len(standardized_diffs) * 0.4 + 2)) | |
# Plot horizontal bars | |
ax = sns.barplot(x='Standardized Difference', y='Covariate', data=df, | |
palette=['red' if abs(x) > threshold else 'green' for x in df['Standardized Difference']]) | |
# Add vertical lines for thresholds | |
plt.axvline(x=threshold, color='red', linestyle='--', alpha=0.7) | |
plt.axvline(x=-threshold, color='red', linestyle='--', alpha=0.7) | |
plt.axvline(x=0, color='black', linestyle='-', alpha=0.7) | |
plt.title('Covariate Balance: Standardized Differences') | |
plt.xlabel('Standardized Difference') | |
plt.tight_layout() | |
if save_path: | |
plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
plt.show() | |
def check_temporal_structure(df: pd.DataFrame) -> Dict[str, Any]: | |
""" | |
Check if dataset has temporal structure. | |
Args: | |
df: DataFrame to check | |
Returns: | |
Dictionary with temporal structure information | |
""" | |
# Check for date/time columns | |
date_cols = [] | |
for col in df.columns: | |
# Check if column has date in name | |
if any(date_term in col.lower() for date_term in ['date', 'time', 'year', 'month', 'day', 'period']): | |
date_cols.append(col) | |
# Check if column can be converted to datetime | |
if df[col].dtype == 'object': | |
try: | |
pd.to_datetime(df[col], errors='raise') | |
date_cols.append(col) | |
except: | |
pass | |
# Check for panel structure - look for ID columns | |
id_cols = [] | |
for col in df.columns: | |
# Check if column has ID in name | |
if any(id_term in col.lower() for id_term in ['id', 'identifier', 'key', 'code']): | |
unique_count = df[col].nunique() | |
# If column has multiple values but fewer than 10% of rows, likely an ID | |
if 1 < unique_count < len(df) * 0.1: | |
id_cols.append(col) | |
# Check if there are multiple observations per unit | |
is_panel = False | |
panel_units = None | |
if id_cols and date_cols: | |
# For each ID column, check if there are multiple time periods | |
for id_col in id_cols: | |
obs_per_id = df.groupby(id_col).size() | |
if (obs_per_id > 1).any(): | |
is_panel = True | |
panel_units = id_col | |
break | |
return { | |
"has_temporal_structure": len(date_cols) > 0, | |
"temporal_columns": date_cols, | |
"potential_id_columns": id_cols, | |
"is_panel_data": is_panel, | |
"panel_units": panel_units | |
} | |
def check_for_discontinuities(df: pd.DataFrame, outcome: str, | |
threshold_zscore: float = 3.0) -> Dict[str, Any]: | |
""" | |
Check for potential discontinuities in continuous variables. | |
Args: | |
df: DataFrame to check | |
outcome: Name of outcome variable | |
threshold_zscore: Z-score threshold for detecting discontinuities | |
Returns: | |
Dictionary with discontinuity information | |
""" | |
potential_running_vars = [] | |
# Check only numeric columns that aren't the outcome | |
numeric_cols = df.select_dtypes(include=[np.number]).columns | |
numeric_cols = [col for col in numeric_cols if col != outcome] | |
for col in numeric_cols: | |
# Skip if too many unique values (unlikely to be a running variable) | |
if df[col].nunique() > 100: | |
continue | |
# Sort values and calculate differences | |
sorted_vals = np.sort(df[col].unique()) | |
if len(sorted_vals) <= 1: | |
continue | |
diffs = np.diff(sorted_vals) | |
mean_diff = np.mean(diffs) | |
std_diff = np.std(diffs) | |
# Skip if all differences are the same | |
if std_diff == 0: | |
continue | |
# Calculate z-scores of differences | |
zscores = (diffs - mean_diff) / std_diff | |
# Check if any z-score exceeds threshold | |
if np.any(np.abs(zscores) > threshold_zscore): | |
# Potential discontinuity found | |
max_idx = np.argmax(np.abs(zscores)) | |
threshold = (sorted_vals[max_idx] + sorted_vals[max_idx + 1]) / 2 | |
# Check if outcome means differ across threshold | |
below_mean = df[df[col] < threshold][outcome].mean() | |
above_mean = df[df[col] >= threshold][outcome].mean() | |
# Only include if outcome means differ substantially | |
if abs(above_mean - below_mean) > 0.1 * df[outcome].std(): | |
potential_running_vars.append({ | |
"variable": col, | |
"threshold": float(threshold), | |
"z_score": float(zscores[max_idx]), | |
"outcome_diff": float(above_mean - below_mean) | |
}) | |
return { | |
"has_discontinuities": len(potential_running_vars) > 0, | |
"potential_running_variables": potential_running_vars | |
} | |
def find_potential_instruments(df: pd.DataFrame, treatment: str, outcome: str, | |
correlation_threshold: float = 0.3) -> Dict[str, Any]: | |
""" | |
Find potential instrumental variables. | |
Args: | |
df: DataFrame to check | |
treatment: Name of treatment variable | |
outcome: Name of outcome variable | |
correlation_threshold: Threshold for correlation with treatment | |
Returns: | |
Dictionary with potential instruments information | |
""" | |
# Get numeric columns that aren't treatment or outcome | |
numeric_cols = df.select_dtypes(include=[np.number]).columns | |
potential_ivs = [col for col in numeric_cols if col != treatment and col != outcome] | |
iv_results = [] | |
for col in potential_ivs: | |
# Skip if column has too many missing values | |
if df[col].isna().mean() > 0.1: | |
continue | |
# Check correlation with treatment | |
corr_treatment = df[[col, treatment]].corr().iloc[0, 1] | |
# Check correlation with outcome | |
corr_outcome = df[[col, outcome]].corr().iloc[0, 1] | |
# Potential IV should be correlated with treatment but not directly with outcome | |
if abs(corr_treatment) > correlation_threshold and abs(corr_outcome) < correlation_threshold/2: | |
iv_results.append({ | |
"variable": col, | |
"correlation_with_treatment": float(corr_treatment), | |
"correlation_with_outcome": float(corr_outcome), | |
"strength": "Strong" if abs(corr_treatment) > 0.5 else "Moderate" | |
}) | |
return { | |
"has_potential_instruments": len(iv_results) > 0, | |
"potential_instruments": iv_results | |
} | |
def test_parallel_trends(df: pd.DataFrame, treatment: str, outcome: str, | |
time_var: str, unit_var: str) -> Dict[str, Any]: | |
""" | |
Test for parallel trends assumption in difference-in-differences. | |
Args: | |
df: DataFrame to check | |
treatment: Name of treatment variable | |
outcome: Name of outcome variable | |
time_var: Name of time variable | |
unit_var: Name of unit variable | |
Returns: | |
Dictionary with parallel trends test results | |
""" | |
# Ensure time_var is properly formatted | |
df = df.copy() | |
if df[time_var].dtype != 'int64': | |
# Try to convert to datetime and then to period | |
try: | |
df[time_var] = pd.to_datetime(df[time_var]) | |
# Get unique periods and map to integers | |
periods = df[time_var].dt.to_period('M').unique() | |
period_dict = {p: i for i, p in enumerate(sorted(periods))} | |
df['time_period'] = df[time_var].dt.to_period('M').map(period_dict) | |
time_var = 'time_period' | |
except: | |
# If conversion fails, try to map unique values to integers | |
unique_times = df[time_var].unique() | |
time_dict = {t: i for i, t in enumerate(sorted(unique_times))} | |
df['time_period'] = df[time_var].map(time_dict) | |
time_var = 'time_period' | |
# Identify treatment and control groups | |
# Treatment indicator should be 0 or 1 for each unit (not time-varying) | |
unit_treatment = df.groupby(unit_var)[treatment].max() | |
treatment_units = unit_treatment[unit_treatment == 1].index | |
control_units = unit_treatment[unit_treatment == 0].index | |
# Find time of treatment implementation | |
if len(treatment_units) > 0: | |
treatment_time = df[df[unit_var].isin(treatment_units) & (df[treatment] == 1)][time_var].min() | |
else: | |
# No treated units found | |
return { | |
"parallel_trends": False, | |
"reason": "No treated units found", | |
"pre_trend_correlation": None, | |
"pre_trend_p_value": None | |
} | |
# Select pre-treatment periods | |
pre_treatment = df[df[time_var] < treatment_time] | |
# Calculate average outcome by time and group | |
treated_means = pre_treatment[pre_treatment[unit_var].isin(treatment_units)].groupby(time_var)[outcome].mean() | |
control_means = pre_treatment[pre_treatment[unit_var].isin(control_units)].groupby(time_var)[outcome].mean() | |
# Need enough pre-treatment periods to test | |
if len(treated_means) < 3: | |
return { | |
"parallel_trends": None, | |
"reason": "Insufficient pre-treatment periods", | |
"pre_trend_correlation": None, | |
"pre_trend_p_value": None | |
} | |
# Align indices and calculate trends | |
common_periods = sorted(set(treated_means.index).intersection(set(control_means.index))) | |
if len(common_periods) < 3: | |
return { | |
"parallel_trends": None, | |
"reason": "Insufficient common pre-treatment periods", | |
"pre_trend_correlation": None, | |
"pre_trend_p_value": None | |
} | |
treated_trends = np.diff(treated_means[common_periods]) | |
control_trends = np.diff(control_means[common_periods]) | |
# Calculate correlation between trends | |
correlation, p_value = stats.pearsonr(treated_trends, control_trends) | |
# Test if trends are parallel (high correlation, not significantly different) | |
parallel_trends = correlation > 0.7 and p_value < 0.05 | |
return { | |
"parallel_trends": parallel_trends, | |
"reason": "Trends are parallel" if parallel_trends else "Trends are not parallel", | |
"pre_trend_correlation": float(correlation), | |
"pre_trend_p_value": float(p_value) | |
} | |
def preprocess_data(df: pd.DataFrame, treatment_var: str, outcome_var: str, | |
covariates: List[str], verbose: bool = True) -> pd.DataFrame: | |
""" | |
Preprocess the dataset to handle missing values and encode categorical variables. | |
Args: | |
df (pd.DataFrame): The dataset | |
treatment_var (str): The treatment variable name | |
outcome_var (str): The outcome variable name | |
covariates (list): List of covariate variable names | |
verbose (bool): Whether to print verbose output | |
Returns: | |
Tuple[pd.DataFrame, str, str, List[str], Dict[str, Any]]: | |
Preprocessed dataset, updated treatment var name, | |
updated outcome var name, updated covariates list, | |
and column mappings. | |
""" | |
df_processed = df.copy() | |
column_mappings: Dict[str, Any] = {} | |
# Store original dtypes for mapping | |
original_dtypes = {col: str(df_processed[col].dtype) for col in df_processed.columns} | |
# Report missing values | |
all_vars = [treatment_var, outcome_var] + covariates | |
missing_data = df_processed[all_vars].isnull().sum() | |
total_missing = missing_data.sum() | |
if total_missing > 0: | |
if verbose: | |
logger.info(f"Dataset contains {total_missing} missing values:") | |
for col in missing_data[missing_data > 0].index: | |
percent = (missing_data[col] / len(df_processed)) * 100 | |
if verbose: | |
logger.info(f" - {col}: {missing_data[col]} missing values ({percent:.2f}%)") | |
else: | |
if verbose: | |
logger.info("No missing values found in relevant columns.") | |
# return df_processed # No preprocessing needed if no missing values | |
# Handle missing values in treatment variable | |
if df_processed[treatment_var].isnull().sum() > 0: | |
if verbose: | |
logger.info(f"Filling missing values in treatment variable '{treatment_var}' with mode") | |
# For treatment, use mode (most common value) | |
mode_val = df_processed[treatment_var].mode()[0] if not df_processed[treatment_var].mode().empty else 0 | |
df_processed[treatment_var] = df_processed[treatment_var].fillna(mode_val) | |
# Handle missing values in outcome variable | |
if df_processed[outcome_var].isnull().sum() > 0: | |
if verbose: | |
logger.info(f"Filling missing values in outcome variable '{outcome_var}' with mean") | |
# For outcome, use mean | |
mean_val = df_processed[outcome_var].mean() | |
df_processed[outcome_var] = df_processed[outcome_var].fillna(mean_val) | |
# Handle missing values in covariates | |
for col in covariates: | |
if df_processed[col].isnull().sum() > 0: | |
if pd.api.types.is_numeric_dtype(df_processed[col]): | |
# For numeric covariates, use mean | |
if verbose: | |
logger.info(f"Filling missing values in numeric covariate '{col}' with mean") | |
mean_val = df_processed[col].mean() | |
df_processed[col] = df_processed[col].fillna(mean_val) | |
elif pd.api.types.is_categorical_dtype(df_processed[col]) or df_processed[col].dtype == 'object': | |
# For categorical covariates, use mode | |
mode_val = df_processed[col].mode()[0] if not df_processed[col].mode().empty else "Missing" | |
if verbose: | |
logger.info(f"Filling missing values in categorical covariate '{col}' with mode ('{mode_val}')") | |
df_processed[col] = df_processed[col].fillna(mode_val) | |
else: | |
# For other types, create a "Missing" category | |
if verbose: | |
logger.info(f"Filling missing values in covariate '{col}' of type {df_processed[col].dtype} with 'Missing' category") | |
# Ensure the column is of object type before filling with string | |
if df_processed[col].dtype != 'object': | |
try: | |
df_processed[col] = df_processed[col].astype(object) | |
except Exception as e: | |
logger.warning(f"Could not convert column {col} to object type to fill NAs: {e}. Skipping fill.") | |
continue | |
df_processed[col] = df_processed[col].fillna("Missing") | |
# --- Categorical Encoding --- | |
updated_treatment_var = treatment_var | |
updated_outcome_var = outcome_var | |
# Helper function for label encoding binary categoricals | |
def label_encode_binary(series: pd.Series, var_name: str) -> Tuple[pd.Series, Dict[int, Any]]: | |
uniques = series.dropna().unique() | |
mapping = {} | |
if len(uniques) == 2: | |
# Try to map to 0 and 1 consistently, e.g., sort and assign | |
# Or if boolean, map True to 1, False to 0 | |
if series.dtype == 'bool': | |
mapping = {0: False, 1: True} | |
return series.astype(int), mapping | |
# For non-boolean, sort to ensure consistent mapping | |
# However, direct replacement is safer to control which becomes 0 and 1 | |
# For simplicity here, we'll make a simple map. | |
# A more robust approach might involve explicit mapping rules or user input. | |
sorted_uniques = sorted(uniques, key=lambda x: str(x)) # sort to make it deterministic | |
map_dict = {sorted_uniques[0]: 0, sorted_uniques[1]: 1} | |
mapping = {v: k for k, v in map_dict.items()} # Inverse map for column_mappings | |
if verbose: | |
logger.info(f"Label encoding binary variable '{var_name}': {map_dict}") | |
return series.map(map_dict), mapping | |
elif len(uniques) == 1: # Single unique value, treat as constant (encode as 0) | |
if verbose: | |
logger.info(f"Binary variable '{var_name}' has only one unique value '{uniques[0]}'. Encoding as 0.") | |
map_dict = {uniques[0]:0} | |
mapping = {0: uniques[0]} | |
return series.map(map_dict), mapping | |
return series, mapping # No change if not binary | |
# Encode Treatment Variable | |
if df_processed[treatment_var].dtype == 'object' or df_processed[treatment_var].dtype == 'category' or df_processed[treatment_var].dtype == 'bool': | |
original_series = df_processed[treatment_var].copy() | |
df_processed[treatment_var], value_map = label_encode_binary(df_processed[treatment_var], treatment_var) | |
if value_map: # If encoding happened | |
column_mappings[treatment_var] = { | |
'original_dtype': original_dtypes[treatment_var], | |
'transformed_as': 'label_encoded_binary', | |
'new_column_name': treatment_var, # Name doesn't change | |
'value_map': value_map | |
} | |
if verbose: | |
logger.info(f"Encoded treatment variable '{treatment_var}' to numeric.") | |
# Encode Outcome Variable | |
if df_processed[outcome_var].dtype == 'object' or df_processed[outcome_var].dtype == 'category' or df_processed[outcome_var].dtype == 'bool': | |
original_series = df_processed[outcome_var].copy() | |
df_processed[outcome_var], value_map = label_encode_binary(df_processed[outcome_var], outcome_var) | |
if value_map: # If encoding happened | |
column_mappings[outcome_var] = { | |
'original_dtype': original_dtypes[outcome_var], | |
'transformed_as': 'label_encoded_binary', | |
'new_column_name': outcome_var, # Name doesn't change | |
'value_map': value_map | |
} | |
if verbose: | |
logger.info(f"Encoded outcome variable '{outcome_var}' to numeric.") | |
# Encode Covariates (One-Hot Encoding for non-numeric) | |
updated_covariates = [] | |
categorical_covariates_to_encode = [] | |
for cov in covariates: | |
if cov not in df_processed.columns: # If a covariate was dropped or is an instrument etc. | |
if verbose: | |
logger.warning(f"Covariate '{cov}' not found in DataFrame columns after initial processing. Skipping encoding for it.") | |
continue | |
if df_processed[cov].dtype == 'object' or df_processed[cov].dtype == 'category' or pd.api.types.is_bool_dtype(df_processed[cov]): | |
# Check if it's binary - if so, can also label encode | |
# However, for consistency with get_dummies and to handle multi-category, | |
# we'll let get_dummies handle it, or apply label encoding for binary covariates too. | |
# For simplicity, let's stick to one-hot for all categorical covariates. | |
if len(df_processed[cov].dropna().unique()) > 1 : # Only encode if more than 1 unique value | |
categorical_covariates_to_encode.append(cov) | |
else: # If only one unique value or all NaNs (already handled), it's constant-like | |
if verbose: | |
logger.info(f"Categorical covariate '{cov}' has <= 1 unique value after NA handling. Treating as constant-like, not one-hot encoding.") | |
updated_covariates.append(cov) # Keep as is, will likely be numeric 0 or some constant | |
else: # Already numeric | |
updated_covariates.append(cov) | |
if categorical_covariates_to_encode: | |
if verbose: | |
logger.info(f"One-hot encoding categorical covariates: {categorical_covariates_to_encode} using pd.get_dummies (drop_first=True)") | |
# Store original columns before get_dummies to identify new ones | |
original_df_columns = set(df_processed.columns) | |
df_processed = pd.get_dummies(df_processed, columns=categorical_covariates_to_encode, | |
prefix_sep='_', drop_first=True, dummy_na=False) # dummy_na=False since we handled NAs | |
# Identify new columns created by get_dummies | |
new_dummy_columns = list(set(df_processed.columns) - original_df_columns) | |
updated_covariates.extend(new_dummy_columns) | |
for original_cov_name in categorical_covariates_to_encode: | |
# Find which dummy columns correspond to this original covariate | |
related_dummies = [col for col in new_dummy_columns if col.startswith(original_cov_name + '_')] | |
column_mappings[original_cov_name] = { | |
'original_dtype': original_dtypes[original_cov_name], | |
'transformed_as': 'one_hot_encoded', | |
'encoded_columns': related_dummies, | |
# 'dropped_category': can be inferred if needed, but not explicitly stored for simplicity here | |
} | |
if verbose: | |
logger.info(f" Original covariate '{original_cov_name}' resulted in dummy variables: {related_dummies}") | |
if verbose: | |
logger.info("Preprocessing complete.") | |
if column_mappings: | |
logger.info(f"Column mappings generated: {column_mappings}") | |
else: | |
logger.info("No column encodings were applied.") | |
return df_processed, updated_treatment_var, updated_outcome_var, list(dict.fromkeys(updated_covariates)), column_mappings | |
def check_collinearity(df: pd.DataFrame, covariates: List[str]) -> Optional[List[str]]: | |
# Implementation of check_collinearity function | |
# This function should return a list of collinear variables or None | |
pass |