""" Tool for interpreting causal queries in the context of a dataset. This module provides a LangChain tool for matching query concepts to actual dataset variables, identifying treatment, outcome, and covariate variables. """ # Removed Pydantic import, will be imported via models # from pydantic import BaseModel, Field from typing import Dict, List, Any, Optional, Union # Keep Any, Dict for workflow_state import logging # Import shared Pydantic models from the central location from auto_causal.models import ( TemporalStructure, DatasetInfo, DatasetAnalysis, QueryInfo, QueryInterpreterInput, Variables, QueryInterpreterOutput ) # --- Removed local Pydantic definitions --- # class TemporalStructure(BaseModel): ... # class DatasetInfo(BaseModel): ... # class DatasetAnalysis(BaseModel): ... # class QueryInfo(BaseModel): ... # class QueryInterpreterInput(BaseModel): ... # class Variables(BaseModel): ... # class QueryInterpreterOutput(BaseModel): ... logger = logging.getLogger(__name__) from langchain.tools import tool from auto_causal.components.query_interpreter import interpret_query from auto_causal.components.state_manager import create_workflow_state_update @tool() # Modify signature to accept individual Pydantic models/types as arguments def query_interpreter_tool( query_info: QueryInfo, dataset_analysis: DatasetAnalysis, dataset_description: str, original_query: Optional[str] = None # Keep optional original_query ) -> QueryInterpreterOutput: """ Interpret a causal query in the context of a specific dataset. Args: query_info: Pydantic model with parsed query information. dataset_analysis: Pydantic model with dataset analysis results. dataset_description: String description of the dataset. original_query: The original user query string (optional). Returns: A Pydantic model containing identified variables (including is_rct), dataset analysis, description, and workflow state. """ logger.info("Running query_interpreter_tool with direct arguments...") # Use arguments directly, dump models to dicts for the component call query_info_dict = query_info.model_dump() dataset_analysis_dict = dataset_analysis.model_dump() # dataset_description is already a string # Call the component function try: # Assume interpret_query returns a dictionary compatible with Variables model # AND that interpret_query now attempts to determine is_rct interpretation_dict = interpret_query(query_info_dict, dataset_analysis_dict, dataset_description) if not isinstance(interpretation_dict, dict): raise TypeError(f"interpret_query component did not return a dictionary. Got: {type(interpretation_dict)}") # Validate and structure the interpretation using Pydantic # This will raise validation error if interpret_query didn't return expected fields variables_output = Variables(**interpretation_dict) except Exception as e: logger.error(f"Error during query interpretation component call: {e}", exc_info=True) workflow_update = create_workflow_state_update( current_step="variable_identification", step_completed_flag=False, next_tool="query_interpreter_tool", # Or error handler next_step_reason=f"Component execution failed: {e}" ) error_vars = Variables() # Use the passed dataset_analysis object directly in case of error error_analysis = dataset_analysis # Return Pydantic output even on error return QueryInterpreterOutput( variables=error_vars, dataset_analysis=error_analysis, dataset_description=dataset_description, original_query=original_query, # Pass original query if available workflow_state=workflow_update.get('workflow_state', {}) ) # Create workflow state update for success workflow_update = create_workflow_state_update( current_step="variable_identification", step_completed_flag="variables_identified", next_tool="method_selector_tool", next_step_reason="Now that we have identified the variables, we can select an appropriate causal inference method" ) # Construct the Pydantic output object output = QueryInterpreterOutput( variables=variables_output, # Pass the original dataset_analysis Pydantic model dataset_analysis=dataset_analysis, dataset_description=dataset_description, original_query=original_query, # Pass along original query workflow_state=workflow_update.get('workflow_state', {}) # Extract state dict ) logger.info("query_interpreter_tool finished successfully.") return output