Spaces:
Running
Running
""" | |
LLM assistance functions for Instrumental Variable (IV) analysis. | |
This module provides functions for LLM-based assistance in instrumental variable analysis, | |
including identifying potential instruments, validating IV assumptions, and interpreting results. | |
""" | |
from typing import List, Dict, Any, Optional | |
import logging | |
# Imported for type hinting | |
from langchain.chat_models.base import BaseChatModel | |
# Import shared LLM helpers | |
from auto_causal.utils.llm_helpers import call_llm_with_json_output | |
logger = logging.getLogger(__name__) | |
def identify_instrument_variable( | |
df_cols: List[str], | |
query: str, | |
llm: Optional[BaseChatModel] = None | |
) -> List[str]: | |
""" | |
Use LLM to identify potential instrumental variables from available columns. | |
Args: | |
df_cols: List of column names from the dataset | |
query: User's causal query text | |
llm: Optional LLM model instance | |
Returns: | |
List of column names identified as potential instruments | |
""" | |
if llm is None: | |
logger.warning("No LLM provided for instrument identification") | |
return [] | |
prompt = f""" | |
You are assisting with an instrumental variable analysis. | |
Available columns in the dataset: {df_cols} | |
User query: {query} | |
Identify potential instrumental variable(s) from the available columns based on the query. | |
The treatment and outcome should NOT be included as instruments. | |
Return ONLY a valid JSON object with the following structure (no explanations or surrounding text): | |
{{ | |
"potential_instruments": ["column_name1", "column_name2", ...] | |
}} | |
""" | |
response = call_llm_with_json_output(llm, prompt) | |
if response and "potential_instruments" in response and isinstance(response["potential_instruments"], list): | |
# Basic validation: ensure items are strings (column names) | |
valid_instruments = [item for item in response["potential_instruments"] if isinstance(item, str)] | |
if len(valid_instruments) != len(response["potential_instruments"]): | |
logger.warning("LLM returned non-string items in potential_instruments list.") | |
return valid_instruments | |
logger.warning(f"Failed to get valid instrument recommendations from LLM. Response: {response}") | |
return [] | |
def validate_instrument_assumptions_qualitative( | |
treatment: str, | |
outcome: str, | |
instrument: List[str], | |
covariates: List[str], | |
query: str, | |
llm: Optional[BaseChatModel] = None | |
) -> Dict[str, str]: | |
""" | |
Use LLM to provide qualitative assessment of IV assumptions. | |
Args: | |
treatment: Treatment variable name | |
outcome: Outcome variable name | |
instrument: List of instrumental variable names | |
covariates: List of covariate variable names | |
query: User's causal query text | |
llm: Optional LLM model instance | |
Returns: | |
Dictionary with qualitative assessments of exclusion and exogeneity assumptions | |
""" | |
default_fail = { | |
"exclusion_assessment": "LLM Check Failed", | |
"exogeneity_assessment": "LLM Check Failed" | |
} | |
if llm is None: | |
return { | |
"exclusion_assessment": "LLM Not Provided", | |
"exogeneity_assessment": "LLM Not Provided" | |
} | |
prompt = f""" | |
You are assisting with assessing the validity of instrumental variable assumptions. | |
Treatment variable: {treatment} | |
Outcome variable: {outcome} | |
Instrumental variable(s): {instrument} | |
Covariates: {covariates} | |
User query: {query} | |
Assess the core Instrumental Variable (IV) assumptions based *only* on the provided variable names and query context: | |
1. Exclusion restriction: Plausibility that the instrument(s) affect the outcome ONLY through the treatment. | |
2. Exogeneity (also called Independence): Plausibility that the instrument(s) are not correlated with unobserved confounders that also affect the outcome. | |
Provide a brief, qualitative assessment (e.g., 'Plausible', 'Unlikely', 'Requires Domain Knowledge', 'Potentially Violated'). | |
Return ONLY a valid JSON object with the following structure (no explanations or surrounding text): | |
{{ | |
"exclusion_assessment": "<brief assessment of exclusion restriction>", | |
"exogeneity_assessment": "<brief assessment of exogeneity assumption>" | |
}} | |
""" | |
response = call_llm_with_json_output(llm, prompt) | |
if response and isinstance(response, dict) and \ | |
"exclusion_assessment" in response and isinstance(response["exclusion_assessment"], str) and \ | |
"exogeneity_assessment" in response and isinstance(response["exogeneity_assessment"], str): | |
return response | |
logger.warning(f"Failed to get valid assumption assessment from LLM. Response: {response}") | |
return default_fail | |
def interpret_iv_results( | |
results: Dict[str, Any], | |
diagnostics: Dict[str, Any], | |
llm: Optional[BaseChatModel] = None | |
) -> str: | |
""" | |
Use LLM to interpret IV results in natural language. | |
Args: | |
results: Dictionary of estimation results (e.g., effect_estimate, p_value, confidence_interval) | |
diagnostics: Dictionary of diagnostic test results (e.g., first_stage_f_statistic, overid_test) | |
llm: Optional LLM model instance | |
Returns: | |
String containing natural language interpretation of results | |
""" | |
if llm is None: | |
return "LLM was not available to provide interpretation. Please review the numeric results manually." | |
# Construct a concise summary of inputs for the prompt | |
results_summary = {} | |
effect = results.get('effect_estimate') | |
if effect is not None: | |
try: | |
results_summary['Effect Estimate'] = f"{float(effect):.3f}" | |
except (ValueError, TypeError): | |
results_summary['Effect Estimate'] = 'N/A (Invalid Format)' | |
else: | |
results_summary['Effect Estimate'] = 'N/A' | |
p_value = results.get('p_value') | |
if p_value is not None: | |
try: | |
results_summary['P-value'] = f"{float(p_value):.3f}" | |
except (ValueError, TypeError): | |
results_summary['P-value'] = 'N/A (Invalid Format)' | |
else: | |
results_summary['P-value'] = 'N/A' | |
ci = results.get('confidence_interval') | |
if ci is not None and isinstance(ci, (list, tuple)) and len(ci) == 2: | |
try: | |
results_summary['Confidence Interval'] = f"[{float(ci[0]):.3f}, {float(ci[1]):.3f}]" | |
except (ValueError, TypeError): | |
results_summary['Confidence Interval'] = 'N/A (Invalid Format)' | |
else: | |
# Handle cases where CI is None or not a 2-element list/tuple | |
results_summary['Confidence Interval'] = str(ci) if ci is not None else 'N/A' | |
if 'treatment_variable' in results: | |
results_summary['Treatment'] = results['treatment_variable'] | |
if 'outcome_variable' in results: | |
results_summary['Outcome'] = results['outcome_variable'] | |
diagnostics_summary = {} | |
f_stat = diagnostics.get('first_stage_f_statistic') | |
if f_stat is not None: | |
try: | |
diagnostics_summary['First-Stage F-statistic'] = f"{float(f_stat):.2f}" | |
except (ValueError, TypeError): | |
diagnostics_summary['First-Stage F-statistic'] = 'N/A (Invalid Format)' | |
else: | |
diagnostics_summary['First-Stage F-statistic'] = 'N/A' | |
if 'weak_instrument_test_status' in diagnostics: | |
diagnostics_summary['Weak Instrument Test'] = diagnostics['weak_instrument_test_status'] | |
overid_p = diagnostics.get('overid_test_p_value') | |
if overid_p is not None: | |
try: | |
diagnostics_summary['Overidentification Test P-value'] = f"{float(overid_p):.3f}" | |
diagnostics_summary['Overidentification Test Applicable'] = diagnostics.get('overid_test_applicable', 'N/A') | |
except (ValueError, TypeError): | |
diagnostics_summary['Overidentification Test P-value'] = 'N/A (Invalid Format)' | |
diagnostics_summary['Overidentification Test Applicable'] = diagnostics.get('overid_test_applicable', 'N/A') | |
else: | |
# Explicitly state if not applicable or not available | |
if diagnostics.get('overid_test_applicable') == False: | |
diagnostics_summary['Overidentification Test'] = 'Not Applicable' | |
else: | |
diagnostics_summary['Overidentification Test P-value'] = 'N/A' | |
diagnostics_summary['Overidentification Test Applicable'] = diagnostics.get('overid_test_applicable', 'N/A') | |
prompt = f""" | |
You are assisting with interpreting instrumental variable (IV) analysis results. | |
Estimation results summary: {results_summary} | |
Diagnostic test results summary: {diagnostics_summary} | |
Explain these Instrumental Variable (IV) results in clear, concise language (2-4 sentences). | |
Focus on: | |
1. The estimated causal effect (magnitude, direction, statistical significance based on p-value < 0.05). | |
2. The strength of the instrument(s) (based on F-statistic, typically > 10 indicates strength). | |
3. Any implications from other diagnostic tests (e.g., overidentification test suggesting instrument validity issues if p < 0.05). | |
Return ONLY a valid JSON object with the following structure (no explanations or surrounding text): | |
{{ | |
"interpretation": "<your concise interpretation text>" | |
}} | |
""" | |
response = call_llm_with_json_output(llm, prompt) | |
if response and isinstance(response, dict) and \ | |
"interpretation" in response and isinstance(response["interpretation"], str): | |
return response["interpretation"] | |
logger.warning(f"Failed to get valid interpretation from LLM. Response: {response}") | |
return "LLM interpretation could not be generated. Please review the numeric results manually." |