nav13n commited on
Commit
0fa51ba
·
1 Parent(s): bfbe9a0

first commit

Browse files
Files changed (2) hide show
  1. app.py +9 -1
  2. sage.py +15 -15
app.py CHANGED
@@ -7,6 +7,15 @@ from langchain.schema.runnable import RunnableConfig
7
  welcome_message = "Welcome! I'm Sage, your friendly AI assistant. I'm here to help you quickly find answers to your HR and policy questions. What can I assist you with today?"
8
  @cl.on_chat_start
9
  async def start_chat():
 
 
 
 
 
 
 
 
 
10
  await cl.Message(content=welcome_message).send()
11
  cl.user_session.set("runnable", app)
12
 
@@ -18,7 +27,6 @@ async def main(message: cl.Message):
18
 
19
  input = {"question": message.content}
20
 
21
- value = None
22
  for output in runnable.stream(input):
23
  for key, value in output.items():
24
  print(f"Finished running: {key}:")
 
7
  welcome_message = "Welcome! I'm Sage, your friendly AI assistant. I'm here to help you quickly find answers to your HR and policy questions. What can I assist you with today?"
8
  @cl.on_chat_start
9
  async def start_chat():
10
+
11
+ input = {"question": ""}
12
+
13
+ for output in app.stream(input):
14
+ for key, value in output.items():
15
+ print(f"Finished running...{key}:")
16
+
17
+ print("Initialised chain...")
18
+
19
  await cl.Message(content=welcome_message).send()
20
  cl.user_session.set("runnable", app)
21
 
 
27
 
28
  input = {"question": message.content}
29
 
 
30
  for output in runnable.stream(input):
31
  for key, value in output.items():
32
  print(f"Finished running: {key}:")
sage.py CHANGED
@@ -27,8 +27,6 @@ text_splitter = RecursiveCharacterTextSplitter(
27
  )
28
  doc_splits = text_splitter.split_documents(documents)
29
 
30
- print(len(doc_splits),doc_splits[0])
31
-
32
  vectorstore = FAISS.from_documents(documents=doc_splits,embedding=embed_model)
33
 
34
  from langchain.retrievers import ContextualCompressionRetriever
@@ -320,7 +318,7 @@ def dummy_payroll_api_call(employee_id, month, year):
320
 
321
  return data[year][month]
322
 
323
- print(dummy_payroll_api_call(1234, 'CUR', 2024))
324
 
325
  import time
326
  from langchain.prompts import PromptTemplate
@@ -350,9 +348,9 @@ router_prompt = PromptTemplate(
350
 
351
  router_chain = router_prompt | llm | JsonOutputParser()
352
 
353
- print(router_chain.invoke({"question":"What is my salary on 6 2024 ?"}))
354
 
355
- print(router_chain.invoke({"question":"What is leave policy ?"}))
356
 
357
  payroll_schema= {
358
  "$schema": "http://json-schema.org/draft-07/schema#",
@@ -482,7 +480,7 @@ payroll_schema= {
482
  "required": ["employeeDetails", "paymentDetails", "companyDetails"]
483
  }
484
 
485
- print(str(payroll_schema))
486
 
487
  import time
488
  from langchain.prompts import PromptTemplate
@@ -511,7 +509,7 @@ filter_extraction_prompt = PromptTemplate(
511
 
512
  fiter_extraction_chain = filter_extraction_prompt | llm | JsonOutputParser()
513
 
514
- print(fiter_extraction_chain.invoke({"question":"What is my salary on 6 2024 ?"}))
515
 
516
  import time
517
  from langchain.prompts import PromptTemplate
@@ -550,11 +548,11 @@ api_result
550
 
551
  payroll_qa_chain.invoke({"question":"What is my salary on jan 2024 ?", "data":api_result, "schema":payroll_schema})
552
 
 
 
553
  from typing_extensions import TypedDict
554
  from typing import List
555
 
556
- ### State
557
-
558
  class AgentState(TypedDict):
559
  question : str
560
  answer : str
@@ -604,8 +602,8 @@ def retrieve_policy(state):
604
  documents = compression_retriever.invoke(question)
605
  return {"documents": documents, "question": question}
606
 
607
- state = AgentState(question="What is leave policy?", answer="", documents=None)
608
- retrieve_policy(state)
609
 
610
  def generate_answer(state):
611
  """
@@ -626,8 +624,8 @@ def generate_answer(state):
626
 
627
  return {"documents": documents, "question": question, "answer": answer}
628
 
629
- state = AgentState(question="What is leave policy?", answer="", documents=[Document(page_content="According to leave policy, there are two types of leaves 1: PL 2: CL")])
630
- generate_answer(state)
631
 
632
  def query_payroll(state):
633
  """
@@ -652,9 +650,11 @@ def query_payroll(state):
652
  documents = [Document(page_content=context)]
653
  return {"documents": documents, "question": question}
654
 
655
- state = AgentState(question="Tell me salary for Jan 2024?", answer="", documents=None)
656
- query_payroll(state)
 
657
 
 
658
  from langgraph.graph import END, StateGraph
659
  workflow = StateGraph(AgentState)
660
 
 
27
  )
28
  doc_splits = text_splitter.split_documents(documents)
29
 
 
 
30
  vectorstore = FAISS.from_documents(documents=doc_splits,embedding=embed_model)
31
 
32
  from langchain.retrievers import ContextualCompressionRetriever
 
318
 
319
  return data[year][month]
320
 
321
+ # print(dummy_payroll_api_call(1234, 'CUR', 2024))
322
 
323
  import time
324
  from langchain.prompts import PromptTemplate
 
348
 
349
  router_chain = router_prompt | llm | JsonOutputParser()
350
 
351
+ # print(router_chain.invoke({"question":"What is my salary on 6 2024 ?"}))
352
 
353
+ # print(router_chain.invoke({"question":"What is leave policy ?"}))
354
 
355
  payroll_schema= {
356
  "$schema": "http://json-schema.org/draft-07/schema#",
 
480
  "required": ["employeeDetails", "paymentDetails", "companyDetails"]
481
  }
482
 
483
+ # print(str(payroll_schema))
484
 
485
  import time
486
  from langchain.prompts import PromptTemplate
 
509
 
510
  fiter_extraction_chain = filter_extraction_prompt | llm | JsonOutputParser()
511
 
512
+ # print(fiter_extraction_chain.invoke({"question":"What is my salary on 6 2024 ?"}))
513
 
514
  import time
515
  from langchain.prompts import PromptTemplate
 
548
 
549
  payroll_qa_chain.invoke({"question":"What is my salary on jan 2024 ?", "data":api_result, "schema":payroll_schema})
550
 
551
+
552
+ ########### Create Nodes and Actions ###########
553
  from typing_extensions import TypedDict
554
  from typing import List
555
 
 
 
556
  class AgentState(TypedDict):
557
  question : str
558
  answer : str
 
602
  documents = compression_retriever.invoke(question)
603
  return {"documents": documents, "question": question}
604
 
605
+ # state = AgentState(question="What is leave policy?", answer="", documents=None)
606
+ # retrieve_policy(state)
607
 
608
  def generate_answer(state):
609
  """
 
624
 
625
  return {"documents": documents, "question": question, "answer": answer}
626
 
627
+ # state = AgentState(question="What is leave policy?", answer="", documents=[Document(page_content="According to leave policy, there are two types of leaves 1: PL 2: CL")])
628
+ # generate_answer(state)
629
 
630
  def query_payroll(state):
631
  """
 
650
  documents = [Document(page_content=context)]
651
  return {"documents": documents, "question": question}
652
 
653
+ # state = AgentState(question="Tell me salary for Jan 2024?", answer="", documents=None)
654
+ # query_payroll(state)
655
+
656
 
657
+ ########### Build Execution Graph ###########
658
  from langgraph.graph import END, StateGraph
659
  workflow = StateGraph(AgentState)
660