first commit
Browse files
app.py
CHANGED
@@ -7,6 +7,15 @@ from langchain.schema.runnable import RunnableConfig
|
|
7 |
welcome_message = "Welcome! I'm Sage, your friendly AI assistant. I'm here to help you quickly find answers to your HR and policy questions. What can I assist you with today?"
|
8 |
@cl.on_chat_start
|
9 |
async def start_chat():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
await cl.Message(content=welcome_message).send()
|
11 |
cl.user_session.set("runnable", app)
|
12 |
|
@@ -18,7 +27,6 @@ async def main(message: cl.Message):
|
|
18 |
|
19 |
input = {"question": message.content}
|
20 |
|
21 |
-
value = None
|
22 |
for output in runnable.stream(input):
|
23 |
for key, value in output.items():
|
24 |
print(f"Finished running: {key}:")
|
|
|
7 |
welcome_message = "Welcome! I'm Sage, your friendly AI assistant. I'm here to help you quickly find answers to your HR and policy questions. What can I assist you with today?"
|
8 |
@cl.on_chat_start
|
9 |
async def start_chat():
|
10 |
+
|
11 |
+
input = {"question": ""}
|
12 |
+
|
13 |
+
for output in app.stream(input):
|
14 |
+
for key, value in output.items():
|
15 |
+
print(f"Finished running...{key}:")
|
16 |
+
|
17 |
+
print("Initialised chain...")
|
18 |
+
|
19 |
await cl.Message(content=welcome_message).send()
|
20 |
cl.user_session.set("runnable", app)
|
21 |
|
|
|
27 |
|
28 |
input = {"question": message.content}
|
29 |
|
|
|
30 |
for output in runnable.stream(input):
|
31 |
for key, value in output.items():
|
32 |
print(f"Finished running: {key}:")
|
sage.py
CHANGED
@@ -27,8 +27,6 @@ text_splitter = RecursiveCharacterTextSplitter(
|
|
27 |
)
|
28 |
doc_splits = text_splitter.split_documents(documents)
|
29 |
|
30 |
-
print(len(doc_splits),doc_splits[0])
|
31 |
-
|
32 |
vectorstore = FAISS.from_documents(documents=doc_splits,embedding=embed_model)
|
33 |
|
34 |
from langchain.retrievers import ContextualCompressionRetriever
|
@@ -320,7 +318,7 @@ def dummy_payroll_api_call(employee_id, month, year):
|
|
320 |
|
321 |
return data[year][month]
|
322 |
|
323 |
-
print(dummy_payroll_api_call(1234, 'CUR', 2024))
|
324 |
|
325 |
import time
|
326 |
from langchain.prompts import PromptTemplate
|
@@ -350,9 +348,9 @@ router_prompt = PromptTemplate(
|
|
350 |
|
351 |
router_chain = router_prompt | llm | JsonOutputParser()
|
352 |
|
353 |
-
print(router_chain.invoke({"question":"What is my salary on 6 2024 ?"}))
|
354 |
|
355 |
-
print(router_chain.invoke({"question":"What is leave policy ?"}))
|
356 |
|
357 |
payroll_schema= {
|
358 |
"$schema": "http://json-schema.org/draft-07/schema#",
|
@@ -482,7 +480,7 @@ payroll_schema= {
|
|
482 |
"required": ["employeeDetails", "paymentDetails", "companyDetails"]
|
483 |
}
|
484 |
|
485 |
-
print(str(payroll_schema))
|
486 |
|
487 |
import time
|
488 |
from langchain.prompts import PromptTemplate
|
@@ -511,7 +509,7 @@ filter_extraction_prompt = PromptTemplate(
|
|
511 |
|
512 |
fiter_extraction_chain = filter_extraction_prompt | llm | JsonOutputParser()
|
513 |
|
514 |
-
print(fiter_extraction_chain.invoke({"question":"What is my salary on 6 2024 ?"}))
|
515 |
|
516 |
import time
|
517 |
from langchain.prompts import PromptTemplate
|
@@ -550,11 +548,11 @@ api_result
|
|
550 |
|
551 |
payroll_qa_chain.invoke({"question":"What is my salary on jan 2024 ?", "data":api_result, "schema":payroll_schema})
|
552 |
|
|
|
|
|
553 |
from typing_extensions import TypedDict
|
554 |
from typing import List
|
555 |
|
556 |
-
### State
|
557 |
-
|
558 |
class AgentState(TypedDict):
|
559 |
question : str
|
560 |
answer : str
|
@@ -604,8 +602,8 @@ def retrieve_policy(state):
|
|
604 |
documents = compression_retriever.invoke(question)
|
605 |
return {"documents": documents, "question": question}
|
606 |
|
607 |
-
state = AgentState(question="What is leave policy?", answer="", documents=None)
|
608 |
-
retrieve_policy(state)
|
609 |
|
610 |
def generate_answer(state):
|
611 |
"""
|
@@ -626,8 +624,8 @@ def generate_answer(state):
|
|
626 |
|
627 |
return {"documents": documents, "question": question, "answer": answer}
|
628 |
|
629 |
-
state = AgentState(question="What is leave policy?", answer="", documents=[Document(page_content="According to leave policy, there are two types of leaves 1: PL 2: CL")])
|
630 |
-
generate_answer(state)
|
631 |
|
632 |
def query_payroll(state):
|
633 |
"""
|
@@ -652,9 +650,11 @@ def query_payroll(state):
|
|
652 |
documents = [Document(page_content=context)]
|
653 |
return {"documents": documents, "question": question}
|
654 |
|
655 |
-
state = AgentState(question="Tell me salary for Jan 2024?", answer="", documents=None)
|
656 |
-
query_payroll(state)
|
|
|
657 |
|
|
|
658 |
from langgraph.graph import END, StateGraph
|
659 |
workflow = StateGraph(AgentState)
|
660 |
|
|
|
27 |
)
|
28 |
doc_splits = text_splitter.split_documents(documents)
|
29 |
|
|
|
|
|
30 |
vectorstore = FAISS.from_documents(documents=doc_splits,embedding=embed_model)
|
31 |
|
32 |
from langchain.retrievers import ContextualCompressionRetriever
|
|
|
318 |
|
319 |
return data[year][month]
|
320 |
|
321 |
+
# print(dummy_payroll_api_call(1234, 'CUR', 2024))
|
322 |
|
323 |
import time
|
324 |
from langchain.prompts import PromptTemplate
|
|
|
348 |
|
349 |
router_chain = router_prompt | llm | JsonOutputParser()
|
350 |
|
351 |
+
# print(router_chain.invoke({"question":"What is my salary on 6 2024 ?"}))
|
352 |
|
353 |
+
# print(router_chain.invoke({"question":"What is leave policy ?"}))
|
354 |
|
355 |
payroll_schema= {
|
356 |
"$schema": "http://json-schema.org/draft-07/schema#",
|
|
|
480 |
"required": ["employeeDetails", "paymentDetails", "companyDetails"]
|
481 |
}
|
482 |
|
483 |
+
# print(str(payroll_schema))
|
484 |
|
485 |
import time
|
486 |
from langchain.prompts import PromptTemplate
|
|
|
509 |
|
510 |
fiter_extraction_chain = filter_extraction_prompt | llm | JsonOutputParser()
|
511 |
|
512 |
+
# print(fiter_extraction_chain.invoke({"question":"What is my salary on 6 2024 ?"}))
|
513 |
|
514 |
import time
|
515 |
from langchain.prompts import PromptTemplate
|
|
|
548 |
|
549 |
payroll_qa_chain.invoke({"question":"What is my salary on jan 2024 ?", "data":api_result, "schema":payroll_schema})
|
550 |
|
551 |
+
|
552 |
+
########### Create Nodes and Actions ###########
|
553 |
from typing_extensions import TypedDict
|
554 |
from typing import List
|
555 |
|
|
|
|
|
556 |
class AgentState(TypedDict):
|
557 |
question : str
|
558 |
answer : str
|
|
|
602 |
documents = compression_retriever.invoke(question)
|
603 |
return {"documents": documents, "question": question}
|
604 |
|
605 |
+
# state = AgentState(question="What is leave policy?", answer="", documents=None)
|
606 |
+
# retrieve_policy(state)
|
607 |
|
608 |
def generate_answer(state):
|
609 |
"""
|
|
|
624 |
|
625 |
return {"documents": documents, "question": question, "answer": answer}
|
626 |
|
627 |
+
# state = AgentState(question="What is leave policy?", answer="", documents=[Document(page_content="According to leave policy, there are two types of leaves 1: PL 2: CL")])
|
628 |
+
# generate_answer(state)
|
629 |
|
630 |
def query_payroll(state):
|
631 |
"""
|
|
|
650 |
documents = [Document(page_content=context)]
|
651 |
return {"documents": documents, "question": question}
|
652 |
|
653 |
+
# state = AgentState(question="Tell me salary for Jan 2024?", answer="", documents=None)
|
654 |
+
# query_payroll(state)
|
655 |
+
|
656 |
|
657 |
+
########### Build Execution Graph ###########
|
658 |
from langgraph.graph import END, StateGraph
|
659 |
workflow = StateGraph(AgentState)
|
660 |
|