vkasyap commited on
Commit
7bf82e5
·
verified ·
1 Parent(s): 0dfce95

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -189
app.py DELETED
@@ -1,189 +0,0 @@
1
- import gradio as gr
2
- import os
3
- from langchain_community.document_loaders import PyPDFLoader
4
- from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain_community.vectorstores import Chroma
6
- from langchain.chains import ConversationalRetrievalChain
7
- from langchain_community.embeddings import HuggingFaceEmbeddings
8
- from langchain_community.llms import HuggingFacePipeline
9
- from langchain.chains import ConversationChain
10
- from langchain.memory import ConversationBufferMemory
11
- from langchain_community.llms import HuggingFaceEndpoint
12
- from pathlib import Path
13
- import chromadb
14
- from unidecode import unidecode
15
- from transformers import AutoTokenizer, AutoModelForMaskedLM
16
- import transformers
17
- import torch
18
- import tqdm
19
- import accelerate
20
- import re
21
-
22
- # Load the tokenizer and model
23
- tokenizer = AutoTokenizer.from_pretrained("google/muril-base-cased")
24
- model = AutoModelForMaskedLM.from_pretrained("google/muril-base-cased")
25
-
26
- # default_persist_directory = './chroma_HF/'
27
- list_llm = ["mistralai/Mistral-7B-Instruct-v0.2"]
28
- list_llm_simple = [os.path.basename(llm) for llm in list_llm]
29
-
30
- # Load PDF document and create doc splits
31
- def load_doc(list_file_path, chunk_size, chunk_overlap):
32
-
33
- loaders = [PyPDFLoader(x) for x in list_file_path]
34
- pages = []
35
- for loader in loaders:
36
- pages.extend(loader.load())
37
- # text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50)
38
- text_splitter = RecursiveCharacterTextSplitter(
39
- chunk_size = chunk_size,
40
- chunk_overlap = chunk_overlap)
41
- doc_splits = text_splitter.split_documents(pages)
42
- return doc_splits
43
-
44
- # Create vector database
45
- def create_db(splits, collection_name):
46
- embedding = HuggingFaceEmbeddings()
47
- new_client = chromadb.EphemeralClient()
48
- vectordb = Chroma.from_documents(
49
- documents=splits,
50
- embedding=embedding,
51
- client=new_client,
52
- collection_name=collection_name,
53
- # persist_directory=default_persist_directory
54
- )
55
- return vectordb
56
-
57
-
58
- # Load vector database
59
- def load_db():
60
- embedding = HuggingFaceEmbeddings()
61
- vectordb = Chroma(
62
- # persist_directory=default_persist_directory,
63
- embedding_function=embedding)
64
- return vectordb
65
-
66
-
67
- # Initialize langchain LLM chain
68
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
69
- progress(0.1, desc="Initializing HF tokenizer...")
70
-
71
- progress(0.5, desc="Initializing HF Hub...")
72
-
73
- llm = HuggingFaceEndpoint(
74
- repo_id=llm_model,
75
- temperature = temperature,
76
- max_new_tokens = max_tokens,
77
- top_k = top_k,
78
- )
79
-
80
- # Initialize conversation chain
81
- conversation_chain = ConversationChain(
82
- llm=llm,
83
- conversation_buffer_memory=ConversationBufferMemory(max_memory=10),
84
- )
85
-
86
- return conversation_chain
87
-
88
-
89
- # Initialize LLM
90
- def initialize_LLM(llm_model, temperature, max_tokens, top_k, vector_db):
91
- progress = gr.Progress()
92
- qa_chain = initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress)
93
- return qa_chain, progress
94
-
95
-
96
- # Format chat history
97
- def format_chat_history(message, history):
98
- formatted_chat_history = ""
99
- for i, (user_message, response) in enumerate(history):
100
- formatted_chat_history += f"User: {user_message}\nAssistant: {response}\n\n"
101
- formatted_chat_history += f"User: {message}\n"
102
- return formatted_chat_history
103
-
104
-
105
- # Conversation function
106
- def conversation(qa_chain, message, history, language):
107
- formatted_chat_history = format_chat_history(message, history)
108
- response = qa_chain({"question": message, "chat_history": formatted_chat_history})
109
- response_answer = response["answer"]
110
- if response_answer.find("Helpful Answer:") != -1:
111
- response_answer = response_answer.split("Helpful Answer:")[-1]
112
- # Detect language of the question if selected
113
- if language == "Detect Language":
114
- from langdetect import detect
115
- language = detect(message)
116
- # Translate response to selected language
117
- if language != "English":
118
- translator = googletrans.Translator()
119
- response_answer = translator.translate(response_answer, dest=language).text
120
- response_sources = response["source_documents"]
121
- response_source1 = response_sources[0].page_content.strip()
122
- response_source2 = response_sources[1].page_content.strip()
123
- response_source3 = response_sources[2].page_content.strip()
124
- response_source1_page = response_sources[0].metadata["page"] + 1
125
- response_source2_page = response_sources[1].metadata["page"] + 1
126
- response_source3_page = response_sources[2].metadata["page"] + 1
127
- return qa_chain, gr.update(value=""), history + [(message, response_answer)], response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
128
-
129
- # Create Gradio interface
130
- demo = gr.Blocks()
131
-
132
- with demo:
133
- with gr.Tab("Step 1 - Upload Document"):
134
- uploaded_file = gr.File(label="Upload Document")
135
- upload_btn = gr.Button("Upload")
136
- document = gr.Textbox(label="Document Content", lines=20, container=True)
137
-
138
- with gr.Tab("Step 2 - Create Database"):
139
- slider_chunk_size = gr.Slider(label="Chunk Size", minimum=100, maximum=1000, value=600, step=100)
140
- slider_chunk_overlap = gr.Slider(label="Chunk Overlap", minimum=0, maximum=500, value=50, step=50)
141
- db_btn = gr.Button("Create Database")
142
- vector_db = gr.Textbox(label="Vector Database", lines=20, container=True)
143
- collection_name = gr.Textbox(label="Collection Name", lines=1, container=True)
144
- db_progress = gr.Progress()
145
-
146
- with gr.Tab("Step 3 - Initialize LLM"):
147
- llm_btn = gr.Dropdown(choices=list_llm_simple, value=list_llm_simple[0], label="LLM Model")
148
- slider_temperature = gr.Slider(label="Temperature", minimum=0, maximum=1, value=0.7, step=0.1)
149
- slider_maxtokens = gr.Slider(label="Max Tokens", minimum=10, maximum=500, value=200, step=50)
150
- slider_topk = gr.Slider(label="Top K", minimum=1, maximum=10, value=5, step=1)
151
- qachain_btn = gr.Button("Initialize LLM")
152
- qa_chain = gr.Textbox(label="QA Chain", lines=20, container=True)
153
- db_progress = gr.Progress()
154
- llm_progress = gr.Progress()
155
-
156
- with gr.Tab("Step 4 - Chatbot"):
157
- chatbot = gr.Chatbot(height=300)
158
- with gr.Accordion("Advanced - Document references", open=False):
159
- with gr.Row():
160
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
161
- source1_page = gr.Number(label="Page", scale=1)
162
- with gr.Row():
163
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
164
- source2_page = gr.Number(label="Page", scale=1)
165
- with gr.Row():
166
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
167
- source3_page = gr.Number(label="Page", scale=1)
168
- with gr.Row():
169
- msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
170
- with gr.Row():
171
- submit_btn = gr.Button("Submit message")
172
- clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
173
- language = gr.Dropdown(choices=["English", "Detect Language"], value="English", label="Language")
174
-
175
- # Preprocessing events
176
- upload_btn.click(load_doc, inputs=[uploaded_file], outputs=[document])
177
- db_btn.click(create_db, inputs=[document, slider_chunk_size, slider_chunk_overlap], outputs=[vector_db, collection_name, db_progress])
178
- qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
179
-
180
- # Chatbot events
181
- msg.submit(conversation, inputs=[qa_chain, msg, chatbot, language], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
182
- submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot, language], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
183
- clear_btn.click(lambda:[None,"",0,"",0,"",0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
184
-
185
- demo.queue().launch(debug=True)
186
-
187
-
188
- if __name__ == "__main__":
189
- demo()