""" Difference in Means / Simple Linear Regression Estimator. Estimates the Average Treatment Effect (ATE) by comparing the mean outcome between the treated and control groups. This is equivalent to a simple OLS regression of the outcome on the treatment indicator. Assumes no confounding (e.g., suitable for RCT data). """ import pandas as pd import statsmodels.api as sm import numpy as np import warnings from typing import Dict, Any, Optional import logging from langchain.chat_models.base import BaseChatModel # For type hinting llm from .diagnostics import run_dim_diagnostics from .llm_assist import interpret_dim_results logger = logging.getLogger(__name__) def estimate_effect( df: pd.DataFrame, treatment: str, outcome: str, query: Optional[str] = None, # For potential LLM use llm: Optional[BaseChatModel] = None, # For potential LLM use **kwargs # To capture any other potential arguments (e.g., covariates - which are ignored) ) -> Dict[str, Any]: """ Estimates the causal effect using Difference in Means (via OLS). Ignores any provided covariates. Args: df: Input DataFrame. treatment: Name of the binary treatment variable column (should be 0 or 1). outcome: Name of the outcome variable column. query: Optional user query for context. llm: Optional Language Model instance. **kwargs: Additional keyword arguments (ignored). Returns: Dictionary containing estimation results: - 'effect_estimate': The difference in means (treatment coefficient). - 'p_value': The p-value associated with the difference. - 'confidence_interval': The 95% confidence interval for the difference. - 'standard_error': The standard error of the difference. - 'formula': The regression formula used. - 'model_summary': Summary object from statsmodels. - 'diagnostics': Basic group statistics. - 'interpretation': LLM interpretation. """ required_cols = [treatment, outcome] missing_cols = [col for col in required_cols if col not in df.columns] if missing_cols: raise ValueError(f"Missing required columns: {missing_cols}") # Validate treatment is binary (or close to it) treat_vals = df[treatment].dropna().unique() if not np.all(np.isin(treat_vals, [0, 1])): warnings.warn(f"Treatment column '{treatment}' contains values other than 0 and 1: {treat_vals}. Proceeding, but results may be unreliable.", UserWarning) # Optional: could raise ValueError here if strict binary is required # Prepare data for statsmodels (add constant, handle potential NaNs) df_analysis = df[required_cols].dropna() if df_analysis.empty: raise ValueError("No data remaining after dropping NaNs for required columns.") X = df_analysis[[treatment]] X = sm.add_constant(X) # Add intercept y = df_analysis[outcome] formula = f"{outcome} ~ {treatment} + const" logger.info(f"Running Difference in Means regression: {formula}") try: model = sm.OLS(y, X) results = model.fit() effect_estimate = results.params[treatment] p_value = results.pvalues[treatment] conf_int = results.conf_int(alpha=0.05).loc[treatment].tolist() std_err = results.bse[treatment] # Run basic diagnostics (group means, stds, counts) diag_results = run_dim_diagnostics(df_analysis, treatment, outcome) # Get interpretation interpretation = interpret_dim_results(results, diag_results, treatment, llm=llm) return { 'effect_estimate': effect_estimate, 'p_value': p_value, 'confidence_interval': conf_int, 'standard_error': std_err, 'formula': formula, 'model_summary': results.summary(), 'diagnostics': diag_results, 'interpretation': interpretation, 'method_used': 'Difference in Means (OLS)' } except Exception as e: logger.error(f"Difference in Means failed: {e}") raise