|
import re |
|
from typing import Any, Literal |
|
|
|
from langchain_community.tools import DuckDuckGoSearchResults |
|
from langchain_core.messages import SystemMessage, AnyMessage |
|
from langchain_core.runnables import Runnable |
|
from langchain_core.tools import BaseTool |
|
from langchain_ollama import ChatOllama |
|
from langgraph.constants import START, END |
|
from langgraph.graph import MessagesState, StateGraph |
|
from langgraph.graph.graph import CompiledGraph |
|
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint |
|
from langgraph.prebuilt import ToolNode |
|
from pydantic import BaseModel |
|
|
|
from tools import ( |
|
get_excel_table_content, |
|
get_youtube_video_transcript, |
|
reverse_string, |
|
transcribe_audio_file, |
|
web_page_info_retriever, |
|
youtube_video_to_frame_captions, sum_list, execute_python_script, |
|
) |
|
|
|
|
|
class AgentFactory: |
|
""" |
|
A factory for the agent. It is assumed that an Ollama server is running |
|
on the machine where the factory is used. |
|
""" |
|
|
|
__system_prompt: str = ( |
|
"You have to answer to some test questions.\n" |
|
"Sometimes auxiliary files may be attached to the question.\n" |
|
"Each question is a JSON string with the following fields:\n" |
|
"1. task_id: unique hash identifier of the question.\n" |
|
"2. question: the text of the question.\n" |
|
"3. Level: ignore this field.\n" |
|
"4. file_name: the name of the file needed to answer the question. " |
|
"This is empty if the question does not refer to any file. " |
|
"IMPORTANT: The text of the question may mention a file name that is " |
|
"different from what is reported into the \"file_name\" JSON field. " |
|
"YOU HAVE TO IGNORE THE FILE NAME MENTIONED INTO \"question\" AND " |
|
"YOU MUST USE THE FILE NAME PROVIDED INTO THE \"file_name\" FIELD.\n" |
|
"\n" |
|
"Achieve the solution by dividing your reasoning in steps, and\n" |
|
"provide an explanation for each step.\n" |
|
"\n" |
|
"The format of your final answer must be\n" |
|
"\n" |
|
"<ANSWER>your_final_answer</Answer>, where your_final_answer is a\n" |
|
"number OR as few words as possible OR a comma separated list of\n" |
|
"numbers and/or strings. If you are asked for\n" |
|
"a number, don't use comma to write your number neither use units\n" |
|
"such as $ or percent sign unless specified otherwise. If you are\n" |
|
"asked for a string, don't use articles, neither abbreviations (e.g.\n" |
|
"for cities), and write the digits in plain text unless specified\n" |
|
"otherwise. If you are asked for a comma separated list, apply the\n" |
|
"above rules depending of whether the element to be put in the list\n" |
|
"is a number or a string.\n" |
|
"ALWAYS PRESENT THE FINAL ANSWER BETWEEN THE <ANSWER> AND </ANSWER>\n" |
|
"TAGS.\n" |
|
"\n" |
|
"When, for achieving the solution, you have to perform a sum, DON'T\n" |
|
"try to do that yourself. Exploit the tool that is able to sum a list\n" |
|
" of numbers. If you have to sum the results of previous sums, use\n" |
|
"again the same tool, by calling it again.\n" |
|
"You are advised to cycle between reasoning and tool calling also\n" |
|
"multiple times. Provide an answer only when you are sure you don't\n" |
|
"have to call any tool again." |
|
) |
|
|
|
__llm: Runnable |
|
__tools: list[BaseTool] |
|
|
|
def __init__( |
|
self, |
|
model: str = "qwen2.5-coder:32b", |
|
|
|
|
|
temperature: float = 0.0, |
|
num_ctx: int = 8192 |
|
) -> None: |
|
""" |
|
Constructor. |
|
|
|
Args: |
|
model: The name of the Ollama model to use. |
|
temperature: Temperature parameter. |
|
num_ctx: Size of the context window used to generate the |
|
next token. |
|
""" |
|
search_tool = DuckDuckGoSearchResults( |
|
description=( |
|
"A wrapper around Duck Duck Go Search. Useful for when you " |
|
"need to answer questions about information you can find on " |
|
"the web. Input should be a search query. It is advisable to " |
|
"use this tool to retrieve web page URLs and use another tool " |
|
"to analyze the pages. If the web source is suggested by the " |
|
"user query, prefer retrieving information from that source. " |
|
"For example, the query may suggest to search on Wikipedia or " |
|
"Medium. In those cases, prepend the query with " |
|
"'site: <name of the source>'. For example: " |
|
"'site: wikipedia.org'" |
|
), |
|
output_format="list" |
|
) |
|
search_tool.with_retry() |
|
self.__tools = [ |
|
execute_python_script, |
|
get_excel_table_content, |
|
get_youtube_video_transcript, |
|
reverse_string, |
|
search_tool, |
|
sum_list, |
|
transcribe_audio_file, |
|
web_page_info_retriever, |
|
youtube_video_to_frame_captions |
|
] |
|
self.__llm = ChatOllama( |
|
model=model, |
|
temperature=temperature, |
|
num_ctx=num_ctx |
|
).bind_tools(tools=self.__tools) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __run_llm(self, state: MessagesState) -> dict[str, Any]: |
|
answer = self.__llm.invoke(state["messages"]) |
|
|
|
pattern = r'\n*<think>.*?</think>\n*' |
|
answer.content = re.sub( |
|
pattern, "", answer.content, flags=re.DOTALL |
|
) |
|
return {"messages": [answer]} |
|
|
|
@staticmethod |
|
def __extract_last_message( |
|
state: list[AnyMessage] | dict[str, Any] | BaseModel, |
|
messages_key: str |
|
) -> str: |
|
if isinstance(state, list): |
|
last_message = state[-1] |
|
elif isinstance(state, dict) and (messages := state.get(messages_key, [])): |
|
last_message = messages[-1] |
|
elif messages := getattr(state, messages_key, []): |
|
last_message = messages[-1] |
|
else: |
|
raise ValueError(f"No messages found in input state to tool_edge: {state}") |
|
return last_message |
|
|
|
def __route_from_llm( |
|
self, |
|
state: list[AnyMessage] | dict[str, Any] | BaseModel, |
|
messages_key: str = "messages", |
|
) -> Literal["tools", "extract_final_answer"]: |
|
ai_message = self.__extract_last_message(state, messages_key) |
|
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: |
|
return "tools" |
|
return "extract_final_answer" |
|
|
|
@staticmethod |
|
def __extract_final_answer(state: MessagesState) -> dict[str, Any]: |
|
last_message = state["messages"][-1].content |
|
pattern = r"<ANSWER>(?P<answer>.*?)</ANSWER>" |
|
m = re.search(pattern, last_message, flags=re.DOTALL) |
|
answer = m.group("answer").strip() if m else "" |
|
return {"messages": [answer]} |
|
|
|
@property |
|
def system_prompt(self) -> SystemMessage: |
|
""" |
|
Returns: |
|
The system prompt to use with the agent. |
|
""" |
|
return SystemMessage(content=self.__system_prompt) |
|
|
|
def get(self) -> CompiledGraph: |
|
""" |
|
Factory method. |
|
|
|
Returns: |
|
The instance of the agent. |
|
""" |
|
graph_builder = StateGraph(MessagesState) |
|
|
|
graph_builder.add_node("LLM", self.__run_llm) |
|
graph_builder.add_node("tools", ToolNode(tools=self.__tools)) |
|
graph_builder.add_node( |
|
"extract_final_answer", |
|
self.__extract_final_answer |
|
) |
|
|
|
graph_builder.add_edge(start_key=START, end_key="LLM") |
|
graph_builder.add_conditional_edges( |
|
source="LLM", |
|
path=self.__route_from_llm, |
|
path_map={ |
|
"tools": "tools", |
|
"extract_final_answer": "extract_final_answer" |
|
} |
|
) |
|
graph_builder.add_edge(start_key="tools", end_key="LLM") |
|
graph_builder.add_edge(start_key="extract_final_answer", end_key=END) |
|
|
|
return graph_builder.compile() |
|
|