shrutikaP8497 commited on
Commit
0375d91
·
verified ·
1 Parent(s): abecf76

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +74 -32
agent.py CHANGED
@@ -1,36 +1,78 @@
1
- from tools import get_tools
2
  from retriever import retrieve_context
3
- from config import LLM_MODEL
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
5
-
6
- class Agent:
7
- def __init__(self):
8
- self.model = AutoModelForCausalLM.from_pretrained(
9
- LLM_MODEL,
10
- device_map="auto",
11
- trust_remote_code=True
12
- )
13
- self.tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
14
- self.generator = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)
15
- self.tools = get_tools()
16
-
17
- def generate_answer(self, question: str, context: str = "") -> str:
18
- prompt = f"""
19
- You are an expert AI agent answering academic and logical questions concisely.
20
- Use the context below to help answer the user's question.
21
-
22
- Context:
23
- {context}
24
-
25
- Question:
26
- {question}
 
 
 
 
 
 
 
27
 
28
- Answer:
 
29
  """
30
- outputs = self.generator(prompt, max_new_tokens=100, do_sample=False)
31
- return outputs[0]['generated_text'].split("Answer:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- def run(self, task: dict) -> str:
34
- question = task.get("question", "")
35
- context = retrieve_context(task)
36
- return self.generate_answer(question, context)
 
 
1
  from retriever import retrieve_context
2
+ from tools import tools
3
+
4
+ from langchain_core.runnables import RunnableParallel, RunnablePassthrough
5
+ from langchain_core.prompts import PromptTemplate
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_core.runnables import RunnableLambda
8
+ from langchain_core.messages import AIMessage, HumanMessage
9
+ from langchain_core.runnables.history import RunnableWithMessageHistory
10
+ from langchain_core.runnables import RunnableBranch
11
+ from langchain_community.chat_models import ChatOllama
12
+
13
+ from langgraph.graph import END, StateGraph
14
+ from typing import Annotated, TypedDict, List
15
+ import operator
16
+
17
+ model = ChatOllama(model="qwen:1.8b")
18
+
19
+ tools_with_names = {tool.name: tool for tool in tools}
20
+
21
+ class AgentState(TypedDict):
22
+ messages: Annotated[List], []
23
+ next: str
24
+
25
+ tool_chain = (
26
+ RunnableParallel({
27
+ "message": lambda x: x["messages"][-1].content,
28
+ "tool": lambda x: x["next"]
29
+ })
30
+ | (lambda x: tools_with_names[x["tool"]].invoke(x["message"]))
31
+ | (lambda x: {"messages": [AIMessage(content=str(x))], "next": "end"})
32
+ )
33
 
34
+ system = """
35
+ You are a helpful assistant. Use tools if needed. Keep responses short.
36
  """
37
+ prompt = PromptTemplate.from_template("""{context}
38
+
39
+ {question}
40
+ """)
41
+
42
+ context_chain = (
43
+ {
44
+ "context": RunnableLambda(retrieve_context),
45
+ "question": lambda x: x["messages"][-1].content,
46
+ }
47
+ | prompt
48
+ )
49
+
50
+ agent = context_chain | model | StrOutputParser() | (lambda x: {"messages": [AIMessage(content=x), HumanMessage(content="Do you want to use a tool?")], "next": "tool_picker"})
51
+
52
+ conditional_agent = RunnableBranch(
53
+ (lambda x: "tool" in x["next"], tool_chain),
54
+ agent
55
+ )
56
+
57
+ def create_graph():
58
+ graph_builder = StateGraph(AgentState)
59
+ graph_builder.add_node("agent", conditional_agent)
60
+ graph_builder.set_entry_point("agent")
61
+ graph_builder.add_node("tool_chain", tool_chain)
62
+ graph_builder.add_conditional_edges(
63
+ "agent", lambda x: x["next"], {
64
+ "tool": "tool_chain",
65
+ "end": END
66
+ }
67
+ )
68
+ graph_builder.add_edge("tool_chain", "agent")
69
+ return graph_builder.compile()
70
+
71
+ app = create_graph()
72
+ chain = RunnableWithMessageHistory(
73
+ app,
74
+ lambda session_id: {},
75
+ input_messages_key="messages",
76
+ history_messages_key="messages",
77
+ )
78