MuntasirHossain commited on
Commit
a7d217d
·
verified ·
1 Parent(s): 40e1456

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -72
app.py CHANGED
@@ -2,7 +2,6 @@ import gradio as gr
2
  import os
3
  api_token = os.getenv("HF_TOKEN")
4
 
5
-
6
  from langchain_community.vectorstores import FAISS
7
  from langchain_community.document_loaders import PyPDFLoader
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -18,11 +17,7 @@ import torch
18
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
19
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
20
 
21
- # Load and split PDF document
22
  def load_doc(list_file_path):
23
- # Processing for one document only
24
- # loader = PyPDFLoader(file_path)
25
- # pages = loader.load()
26
  loaders = [PyPDFLoader(x) for x in list_file_path]
27
  pages = []
28
  for loader in loaders:
@@ -34,14 +29,11 @@ def load_doc(list_file_path):
34
  doc_splits = text_splitter.split_documents(pages)
35
  return doc_splits
36
 
37
- # Create vector database
38
  def create_db(splits):
39
  embeddings = HuggingFaceEmbeddings()
40
  vectordb = FAISS.from_documents(splits, embeddings)
41
  return vectordb
42
 
43
-
44
- # Initialize langchain LLM chain
45
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
46
  if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
47
  llm = HuggingFaceEndpoint(
@@ -77,36 +69,27 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
77
  )
78
  return qa_chain
79
 
80
- # Initialize database
81
  def initialize_database(list_file_obj, progress=gr.Progress()):
82
- # Create a list of documents (when valid)
83
  list_file_path = [x.name for x in list_file_obj if x is not None]
84
- # Load document and create splits
85
  doc_splits = load_doc(list_file_path)
86
- # Create or load vector database
87
  vector_db = create_db(doc_splits)
88
  return vector_db, "Database created!"
89
 
90
- # Initialize LLM
91
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
92
- # print("llm_option",llm_option)
93
  llm_name = list_llm[llm_option]
94
  print("llm_name: ",llm_name)
95
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
96
  return qa_chain, "QA chain initialized. Chatbot is ready!"
97
 
98
-
99
  def format_chat_history(message, chat_history):
100
  formatted_chat_history = []
101
  for user_message, bot_message in chat_history:
102
  formatted_chat_history.append(f"User: {user_message}")
103
  formatted_chat_history.append(f"Assistant: {bot_message}")
104
  return formatted_chat_history
105
-
106
 
107
  def conversation(qa_chain, message, history):
108
  formatted_chat_history = format_chat_history(message, history)
109
- # Generate response using QA chain
110
  response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
111
  response_answer = response["answer"]
112
  if response_answer.find("Helpful Answer:") != -1:
@@ -115,14 +98,11 @@ def conversation(qa_chain, message, history):
115
  response_source1 = response_sources[0].page_content.strip()
116
  response_source2 = response_sources[1].page_content.strip()
117
  response_source3 = response_sources[2].page_content.strip()
118
- # Langchain sources are zero-based
119
  response_source1_page = response_sources[0].metadata["page"] + 1
120
  response_source2_page = response_sources[1].metadata["page"] + 1
121
  response_source3_page = response_sources[2].metadata["page"] + 1
122
- # Append user message and response to chat history
123
  new_history = history + [(message, response_answer)]
124
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
125
-
126
 
127
  def upload_file(file_obj):
128
  list_file_path = []
@@ -133,24 +113,23 @@ def upload_file(file_obj):
133
 
134
  def demo():
135
  custom_css = """
136
- .horizontal-container {
137
- display: flex !important;
138
- flex-direction: row !important;
139
- flex-wrap: nowrap !important;
140
- width: 100% !important;
141
  }
142
- .column-1 {
143
- min-width: 300px !important;
144
- max-width: 35% !important;
145
- flex: 1 !important;
146
  }
147
- .column-2 {
148
- min-width: 500px !important;
149
- flex: 2 !important;
150
  }
151
- @media (max-width: 900px) {
152
- .column-1 { max-width: 40% !important; }
153
- .column-2 { min-width: 400px !important; }
154
  }
155
  """
156
 
@@ -158,55 +137,41 @@ def demo():
158
  vector_db = gr.State()
159
  qa_chain = gr.State()
160
  gr.HTML("<center><h1>RAG PDF chatbot</h1><center>")
161
- gr.Markdown("""<b>Query your PDF documents!</b> This AI agent is designed to perform retrieval augmented generation (RAG) on PDF documents. The app is hosted on Hugging Face Hub for the sole purpose of demonstration. \
162
- <b>Please do not upload confidential documents.</b>
163
- """)
164
 
165
- with gr.Row(elem_classes="horizontal-container"):
166
- with gr.Column(elem_classes="column-1"):
167
  gr.Markdown("<b>Step 1 - Upload PDF documents and Initialize RAG pipeline</b>")
168
- with gr.Row():
169
- document = gr.Files(height=300, file_count="multiple", file_types=[".pdf"], interactive=True, label="Upload PDF documents")
170
- with gr.Row():
171
- db_btn = gr.Button("Create vector database")
172
- with gr.Row():
173
- db_progress = gr.Textbox(value="Not initialized", show_label=False)
174
- gr.Markdown("<style>body { font-size: 16px; }</style><b>Select Large Language Model (LLM) and input parameters</b>")
175
- with gr.Row():
176
- llm_btn = gr.Radio(list_llm_simple, label="Available LLMs", value=list_llm_simple[0], type="index")
177
- with gr.Row():
178
- with gr.Accordion("LLM input parameters", open=False):
179
- with gr.Row():
180
- slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Temperature", info="Controls randomness in token generation", interactive=True)
181
- with gr.Row():
182
- slider_maxtokens = gr.Slider(minimum=128, maximum=9192, value=4096, step=128, label="Max New Tokens", info="Maximum number of tokens to be generated", interactive=True)
183
- with gr.Row():
184
- slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k", info="Number of tokens to select the next token from", interactive=True)
185
- with gr.Row():
186
- qachain_btn = gr.Button("Initialize Question Answering Chatbot")
187
- with gr.Row():
188
- llm_progress = gr.Textbox(value="Not initialized", show_label=False)
189
-
190
- with gr.Column(elem_classes="column-2"):
191
  gr.Markdown("<b>Step 2 - Chat with your Document</b>")
192
  chatbot = gr.Chatbot(height=505)
193
  with gr.Accordion("Relevant context from the source document", open=False):
194
  with gr.Row():
195
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
196
  source1_page = gr.Number(label="Page", scale=1)
197
  with gr.Row():
198
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
199
  source2_page = gr.Number(label="Page", scale=1)
200
  with gr.Row():
201
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
202
  source3_page = gr.Number(label="Page", scale=1)
203
- with gr.Row():
204
- msg = gr.Textbox(placeholder="Ask a question", container=True)
205
  with gr.Row():
206
  submit_btn = gr.Button("Submit")
207
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
208
-
209
- # Rest of your event handlers remain the same...
210
  db_btn.click(initialize_database,
211
  inputs=[document],
212
  outputs=[vector_db, db_progress])
@@ -216,7 +181,6 @@ def demo():
216
  inputs=None,
217
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
218
  queue=False)
219
-
220
  msg.submit(conversation,
221
  inputs=[qa_chain, msg, chatbot],
222
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
@@ -231,4 +195,6 @@ def demo():
231
  queue=False)
232
 
233
  demo.queue().launch(debug=True)
234
-
 
 
 
2
  import os
3
  api_token = os.getenv("HF_TOKEN")
4
 
 
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_community.document_loaders import PyPDFLoader
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
17
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
18
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
19
 
 
20
  def load_doc(list_file_path):
 
 
 
21
  loaders = [PyPDFLoader(x) for x in list_file_path]
22
  pages = []
23
  for loader in loaders:
 
29
  doc_splits = text_splitter.split_documents(pages)
30
  return doc_splits
31
 
 
32
  def create_db(splits):
33
  embeddings = HuggingFaceEmbeddings()
34
  vectordb = FAISS.from_documents(splits, embeddings)
35
  return vectordb
36
 
 
 
37
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
38
  if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
39
  llm = HuggingFaceEndpoint(
 
69
  )
70
  return qa_chain
71
 
 
72
  def initialize_database(list_file_obj, progress=gr.Progress()):
 
73
  list_file_path = [x.name for x in list_file_obj if x is not None]
 
74
  doc_splits = load_doc(list_file_path)
 
75
  vector_db = create_db(doc_splits)
76
  return vector_db, "Database created!"
77
 
 
78
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
79
  llm_name = list_llm[llm_option]
80
  print("llm_name: ",llm_name)
81
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
82
  return qa_chain, "QA chain initialized. Chatbot is ready!"
83
 
 
84
  def format_chat_history(message, chat_history):
85
  formatted_chat_history = []
86
  for user_message, bot_message in chat_history:
87
  formatted_chat_history.append(f"User: {user_message}")
88
  formatted_chat_history.append(f"Assistant: {bot_message}")
89
  return formatted_chat_history
 
90
 
91
  def conversation(qa_chain, message, history):
92
  formatted_chat_history = format_chat_history(message, history)
 
93
  response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
94
  response_answer = response["answer"]
95
  if response_answer.find("Helpful Answer:") != -1:
 
98
  response_source1 = response_sources[0].page_content.strip()
99
  response_source2 = response_sources[1].page_content.strip()
100
  response_source3 = response_sources[2].page_content.strip()
 
101
  response_source1_page = response_sources[0].metadata["page"] + 1
102
  response_source2_page = response_sources[1].metadata["page"] + 1
103
  response_source3_page = response_sources[2].metadata["page"] + 1
 
104
  new_history = history + [(message, response_answer)]
105
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
 
106
 
107
  def upload_file(file_obj):
108
  list_file_path = []
 
113
 
114
  def demo():
115
  custom_css = """
116
+ #column-container {
117
+ display: flex;
118
+ flex-direction: row;
119
+ flex-wrap: nowrap;
 
120
  }
121
+ #column-left {
122
+ min-width: 350px;
123
+ max-width: 35%;
124
+ margin-right: 20px;
125
  }
126
+ #column-right {
127
+ min-width: 500px;
128
+ flex-grow: 1;
129
  }
130
+ @media (max-width: 1200px) {
131
+ #column-left { min-width: 300px; }
132
+ #column-right { min-width: 400px; }
133
  }
134
  """
135
 
 
137
  vector_db = gr.State()
138
  qa_chain = gr.State()
139
  gr.HTML("<center><h1>RAG PDF chatbot</h1><center>")
140
+ gr.Markdown("""<b>Query your PDF documents!</b> This AI agent is designed to perform retrieval augmented generation (RAG) on PDF documents.""")
 
 
141
 
142
+ with gr.Row(elem_id="column-container"):
143
+ with gr.Column(elem_id="column-left"):
144
  gr.Markdown("<b>Step 1 - Upload PDF documents and Initialize RAG pipeline</b>")
145
+ document = gr.Files(height=300, file_count="multiple", file_types=[".pdf"], interactive=True, label="Upload PDF documents")
146
+ db_btn = gr.Button("Create vector database")
147
+ db_progress = gr.Textbox(value="Not initialized", show_label=False)
148
+ gr.Markdown("<b>Select Large Language Model (LLM) and input parameters</b>")
149
+ llm_btn = gr.Radio(list_llm_simple, label="Available LLMs", value=list_llm_simple[0], type="index")
150
+ with gr.Accordion("LLM input parameters", open=False):
151
+ slider_temperature = gr.Slider(0.01, 1.0, value=0.5, step=0.1, label="Temperature")
152
+ slider_maxtokens = gr.Slider(128, 9192, value=4096, step=128, label="Max New Tokens")
153
+ slider_topk = gr.Slider(1, 10, value=3, step=1, label="top-k")
154
+ qachain_btn = gr.Button("Initialize Question Answering Chatbot")
155
+ llm_progress = gr.Textbox(value="Not initialized", show_label=False)
156
+
157
+ with gr.Column(elem_id="column-right"):
 
 
 
 
 
 
 
 
 
 
158
  gr.Markdown("<b>Step 2 - Chat with your Document</b>")
159
  chatbot = gr.Chatbot(height=505)
160
  with gr.Accordion("Relevant context from the source document", open=False):
161
  with gr.Row():
162
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, scale=20)
163
  source1_page = gr.Number(label="Page", scale=1)
164
  with gr.Row():
165
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, scale=20)
166
  source2_page = gr.Number(label="Page", scale=1)
167
  with gr.Row():
168
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, scale=20)
169
  source3_page = gr.Number(label="Page", scale=1)
170
+ msg = gr.Textbox(placeholder="Ask a question")
 
171
  with gr.Row():
172
  submit_btn = gr.Button("Submit")
173
  clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
174
+
 
175
  db_btn.click(initialize_database,
176
  inputs=[document],
177
  outputs=[vector_db, db_progress])
 
181
  inputs=None,
182
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
183
  queue=False)
 
184
  msg.submit(conversation,
185
  inputs=[qa_chain, msg, chatbot],
186
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
 
195
  queue=False)
196
 
197
  demo.queue().launch(debug=True)
198
+
199
+ if __name__ == "__main__":
200
+ demo()