File size: 7,293 Bytes
a23082c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import logging
from dotenv import load_dotenv

from llama_index.core.agent.workflow import ReActAgent
from llama_index.core.tools import FunctionTool
from llama_index.llms.google_genai import GoogleGenAI
from llama_index.llms.openai import OpenAI

# Load environment variables
load_dotenv()

# Setup logging
logger = logging.getLogger(__name__)

# Helper function to load prompt from file
def load_prompt_from_file(filename: str, default_prompt: str) -> str:
    """Loads a prompt from a text file."""
    try:
        # Assuming the prompt file is in the same directory as the agent script
        script_dir = os.path.dirname(__file__)
        prompt_path = os.path.join(script_dir, filename)
        with open(prompt_path, "r") as f:
            prompt = f.read()
            logger.info(f"Successfully loaded prompt from {prompt_path}")
            return prompt
    except FileNotFoundError:
        logger.warning(f"Prompt file {filename} not found at {prompt_path}. Using default.")
        return default_prompt
    except Exception as e:
        logger.error(f"Error loading prompt file {filename}: {e}", exc_info=True)
        return default_prompt

# --- Tool Function ---

def reasoning_tool_fn(context: str) -> str:
    """
    Perform chain-of-thought reasoning over the provided context using a dedicated LLM.
    Args:
        context (str): The conversation/workflow history and current problem statement.
    Returns:
        str: A structured reasoning trace and conclusion, or an error message.
    """
    logger.info(f"Executing reasoning tool with context length: {len(context)}")
    
    # Configuration for the reasoning LLM (OpenAI in the original)
    reasoning_llm_model = os.getenv("REASONING_LLM_MODEL", "gpt-4o-mini") # Use gpt-4o-mini as default
    openai_api_key = os.getenv("ALPAFLOW_OPENAI_API_KEY") # Specific key from original code
    
    if not openai_api_key:
        logger.error("ALPAFLOW_OPENAI_API_KEY not found for reasoning tool LLM.")
        return "Error: ALPAFLOW_OPENAI_API_KEY must be set to use the reasoning tool."

    # Define the prompt for the reasoning LLM
    reasoning_prompt = f"""You are an expert reasoning engine. Analyze the following workflow context and problem statement:

    --- CONTEXT START ---
    {context}
    --- CONTEXT END ---

    Perform the following steps:
    1. **Comprehension**: Identify the core question/problem and key constraints from the context.
    2. **Decomposition**: Break the problem into logical sub-steps.
    3. **Chain-of-Thought**: Reason through each sub-step, stating assumptions and deriving implications.
    4. **Verification**: Check conclusions against constraints.
    5. **Synthesis**: Integrate results into a cohesive answer/recommendation.
    6. **Clarity**: Use precise language.

    Respond with your numbered reasoning steps followed by a concise final conclusion or recommendation.
    """

    try:
        # Note: Original used OpenAI with a specific key and model. Retaining that.
        # Consider adding `reasoning_effort="high"` if supported and desired.
        llm = OpenAI(
            model=reasoning_llm_model,
            api_key=openai_api_key,
            # reasoning_effort="high" # Add if needed and supported by the specific OpenAI integration
        )
        logger.info(f"Using reasoning LLM: {reasoning_llm_model}")
        response = llm.complete(reasoning_prompt)
        logger.info("Reasoning tool execution successful.")
        return response.text
    except Exception as e:
        logger.error(f"Error during reasoning tool LLM call: {e}", exc_info=True)
        return f"Error during reasoning: {e}"

# --- Tool Definition ---
reasoning_tool = FunctionTool.from_defaults(
    fn=reasoning_tool_fn,
    name="reasoning_tool",
    description=(
        "Applies detailed chain-of-thought reasoning to the provided workflow context using a dedicated LLM. "
        "Input: context (str). Output: Reasoning steps and conclusion (str) or error message."
    ),
)

# --- Agent Initialization ---
def initialize_reasoning_agent() -> ReActAgent:
    """Initializes the Reasoning Agent."""
    logger.info("Initializing ReasoningAgent...")

    # Configuration for the agent's main LLM (Google GenAI)
    agent_llm_model = os.getenv("REASONING_AGENT_LLM_MODEL", "models/gemini-1.5-pro")
    gemini_api_key = os.getenv("GEMINI_API_KEY")

    if not gemini_api_key:
        logger.error("GEMINI_API_KEY not found for ReasoningAgent.")
        raise ValueError("GEMINI_API_KEY must be set for ReasoningAgent")

    try:
        llm = GoogleGenAI(api_key=gemini_api_key, model=agent_llm_model)
        logger.info(f"Using agent LLM: {agent_llm_model}")

        # Load system prompt
        default_system_prompt = ("You are ReasoningAgent... [Default prompt content - replace with actual]" # Placeholder
                              )
        system_prompt = load_prompt_from_file("../prompts/reasoning_agent_prompt.txt", default_system_prompt)
        if system_prompt == default_system_prompt:
             logger.warning("Using default/fallback system prompt for ReasoningAgent.")

        agent = ReActAgent(
            name="reasoning_agent",
            description=(
                "A pure reasoning agent that uses the `reasoning_tool` for detailed chain-of-thought analysis "
                "on the provided context, then hands off the result to the `planner_agent`."
            ),
            tools=[reasoning_tool], # Only has access to the reasoning tool
            llm=llm,
            system_prompt=system_prompt,
            can_handoff_to=["planner_agent"],
        )
        logger.info("ReasoningAgent initialized successfully.")
        return agent

    except Exception as e:
        logger.error(f"Error during ReasoningAgent initialization: {e}", exc_info=True)
        raise

# Example usage (for testing if run directly)
if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    logger.info("Running reasoning_agent.py directly for testing...")

    # Check required keys
    required_keys = ["GEMINI_API_KEY", "ALPAFLOW_OPENAI_API_KEY"]
    missing_keys = [key for key in required_keys if not os.getenv(key)]
    if missing_keys:
        print(f"Error: Required environment variable(s) not set: {', '.join(missing_keys)}. Cannot run test.")
    else:
        try:
            # Test the reasoning tool directly
            print("\nTesting reasoning_tool_fn...")
            test_context = "User asked: What is the capital of France? ResearchAgent found: Paris. VerifierAgent confirmed: High confidence."
            reasoning_output = reasoning_tool_fn(test_context)
            print(f"Reasoning Tool Output:\n{reasoning_output}")
            
            # Initialize the agent (optional)
            # test_agent = initialize_reasoning_agent()
            # print("\nReasoning Agent initialized successfully for testing.")
            # Example chat (would require context passing mechanism)
            # result = test_agent.chat("Synthesize the findings about the capital of France.") 
            # print(f"Agent chat result: {result}")

        except Exception as e:
            print(f"Error during testing: {e}")