Spaces:
Running
Running
""" | |
Tool for interpreting causal queries in the context of a dataset. | |
This module provides a LangChain tool for matching query concepts to actual | |
dataset variables, identifying treatment, outcome, and covariate variables. | |
""" | |
# Removed Pydantic import, will be imported via models | |
# from pydantic import BaseModel, Field | |
from typing import Dict, List, Any, Optional, Union # Keep Any, Dict for workflow_state | |
import logging | |
# Import shared Pydantic models from the central location | |
from auto_causal.models import ( | |
TemporalStructure, | |
DatasetInfo, | |
DatasetAnalysis, | |
QueryInfo, | |
QueryInterpreterInput, | |
Variables, | |
QueryInterpreterOutput | |
) | |
# --- Removed local Pydantic definitions --- | |
# class TemporalStructure(BaseModel): ... | |
# class DatasetInfo(BaseModel): ... | |
# class DatasetAnalysis(BaseModel): ... | |
# class QueryInfo(BaseModel): ... | |
# class QueryInterpreterInput(BaseModel): ... | |
# class Variables(BaseModel): ... | |
# class QueryInterpreterOutput(BaseModel): ... | |
logger = logging.getLogger(__name__) | |
from langchain.tools import tool | |
from auto_causal.components.query_interpreter import interpret_query | |
from auto_causal.components.state_manager import create_workflow_state_update | |
# Modify signature to accept individual Pydantic models/types as arguments | |
def query_interpreter_tool( | |
query_info: QueryInfo, | |
dataset_analysis: DatasetAnalysis, | |
dataset_description: str, | |
original_query: Optional[str] = None # Keep optional original_query | |
) -> QueryInterpreterOutput: | |
""" | |
Interpret a causal query in the context of a specific dataset. | |
Args: | |
query_info: Pydantic model with parsed query information. | |
dataset_analysis: Pydantic model with dataset analysis results. | |
dataset_description: String description of the dataset. | |
original_query: The original user query string (optional). | |
Returns: | |
A Pydantic model containing identified variables (including is_rct), dataset analysis, description, and workflow state. | |
""" | |
logger.info("Running query_interpreter_tool with direct arguments...") | |
# Use arguments directly, dump models to dicts for the component call | |
query_info_dict = query_info.model_dump() | |
dataset_analysis_dict = dataset_analysis.model_dump() | |
# dataset_description is already a string | |
# Call the component function | |
try: | |
# Assume interpret_query returns a dictionary compatible with Variables model | |
# AND that interpret_query now attempts to determine is_rct | |
interpretation_dict = interpret_query(query_info_dict, dataset_analysis_dict, dataset_description) | |
if not isinstance(interpretation_dict, dict): | |
raise TypeError(f"interpret_query component did not return a dictionary. Got: {type(interpretation_dict)}") | |
# Validate and structure the interpretation using Pydantic | |
# This will raise validation error if interpret_query didn't return expected fields | |
variables_output = Variables(**interpretation_dict) | |
except Exception as e: | |
logger.error(f"Error during query interpretation component call: {e}", exc_info=True) | |
workflow_update = create_workflow_state_update( | |
current_step="variable_identification", | |
step_completed_flag=False, | |
next_tool="query_interpreter_tool", # Or error handler | |
next_step_reason=f"Component execution failed: {e}" | |
) | |
error_vars = Variables() | |
# Use the passed dataset_analysis object directly in case of error | |
error_analysis = dataset_analysis | |
# Return Pydantic output even on error | |
return QueryInterpreterOutput( | |
variables=error_vars, | |
dataset_analysis=error_analysis, | |
dataset_description=dataset_description, | |
original_query=original_query, # Pass original query if available | |
workflow_state=workflow_update.get('workflow_state', {}) | |
) | |
# Create workflow state update for success | |
workflow_update = create_workflow_state_update( | |
current_step="variable_identification", | |
step_completed_flag="variables_identified", | |
next_tool="method_selector_tool", | |
next_step_reason="Now that we have identified the variables, we can select an appropriate causal inference method" | |
) | |
# Construct the Pydantic output object | |
output = QueryInterpreterOutput( | |
variables=variables_output, | |
# Pass the original dataset_analysis Pydantic model | |
dataset_analysis=dataset_analysis, | |
dataset_description=dataset_description, | |
original_query=original_query, # Pass along original query | |
workflow_state=workflow_update.get('workflow_state', {}) # Extract state dict | |
) | |
logger.info("query_interpreter_tool finished successfully.") | |
return output |