Spaces:
Running
Running
File size: 4,893 Bytes
1721aea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
"""
Tool for interpreting causal queries in the context of a dataset.
This module provides a LangChain tool for matching query concepts to actual
dataset variables, identifying treatment, outcome, and covariate variables.
"""
# Removed Pydantic import, will be imported via models
# from pydantic import BaseModel, Field
from typing import Dict, List, Any, Optional, Union # Keep Any, Dict for workflow_state
import logging
# Import shared Pydantic models from the central location
from auto_causal.models import (
TemporalStructure,
DatasetInfo,
DatasetAnalysis,
QueryInfo,
QueryInterpreterInput,
Variables,
QueryInterpreterOutput
)
# --- Removed local Pydantic definitions ---
# class TemporalStructure(BaseModel): ...
# class DatasetInfo(BaseModel): ...
# class DatasetAnalysis(BaseModel): ...
# class QueryInfo(BaseModel): ...
# class QueryInterpreterInput(BaseModel): ...
# class Variables(BaseModel): ...
# class QueryInterpreterOutput(BaseModel): ...
logger = logging.getLogger(__name__)
from langchain.tools import tool
from auto_causal.components.query_interpreter import interpret_query
from auto_causal.components.state_manager import create_workflow_state_update
@tool()
# Modify signature to accept individual Pydantic models/types as arguments
def query_interpreter_tool(
query_info: QueryInfo,
dataset_analysis: DatasetAnalysis,
dataset_description: str,
original_query: Optional[str] = None # Keep optional original_query
) -> QueryInterpreterOutput:
"""
Interpret a causal query in the context of a specific dataset.
Args:
query_info: Pydantic model with parsed query information.
dataset_analysis: Pydantic model with dataset analysis results.
dataset_description: String description of the dataset.
original_query: The original user query string (optional).
Returns:
A Pydantic model containing identified variables (including is_rct), dataset analysis, description, and workflow state.
"""
logger.info("Running query_interpreter_tool with direct arguments...")
# Use arguments directly, dump models to dicts for the component call
query_info_dict = query_info.model_dump()
dataset_analysis_dict = dataset_analysis.model_dump()
# dataset_description is already a string
# Call the component function
try:
# Assume interpret_query returns a dictionary compatible with Variables model
# AND that interpret_query now attempts to determine is_rct
interpretation_dict = interpret_query(query_info_dict, dataset_analysis_dict, dataset_description)
if not isinstance(interpretation_dict, dict):
raise TypeError(f"interpret_query component did not return a dictionary. Got: {type(interpretation_dict)}")
# Validate and structure the interpretation using Pydantic
# This will raise validation error if interpret_query didn't return expected fields
variables_output = Variables(**interpretation_dict)
except Exception as e:
logger.error(f"Error during query interpretation component call: {e}", exc_info=True)
workflow_update = create_workflow_state_update(
current_step="variable_identification",
step_completed_flag=False,
next_tool="query_interpreter_tool", # Or error handler
next_step_reason=f"Component execution failed: {e}"
)
error_vars = Variables()
# Use the passed dataset_analysis object directly in case of error
error_analysis = dataset_analysis
# Return Pydantic output even on error
return QueryInterpreterOutput(
variables=error_vars,
dataset_analysis=error_analysis,
dataset_description=dataset_description,
original_query=original_query, # Pass original query if available
workflow_state=workflow_update.get('workflow_state', {})
)
# Create workflow state update for success
workflow_update = create_workflow_state_update(
current_step="variable_identification",
step_completed_flag="variables_identified",
next_tool="method_selector_tool",
next_step_reason="Now that we have identified the variables, we can select an appropriate causal inference method"
)
# Construct the Pydantic output object
output = QueryInterpreterOutput(
variables=variables_output,
# Pass the original dataset_analysis Pydantic model
dataset_analysis=dataset_analysis,
dataset_description=dataset_description,
original_query=original_query, # Pass along original query
workflow_state=workflow_update.get('workflow_state', {}) # Extract state dict
)
logger.info("query_interpreter_tool finished successfully.")
return output |