""" Method Executor Tool for the causal inference agent. Executes the selected causal inference method using its implementation function. """ import pandas as pd from typing import Dict, Any, Optional, List, Union from langchain.tools import tool import traceback # For error logging import logging # Add logging # Import the mapping and potentially preprocessing utils from auto_causal.methods import METHOD_MAPPING from auto_causal.methods.utils import preprocess_data # Assuming preprocess exists from auto_causal.components.state_manager import create_workflow_state_update from auto_causal.config import get_llm_client # IMPORT LLM Client Factory # Import shared models from central location from auto_causal.models import ( Variables, TemporalStructure, # Needed indirectly by DatasetAnalysis DatasetInfo, # Needed indirectly by DatasetAnalysis DatasetAnalysis, MethodExecutorInput ) # Add this module-level variable, typically near imports or at the top CURRENT_OUTPUT_LOG_FILE = None logger = logging.getLogger(__name__) @tool def method_executor_tool(inputs: MethodExecutorInput, original_query: Optional[str] = None) -> Dict[str, Any]: # Use Pydantic Input '''Execute the selected causal inference method function using structured input. Args: inputs: Pydantic model containing method, variables, dataset_path, dataset_analysis, and dataset_description. Returns: Dict with numerical results, context for next step, and workflow state. ''' # Access data from input model method = inputs.method variables_dict = inputs.variables.model_dump() dataset_path = inputs.dataset_path dataset_analysis_dict = inputs.dataset_analysis.model_dump() dataset_description_str = inputs.dataset_description validation_info = inputs.validation_info # Can be passed if needed logger.info(f"Executing method: {method}") try: # --- Get LLM Instance --- llm_instance = None try: llm_instance = get_llm_client() except Exception as llm_e: logger.warning(f"Could not get LLM client in method_executor_tool: {llm_e}. LLM-dependent features in method will be disabled.") # 1. Load Data if not dataset_path: raise ValueError("Dataset path is missing.") df = pd.read_csv(dataset_path) # 2. Extract Key Variables needed by estimate_func signature treatment = variables_dict.get("treatment_variable") outcome = variables_dict.get("outcome_variable") covariates = variables_dict.get("covariates", []) query_str = original_query if original_query is not None else inputs.original_query if not all([treatment, outcome]): raise ValueError("Treatment or Outcome variable not found in 'variables' dict.") # 3. Preprocess Data required_cols_for_method = [treatment, outcome] + covariates # Add method-specific required vars from the variables_dict if method == "instrumental_variable" and variables_dict.get("instrument_variable"): required_cols_for_method.append(variables_dict["instrument_variable"]) elif method == "regression_discontinuity" and variables_dict.get("running_variable"): required_cols_for_method.append(variables_dict["running_variable"]) missing_df_cols = [col for col in required_cols_for_method if col not in df.columns] if missing_df_cols: raise ValueError(f"Dataset at {dataset_path} is missing required columns for method '{method}': {missing_df_cols}") df_processed, updated_treatment, updated_outcome, updated_covariates, column_mappings = \ preprocess_data(df, treatment, outcome, covariates, verbose=False) # 4. Get the correct method execution function if method not in METHOD_MAPPING: raise ValueError(f"Method '{method}' not found in METHOD_MAPPING.") estimate_func = METHOD_MAPPING[method] # 5. Execute the method # Pass only necessary args from variables_dict as kwargs # (e.g., instrument_variable, running_variable, cutoff_value, etc.) # Avoid passing the entire variables_dict as estimate_func expects specific args kwargs_for_method = {} for key in ["instrument_variable", "time_variable", "group_variable", "running_variable", "cutoff_value"]: if key in variables_dict and variables_dict[key] is not None: kwargs_for_method[key] = variables_dict[key] # Add new fields from the Variables model (which is inputs.variables) if hasattr(inputs, 'variables'): # ensure variables object exists on inputs if inputs.variables.treatment_reference_level is not None: kwargs_for_method['treatment_reference_level'] = inputs.variables.treatment_reference_level if inputs.variables.interaction_term_suggested is not None: # boolean, so check for None to allow False kwargs_for_method['interaction_term_suggested'] = inputs.variables.interaction_term_suggested if inputs.variables.interaction_variable_candidate is not None: kwargs_for_method['interaction_variable_candidate'] = inputs.variables.interaction_variable_candidate # Add query if needed by llm_assist functions within the method kwargs_for_method['query'] = query_str kwargs_for_method['column_mappings'] = column_mappings results_dict = estimate_func( df=df_processed, treatment=updated_treatment, outcome=updated_outcome, covariates=updated_covariates, dataset_description=dataset_description_str, query_str=query_str, llm=llm_instance, **kwargs_for_method # Pass specific args needed by the method ) # 6. Prepare output logger.info(f"Method execution successful. Effect estimate: {results_dict.get('effect_estimate')}") # Add workflow state workflow_update = create_workflow_state_update( current_step="method_execution", step_completed_flag="method_executed", next_tool="explainer_tool", next_step_reason="Now we need to explain the results and their implications" ) # --- Prepare Output Dictionary --- # Structure required by explainer_tool: context + nested "results" final_output = { # Nested dictionary for numerical results and diagnostics "results": { # Core estimation results (extracted from results_dict) "effect_estimate": results_dict.get("effect_estimate"), "confidence_interval": results_dict.get("confidence_interval"), "standard_error": results_dict.get("standard_error"), "p_value": results_dict.get("p_value"), "method_used": results_dict.get("method_used"), "llm_assumption_check": results_dict.get("llm_assumption_check"), "raw_results": results_dict.get("raw_results"), # Diagnostics and Refutation results "diagnostics": results_dict.get("diagnostics"), "refutation_results": results_dict.get("refutation_results") }, # Top-level context to be passed along "variables": variables_dict, "dataset_analysis": dataset_analysis_dict, "dataset_description": dataset_description_str, "validation_info": validation_info, # Pass validation info "original_query": inputs.original_query, "column_mappings": column_mappings # Add column_mappings to the output # Workflow state will be added next } # Add workflow state to the final output final_output.update(workflow_update.get('workflow_state', {})) # --- Logging logic (moved from output_formatter.py) --- # Prepare a summary dict for logging summary_keys = {"query", "method_used", "causal_effect", "standard_error", "confidence_interval"} # Try to get these from the available context summary_dict = { "query": inputs.original_query if hasattr(inputs, 'original_query') else None, "method_used": results_dict.get("method_used"), "causal_effect": results_dict.get("effect_estimate"), "standard_error": results_dict.get("standard_error"), "confidence_interval": results_dict.get("confidence_interval") } print(f"summary_dict: {summary_dict}") print(f"CURRENT_OUTPUT_LOG_FILE: {CURRENT_OUTPUT_LOG_FILE}") if CURRENT_OUTPUT_LOG_FILE and summary_dict: try: import json log_entry = {"type": "analysis_result", "data": summary_dict} with open(CURRENT_OUTPUT_LOG_FILE, mode='a', encoding='utf-8') as log_file: log_file.write('\n' + json.dumps(log_entry) + '\n') except Exception as e: print(f"[ERROR] method_executor_tool.py: Failed to write analysis results to log file '{CURRENT_OUTPUT_LOG_FILE}': {e}") return final_output except Exception as e: error_message = f"Error executing method {method}: {str(e)}" logger.error(error_message, exc_info=True) # Return error state, include context if available workflow_update = create_workflow_state_update( current_step="method_execution", step_completed_flag=False, next_tool="explainer_tool", # Or error handler? next_step_reason=f"Failed during method execution: {error_message}" ) # Ensure error output still contains necessary context keys if possible error_result = {"error": error_message, "variables": variables_dict if 'variables_dict' in locals() else {}, "dataset_analysis": dataset_analysis_dict if 'dataset_analysis_dict' in locals() else {}, "dataset_description": dataset_description_str if 'dataset_description_str' in locals() else None, "original_query": inputs.original_query if hasattr(inputs, 'original_query') else None, "column_mappings": column_mappings if 'column_mappings' in locals() else {} # Also add to error output } error_result.update(workflow_update.get('workflow_state', {})) return error_result