Spaces:
Running
Running
""" | |
Method validator tool for causal inference methods. | |
This tool validates the selected causal inference method against | |
dataset characteristics and available variables. | |
""" | |
from typing import Dict, Any, Optional, List, Union | |
from langchain.tools import tool | |
import logging | |
from auto_causal.components.method_validator import validate_method | |
from auto_causal.components.state_manager import create_workflow_state_update | |
from auto_causal.components.decision_tree import rule_based_select_method | |
# Import shared models from central location | |
from auto_causal.models import ( | |
Variables, | |
TemporalStructure, # Needed indirectly by DatasetAnalysis | |
DatasetInfo, # Needed indirectly by DatasetAnalysis | |
DatasetAnalysis, | |
MethodInfo, | |
MethodValidatorInput | |
) | |
logger = logging.getLogger(__name__) | |
def extract_properties_from_inputs(inputs: MethodValidatorInput) -> Dict[str, Any]: | |
""" | |
Helper function to extract dataset properties from MethodValidatorInput | |
for use with the decision tree. | |
""" | |
variables_dict = inputs.variables.model_dump() | |
dataset_analysis_dict = inputs.dataset_analysis.model_dump() | |
return { | |
"treatment_variable": variables_dict.get("treatment_variable"), | |
"outcome_variable": variables_dict.get("outcome_variable"), | |
"instrument_variable": variables_dict.get("instrument_variable"), | |
"covariates": variables_dict.get("covariates", []), | |
"time_variable": variables_dict.get("time_variable"), | |
"running_variable": variables_dict.get("running_variable"), | |
"treatment_variable_type": variables_dict.get("treatment_variable_type", "binary"), | |
"has_temporal_structure": dataset_analysis_dict.get("temporal_structure", {}).get("has_temporal_structure", False), | |
"frontdoor_criterion": variables_dict.get("frontdoor_criterion", False), | |
"cutoff_value": variables_dict.get("cutoff_value"), | |
"covariate_overlap_score": variables_dict.get("covariate_overlap_result", 0), | |
"is_rct": variables_dict.get("is_rct", False) | |
} | |
# --- Removed local Pydantic definitions --- | |
# class Variables(BaseModel): ... | |
# class TemporalStructure(BaseModel): ... | |
# class DatasetInfo(BaseModel): ... | |
# class DatasetAnalysis(BaseModel): ... | |
# class MethodInfo(BaseModel): ... | |
# class MethodValidatorInput(BaseModel): ... | |
# --- Tool Definition --- | |
def method_validator_tool(inputs: MethodValidatorInput) -> Dict[str, Any]: # Use Pydantic Input | |
""" | |
Validate the assumptions of the selected causal method using structured input. | |
Args: | |
inputs: Pydantic model containing method_info, dataset_analysis, variables, and dataset_description. | |
Returns: | |
Dictionary with validation results, context for next step, and workflow state. | |
""" | |
logger.info(f"Running method_validator_tool for method: {inputs.method_info.selected_method}") | |
# Access data from input model (converting to dicts for component) | |
method_info_dict = inputs.method_info.model_dump() | |
dataset_analysis_dict = inputs.dataset_analysis.model_dump() | |
variables_dict = inputs.variables.model_dump() | |
dataset_description_str = inputs.dataset_description | |
# Call the component function to validate the method | |
try: | |
validation_results = validate_method(method_info_dict, dataset_analysis_dict, variables_dict) | |
if not isinstance(validation_results, dict): | |
raise TypeError(f"validate_method component did not return a dict. Got: {type(validation_results)}") | |
except Exception as e: | |
logger.error(f"Error during validate_method execution: {e}", exc_info=True) | |
# Construct error output | |
workflow_update = create_workflow_state_update( | |
current_step="method_validation", method_validated=False, error=f"Component failed: {e}" | |
) | |
# Pass context even on error | |
return {"error": f"Method validation component failed: {e}", | |
"variables": variables_dict, | |
"dataset_analysis": dataset_analysis_dict, | |
"dataset_description": dataset_description_str, | |
**workflow_update.get('workflow_state', {})} | |
# Determine if assumptions are valid based on component output | |
assumptions_valid = validation_results.get("valid", False) | |
failed_assumptions = validation_results.get("concerns", []) | |
original_method = method_info_dict.get("selected_method") | |
recommended_method = validation_results.get("recommended_method", original_method) | |
# If validation failed, attempt to backtrack through decision tree | |
if not assumptions_valid and failed_assumptions: | |
logger.info(f"Method {original_method} failed validation due to: {failed_assumptions}") | |
logger.info("Attempting to backtrack and select alternative method...") | |
try: | |
# Extract properties for decision tree | |
dataset_props = extract_properties_from_inputs(inputs) | |
# Get LLM instance (may be None) | |
from auto_causal.config import get_llm_client | |
try: | |
llm_instance = get_llm_client() | |
except Exception as e: | |
logger.warning(f"Failed to get LLM instance: {e}") | |
llm_instance = None | |
# Re-run decision tree with failed method excluded | |
excluded_methods = [original_method] | |
new_selection = rule_based_select_method( | |
dataset_analysis=inputs.dataset_analysis.model_dump(), | |
variables=inputs.variables.model_dump(), | |
is_rct=inputs.variables.is_rct or False, | |
llm=llm_instance, | |
dataset_description=inputs.dataset_description, | |
original_query=inputs.original_query, | |
excluded_methods=excluded_methods | |
) | |
recommended_method = new_selection.get("selected_method", original_method) | |
logger.info(f"Backtracking selected new method: {recommended_method}") | |
# Update validation results to include backtracking info | |
validation_results["backtrack_attempted"] = True | |
validation_results["backtrack_method"] = recommended_method | |
validation_results["excluded_methods"] = excluded_methods | |
except Exception as e: | |
logger.error(f"Backtracking failed: {e}") | |
validation_results["backtrack_attempted"] = True | |
validation_results["backtrack_error"] = str(e) | |
# Keep original recommended method | |
# Prepare output dictionary for the next tool (method_executor) | |
result = { | |
# --- Data for Method Executor --- | |
"method": recommended_method, # Use recommended method going forward | |
"variables": variables_dict, # Pass along all identified variables | |
"dataset_path": dataset_analysis_dict.get('dataset_info',{}).get('file_path'), # Extract path | |
"dataset_analysis": dataset_analysis_dict, # Pass full analysis | |
"dataset_description": dataset_description_str, # Pass description string | |
"original_query": inputs.original_query, # Pass original query | |
# --- Validation Results --- | |
"validation_info": { | |
"original_method": method_info_dict.get("selected_method"), | |
"recommended_method": recommended_method, | |
"assumptions_valid": assumptions_valid, | |
"failed_assumptions": failed_assumptions, | |
"warnings": validation_results.get("warnings", []), | |
"suggestions": validation_results.get("suggestions", []) | |
} | |
} | |
# Determine workflow state | |
method_validated_flag = assumptions_valid # Or perhaps always True if validation ran? | |
next_tool_name = "method_executor_tool" if method_validated_flag else "error_handler_tool" # Go to executor even if assumptions failed? | |
next_reason = "Method assumptions checked. Proceeding to execution." if method_validated_flag else "Method assumptions failed validation." | |
workflow_update = create_workflow_state_update( | |
current_step="method_validation", | |
step_completed_flag=method_validated_flag, | |
next_tool=next_tool_name, | |
next_step_reason=next_reason | |
) | |
result.update(workflow_update) # Add workflow state | |
logger.info(f"method_validator_tool finished. Assumptions valid: {assumptions_valid}") | |
return result |