Merge pull request #1 from seanpedrick-case/dev
Browse filesAdded Gemini and AWS Bedrock compatibility. Gemma model. Now document redaction QA.
- .dockerignore +14 -0
- .gitattributes +1 -0
- .gitignore +4 -1
- app.py +172 -165
- chatfuncs/auth.py +40 -13
- chatfuncs/chatfuncs.py +365 -270
- chatfuncs/config.py +250 -0
- chatfuncs/helper_functions.py +123 -61
- chatfuncs/ingest.py +23 -29
- chatfuncs/model_load.py +82 -0
- chatfuncs/prompts.py +8 -5
- faiss_embedding/faiss_embedding.zip +0 -0
- requirements.txt +2 -1
- requirements_gpu.txt +2 -1
.dockerignore
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
*.ipynb
|
3 |
+
*.pdf
|
4 |
+
*.spec
|
5 |
+
*.toc
|
6 |
+
*.csv
|
7 |
+
*.bin
|
8 |
+
bootstrapper.py
|
9 |
+
build/*
|
10 |
+
dist/*
|
11 |
+
test/*
|
12 |
+
config/*
|
13 |
+
output/*
|
14 |
+
input/*
|
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
@@ -8,4 +8,7 @@
|
|
8 |
bootstrapper.py
|
9 |
build/*
|
10 |
dist/*
|
11 |
-
test/*
|
|
|
|
|
|
|
|
8 |
bootstrapper.py
|
9 |
build/*
|
10 |
dist/*
|
11 |
+
test/*
|
12 |
+
config/*
|
13 |
+
output/*
|
14 |
+
input/*
|
app.py
CHANGED
@@ -1,95 +1,122 @@
|
|
1 |
-
# Load in packages
|
2 |
-
|
3 |
import os
|
4 |
-
import socket
|
5 |
-
|
6 |
from typing import Type
|
7 |
-
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
|
8 |
from langchain_community.vectorstores import FAISS
|
9 |
import gradio as gr
|
10 |
import pandas as pd
|
11 |
-
|
12 |
-
from transformers import AutoTokenizer
|
13 |
-
import torch
|
14 |
-
|
15 |
from llama_cpp import Llama
|
16 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
17 |
from chatfuncs.ingest import embed_faiss_save_to_zip
|
18 |
-
from chatfuncs.helper_functions import get_or_create_env_var
|
19 |
|
20 |
-
from chatfuncs.helper_functions import
|
21 |
from chatfuncs.aws_functions import upload_file_to_s3
|
22 |
-
#from chatfuncs.llm_api_call import llm_query
|
23 |
from chatfuncs.auth import authenticate_user
|
|
|
|
|
|
|
|
|
24 |
|
25 |
PandasDataFrame = Type[pd.DataFrame]
|
26 |
|
27 |
from datetime import datetime
|
28 |
today_rev = datetime.now().strftime("%Y%m%d")
|
29 |
|
30 |
-
|
|
|
|
|
|
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
access_logs_data_folder = 'logs/' + today_rev + '/' + host_name + '/'
|
35 |
-
feedback_data_folder = 'feedback/' + today_rev + '/' + host_name + '/'
|
36 |
-
usage_data_folder = 'usage/' + today_rev + '/' + host_name + '/'
|
37 |
|
38 |
# Disable cuda devices if necessary
|
39 |
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
40 |
|
41 |
-
#from chatfuncs.chatfuncs import *
|
42 |
-
import chatfuncs.ingest as ing
|
43 |
|
44 |
###
|
45 |
# Load preset embeddings, vectorstore, and model
|
46 |
###
|
47 |
|
48 |
-
|
49 |
|
50 |
-
|
51 |
|
52 |
-
|
53 |
|
54 |
-
|
55 |
|
56 |
-
|
57 |
|
58 |
-
|
59 |
|
60 |
-
def get_faiss_store(faiss_vstore_folder,embeddings):
|
61 |
-
import zipfile
|
62 |
with zipfile.ZipFile(faiss_vstore_folder + '/' + faiss_vstore_folder + '.zip', 'r') as zip_ref:
|
63 |
zip_ref.extractall(faiss_vstore_folder)
|
64 |
|
65 |
-
faiss_vstore = FAISS.load_local(folder_path=faiss_vstore_folder, embeddings=
|
66 |
os.remove(faiss_vstore_folder + "/index.faiss")
|
67 |
os.remove(faiss_vstore_folder + "/index.pkl")
|
68 |
|
69 |
-
global vectorstore
|
70 |
|
71 |
-
vectorstore = faiss_vstore
|
72 |
|
73 |
-
return vectorstore
|
74 |
|
75 |
-
|
|
|
|
|
76 |
|
77 |
-
chatf.embeddings =
|
78 |
-
chatf.vectorstore =
|
79 |
|
|
|
80 |
|
81 |
-
|
82 |
-
|
|
|
83 |
|
84 |
-
|
85 |
-
if gpu_config is None:
|
86 |
-
gpu_config = chatf.gpu_config
|
87 |
-
if cpu_config is None:
|
88 |
-
cpu_config = chatf.cpu_config
|
89 |
-
if torch_device is None:
|
90 |
-
torch_device = chatf.torch_device
|
91 |
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
if torch_device == "cuda":
|
94 |
gpu_config.update_gpu(gpu_layers)
|
95 |
print("Loading with", gpu_config.n_gpu_layers, "model layers sent to GPU.")
|
@@ -99,86 +126,46 @@ def load_model(model_type, gpu_layers, gpu_config=None, cpu_config=None, torch_d
|
|
99 |
|
100 |
print("Loading with", cpu_config.n_gpu_layers, "model layers sent to GPU.")
|
101 |
|
102 |
-
print(vars(gpu_config))
|
103 |
-
print(vars(cpu_config))
|
104 |
-
|
105 |
try:
|
106 |
model = Llama(
|
107 |
model_path=hf_hub_download(
|
108 |
-
repo_id=
|
109 |
-
filename=
|
110 |
),
|
111 |
**vars(gpu_config) # change n_gpu_layers if you have more or less VRAM
|
112 |
)
|
113 |
|
114 |
except Exception as e:
|
115 |
-
print("GPU load failed")
|
116 |
-
print(e)
|
117 |
model = Llama(
|
118 |
model_path=hf_hub_download(
|
119 |
-
repo_id=
|
120 |
-
filename=
|
121 |
),
|
122 |
**vars(cpu_config)
|
123 |
)
|
124 |
|
125 |
tokenizer = []
|
126 |
|
127 |
-
if model_type ==
|
128 |
# Huggingface chat model
|
129 |
-
hf_checkpoint =
|
130 |
|
131 |
-
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
if "flan" in model_name:
|
137 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")#, torch_dtype=torch.float16)
|
138 |
-
else:
|
139 |
-
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")#, torch_dtype=torch.float16)
|
140 |
-
else:
|
141 |
-
if "flan" in model_name:
|
142 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)#, torch_dtype=torch.float16)
|
143 |
-
else:
|
144 |
-
model = AutoModelForCausalLM.from_pretrained(model_name)#, trust_remote_code=True)#, torch_dtype=torch.float16)
|
145 |
-
|
146 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = chatf.context_length)
|
147 |
-
|
148 |
-
return model, tokenizer, model_type
|
149 |
-
|
150 |
-
model, tokenizer, model_type = create_hf_model(model_name = hf_checkpoint)
|
151 |
|
152 |
-
chatf.
|
153 |
chatf.tokenizer = tokenizer
|
154 |
chatf.model_type = model_type
|
155 |
|
156 |
load_confirmation = "Finished loading model: " + model_type
|
157 |
|
158 |
print(load_confirmation)
|
159 |
-
return model_type, load_confirmation, model_type
|
160 |
-
|
161 |
-
# Both models are loaded on app initialisation so that users don't have to wait for the models to be downloaded
|
162 |
-
#model_type = "Phi 3.5 Mini (larger, slow)"
|
163 |
-
#load_model(model_type, chatf.gpu_layers, chatf.gpu_config, chatf.cpu_config, chatf.torch_device)
|
164 |
-
|
165 |
-
model_type = "Qwen 2 0.5B (small, fast)"
|
166 |
-
load_model(model_type, 0, chatf.gpu_config, chatf.cpu_config, chatf.torch_device)
|
167 |
|
168 |
-
|
169 |
-
|
170 |
-
print(f"> Total split documents: {len(docs_out)}")
|
171 |
-
|
172 |
-
print(docs_out)
|
173 |
-
|
174 |
-
vectorstore_func = FAISS.from_documents(documents=docs_out, embedding=embeddings)
|
175 |
-
|
176 |
-
chatf.vectorstore = vectorstore_func
|
177 |
-
|
178 |
-
out_message = "Document processing complete"
|
179 |
-
|
180 |
-
return out_message, vectorstore_func
|
181 |
-
# Gradio chat
|
182 |
|
183 |
|
184 |
###
|
@@ -188,24 +175,40 @@ def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
|
|
188 |
app = gr.Blocks(theme = gr.themes.Base(), fill_width=True)#css=".gradio-container {background-color: black}")
|
189 |
|
190 |
with app:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
ingest_text = gr.State()
|
192 |
ingest_metadata = gr.State()
|
193 |
ingest_docs = gr.State()
|
194 |
|
195 |
model_type_state = gr.State(model_type)
|
196 |
-
|
197 |
-
|
|
|
198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
relevant_query_state = gr.Checkbox(value=True, visible=False)
|
200 |
|
201 |
-
|
|
|
202 |
tokenizer_state = gr.State() # chatf.tokenizer (gives error)
|
203 |
|
204 |
chat_history_state = gr.State()
|
205 |
instruction_prompt_out = gr.State()
|
206 |
|
207 |
session_hash_state = gr.State()
|
208 |
-
|
|
|
209 |
|
210 |
session_hash_textbox = gr.Textbox(value="", visible=False)
|
211 |
s3_logs_output_textbox = gr.Textbox(label="S3 logs", visible=False)
|
@@ -219,21 +222,18 @@ with app:
|
|
219 |
|
220 |
gr.Markdown("<h1><center>Lightweight PDF / web page QA bot</center></h1>")
|
221 |
|
222 |
-
gr.Markdown("Chat with PDF, web page or (new) csv/Excel documents. The default is a small model (
|
223 |
-
|
224 |
-
with gr.Accordion(label="Use Gemini or AWS Claude model", open=False, visible=False):
|
225 |
-
api_model_choice = gr.Dropdown(value = "None", choices = ["gemini-1.5-flash-002", "gemini-1.5-pro-002", "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0", "None"], label="LLM model to use", multiselect=False, interactive=True, visible=False)
|
226 |
-
in_api_key = gr.Textbox(value = "", label="Enter Gemini API key (only if using Google API models)", lines=1, type="password",interactive=True, visible=False)
|
227 |
|
228 |
with gr.Row():
|
229 |
-
current_source = gr.Textbox(label="Current data source(s)", value=
|
230 |
current_model = gr.Textbox(label="Current model", value=model_type, scale = 3)
|
231 |
|
232 |
with gr.Tab("Chatbot"):
|
233 |
|
234 |
with gr.Row():
|
235 |
#chat_height = 500
|
236 |
-
chatbot = gr.Chatbot(avatar_images=('user.jfif', 'bot.jpg'), scale = 1, resizable=True, show_copy_all_button=True, show_copy_button=True, show_share_button=True, type='
|
237 |
with gr.Accordion("Open this tab to see the source paragraphs used to generate the answer", open = True):
|
238 |
sources = gr.HTML(value = "Source paragraphs with the most relevant text will appear here") # , height=chat_height
|
239 |
|
@@ -245,17 +245,12 @@ with app:
|
|
245 |
with gr.Row():
|
246 |
submit = gr.Button(value="Send message", variant="primary", scale = 4)
|
247 |
clear = gr.Button(value="Clear chat", variant="secondary", scale=1)
|
248 |
-
stop = gr.Button(value="Stop generating", variant="
|
249 |
-
|
250 |
-
examples_set = gr.Radio(label="Examples for the Lambeth Borough Plan",
|
251 |
-
#value = "What were the five pillars of the previous borough plan?",
|
252 |
-
choices=["What were the five pillars of the previous borough plan?",
|
253 |
-
"What is the vision statement for Lambeth?",
|
254 |
-
"What are the commitments for Lambeth?",
|
255 |
-
"What are the 2030 outcomes for Lambeth?"])
|
256 |
-
|
257 |
-
current_topic = gr.Textbox(label="Feature currently disabled - Keywords related to current conversation topic.", placeholder="Keywords related to the conversation topic will appear here", visible=False)
|
258 |
|
|
|
|
|
|
|
|
|
259 |
|
260 |
with gr.Tab("Load in a different file to chat with"):
|
261 |
with gr.Accordion("PDF file", open = False):
|
@@ -281,75 +276,91 @@ with app:
|
|
281 |
out_passages = gr.Slider(minimum=1, value = 2, maximum=10, step=1, label="Choose number of passages to retrieve from the document. Numbers greater than 2 may lead to increased hallucinations or input text being truncated.")
|
282 |
temp_slide = gr.Slider(minimum=0.1, value = 0.5, maximum=1, step=0.1, label="Choose temperature setting for response generation.")
|
283 |
with gr.Row():
|
284 |
-
model_choice = gr.Radio(label="Choose a chat model", value=
|
|
|
285 |
change_model_button = gr.Button(value="Load model", scale=0)
|
286 |
with gr.Accordion("Choose number of model layers to send to GPU (WARNING: please don't modify unless you are sure you have a GPU).", open = False):
|
287 |
gpu_layer_choice = gr.Slider(label="Choose number of model layers to send to GPU.", value=0, minimum=0, maximum=100, step = 1, visible=True)
|
288 |
|
289 |
-
load_text = gr.Text(label="Load status")
|
290 |
-
|
291 |
|
292 |
gr.HTML(
|
293 |
-
"<center>This app is
|
294 |
)
|
295 |
|
296 |
examples_set.change(fn=chatf.update_message, inputs=[examples_set], outputs=[message])
|
297 |
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
# Load in a pdf
|
305 |
load_pdf_click = load_pdf.click(ing.parse_file, inputs=[in_pdf], outputs=[ingest_text, current_source]).\
|
306 |
success(ing.text_to_docs, inputs=[ingest_text], outputs=[ingest_docs]).\
|
307 |
-
success(embed_faiss_save_to_zip, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state, file_out_box]).\
|
308 |
success(chatf.hide_block, outputs = [examples_set])
|
309 |
|
310 |
# Load in a webpage
|
311 |
load_web_click = load_web.click(ing.parse_html, inputs=[in_web, in_div], outputs=[ingest_text, ingest_metadata, current_source]).\
|
312 |
success(ing.html_text_to_docs, inputs=[ingest_text, ingest_metadata], outputs=[ingest_docs]).\
|
313 |
-
success(embed_faiss_save_to_zip, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state, file_out_box]).\
|
314 |
success(chatf.hide_block, outputs = [examples_set])
|
315 |
|
316 |
# Load in a csv/excel file
|
317 |
load_csv_click = load_csv.click(ing.parse_csv_or_excel, inputs=[in_csv, in_text_column], outputs=[ingest_text, current_source]).\
|
318 |
success(ing.csv_excel_text_to_docs, inputs=[ingest_text, in_text_column], outputs=[ingest_docs]).\
|
319 |
-
success(embed_faiss_save_to_zip, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state, file_out_box]).\
|
320 |
success(chatf.hide_block, outputs = [examples_set])
|
|
|
321 |
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
response_click = submit.click(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages, api_model_choice, in_api_key], outputs=[chat_history_state, sources, instruction_prompt_out, relevant_query_state], queue=False, api_name="retrieval").\
|
326 |
-
success(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
|
327 |
-
success(chatf.produce_streaming_answer_chatbot, inputs=[chatbot, instruction_prompt_out, model_type_state, temp_slide, relevant_query_state], outputs=chatbot)
|
328 |
-
response_click.success(chatf.highlight_found_text, [chatbot, sources], [sources]).\
|
329 |
-
success(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
|
330 |
-
success(lambda: chatf.restore_interactivity(), None, [message], queue=False)
|
331 |
-
|
332 |
-
response_enter = message.submit(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages, api_model_choice, in_api_key], outputs=[chat_history_state, sources, instruction_prompt_out, relevant_query_state], queue=False).\
|
333 |
-
success(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
|
334 |
-
success(chatf.produce_streaming_answer_chatbot, [chatbot, instruction_prompt_out, model_type_state, temp_slide, relevant_query_state], chatbot)
|
335 |
-
response_enter.success(chatf.highlight_found_text, [chatbot, sources], [sources]).\
|
336 |
-
success(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
|
337 |
-
success(lambda: chatf.restore_interactivity(), None, [message], queue=False)
|
338 |
-
|
339 |
-
# Stop box
|
340 |
-
stop.click(fn=None, inputs=None, outputs=None, cancels=[response_click, response_enter])
|
341 |
-
|
342 |
-
# Clear box
|
343 |
-
clear.click(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic])
|
344 |
-
clear.click(lambda: None, None, chatbot, queue=False)
|
345 |
|
346 |
-
|
347 |
-
|
|
|
|
|
|
|
348 |
|
349 |
###
|
350 |
# LOGGING AND ON APP LOAD FUNCTIONS
|
351 |
-
###
|
352 |
-
|
|
|
|
|
|
|
353 |
|
354 |
# Log usernames and times of access to file (to know who is using the app when running on AWS)
|
355 |
access_callback = gr.CSVLogger()
|
@@ -358,12 +369,8 @@ with app:
|
|
358 |
session_hash_textbox.change(lambda *args: access_callback.flag(list(args)), [session_hash_textbox], None, preprocess=False).\
|
359 |
success(fn = upload_file_to_s3, inputs=[access_logs_state, access_s3_logs_loc_state], outputs=[s3_logs_output_textbox])
|
360 |
|
361 |
-
# Launch the Gradio app
|
362 |
-
COGNITO_AUTH = get_or_create_env_var('COGNITO_AUTH', '0')
|
363 |
-
print(f'The value of COGNITO_AUTH is {COGNITO_AUTH}')
|
364 |
-
|
365 |
if __name__ == "__main__":
|
366 |
-
if
|
367 |
-
app.queue().launch(show_error=True, auth=authenticate_user, max_file_size=
|
368 |
else:
|
369 |
-
app.queue().launch(show_error=True, inbrowser=True, max_file_size=
|
|
|
|
|
|
|
1 |
import os
|
|
|
|
|
2 |
from typing import Type
|
3 |
+
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
|
4 |
from langchain_community.vectorstores import FAISS
|
5 |
import gradio as gr
|
6 |
import pandas as pd
|
7 |
+
from torch import float16
|
|
|
|
|
|
|
8 |
from llama_cpp import Llama
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
|
11 |
+
import zipfile
|
12 |
+
|
13 |
from chatfuncs.ingest import embed_faiss_save_to_zip
|
|
|
14 |
|
15 |
+
from chatfuncs.helper_functions import get_connection_params, reveal_feedback_buttons, wipe_logs
|
16 |
from chatfuncs.aws_functions import upload_file_to_s3
|
|
|
17 |
from chatfuncs.auth import authenticate_user
|
18 |
+
from chatfuncs.config import FEEDBACK_LOGS_FOLDER, ACCESS_LOGS_FOLDER, USAGE_LOGS_FOLDER, HOST_NAME, COGNITO_AUTH, INPUT_FOLDER, OUTPUT_FOLDER, MAX_QUEUE_SIZE, DEFAULT_CONCURRENCY_LIMIT, MAX_FILE_SIZE, GRADIO_SERVER_PORT, ROOT_PATH, DEFAULT_EMBEDDINGS_LOCATION, EMBEDDINGS_MODEL_NAME, DEFAULT_DATA_SOURCE, HF_TOKEN, LARGE_MODEL_REPO_ID, LARGE_MODEL_GGUF_FILE, LARGE_MODEL_NAME, SMALL_MODEL_NAME, SMALL_MODEL_REPO_ID, DEFAULT_DATA_SOURCE_NAME, DEFAULT_EXAMPLES, DEFAULT_MODEL_CHOICES
|
19 |
+
from chatfuncs.model_load import torch_device, gpu_config, cpu_config, context_length
|
20 |
+
import chatfuncs.chatfuncs as chatf
|
21 |
+
import chatfuncs.ingest as ing
|
22 |
|
23 |
PandasDataFrame = Type[pd.DataFrame]
|
24 |
|
25 |
from datetime import datetime
|
26 |
today_rev = datetime.now().strftime("%Y%m%d")
|
27 |
|
28 |
+
host_name = HOST_NAME
|
29 |
+
access_logs_data_folder = ACCESS_LOGS_FOLDER
|
30 |
+
feedback_data_folder = FEEDBACK_LOGS_FOLDER
|
31 |
+
usage_data_folder = USAGE_LOGS_FOLDER
|
32 |
|
33 |
+
if isinstance(DEFAULT_EXAMPLES, str): default_examples_set = eval(DEFAULT_EXAMPLES)
|
34 |
+
if isinstance(DEFAULT_MODEL_CHOICES, str): default_model_choices = eval(DEFAULT_MODEL_CHOICES)
|
|
|
|
|
|
|
35 |
|
36 |
# Disable cuda devices if necessary
|
37 |
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
38 |
|
|
|
|
|
39 |
|
40 |
###
|
41 |
# Load preset embeddings, vectorstore, and model
|
42 |
###
|
43 |
|
44 |
+
def load_embeddings_model(embeddings_model = EMBEDDINGS_MODEL_NAME):
|
45 |
|
46 |
+
embeddings_func = HuggingFaceEmbeddings(model_name=embeddings_model)
|
47 |
|
48 |
+
#global embeddings
|
49 |
|
50 |
+
#embeddings = embeddings_func
|
51 |
|
52 |
+
return embeddings_func
|
53 |
|
54 |
+
def get_faiss_store(faiss_vstore_folder:str, embeddings_model:object):
|
55 |
|
|
|
|
|
56 |
with zipfile.ZipFile(faiss_vstore_folder + '/' + faiss_vstore_folder + '.zip', 'r') as zip_ref:
|
57 |
zip_ref.extractall(faiss_vstore_folder)
|
58 |
|
59 |
+
faiss_vstore = FAISS.load_local(folder_path=faiss_vstore_folder, embeddings=embeddings_model, allow_dangerous_deserialization=True)
|
60 |
os.remove(faiss_vstore_folder + "/index.faiss")
|
61 |
os.remove(faiss_vstore_folder + "/index.pkl")
|
62 |
|
63 |
+
#global vectorstore
|
64 |
|
65 |
+
#vectorstore = faiss_vstore
|
66 |
|
67 |
+
return faiss_vstore #vectorstore
|
68 |
|
69 |
+
# Load in default embeddings and embeddings model name
|
70 |
+
embeddings_model = load_embeddings_model(EMBEDDINGS_MODEL_NAME)
|
71 |
+
vectorstore = get_faiss_store(faiss_vstore_folder=DEFAULT_EMBEDDINGS_LOCATION,embeddings_model=embeddings_model)#globals()["embeddings"])
|
72 |
|
73 |
+
chatf.embeddings = embeddings_model
|
74 |
+
chatf.vectorstore = vectorstore
|
75 |
|
76 |
+
def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings_model=embeddings_model):
|
77 |
|
78 |
+
print(f"> Total split documents: {len(docs_out)}")
|
79 |
+
|
80 |
+
print(docs_out)
|
81 |
|
82 |
+
vectorstore_func = FAISS.from_documents(documents=docs_out, embedding=embeddings_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
+
chatf.vectorstore = vectorstore_func
|
85 |
+
|
86 |
+
out_message = "Document processing complete"
|
87 |
+
|
88 |
+
return out_message, vectorstore_func
|
89 |
+
|
90 |
+
|
91 |
+
def create_hf_model(model_name:str, hf_token=HF_TOKEN):
|
92 |
+
if torch_device == "cuda":
|
93 |
+
if "flan" in model_name:
|
94 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")#, torch_dtype=torch.float16)
|
95 |
+
else:
|
96 |
+
if hf_token:
|
97 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", token=hf_token) # , torch_dtype=float16
|
98 |
+
else:
|
99 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") # , torch_dtype=float16
|
100 |
+
else:
|
101 |
+
if "flan" in model_name:
|
102 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)#, torch_dtype=torch.float16)
|
103 |
+
else:
|
104 |
+
if hf_token:
|
105 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token) # , torch_dtype=float16
|
106 |
+
else:
|
107 |
+
model = AutoModelForCausalLM.from_pretrained(model_name) # , torch_dtype=float16
|
108 |
+
|
109 |
+
if hf_token:
|
110 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = context_length, token=hf_token)
|
111 |
+
else:
|
112 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = context_length)
|
113 |
+
|
114 |
+
return model, tokenizer
|
115 |
+
|
116 |
+
def load_model(model_type:str, gpu_layers:int, gpu_config:dict=gpu_config, cpu_config:dict=cpu_config, torch_device:str=torch_device):
|
117 |
+
print("Loading model")
|
118 |
+
|
119 |
+
if model_type == LARGE_MODEL_NAME:
|
120 |
if torch_device == "cuda":
|
121 |
gpu_config.update_gpu(gpu_layers)
|
122 |
print("Loading with", gpu_config.n_gpu_layers, "model layers sent to GPU.")
|
|
|
126 |
|
127 |
print("Loading with", cpu_config.n_gpu_layers, "model layers sent to GPU.")
|
128 |
|
|
|
|
|
|
|
129 |
try:
|
130 |
model = Llama(
|
131 |
model_path=hf_hub_download(
|
132 |
+
repo_id=LARGE_MODEL_REPO_ID,
|
133 |
+
filename=LARGE_MODEL_GGUF_FILE
|
134 |
),
|
135 |
**vars(gpu_config) # change n_gpu_layers if you have more or less VRAM
|
136 |
)
|
137 |
|
138 |
except Exception as e:
|
139 |
+
print("GPU load failed", e, "loading CPU version instead")
|
|
|
140 |
model = Llama(
|
141 |
model_path=hf_hub_download(
|
142 |
+
repo_id=LARGE_MODEL_REPO_ID,
|
143 |
+
filename=LARGE_MODEL_GGUF_FILE
|
144 |
),
|
145 |
**vars(cpu_config)
|
146 |
)
|
147 |
|
148 |
tokenizer = []
|
149 |
|
150 |
+
if model_type == SMALL_MODEL_NAME:
|
151 |
# Huggingface chat model
|
152 |
+
hf_checkpoint = SMALL_MODEL_REPO_ID# 'declare-lab/flan-alpaca-large'#'declare-lab/flan-alpaca-base' # # # 'Qwen/Qwen1.5-0.5B-Chat' #
|
153 |
|
154 |
+
model, tokenizer = create_hf_model(model_name = hf_checkpoint)
|
155 |
|
156 |
+
else:
|
157 |
+
model = model_type
|
158 |
+
tokenizer = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
+
chatf.model_object = model
|
161 |
chatf.tokenizer = tokenizer
|
162 |
chatf.model_type = model_type
|
163 |
|
164 |
load_confirmation = "Finished loading model: " + model_type
|
165 |
|
166 |
print(load_confirmation)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
+
return model_type, load_confirmation, model_type#model, tokenizer, model_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
|
171 |
###
|
|
|
175 |
app = gr.Blocks(theme = gr.themes.Base(), fill_width=True)#css=".gradio-container {background-color: black}")
|
176 |
|
177 |
with app:
|
178 |
+
model_type = SMALL_MODEL_NAME
|
179 |
+
load_model(model_type, 0, gpu_config, cpu_config, torch_device) # chatf.model_object, chatf.tokenizer, chatf.model_type =
|
180 |
+
|
181 |
+
# Both models are loaded on app initialisation so that users don't have to wait for the models to be downloaded
|
182 |
+
#model_type = "Phi 3.5 Mini (larger, slow)"
|
183 |
+
#load_model(model_type, gpu_layers, gpu_config, cpu_config, torch_device)
|
184 |
+
|
185 |
ingest_text = gr.State()
|
186 |
ingest_metadata = gr.State()
|
187 |
ingest_docs = gr.State()
|
188 |
|
189 |
model_type_state = gr.State(model_type)
|
190 |
+
gpu_config_state = gr.State(gpu_config)
|
191 |
+
cpu_config_state = gr.State(cpu_config)
|
192 |
+
torch_device_state = gr.State(torch_device)
|
193 |
|
194 |
+
# Embeddings related vars
|
195 |
+
embeddings_model_object_state = gr.State(embeddings_model)#globals()["embeddings"])
|
196 |
+
vectorstore_state = gr.State(vectorstore)#globals()["vectorstore"])
|
197 |
+
default_embeddings_store_text = gr.Textbox(value=DEFAULT_EMBEDDINGS_LOCATION, visible=False)
|
198 |
+
|
199 |
+
# Is the query relevant to the sources provided?
|
200 |
relevant_query_state = gr.Checkbox(value=True, visible=False)
|
201 |
|
202 |
+
# Storing model objects in state doesn't seem to work, so we have to load in different models in roundabout ways
|
203 |
+
model_state = gr.State() # chatf.model_object (gives error)
|
204 |
tokenizer_state = gr.State() # chatf.tokenizer (gives error)
|
205 |
|
206 |
chat_history_state = gr.State()
|
207 |
instruction_prompt_out = gr.State()
|
208 |
|
209 |
session_hash_state = gr.State()
|
210 |
+
output_folder_textbox = gr.Textbox(value=OUTPUT_FOLDER, visible=False)
|
211 |
+
input_folder_textbox = gr.Textbox(value=INPUT_FOLDER, visible=False)
|
212 |
|
213 |
session_hash_textbox = gr.Textbox(value="", visible=False)
|
214 |
s3_logs_output_textbox = gr.Textbox(label="S3 logs", visible=False)
|
|
|
222 |
|
223 |
gr.Markdown("<h1><center>Lightweight PDF / web page QA bot</center></h1>")
|
224 |
|
225 |
+
gr.Markdown(f"""Chat with PDF, web page or (new) csv/Excel documents. The default is a small model ({SMALL_MODEL_NAME}), that can only answer specific questions that are answered in the text. It cannot give overall impressions of, or summarise the document. The alternative ({LARGE_MODEL_NAME}), can reason a little better, but is much slower (See Advanced settings tab).\n\nBy default '[{DEFAULT_DATA_SOURCE_NAME}]({DEFAULT_DATA_SOURCE})' is loaded.If you want to talk about another document or web page, please select from the second tab. If switching topic, please click the 'Clear chat' button.\n\nCaution: This is a public app. Please ensure that the document you upload is not sensitive is any way as other users may see it! Also, please note that LLM chatbots may give incomplete or incorrect information, so please use with care.""")
|
226 |
+
|
|
|
|
|
|
|
227 |
|
228 |
with gr.Row():
|
229 |
+
current_source = gr.Textbox(label="Current data source(s)", value=DEFAULT_DATA_SOURCE, scale = 10)
|
230 |
current_model = gr.Textbox(label="Current model", value=model_type, scale = 3)
|
231 |
|
232 |
with gr.Tab("Chatbot"):
|
233 |
|
234 |
with gr.Row():
|
235 |
#chat_height = 500
|
236 |
+
chatbot = gr.Chatbot(value=None, avatar_images=('user.jfif', 'bot.jpg'), scale = 1, resizable=True, show_copy_all_button=True, show_copy_button=True, show_share_button=True, type='messages') # , height=chat_height
|
237 |
with gr.Accordion("Open this tab to see the source paragraphs used to generate the answer", open = True):
|
238 |
sources = gr.HTML(value = "Source paragraphs with the most relevant text will appear here") # , height=chat_height
|
239 |
|
|
|
245 |
with gr.Row():
|
246 |
submit = gr.Button(value="Send message", variant="primary", scale = 4)
|
247 |
clear = gr.Button(value="Clear chat", variant="secondary", scale=1)
|
248 |
+
stop = gr.Button(value="Stop generating", variant="stop", scale=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
+
examples_set = gr.Radio(label="Example questions",
|
251 |
+
choices=default_examples_set)
|
252 |
+
|
253 |
+
current_topic = gr.Textbox(label="Feature currently disabled - Keywords related to current conversation topic.", placeholder="Keywords related to the conversation topic will appear here", visible=False)
|
254 |
|
255 |
with gr.Tab("Load in a different file to chat with"):
|
256 |
with gr.Accordion("PDF file", open = False):
|
|
|
276 |
out_passages = gr.Slider(minimum=1, value = 2, maximum=10, step=1, label="Choose number of passages to retrieve from the document. Numbers greater than 2 may lead to increased hallucinations or input text being truncated.")
|
277 |
temp_slide = gr.Slider(minimum=0.1, value = 0.5, maximum=1, step=0.1, label="Choose temperature setting for response generation.")
|
278 |
with gr.Row():
|
279 |
+
model_choice = gr.Radio(label="Choose a chat model", value=SMALL_MODEL_NAME, choices = default_model_choices)
|
280 |
+
in_api_key = gr.Textbox(value = "", label="Enter Gemini API key (only if using Google API models)", lines=1, type="password",interactive=True, visible=True)
|
281 |
change_model_button = gr.Button(value="Load model", scale=0)
|
282 |
with gr.Accordion("Choose number of model layers to send to GPU (WARNING: please don't modify unless you are sure you have a GPU).", open = False):
|
283 |
gpu_layer_choice = gr.Slider(label="Choose number of model layers to send to GPU.", value=0, minimum=0, maximum=100, step = 1, visible=True)
|
284 |
|
285 |
+
load_text = gr.Text(label="Load status")
|
|
|
286 |
|
287 |
gr.HTML(
|
288 |
+
"<center>This app is powered by Gradio, Transformers, and Llama.cpp.</center>"
|
289 |
)
|
290 |
|
291 |
examples_set.change(fn=chatf.update_message, inputs=[examples_set], outputs=[message])
|
292 |
|
293 |
+
###
|
294 |
+
# CHAT PAGE
|
295 |
+
###
|
296 |
+
|
297 |
+
# Click to send message
|
298 |
+
response_click = submit.click(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_model_object_state, model_type_state, out_passages, in_api_key], outputs=[chat_history_state, sources, instruction_prompt_out, relevant_query_state], queue=False, api_name="retrieval").\
|
299 |
+
success(chatf.turn_off_interactivity, inputs=None, outputs=[message, submit], queue=False).\
|
300 |
+
success(chatf.produce_streaming_answer_chatbot, inputs=[chatbot, instruction_prompt_out, model_type_state, temp_slide, relevant_query_state, chat_history_state, in_api_key], outputs=chatbot)
|
301 |
+
response_click.success(chatf.highlight_found_text, [chatbot, sources], [sources]).\
|
302 |
+
success(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
|
303 |
+
success(lambda: chatf.restore_interactivity(), None, [message, submit], queue=False)
|
304 |
+
|
305 |
+
# Press enter to send message
|
306 |
+
response_enter = message.submit(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_model_object_state, model_type_state, out_passages, in_api_key], outputs=[chat_history_state, sources, instruction_prompt_out, relevant_query_state], queue=False).\
|
307 |
+
success(chatf.turn_off_interactivity, inputs=None, outputs=[message, submit], queue=False).\
|
308 |
+
success(chatf.produce_streaming_answer_chatbot, [chatbot, instruction_prompt_out, model_type_state, temp_slide, relevant_query_state, chat_history_state, in_api_key], chatbot)
|
309 |
+
response_enter.success(chatf.highlight_found_text, [chatbot, sources], [sources]).\
|
310 |
+
success(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
|
311 |
+
success(lambda: chatf.restore_interactivity(), None, [message, submit], queue=False)
|
312 |
+
|
313 |
+
# Stop box
|
314 |
+
stop.click(fn=None, inputs=None, outputs=None, cancels=[response_click, response_enter])
|
315 |
+
|
316 |
+
# Clear box
|
317 |
+
clear.click(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic])
|
318 |
+
clear.click(lambda: None, None, chatbot, queue=False)
|
319 |
+
|
320 |
+
# Thumbs up or thumbs down voting function
|
321 |
+
chatbot.like(chatf.vote, [chat_history_state, instruction_prompt_out, model_type_state], None)
|
322 |
+
|
323 |
+
|
324 |
+
###
|
325 |
+
# LOAD NEW DATA PAGE
|
326 |
+
###
|
327 |
|
328 |
# Load in a pdf
|
329 |
load_pdf_click = load_pdf.click(ing.parse_file, inputs=[in_pdf], outputs=[ingest_text, current_source]).\
|
330 |
success(ing.text_to_docs, inputs=[ingest_text], outputs=[ingest_docs]).\
|
331 |
+
success(embed_faiss_save_to_zip, inputs=[ingest_docs, output_folder_textbox, embeddings_model_object_state], outputs=[ingest_embed_out, vectorstore_state, file_out_box]).\
|
332 |
success(chatf.hide_block, outputs = [examples_set])
|
333 |
|
334 |
# Load in a webpage
|
335 |
load_web_click = load_web.click(ing.parse_html, inputs=[in_web, in_div], outputs=[ingest_text, ingest_metadata, current_source]).\
|
336 |
success(ing.html_text_to_docs, inputs=[ingest_text, ingest_metadata], outputs=[ingest_docs]).\
|
337 |
+
success(embed_faiss_save_to_zip, inputs=[ingest_docs, output_folder_textbox, embeddings_model_object_state], outputs=[ingest_embed_out, vectorstore_state, file_out_box]).\
|
338 |
success(chatf.hide_block, outputs = [examples_set])
|
339 |
|
340 |
# Load in a csv/excel file
|
341 |
load_csv_click = load_csv.click(ing.parse_csv_or_excel, inputs=[in_csv, in_text_column], outputs=[ingest_text, current_source]).\
|
342 |
success(ing.csv_excel_text_to_docs, inputs=[ingest_text, in_text_column], outputs=[ingest_docs]).\
|
343 |
+
success(embed_faiss_save_to_zip, inputs=[ingest_docs, output_folder_textbox, embeddings_model_object_state], outputs=[ingest_embed_out, vectorstore_state, file_out_box]).\
|
344 |
success(chatf.hide_block, outputs = [examples_set])
|
345 |
+
|
346 |
|
347 |
+
###
|
348 |
+
# LOAD MODEL PAGE
|
349 |
+
###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
|
351 |
+
change_model_button.click(fn=chatf.turn_off_interactivity, inputs=None, outputs=[message, submit], queue=False).\
|
352 |
+
success(fn=load_model, inputs=[model_choice, gpu_layer_choice], outputs = [model_type_state, load_text, current_model]).\
|
353 |
+
success(lambda: chatf.restore_interactivity(), None, [message, submit], queue=False).\
|
354 |
+
success(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic]).\
|
355 |
+
success(lambda: None, None, chatbot, queue=False)
|
356 |
|
357 |
###
|
358 |
# LOGGING AND ON APP LOAD FUNCTIONS
|
359 |
+
###
|
360 |
+
# Load in default model and embeddings for each user
|
361 |
+
app.load(get_connection_params, inputs=None, outputs=[session_hash_state, output_folder_textbox, session_hash_textbox, input_folder_textbox]).\
|
362 |
+
success(load_model, inputs=[model_type_state, gpu_layer_choice, gpu_config_state, cpu_config_state, torch_device_state], outputs=[model_type_state, load_text, current_model]).\
|
363 |
+
success(get_faiss_store, inputs=[default_embeddings_store_text, embeddings_model_object_state], outputs=[vectorstore_state])
|
364 |
|
365 |
# Log usernames and times of access to file (to know who is using the app when running on AWS)
|
366 |
access_callback = gr.CSVLogger()
|
|
|
369 |
session_hash_textbox.change(lambda *args: access_callback.flag(list(args)), [session_hash_textbox], None, preprocess=False).\
|
370 |
success(fn = upload_file_to_s3, inputs=[access_logs_state, access_s3_logs_loc_state], outputs=[s3_logs_output_textbox])
|
371 |
|
|
|
|
|
|
|
|
|
372 |
if __name__ == "__main__":
|
373 |
+
if COGNITO_AUTH == "1":
|
374 |
+
app.queue(max_size=int(MAX_QUEUE_SIZE), default_concurrency_limit=int(DEFAULT_CONCURRENCY_LIMIT)).launch(show_error=True, inbrowser=True, auth=authenticate_user, max_file_size=MAX_FILE_SIZE, server_port=GRADIO_SERVER_PORT, root_path=ROOT_PATH)
|
375 |
else:
|
376 |
+
app.queue(max_size=int(MAX_QUEUE_SIZE), default_concurrency_limit=int(DEFAULT_CONCURRENCY_LIMIT)).launch(show_error=True, inbrowser=True, max_file_size=MAX_FILE_SIZE, server_port=GRADIO_SERVER_PORT, root_path=ROOT_PATH)
|
chatfuncs/auth.py
CHANGED
@@ -1,14 +1,22 @@
|
|
1 |
-
|
2 |
import boto3
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
def authenticate_user(username, password, user_pool_id=
|
12 |
"""Authenticates a user against an AWS Cognito user pool.
|
13 |
|
14 |
Args:
|
@@ -16,22 +24,39 @@ def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=c
|
|
16 |
client_id (str): The ID of the Cognito user pool client.
|
17 |
username (str): The username of the user.
|
18 |
password (str): The password of the user.
|
|
|
19 |
|
20 |
Returns:
|
21 |
bool: True if the user is authenticated, False otherwise.
|
22 |
"""
|
23 |
|
24 |
-
client = boto3.client('cognito-idp') # Cognito Identity Provider client
|
|
|
|
|
|
|
25 |
|
26 |
try:
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
AuthFlow='USER_PASSWORD_AUTH',
|
29 |
AuthParameters={
|
30 |
'USERNAME': username,
|
31 |
'PASSWORD': password,
|
|
|
32 |
},
|
33 |
ClientId=client_id
|
34 |
-
|
35 |
|
36 |
# If successful, you'll receive an AuthenticationResult in the response
|
37 |
if response.get('AuthenticationResult'):
|
@@ -44,5 +69,7 @@ def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=c
|
|
44 |
except client.exceptions.UserNotFoundException:
|
45 |
return False
|
46 |
except Exception as e:
|
47 |
-
|
48 |
-
|
|
|
|
|
|
1 |
+
#import os
|
2 |
import boto3
|
3 |
+
#import gradio as gr
|
4 |
+
import hmac
|
5 |
+
import hashlib
|
6 |
+
import base64
|
7 |
+
from chatfuncs.config import AWS_CLIENT_ID, AWS_CLIENT_SECRET, AWS_USER_POOL_ID, AWS_REGION
|
8 |
|
9 |
+
def calculate_secret_hash(client_id:str, client_secret:str, username:str):
|
10 |
+
message = username + client_id
|
11 |
+
dig = hmac.new(
|
12 |
+
str(client_secret).encode('utf-8'),
|
13 |
+
msg=str(message).encode('utf-8'),
|
14 |
+
digestmod=hashlib.sha256
|
15 |
+
).digest()
|
16 |
+
secret_hash = base64.b64encode(dig).decode()
|
17 |
+
return secret_hash
|
18 |
|
19 |
+
def authenticate_user(username:str, password:str, user_pool_id:str=AWS_USER_POOL_ID, client_id:str=AWS_CLIENT_ID, client_secret:str=AWS_CLIENT_SECRET):
|
20 |
"""Authenticates a user against an AWS Cognito user pool.
|
21 |
|
22 |
Args:
|
|
|
24 |
client_id (str): The ID of the Cognito user pool client.
|
25 |
username (str): The username of the user.
|
26 |
password (str): The password of the user.
|
27 |
+
client_secret (str): The client secret of the app client
|
28 |
|
29 |
Returns:
|
30 |
bool: True if the user is authenticated, False otherwise.
|
31 |
"""
|
32 |
|
33 |
+
client = boto3.client('cognito-idp', region_name=AWS_REGION) # Cognito Identity Provider client
|
34 |
+
|
35 |
+
# Compute the secret hash
|
36 |
+
secret_hash = calculate_secret_hash(client_id, client_secret, username)
|
37 |
|
38 |
try:
|
39 |
+
|
40 |
+
if client_secret == '':
|
41 |
+
response = client.initiate_auth(
|
42 |
+
AuthFlow='USER_PASSWORD_AUTH',
|
43 |
+
AuthParameters={
|
44 |
+
'USERNAME': username,
|
45 |
+
'PASSWORD': password,
|
46 |
+
},
|
47 |
+
ClientId=client_id
|
48 |
+
)
|
49 |
+
|
50 |
+
else:
|
51 |
+
response = client.initiate_auth(
|
52 |
AuthFlow='USER_PASSWORD_AUTH',
|
53 |
AuthParameters={
|
54 |
'USERNAME': username,
|
55 |
'PASSWORD': password,
|
56 |
+
'SECRET_HASH': secret_hash
|
57 |
},
|
58 |
ClientId=client_id
|
59 |
+
)
|
60 |
|
61 |
# If successful, you'll receive an AuthenticationResult in the response
|
62 |
if response.get('AuthenticationResult'):
|
|
|
69 |
except client.exceptions.UserNotFoundException:
|
70 |
return False
|
71 |
except Exception as e:
|
72 |
+
out_message = f"An error occurred: {e}"
|
73 |
+
print(out_message)
|
74 |
+
raise Exception(out_message)
|
75 |
+
return False
|
chatfuncs/chatfuncs.py
CHANGED
@@ -5,65 +5,60 @@ from typing import Type, Dict, List, Tuple
|
|
5 |
import time
|
6 |
from itertools import compress
|
7 |
import pandas as pd
|
8 |
-
import
|
9 |
-
|
10 |
-
|
11 |
-
import
|
12 |
-
|
13 |
-
from transformers import pipeline, TextIteratorStreamer
|
14 |
-
|
15 |
-
# Alternative model sources
|
16 |
-
#from dataclasses import asdict, dataclass
|
17 |
-
|
18 |
-
# Langchain functions
|
19 |
-
from langchain.prompts import PromptTemplate
|
20 |
-
from langchain_community.vectorstores import FAISS
|
21 |
-
from langchain_community.retrievers import SVMRetriever
|
22 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
23 |
-
from langchain.docstore.document import Document
|
24 |
-
|
25 |
-
# For keyword extraction (not currently used)
|
26 |
-
#import nltk
|
27 |
-
#nltk.download('wordnet')
|
28 |
from nltk.corpus import stopwords
|
29 |
from nltk.tokenize import RegexpTokenizer
|
30 |
from nltk.stem import WordNetLemmatizer
|
31 |
-
#from nltk.stem.snowball import SnowballStemmer
|
32 |
from keybert import KeyBERT
|
33 |
|
34 |
# For Name Entity Recognition model
|
35 |
#from span_marker import SpanMarkerModel # Not currently used
|
36 |
|
37 |
-
|
38 |
# For BM25 retrieval
|
39 |
import bm25s
|
40 |
import Stemmer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
|
46 |
-
|
47 |
-
|
48 |
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
|
53 |
torch.cuda.empty_cache()
|
54 |
|
55 |
PandasDataFrame = Type[pd.DataFrame]
|
56 |
|
57 |
embeddings = None # global variable setup
|
|
|
58 |
vectorstore = None # global variable setup
|
59 |
model_type = None # global variable setup
|
60 |
|
61 |
max_memory_length = 0 # How long should the memory of the conversation last?
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
model = [] # Define empty list for model functions to run
|
66 |
-
tokenizer = [] # Define empty list for model functions to run
|
67 |
|
68 |
## Highlight text constants
|
69 |
hlt_chunk_size = 12
|
@@ -77,117 +72,53 @@ ner_model = []#SpanMarkerModel.from_pretrained("tomaarsen/span-marker-mbert-base
|
|
77 |
# Used to pull out keywords from chat history to add to user queries behind the scenes
|
78 |
kw_model = pipeline("feature-extraction", model="sentence-transformers/all-MiniLM-L6-v2")
|
79 |
|
80 |
-
# Currently set gpu_layers to 0 even with cuda due to persistent bugs in implementation with cuda
|
81 |
-
if torch.cuda.is_available():
|
82 |
-
torch_device = "cuda"
|
83 |
-
gpu_layers = 100
|
84 |
-
else:
|
85 |
-
torch_device = "cpu"
|
86 |
-
gpu_layers = 0
|
87 |
-
|
88 |
-
print("Running on device:", torch_device)
|
89 |
-
threads = 8 #torch.get_num_threads()
|
90 |
-
print("CPU threads:", threads)
|
91 |
-
|
92 |
-
# Qwen 2 0.5B (small, fast) Model parameters
|
93 |
-
temperature: float = 0.1
|
94 |
-
top_k: int = 3
|
95 |
-
top_p: float = 1
|
96 |
-
repetition_penalty: float = 1.15
|
97 |
-
flan_alpaca_repetition_penalty: float = 1.3
|
98 |
-
last_n_tokens: int = 64
|
99 |
-
max_new_tokens: int = 1024
|
100 |
-
seed: int = 42
|
101 |
-
reset: bool = False
|
102 |
-
stream: bool = True
|
103 |
-
threads: int = threads
|
104 |
-
batch_size:int = 256
|
105 |
-
context_length:int = 2048
|
106 |
-
sample = True
|
107 |
-
|
108 |
-
|
109 |
-
class CtransInitConfig_gpu:
|
110 |
-
def __init__(self,
|
111 |
-
last_n_tokens=last_n_tokens,
|
112 |
-
seed=seed,
|
113 |
-
n_threads=threads,
|
114 |
-
n_batch=batch_size,
|
115 |
-
n_ctx=4096,
|
116 |
-
n_gpu_layers=gpu_layers):
|
117 |
-
|
118 |
-
self.last_n_tokens = last_n_tokens
|
119 |
-
self.seed = seed
|
120 |
-
self.n_threads = n_threads
|
121 |
-
self.n_batch = n_batch
|
122 |
-
self.n_ctx = n_ctx
|
123 |
-
self.n_gpu_layers = n_gpu_layers
|
124 |
-
# self.stop: list[str] = field(default_factory=lambda: [stop_string])
|
125 |
-
|
126 |
-
def update_gpu(self, new_value):
|
127 |
-
self.n_gpu_layers = new_value
|
128 |
-
|
129 |
-
class CtransInitConfig_cpu(CtransInitConfig_gpu):
|
130 |
-
def __init__(self):
|
131 |
-
super().__init__()
|
132 |
-
self.n_gpu_layers = 0
|
133 |
-
|
134 |
-
gpu_config = CtransInitConfig_gpu()
|
135 |
-
cpu_config = CtransInitConfig_cpu()
|
136 |
-
|
137 |
-
|
138 |
-
class CtransGenGenerationConfig:
|
139 |
-
def __init__(self, temperature=temperature,
|
140 |
-
top_k=top_k,
|
141 |
-
top_p=top_p,
|
142 |
-
repeat_penalty=repetition_penalty,
|
143 |
-
seed=seed,
|
144 |
-
stream=stream,
|
145 |
-
max_tokens=max_new_tokens
|
146 |
-
):
|
147 |
-
self.temperature = temperature
|
148 |
-
self.top_k = top_k
|
149 |
-
self.top_p = top_p
|
150 |
-
self.repeat_penalty = repeat_penalty
|
151 |
-
self.seed = seed
|
152 |
-
self.max_tokens=max_tokens
|
153 |
-
self.stream = stream
|
154 |
-
|
155 |
-
def update_temp(self, new_value):
|
156 |
-
self.temperature = new_value
|
157 |
-
|
158 |
# Vectorstore funcs
|
159 |
|
160 |
-
def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
|
161 |
|
162 |
-
|
163 |
|
164 |
-
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
|
177 |
-
|
178 |
|
179 |
-
|
180 |
|
181 |
-
|
182 |
|
183 |
-
|
184 |
-
|
185 |
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
# Prompt functions
|
189 |
|
190 |
-
def base_prompt_templates(model_type =
|
191 |
|
192 |
#EXAMPLE_PROMPT = PromptTemplate(
|
193 |
# template="\nCONTENT:\n\n{page_content}\n\nSOURCE: {source}\n\n",
|
@@ -201,24 +132,24 @@ def base_prompt_templates(model_type = "Qwen 2 0.5B (small, fast)"):
|
|
201 |
|
202 |
# The main prompt:
|
203 |
|
204 |
-
if model_type ==
|
205 |
INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_qwen, input_variables=['question', 'summaries'])
|
206 |
-
elif model_type ==
|
207 |
INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_phi3, input_variables=['question', 'summaries'])
|
|
|
|
|
|
|
208 |
|
209 |
return INSTRUCTION_PROMPT, CONTENT_PROMPT
|
210 |
|
211 |
-
def write_out_metadata_as_string(metadata_in):
|
212 |
metadata_string = [f"{' '.join(f'{k}: {v}' for k, v in d.items() if k != 'page_section')}" for d in metadata_in] # ['metadata']
|
213 |
return metadata_string
|
214 |
|
215 |
-
def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content_prompt, extracted_memory, vectorstore, embeddings, relevant_flag = True, out_passages = 2): # ,
|
216 |
|
217 |
question = inputs["question"]
|
218 |
chat_history = inputs["chat_history"]
|
219 |
-
|
220 |
-
print("relevant_flag in generate_expanded_prompt:", relevant_flag)
|
221 |
-
|
222 |
|
223 |
if relevant_flag == True:
|
224 |
new_question_kworded = adapt_q_from_chat_history(question, chat_history, extracted_memory) # new_question_keywords,
|
@@ -234,8 +165,6 @@ def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content
|
|
234 |
return sorry_prompt, "No relevant sources found.", new_question_kworded
|
235 |
|
236 |
# Expand the found passages to the neighbouring context
|
237 |
-
print("Doc_df columns:", doc_df.columns)
|
238 |
-
|
239 |
if 'meta_url' in doc_df.columns:
|
240 |
file_type = determine_file_type(doc_df['meta_url'][0])
|
241 |
else:
|
@@ -243,7 +172,7 @@ def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content
|
|
243 |
|
244 |
# Only expand passages if not tabular data
|
245 |
if (file_type != ".csv") & (file_type != ".xlsx"):
|
246 |
-
docs_keep_as_doc, doc_df = get_expanded_passages(vectorstore, docs_keep_out, width=
|
247 |
|
248 |
# Build up sources content to add to user display
|
249 |
doc_df['meta_clean'] = write_out_metadata_as_string(doc_df["metadata"]) # [f"<b>{' '.join(f'{k}: {v}' for k, v in d.items() if k != 'page_section')}</b>" for d in doc_df['metadata']]
|
@@ -259,29 +188,28 @@ def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content
|
|
259 |
sources_docs_content_string = '<br><br>'.join(doc_df['content_meta'])#.replace(" "," ")#.strip()
|
260 |
|
261 |
instruction_prompt_out = instruction_prompt.format(question=new_question_kworded, summaries=docs_content_string)
|
262 |
-
|
263 |
-
print('Final prompt is: ')
|
264 |
-
print(instruction_prompt_out)
|
265 |
|
266 |
return instruction_prompt_out, sources_docs_content_string, new_question_kworded
|
267 |
|
268 |
-
def create_full_prompt(user_input,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
#if chain_agent is None:
|
271 |
# history.append((user_input, "Please click the button to submit the Huggingface API key before using the chatbot (top right)"))
|
272 |
# return history, history, "", ""
|
273 |
print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
|
274 |
-
|
275 |
-
|
276 |
history = history or []
|
277 |
-
|
278 |
-
if api_model_choice and api_model_choice != "None":
|
279 |
-
print("API model choice detected")
|
280 |
-
if api_key:
|
281 |
-
print("API key detected")
|
282 |
-
return history, "", None, relevant_flag
|
283 |
-
else:
|
284 |
-
return history, "", None, relevant_flag
|
285 |
|
286 |
# Create instruction prompt
|
287 |
instruction_prompt, content_prompt = base_prompt_templates(model_type=model_type)
|
@@ -291,38 +219,15 @@ def create_full_prompt(user_input, history, extracted_memory, vectorstore, embed
|
|
291 |
relevant_flag = False
|
292 |
else:
|
293 |
relevant_flag = True
|
294 |
-
|
295 |
-
print("User input:", user_input)
|
296 |
|
297 |
instruction_prompt_out, docs_content_string, new_question_kworded =\
|
298 |
generate_expanded_prompt({"question": user_input, "chat_history": history}, #vectorstore,
|
299 |
instruction_prompt, content_prompt, extracted_memory, vectorstore, embeddings, relevant_flag, out_passages)
|
300 |
|
301 |
-
history.append(user_input)
|
302 |
-
|
303 |
-
print("Output history is:", history)
|
304 |
-
print("Final prompt to model is:",instruction_prompt_out)
|
305 |
|
306 |
return history, docs_content_string, instruction_prompt_out, relevant_flag
|
307 |
|
308 |
-
# Chat functions
|
309 |
-
import boto3
|
310 |
-
import json
|
311 |
-
from chatfuncs.helper_functions import get_or_create_env_var
|
312 |
-
|
313 |
-
# ResponseObject class for AWS Bedrock calls
|
314 |
-
class ResponseObject:
|
315 |
-
def __init__(self, text, usage_metadata):
|
316 |
-
self.text = text
|
317 |
-
self.usage_metadata = usage_metadata
|
318 |
-
|
319 |
-
max_tokens = 4096
|
320 |
-
|
321 |
-
AWS_DEFAULT_REGION = get_or_create_env_var('AWS_DEFAULT_REGION', 'eu-west-2')
|
322 |
-
print(f'The value of AWS_DEFAULT_REGION is {AWS_DEFAULT_REGION}')
|
323 |
-
|
324 |
-
bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_DEFAULT_REGION)
|
325 |
-
|
326 |
def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice: str) -> ResponseObject:
|
327 |
"""
|
328 |
This function sends a request to AWS Claude with the following parameters:
|
@@ -351,6 +256,8 @@ def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tok
|
|
351 |
],
|
352 |
}
|
353 |
|
|
|
|
|
354 |
body = json.dumps(prompt_config)
|
355 |
|
356 |
modelId = model_choice
|
@@ -376,16 +283,173 @@ def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tok
|
|
376 |
|
377 |
return response
|
378 |
|
379 |
-
def
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
):
|
390 |
#print("Model type is: ", model_type)
|
391 |
|
@@ -395,16 +459,18 @@ def produce_streaming_answer_chatbot(history,
|
|
395 |
|
396 |
# return history
|
397 |
|
398 |
-
|
|
|
|
|
399 |
|
400 |
if relevant_query_bool == False:
|
401 |
-
|
402 |
-
history.append(out_message)
|
403 |
|
404 |
yield history
|
405 |
return
|
406 |
|
407 |
-
if model_type ==
|
|
|
408 |
# Get the model and tokenizer, and tokenize the user text.
|
409 |
model_inputs = tokenizer(text=full_prompt, return_tensors="pt", return_attention_mask=False).to(torch_device)
|
410 |
|
@@ -422,9 +488,7 @@ def produce_streaming_answer_chatbot(history,
|
|
422 |
top_k=top_k
|
423 |
)
|
424 |
|
425 |
-
|
426 |
-
|
427 |
-
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
428 |
t.start()
|
429 |
|
430 |
# Pull the generated text from the streamer, and update the model output.
|
@@ -432,12 +496,15 @@ def produce_streaming_answer_chatbot(history,
|
|
432 |
NUM_TOKENS=0
|
433 |
print('-'*4+'Start Generation'+'-'*4)
|
434 |
|
435 |
-
history
|
|
|
436 |
for new_text in streamer:
|
437 |
try:
|
438 |
-
if new_text
|
439 |
-
|
440 |
-
|
|
|
|
|
441 |
yield history
|
442 |
except Exception as e:
|
443 |
print(f"Error during text generation: {e}")
|
@@ -450,7 +517,7 @@ def produce_streaming_answer_chatbot(history,
|
|
450 |
print(f'Tokens per secound: {NUM_TOKENS/time_generate}')
|
451 |
print(f'Time per token: {(time_generate/NUM_TOKENS)*1000}ms')
|
452 |
|
453 |
-
elif model_type ==
|
454 |
#tokens = model.tokenize(full_prompt)
|
455 |
|
456 |
gen_config = CtransGenGenerationConfig()
|
@@ -463,15 +530,17 @@ def produce_streaming_answer_chatbot(history,
|
|
463 |
NUM_TOKENS=0
|
464 |
print('-'*4+'Start Generation'+'-'*4)
|
465 |
|
466 |
-
output =
|
467 |
full_prompt, **vars(gen_config))
|
468 |
|
469 |
-
history
|
|
|
470 |
for out in output:
|
471 |
|
472 |
if "choices" in out and len(out["choices"]) > 0 and "text" in out["choices"][0]:
|
473 |
-
history[-1][
|
474 |
NUM_TOKENS+=1
|
|
|
475 |
yield history
|
476 |
else:
|
477 |
print(f"Unexpected output structure: {out}")
|
@@ -481,36 +550,80 @@ def produce_streaming_answer_chatbot(history,
|
|
481 |
print('-'*4+'End Generation'+'-'*4)
|
482 |
print(f'Num of generated tokens: {NUM_TOKENS}')
|
483 |
print(f'Time for complete generation: {time_generate}s')
|
484 |
-
print(f'Tokens per
|
485 |
print(f'Time per token: {(time_generate/NUM_TOKENS)*1000}ms')
|
486 |
|
487 |
-
elif
|
488 |
system_prompt = "You are answering questions from the user based on source material. Respond with short, factually correct answers."
|
489 |
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
|
500 |
-
time.sleep(30)
|
501 |
-
response = call_aws_claude(full_prompt, system_prompt, temperature, max_tokens, model_type)
|
502 |
-
|
503 |
-
except Exception as e:
|
504 |
-
print(e)
|
505 |
-
return "", history
|
506 |
# Update the conversation history with the new prompt and response
|
507 |
-
|
508 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
509 |
|
510 |
-
|
511 |
-
#print("conversation_history:", conversation_history)
|
512 |
|
513 |
-
|
|
|
|
|
|
|
|
|
|
|
514 |
|
515 |
# Chat helper functions
|
516 |
|
@@ -589,9 +702,6 @@ def hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val, out_p
|
|
589 |
|
590 |
docs = vectorstore.similarity_search_with_score(new_question_kworded, k=k_val)
|
591 |
|
592 |
-
print("Docs from similarity search:")
|
593 |
-
print(docs)
|
594 |
-
|
595 |
# Keep only documents with a certain score
|
596 |
docs_len = [len(x[0].page_content) for x in docs]
|
597 |
docs_scores = [x[1] for x in docs]
|
@@ -688,12 +798,8 @@ def hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val, out_p
|
|
688 |
# 3rd level check on retrieved docs with SVM retriever
|
689 |
# Check the type of the embeddings object
|
690 |
embeddings_type = type(embeddings)
|
691 |
-
print("Type of embeddings object:", embeddings_type)
|
692 |
|
693 |
|
694 |
-
print("embeddings:", embeddings)
|
695 |
-
|
696 |
-
from langchain_huggingface import HuggingFaceEmbeddings
|
697 |
#hf_embeddings = HuggingFaceEmbeddings(**embeddings)
|
698 |
hf_embeddings = embeddings
|
699 |
|
@@ -743,10 +849,6 @@ def hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val, out_p
|
|
743 |
# Make df of best options
|
744 |
doc_df = create_doc_df(docs_keep_out)
|
745 |
|
746 |
-
print("doc_df:",doc_df)
|
747 |
-
print("docs_keep_as_doc:",docs_keep_as_doc)
|
748 |
-
print("docs_keep_out:", docs_keep_out)
|
749 |
-
|
750 |
return docs_keep_as_doc, doc_df, docs_keep_out
|
751 |
|
752 |
def get_expanded_passages(vectorstore, docs, width):
|
@@ -836,16 +938,16 @@ def get_expanded_passages(vectorstore, docs, width):
|
|
836 |
|
837 |
return expanded_docs, doc_df
|
838 |
|
839 |
-
def highlight_found_text(
|
840 |
"""
|
841 |
-
Highlights occurrences of
|
842 |
|
843 |
Parameters:
|
844 |
-
-
|
845 |
-
-
|
846 |
|
847 |
Returns:
|
848 |
-
- str: A string with occurrences of
|
849 |
|
850 |
Example:
|
851 |
>>> highlight_found_text("world", "Hello, world! This is a test. Another world awaits.")
|
@@ -859,32 +961,27 @@ def highlight_found_text(search_text: str, full_text: str, hlt_chunk_size:int=hl
|
|
859 |
return text[i][0].replace(" ", " ").strip()
|
860 |
else:
|
861 |
return ""
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
full_text = extract_text_from_input(full_text)
|
872 |
-
search_text = extract_search_text_from_input(search_text)
|
873 |
-
|
874 |
-
|
875 |
|
876 |
text_splitter = RecursiveCharacterTextSplitter(
|
877 |
chunk_size=hlt_chunk_size,
|
878 |
separators=hlt_strat,
|
879 |
chunk_overlap=hlt_overlap,
|
880 |
)
|
881 |
-
sections = text_splitter.split_text(
|
882 |
|
883 |
found_positions = {}
|
884 |
for x in sections:
|
885 |
text_start_pos = 0
|
886 |
while text_start_pos != -1:
|
887 |
-
text_start_pos =
|
888 |
if text_start_pos != -1:
|
889 |
found_positions[text_start_pos] = text_start_pos + len(x)
|
890 |
text_start_pos += 1
|
@@ -907,20 +1004,22 @@ def highlight_found_text(search_text: str, full_text: str, hlt_chunk_size:int=hl
|
|
907 |
prev_end = 0
|
908 |
for start, end in combined_positions:
|
909 |
if end-start > 15: # Only combine if there is a significant amount of matched text. Avoids picking up single words like 'and' etc.
|
910 |
-
pos_tokens.append(
|
911 |
-
pos_tokens.append('<mark style="color:black;">' +
|
912 |
prev_end = end
|
913 |
-
pos_tokens.append(
|
|
|
|
|
914 |
|
915 |
-
return
|
916 |
|
917 |
|
918 |
# # Chat history functions
|
919 |
|
920 |
def clear_chat(chat_history_state, sources, chat_message, current_topic):
|
921 |
-
chat_history_state =
|
922 |
sources = ''
|
923 |
-
chat_message =
|
924 |
current_topic = ''
|
925 |
|
926 |
return chat_history_state, sources, chat_message, current_topic
|
@@ -1011,8 +1110,7 @@ def remove_q_stopwords(question): # Remove stopwords from question. Not used at
|
|
1011 |
for word in tokens_without_sw:
|
1012 |
if word not in ordered_tokens:
|
1013 |
ordered_tokens.add(word)
|
1014 |
-
result.append(word)
|
1015 |
-
|
1016 |
|
1017 |
|
1018 |
new_question_keywords = ' '.join(result)
|
@@ -1021,9 +1119,6 @@ def remove_q_stopwords(question): # Remove stopwords from question. Not used at
|
|
1021 |
def remove_q_ner_extractor(question):
|
1022 |
|
1023 |
predict_out = ner_model.predict(question)
|
1024 |
-
|
1025 |
-
|
1026 |
-
|
1027 |
predict_tokens = [' '.join(v for k, v in d.items() if k == 'span') for d in predict_out]
|
1028 |
|
1029 |
# Remove duplicate words while preserving order
|
@@ -1075,11 +1170,11 @@ def keybert_keywords(text, n, kw_model):
|
|
1075 |
return keywords_list
|
1076 |
|
1077 |
# Gradio functions
|
1078 |
-
def turn_off_interactivity(
|
1079 |
-
return gr.
|
1080 |
|
1081 |
def restore_interactivity():
|
1082 |
-
return gr.
|
1083 |
|
1084 |
def update_message(dropdown_value):
|
1085 |
return gr.Textbox(value=dropdown_value)
|
|
|
5 |
import time
|
6 |
from itertools import compress
|
7 |
import pandas as pd
|
8 |
+
import google.generativeai as ai
|
9 |
+
import gradio as gr
|
10 |
+
from gradio import Progress
|
11 |
+
import boto3
|
12 |
+
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
from nltk.corpus import stopwords
|
14 |
from nltk.tokenize import RegexpTokenizer
|
15 |
from nltk.stem import WordNetLemmatizer
|
|
|
16 |
from keybert import KeyBERT
|
17 |
|
18 |
# For Name Entity Recognition model
|
19 |
#from span_marker import SpanMarkerModel # Not currently used
|
20 |
|
|
|
21 |
# For BM25 retrieval
|
22 |
import bm25s
|
23 |
import Stemmer
|
24 |
+
# Model packages
|
25 |
+
import torch.cuda
|
26 |
+
from threading import Thread
|
27 |
+
from transformers import pipeline, TextIteratorStreamer
|
28 |
+
# Langchain functions
|
29 |
+
from langchain.prompts import PromptTemplate
|
30 |
+
from langchain_community.vectorstores import FAISS
|
31 |
+
from langchain_community.retrievers import SVMRetriever
|
32 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
33 |
+
from langchain.docstore.document import Document
|
34 |
|
35 |
+
from chatfuncs.prompts import instruction_prompt_template_alpaca, instruction_prompt_mistral_orca, instruction_prompt_phi3, instruction_prompt_llama3, instruction_prompt_qwen, instruction_prompt_template_orca, instruction_prompt_gemma
|
36 |
+
from chatfuncs.model_load import temperature, max_new_tokens, sample, repetition_penalty, top_p, top_k, torch_device, CtransGenGenerationConfig, max_tokens
|
37 |
+
from chatfuncs.config import GEMINI_API_KEY, AWS_DEFAULT_REGION, LARGE_MODEL_NAME, SMALL_MODEL_NAME
|
38 |
|
39 |
+
model_object = [] # Define empty list for model functions to run
|
40 |
+
tokenizer = [] # Define empty list for model functions to run
|
41 |
|
42 |
+
# ResponseObject class for AWS Bedrock calls
|
43 |
+
class ResponseObject:
|
44 |
+
def __init__(self, text, usage_metadata):
|
45 |
+
self.text = text
|
46 |
+
self.usage_metadata = usage_metadata
|
47 |
|
48 |
+
bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_DEFAULT_REGION)
|
49 |
|
50 |
torch.cuda.empty_cache()
|
51 |
|
52 |
PandasDataFrame = Type[pd.DataFrame]
|
53 |
|
54 |
embeddings = None # global variable setup
|
55 |
+
embeddings_model = None # global variable setup
|
56 |
vectorstore = None # global variable setup
|
57 |
model_type = None # global variable setup
|
58 |
|
59 |
max_memory_length = 0 # How long should the memory of the conversation last?
|
60 |
|
61 |
+
source_texts = "" # Define dummy source text (full text) just to enable highlight function to load
|
|
|
|
|
|
|
62 |
|
63 |
## Highlight text constants
|
64 |
hlt_chunk_size = 12
|
|
|
72 |
# Used to pull out keywords from chat history to add to user queries behind the scenes
|
73 |
kw_model = pipeline("feature-extraction", model="sentence-transformers/all-MiniLM-L6-v2")
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
# Vectorstore funcs
|
76 |
|
77 |
+
# def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
|
78 |
|
79 |
+
# print(f"> Total split documents: {len(docs_out)}")
|
80 |
|
81 |
+
# vectorstore_func = FAISS.from_documents(documents=docs_out, embedding=embeddings)
|
82 |
|
83 |
+
# '''
|
84 |
+
# #with open("vectorstore.pkl", "wb") as f:
|
85 |
+
# #pickle.dump(vectorstore, f)
|
86 |
+
# '''
|
87 |
|
88 |
+
# #if Path(save_to).exists():
|
89 |
+
# # vectorstore_func.save_local(folder_path=save_to)
|
90 |
+
# #else:
|
91 |
+
# # os.mkdir(save_to)
|
92 |
+
# # vectorstore_func.save_local(folder_path=save_to)
|
93 |
|
94 |
+
# global vectorstore
|
95 |
|
96 |
+
# vectorstore = vectorstore_func
|
97 |
|
98 |
+
# out_message = "Document processing complete"
|
99 |
|
100 |
+
# #print(out_message)
|
101 |
+
# #print(f"> Saved to: {save_to}")
|
102 |
|
103 |
+
# return out_message
|
104 |
+
|
105 |
+
# def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings_model=embeddings_model):
|
106 |
+
|
107 |
+
# print(f"> Total split documents: {len(docs_out)}")
|
108 |
+
|
109 |
+
# print(docs_out)
|
110 |
+
|
111 |
+
# vectorstore_func = FAISS.from_documents(documents=docs_out, embedding=embeddings_model)
|
112 |
+
|
113 |
+
# vectorstore = vectorstore_func
|
114 |
+
|
115 |
+
# out_message = "Document processing complete"
|
116 |
+
|
117 |
+
# return out_message, vectorstore_func
|
118 |
|
119 |
# Prompt functions
|
120 |
|
121 |
+
def base_prompt_templates(model_type:str = SMALL_MODEL_NAME):
|
122 |
|
123 |
#EXAMPLE_PROMPT = PromptTemplate(
|
124 |
# template="\nCONTENT:\n\n{page_content}\n\nSOURCE: {source}\n\n",
|
|
|
132 |
|
133 |
# The main prompt:
|
134 |
|
135 |
+
if model_type == SMALL_MODEL_NAME:
|
136 |
INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_qwen, input_variables=['question', 'summaries'])
|
137 |
+
elif model_type == LARGE_MODEL_NAME:
|
138 |
INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_phi3, input_variables=['question', 'summaries'])
|
139 |
+
else:
|
140 |
+
INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_template_orca, input_variables=['question', 'summaries'])
|
141 |
+
|
142 |
|
143 |
return INSTRUCTION_PROMPT, CONTENT_PROMPT
|
144 |
|
145 |
+
def write_out_metadata_as_string(metadata_in:str):
|
146 |
metadata_string = [f"{' '.join(f'{k}: {v}' for k, v in d.items() if k != 'page_section')}" for d in metadata_in] # ['metadata']
|
147 |
return metadata_string
|
148 |
|
149 |
+
def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt:str, content_prompt:str, extracted_memory:list, vectorstore:object, embeddings:object, relevant_flag:bool = True, out_passages:int = 2, total_output_passage_chunks_size:int=5): # ,
|
150 |
|
151 |
question = inputs["question"]
|
152 |
chat_history = inputs["chat_history"]
|
|
|
|
|
|
|
153 |
|
154 |
if relevant_flag == True:
|
155 |
new_question_kworded = adapt_q_from_chat_history(question, chat_history, extracted_memory) # new_question_keywords,
|
|
|
165 |
return sorry_prompt, "No relevant sources found.", new_question_kworded
|
166 |
|
167 |
# Expand the found passages to the neighbouring context
|
|
|
|
|
168 |
if 'meta_url' in doc_df.columns:
|
169 |
file_type = determine_file_type(doc_df['meta_url'][0])
|
170 |
else:
|
|
|
172 |
|
173 |
# Only expand passages if not tabular data
|
174 |
if (file_type != ".csv") & (file_type != ".xlsx"):
|
175 |
+
docs_keep_as_doc, doc_df = get_expanded_passages(vectorstore, docs_keep_out, width=total_output_passage_chunks_size)
|
176 |
|
177 |
# Build up sources content to add to user display
|
178 |
doc_df['meta_clean'] = write_out_metadata_as_string(doc_df["metadata"]) # [f"<b>{' '.join(f'{k}: {v}' for k, v in d.items() if k != 'page_section')}</b>" for d in doc_df['metadata']]
|
|
|
188 |
sources_docs_content_string = '<br><br>'.join(doc_df['content_meta'])#.replace(" "," ")#.strip()
|
189 |
|
190 |
instruction_prompt_out = instruction_prompt.format(question=new_question_kworded, summaries=docs_content_string)
|
|
|
|
|
|
|
191 |
|
192 |
return instruction_prompt_out, sources_docs_content_string, new_question_kworded
|
193 |
|
194 |
+
def create_full_prompt(user_input:str,
|
195 |
+
history:list[dict],
|
196 |
+
extracted_memory:str,
|
197 |
+
vectorstore:object,
|
198 |
+
embeddings:object,
|
199 |
+
model_type:str,
|
200 |
+
out_passages:list[str],
|
201 |
+
api_key:str="",
|
202 |
+
relevant_flag:bool=True):
|
203 |
+
|
204 |
+
if "gemini" in model_type and not GEMINI_API_KEY and not api_key:
|
205 |
+
raise Exception("Gemini model selected but no API key found. Please enter an API key on the Advanced settings page.")
|
206 |
|
207 |
#if chain_agent is None:
|
208 |
# history.append((user_input, "Please click the button to submit the Huggingface API key before using the chatbot (top right)"))
|
209 |
# return history, history, "", ""
|
210 |
print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
|
211 |
+
|
|
|
212 |
history = history or []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
# Create instruction prompt
|
215 |
instruction_prompt, content_prompt = base_prompt_templates(model_type=model_type)
|
|
|
219 |
relevant_flag = False
|
220 |
else:
|
221 |
relevant_flag = True
|
|
|
|
|
222 |
|
223 |
instruction_prompt_out, docs_content_string, new_question_kworded =\
|
224 |
generate_expanded_prompt({"question": user_input, "chat_history": history}, #vectorstore,
|
225 |
instruction_prompt, content_prompt, extracted_memory, vectorstore, embeddings, relevant_flag, out_passages)
|
226 |
|
227 |
+
history.append({"metadata":None, "options":None, "role": 'user', "content": user_input})
|
|
|
|
|
|
|
228 |
|
229 |
return history, docs_content_string, instruction_prompt_out, relevant_flag
|
230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice: str) -> ResponseObject:
|
232 |
"""
|
233 |
This function sends a request to AWS Claude with the following parameters:
|
|
|
256 |
],
|
257 |
}
|
258 |
|
259 |
+
print("prompt_config:", prompt_config)
|
260 |
+
|
261 |
body = json.dumps(prompt_config)
|
262 |
|
263 |
modelId = model_choice
|
|
|
283 |
|
284 |
return response
|
285 |
|
286 |
+
def construct_gemini_generative_model(in_api_key: str, temperature: float, model_choice: str, system_prompt: str, max_tokens: int) -> Tuple[object, dict]:
|
287 |
+
"""
|
288 |
+
Constructs a GenerativeModel for Gemini API calls.
|
289 |
+
|
290 |
+
Parameters:
|
291 |
+
- in_api_key (str): The API key for authentication.
|
292 |
+
- temperature (float): The temperature parameter for the model, controlling the randomness of the output.
|
293 |
+
- model_choice (str): The choice of model to use for generation.
|
294 |
+
- system_prompt (str): The system prompt to guide the generation.
|
295 |
+
- max_tokens (int): The maximum number of tokens to generate.
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
- Tuple[object, dict]: A tuple containing the constructed GenerativeModel and its configuration.
|
299 |
+
"""
|
300 |
+
# Construct a GenerativeModel
|
301 |
+
try:
|
302 |
+
if in_api_key:
|
303 |
+
#print("Getting API key from textbox")
|
304 |
+
api_key = in_api_key
|
305 |
+
ai.configure(api_key=api_key)
|
306 |
+
elif "GOOGLE_API_KEY" in os.environ:
|
307 |
+
#print("Searching for API key in environmental variables")
|
308 |
+
api_key = os.environ["GOOGLE_API_KEY"]
|
309 |
+
ai.configure(api_key=api_key)
|
310 |
+
else:
|
311 |
+
print("No API key foound")
|
312 |
+
raise gr.Error("No API key found.")
|
313 |
+
except Exception as e:
|
314 |
+
print(e)
|
315 |
+
|
316 |
+
config = ai.GenerationConfig(temperature=temperature, max_output_tokens=max_tokens)
|
317 |
+
|
318 |
+
print("model_choice:", model_choice)
|
319 |
+
|
320 |
+
#model = ai.GenerativeModel.from_cached_content(cached_content=cache, generation_config=config)
|
321 |
+
model = ai.GenerativeModel(model_name=model_choice, system_instruction=system_prompt, generation_config=config)
|
322 |
+
|
323 |
+
return model, config
|
324 |
+
|
325 |
+
# Function to send a request and update history
|
326 |
+
def send_request(prompt: str, conversation_history: List[dict], model: object, config: dict, model_choice: str, system_prompt: str, temperature: float, progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
|
327 |
+
"""
|
328 |
+
This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
|
329 |
+
It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
|
330 |
+
If the model choice is specific to AWS Claude, it calls the `call_aws_claude` function; otherwise, it uses the `model.generate_content` method.
|
331 |
+
The function returns the response text and the updated conversation history.
|
332 |
+
"""
|
333 |
+
# Constructing the full prompt from the conversation history
|
334 |
+
full_prompt = "Conversation history:\n"
|
335 |
+
|
336 |
+
for entry in conversation_history:
|
337 |
+
role = entry['role'].capitalize() # Assuming the history is stored with 'role' and 'content'
|
338 |
+
message = ' '.join(entry['parts']) # Combining all parts of the message
|
339 |
+
full_prompt += f"{role}: {message}\n"
|
340 |
+
|
341 |
+
# Adding the new user prompt
|
342 |
+
full_prompt += f"\nUser: {prompt}"
|
343 |
+
|
344 |
+
# Print the full prompt for debugging purposes
|
345 |
+
#print("full_prompt:", full_prompt)
|
346 |
+
|
347 |
+
# Generate the model's response
|
348 |
+
if "gemini" in model_choice:
|
349 |
+
try:
|
350 |
+
response = model.generate_content(contents=full_prompt, generation_config=config)
|
351 |
+
except Exception as e:
|
352 |
+
# If fails, try again after 10 seconds in case there is a throttle limit
|
353 |
+
print(e)
|
354 |
+
try:
|
355 |
+
print("Calling Gemini model")
|
356 |
+
out_message = "API limit hit - waiting 30 seconds to retry."
|
357 |
+
print(out_message)
|
358 |
+
progress(0.5, desc=out_message)
|
359 |
+
time.sleep(30)
|
360 |
+
response = model.generate_content(contents=full_prompt, generation_config=config)
|
361 |
+
except Exception as e:
|
362 |
+
print(e)
|
363 |
+
return "", conversation_history
|
364 |
+
elif "claude" in model_choice:
|
365 |
+
try:
|
366 |
+
print("Calling AWS Claude model")
|
367 |
+
print("prompt:", prompt)
|
368 |
+
print("system_prompt:", system_prompt)
|
369 |
+
response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice)
|
370 |
+
except Exception as e:
|
371 |
+
# If fails, try again after x seconds in case there is a throttle limit
|
372 |
+
print(e)
|
373 |
+
try:
|
374 |
+
out_message = "API limit hit - waiting 30 seconds to retry."
|
375 |
+
print(out_message)
|
376 |
+
progress(0.5, desc=out_message)
|
377 |
+
time.sleep(30)
|
378 |
+
response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice)
|
379 |
+
|
380 |
+
except Exception as e:
|
381 |
+
print(e)
|
382 |
+
return "", conversation_history
|
383 |
+
else:
|
384 |
+
raise Exception("Model not found")
|
385 |
+
|
386 |
+
# Update the conversation history with the new prompt and response
|
387 |
+
conversation_history.append({"metadata":None, "options":None, "role": 'user', 'parts': [prompt]})
|
388 |
+
conversation_history.append({"metadata":None, "options":None, "role": "assistant", 'parts': [response.text]})
|
389 |
+
|
390 |
+
# Print the updated conversation history
|
391 |
+
#print("conversation_history:", conversation_history)
|
392 |
+
|
393 |
+
return response, conversation_history
|
394 |
+
|
395 |
+
def process_requests(prompts: List[str], system_prompt_with_table: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], model: object, config: dict, model_choice: str, temperature: float, batch_no:int = 1, master:bool = False) -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
|
396 |
+
"""
|
397 |
+
Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
|
398 |
+
|
399 |
+
Args:
|
400 |
+
prompts (List[str]): A list of prompts to be processed.
|
401 |
+
system_prompt_with_table (str): The system prompt including a table.
|
402 |
+
conversation_history (List[dict]): The history of the conversation.
|
403 |
+
whole_conversation (List[str]): The complete conversation including prompts and responses.
|
404 |
+
whole_conversation_metadata (List[str]): Metadata about the whole conversation.
|
405 |
+
model (object): The model to use for processing the prompts.
|
406 |
+
config (dict): Configuration for the model.
|
407 |
+
model_choice (str): The choice of model to use.
|
408 |
+
temperature (float): The temperature parameter for the model.
|
409 |
+
batch_no (int): Batch number of the large language model request.
|
410 |
+
master (bool): Is this request for the master table.
|
411 |
+
|
412 |
+
Returns:
|
413 |
+
Tuple[List[ResponseObject], List[dict], List[str], List[str]]: A tuple containing the list of responses, the updated conversation history, the updated whole conversation, and the updated whole conversation metadata.
|
414 |
+
"""
|
415 |
+
responses = []
|
416 |
+
#for prompt in prompts:
|
417 |
+
|
418 |
+
response, conversation_history = send_request(prompts[0], conversation_history, model=model, config=config, model_choice=model_choice, system_prompt=system_prompt_with_table, temperature=temperature)
|
419 |
+
|
420 |
+
print(response.text)
|
421 |
+
#"Okay, I'm ready. What source are we discussing, and what's your question about it? Please provide as much context as possible so I can give you the best answer."]
|
422 |
+
print(response.usage_metadata)
|
423 |
+
responses.append(response)
|
424 |
+
|
425 |
+
# Create conversation txt object
|
426 |
+
whole_conversation.append(prompts[0])
|
427 |
+
whole_conversation.append(response.text)
|
428 |
+
|
429 |
+
# Create conversation metadata
|
430 |
+
if master == False:
|
431 |
+
whole_conversation_metadata.append(f"Query batch {batch_no} prompt {len(responses)} metadata:")
|
432 |
+
else:
|
433 |
+
whole_conversation_metadata.append(f"Query summary metadata:")
|
434 |
+
|
435 |
+
whole_conversation_metadata.append(str(response.usage_metadata))
|
436 |
+
|
437 |
+
return responses, conversation_history, whole_conversation, whole_conversation_metadata
|
438 |
+
|
439 |
+
def produce_streaming_answer_chatbot(
|
440 |
+
history:list,
|
441 |
+
full_prompt:str,
|
442 |
+
model_type:str,
|
443 |
+
temperature:float=temperature,
|
444 |
+
relevant_query_bool:bool=True,
|
445 |
+
chat_history:list[dict]=[{"metadata":None, "options":None, "role": 'user', "content": ""}],
|
446 |
+
in_api_key:str=GEMINI_API_KEY,
|
447 |
+
max_new_tokens:int=max_new_tokens,
|
448 |
+
sample:bool=sample,
|
449 |
+
repetition_penalty:float=repetition_penalty,
|
450 |
+
top_p:float=top_p,
|
451 |
+
top_k:float=top_k,
|
452 |
+
max_tokens:int=max_tokens
|
453 |
):
|
454 |
#print("Model type is: ", model_type)
|
455 |
|
|
|
459 |
|
460 |
# return history
|
461 |
|
462 |
+
history = chat_history
|
463 |
+
|
464 |
+
print("history at start of streaming function:", history)
|
465 |
|
466 |
if relevant_query_bool == False:
|
467 |
+
history.append({"metadata":None, "options":None, "role": "assistant", "content": 'No relevant query found. Please retry your question'})
|
|
|
468 |
|
469 |
yield history
|
470 |
return
|
471 |
|
472 |
+
if model_type == SMALL_MODEL_NAME:
|
473 |
+
|
474 |
# Get the model and tokenizer, and tokenize the user text.
|
475 |
model_inputs = tokenizer(text=full_prompt, return_tensors="pt", return_attention_mask=False).to(torch_device)
|
476 |
|
|
|
488 |
top_k=top_k
|
489 |
)
|
490 |
|
491 |
+
t = Thread(target=model_object.generate, kwargs=generate_kwargs)
|
|
|
|
|
492 |
t.start()
|
493 |
|
494 |
# Pull the generated text from the streamer, and update the model output.
|
|
|
496 |
NUM_TOKENS=0
|
497 |
print('-'*4+'Start Generation'+'-'*4)
|
498 |
|
499 |
+
history.append({"metadata":None, "options":None, "role": "assistant", "content": ""})
|
500 |
+
|
501 |
for new_text in streamer:
|
502 |
try:
|
503 |
+
if new_text is None:
|
504 |
+
new_text = ""
|
505 |
+
history[-1]['content'] += new_text
|
506 |
+
NUM_TOKENS += 1
|
507 |
+
history[-1]['content'] = history[-1]['content'].replace('<|im_end|>','')
|
508 |
yield history
|
509 |
except Exception as e:
|
510 |
print(f"Error during text generation: {e}")
|
|
|
517 |
print(f'Tokens per secound: {NUM_TOKENS/time_generate}')
|
518 |
print(f'Time per token: {(time_generate/NUM_TOKENS)*1000}ms')
|
519 |
|
520 |
+
elif model_type == LARGE_MODEL_NAME:
|
521 |
#tokens = model.tokenize(full_prompt)
|
522 |
|
523 |
gen_config = CtransGenGenerationConfig()
|
|
|
530 |
NUM_TOKENS=0
|
531 |
print('-'*4+'Start Generation'+'-'*4)
|
532 |
|
533 |
+
output = model_object(
|
534 |
full_prompt, **vars(gen_config))
|
535 |
|
536 |
+
history.append({"metadata":None, "options":None, "role": "assistant", "content": ""})
|
537 |
+
|
538 |
for out in output:
|
539 |
|
540 |
if "choices" in out and len(out["choices"]) > 0 and "text" in out["choices"][0]:
|
541 |
+
history[-1]['content'] += out["choices"][0]["text"]
|
542 |
NUM_TOKENS+=1
|
543 |
+
history[-1]['content'] = history[-1]['content'].replace('<|im_end|>','')
|
544 |
yield history
|
545 |
else:
|
546 |
print(f"Unexpected output structure: {out}")
|
|
|
550 |
print('-'*4+'End Generation'+'-'*4)
|
551 |
print(f'Num of generated tokens: {NUM_TOKENS}')
|
552 |
print(f'Time for complete generation: {time_generate}s')
|
553 |
+
print(f'Tokens per second: {NUM_TOKENS/time_generate}')
|
554 |
print(f'Time per token: {(time_generate/NUM_TOKENS)*1000}ms')
|
555 |
|
556 |
+
elif "claude" in model_type:
|
557 |
system_prompt = "You are answering questions from the user based on source material. Respond with short, factually correct answers."
|
558 |
|
559 |
+
print("full_prompt:", full_prompt)
|
560 |
+
|
561 |
+
if isinstance(full_prompt, str):
|
562 |
+
full_prompt = [full_prompt]
|
563 |
+
|
564 |
+
model = model_type
|
565 |
+
config = {}
|
566 |
+
|
567 |
+
responses, summary_conversation_history, whole_summary_conversation, whole_conversation_metadata = process_requests(full_prompt, system_prompt, conversation_history=[], whole_conversation=[], whole_conversation_metadata=[], model=model, config = config, model_choice = model_type, temperature = temperature)
|
568 |
+
|
569 |
+
if isinstance(responses[-1], ResponseObject):
|
570 |
+
response_texts = [resp.text for resp in responses]
|
571 |
+
elif "choices" in responses[-1]:
|
572 |
+
response_texts = [resp["choices"][0]['text'] for resp in responses]
|
573 |
+
else:
|
574 |
+
response_texts = [resp.text for resp in responses]
|
575 |
+
|
576 |
+
latest_response_text = response_texts[-1]
|
577 |
|
|
|
|
|
|
|
|
|
|
|
|
|
578 |
# Update the conversation history with the new prompt and response
|
579 |
+
clean_text = re.sub(r'[\n\t\r]', ' ', latest_response_text) # Replace newlines, tabs, and carriage returns with a space
|
580 |
+
clean_response_text = re.sub(r'[^\x20-\x7E]', '', clean_text).strip() # Remove all non-ASCII printable characters
|
581 |
+
|
582 |
+
history.append({"metadata":None, "options":None, "role": "assistant", "content": ''})
|
583 |
+
|
584 |
+
for char in clean_response_text:
|
585 |
+
time.sleep(0.005)
|
586 |
+
history[-1]['content'] += char
|
587 |
+
yield history
|
588 |
+
|
589 |
+
elif "gemini" in model_type:
|
590 |
+
|
591 |
+
if in_api_key: gemini_api_key = in_api_key
|
592 |
+
elif GEMINI_API_KEY: gemini_api_key = GEMINI_API_KEY
|
593 |
+
else: raise Exception("Gemini API key not found. Please enter a key on the Advanced settings page or select another model type")
|
594 |
+
|
595 |
+
print("Using Gemini model:", model_type)
|
596 |
+
print("full_prompt:", full_prompt)
|
597 |
+
|
598 |
+
if isinstance(full_prompt, str):
|
599 |
+
full_prompt = [full_prompt]
|
600 |
+
|
601 |
+
system_prompt = "You are answering questions from the user based on source material. Respond with short, factually correct answers."
|
602 |
+
|
603 |
+
model, config = construct_gemini_generative_model(gemini_api_key, temperature, model_type, system_prompt, max_tokens)
|
604 |
+
|
605 |
+
responses, summary_conversation_history, whole_summary_conversation, whole_conversation_metadata = process_requests(full_prompt, system_prompt, conversation_history=[], whole_conversation=[], whole_conversation_metadata=[], model=model, config = config, model_choice = model_type, temperature = temperature)
|
606 |
+
|
607 |
+
if isinstance(responses[-1], ResponseObject):
|
608 |
+
response_texts = [resp.text for resp in responses]
|
609 |
+
elif "choices" in responses[-1]:
|
610 |
+
response_texts = [resp["choices"][0]['text'] for resp in responses]
|
611 |
+
else:
|
612 |
+
response_texts = [resp.text for resp in responses]
|
613 |
+
|
614 |
+
latest_response_text = response_texts[-1]
|
615 |
+
|
616 |
+
clean_text = re.sub(r'[\n\t\r]', ' ', latest_response_text) # Replace newlines, tabs, and carriage returns with a space
|
617 |
+
clean_response_text = re.sub(r'[^\x20-\x7E]', '', clean_text).strip() # Remove all non-ASCII printable characters
|
618 |
|
619 |
+
history.append({"metadata":None, "options":None, "role": "assistant", "content": ''})
|
|
|
620 |
|
621 |
+
for char in clean_response_text:
|
622 |
+
time.sleep(0.005)
|
623 |
+
history[-1]['content'] += char
|
624 |
+
yield history
|
625 |
+
|
626 |
+
print("history at end of function:", history)
|
627 |
|
628 |
# Chat helper functions
|
629 |
|
|
|
702 |
|
703 |
docs = vectorstore.similarity_search_with_score(new_question_kworded, k=k_val)
|
704 |
|
|
|
|
|
|
|
705 |
# Keep only documents with a certain score
|
706 |
docs_len = [len(x[0].page_content) for x in docs]
|
707 |
docs_scores = [x[1] for x in docs]
|
|
|
798 |
# 3rd level check on retrieved docs with SVM retriever
|
799 |
# Check the type of the embeddings object
|
800 |
embeddings_type = type(embeddings)
|
|
|
801 |
|
802 |
|
|
|
|
|
|
|
803 |
#hf_embeddings = HuggingFaceEmbeddings(**embeddings)
|
804 |
hf_embeddings = embeddings
|
805 |
|
|
|
849 |
# Make df of best options
|
850 |
doc_df = create_doc_df(docs_keep_out)
|
851 |
|
|
|
|
|
|
|
|
|
852 |
return docs_keep_as_doc, doc_df, docs_keep_out
|
853 |
|
854 |
def get_expanded_passages(vectorstore, docs, width):
|
|
|
938 |
|
939 |
return expanded_docs, doc_df
|
940 |
|
941 |
+
def highlight_found_text(chat_history: list[dict], source_texts: list[dict], hlt_chunk_size:int=hlt_chunk_size, hlt_strat:List=hlt_strat, hlt_overlap:int=hlt_overlap) -> str:
|
942 |
"""
|
943 |
+
Highlights occurrences of chat_history within source_texts.
|
944 |
|
945 |
Parameters:
|
946 |
+
- chat_history (str): The text to be searched for within source_texts.
|
947 |
+
- source_texts (str): The text within which chat_history occurrences will be highlighted.
|
948 |
|
949 |
Returns:
|
950 |
+
- str: A string with occurrences of chat_history highlighted.
|
951 |
|
952 |
Example:
|
953 |
>>> highlight_found_text("world", "Hello, world! This is a test. Another world awaits.")
|
|
|
961 |
return text[i][0].replace(" ", " ").strip()
|
962 |
else:
|
963 |
return ""
|
964 |
+
|
965 |
+
print("chat_history:", chat_history)
|
966 |
+
|
967 |
+
response_text = next(
|
968 |
+
(entry['content'] for entry in reversed(chat_history) if entry.get('role') == 'assistant'),
|
969 |
+
"")
|
970 |
+
|
971 |
+
source_texts = extract_text_from_input(source_texts)
|
|
|
|
|
|
|
|
|
|
|
972 |
|
973 |
text_splitter = RecursiveCharacterTextSplitter(
|
974 |
chunk_size=hlt_chunk_size,
|
975 |
separators=hlt_strat,
|
976 |
chunk_overlap=hlt_overlap,
|
977 |
)
|
978 |
+
sections = text_splitter.split_text(response_text)
|
979 |
|
980 |
found_positions = {}
|
981 |
for x in sections:
|
982 |
text_start_pos = 0
|
983 |
while text_start_pos != -1:
|
984 |
+
text_start_pos = source_texts.find(x, text_start_pos)
|
985 |
if text_start_pos != -1:
|
986 |
found_positions[text_start_pos] = text_start_pos + len(x)
|
987 |
text_start_pos += 1
|
|
|
1004 |
prev_end = 0
|
1005 |
for start, end in combined_positions:
|
1006 |
if end-start > 15: # Only combine if there is a significant amount of matched text. Avoids picking up single words like 'and' etc.
|
1007 |
+
pos_tokens.append(source_texts[prev_end:start])
|
1008 |
+
pos_tokens.append('<mark style="color:black;">' + source_texts[start:end] + '</mark>')
|
1009 |
prev_end = end
|
1010 |
+
pos_tokens.append(source_texts[prev_end:])
|
1011 |
+
|
1012 |
+
out_pos_tokens = "".join(pos_tokens)
|
1013 |
|
1014 |
+
return out_pos_tokens
|
1015 |
|
1016 |
|
1017 |
# # Chat history functions
|
1018 |
|
1019 |
def clear_chat(chat_history_state, sources, chat_message, current_topic):
|
1020 |
+
chat_history_state = None
|
1021 |
sources = ''
|
1022 |
+
chat_message = None
|
1023 |
current_topic = ''
|
1024 |
|
1025 |
return chat_history_state, sources, chat_message, current_topic
|
|
|
1110 |
for word in tokens_without_sw:
|
1111 |
if word not in ordered_tokens:
|
1112 |
ordered_tokens.add(word)
|
1113 |
+
result.append(word)
|
|
|
1114 |
|
1115 |
|
1116 |
new_question_keywords = ' '.join(result)
|
|
|
1119 |
def remove_q_ner_extractor(question):
|
1120 |
|
1121 |
predict_out = ner_model.predict(question)
|
|
|
|
|
|
|
1122 |
predict_tokens = [' '.join(v for k, v in d.items() if k == 'span') for d in predict_out]
|
1123 |
|
1124 |
# Remove duplicate words while preserving order
|
|
|
1170 |
return keywords_list
|
1171 |
|
1172 |
# Gradio functions
|
1173 |
+
def turn_off_interactivity():
|
1174 |
+
return gr.Textbox(interactive=False), gr.Button(interactive=False)
|
1175 |
|
1176 |
def restore_interactivity():
|
1177 |
+
return gr.Textbox(interactive=True), gr.Button(interactive=True)
|
1178 |
|
1179 |
def update_message(dropdown_value):
|
1180 |
return gr.Textbox(value=dropdown_value)
|
chatfuncs/config.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import socket
|
4 |
+
import logging
|
5 |
+
from datetime import datetime
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
today_rev = datetime.now().strftime("%Y%m%d")
|
9 |
+
HOST_NAME = socket.gethostname()
|
10 |
+
|
11 |
+
# Set or retrieve configuration variables for the redaction app
|
12 |
+
|
13 |
+
def get_or_create_env_var(var_name:str, default_value:str, print_val:bool=False):
|
14 |
+
'''
|
15 |
+
Get an environmental variable, and set it to a default value if it doesn't exist
|
16 |
+
'''
|
17 |
+
# Get the environment variable if it exists
|
18 |
+
value = os.environ.get(var_name)
|
19 |
+
|
20 |
+
# If it doesn't exist, set the environment variable to the default value
|
21 |
+
if value is None:
|
22 |
+
os.environ[var_name] = default_value
|
23 |
+
value = default_value
|
24 |
+
|
25 |
+
if print_val == True:
|
26 |
+
print(f'The value of {var_name} is {value}')
|
27 |
+
|
28 |
+
return value
|
29 |
+
|
30 |
+
def ensure_folder_exists(output_folder:str):
|
31 |
+
"""Checks if the specified folder exists, creates it if not."""
|
32 |
+
|
33 |
+
if not os.path.exists(output_folder):
|
34 |
+
# Create the folder if it doesn't exist
|
35 |
+
os.makedirs(output_folder, exist_ok=True)
|
36 |
+
print(f"Created the {output_folder} folder.")
|
37 |
+
else:
|
38 |
+
print(f"The {output_folder} folder already exists.")
|
39 |
+
|
40 |
+
def add_folder_to_path(folder_path: str):
|
41 |
+
'''
|
42 |
+
Check if a folder exists on your system. If so, get the absolute path and then add it to the system Path variable if it doesn't already exist. Function is only relevant for locally-created executable files based on this app (when using pyinstaller it creates a _internal folder that contains tesseract and poppler. These need to be added to the system path to enable the app to run)
|
43 |
+
'''
|
44 |
+
|
45 |
+
if os.path.exists(folder_path) and os.path.isdir(folder_path):
|
46 |
+
print(folder_path, "folder exists.")
|
47 |
+
|
48 |
+
# Resolve relative path to absolute path
|
49 |
+
absolute_path = os.path.abspath(folder_path)
|
50 |
+
|
51 |
+
current_path = os.environ['PATH']
|
52 |
+
if absolute_path not in current_path.split(os.pathsep):
|
53 |
+
full_path_extension = absolute_path + os.pathsep + current_path
|
54 |
+
os.environ['PATH'] = full_path_extension
|
55 |
+
#print(f"Updated PATH with: ", full_path_extension)
|
56 |
+
else:
|
57 |
+
print(f"Directory {folder_path} already exists in PATH.")
|
58 |
+
else:
|
59 |
+
print(f"Folder not found at {folder_path} - not added to PATH")
|
60 |
+
|
61 |
+
ensure_folder_exists("config/")
|
62 |
+
|
63 |
+
# If you have an aws_config env file in the config folder, you can load in app variables this way, e.g. 'config/app_config.env'
|
64 |
+
APP_CONFIG_PATH = get_or_create_env_var('APP_CONFIG_PATH', 'config/app_config.env') # e.g. config/app_config.env
|
65 |
+
|
66 |
+
if APP_CONFIG_PATH:
|
67 |
+
if os.path.exists(APP_CONFIG_PATH):
|
68 |
+
print(f"Loading app variables from config file {APP_CONFIG_PATH}")
|
69 |
+
load_dotenv(APP_CONFIG_PATH)
|
70 |
+
else: print("App config file not found at location:", APP_CONFIG_PATH)
|
71 |
+
|
72 |
+
# Report logging to console?
|
73 |
+
LOGGING = get_or_create_env_var('LOGGING', 'False')
|
74 |
+
|
75 |
+
if LOGGING == 'True':
|
76 |
+
# Configure logging
|
77 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
78 |
+
|
79 |
+
###
|
80 |
+
# AWS CONFIG
|
81 |
+
###
|
82 |
+
|
83 |
+
# If you have an aws_config env file in the config folder, you can load in AWS keys this way, e.g. 'env/aws_config.env'
|
84 |
+
AWS_CONFIG_PATH = get_or_create_env_var('AWS_CONFIG_PATH', '') # e.g. config/aws_config.env
|
85 |
+
|
86 |
+
if AWS_CONFIG_PATH:
|
87 |
+
if os.path.exists(AWS_CONFIG_PATH):
|
88 |
+
print(f"Loading AWS variables from config file {AWS_CONFIG_PATH}")
|
89 |
+
load_dotenv(AWS_CONFIG_PATH)
|
90 |
+
else: print("AWS config file not found at location:", AWS_CONFIG_PATH)
|
91 |
+
|
92 |
+
RUN_AWS_FUNCTIONS = get_or_create_env_var("RUN_AWS_FUNCTIONS", "0")
|
93 |
+
|
94 |
+
AWS_REGION = get_or_create_env_var('AWS_REGION', '')
|
95 |
+
|
96 |
+
AWS_DEFAULT_REGION = get_or_create_env_var('AWS_DEFAULT_REGION', '')
|
97 |
+
|
98 |
+
AWS_CLIENT_ID = get_or_create_env_var('AWS_CLIENT_ID', '')
|
99 |
+
|
100 |
+
AWS_CLIENT_SECRET = get_or_create_env_var('AWS_CLIENT_SECRET', '')
|
101 |
+
|
102 |
+
AWS_USER_POOL_ID = get_or_create_env_var('AWS_USER_POOL_ID', '')
|
103 |
+
|
104 |
+
AWS_ACCESS_KEY = get_or_create_env_var('AWS_ACCESS_KEY', '')
|
105 |
+
if AWS_ACCESS_KEY: print(f'AWS_ACCESS_KEY found in environment variables')
|
106 |
+
|
107 |
+
AWS_SECRET_KEY = get_or_create_env_var('AWS_SECRET_KEY', '')
|
108 |
+
if AWS_SECRET_KEY: print(f'AWS_SECRET_KEY found in environment variables')
|
109 |
+
|
110 |
+
DOCUMENT_REDACTION_BUCKET = get_or_create_env_var('DOCUMENT_REDACTION_BUCKET', '')
|
111 |
+
|
112 |
+
# Custom headers e.g. if routing traffic through Cloudfront
|
113 |
+
# Retrieving or setting CUSTOM_HEADER
|
114 |
+
CUSTOM_HEADER = get_or_create_env_var('CUSTOM_HEADER', '')
|
115 |
+
#if CUSTOM_HEADER: print(f'CUSTOM_HEADER found')
|
116 |
+
|
117 |
+
# Retrieving or setting CUSTOM_HEADER_VALUE
|
118 |
+
CUSTOM_HEADER_VALUE = get_or_create_env_var('CUSTOM_HEADER_VALUE', '')
|
119 |
+
#if CUSTOM_HEADER_VALUE: print(f'CUSTOM_HEADER_VALUE found')
|
120 |
+
|
121 |
+
###
|
122 |
+
# File I/O config
|
123 |
+
###
|
124 |
+
SESSION_OUTPUT_FOLDER = get_or_create_env_var('SESSION_OUTPUT_FOLDER', 'False') # i.e. do you want your input and output folders saved within a subfolder based on session hash value within output/input folders
|
125 |
+
|
126 |
+
OUTPUT_FOLDER = get_or_create_env_var('GRADIO_OUTPUT_FOLDER', 'output/') # 'output/'
|
127 |
+
INPUT_FOLDER = get_or_create_env_var('GRADIO_INPUT_FOLDER', 'input/') # 'input/'
|
128 |
+
|
129 |
+
ensure_folder_exists(OUTPUT_FOLDER)
|
130 |
+
ensure_folder_exists(INPUT_FOLDER)
|
131 |
+
|
132 |
+
# Allow for files to be saved in a temporary folder for increased security in some instances
|
133 |
+
if OUTPUT_FOLDER == "TEMP" or INPUT_FOLDER == "TEMP":
|
134 |
+
# Create a temporary directory
|
135 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
136 |
+
print(f'Temporary directory created at: {temp_dir}')
|
137 |
+
|
138 |
+
if OUTPUT_FOLDER == "TEMP": OUTPUT_FOLDER = temp_dir + "/"
|
139 |
+
if INPUT_FOLDER == "TEMP": INPUT_FOLDER = temp_dir + "/"
|
140 |
+
|
141 |
+
# By default, logs are put into a subfolder of today's date and the host name of the instance running the app. This is to avoid at all possible the possibility of log files from one instance overwriting the logs of another instance on S3. If running the app on one system always, or just locally, it is not necessary to make the log folders so specific.
|
142 |
+
# Another way to address this issue would be to write logs to another type of storage, e.g. database such as dynamodb. I may look into this in future.
|
143 |
+
|
144 |
+
USE_LOG_SUBFOLDERS = get_or_create_env_var('USE_LOG_SUBFOLDERS', 'True')
|
145 |
+
|
146 |
+
if USE_LOG_SUBFOLDERS == "True":
|
147 |
+
day_log_subfolder = today_rev + '/'
|
148 |
+
host_name_subfolder = HOST_NAME + '/'
|
149 |
+
full_log_subfolder = day_log_subfolder + host_name_subfolder
|
150 |
+
else:
|
151 |
+
full_log_subfolder = ""
|
152 |
+
|
153 |
+
FEEDBACK_LOGS_FOLDER = get_or_create_env_var('FEEDBACK_LOGS_FOLDER', 'feedback/' + full_log_subfolder)
|
154 |
+
ACCESS_LOGS_FOLDER = get_or_create_env_var('ACCESS_LOGS_FOLDER', 'logs/' + full_log_subfolder)
|
155 |
+
USAGE_LOGS_FOLDER = get_or_create_env_var('USAGE_LOGS_FOLDER', 'usage/' + full_log_subfolder)
|
156 |
+
|
157 |
+
ensure_folder_exists(FEEDBACK_LOGS_FOLDER)
|
158 |
+
ensure_folder_exists(ACCESS_LOGS_FOLDER)
|
159 |
+
ensure_folder_exists(USAGE_LOGS_FOLDER)
|
160 |
+
|
161 |
+
# Should the redacted file name be included in the logs? In some instances, the names of the files themselves could be sensitive, and should not be disclosed beyond the app. So, by default this is false.
|
162 |
+
DISPLAY_FILE_NAMES_IN_LOGS = get_or_create_env_var('DISPLAY_FILE_NAMES_IN_LOGS', 'False')
|
163 |
+
|
164 |
+
###
|
165 |
+
# RUN CONFIG
|
166 |
+
GEMINI_API_KEY = get_or_create_env_var('GEMINI_API_KEY', '')
|
167 |
+
|
168 |
+
HF_TOKEN = get_or_create_env_var('HF_TOKEN', '')
|
169 |
+
|
170 |
+
|
171 |
+
# Number of pages to loop through before breaking the function and restarting from the last finished page (not currently activated).
|
172 |
+
PAGE_BREAK_VALUE = get_or_create_env_var('PAGE_BREAK_VALUE', '99999')
|
173 |
+
|
174 |
+
MAX_TIME_VALUE = get_or_create_env_var('MAX_TIME_VALUE', '999999')
|
175 |
+
|
176 |
+
###
|
177 |
+
# APP RUN CONFIG
|
178 |
+
###
|
179 |
+
|
180 |
+
SMALL_MODEL_NAME = get_or_create_env_var("SMALL_MODEL_NAME", "Gemma 3 1B (small, fast)") # "Qwen 2 0.5B (small, fast)"
|
181 |
+
|
182 |
+
SMALL_MODEL_REPO_ID = get_or_create_env_var("SMALL_MODEL_REPO_ID", 'google/gemma-3-1b-it') #'Qwen/Qwen2-0.5B-Instruct')
|
183 |
+
|
184 |
+
LARGE_MODEL_NAME = get_or_create_env_var("LARGE_MODEL_NAME", "Phi 3.5 Mini (larger, slow)")
|
185 |
+
|
186 |
+
LARGE_MODEL_REPO_ID = get_or_create_env_var("LARGE_MODEL_REPO_ID", "QuantFactory/Phi-3.5-mini-instruct-GGUF") # "QuantFactory/Phi-3-mini-128k-instruct-GGUF"), # "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF-v2"), #"microsoft/Phi-3-mini-4k-instruct-gguf"),#"TheBloke/Mistral-7B-OpenOrca-GGUF"),
|
187 |
+
LARGE_MODEL_GGUF_FILE = get_or_create_env_var("LARGE_MODEL_GGUF_FILE", "Phi-3.5-mini-instruct.Q4_K_M.gguf") #"Phi-3-mini-128k-instruct.Q4_K_M.gguf") #"Meta-Llama-3-8B-Instruct-v2.Q6_K.gguf") #"Phi-3-mini-4k-instruct-q4.gguf")#"mistral-7b-openorca.Q4_K_M.gguf"),
|
188 |
+
|
189 |
+
if RUN_AWS_FUNCTIONS == "1":
|
190 |
+
default_model_choices = f'["{SMALL_MODEL_NAME}", "{LARGE_MODEL_NAME}", "gemini-2.0-flash-001", "gemini-2.5-flash-preview-04-17", "gemini-2.5-pro-preview-03-25", "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"]'
|
191 |
+
else:
|
192 |
+
default_model_choices = f'["{SMALL_MODEL_NAME}", "{LARGE_MODEL_NAME}", "gemini-2.0-flash-001", "gemini-2.5-flash-preview-04-17", "gemini-2.5-pro-preview-03-25"]'
|
193 |
+
|
194 |
+
DEFAULT_MODEL_CHOICES = get_or_create_env_var("DEFAULT_MODEL_CHOICES", default_model_choices)
|
195 |
+
|
196 |
+
EMBEDDINGS_MODEL_NAME = get_or_create_env_var('EMBEDDINGS_MODEL_NAME', "BAAI/bge-base-en-v1.5") #"mixedbread-ai/mxbai-embed-xsmall-v1"
|
197 |
+
|
198 |
+
DEFAULT_EMBEDDINGS_LOCATION = get_or_create_env_var('DEFAULT_EMBEDDINGS_LOCATION', "faiss_embedding")
|
199 |
+
|
200 |
+
DEFAULT_DATA_SOURCE_NAME = get_or_create_env_var('DEFAULT_DATA_SOURCE_NAME', "Document redaction app documentation")
|
201 |
+
|
202 |
+
DEFAULT_DATA_SOURCE = get_or_create_env_var('DEFAULT_DATA_SOURCE', "https://seanpedrick-case.github.io/doc_redaction/README.html")
|
203 |
+
|
204 |
+
DEFAULT_EXAMPLES = get_or_create_env_var('DEFAULT_EXAMPLES', '[ "How can I make a custom deny list?", "How can I find page duplicates?", "How can I review and modify existing redactions?", "How can I export my review files to Adobe?"]')
|
205 |
+
#
|
206 |
+
# ') # ["What were the five pillars of the previous borough plan?",
|
207 |
+
#"What is the vision statement for Lambeth?",
|
208 |
+
#"What are the commitments for Lambeth?",
|
209 |
+
#"What are the 2030 outcomes for Lambeth?"]
|
210 |
+
|
211 |
+
# Get some environment variables and Launch the Gradio app
|
212 |
+
COGNITO_AUTH = get_or_create_env_var('COGNITO_AUTH', '0')
|
213 |
+
|
214 |
+
RUN_DIRECT_MODE = get_or_create_env_var('RUN_DIRECT_MODE', '0')
|
215 |
+
|
216 |
+
MAX_QUEUE_SIZE = int(get_or_create_env_var('MAX_QUEUE_SIZE', '5'))
|
217 |
+
|
218 |
+
MAX_FILE_SIZE = get_or_create_env_var('MAX_FILE_SIZE', '250mb')
|
219 |
+
|
220 |
+
GRADIO_SERVER_PORT = int(get_or_create_env_var('GRADIO_SERVER_PORT', '7860'))
|
221 |
+
|
222 |
+
ROOT_PATH = get_or_create_env_var('ROOT_PATH', '')
|
223 |
+
|
224 |
+
DEFAULT_CONCURRENCY_LIMIT = get_or_create_env_var('DEFAULT_CONCURRENCY_LIMIT', '3')
|
225 |
+
|
226 |
+
GET_DEFAULT_ALLOW_LIST = get_or_create_env_var('GET_DEFAULT_ALLOW_LIST', 'False')
|
227 |
+
|
228 |
+
ALLOW_LIST_PATH = get_or_create_env_var('ALLOW_LIST_PATH', '') # config/default_allow_list.csv
|
229 |
+
|
230 |
+
S3_ALLOW_LIST_PATH = get_or_create_env_var('S3_ALLOW_LIST_PATH', '') # default_allow_list.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
|
231 |
+
|
232 |
+
if ALLOW_LIST_PATH: OUTPUT_ALLOW_LIST_PATH = ALLOW_LIST_PATH
|
233 |
+
else: OUTPUT_ALLOW_LIST_PATH = 'config/default_allow_list.csv'
|
234 |
+
|
235 |
+
SHOW_COSTS = get_or_create_env_var('SHOW_COSTS', 'False')
|
236 |
+
|
237 |
+
GET_COST_CODES = get_or_create_env_var('GET_COST_CODES', 'False')
|
238 |
+
|
239 |
+
DEFAULT_COST_CODE = get_or_create_env_var('DEFAULT_COST_CODE', '')
|
240 |
+
|
241 |
+
COST_CODES_PATH = get_or_create_env_var('COST_CODES_PATH', '') # 'config/COST_CENTRES.csv' # file should be a csv file with a single table in it that has two columns with a header. First column should contain cost codes, second column should contain a name or description for the cost code
|
242 |
+
|
243 |
+
S3_COST_CODES_PATH = get_or_create_env_var('S3_COST_CODES_PATH', '') # COST_CENTRES.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
|
244 |
+
|
245 |
+
if COST_CODES_PATH: OUTPUT_COST_CODES_PATH = COST_CODES_PATH
|
246 |
+
else: OUTPUT_COST_CODES_PATH = 'config/COST_CENTRES.csv'
|
247 |
+
|
248 |
+
ENFORCE_COST_CODES = get_or_create_env_var('ENFORCE_COST_CODES', 'False') # If you have cost codes listed, is it compulsory to choose one before redacting?
|
249 |
+
|
250 |
+
if ENFORCE_COST_CODES == 'True': GET_COST_CODES = 'True'
|
chatfuncs/helper_functions.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
|
|
|
|
|
|
4 |
|
5 |
def get_or_create_env_var(var_name, default_value):
|
6 |
# Get the environment variable if it exists
|
@@ -13,12 +16,6 @@ def get_or_create_env_var(var_name, default_value):
|
|
13 |
|
14 |
return value
|
15 |
|
16 |
-
# Retrieving or setting output folder
|
17 |
-
env_var_name = 'GRADIO_OUTPUT_FOLDER'
|
18 |
-
default_value = 'output/'
|
19 |
-
|
20 |
-
output_folder = get_or_create_env_var(env_var_name, default_value)
|
21 |
-
print(f'The value of {env_var_name} is {output_folder}')
|
22 |
|
23 |
def get_file_path_with_extension(file_path):
|
24 |
# First, get the basename of the file (e.g., "example.txt" from "/path/to/example.txt")
|
@@ -165,64 +162,129 @@ def wipe_logs(feedback_logs_loc, usage_logs_loc):
|
|
165 |
|
166 |
|
167 |
|
168 |
-
async def get_connection_params(request: gr.Request):
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
else:
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
-
|
205 |
-
|
206 |
-
if request.username:
|
207 |
-
out_session_hash = request.username
|
208 |
-
base_folder = "user-files/"
|
209 |
-
print("Request username found:", out_session_hash)
|
210 |
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
print("Cognito ID found:", out_session_hash)
|
215 |
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
# print("Cognito ID not found. Using session hash as save folder:", out_session_hash)
|
220 |
|
221 |
-
|
222 |
-
|
223 |
-
# print("S3 output folder is: " + "s3://" + bucket_name + "/" + output_folder)
|
224 |
|
225 |
-
|
226 |
-
else:
|
227 |
-
print("No session parameters found.")
|
228 |
-
return "",""
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
+
import boto3
|
5 |
+
from botocore.exceptions import ClientError
|
6 |
+
from chatfuncs.config import CUSTOM_HEADER_VALUE, CUSTOM_HEADER, OUTPUT_FOLDER, INPUT_FOLDER, SESSION_OUTPUT_FOLDER, AWS_USER_POOL_ID
|
7 |
|
8 |
def get_or_create_env_var(var_name, default_value):
|
9 |
# Get the environment variable if it exists
|
|
|
16 |
|
17 |
return value
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def get_file_path_with_extension(file_path):
|
21 |
# First, get the basename of the file (e.g., "example.txt" from "/path/to/example.txt")
|
|
|
162 |
|
163 |
|
164 |
|
165 |
+
# async def get_connection_params(request: gr.Request):
|
166 |
+
# base_folder = ""
|
167 |
+
|
168 |
+
# if request:
|
169 |
+
# #print("request user:", request.username)
|
170 |
+
|
171 |
+
# #request_data = await request.json() # Parse JSON body
|
172 |
+
# #print("All request data:", request_data)
|
173 |
+
# #context_value = request_data.get('context')
|
174 |
+
# #if 'context' in request_data:
|
175 |
+
# # print("Request context dictionary:", request_data['context'])
|
176 |
+
|
177 |
+
# # print("Request headers dictionary:", request.headers)
|
178 |
+
# # print("All host elements", request.client)
|
179 |
+
# # print("IP address:", request.client.host)
|
180 |
+
# # print("Query parameters:", dict(request.query_params))
|
181 |
+
# # To get the underlying FastAPI items you would need to use await and some fancy @ stuff for a live query: https://fastapi.tiangolo.com/vi/reference/request/
|
182 |
+
# #print("Request dictionary to object:", request.request.body())
|
183 |
+
# print("Session hash:", request.session_hash)
|
184 |
+
|
185 |
+
# # Retrieving or setting CUSTOM_CLOUDFRONT_HEADER
|
186 |
+
# CUSTOM_CLOUDFRONT_HEADER_var = get_or_create_env_var('CUSTOM_CLOUDFRONT_HEADER', '')
|
187 |
+
# #print(f'The value of CUSTOM_CLOUDFRONT_HEADER is {CUSTOM_CLOUDFRONT_HEADER_var}')
|
188 |
+
|
189 |
+
# # Retrieving or setting CUSTOM_CLOUDFRONT_HEADER_VALUE
|
190 |
+
# CUSTOM_CLOUDFRONT_HEADER_VALUE_var = get_or_create_env_var('CUSTOM_CLOUDFRONT_HEADER_VALUE', '')
|
191 |
+
# #print(f'The value of CUSTOM_CLOUDFRONT_HEADER_VALUE_var is {CUSTOM_CLOUDFRONT_HEADER_VALUE_var}')
|
192 |
+
|
193 |
+
# if CUSTOM_CLOUDFRONT_HEADER_var and CUSTOM_CLOUDFRONT_HEADER_VALUE_var:
|
194 |
+
# if CUSTOM_CLOUDFRONT_HEADER_var in request.headers:
|
195 |
+
# supplied_cloudfront_custom_value = request.headers[CUSTOM_CLOUDFRONT_HEADER_var]
|
196 |
+
# if supplied_cloudfront_custom_value == CUSTOM_CLOUDFRONT_HEADER_VALUE_var:
|
197 |
+
# print("Custom Cloudfront header found:", supplied_cloudfront_custom_value)
|
198 |
+
# else:
|
199 |
+
# raise(ValueError, "Custom Cloudfront header value does not match expected value.")
|
200 |
+
|
201 |
+
# # Get output save folder from 1 - username passed in from direct Cognito login, 2 - Cognito ID header passed through a Lambda authenticator, 3 - the session hash.
|
202 |
+
|
203 |
+
# if request.username:
|
204 |
+
# out_session_hash = request.username
|
205 |
+
# base_folder = "user-files/"
|
206 |
+
# print("Request username found:", out_session_hash)
|
207 |
+
|
208 |
+
# elif 'x-cognito-id' in request.headers:
|
209 |
+
# out_session_hash = request.headers['x-cognito-id']
|
210 |
+
# base_folder = "user-files/"
|
211 |
+
# print("Cognito ID found:", out_session_hash)
|
212 |
+
|
213 |
+
# else:
|
214 |
+
# out_session_hash = request.session_hash
|
215 |
+
# base_folder = "temp-files/"
|
216 |
+
# # print("Cognito ID not found. Using session hash as save folder:", out_session_hash)
|
217 |
+
|
218 |
+
# output_folder = base_folder + out_session_hash + "/"
|
219 |
+
# #if bucket_name:
|
220 |
+
# # print("S3 output folder is: " + "s3://" + bucket_name + "/" + output_folder)
|
221 |
+
|
222 |
+
# return out_session_hash, output_folder, out_session_hash
|
223 |
+
# else:
|
224 |
+
# print("No session parameters found.")
|
225 |
+
# return "",""
|
226 |
+
|
227 |
+
async def get_connection_params(request: gr.Request,
|
228 |
+
output_folder_textbox:str=OUTPUT_FOLDER,
|
229 |
+
input_folder_textbox:str=INPUT_FOLDER,
|
230 |
+
session_output_folder:str=SESSION_OUTPUT_FOLDER):
|
231 |
+
|
232 |
+
#print("Session hash:", request.session_hash)
|
233 |
+
|
234 |
+
if CUSTOM_HEADER and CUSTOM_HEADER_VALUE:
|
235 |
+
if CUSTOM_HEADER in request.headers:
|
236 |
+
supplied_custom_header_value = request.headers[CUSTOM_HEADER]
|
237 |
+
if supplied_custom_header_value == CUSTOM_HEADER_VALUE:
|
238 |
+
print("Custom header supplied and matches CUSTOM_HEADER_VALUE")
|
239 |
else:
|
240 |
+
print("Custom header value does not match expected value.")
|
241 |
+
raise ValueError("Custom header value does not match expected value.")
|
242 |
+
else:
|
243 |
+
print("Custom header value not found.")
|
244 |
+
raise ValueError("Custom header value not found.")
|
245 |
+
|
246 |
+
# Get output save folder from 1 - username passed in from direct Cognito login, 2 - Cognito ID header passed through a Lambda authenticator, 3 - the session hash.
|
247 |
+
|
248 |
+
if request.username:
|
249 |
+
out_session_hash = request.username
|
250 |
+
#print("Request username found:", out_session_hash)
|
251 |
+
|
252 |
+
elif 'x-cognito-id' in request.headers:
|
253 |
+
out_session_hash = request.headers['x-cognito-id']
|
254 |
+
#print("Cognito ID found:", out_session_hash)
|
255 |
+
|
256 |
+
elif 'x-amzn-oidc-identity' in request.headers:
|
257 |
+
out_session_hash = request.headers['x-amzn-oidc-identity']
|
258 |
+
|
259 |
+
# Fetch email address using Cognito client
|
260 |
+
cognito_client = boto3.client('cognito-idp')
|
261 |
+
try:
|
262 |
+
response = cognito_client.admin_get_user(
|
263 |
+
UserPoolId=AWS_USER_POOL_ID, # Replace with your User Pool ID
|
264 |
+
Username=out_session_hash
|
265 |
+
)
|
266 |
+
email = next(attr['Value'] for attr in response['UserAttributes'] if attr['Name'] == 'email')
|
267 |
+
#print("Email address found:", email)
|
268 |
+
|
269 |
+
out_session_hash = email
|
270 |
+
except ClientError as e:
|
271 |
+
print("Error fetching user details:", e)
|
272 |
+
email = None
|
273 |
+
|
274 |
+
print("Cognito ID found:", out_session_hash)
|
275 |
|
276 |
+
else:
|
277 |
+
out_session_hash = request.session_hash
|
|
|
|
|
|
|
|
|
278 |
|
279 |
+
if session_output_folder == 'True':
|
280 |
+
output_folder = output_folder_textbox + out_session_hash + "/"
|
281 |
+
input_folder = input_folder_textbox + out_session_hash + "/"
|
|
|
282 |
|
283 |
+
else:
|
284 |
+
output_folder = output_folder_textbox
|
285 |
+
input_folder = input_folder_textbox
|
|
|
286 |
|
287 |
+
if not os.path.exists(output_folder): os.mkdir(output_folder)
|
288 |
+
if not os.path.exists(input_folder): os.mkdir(input_folder)
|
|
|
289 |
|
290 |
+
return out_session_hash, output_folder, out_session_hash, input_folder
|
|
|
|
|
|
chatfuncs/ingest.py
CHANGED
@@ -7,13 +7,14 @@ import requests
|
|
7 |
import pandas as pd
|
8 |
import dateutil.parser
|
9 |
from typing import Type, List
|
10 |
-
import shutil
|
11 |
|
12 |
-
from langchain_community.embeddings import HuggingFaceEmbeddings # HuggingFaceInstructEmbeddings,
|
13 |
from langchain_community.vectorstores.faiss import FAISS
|
14 |
#from langchain_community.vectorstores import Chroma
|
15 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
16 |
from langchain.docstore.document import Document
|
|
|
17 |
|
18 |
from bs4 import BeautifulSoup
|
19 |
from docx import Document as Doc
|
@@ -557,31 +558,24 @@ def docs_elements_from_csv_save(docs_path="documents.csv"):
|
|
557 |
|
558 |
# ## Create embeddings and save faiss vector store to the path specified in `save_to`
|
559 |
|
560 |
-
def
|
561 |
|
562 |
-
|
563 |
-
# embeddings_func = HuggingFaceInstructEmbeddings(model_name=model_name,
|
564 |
-
# embed_instruction="Represent the paragraph for retrieval: ",
|
565 |
-
# query_instruction="Represent the question for retrieving supporting documents: "
|
566 |
-
# )
|
567 |
|
568 |
-
|
569 |
-
embeddings_func = HuggingFaceEmbeddings(model_name=model_name)
|
570 |
|
571 |
-
|
572 |
|
573 |
-
|
574 |
|
575 |
-
|
576 |
-
|
577 |
-
def embed_faiss_save_to_zip(docs_out, save_to="output", model_name="BAAI/bge-base-en-v1.5"):
|
578 |
-
load_embeddings(model_name=model_name)
|
579 |
|
580 |
print(f"> Total split documents: {len(docs_out)}")
|
581 |
|
582 |
-
vectorstore = FAISS.from_documents(documents=docs_out, embedding=
|
583 |
|
584 |
-
save_to_path = Path(save_to)
|
585 |
save_to_path.mkdir(parents=True, exist_ok=True)
|
586 |
|
587 |
vectorstore.save_local(folder_path=str(save_to_path))
|
@@ -619,20 +613,20 @@ def embed_faiss_save_to_zip(docs_out, save_to="output", model_name="BAAI/bge-bas
|
|
619 |
|
620 |
|
621 |
|
622 |
-
def sim_search_local_saved_vec(query, k_val, save_to="faiss_lambeth_census_embedding"):
|
623 |
|
624 |
-
|
625 |
|
626 |
-
|
627 |
|
628 |
|
629 |
-
|
630 |
|
631 |
-
|
632 |
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
|
|
7 |
import pandas as pd
|
8 |
import dateutil.parser
|
9 |
from typing import Type, List
|
10 |
+
#import shutil
|
11 |
|
12 |
+
#from langchain_community.embeddings import HuggingFaceEmbeddings # HuggingFaceInstructEmbeddings,
|
13 |
from langchain_community.vectorstores.faiss import FAISS
|
14 |
#from langchain_community.vectorstores import Chroma
|
15 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
16 |
from langchain.docstore.document import Document
|
17 |
+
#from chatfuncs.config import EMBEDDINGS_MODEL_NAME
|
18 |
|
19 |
from bs4 import BeautifulSoup
|
20 |
from docx import Document as Doc
|
|
|
558 |
|
559 |
# ## Create embeddings and save faiss vector store to the path specified in `save_to`
|
560 |
|
561 |
+
# def load_embeddings_model(embeddings_model = EMBEDDINGS_MODEL_NAME):
|
562 |
|
563 |
+
# embeddings_func = HuggingFaceEmbeddings(model_name=embeddings_model)
|
|
|
|
|
|
|
|
|
564 |
|
565 |
+
# #global embeddings
|
|
|
566 |
|
567 |
+
# #embeddings = embeddings_func
|
568 |
|
569 |
+
# return embeddings_func
|
570 |
|
571 |
+
def embed_faiss_save_to_zip(docs_out, save_folder, embeddings_model_object, save_to="faiss_embeddings", model_name="BAAI/bge-base-en-v1.5"):
|
572 |
+
#load_embeddings(model_name=model_name)
|
|
|
|
|
573 |
|
574 |
print(f"> Total split documents: {len(docs_out)}")
|
575 |
|
576 |
+
vectorstore = FAISS.from_documents(documents=docs_out, embedding=embeddings_model_object)
|
577 |
|
578 |
+
save_to_path = Path(save_folder, save_to)
|
579 |
save_to_path.mkdir(parents=True, exist_ok=True)
|
580 |
|
581 |
vectorstore.save_local(folder_path=str(save_to_path))
|
|
|
613 |
|
614 |
|
615 |
|
616 |
+
# def sim_search_local_saved_vec(query, k_val, save_to="faiss_lambeth_census_embedding"):
|
617 |
|
618 |
+
# load_embeddings()
|
619 |
|
620 |
+
# docsearch = FAISS.load_local(folder_path=save_to, embeddings=embeddings)
|
621 |
|
622 |
|
623 |
+
# display(Markdown(question))
|
624 |
|
625 |
+
# search = docsearch.similarity_search_with_score(query, k=k_val)
|
626 |
|
627 |
+
# for item in search:
|
628 |
+
# print(item[0].page_content)
|
629 |
+
# print(f"Page: {item[0].metadata['source']}")
|
630 |
+
# print(f"Date: {item[0].metadata['date']}")
|
631 |
+
# print(f"Score: {item[1]}")
|
632 |
+
# print("---")
|
chatfuncs/model_load.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# Currently set gpu_layers to 0 even with cuda due to persistent bugs in implementation with cuda
|
4 |
+
if torch.cuda.is_available():
|
5 |
+
torch_device = "cuda"
|
6 |
+
gpu_layers = 100
|
7 |
+
else:
|
8 |
+
torch_device = "cpu"
|
9 |
+
gpu_layers = 0
|
10 |
+
|
11 |
+
print("Running on device:", torch_device)
|
12 |
+
threads = 8 #torch.get_num_threads()
|
13 |
+
print("CPU threads:", threads)
|
14 |
+
|
15 |
+
# Qwen 2 0.5B (small, fast) Model parameters
|
16 |
+
temperature: float = 0.1
|
17 |
+
top_k: int = 3
|
18 |
+
top_p: float = 1
|
19 |
+
repetition_penalty: float = 1.15
|
20 |
+
flan_alpaca_repetition_penalty: float = 1.3
|
21 |
+
last_n_tokens: int = 64
|
22 |
+
max_new_tokens: int = 1024
|
23 |
+
seed: int = 42
|
24 |
+
reset: bool = False
|
25 |
+
stream: bool = True
|
26 |
+
threads: int = threads
|
27 |
+
batch_size:int = 256
|
28 |
+
context_length:int = 2048
|
29 |
+
sample = True
|
30 |
+
|
31 |
+
# Bedrock parameters
|
32 |
+
max_tokens = 4096
|
33 |
+
|
34 |
+
|
35 |
+
class CtransInitConfig_gpu:
|
36 |
+
def __init__(self,
|
37 |
+
last_n_tokens=last_n_tokens,
|
38 |
+
seed=seed,
|
39 |
+
n_threads=threads,
|
40 |
+
n_batch=batch_size,
|
41 |
+
n_ctx=max_tokens,
|
42 |
+
n_gpu_layers=gpu_layers):
|
43 |
+
|
44 |
+
self.last_n_tokens = last_n_tokens
|
45 |
+
self.seed = seed
|
46 |
+
self.n_threads = n_threads
|
47 |
+
self.n_batch = n_batch
|
48 |
+
self.n_ctx = n_ctx
|
49 |
+
self.n_gpu_layers = n_gpu_layers
|
50 |
+
# self.stop: list[str] = field(default_factory=lambda: [stop_string])
|
51 |
+
|
52 |
+
def update_gpu(self, new_value):
|
53 |
+
self.n_gpu_layers = new_value
|
54 |
+
|
55 |
+
class CtransInitConfig_cpu(CtransInitConfig_gpu):
|
56 |
+
def __init__(self):
|
57 |
+
super().__init__()
|
58 |
+
self.n_gpu_layers = 0
|
59 |
+
|
60 |
+
gpu_config = CtransInitConfig_gpu()
|
61 |
+
cpu_config = CtransInitConfig_cpu()
|
62 |
+
|
63 |
+
|
64 |
+
class CtransGenGenerationConfig:
|
65 |
+
def __init__(self, temperature=temperature,
|
66 |
+
top_k=top_k,
|
67 |
+
top_p=top_p,
|
68 |
+
repeat_penalty=repetition_penalty,
|
69 |
+
seed=seed,
|
70 |
+
stream=stream,
|
71 |
+
max_tokens=max_new_tokens
|
72 |
+
):
|
73 |
+
self.temperature = temperature
|
74 |
+
self.top_k = top_k
|
75 |
+
self.top_p = top_p
|
76 |
+
self.repeat_penalty = repeat_penalty
|
77 |
+
self.seed = seed
|
78 |
+
self.max_tokens=max_tokens
|
79 |
+
self.stream = stream
|
80 |
+
|
81 |
+
def update_temp(self, new_value):
|
82 |
+
self.temperature = new_value
|
chatfuncs/prompts.py
CHANGED
@@ -23,8 +23,7 @@ QUESTION - {question}
|
|
23 |
"""
|
24 |
|
25 |
|
26 |
-
instruction_prompt_template_orca = """
|
27 |
-
### System:
|
28 |
You are an AI assistant that follows instruction extremely well. Help as much as you can.
|
29 |
### User:
|
30 |
Answer the QUESTION with a short response using information from the following CONTENT.
|
@@ -33,8 +32,7 @@ CONTENT: {summaries}
|
|
33 |
|
34 |
### Response:"""
|
35 |
|
36 |
-
instruction_prompt_template_orca_quote = """
|
37 |
-
### System:
|
38 |
You are an AI assistant that follows instruction extremely well. Help as much as you can.
|
39 |
### User:
|
40 |
Quote text from the CONTENT to answer the QUESTION below.
|
@@ -73,4 +71,9 @@ Answer the QUESTION using information from the following CONTENT. Respond with s
|
|
73 |
CONTENT: {summaries}
|
74 |
QUESTION: {question}\n
|
75 |
Answer:<|im_end|>
|
76 |
-
<|im_start|>assistant\n"""
|
|
|
|
|
|
|
|
|
|
|
|
23 |
"""
|
24 |
|
25 |
|
26 |
+
instruction_prompt_template_orca = """### System:
|
|
|
27 |
You are an AI assistant that follows instruction extremely well. Help as much as you can.
|
28 |
### User:
|
29 |
Answer the QUESTION with a short response using information from the following CONTENT.
|
|
|
32 |
|
33 |
### Response:"""
|
34 |
|
35 |
+
instruction_prompt_template_orca_quote = """### System:
|
|
|
36 |
You are an AI assistant that follows instruction extremely well. Help as much as you can.
|
37 |
### User:
|
38 |
Quote text from the CONTENT to answer the QUESTION below.
|
|
|
71 |
CONTENT: {summaries}
|
72 |
QUESTION: {question}\n
|
73 |
Answer:<|im_end|>
|
74 |
+
<|im_start|>assistant\n"""
|
75 |
+
|
76 |
+
instruction_prompt_gemma = """Answer the QUESTION using information from the following CONTENT. Respond with short answers that directly answer the question.
|
77 |
+
CONTENT: {summaries}
|
78 |
+
QUESTION: {question}
|
79 |
+
assistant:"""
|
faiss_embedding/faiss_embedding.zip
CHANGED
Binary files a/faiss_embedding/faiss_embedding.zip and b/faiss_embedding/faiss_embedding.zip differ
|
|
requirements.txt
CHANGED
@@ -4,7 +4,7 @@ langchain-community==0.3.22
|
|
4 |
beautifulsoup4==4.13.4
|
5 |
google-generativeai==0.8.5
|
6 |
pandas==2.2.3
|
7 |
-
transformers==4.
|
8 |
# For Windows https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.2/llama_cpp_python-0.3.2-cp311-cp311-win_amd64.whl -C cmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS"
|
9 |
llama-cpp-python==0.3.2 --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu
|
10 |
#-C cmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS"
|
@@ -22,3 +22,4 @@ PyStemmer==2.2.0.3
|
|
22 |
scipy==1.15.2
|
23 |
numpy==1.26.4
|
24 |
boto3==1.38.0
|
|
|
|
4 |
beautifulsoup4==4.13.4
|
5 |
google-generativeai==0.8.5
|
6 |
pandas==2.2.3
|
7 |
+
transformers==4.51.3
|
8 |
# For Windows https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.2/llama_cpp_python-0.3.2-cp311-cp311-win_amd64.whl -C cmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS"
|
9 |
llama-cpp-python==0.3.2 --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu
|
10 |
#-C cmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS"
|
|
|
22 |
scipy==1.15.2
|
23 |
numpy==1.26.4
|
24 |
boto3==1.38.0
|
25 |
+
python-dotenv==1.1.0
|
requirements_gpu.txt
CHANGED
@@ -20,4 +20,5 @@ bm25s==0.2.12
|
|
20 |
PyStemmer==2.2.0.3
|
21 |
scipy==1.15.2
|
22 |
numpy==1.26.4
|
23 |
-
boto3==1.38.0
|
|
|
|
20 |
PyStemmer==2.2.0.3
|
21 |
scipy==1.15.2
|
22 |
numpy==1.26.4
|
23 |
+
boto3==1.38.0
|
24 |
+
python-dotenv==1.1.0
|