""" Method validator tool for causal inference methods. This tool validates the selected causal inference method against dataset characteristics and available variables. """ from typing import Dict, Any, Optional, List, Union from langchain.tools import tool import logging from auto_causal.components.method_validator import validate_method from auto_causal.components.state_manager import create_workflow_state_update from auto_causal.components.decision_tree import rule_based_select_method # Import shared models from central location from auto_causal.models import ( Variables, TemporalStructure, # Needed indirectly by DatasetAnalysis DatasetInfo, # Needed indirectly by DatasetAnalysis DatasetAnalysis, MethodInfo, MethodValidatorInput ) logger = logging.getLogger(__name__) def extract_properties_from_inputs(inputs: MethodValidatorInput) -> Dict[str, Any]: """ Helper function to extract dataset properties from MethodValidatorInput for use with the decision tree. """ variables_dict = inputs.variables.model_dump() dataset_analysis_dict = inputs.dataset_analysis.model_dump() return { "treatment_variable": variables_dict.get("treatment_variable"), "outcome_variable": variables_dict.get("outcome_variable"), "instrument_variable": variables_dict.get("instrument_variable"), "covariates": variables_dict.get("covariates", []), "time_variable": variables_dict.get("time_variable"), "running_variable": variables_dict.get("running_variable"), "treatment_variable_type": variables_dict.get("treatment_variable_type", "binary"), "has_temporal_structure": dataset_analysis_dict.get("temporal_structure", {}).get("has_temporal_structure", False), "frontdoor_criterion": variables_dict.get("frontdoor_criterion", False), "cutoff_value": variables_dict.get("cutoff_value"), "covariate_overlap_score": variables_dict.get("covariate_overlap_result", 0), "is_rct": variables_dict.get("is_rct", False) } # --- Removed local Pydantic definitions --- # class Variables(BaseModel): ... # class TemporalStructure(BaseModel): ... # class DatasetInfo(BaseModel): ... # class DatasetAnalysis(BaseModel): ... # class MethodInfo(BaseModel): ... # class MethodValidatorInput(BaseModel): ... # --- Tool Definition --- @tool def method_validator_tool(inputs: MethodValidatorInput) -> Dict[str, Any]: # Use Pydantic Input """ Validate the assumptions of the selected causal method using structured input. Args: inputs: Pydantic model containing method_info, dataset_analysis, variables, and dataset_description. Returns: Dictionary with validation results, context for next step, and workflow state. """ logger.info(f"Running method_validator_tool for method: {inputs.method_info.selected_method}") # Access data from input model (converting to dicts for component) method_info_dict = inputs.method_info.model_dump() dataset_analysis_dict = inputs.dataset_analysis.model_dump() variables_dict = inputs.variables.model_dump() dataset_description_str = inputs.dataset_description # Call the component function to validate the method try: validation_results = validate_method(method_info_dict, dataset_analysis_dict, variables_dict) if not isinstance(validation_results, dict): raise TypeError(f"validate_method component did not return a dict. Got: {type(validation_results)}") except Exception as e: logger.error(f"Error during validate_method execution: {e}", exc_info=True) # Construct error output workflow_update = create_workflow_state_update( current_step="method_validation", method_validated=False, error=f"Component failed: {e}" ) # Pass context even on error return {"error": f"Method validation component failed: {e}", "variables": variables_dict, "dataset_analysis": dataset_analysis_dict, "dataset_description": dataset_description_str, **workflow_update.get('workflow_state', {})} # Determine if assumptions are valid based on component output assumptions_valid = validation_results.get("valid", False) failed_assumptions = validation_results.get("concerns", []) original_method = method_info_dict.get("selected_method") recommended_method = validation_results.get("recommended_method", original_method) # If validation failed, attempt to backtrack through decision tree if not assumptions_valid and failed_assumptions: logger.info(f"Method {original_method} failed validation due to: {failed_assumptions}") logger.info("Attempting to backtrack and select alternative method...") try: # Extract properties for decision tree dataset_props = extract_properties_from_inputs(inputs) # Get LLM instance (may be None) from auto_causal.config import get_llm_client try: llm_instance = get_llm_client() except Exception as e: logger.warning(f"Failed to get LLM instance: {e}") llm_instance = None # Re-run decision tree with failed method excluded excluded_methods = [original_method] new_selection = rule_based_select_method( dataset_analysis=inputs.dataset_analysis.model_dump(), variables=inputs.variables.model_dump(), is_rct=inputs.variables.is_rct or False, llm=llm_instance, dataset_description=inputs.dataset_description, original_query=inputs.original_query, excluded_methods=excluded_methods ) recommended_method = new_selection.get("selected_method", original_method) logger.info(f"Backtracking selected new method: {recommended_method}") # Update validation results to include backtracking info validation_results["backtrack_attempted"] = True validation_results["backtrack_method"] = recommended_method validation_results["excluded_methods"] = excluded_methods except Exception as e: logger.error(f"Backtracking failed: {e}") validation_results["backtrack_attempted"] = True validation_results["backtrack_error"] = str(e) # Keep original recommended method # Prepare output dictionary for the next tool (method_executor) result = { # --- Data for Method Executor --- "method": recommended_method, # Use recommended method going forward "variables": variables_dict, # Pass along all identified variables "dataset_path": dataset_analysis_dict.get('dataset_info',{}).get('file_path'), # Extract path "dataset_analysis": dataset_analysis_dict, # Pass full analysis "dataset_description": dataset_description_str, # Pass description string "original_query": inputs.original_query, # Pass original query # --- Validation Results --- "validation_info": { "original_method": method_info_dict.get("selected_method"), "recommended_method": recommended_method, "assumptions_valid": assumptions_valid, "failed_assumptions": failed_assumptions, "warnings": validation_results.get("warnings", []), "suggestions": validation_results.get("suggestions", []) } } # Determine workflow state method_validated_flag = assumptions_valid # Or perhaps always True if validation ran? next_tool_name = "method_executor_tool" if method_validated_flag else "error_handler_tool" # Go to executor even if assumptions failed? next_reason = "Method assumptions checked. Proceeding to execution." if method_validated_flag else "Method assumptions failed validation." workflow_update = create_workflow_state_update( current_step="method_validation", step_completed_flag=method_validated_flag, next_tool=next_tool_name, next_step_reason=next_reason ) result.update(workflow_update) # Add workflow state logger.info(f"method_validator_tool finished. Assumptions valid: {assumptions_valid}") return result