nav13n commited on
Commit
310f427
·
1 Parent(s): 0898a2e

first commit

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. sage.py +261 -25
app.py CHANGED
@@ -20,11 +20,11 @@ async def main(message: cl.Message):
20
  msg = cl.Message(content="")
21
 
22
  input = {"question": message.content}
23
-
24
- async for output in runnable.astream(input):
25
  for key, value in output.items():
26
  print(f"Finished running: {key}:")
27
- if key == "generator_agent":
28
  answer = value["answer"]
29
  await msg.stream_token(answer)
30
 
 
20
  msg = cl.Message(content="")
21
 
22
  input = {"question": message.content}
23
+
24
+ for output in runnable.stream(input):
25
  for key, value in output.items():
26
  print(f"Finished running: {key}:")
27
+ if key == "generate":
28
  answer = value["answer"]
29
  await msg.stream_token(answer)
30
 
sage.py CHANGED
@@ -328,14 +328,14 @@ from langchain_core.output_parsers import StrOutputParser
328
  ROUTER_AGENT_PROMPT_TEMPLATE = """
329
  <|begin_of_text|><|start_header_id|>system<|end_header_id|>
330
 
331
- You are an expert at delegating user questions to one of the most appropriate agents 'policy_agent' or 'payroll_agent'.
332
 
333
  Use the following criteria to determine the appropriate agents to answer the user que:
334
 
335
- - If the query is regarding payslips, salary, tax deductions, basepay of a given month, use payroll_agent'.
336
- - If the question is closely related to general human resource queries, organisational policies, prompt engineering, or adversarial attacks, even if the keywords are not explicitly mentioned, use the 'policyagent'.
337
 
338
- Your output should be a JSON object with a single key 'agent' and a value of either 'policy_agent' or 'payroll_agent'. Do not include any preamble, explanation, or additional text.
339
 
340
  User's Question: {question}
341
 
@@ -549,6 +549,96 @@ api_result
549
  payroll_qa_chain.invoke({"question":"What is my salary on jan 2024 ?", "data":api_result, "schema":payroll_schema})
550
 
551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
  ########### Create Nodes and Actions ###########
553
  from typing_extensions import TypedDict
554
  from typing import List
@@ -576,20 +666,15 @@ def route_question(state):
576
 
577
  log.debug('Routing to {}....'.format(result["agent"]))
578
 
579
- if result['agent'] == 'payroll_agent':
580
- log.debug('Routing to {}....'.format(result["agent"]))
581
- return "payroll_agent"
582
- elif result['agent'] == 'policy_agent':
583
- log.debug('Routing to {}....'.format(result["agent"]))
584
- return "policy_agent"
585
 
586
  state = AgentState(question="What is my salary on jan 2024 ?", answer="", documents=None)
587
  route_question(state)
588
 
589
  from langchain.schema import Document
590
- def retrieve_policy(state):
591
  """
592
- Retrieve policy documents from vectorstore
593
 
594
  Args:
595
  state (dict): The current graph state
@@ -597,7 +682,7 @@ def retrieve_policy(state):
597
  Returns:
598
  state (dict): New key added to state, documents, that contains retrieved documents
599
  """
600
- log.debug("Retreiving policy documents.......")
601
  question = state["question"]
602
  documents = compression_retriever.invoke(question)
603
  return {"documents": documents, "question": question}
@@ -605,7 +690,7 @@ def retrieve_policy(state):
605
  # state = AgentState(question="What is leave policy?", answer="", documents=None)
606
  # retrieve_policy(state)
607
 
608
- def generate_answer(state):
609
  """
610
  Generate answer using retrieved data
611
 
@@ -615,7 +700,7 @@ def generate_answer(state):
615
  Returns:
616
  state (dict): New key added to state, generation, that contains LLM generation
617
  """
618
- log.debug("Generating answer.......")
619
  question = state["question"]
620
  documents = state["documents"]
621
 
@@ -627,7 +712,135 @@ def generate_answer(state):
627
  # state = AgentState(question="What is leave policy?", answer="", documents=[Document(page_content="According to leave policy, there are two types of leaves 1: PL 2: CL")])
628
  # generate_answer(state)
629
 
630
- def query_payroll(state):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
631
  """
632
  Query payroll api to retrieve payroll data
633
 
@@ -638,7 +851,7 @@ def query_payroll(state):
638
  state (dict): Updated state with retrived payroll data
639
  """
640
 
641
-
642
  question = state["question"]
643
  payroll_query_filters = fiter_extraction_chain.invoke({"question":question})
644
  payroll_api_query_results = dummy_payroll_api_call(1234, result["month"], result["year"])
@@ -659,19 +872,42 @@ from langgraph.graph import END, StateGraph
659
  workflow = StateGraph(AgentState)
660
 
661
  # Define the nodes
662
- workflow.add_node("payroll_agent", query_payroll)
663
- workflow.add_node("policy_agent", retrieve_policy)
664
- workflow.add_node("generator_agent", generate_answer)
 
 
 
665
 
666
  workflow.set_conditional_entry_point(
667
  route_question,
668
  {
669
- "payroll_agent": "payroll_agent",
670
- "policy_agent": "policy_agent",
671
  },
672
  )
673
- workflow.add_edge("payroll_agent", "generator_agent")
674
- workflow.add_edge("policy_agent", "generator_agent")
675
- workflow.add_edge("generator_agent", END)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
 
677
  app = workflow.compile()
 
328
  ROUTER_AGENT_PROMPT_TEMPLATE = """
329
  <|begin_of_text|><|start_header_id|>system<|end_header_id|>
330
 
331
+ You are an expert at delegating user questions to one of the most appropriate agents 'raqa' or 'payroll'.
332
 
333
  Use the following criteria to determine the appropriate agents to answer the user que:
334
 
335
+ - If the query is regarding payslips, salary, tax deductions, basepay of a given month, use 'payroll'.
336
+ - If the question is closely related to general human resource queries, organisational policies, prompt engineering, or adversarial attacks, even if the keywords are not explicitly mentioned, use the 'raqa'.
337
 
338
+ Your output should be a JSON object with a single key 'agent' and a value of either 'raqa' or 'payroll'. Do not include any preamble, explanation, or additional text.
339
 
340
  User's Question: {question}
341
 
 
549
  payroll_qa_chain.invoke({"question":"What is my salary on jan 2024 ?", "data":api_result, "schema":payroll_schema})
550
 
551
 
552
+
553
+
554
+ ### Retrieval Grader
555
+
556
+ from langchain.prompts import PromptTemplate
557
+ from langchain_community.chat_models import ChatOllama
558
+ from langchain_core.output_parsers import JsonOutputParser
559
+
560
+
561
+ RETREIVAL_GRADER_PROMPT = PromptTemplate(
562
+ template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
563
+ You are a grader assessing relevance of a retrieved document to a user question. \n
564
+ Here is the retrieved document: \n\n {document} \n\n
565
+ Here is the user question: {question} \n
566
+ If the document contains keywords related to the user question, grade it as relevant. \n
567
+ It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
568
+ Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n
569
+ Provide the binary score as a JSON with a single key 'score',
570
+ Do not include any preamble, explanation, or additional text
571
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
572
+ input_variables=["question", "document"],
573
+ )
574
+
575
+ retrieval_grader = RETREIVAL_GRADER_PROMPT | llm | JsonOutputParser()
576
+
577
+ # question = "agent memory"
578
+ # docs = retriever.get_relevant_documents(question)
579
+ # doc_txt = docs[1].page_content
580
+ # print(retrieval_grader.invoke({"question": question, "document": doc_txt}))
581
+
582
+
583
+ ### Hallucination Grader
584
+
585
+ # Prompt
586
+ HALLUCINATION_GRADER_PROMPT = PromptTemplate(
587
+ template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
588
+ You are a grader assessing whether an answer is grounded in / supported by a set of facts. \n
589
+ Here are the facts:
590
+ \n ------- \n
591
+ {documents}
592
+ \n ------- \n
593
+ Here is the answer: {answer}
594
+ Give a binary score 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts. \n
595
+ Provide the binary score as a JSON with a single key 'score'.
596
+ Do not include any preamble, explanation, or additional text
597
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
598
+ input_variables=["answer", "documents"],
599
+ )
600
+
601
+ hallucination_grader = HALLUCINATION_GRADER_PROMPT | llm | JsonOutputParser()
602
+ # hallucination_grader.invoke({"documents": docs, "generation": generation})
603
+
604
+
605
+ ### Answer Grader
606
+
607
+ # Prompt
608
+ ANSWER_GRADER_PROMPT = PromptTemplate(
609
+ template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
610
+ You are a grader assessing whether an answer is useful to resolve a question. \n
611
+ Here is the answer:
612
+ \n ------- \n
613
+ {answer}
614
+ \n ------- \n
615
+ Here is the question: {question}
616
+ Give a binary score 'yes' or 'no' to indicate whether the answer is useful to resolve a question. \n
617
+ Provide the binary score as a JSON with a single key 'score'.
618
+ Do not include any preamble, explanation, or additional text
619
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
620
+ input_variables=["answer", "question"],
621
+ )
622
+
623
+ answer_grader = ANSWER_GRADER_PROMPT | llm | JsonOutputParser()
624
+
625
+ ## Question Re-writer
626
+
627
+ # Prompt
628
+ REWRITER_PROMPT = PromptTemplate(
629
+ template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
630
+ You a question re-writer that converts an input question to a better version that is optimized \n
631
+ for vectorstore retrieval. Look at the initial and formulate an improved question. \n
632
+ Here is the initial question: \n\n {question}. Improved question with no preamble: \n
633
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
634
+ input_variables=["answer", "question"],
635
+ )
636
+
637
+ question_rewriter = REWRITER_PROMPT | llm | StrOutputParser()
638
+ # question_rewriter.invoke({"question": question})
639
+
640
+ # answer_grader.invoke({"question": question, "generation": generation})
641
+
642
  ########### Create Nodes and Actions ###########
643
  from typing_extensions import TypedDict
644
  from typing import List
 
666
 
667
  log.debug('Routing to {}....'.format(result["agent"]))
668
 
669
+ return result["agent"]
 
 
 
 
 
670
 
671
  state = AgentState(question="What is my salary on jan 2024 ?", answer="", documents=None)
672
  route_question(state)
673
 
674
  from langchain.schema import Document
675
+ def retrieve(state):
676
  """
677
+ Retrieve documents from vectorstore
678
 
679
  Args:
680
  state (dict): The current graph state
 
682
  Returns:
683
  state (dict): New key added to state, documents, that contains retrieved documents
684
  """
685
+ print("---RETRIEVE DOCUMENTS---")
686
  question = state["question"]
687
  documents = compression_retriever.invoke(question)
688
  return {"documents": documents, "question": question}
 
690
  # state = AgentState(question="What is leave policy?", answer="", documents=None)
691
  # retrieve_policy(state)
692
 
693
+ def generate(state):
694
  """
695
  Generate answer using retrieved data
696
 
 
700
  Returns:
701
  state (dict): New key added to state, generation, that contains LLM generation
702
  """
703
+ print("---GENERATE ANSWER---")
704
  question = state["question"]
705
  documents = state["documents"]
706
 
 
712
  # state = AgentState(question="What is leave policy?", answer="", documents=[Document(page_content="According to leave policy, there are two types of leaves 1: PL 2: CL")])
713
  # generate_answer(state)
714
 
715
+ def grade_documents(state):
716
+ """
717
+ Determines whether the retrieved documents are relevant to the question.
718
+
719
+ Args:
720
+ state (dict): The current graph state
721
+
722
+ Returns:
723
+ state (dict): Updates documents key with only filtered relevant documents
724
+ """
725
+
726
+ print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
727
+ question = state["question"]
728
+ documents = state["documents"]
729
+
730
+ # Score each doc
731
+ filtered_docs = []
732
+ for d in documents:
733
+ score = retrieval_grader.invoke(
734
+ {"question": question, "document": d.page_content}
735
+ )
736
+ grade = score["score"]
737
+ if grade == "yes":
738
+ print("---GRADE: DOCUMENT RELEVANT---")
739
+ filtered_docs.append(d)
740
+ else:
741
+ print("---GRADE: DOCUMENT NOT RELEVANT---")
742
+ continue
743
+ return {"documents": filtered_docs, "question": question}
744
+
745
+ def transform_query(state):
746
+ """
747
+ Transform the query to produce a better question.
748
+
749
+ Args:
750
+ state (dict): The current graph state
751
+
752
+ Returns:
753
+ state (dict): Updates question key with a re-phrased question
754
+ """
755
+
756
+ print("---TRANSFORM QUERY---")
757
+ question = state["question"]
758
+ documents = state["documents"]
759
+
760
+ # Re-write question
761
+ better_question = question_rewriter.invoke({"question": question})
762
+ return {"documents": documents, "question": better_question}
763
+
764
+ def decide_to_generate(state):
765
+ """
766
+ Determines whether to generate an answer, or re-generate a question.
767
+
768
+ Args:
769
+ state (dict): The current graph state
770
+
771
+ Returns:
772
+ str: Binary decision for next node to call
773
+ """
774
+
775
+ print("---ASSESS GRADED DOCUMENTS---")
776
+ question = state["question"]
777
+ filtered_documents = state["documents"]
778
+
779
+ if not filtered_documents:
780
+ # All documents have been filtered check_relevance
781
+ # We will re-generate a new query
782
+ print(
783
+ "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
784
+ )
785
+ return "transform_query"
786
+ else:
787
+ # We have relevant documents, so generate answer
788
+ print("---DECISION: GENERATE---")
789
+ return "generate"
790
+
791
+ def fallback(state):
792
+ """
793
+ Fallback to default answer.
794
+
795
+ Args:
796
+ state (dict): The current graph state
797
+
798
+ Returns:
799
+ str: Decision for next node to call
800
+ """
801
+
802
+
803
+ return {"answer": "Sorry,I don't know the answer to this question."}
804
+
805
+ def grade_generation_v_documents_and_question(state):
806
+ """
807
+ Determines whether the generation is grounded in the document and answers question.
808
+
809
+ Args:
810
+ state (dict): The current graph state
811
+
812
+ Returns:
813
+ str: Decision for next node to call
814
+ """
815
+
816
+ print("---CHECK HALLUCINATIONS---")
817
+ question = state["question"]
818
+ documents = state["documents"]
819
+ answer = state["answer"]
820
+
821
+ score = hallucination_grader.invoke(
822
+ {"documents": documents, "answer": answer}
823
+ )
824
+ grade = score["score"]
825
+
826
+ # Check hallucination
827
+ if grade == "yes":
828
+ print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
829
+ # Check question-answering
830
+ print("---GRADE GENERATION vs QUESTION---")
831
+ score = answer_grader.invoke({"question": question, "answer": answer})
832
+ grade = score["score"]
833
+ if grade == "yes":
834
+ print("---DECISION: GENERATION ADDRESSES QUESTION---")
835
+ return "useful"
836
+ else:
837
+ print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
838
+ return "not useful"
839
+ else:
840
+ print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
841
+ return "not supported"
842
+
843
+ def payroll(state):
844
  """
845
  Query payroll api to retrieve payroll data
846
 
 
851
  state (dict): Updated state with retrived payroll data
852
  """
853
 
854
+ print("---QUERY PAYROLL API---")
855
  question = state["question"]
856
  payroll_query_filters = fiter_extraction_chain.invoke({"question":question})
857
  payroll_api_query_results = dummy_payroll_api_call(1234, result["month"], result["year"])
 
872
  workflow = StateGraph(AgentState)
873
 
874
  # Define the nodes
875
+ workflow.add_node("payroll", payroll)
876
+ workflow.add_node("retrieve", retrieve)
877
+ workflow.add_node("generate", generate)
878
+ # workflow.add_node("grade_documents", grade_documents) # grade documents
879
+ # workflow.add_node("transform_query", transform_query) # transform_query
880
+ # workflow.add_node("fallback", fallback)
881
 
882
  workflow.set_conditional_entry_point(
883
  route_question,
884
  {
885
+ "payroll": "payroll",
886
+ "raqa": "retrieve",
887
  },
888
  )
889
+ workflow.add_edge("payroll", "generate")
890
+ # workflow.add_edge("retrieve", "generate")
891
+ # workflow.add_edge("generate", END)
892
+ workflow.add_edge("retrieve", "generate")
893
+ # workflow.add_conditional_edges(
894
+ # "grade_documents",
895
+ # decide_to_generate,
896
+ # {
897
+ # "transform_query": "transform_query",
898
+ # "generate": "generate",
899
+ # },
900
+ # )
901
+ # workflow.add_edge("transform_query", "retrieve")
902
+ # workflow.add_conditional_edges(
903
+ # "generate",
904
+ # grade_generation_v_documents_and_question,
905
+ # {
906
+ # "not supported": "generate",
907
+ # "useful": END,
908
+ # "not useful": "fallback",
909
+ # },
910
+ # )
911
+ workflow.add_edge("generate", END)
912
 
913
  app = workflow.compile()