""" Tool for analyzing datasets for causal inference. This module provides a LangChain tool for analyzing datasets to detect characteristics relevant for causal inference, such as temporal structure, potential instrumental variables, and variable relationships. """ from typing import Dict, Any, Optional from langchain.tools import tool import logging from auto_causal.components.dataset_analyzer import analyze_dataset from auto_causal.components.state_manager import create_workflow_state_update from langchain_core.language_models import BaseChatModel from auto_causal.config import get_llm_client # Import the required Pydantic models from auto_causal.models import DatasetAnalysis, DatasetAnalyzerOutput from auto_causal import models logger = logging.getLogger(__name__) @tool def dataset_analyzer_tool(dataset_path: str, dataset_description: Optional[str] = None, original_query: Optional[str] = None) -> DatasetAnalyzerOutput: """ Analyze dataset to identify important characteristics for causal inference. This tool loads the dataset, calculates summary statistics, checks for temporal structure, identifies potential treatments/outcomes/instruments, and assesses variable relationships relevant for selecting a causal method. Args: dataset_path: Path to the dataset file. dataset_description: Optional description string from input. llm: Optional LLM client for enhanced analysis. Returns: A Pydantic model containing the structured dataset analysis results and workflow state. """ logger.info(f"Running dataset_analyzer_tool on path: {dataset_path}") # Call the component function with the LLM if available llm = get_llm_client() try: # Call the component function analysis_dict = analyze_dataset(dataset_path, llm_client=llm, dataset_description=dataset_description, original_query=original_query) # Check for errors returned explicitly by the component if isinstance(analysis_dict, dict) and "error" in analysis_dict: logger.error(f"Dataset analysis component failed: {analysis_dict['error']}") raise ValueError(analysis_dict['error']) # Validate and structure the analysis using Pydantic # This assumes analyze_dataset returns a dict compatible with DatasetAnalysis # Handle potential missing keys or type mismatches gracefully analysis_results_model = DatasetAnalysis(**analysis_dict) except Exception as e: logger.error(f"Error during dataset analysis or Pydantic model creation: {e}", exc_info=True) error_state = create_workflow_state_update( current_step="data_analysis", step_completed_flag=False, next_tool="dataset_analyzer_tool", # Retry or error handler? next_step_reason=f"Dataset analysis failed: {e}" ) minimal_info = models.DatasetInfo(num_rows=0, num_columns=0, file_path=dataset_path, file_name="unknown") empty_temporal = models.TemporalStructure(has_temporal_structure=False, temporal_columns=[], is_panel_data=False) error_analysis = models.DatasetAnalysis( dataset_info=minimal_info, columns=[], potential_treatments=[], potential_outcomes=[], temporal_structure_detected=False, panel_data_detected=False, potential_instruments_detected=False, discontinuities_detected=False, temporal_structure=empty_temporal, sample_size=0, num_covariates_estimate=0 ) return DatasetAnalyzerOutput( analysis_results=error_analysis, dataset_description=dataset_description, workflow_state=error_state.get('workflow_state', {}) ) # Create workflow state update for success workflow_update = create_workflow_state_update( current_step="data_analysis", step_completed_flag="dataset_analyzed", next_tool="query_interpreter_tool", next_step_reason="Now we need to map query concepts to actual dataset variables" ) # Construct the final Pydantic output object output = DatasetAnalyzerOutput( analysis_results=analysis_results_model, dataset_description=dataset_description, dataset_path=dataset_path, workflow_state=workflow_update.get('workflow_state', {}) ) # print(output) logger.info("dataset_analyzer_tool finished successfully.") return output