File size: 5,223 Bytes
03ab668
ed02a3d
 
 
 
 
 
 
 
03ab668
ed02a3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03ab668
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

def initialise_vectorstore(pdf, progress=gr.Progress()):
    progress(0, desc="Reading PDF")

    loader = PyPDFLoader(pdf.name)
    pages = loader.load()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    splits = text_splitter.split_documents(pages)

    progress(0.5, desc="Initialising Vectorstore")

    vectorstore = Chroma.from_documents(
        splits,
        embedding=HuggingFaceEmbeddings()
    )

    progress(1, desc="Complete")

    return vectorstore, progress

def initialise_chain(llm, vectorstore, progress=gr.Progress()):

    progress(0, desc="Initialising LLM")

    llm = HuggingFaceEndpoint(
        repo_id=llm,
        task="text-generation",
        max_new_tokens=512,
        do_sample=False,
        repetition_penalty=1.03
    )

    chat = ChatHuggingFace(
        llm=llm, 
        verbose=True
    )

    progress(0.5, desc="Initialising RAG Chain")

    retriever = vectorstore.as_retriever()
    prompt = hub.pull("rlm/rag-prompt")
    parser = StrOutputParser()

    rag_chain = {"context": retriever, "question": RunnablePassthrough()} | prompt | chat | parser

    progress(1, desc="Complete")

    return rag_chain, progress

def send(message, rag_chain, chat_history):
    response = rag_chain.invoke(message)
    chat_history.append((message, response))
    return "", chat_history

def restart():
    return f"Restarting"


with gr.Blocks() as demo:

    vectorstore = gr.State()
    rag_chain = gr.State()
    
    gr.Markdown("<H1>Talk to Documents</H1>")
    gr.Markdown("<H3>Upload and ask questions about your PDF files</H3>")
    gr.Markdown("<H6>Note: This project uses LangChain to perform RAG (Retrieval Augmented Generation) on PDF files, allowing users to ask any questions related to their contents. When a PDF file is uploaded, it is embedded and stored in an in-memory Chroma vectorstore, which the chatbot uses as a source of knowledge when aswering user questions.</H6>")

    # Vectorstore Tab
    with gr.Tab("Vectorstore"):
        with gr.Row():
            input_pdf = gr.File()
        with gr.Row():
            with gr.Column(scale=1, min_width=0):
                pass
            with gr.Column(scale=2, min_width=0):
                initialise_vectorstore_btn = gr.Button(
                    "Initialise Vectorstore",
                    variant='primary'
                )
            with gr.Column(scale=1, min_width=0):
                pass
        with gr.Row():
            vectorstore_initialisation_progress = gr.Textbox(value="None", label="Initialization")

    # RAG Chain
    with gr.Tab("RAG Chain"):
        with gr.Row():
            language_model = gr.Radio(["microsoft/Phi-3-mini-4k-instruct", "mistralai/Mistral-7B-Instruct-v0.2", "nvidia/Mistral-NeMo-Minitron-8B-Base"])
        with gr.Row():
            with gr.Column(scale=1, min_width=0):
                pass
            with gr.Column(scale=2, min_width=0):
                initialise_chain_btn = gr.Button(
                    "Initialise RAG Chain",
                    variant='primary'
                )
            with gr.Column(scale=1, min_width=0):
                pass
        with gr.Row():
            chain_initialisation_progress = gr.Textbox(value="None", label="Initialization")

    # Chatbot Tab
    with gr.Tab("Chatbot"):
        with gr.Row():
            chatbot = gr.Chatbot()
        with gr.Accordion("Advanced - Document references", open=False):
            with gr.Row():
                doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
                source1_page = gr.Number(label="Page", scale=1)
            with gr.Row():
                doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
                source2_page = gr.Number(label="Page", scale=1)
            with gr.Row():
                doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
                source3_page = gr.Number(label="Page", scale=1)
        with gr.Row():
            message = gr.Textbox()
        with gr.Row():
            send_btn = gr.Button(
                "Send",
                variant=["primary"]
            )
            restart_btn = gr.Button(
                "Restart",
                variant=["secondary"]
            )

    initialise_vectorstore_btn.click(fn=initialise_vectorstore, inputs=input_pdf, outputs=[vectorstore, vectorstore_initialisation_progress])
    initialise_chain_btn.click(fn=initialise_chain, inputs=[language_model, vectorstore], outputs=[rag_chain, chain_initialisation_progress])
    send_btn.click(fn=send, inputs=[message, rag_chain, chatbot], outputs=[message, chatbot])
    restart_btn.click(fn=restart)

demo.launch()