File size: 8,507 Bytes
1721aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
Method validator tool for causal inference methods.

This tool validates the selected causal inference method against
dataset characteristics and available variables.
"""

from typing import Dict, Any, Optional, List, Union
from langchain.tools import tool
import logging

from auto_causal.components.method_validator import validate_method
from auto_causal.components.state_manager import create_workflow_state_update
from auto_causal.components.decision_tree import rule_based_select_method

# Import shared models from central location
from auto_causal.models import (
    Variables, 
    TemporalStructure, # Needed indirectly by DatasetAnalysis
    DatasetInfo,       # Needed indirectly by DatasetAnalysis
    DatasetAnalysis,
    MethodInfo, 
    MethodValidatorInput
)

logger = logging.getLogger(__name__)

def extract_properties_from_inputs(inputs: MethodValidatorInput) -> Dict[str, Any]:
    """
    Helper function to extract dataset properties from MethodValidatorInput
    for use with the decision tree.
    """
    variables_dict = inputs.variables.model_dump()
    dataset_analysis_dict = inputs.dataset_analysis.model_dump()
    
    return {
        "treatment_variable": variables_dict.get("treatment_variable"),
        "outcome_variable": variables_dict.get("outcome_variable"),
        "instrument_variable": variables_dict.get("instrument_variable"),
        "covariates": variables_dict.get("covariates", []),
        "time_variable": variables_dict.get("time_variable"),
        "running_variable": variables_dict.get("running_variable"),
        "treatment_variable_type": variables_dict.get("treatment_variable_type", "binary"),
        "has_temporal_structure": dataset_analysis_dict.get("temporal_structure", {}).get("has_temporal_structure", False),
        "frontdoor_criterion": variables_dict.get("frontdoor_criterion", False),
        "cutoff_value": variables_dict.get("cutoff_value"),
        "covariate_overlap_score": variables_dict.get("covariate_overlap_result", 0),
        "is_rct": variables_dict.get("is_rct", False)
    }

# --- Removed local Pydantic definitions --- 
# class Variables(BaseModel): ...
# class TemporalStructure(BaseModel): ...
# class DatasetInfo(BaseModel): ...
# class DatasetAnalysis(BaseModel): ...
# class MethodInfo(BaseModel): ...
# class MethodValidatorInput(BaseModel): ...

# --- Tool Definition --- 
@tool
def method_validator_tool(inputs: MethodValidatorInput) -> Dict[str, Any]: # Use Pydantic Input
    """
    Validate the assumptions of the selected causal method using structured input.
    
    Args:
        inputs: Pydantic model containing method_info, dataset_analysis, variables, and dataset_description.
        
    Returns:
        Dictionary with validation results, context for next step, and workflow state.
    """
    logger.info(f"Running method_validator_tool for method: {inputs.method_info.selected_method}")
    
    # Access data from input model (converting to dicts for component)
    method_info_dict = inputs.method_info.model_dump()
    dataset_analysis_dict = inputs.dataset_analysis.model_dump()
    variables_dict = inputs.variables.model_dump()
    dataset_description_str = inputs.dataset_description
    
    # Call the component function to validate the method
    try:
        validation_results = validate_method(method_info_dict, dataset_analysis_dict, variables_dict)
        if not isinstance(validation_results, dict):
            raise TypeError(f"validate_method component did not return a dict. Got: {type(validation_results)}")
            
    except Exception as e:
        logger.error(f"Error during validate_method execution: {e}", exc_info=True)
        # Construct error output
        workflow_update = create_workflow_state_update(
            current_step="method_validation", method_validated=False, error=f"Component failed: {e}"
        )
        # Pass context even on error
        return {"error": f"Method validation component failed: {e}",
                "variables": variables_dict,
                "dataset_analysis": dataset_analysis_dict,
                "dataset_description": dataset_description_str,
                **workflow_update.get('workflow_state', {})}

    # Determine if assumptions are valid based on component output
    assumptions_valid = validation_results.get("valid", False) 
    failed_assumptions = validation_results.get("concerns", [])
    original_method = method_info_dict.get("selected_method")
    recommended_method = validation_results.get("recommended_method", original_method)
    
    # If validation failed, attempt to backtrack through decision tree
    if not assumptions_valid and failed_assumptions:
        logger.info(f"Method {original_method} failed validation due to: {failed_assumptions}")
        logger.info("Attempting to backtrack and select alternative method...")
        
        try:
            # Extract properties for decision tree
            dataset_props = extract_properties_from_inputs(inputs)
            
            # Get LLM instance (may be None)
            from auto_causal.config import get_llm_client
            try:
                llm_instance = get_llm_client()
            except Exception as e:
                logger.warning(f"Failed to get LLM instance: {e}")
                llm_instance = None
            
            # Re-run decision tree with failed method excluded
            excluded_methods = [original_method]
            new_selection = rule_based_select_method(
                dataset_analysis=inputs.dataset_analysis.model_dump(),
                variables=inputs.variables.model_dump(),
                is_rct=inputs.variables.is_rct or False,
                llm=llm_instance,
                dataset_description=inputs.dataset_description,
                original_query=inputs.original_query,
                excluded_methods=excluded_methods
            )
            
            recommended_method = new_selection.get("selected_method", original_method)
            logger.info(f"Backtracking selected new method: {recommended_method}")
            
            # Update validation results to include backtracking info
            validation_results["backtrack_attempted"] = True
            validation_results["backtrack_method"] = recommended_method
            validation_results["excluded_methods"] = excluded_methods
            
        except Exception as e:
            logger.error(f"Backtracking failed: {e}")
            validation_results["backtrack_attempted"] = True
            validation_results["backtrack_error"] = str(e)
            # Keep original recommended method
    
    # Prepare output dictionary for the next tool (method_executor)
    result = {
        # --- Data for Method Executor --- 
        "method": recommended_method, # Use recommended method going forward
        "variables": variables_dict, # Pass along all identified variables
        "dataset_path": dataset_analysis_dict.get('dataset_info',{}).get('file_path'), # Extract path
        "dataset_analysis": dataset_analysis_dict, # Pass full analysis
        "dataset_description": dataset_description_str, # Pass description string
        "original_query": inputs.original_query, # Pass original query
        
        # --- Validation Results --- 
        "validation_info": {
            "original_method": method_info_dict.get("selected_method"),
            "recommended_method": recommended_method,
            "assumptions_valid": assumptions_valid,
            "failed_assumptions": failed_assumptions,
            "warnings": validation_results.get("warnings", []),
            "suggestions": validation_results.get("suggestions", [])
        }
    }
    
    # Determine workflow state
    method_validated_flag = assumptions_valid # Or perhaps always True if validation ran?
    next_tool_name = "method_executor_tool" if method_validated_flag else "error_handler_tool" # Go to executor even if assumptions failed?
    next_reason = "Method assumptions checked. Proceeding to execution." if method_validated_flag else "Method assumptions failed validation."
    workflow_update = create_workflow_state_update(
        current_step="method_validation",
        step_completed_flag=method_validated_flag,
        next_tool=next_tool_name, 
        next_step_reason=next_reason
    )
    result.update(workflow_update) # Add workflow state
    
    logger.info(f"method_validator_tool finished. Assumptions valid: {assumptions_valid}")
    return result