Spaces:
Running
Running
File size: 8,196 Bytes
1721aea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
"""
LLM assistance functions for Backdoor Adjustment analysis.
"""
from typing import List, Dict, Any, Optional
import logging
# Imported for type hinting
from langchain.chat_models.base import BaseChatModel
from statsmodels.regression.linear_model import RegressionResultsWrapper
# Import shared LLM helpers
from auto_causal.utils.llm_helpers import call_llm_with_json_output
logger = logging.getLogger(__name__)
def identify_backdoor_set(
df_cols: List[str],
treatment: str,
outcome: str,
query: Optional[str] = None,
existing_covariates: Optional[List[str]] = None, # Allow user to provide some
llm: Optional[BaseChatModel] = None
) -> List[str]:
"""
Use LLM to suggest a potential backdoor adjustment set (confounders).
Tries to identify variables that affect both treatment and outcome.
Args:
df_cols: List of available column names in the dataset.
treatment: Treatment variable name.
outcome: Outcome variable name.
query: User's causal query text (provides context).
existing_covariates: Covariates already considered/provided by user.
llm: Optional LLM model instance.
Returns:
List of suggested variable names for the backdoor adjustment set.
"""
if llm is None:
logger.warning("No LLM provided for backdoor set identification.")
return existing_covariates or []
# Exclude treatment and outcome from potential confounders
potential_confounders = [c for c in df_cols if c not in [treatment, outcome]]
if not potential_confounders:
return existing_covariates or []
prompt = f"""
You are assisting with identifying a backdoor adjustment set for causal inference.
The goal is to find observed variables that confound the relationship between the treatment and outcome.
Assume the causal effect of '{treatment}' on '{outcome}' is of interest.
User query context (optional): {query}
Available variables in the dataset (excluding treatment and outcome): {potential_confounders}
Variables already specified as covariates by user (if any): {existing_covariates}
Based *only* on the variable names and the query context, identify which of the available variables are likely to be common causes (confounders) of both '{treatment}' and '{outcome}'.
These variables should be included in the backdoor adjustment set.
Consider variables that likely occurred *before* or *at the same time as* the treatment.
Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
{{
"suggested_backdoor_set": ["confounder1", "confounder2", ...]
}}
Include variables from the user-provided list if they seem appropriate as confounders.
If no plausible confounders are identified among the available variables, return an empty list.
"""
response = call_llm_with_json_output(llm, prompt)
suggested_set = []
if response and "suggested_backdoor_set" in response and isinstance(response["suggested_backdoor_set"], list):
# Basic validation
valid_vars = [item for item in response["suggested_backdoor_set"] if isinstance(item, str)]
if len(valid_vars) != len(response["suggested_backdoor_set"]):
logger.warning("LLM returned non-string items in suggested_backdoor_set list.")
suggested_set = valid_vars
else:
logger.warning(f"Failed to get valid backdoor set recommendations from LLM. Response: {response}")
# Combine with existing covariates, removing duplicates
final_set = list(dict.fromkeys((existing_covariates or []) + suggested_set))
return final_set
def interpret_backdoor_results(
results: RegressionResultsWrapper,
diagnostics: Dict[str, Any],
treatment_var: str,
covariates: List[str],
llm: Optional[BaseChatModel] = None
) -> str:
"""
Use LLM to interpret Backdoor Adjustment results.
Args:
results: Fitted statsmodels OLS results object.
diagnostics: Dictionary of diagnostic results.
treatment_var: Name of the treatment variable.
covariates: List of covariates used in the adjustment set.
llm: Optional LLM model instance.
Returns:
String containing natural language interpretation.
"""
default_interpretation = "LLM interpretation not available for Backdoor Adjustment."
if llm is None:
logger.info("LLM not provided for Backdoor Adjustment interpretation.")
return default_interpretation
try:
# --- Prepare summary for LLM ---
results_summary = {}
diag_details = diagnostics.get('details', {})
effect = results.params.get(treatment_var)
pval = results.pvalues.get(treatment_var)
results_summary['Treatment Effect Estimate'] = f"{effect:.3f}" if isinstance(effect, (int, float)) else str(effect)
results_summary['P-value'] = f"{pval:.3f}" if isinstance(pval, (int, float)) else str(pval)
try:
conf_int = results.conf_int().loc[treatment_var]
results_summary['95% Confidence Interval'] = f"[{conf_int[0]:.3f}, {conf_int[1]:.3f}]"
except KeyError:
results_summary['95% Confidence Interval'] = "Not Found"
except Exception as ci_e:
results_summary['95% Confidence Interval'] = f"Error ({ci_e})"
results_summary['Adjustment Set (Covariates Used)'] = covariates
results_summary['Model R-squared'] = f"{diagnostics.get('details', {}).get('r_squared', 'N/A'):.3f}" if isinstance(diagnostics.get('details', {}).get('r_squared'), (int, float)) else "N/A"
diag_summary = {}
if diagnostics.get("status") == "Success":
diag_summary['Residuals Normality Status'] = diag_details.get('residuals_normality_status', 'N/A')
diag_summary['Homoscedasticity Status'] = diag_details.get('homoscedasticity_status', 'N/A')
diag_summary['Multicollinearity Status'] = diag_details.get('multicollinearity_status', 'N/A')
else:
diag_summary['Status'] = diagnostics.get("status", "Unknown")
# --- Construct Prompt ---
prompt = f"""
You are assisting with interpreting Backdoor Adjustment (Regression) results.
The key assumption is that the specified adjustment set (covariates) blocks all confounding paths between the treatment ('{treatment_var}') and outcome.
Results Summary:
{results_summary}
Diagnostics Summary (OLS model checks):
{diag_summary}
Explain these results in 2-4 concise sentences. Focus on:
1. The estimated average treatment effect after adjusting for the specified covariates (magnitude, direction, statistical significance based on p-value < 0.05).
2. **Crucially, mention that this estimate relies heavily on the assumption that the included covariates ('{str(covariates)[:100]}...') are sufficient to control for confounding (i.e., satisfy the backdoor criterion).**
3. Briefly mention any major OLS diagnostic issues noted (e.g., non-normal residuals, heteroscedasticity, high multicollinearity).
Return ONLY a valid JSON object with the following structure (no explanations or surrounding text):
{{
"interpretation": "<your concise interpretation text>"
}}
"""
# --- Call LLM ---
response = call_llm_with_json_output(llm, prompt)
# --- Process Response ---
if response and isinstance(response, dict) and \
"interpretation" in response and isinstance(response["interpretation"], str):
return response["interpretation"]
else:
logger.warning(f"Failed to get valid interpretation from LLM for Backdoor Adj. Response: {response}")
return default_interpretation
except Exception as e:
logger.error(f"Error during LLM interpretation for Backdoor Adj: {e}")
return f"Error generating interpretation: {e}"
|