File size: 8,250 Bytes
1721aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
Method Selector Tool for selecting causal inference methods.

This module provides a LangChain tool for selecting appropriate
causal inference methods based on dataset characteristics and query details.
"""

import logging # Add logging
from typing import Dict, List, Any, Optional, Union
from langchain_core.tools import tool # Use langchain_core

# Import component function and central LLM factory
from auto_causal.components.decision_tree import rule_based_select_method # Rule-based
from auto_causal.components.decision_tree_llm import DecisionTreeLLMEngine # LLM-based
from auto_causal.config import get_llm_client # Updated import path
from auto_causal.components.state_manager import create_workflow_state_update

# Import shared models from central location
from auto_causal.models import (
    Variables, 
    DatasetAnalysis, 
    MethodSelectorInput # Still needed for args_schema
)

logger = logging.getLogger(__name__)

@tool(args_schema=MethodSelectorInput)
# Option 1: Modify signature to match args_schema fields
def method_selector_tool(
    variables: Variables, 
    dataset_analysis: DatasetAnalysis, 
    dataset_description: Optional[str] = None, 
    original_query: Optional[str] = None,
    excluded_methods: Optional[List[str]] = None
) -> Dict[str, Any]:
    """
    Select the most appropriate causal inference method based on structured input.
    
    Applies decision logic based on dataset analysis and identified variables (including is_rct).
    
    Args:
        variables: Pydantic model containing identified variables (T, O, C, IV, RDD, is_rct, etc.).
        dataset_analysis: Pydantic model containing results of dataset analysis.
        dataset_description: Optional textual description of the dataset.
        original_query: Optional original user query string.
        excluded_methods: Optional list of method names to exclude from selection.
        
    Returns:
        Dictionary with method selection details, context for next step, and workflow state.
    """
    logger.info("Running method_selector_tool with individual args...")
    
    # Access data directly from arguments (they are already Pydantic models)
    variables_model = variables
    dataset_analysis_model = dataset_analysis
    dataset_description_str = dataset_description
    is_rct_flag = variables_model.is_rct # Get is_rct directly from variables argument

    # Convert Pydantic models to dicts for the component call (select_method expects dicts)
    variables_dict = variables_model.model_dump()
    dataset_analysis_dict = dataset_analysis_model.model_dump()
    
    # Basic validation
    treatment = variables_dict.get("treatment_variable")
    outcome = variables_dict.get("outcome_variable")
    if not all([treatment, outcome]):
        logger.error("Missing treatment or outcome variable in input.")
        # Construct error output, including passed-along context
        workflow_update = create_workflow_state_update(
            current_step="method_selection", 
            step_completed_flag=False, 
            next_tool="method_selector_tool", 
            next_step_reason="Missing treatment/outcome variable in input",
            error="Missing treatment/outcome variable in input"
        )
        # Use model_dump() for analysis dict
        return { "error": "Missing treatment/outcome", 
                 "variables": variables_dict,
                 "dataset_analysis": dataset_analysis_dict, 
                 "dataset_description": dataset_description_str,
                 **workflow_update.get('workflow_state', {})}
        
    # Get LLM instance (optional for component)
    try:
        llm_instance = get_llm_client()
    except Exception as e:
        logger.warning(f"Failed to initialize LLM for method_selector_tool: {e}. Proceeding without LLM features.")
        llm_instance = None
        
    # --- Configuration for switching ---
    USE_LLM_DECISION_TREE = False # Set to False to use the original rule-based tree
        
    # Call the component function
    try:
        if USE_LLM_DECISION_TREE:
            logger.info("Using LLM-based Decision Tree Engine for method selection.")
            if not llm_instance:
                logger.warning("LLM instance is required for DecisionTreeLLMEngine but not available. Falling back to rule-based or error.")
                # Potentially raise an error or explicitly call rule-based here if LLM is mandatory for this path
                # For now, it will proceed and DecisionTreeLLMEngine will handle the missing llm
            llm_engine = DecisionTreeLLMEngine(verbose=True) # You can set verbosity as needed
            method_selection_dict = llm_engine.select_method_llm(
                dataset_analysis=dataset_analysis_dict,
                variables=variables_dict,
                is_rct=is_rct_flag if isinstance(is_rct_flag, bool) else False,
                llm=llm_instance,
                excluded_methods=excluded_methods
            )
        else:
            logger.info("Using Rule-based Decision Tree Engine for method selection.")
            # Pass dicts and the is_rct flag
            method_selection_dict = rule_based_select_method(
                 dataset_analysis=dataset_analysis_dict, 
                 variables=variables_dict,
                 is_rct=is_rct_flag if isinstance(is_rct_flag, bool) else False, # Handle None case
                 llm=llm_instance, 
                 dataset_description = dataset_description,
                 original_query = original_query,
                 excluded_methods = excluded_methods
            )
    except Exception as e:
        logger.error(f"Error during method selection execution: {e}", exc_info=True)
        # Construct error output
        workflow_update = create_workflow_state_update(
            current_step="method_selection", 
            step_completed_flag=False, 
            next_tool="error_handler_tool", 
            next_step_reason=f"Component failed: {e}",
            error=f"Component failed: {e}"
        )
        return { "error": f"Method selection logic failed: {e}",
                 "variables": variables_dict, 
                 "dataset_analysis": dataset_analysis_dict, 
                 "dataset_description": dataset_description_str,
                 **workflow_update.get('workflow_state', {})}

    # --- Prepare Output Dictionary --- 
    method_selected_flag = bool(method_selection_dict.get("selected_method") and method_selection_dict["selected_method"] != "Error")
    
    # Create the 'method_info' sub-dictionary required by the validator
    # Include alternative_methods if present in the selection output
    method_info = {
         "selected_method": method_selection_dict.get("selected_method"),
         "method_name": method_selection_dict.get("selected_method", "").replace("_", " ").title() if method_selected_flag else None, 
         "method_justification": method_selection_dict.get("method_justification"),
         "method_assumptions": method_selection_dict.get("method_assumptions", []),
         "alternative_methods": method_selection_dict.get("alternatives", []) # Include alternatives
    }
    
    # Create the final output dictionary for the agent
    result = {
        "method_info": method_info,
        "variables": variables_dict,
        "dataset_analysis": dataset_analysis_dict, 
        "dataset_description": dataset_description_str,
        "original_query": original_query # Pass original query argument
    }
    
    # Determine workflow state for the next step
    next_tool_name = "method_validator_tool" if method_selected_flag else "error_handler_tool" 
    next_reason = "Now we need to validate the assumptions of the selected method" if method_selected_flag else "Method selection failed or returned an error."
    workflow_update = create_workflow_state_update(
        current_step="method_selection",
        step_completed_flag=method_selected_flag, 
        next_tool=next_tool_name, 
        next_step_reason=next_reason
    )
    result.update(workflow_update.get('workflow_state', {})) # Add workflow state dict
    
    logger.info(f"method_selector_tool finished. Selected: {method_info.get('selected_method')}")
    return result