nav13n commited on
Commit
c8e6eb5
·
1 Parent(s): 310f427

first commit

Browse files
Files changed (2) hide show
  1. sage.py +35 -681
  2. utils.py +376 -0
sage.py CHANGED
@@ -1,24 +1,44 @@
1
 
2
- from langchain_community.vectorstores import FAISS
 
 
 
 
 
 
 
 
 
3
 
 
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain_community.document_loaders import PyMuPDFLoader
6
  from langchain_community.vectorstores import FAISS
7
  from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
8
- import os
9
- from dotenv import load_dotenv
 
 
 
 
 
 
 
 
10
  load_dotenv()
11
 
 
12
  embed_model = FastEmbedEmbeddings(model_name="snowflake/snowflake-arctic-embed-m")
13
 
14
- from groq import Groq
15
- from langchain_groq import ChatGroq
16
 
17
 
18
  llm = ChatGroq(temperature=0,
19
  model_name="Llama3-8b-8192",
20
  api_key=os.getenv("GROQ_API_KEY"),)
21
 
 
 
 
22
  loader = PyMuPDFLoader("https://home.synise.com/HRUtility/Documents/HRA/UmaP/Synise%20Handbook.pdf")
23
  documents = loader.load()
24
 
@@ -29,21 +49,15 @@ doc_splits = text_splitter.split_documents(documents)
29
 
30
  vectorstore = FAISS.from_documents(documents=doc_splits,embedding=embed_model)
31
 
32
- from langchain.retrievers import ContextualCompressionRetriever
33
- from langchain.retrievers.document_compressors import FlashrankRerank
34
-
35
  compressor = FlashrankRerank()
36
  retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 20})
37
  compression_retriever = ContextualCompressionRetriever(
38
  base_compressor=compressor, base_retriever=retriever
39
  )
40
 
41
- from operator import itemgetter
42
- from langchain.prompts import PromptTemplate
43
- from langchain.schema.runnable import RunnablePassthrough
44
- from langchain_core.output_parsers import StrOutputParser
45
-
46
 
 
47
  RAG_PROMPT_TEMPLATE = """
48
  <|begin_of_text|><|start_header_id|>system<|end_header_id|>
49
 
@@ -71,260 +85,7 @@ response_chain = (rag_prompt
71
 
72
  )
73
 
74
- def dummy_payroll_api_call(employee_id, month, year):
75
-
76
- data = {
77
- 2023: {
78
- "MAY": {
79
- "employeeDetails": {
80
- "employeeId": "E2468",
81
- "firstName": "Sarah",
82
- "lastName": "Thompson",
83
- "designation": "Product Manager"
84
- },
85
- "paymentDetails": {
86
- "year": 2023,
87
- "month": "JAN",
88
- "basicSalary": 5500,
89
- "allowances": [
90
- {
91
- "type": "Housing Allowance",
92
- "amount": 1500
93
- },
94
- {
95
- "type": "Travel Allowance",
96
- "amount": 800
97
- }
98
- ],
99
- "deductions": [
100
- {
101
- "type": "Provident Fund",
102
- "amount": 650
103
- },
104
- {
105
- "type": "Health Insurance",
106
- "amount": 300
107
- }
108
- ],
109
- "taxes": [
110
- {
111
- "type": "Income Tax",
112
- "amount": 1300
113
- }
114
- ],
115
- "grossSalary": 7800,
116
- "totalDeductions": 2250,
117
- "netSalary": 6650
118
- },
119
- "companyDetails": {
120
- "companyName": "Tech Solutions Ltd.",
121
- "address": "789 Maple Avenue, City"
122
- }
123
- }
124
- },
125
- 2024: {
126
- "JAN": {
127
- "employeeDetails": {
128
- "employeeId": "E2468",
129
- "firstName": "Sarah",
130
- "lastName": "Thompson",
131
- "designation": "Product Manager"
132
- },
133
- "paymentDetails": {
134
- "year": 2024,
135
- "month": "JAN",
136
- "basicSalary": 6500,
137
- "allowances": [
138
- {
139
- "type": "Housing Allowance",
140
- "amount": 1500
141
- },
142
- {
143
- "type": "Travel Allowance",
144
- "amount": 800
145
- }
146
- ],
147
- "deductions": [
148
- {
149
- "type": "Provident Fund",
150
- "amount": 650
151
- },
152
- {
153
- "type": "Health Insurance",
154
- "amount": 300
155
- }
156
- ],
157
- "taxes": [
158
- {
159
- "type": "Income Tax",
160
- "amount": 1300
161
- }
162
- ],
163
- "grossSalary": 8800,
164
- "totalDeductions": 2250,
165
- "netSalary": 6550
166
- },
167
- "companyDetails": {
168
- "companyName": "Tech Solutions Ltd.",
169
- "address": "789 Maple Avenue, City"
170
- }
171
- },
172
- "FEB": {
173
- "employeeDetails": {
174
- "employeeId": "E2468",
175
- "firstName": "Sarah",
176
- "lastName": "Thompson",
177
- "designation": "Product Manager"
178
- },
179
- "paymentDetails": {
180
- "year": 2024,
181
- "month": "FEB",
182
- "basicSalary": 6500,
183
- "allowances": [
184
- {
185
- "type": "Housing Allowance",
186
- "amount": 1500
187
- },
188
- {
189
- "type": "Travel Allowance",
190
- "amount": 800
191
- }
192
- ],
193
- "deductions": [
194
- {
195
- "type": "Provident Fund",
196
- "amount": 650
197
- },
198
- {
199
- "type": "Health Insurance",
200
- "amount": 300
201
- }
202
- ],
203
- "taxes": [
204
- {
205
- "type": "Income Tax",
206
- "amount": 1300
207
- }
208
- ],
209
- "grossSalary": 8800,
210
- "totalDeductions": 2250,
211
- "netSalary": 6550
212
- },
213
- "companyDetails": {
214
- "companyName": "Tech Solutions Ltd.",
215
- "address": "789 Maple Avenue, City"
216
- }
217
- },
218
- "MAY": {
219
- "employeeDetails": {
220
- "employeeId": "E2468",
221
- "firstName": "Sarah",
222
- "lastName": "Thompson",
223
- "designation": "Product Manager"
224
- },
225
- "paymentDetails": {
226
- "year": 2024,
227
- "month": "MAY",
228
- "basicSalary": 6500,
229
- "allowances": [
230
- {
231
- "type": "Housing Allowance",
232
- "amount": 1500
233
- },
234
- {
235
- "type": "Travel Allowance",
236
- "amount": 800
237
- }
238
- ],
239
- "deductions": [
240
- {
241
- "type": "Provident Fund",
242
- "amount": 650
243
- },
244
- {
245
- "type": "Health Insurance",
246
- "amount": 300
247
- }
248
- ],
249
- "taxes": [
250
- {
251
- "type": "Income Tax",
252
- "amount": 1500
253
- }
254
- ],
255
- "grossSalary": 8800,
256
- "totalDeductions": 2450,
257
- "netSalary": 6350
258
- },
259
- "companyDetails": {
260
- "companyName": "Tech Solutions Ltd.",
261
- "address": "789 Maple Avenue, City"
262
- }
263
- },
264
- "APR": {
265
- "employeeDetails": {
266
- "employeeId": "E2468",
267
- "firstName": "Sarah",
268
- "lastName": "Thompson",
269
- "designation": "Product Manager"
270
- },
271
- "paymentDetails": {
272
- "year": 2024,
273
- "month": "APR",
274
- "basicSalary": 6500,
275
- "allowances": [
276
- {
277
- "type": "Housing Allowance",
278
- "amount": 1500
279
- },
280
- {
281
- "type": "Travel Allowance",
282
- "amount": 800
283
- }
284
- ],
285
- "deductions": [
286
- {
287
- "type": "Provident Fund",
288
- "amount": 650
289
- },
290
- {
291
- "type": "Health Insurance",
292
- "amount": 300
293
- }
294
- ],
295
- "taxes": [
296
- {
297
- "type": "Income Tax",
298
- "amount": 1500
299
- }
300
- ],
301
- "grossSalary": 8800,
302
- "totalDeductions": 2450,
303
- "netSalary": 6350
304
- },
305
- "companyDetails": {
306
- "companyName": "Tech Solutions Ltd.",
307
- "address": "789 Maple Avenue, City"
308
- }
309
- }
310
- }
311
- }
312
- year= 2024 if year == "CUR" else year
313
- year= 2023 if year == "PREV" else year
314
-
315
- month= "MAY" if month == "CUR" else month
316
- month= "APR" if month == "PREV" else month
317
-
318
-
319
- return data[year][month]
320
-
321
- # print(dummy_payroll_api_call(1234, 'CUR', 2024))
322
-
323
- import time
324
- from langchain.prompts import PromptTemplate
325
- from langchain_core.output_parsers import JsonOutputParser
326
- from langchain_core.output_parsers import StrOutputParser
327
-
328
  ROUTER_AGENT_PROMPT_TEMPLATE = """
329
  <|begin_of_text|><|start_header_id|>system<|end_header_id|>
330
 
@@ -348,145 +109,10 @@ router_prompt = PromptTemplate(
348
 
349
  router_chain = router_prompt | llm | JsonOutputParser()
350
 
351
- # print(router_chain.invoke({"question":"What is my salary on 6 2024 ?"}))
352
-
353
- # print(router_chain.invoke({"question":"What is leave policy ?"}))
354
-
355
- payroll_schema= {
356
- "$schema": "http://json-schema.org/draft-07/schema#",
357
- "title": "Monthly Payslip",
358
- "description": "A schema for a monthly payslip",
359
- "type": "object",
360
- "properties": {
361
- "employeeDetails": {
362
- "type": "object",
363
- "properties": {
364
- "employeeId": {
365
- "type": "string",
366
- "description": "Unique identifier for the employee"
367
- },
368
- "firstName": {
369
- "type": "string",
370
- "description": "First name of the employee"
371
- },
372
- "lastName": {
373
- "type": "string",
374
- "description": "Last name of the employee"
375
- },
376
- "designation": {
377
- "type": "string",
378
- "description": "Designation or job title of the employee"
379
- }
380
- },
381
- "required": ["employeeId", "firstName", "lastName", "designation"]
382
- },
383
- "paymentDetails": {
384
- "type": "object",
385
- "properties": {
386
- "year": {
387
- "type": "integer",
388
- "description": "Year of the pay period"
389
- },
390
- "month": {
391
- "type": "string",
392
- "enum": ["JAN", "FEB", "MAR", "APR", "MAY", "JUN", "JUL", "AUG", "SEP", "OCT", "NOV", "DEC"],
393
- "description": "Month of the pay period"
394
- },
395
- "basicSalary": {
396
- "type": "number",
397
- "description": "Basic salary of the employee"
398
- },
399
- "allowances": {
400
- "type": "array",
401
- "items": {
402
- "type": "object",
403
- "properties": {
404
- "type": {
405
- "type": "string",
406
- "description": "Type of allowance"
407
- },
408
- "amount": {
409
- "type": "number",
410
- "description": "Amount of the allowance"
411
- }
412
- },
413
- "required": ["type", "amount"]
414
- }
415
- },
416
- "deductions": {
417
- "type": "array",
418
- "items": {
419
- "type": "object",
420
- "properties": {
421
- "type": {
422
- "type": "string",
423
- "description": "Type of deduction"
424
- },
425
- "amount": {
426
- "type": "number",
427
- "description": "Amount of the deduction"
428
- }
429
- },
430
- "required": ["type", "amount"]
431
- }
432
- },
433
- "taxes": {
434
- "type": "array",
435
- "items": {
436
- "type": "object",
437
- "properties": {
438
- "type": {
439
- "type": "string",
440
- "description": "Type of tax"
441
- },
442
- "amount": {
443
- "type": "number",
444
- "description": "Amount of the tax"
445
- }
446
- },
447
- "required": ["type", "amount"]
448
- }
449
- },
450
- "grossSalary": {
451
- "type": "number",
452
- "description": "Gross salary (basic salary + allowances)"
453
- },
454
- "totalDeductions": {
455
- "type": "number",
456
- "description": "Total deductions (including taxes)"
457
- },
458
- "netSalary": {
459
- "type": "number",
460
- "description": "Net salary (gross salary - total deductions)"
461
- }
462
- },
463
- "required": ["year", "month", "basicSalary", "allowances", "deductions", "taxes", "grossSalary", "totalDeductions", "netSalary"]
464
- },
465
- "companyDetails": {
466
- "type": "object",
467
- "properties": {
468
- "companyName": {
469
- "type": "string",
470
- "description": "Name of the company"
471
- },
472
- "address": {
473
- "type": "string",
474
- "description": "Address of the company"
475
- }
476
- },
477
- "required": ["companyName", "address"]
478
- }
479
- },
480
- "required": ["employeeDetails", "paymentDetails", "companyDetails"]
481
- }
482
-
483
- # print(str(payroll_schema))
484
 
485
- import time
486
- from langchain.prompts import PromptTemplate
487
- from langchain_core.output_parsers import JsonOutputParser
488
- from langchain_core.output_parsers import StrOutputParser
489
 
 
490
  FILTER_EXTTRACTION_PROMPT = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
491
  Extract the month and year from a given user question about payroll. Use the following schema instructions to guide your extraction.
492
 
@@ -509,12 +135,8 @@ filter_extraction_prompt = PromptTemplate(
509
 
510
  fiter_extraction_chain = filter_extraction_prompt | llm | JsonOutputParser()
511
 
512
- # print(fiter_extraction_chain.invoke({"question":"What is my salary on 6 2024 ?"}))
513
 
514
- import time
515
- from langchain.prompts import PromptTemplate
516
- from langchain_core.output_parsers import JsonOutputParser
517
- from langchain_core.output_parsers import StrOutputParser
518
 
519
  PAYROLL_QA_PROMPT = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
520
 
@@ -536,120 +158,13 @@ payroll_qa_prompt = PromptTemplate(
536
  template=PAYROLL_QA_PROMPT, input_variables=["question", "data", "schema"]
537
  )
538
 
539
- payroll_qa_chain = payroll_qa_prompt | llm | StrOutputParser()
540
-
541
- result = fiter_extraction_chain.invoke({"question":"What is my salary on jan 2024 ?"})
542
-
543
- result
544
-
545
- api_result = dummy_payroll_api_call(1234, result["month"], result["year"])
546
-
547
- api_result
548
-
549
- payroll_qa_chain.invoke({"question":"What is my salary on jan 2024 ?", "data":api_result, "schema":payroll_schema})
550
-
551
-
552
-
553
-
554
- ### Retrieval Grader
555
-
556
- from langchain.prompts import PromptTemplate
557
- from langchain_community.chat_models import ChatOllama
558
- from langchain_core.output_parsers import JsonOutputParser
559
-
560
-
561
- RETREIVAL_GRADER_PROMPT = PromptTemplate(
562
- template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
563
- You are a grader assessing relevance of a retrieved document to a user question. \n
564
- Here is the retrieved document: \n\n {document} \n\n
565
- Here is the user question: {question} \n
566
- If the document contains keywords related to the user question, grade it as relevant. \n
567
- It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
568
- Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n
569
- Provide the binary score as a JSON with a single key 'score',
570
- Do not include any preamble, explanation, or additional text
571
- <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
572
- input_variables=["question", "document"],
573
- )
574
-
575
- retrieval_grader = RETREIVAL_GRADER_PROMPT | llm | JsonOutputParser()
576
-
577
- # question = "agent memory"
578
- # docs = retriever.get_relevant_documents(question)
579
- # doc_txt = docs[1].page_content
580
- # print(retrieval_grader.invoke({"question": question, "document": doc_txt}))
581
-
582
-
583
- ### Hallucination Grader
584
-
585
- # Prompt
586
- HALLUCINATION_GRADER_PROMPT = PromptTemplate(
587
- template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
588
- You are a grader assessing whether an answer is grounded in / supported by a set of facts. \n
589
- Here are the facts:
590
- \n ------- \n
591
- {documents}
592
- \n ------- \n
593
- Here is the answer: {answer}
594
- Give a binary score 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts. \n
595
- Provide the binary score as a JSON with a single key 'score'.
596
- Do not include any preamble, explanation, or additional text
597
- <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
598
- input_variables=["answer", "documents"],
599
- )
600
-
601
- hallucination_grader = HALLUCINATION_GRADER_PROMPT | llm | JsonOutputParser()
602
- # hallucination_grader.invoke({"documents": docs, "generation": generation})
603
-
604
-
605
- ### Answer Grader
606
-
607
- # Prompt
608
- ANSWER_GRADER_PROMPT = PromptTemplate(
609
- template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
610
- You are a grader assessing whether an answer is useful to resolve a question. \n
611
- Here is the answer:
612
- \n ------- \n
613
- {answer}
614
- \n ------- \n
615
- Here is the question: {question}
616
- Give a binary score 'yes' or 'no' to indicate whether the answer is useful to resolve a question. \n
617
- Provide the binary score as a JSON with a single key 'score'.
618
- Do not include any preamble, explanation, or additional text
619
- <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
620
- input_variables=["answer", "question"],
621
- )
622
-
623
- answer_grader = ANSWER_GRADER_PROMPT | llm | JsonOutputParser()
624
-
625
- ## Question Re-writer
626
-
627
- # Prompt
628
- REWRITER_PROMPT = PromptTemplate(
629
- template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
630
- You a question re-writer that converts an input question to a better version that is optimized \n
631
- for vectorstore retrieval. Look at the initial and formulate an improved question. \n
632
- Here is the initial question: \n\n {question}. Improved question with no preamble: \n
633
- <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
634
- input_variables=["answer", "question"],
635
- )
636
-
637
- question_rewriter = REWRITER_PROMPT | llm | StrOutputParser()
638
- # question_rewriter.invoke({"question": question})
639
-
640
- # answer_grader.invoke({"question": question, "generation": generation})
641
-
642
- ########### Create Nodes and Actions ###########
643
- from typing_extensions import TypedDict
644
- from typing import List
645
 
646
  class AgentState(TypedDict):
647
  question : str
648
  answer : str
649
  documents : List[str]
650
 
651
- import logging as log
652
-
653
  def route_question(state):
654
  """
655
  Route question to payroll_agent or policy_agent to retrieve reevant data
@@ -660,18 +175,16 @@ def route_question(state):
660
  Returns:
661
  str: Next node to call
662
  """
663
-
664
  question = state["question"]
665
  result = router_chain.invoke({"question": question})
666
 
667
- log.debug('Routing to {}....'.format(result["agent"]))
668
-
669
  return result["agent"]
670
 
671
  state = AgentState(question="What is my salary on jan 2024 ?", answer="", documents=None)
672
  route_question(state)
673
 
674
- from langchain.schema import Document
675
  def retrieve(state):
676
  """
677
  Retrieve documents from vectorstore
@@ -709,137 +222,6 @@ def generate(state):
709
 
710
  return {"documents": documents, "question": question, "answer": answer}
711
 
712
- # state = AgentState(question="What is leave policy?", answer="", documents=[Document(page_content="According to leave policy, there are two types of leaves 1: PL 2: CL")])
713
- # generate_answer(state)
714
-
715
- def grade_documents(state):
716
- """
717
- Determines whether the retrieved documents are relevant to the question.
718
-
719
- Args:
720
- state (dict): The current graph state
721
-
722
- Returns:
723
- state (dict): Updates documents key with only filtered relevant documents
724
- """
725
-
726
- print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
727
- question = state["question"]
728
- documents = state["documents"]
729
-
730
- # Score each doc
731
- filtered_docs = []
732
- for d in documents:
733
- score = retrieval_grader.invoke(
734
- {"question": question, "document": d.page_content}
735
- )
736
- grade = score["score"]
737
- if grade == "yes":
738
- print("---GRADE: DOCUMENT RELEVANT---")
739
- filtered_docs.append(d)
740
- else:
741
- print("---GRADE: DOCUMENT NOT RELEVANT---")
742
- continue
743
- return {"documents": filtered_docs, "question": question}
744
-
745
- def transform_query(state):
746
- """
747
- Transform the query to produce a better question.
748
-
749
- Args:
750
- state (dict): The current graph state
751
-
752
- Returns:
753
- state (dict): Updates question key with a re-phrased question
754
- """
755
-
756
- print("---TRANSFORM QUERY---")
757
- question = state["question"]
758
- documents = state["documents"]
759
-
760
- # Re-write question
761
- better_question = question_rewriter.invoke({"question": question})
762
- return {"documents": documents, "question": better_question}
763
-
764
- def decide_to_generate(state):
765
- """
766
- Determines whether to generate an answer, or re-generate a question.
767
-
768
- Args:
769
- state (dict): The current graph state
770
-
771
- Returns:
772
- str: Binary decision for next node to call
773
- """
774
-
775
- print("---ASSESS GRADED DOCUMENTS---")
776
- question = state["question"]
777
- filtered_documents = state["documents"]
778
-
779
- if not filtered_documents:
780
- # All documents have been filtered check_relevance
781
- # We will re-generate a new query
782
- print(
783
- "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
784
- )
785
- return "transform_query"
786
- else:
787
- # We have relevant documents, so generate answer
788
- print("---DECISION: GENERATE---")
789
- return "generate"
790
-
791
- def fallback(state):
792
- """
793
- Fallback to default answer.
794
-
795
- Args:
796
- state (dict): The current graph state
797
-
798
- Returns:
799
- str: Decision for next node to call
800
- """
801
-
802
-
803
- return {"answer": "Sorry,I don't know the answer to this question."}
804
-
805
- def grade_generation_v_documents_and_question(state):
806
- """
807
- Determines whether the generation is grounded in the document and answers question.
808
-
809
- Args:
810
- state (dict): The current graph state
811
-
812
- Returns:
813
- str: Decision for next node to call
814
- """
815
-
816
- print("---CHECK HALLUCINATIONS---")
817
- question = state["question"]
818
- documents = state["documents"]
819
- answer = state["answer"]
820
-
821
- score = hallucination_grader.invoke(
822
- {"documents": documents, "answer": answer}
823
- )
824
- grade = score["score"]
825
-
826
- # Check hallucination
827
- if grade == "yes":
828
- print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
829
- # Check question-answering
830
- print("---GRADE GENERATION vs QUESTION---")
831
- score = answer_grader.invoke({"question": question, "answer": answer})
832
- grade = score["score"]
833
- if grade == "yes":
834
- print("---DECISION: GENERATION ADDRESSES QUESTION---")
835
- return "useful"
836
- else:
837
- print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
838
- return "not useful"
839
- else:
840
- print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
841
- return "not supported"
842
-
843
  def payroll(state):
844
  """
845
  Query payroll api to retrieve payroll data
@@ -854,7 +236,7 @@ def payroll(state):
854
  print("---QUERY PAYROLL API---")
855
  question = state["question"]
856
  payroll_query_filters = fiter_extraction_chain.invoke({"question":question})
857
- payroll_api_query_results = dummy_payroll_api_call(1234, result["month"], result["year"])
858
 
859
 
860
  context = context = 'PAYROLL DATA SCHEMA: \n {payroll_schema} \n PAYROLL DATA: {payroll_api_query_results}'.format(
@@ -863,21 +245,13 @@ def payroll(state):
863
  documents = [Document(page_content=context)]
864
  return {"documents": documents, "question": question}
865
 
866
- # state = AgentState(question="Tell me salary for Jan 2024?", answer="", documents=None)
867
- # query_payroll(state)
868
-
869
-
870
  ########### Build Execution Graph ###########
871
- from langgraph.graph import END, StateGraph
872
  workflow = StateGraph(AgentState)
873
 
874
  # Define the nodes
875
  workflow.add_node("payroll", payroll)
876
  workflow.add_node("retrieve", retrieve)
877
  workflow.add_node("generate", generate)
878
- # workflow.add_node("grade_documents", grade_documents) # grade documents
879
- # workflow.add_node("transform_query", transform_query) # transform_query
880
- # workflow.add_node("fallback", fallback)
881
 
882
  workflow.set_conditional_entry_point(
883
  route_question,
@@ -887,27 +261,7 @@ workflow.set_conditional_entry_point(
887
  },
888
  )
889
  workflow.add_edge("payroll", "generate")
890
- # workflow.add_edge("retrieve", "generate")
891
- # workflow.add_edge("generate", END)
892
  workflow.add_edge("retrieve", "generate")
893
- # workflow.add_conditional_edges(
894
- # "grade_documents",
895
- # decide_to_generate,
896
- # {
897
- # "transform_query": "transform_query",
898
- # "generate": "generate",
899
- # },
900
- # )
901
- # workflow.add_edge("transform_query", "retrieve")
902
- # workflow.add_conditional_edges(
903
- # "generate",
904
- # grade_generation_v_documents_and_question,
905
- # {
906
- # "not supported": "generate",
907
- # "useful": END,
908
- # "not useful": "fallback",
909
- # },
910
- # )
911
  workflow.add_edge("generate", END)
912
 
913
  app = workflow.compile()
 
1
 
2
+ import os
3
+ import time
4
+ from dotenv import load_dotenv
5
+ from operator import itemgetter
6
+ from typing_extensions import TypedDict
7
+ from typing import List
8
+
9
+ from langchain.prompts import PromptTemplate
10
+ from langchain_core.output_parsers import JsonOutputParser
11
+ from langchain_core.output_parsers import StrOutputParser
12
 
13
+ from langchain_community.vectorstores import FAISS
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
15
  from langchain_community.document_loaders import PyMuPDFLoader
16
  from langchain_community.vectorstores import FAISS
17
  from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
18
+ from langchain.retrievers import ContextualCompressionRetriever
19
+ from langchain.retrievers.document_compressors import FlashrankRerank
20
+ from langchain.schema import Document
21
+ from langgraph.graph import END, StateGraph
22
+
23
+ from groq import Groq
24
+ from langchain_groq import ChatGroq
25
+
26
+ from utils import get_payroll_api_schema, dummy_payroll_api_call
27
+
28
  load_dotenv()
29
 
30
+ # Setup the models
31
  embed_model = FastEmbedEmbeddings(model_name="snowflake/snowflake-arctic-embed-m")
32
 
 
 
33
 
34
 
35
  llm = ChatGroq(temperature=0,
36
  model_name="Llama3-8b-8192",
37
  api_key=os.getenv("GROQ_API_KEY"),)
38
 
39
+
40
+
41
+ # Load the documents
42
  loader = PyMuPDFLoader("https://home.synise.com/HRUtility/Documents/HRA/UmaP/Synise%20Handbook.pdf")
43
  documents = loader.load()
44
 
 
49
 
50
  vectorstore = FAISS.from_documents(documents=doc_splits,embedding=embed_model)
51
 
52
+ # Setup the retriever
 
 
53
  compressor = FlashrankRerank()
54
  retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 20})
55
  compression_retriever = ContextualCompressionRetriever(
56
  base_compressor=compressor, base_retriever=retriever
57
  )
58
 
 
 
 
 
 
59
 
60
+ # Define RAG Chain
61
  RAG_PROMPT_TEMPLATE = """
62
  <|begin_of_text|><|start_header_id|>system<|end_header_id|>
63
 
 
85
 
86
  )
87
 
88
+ # Setup Router Chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  ROUTER_AGENT_PROMPT_TEMPLATE = """
90
  <|begin_of_text|><|start_header_id|>system<|end_header_id|>
91
 
 
109
 
110
  router_chain = router_prompt | llm | JsonOutputParser()
111
 
112
+ payroll_schema = get_payroll_api_schema()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
 
 
 
 
114
 
115
+ # Define Filter Extraction Chain
116
  FILTER_EXTTRACTION_PROMPT = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
117
  Extract the month and year from a given user question about payroll. Use the following schema instructions to guide your extraction.
118
 
 
135
 
136
  fiter_extraction_chain = filter_extraction_prompt | llm | JsonOutputParser()
137
 
 
138
 
139
+ # Define Payroll QA Chain
 
 
 
140
 
141
  PAYROLL_QA_PROMPT = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
142
 
 
158
  template=PAYROLL_QA_PROMPT, input_variables=["question", "data", "schema"]
159
  )
160
 
161
+ ########### Create Nodes Actions ###########
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  class AgentState(TypedDict):
164
  question : str
165
  answer : str
166
  documents : List[str]
167
 
 
 
168
  def route_question(state):
169
  """
170
  Route question to payroll_agent or policy_agent to retrieve reevant data
 
175
  Returns:
176
  str: Next node to call
177
  """
178
+ print("---ROUTING---")
179
  question = state["question"]
180
  result = router_chain.invoke({"question": question})
181
 
 
 
182
  return result["agent"]
183
 
184
  state = AgentState(question="What is my salary on jan 2024 ?", answer="", documents=None)
185
  route_question(state)
186
 
187
+
188
  def retrieve(state):
189
  """
190
  Retrieve documents from vectorstore
 
222
 
223
  return {"documents": documents, "question": question, "answer": answer}
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  def payroll(state):
226
  """
227
  Query payroll api to retrieve payroll data
 
236
  print("---QUERY PAYROLL API---")
237
  question = state["question"]
238
  payroll_query_filters = fiter_extraction_chain.invoke({"question":question})
239
+ payroll_api_query_results = dummy_payroll_api_call(1234, payroll_query_filters["month"], payroll_query_filters["year"])
240
 
241
 
242
  context = context = 'PAYROLL DATA SCHEMA: \n {payroll_schema} \n PAYROLL DATA: {payroll_api_query_results}'.format(
 
245
  documents = [Document(page_content=context)]
246
  return {"documents": documents, "question": question}
247
 
 
 
 
 
248
  ########### Build Execution Graph ###########
 
249
  workflow = StateGraph(AgentState)
250
 
251
  # Define the nodes
252
  workflow.add_node("payroll", payroll)
253
  workflow.add_node("retrieve", retrieve)
254
  workflow.add_node("generate", generate)
 
 
 
255
 
256
  workflow.set_conditional_entry_point(
257
  route_question,
 
261
  },
262
  )
263
  workflow.add_edge("payroll", "generate")
 
 
264
  workflow.add_edge("retrieve", "generate")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  workflow.add_edge("generate", END)
266
 
267
  app = workflow.compile()
utils.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def dummy_payroll_api_call(employee_id, month, year):
2
+
3
+ data = {
4
+ 2023: {
5
+ "MAY": {
6
+ "employeeDetails": {
7
+ "employeeId": "E2468",
8
+ "firstName": "Sarah",
9
+ "lastName": "Thompson",
10
+ "designation": "Product Manager"
11
+ },
12
+ "paymentDetails": {
13
+ "year": 2023,
14
+ "month": "JAN",
15
+ "basicSalary": 5500,
16
+ "allowances": [
17
+ {
18
+ "type": "Housing Allowance",
19
+ "amount": 1500
20
+ },
21
+ {
22
+ "type": "Travel Allowance",
23
+ "amount": 800
24
+ }
25
+ ],
26
+ "deductions": [
27
+ {
28
+ "type": "Provident Fund",
29
+ "amount": 650
30
+ },
31
+ {
32
+ "type": "Health Insurance",
33
+ "amount": 300
34
+ }
35
+ ],
36
+ "taxes": [
37
+ {
38
+ "type": "Income Tax",
39
+ "amount": 1300
40
+ }
41
+ ],
42
+ "grossSalary": 7800,
43
+ "totalDeductions": 2250,
44
+ "netSalary": 6650
45
+ },
46
+ "companyDetails": {
47
+ "companyName": "Tech Solutions Ltd.",
48
+ "address": "789 Maple Avenue, City"
49
+ }
50
+ }
51
+ },
52
+ 2024: {
53
+ "JAN": {
54
+ "employeeDetails": {
55
+ "employeeId": "E2468",
56
+ "firstName": "Sarah",
57
+ "lastName": "Thompson",
58
+ "designation": "Product Manager"
59
+ },
60
+ "paymentDetails": {
61
+ "year": 2024,
62
+ "month": "JAN",
63
+ "basicSalary": 6500,
64
+ "allowances": [
65
+ {
66
+ "type": "Housing Allowance",
67
+ "amount": 1500
68
+ },
69
+ {
70
+ "type": "Travel Allowance",
71
+ "amount": 800
72
+ }
73
+ ],
74
+ "deductions": [
75
+ {
76
+ "type": "Provident Fund",
77
+ "amount": 650
78
+ },
79
+ {
80
+ "type": "Health Insurance",
81
+ "amount": 300
82
+ }
83
+ ],
84
+ "taxes": [
85
+ {
86
+ "type": "Income Tax",
87
+ "amount": 1300
88
+ }
89
+ ],
90
+ "grossSalary": 8800,
91
+ "totalDeductions": 2250,
92
+ "netSalary": 6550
93
+ },
94
+ "companyDetails": {
95
+ "companyName": "Tech Solutions Ltd.",
96
+ "address": "789 Maple Avenue, City"
97
+ }
98
+ },
99
+ "FEB": {
100
+ "employeeDetails": {
101
+ "employeeId": "E2468",
102
+ "firstName": "Sarah",
103
+ "lastName": "Thompson",
104
+ "designation": "Product Manager"
105
+ },
106
+ "paymentDetails": {
107
+ "year": 2024,
108
+ "month": "FEB",
109
+ "basicSalary": 6500,
110
+ "allowances": [
111
+ {
112
+ "type": "Housing Allowance",
113
+ "amount": 1500
114
+ },
115
+ {
116
+ "type": "Travel Allowance",
117
+ "amount": 800
118
+ }
119
+ ],
120
+ "deductions": [
121
+ {
122
+ "type": "Provident Fund",
123
+ "amount": 650
124
+ },
125
+ {
126
+ "type": "Health Insurance",
127
+ "amount": 300
128
+ }
129
+ ],
130
+ "taxes": [
131
+ {
132
+ "type": "Income Tax",
133
+ "amount": 1300
134
+ }
135
+ ],
136
+ "grossSalary": 8800,
137
+ "totalDeductions": 2250,
138
+ "netSalary": 6550
139
+ },
140
+ "companyDetails": {
141
+ "companyName": "Tech Solutions Ltd.",
142
+ "address": "789 Maple Avenue, City"
143
+ }
144
+ },
145
+ "MAY": {
146
+ "employeeDetails": {
147
+ "employeeId": "E2468",
148
+ "firstName": "Sarah",
149
+ "lastName": "Thompson",
150
+ "designation": "Product Manager"
151
+ },
152
+ "paymentDetails": {
153
+ "year": 2024,
154
+ "month": "MAY",
155
+ "basicSalary": 6500,
156
+ "allowances": [
157
+ {
158
+ "type": "Housing Allowance",
159
+ "amount": 1500
160
+ },
161
+ {
162
+ "type": "Travel Allowance",
163
+ "amount": 800
164
+ }
165
+ ],
166
+ "deductions": [
167
+ {
168
+ "type": "Provident Fund",
169
+ "amount": 650
170
+ },
171
+ {
172
+ "type": "Health Insurance",
173
+ "amount": 300
174
+ }
175
+ ],
176
+ "taxes": [
177
+ {
178
+ "type": "Income Tax",
179
+ "amount": 1500
180
+ }
181
+ ],
182
+ "grossSalary": 8800,
183
+ "totalDeductions": 2450,
184
+ "netSalary": 6350
185
+ },
186
+ "companyDetails": {
187
+ "companyName": "Tech Solutions Ltd.",
188
+ "address": "789 Maple Avenue, City"
189
+ }
190
+ },
191
+ "APR": {
192
+ "employeeDetails": {
193
+ "employeeId": "E2468",
194
+ "firstName": "Sarah",
195
+ "lastName": "Thompson",
196
+ "designation": "Product Manager"
197
+ },
198
+ "paymentDetails": {
199
+ "year": 2024,
200
+ "month": "APR",
201
+ "basicSalary": 6500,
202
+ "allowances": [
203
+ {
204
+ "type": "Housing Allowance",
205
+ "amount": 1500
206
+ },
207
+ {
208
+ "type": "Travel Allowance",
209
+ "amount": 800
210
+ }
211
+ ],
212
+ "deductions": [
213
+ {
214
+ "type": "Provident Fund",
215
+ "amount": 650
216
+ },
217
+ {
218
+ "type": "Health Insurance",
219
+ "amount": 300
220
+ }
221
+ ],
222
+ "taxes": [
223
+ {
224
+ "type": "Income Tax",
225
+ "amount": 1500
226
+ }
227
+ ],
228
+ "grossSalary": 8800,
229
+ "totalDeductions": 2450,
230
+ "netSalary": 6350
231
+ },
232
+ "companyDetails": {
233
+ "companyName": "Tech Solutions Ltd.",
234
+ "address": "789 Maple Avenue, City"
235
+ }
236
+ }
237
+ }
238
+ }
239
+ year= 2024 if year == "CUR" else year
240
+ year= 2023 if year == "PREV" else year
241
+
242
+ month= "MAY" if month == "CUR" else month
243
+ month= "APR" if month == "PREV" else month
244
+
245
+
246
+ return data[year][month]
247
+
248
+ def get_payroll_api_schema():
249
+ schema = {
250
+ "$schema": "http://json-schema.org/draft-07/schema#",
251
+ "title": "Monthly Payslip",
252
+ "description": "A schema for a monthly payslip",
253
+ "type": "object",
254
+ "properties": {
255
+ "employeeDetails": {
256
+ "type": "object",
257
+ "properties": {
258
+ "employeeId": {
259
+ "type": "string",
260
+ "description": "Unique identifier for the employee"
261
+ },
262
+ "firstName": {
263
+ "type": "string",
264
+ "description": "First name of the employee"
265
+ },
266
+ "lastName": {
267
+ "type": "string",
268
+ "description": "Last name of the employee"
269
+ },
270
+ "designation": {
271
+ "type": "string",
272
+ "description": "Designation or job title of the employee"
273
+ }
274
+ },
275
+ "required": ["employeeId", "firstName", "lastName", "designation"]
276
+ },
277
+ "paymentDetails": {
278
+ "type": "object",
279
+ "properties": {
280
+ "year": {
281
+ "type": "integer",
282
+ "description": "Year of the pay period"
283
+ },
284
+ "month": {
285
+ "type": "string",
286
+ "enum": ["JAN", "FEB", "MAR", "APR", "MAY", "JUN", "JUL", "AUG", "SEP", "OCT", "NOV", "DEC"],
287
+ "description": "Month of the pay period"
288
+ },
289
+ "basicSalary": {
290
+ "type": "number",
291
+ "description": "Basic salary of the employee"
292
+ },
293
+ "allowances": {
294
+ "type": "array",
295
+ "items": {
296
+ "type": "object",
297
+ "properties": {
298
+ "type": {
299
+ "type": "string",
300
+ "description": "Type of allowance"
301
+ },
302
+ "amount": {
303
+ "type": "number",
304
+ "description": "Amount of the allowance"
305
+ }
306
+ },
307
+ "required": ["type", "amount"]
308
+ }
309
+ },
310
+ "deductions": {
311
+ "type": "array",
312
+ "items": {
313
+ "type": "object",
314
+ "properties": {
315
+ "type": {
316
+ "type": "string",
317
+ "description": "Type of deduction"
318
+ },
319
+ "amount": {
320
+ "type": "number",
321
+ "description": "Amount of the deduction"
322
+ }
323
+ },
324
+ "required": ["type", "amount"]
325
+ }
326
+ },
327
+ "taxes": {
328
+ "type": "array",
329
+ "items": {
330
+ "type": "object",
331
+ "properties": {
332
+ "type": {
333
+ "type": "string",
334
+ "description": "Type of tax"
335
+ },
336
+ "amount": {
337
+ "type": "number",
338
+ "description": "Amount of the tax"
339
+ }
340
+ },
341
+ "required": ["type", "amount"]
342
+ }
343
+ },
344
+ "grossSalary": {
345
+ "type": "number",
346
+ "description": "Gross salary (basic salary + allowances)"
347
+ },
348
+ "totalDeductions": {
349
+ "type": "number",
350
+ "description": "Total deductions (including taxes)"
351
+ },
352
+ "netSalary": {
353
+ "type": "number",
354
+ "description": "Net salary (gross salary - total deductions)"
355
+ }
356
+ },
357
+ "required": ["year", "month", "basicSalary", "allowances", "deductions", "taxes", "grossSalary", "totalDeductions", "netSalary"]
358
+ },
359
+ "companyDetails": {
360
+ "type": "object",
361
+ "properties": {
362
+ "companyName": {
363
+ "type": "string",
364
+ "description": "Name of the company"
365
+ },
366
+ "address": {
367
+ "type": "string",
368
+ "description": "Address of the company"
369
+ }
370
+ },
371
+ "required": ["companyName", "address"]
372
+ }
373
+ },
374
+ "required": ["employeeDetails", "paymentDetails", "companyDetails"]
375
+ }
376
+ return schema