""" LLM assistance functions for Instrumental Variable (IV) analysis. This module provides functions for LLM-based assistance in instrumental variable analysis, including identifying potential instruments, validating IV assumptions, and interpreting results. """ from typing import List, Dict, Any, Optional import logging # Imported for type hinting from langchain.chat_models.base import BaseChatModel # Import shared LLM helpers from auto_causal.utils.llm_helpers import call_llm_with_json_output logger = logging.getLogger(__name__) def identify_instrument_variable( df_cols: List[str], query: str, llm: Optional[BaseChatModel] = None ) -> List[str]: """ Use LLM to identify potential instrumental variables from available columns. Args: df_cols: List of column names from the dataset query: User's causal query text llm: Optional LLM model instance Returns: List of column names identified as potential instruments """ if llm is None: logger.warning("No LLM provided for instrument identification") return [] prompt = f""" You are assisting with an instrumental variable analysis. Available columns in the dataset: {df_cols} User query: {query} Identify potential instrumental variable(s) from the available columns based on the query. The treatment and outcome should NOT be included as instruments. Return ONLY a valid JSON object with the following structure (no explanations or surrounding text): {{ "potential_instruments": ["column_name1", "column_name2", ...] }} """ response = call_llm_with_json_output(llm, prompt) if response and "potential_instruments" in response and isinstance(response["potential_instruments"], list): # Basic validation: ensure items are strings (column names) valid_instruments = [item for item in response["potential_instruments"] if isinstance(item, str)] if len(valid_instruments) != len(response["potential_instruments"]): logger.warning("LLM returned non-string items in potential_instruments list.") return valid_instruments logger.warning(f"Failed to get valid instrument recommendations from LLM. Response: {response}") return [] def validate_instrument_assumptions_qualitative( treatment: str, outcome: str, instrument: List[str], covariates: List[str], query: str, llm: Optional[BaseChatModel] = None ) -> Dict[str, str]: """ Use LLM to provide qualitative assessment of IV assumptions. Args: treatment: Treatment variable name outcome: Outcome variable name instrument: List of instrumental variable names covariates: List of covariate variable names query: User's causal query text llm: Optional LLM model instance Returns: Dictionary with qualitative assessments of exclusion and exogeneity assumptions """ default_fail = { "exclusion_assessment": "LLM Check Failed", "exogeneity_assessment": "LLM Check Failed" } if llm is None: return { "exclusion_assessment": "LLM Not Provided", "exogeneity_assessment": "LLM Not Provided" } prompt = f""" You are assisting with assessing the validity of instrumental variable assumptions. Treatment variable: {treatment} Outcome variable: {outcome} Instrumental variable(s): {instrument} Covariates: {covariates} User query: {query} Assess the core Instrumental Variable (IV) assumptions based *only* on the provided variable names and query context: 1. Exclusion restriction: Plausibility that the instrument(s) affect the outcome ONLY through the treatment. 2. Exogeneity (also called Independence): Plausibility that the instrument(s) are not correlated with unobserved confounders that also affect the outcome. Provide a brief, qualitative assessment (e.g., 'Plausible', 'Unlikely', 'Requires Domain Knowledge', 'Potentially Violated'). Return ONLY a valid JSON object with the following structure (no explanations or surrounding text): {{ "exclusion_assessment": "", "exogeneity_assessment": "" }} """ response = call_llm_with_json_output(llm, prompt) if response and isinstance(response, dict) and \ "exclusion_assessment" in response and isinstance(response["exclusion_assessment"], str) and \ "exogeneity_assessment" in response and isinstance(response["exogeneity_assessment"], str): return response logger.warning(f"Failed to get valid assumption assessment from LLM. Response: {response}") return default_fail def interpret_iv_results( results: Dict[str, Any], diagnostics: Dict[str, Any], llm: Optional[BaseChatModel] = None ) -> str: """ Use LLM to interpret IV results in natural language. Args: results: Dictionary of estimation results (e.g., effect_estimate, p_value, confidence_interval) diagnostics: Dictionary of diagnostic test results (e.g., first_stage_f_statistic, overid_test) llm: Optional LLM model instance Returns: String containing natural language interpretation of results """ if llm is None: return "LLM was not available to provide interpretation. Please review the numeric results manually." # Construct a concise summary of inputs for the prompt results_summary = {} effect = results.get('effect_estimate') if effect is not None: try: results_summary['Effect Estimate'] = f"{float(effect):.3f}" except (ValueError, TypeError): results_summary['Effect Estimate'] = 'N/A (Invalid Format)' else: results_summary['Effect Estimate'] = 'N/A' p_value = results.get('p_value') if p_value is not None: try: results_summary['P-value'] = f"{float(p_value):.3f}" except (ValueError, TypeError): results_summary['P-value'] = 'N/A (Invalid Format)' else: results_summary['P-value'] = 'N/A' ci = results.get('confidence_interval') if ci is not None and isinstance(ci, (list, tuple)) and len(ci) == 2: try: results_summary['Confidence Interval'] = f"[{float(ci[0]):.3f}, {float(ci[1]):.3f}]" except (ValueError, TypeError): results_summary['Confidence Interval'] = 'N/A (Invalid Format)' else: # Handle cases where CI is None or not a 2-element list/tuple results_summary['Confidence Interval'] = str(ci) if ci is not None else 'N/A' if 'treatment_variable' in results: results_summary['Treatment'] = results['treatment_variable'] if 'outcome_variable' in results: results_summary['Outcome'] = results['outcome_variable'] diagnostics_summary = {} f_stat = diagnostics.get('first_stage_f_statistic') if f_stat is not None: try: diagnostics_summary['First-Stage F-statistic'] = f"{float(f_stat):.2f}" except (ValueError, TypeError): diagnostics_summary['First-Stage F-statistic'] = 'N/A (Invalid Format)' else: diagnostics_summary['First-Stage F-statistic'] = 'N/A' if 'weak_instrument_test_status' in diagnostics: diagnostics_summary['Weak Instrument Test'] = diagnostics['weak_instrument_test_status'] overid_p = diagnostics.get('overid_test_p_value') if overid_p is not None: try: diagnostics_summary['Overidentification Test P-value'] = f"{float(overid_p):.3f}" diagnostics_summary['Overidentification Test Applicable'] = diagnostics.get('overid_test_applicable', 'N/A') except (ValueError, TypeError): diagnostics_summary['Overidentification Test P-value'] = 'N/A (Invalid Format)' diagnostics_summary['Overidentification Test Applicable'] = diagnostics.get('overid_test_applicable', 'N/A') else: # Explicitly state if not applicable or not available if diagnostics.get('overid_test_applicable') == False: diagnostics_summary['Overidentification Test'] = 'Not Applicable' else: diagnostics_summary['Overidentification Test P-value'] = 'N/A' diagnostics_summary['Overidentification Test Applicable'] = diagnostics.get('overid_test_applicable', 'N/A') prompt = f""" You are assisting with interpreting instrumental variable (IV) analysis results. Estimation results summary: {results_summary} Diagnostic test results summary: {diagnostics_summary} Explain these Instrumental Variable (IV) results in clear, concise language (2-4 sentences). Focus on: 1. The estimated causal effect (magnitude, direction, statistical significance based on p-value < 0.05). 2. The strength of the instrument(s) (based on F-statistic, typically > 10 indicates strength). 3. Any implications from other diagnostic tests (e.g., overidentification test suggesting instrument validity issues if p < 0.05). Return ONLY a valid JSON object with the following structure (no explanations or surrounding text): {{ "interpretation": "" }} """ response = call_llm_with_json_output(llm, prompt) if response and isinstance(response, dict) and \ "interpretation" in response and isinstance(response["interpretation"], str): return response["interpretation"] logger.warning(f"Failed to get valid interpretation from LLM. Response: {response}") return "LLM interpretation could not be generated. Please review the numeric results manually."