File size: 4,372 Bytes
1721aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Output formatter tool for causal inference results.

This tool provides the LangChain interface for the output formatter component.
"""

# REVERT Pydantic approach for this tool temporarily

from typing import Dict, Any, Optional#, List, Union # Keep only needed
# from pydantic import BaseModel, Field # REVERT
import logging
import json # Ensure json is imported

# Add import for @tool decorator
from langchain.tools import tool

from auto_causal.components import output_formatter
# Import the Pydantic model returned by the component
from auto_causal.models import FormattedOutput

# --- REVERT: Remove Pydantic Model Definitions --- 
# class Variables(BaseModel): 
# ... (Remove all re-defined models)
# class OutputFormatterInput(BaseModel):
# ... (Remove definition)

# --- Tool Definition --- 
logger = logging.getLogger(__name__)

@tool
# REVERT to original signature with individual arguments
def output_formatter_tool(
    query: str,
    method: str,
    results: Dict[str, Any], # Output from method_executor_tool
    explanation: Dict[str, Any], # Output from explainer_tool
    dataset_analysis: Optional[Dict[str, Any]] = None, # Use Dict
    dataset_description: Optional[str] = None 
) -> Dict[str, Any]:
    """
    Formats the final explanation and results using the output_formatter component,
    packages it into a dictionary, adds workflow state, and a JSON representation.

    Args:
        query: Original user query.
        method: The method used (string name).
        results: Numerical results dict from method_executor_tool.
        explanation: Structured explanation dict from explainer_tool.
        dataset_analysis: Optional results from dataset_analyzer_tool.
        dataset_description: Optional initial description string.
        
    Returns:
        Dict containing the formatted output fields, workflow state, and a JSON string.
    """
    logger.info("Running output_formatter_tool...")

    try:
        # Call component function - it now returns a FormattedOutput Pydantic model
        formatted_output_model: FormattedOutput = output_formatter.format_output(
            query=query,
            method=method,
            results=results, 
            explanation=explanation, # Pass explanation dict directly
            # Pass analysis dict directly, handle None case for component
            dataset_analysis=dataset_analysis if dataset_analysis else None,
            dataset_description=dataset_description
        )
        
        # Convert the Pydantic model back to a dictionary for tool output
        # Use model_dump() for Pydantic v2+, or .dict() for v1
        try:
            # Attempt model_dump first (Pydantic v2)
            formatted_output_dict = formatted_output_model.model_dump(mode='json') # mode='json' handles complex types
        except AttributeError:
            # Fallback to dict() (Pydantic v1)
            formatted_output_dict = formatted_output_model.dict()

        # Generate JSON representation of the dictionary
        try:
            # Exclude workflow_state if it accidentally got included in the model dump
            dict_for_json = {k: v for k, v in formatted_output_dict.items() if k != 'workflow_state'}
            json_output_str = json.dumps(dict_for_json, indent=4)
            formatted_output_dict["json_output"] = json_output_str
        except TypeError as json_err:
            logger.error(f"Failed to serialize output to JSON: {json_err}")
            formatted_output_dict["json_output"] = f'{{"error": "Failed to serialize output to JSON: {json_err}"}}'

        # Add workflow state information - analysis is complete
        formatted_output_dict["workflow_state"] = {
            "current_step": "output_formatting",
            "analysis_complete": True
        }
        
        logger.info("Output formatting successful.")
        return formatted_output_dict # Return the final dictionary
        
    except Exception as e:
        logger.error(f"Error during output formatting: {e}", exc_info=True)
        # Return error structure
        return {
            "error": f"Failed to format output: {e}",
            "workflow_state": {
                "current_step": "output_formatting",
                "analysis_complete": False, # Indicate failure
                "error": f"Formatting component failed: {e}"
            }
        }