File size: 4,893 Bytes
1721aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
Tool for interpreting causal queries in the context of a dataset.

This module provides a LangChain tool for matching query concepts to actual
dataset variables, identifying treatment, outcome, and covariate variables.
"""

# Removed Pydantic import, will be imported via models
# from pydantic import BaseModel, Field
from typing import Dict, List, Any, Optional, Union # Keep Any, Dict for workflow_state
import logging

# Import shared Pydantic models from the central location
from auto_causal.models import (
    TemporalStructure,
    DatasetInfo,
    DatasetAnalysis,
    QueryInfo,
    QueryInterpreterInput,
    Variables,
    QueryInterpreterOutput
)

# --- Removed local Pydantic definitions --- 
# class TemporalStructure(BaseModel): ...
# class DatasetInfo(BaseModel): ...
# class DatasetAnalysis(BaseModel): ...
# class QueryInfo(BaseModel): ...
# class QueryInterpreterInput(BaseModel): ...
# class Variables(BaseModel): ...
# class QueryInterpreterOutput(BaseModel): ...

logger = logging.getLogger(__name__)

from langchain.tools import tool
from auto_causal.components.query_interpreter import interpret_query
from auto_causal.components.state_manager import create_workflow_state_update


@tool()
# Modify signature to accept individual Pydantic models/types as arguments
def query_interpreter_tool(
    query_info: QueryInfo, 
    dataset_analysis: DatasetAnalysis, 
    dataset_description: str, 
    original_query: Optional[str] = None # Keep optional original_query
) -> QueryInterpreterOutput:
    """
    Interpret a causal query in the context of a specific dataset.

    Args:
        query_info: Pydantic model with parsed query information.
        dataset_analysis: Pydantic model with dataset analysis results.
        dataset_description: String description of the dataset.
        original_query: The original user query string (optional).
        
    Returns:
        A Pydantic model containing identified variables (including is_rct), dataset analysis, description, and workflow state.
    """
    logger.info("Running query_interpreter_tool with direct arguments...")
    
    # Use arguments directly, dump models to dicts for the component call
    query_info_dict = query_info.model_dump()
    dataset_analysis_dict = dataset_analysis.model_dump()
    # dataset_description is already a string
    # Call the component function 
    try:
        # Assume interpret_query returns a dictionary compatible with Variables model
        # AND that interpret_query now attempts to determine is_rct
        interpretation_dict = interpret_query(query_info_dict, dataset_analysis_dict, dataset_description)
        if not isinstance(interpretation_dict, dict):
             raise TypeError(f"interpret_query component did not return a dictionary. Got: {type(interpretation_dict)}")
             
        # Validate and structure the interpretation using Pydantic
        # This will raise validation error if interpret_query didn't return expected fields
        variables_output = Variables(**interpretation_dict)
        
    except Exception as e:
        logger.error(f"Error during query interpretation component call: {e}", exc_info=True)
        workflow_update = create_workflow_state_update(
            current_step="variable_identification",
            step_completed_flag=False, 
            next_tool="query_interpreter_tool", # Or error handler
            next_step_reason=f"Component execution failed: {e}"
        )
        error_vars = Variables() 
        # Use the passed dataset_analysis object directly in case of error
        error_analysis = dataset_analysis 
        # Return Pydantic output even on error
        return QueryInterpreterOutput(
            variables=error_vars, 
            dataset_analysis=error_analysis, 
            dataset_description=dataset_description,
            original_query=original_query, # Pass original query if available
            workflow_state=workflow_update.get('workflow_state', {})
        )

    # Create workflow state update for success
    workflow_update = create_workflow_state_update(
        current_step="variable_identification",
        step_completed_flag="variables_identified",
        next_tool="method_selector_tool",
        next_step_reason="Now that we have identified the variables, we can select an appropriate causal inference method"
    )
    
    # Construct the Pydantic output object
    output = QueryInterpreterOutput(
        variables=variables_output,
        # Pass the original dataset_analysis Pydantic model 
        dataset_analysis=dataset_analysis, 
        dataset_description=dataset_description,
        original_query=original_query, # Pass along original query
        workflow_state=workflow_update.get('workflow_state', {}) # Extract state dict
    )
    
    logger.info("query_interpreter_tool finished successfully.")
    return output