FireShadow's picture
Initial clean commit
1721aea
"""
Difference in Means / Simple Linear Regression Estimator.
Estimates the Average Treatment Effect (ATE) by comparing the mean outcome
between the treated and control groups. This is equivalent to a simple OLS
regression of the outcome on the treatment indicator.
Assumes no confounding (e.g., suitable for RCT data).
"""
import pandas as pd
import statsmodels.api as sm
import numpy as np
import warnings
from typing import Dict, Any, Optional
import logging
from langchain.chat_models.base import BaseChatModel # For type hinting llm
from .diagnostics import run_dim_diagnostics
from .llm_assist import interpret_dim_results
logger = logging.getLogger(__name__)
def estimate_effect(
df: pd.DataFrame,
treatment: str,
outcome: str,
query: Optional[str] = None, # For potential LLM use
llm: Optional[BaseChatModel] = None, # For potential LLM use
**kwargs # To capture any other potential arguments (e.g., covariates - which are ignored)
) -> Dict[str, Any]:
"""
Estimates the causal effect using Difference in Means (via OLS).
Ignores any provided covariates.
Args:
df: Input DataFrame.
treatment: Name of the binary treatment variable column (should be 0 or 1).
outcome: Name of the outcome variable column.
query: Optional user query for context.
llm: Optional Language Model instance.
**kwargs: Additional keyword arguments (ignored).
Returns:
Dictionary containing estimation results:
- 'effect_estimate': The difference in means (treatment coefficient).
- 'p_value': The p-value associated with the difference.
- 'confidence_interval': The 95% confidence interval for the difference.
- 'standard_error': The standard error of the difference.
- 'formula': The regression formula used.
- 'model_summary': Summary object from statsmodels.
- 'diagnostics': Basic group statistics.
- 'interpretation': LLM interpretation.
"""
required_cols = [treatment, outcome]
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
raise ValueError(f"Missing required columns: {missing_cols}")
# Validate treatment is binary (or close to it)
treat_vals = df[treatment].dropna().unique()
if not np.all(np.isin(treat_vals, [0, 1])):
warnings.warn(f"Treatment column '{treatment}' contains values other than 0 and 1: {treat_vals}. Proceeding, but results may be unreliable.", UserWarning)
# Optional: could raise ValueError here if strict binary is required
# Prepare data for statsmodels (add constant, handle potential NaNs)
df_analysis = df[required_cols].dropna()
if df_analysis.empty:
raise ValueError("No data remaining after dropping NaNs for required columns.")
X = df_analysis[[treatment]]
X = sm.add_constant(X) # Add intercept
y = df_analysis[outcome]
formula = f"{outcome} ~ {treatment} + const"
logger.info(f"Running Difference in Means regression: {formula}")
try:
model = sm.OLS(y, X)
results = model.fit()
effect_estimate = results.params[treatment]
p_value = results.pvalues[treatment]
conf_int = results.conf_int(alpha=0.05).loc[treatment].tolist()
std_err = results.bse[treatment]
# Run basic diagnostics (group means, stds, counts)
diag_results = run_dim_diagnostics(df_analysis, treatment, outcome)
# Get interpretation
interpretation = interpret_dim_results(results, diag_results, treatment, llm=llm)
return {
'effect_estimate': effect_estimate,
'p_value': p_value,
'confidence_interval': conf_int,
'standard_error': std_err,
'formula': formula,
'model_summary': results.summary(),
'diagnostics': diag_results,
'interpretation': interpretation,
'method_used': 'Difference in Means (OLS)'
}
except Exception as e:
logger.error(f"Difference in Means failed: {e}")
raise