Sean Pedrick-Case commited on
Commit
a1ae0af
·
unverified ·
2 Parent(s): 0c818aa 03afd76

Merge pull request #1 from seanpedrick-case/dev

Browse files

Added Gemini and AWS Bedrock compatibility. Gemma model. Now document redaction QA.

.dockerignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ *.ipynb
3
+ *.pdf
4
+ *.spec
5
+ *.toc
6
+ *.csv
7
+ *.bin
8
+ bootstrapper.py
9
+ build/*
10
+ dist/*
11
+ test/*
12
+ config/*
13
+ output/*
14
+ input/*
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.zip filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -8,4 +8,7 @@
8
  bootstrapper.py
9
  build/*
10
  dist/*
11
- test/*
 
 
 
 
8
  bootstrapper.py
9
  build/*
10
  dist/*
11
+ test/*
12
+ config/*
13
+ output/*
14
+ input/*
app.py CHANGED
@@ -1,95 +1,122 @@
1
- # Load in packages
2
-
3
  import os
4
- import socket
5
-
6
  from typing import Type
7
- from langchain_huggingface.embeddings import HuggingFaceEmbeddings#, HuggingFaceInstructEmbeddings
8
  from langchain_community.vectorstores import FAISS
9
  import gradio as gr
10
  import pandas as pd
11
-
12
- from transformers import AutoTokenizer
13
- import torch
14
-
15
  from llama_cpp import Llama
16
  from huggingface_hub import hf_hub_download
 
 
 
17
  from chatfuncs.ingest import embed_faiss_save_to_zip
18
- from chatfuncs.helper_functions import get_or_create_env_var
19
 
20
- from chatfuncs.helper_functions import ensure_output_folder_exists, get_connection_params, output_folder, get_or_create_env_var, reveal_feedback_buttons, wipe_logs
21
  from chatfuncs.aws_functions import upload_file_to_s3
22
- #from chatfuncs.llm_api_call import llm_query
23
  from chatfuncs.auth import authenticate_user
 
 
 
 
24
 
25
  PandasDataFrame = Type[pd.DataFrame]
26
 
27
  from datetime import datetime
28
  today_rev = datetime.now().strftime("%Y%m%d")
29
 
30
- ensure_output_folder_exists()
 
 
 
31
 
32
- host_name = socket.gethostname()
33
-
34
- access_logs_data_folder = 'logs/' + today_rev + '/' + host_name + '/'
35
- feedback_data_folder = 'feedback/' + today_rev + '/' + host_name + '/'
36
- usage_data_folder = 'usage/' + today_rev + '/' + host_name + '/'
37
 
38
  # Disable cuda devices if necessary
39
  #os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
40
 
41
- #from chatfuncs.chatfuncs import *
42
- import chatfuncs.ingest as ing
43
 
44
  ###
45
  # Load preset embeddings, vectorstore, and model
46
  ###
47
 
48
- embeddings_name = "BAAI/bge-base-en-v1.5" #"mixedbread-ai/mxbai-embed-xsmall-v1"
49
 
50
- def load_embeddings(embeddings_name = embeddings_name):
51
 
52
- embeddings_func = HuggingFaceEmbeddings(model_name=embeddings_name)
53
 
54
- global embeddings
55
 
56
- embeddings = embeddings_func
57
 
58
- return embeddings
59
 
60
- def get_faiss_store(faiss_vstore_folder,embeddings):
61
- import zipfile
62
  with zipfile.ZipFile(faiss_vstore_folder + '/' + faiss_vstore_folder + '.zip', 'r') as zip_ref:
63
  zip_ref.extractall(faiss_vstore_folder)
64
 
65
- faiss_vstore = FAISS.load_local(folder_path=faiss_vstore_folder, embeddings=embeddings, allow_dangerous_deserialization=True)
66
  os.remove(faiss_vstore_folder + "/index.faiss")
67
  os.remove(faiss_vstore_folder + "/index.pkl")
68
 
69
- global vectorstore
70
 
71
- vectorstore = faiss_vstore
72
 
73
- return vectorstore
74
 
75
- import chatfuncs.chatfuncs as chatf
 
 
76
 
77
- chatf.embeddings = load_embeddings(embeddings_name)
78
- chatf.vectorstore = get_faiss_store(faiss_vstore_folder="faiss_embedding",embeddings=globals()["embeddings"])
79
 
 
80
 
81
- def load_model(model_type, gpu_layers, gpu_config=None, cpu_config=None, torch_device=None):
82
- print("Loading model")
 
83
 
84
- # Default values inside the function
85
- if gpu_config is None:
86
- gpu_config = chatf.gpu_config
87
- if cpu_config is None:
88
- cpu_config = chatf.cpu_config
89
- if torch_device is None:
90
- torch_device = chatf.torch_device
91
 
92
- if model_type == "Phi 3.5 Mini (larger, slow)":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  if torch_device == "cuda":
94
  gpu_config.update_gpu(gpu_layers)
95
  print("Loading with", gpu_config.n_gpu_layers, "model layers sent to GPU.")
@@ -99,86 +126,46 @@ def load_model(model_type, gpu_layers, gpu_config=None, cpu_config=None, torch_d
99
 
100
  print("Loading with", cpu_config.n_gpu_layers, "model layers sent to GPU.")
101
 
102
- print(vars(gpu_config))
103
- print(vars(cpu_config))
104
-
105
  try:
106
  model = Llama(
107
  model_path=hf_hub_download(
108
- repo_id=os.environ.get("REPO_ID", "QuantFactory/Phi-3.5-mini-instruct-GGUF"),# "QuantFactory/Phi-3-mini-128k-instruct-GGUF"), # "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF-v2"), #"microsoft/Phi-3-mini-4k-instruct-gguf"),#"TheBloke/Mistral-7B-OpenOrca-GGUF"),
109
- filename=os.environ.get("MODEL_FILE", "Phi-3.5-mini-instruct.Q4_K_M.gguf") #"Phi-3-mini-128k-instruct.Q4_K_M.gguf") #"Meta-Llama-3-8B-Instruct-v2.Q6_K.gguf") #"Phi-3-mini-4k-instruct-q4.gguf")#"mistral-7b-openorca.Q4_K_M.gguf"),
110
  ),
111
  **vars(gpu_config) # change n_gpu_layers if you have more or less VRAM
112
  )
113
 
114
  except Exception as e:
115
- print("GPU load failed")
116
- print(e)
117
  model = Llama(
118
  model_path=hf_hub_download(
119
- repo_id=os.environ.get("REPO_ID", "QuantFactory/Phi-3.5-mini-instruct-GGUF"), #"QuantFactory/Phi-3-mini-128k-instruct-GGUF"), #, "microsoft/Phi-3-mini-4k-instruct-gguf"),#"QuantFactory/Meta-Llama-3-8B-Instruct-GGUF-v2"), #"microsoft/Phi-3-mini-4k-instruct-gguf"),#"TheBloke/Mistral-7B-OpenOrca-GGUF"),
120
- filename=os.environ.get("MODEL_FILE", "Phi-3.5-mini-instruct.Q4_K_M.gguf"), # "Phi-3-mini-128k-instruct.Q4_K_M.gguf") # , #"Meta-Llama-3-8B-Instruct-v2.Q6_K.gguf") #"Phi-3-mini-4k-instruct-q4.gguf"),#"mistral-7b-openorca.Q4_K_M.gguf"),
121
  ),
122
  **vars(cpu_config)
123
  )
124
 
125
  tokenizer = []
126
 
127
- if model_type == "Qwen 2 0.5B (small, fast)":
128
  # Huggingface chat model
129
- hf_checkpoint = 'Qwen/Qwen2-0.5B-Instruct'# 'declare-lab/flan-alpaca-large'#'declare-lab/flan-alpaca-base' # # # 'Qwen/Qwen1.5-0.5B-Chat' #
130
 
131
- def create_hf_model(model_name):
132
 
133
- from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM
134
-
135
- if torch_device == "cuda":
136
- if "flan" in model_name:
137
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")#, torch_dtype=torch.float16)
138
- else:
139
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")#, torch_dtype=torch.float16)
140
- else:
141
- if "flan" in model_name:
142
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)#, torch_dtype=torch.float16)
143
- else:
144
- model = AutoModelForCausalLM.from_pretrained(model_name)#, trust_remote_code=True)#, torch_dtype=torch.float16)
145
-
146
- tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = chatf.context_length)
147
-
148
- return model, tokenizer, model_type
149
-
150
- model, tokenizer, model_type = create_hf_model(model_name = hf_checkpoint)
151
 
152
- chatf.model = model
153
  chatf.tokenizer = tokenizer
154
  chatf.model_type = model_type
155
 
156
  load_confirmation = "Finished loading model: " + model_type
157
 
158
  print(load_confirmation)
159
- return model_type, load_confirmation, model_type
160
-
161
- # Both models are loaded on app initialisation so that users don't have to wait for the models to be downloaded
162
- #model_type = "Phi 3.5 Mini (larger, slow)"
163
- #load_model(model_type, chatf.gpu_layers, chatf.gpu_config, chatf.cpu_config, chatf.torch_device)
164
-
165
- model_type = "Qwen 2 0.5B (small, fast)"
166
- load_model(model_type, 0, chatf.gpu_config, chatf.cpu_config, chatf.torch_device)
167
 
168
- def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
169
-
170
- print(f"> Total split documents: {len(docs_out)}")
171
-
172
- print(docs_out)
173
-
174
- vectorstore_func = FAISS.from_documents(documents=docs_out, embedding=embeddings)
175
-
176
- chatf.vectorstore = vectorstore_func
177
-
178
- out_message = "Document processing complete"
179
-
180
- return out_message, vectorstore_func
181
- # Gradio chat
182
 
183
 
184
  ###
@@ -188,24 +175,40 @@ def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
188
  app = gr.Blocks(theme = gr.themes.Base(), fill_width=True)#css=".gradio-container {background-color: black}")
189
 
190
  with app:
 
 
 
 
 
 
 
191
  ingest_text = gr.State()
192
  ingest_metadata = gr.State()
193
  ingest_docs = gr.State()
194
 
195
  model_type_state = gr.State(model_type)
196
- embeddings_state = gr.State(chatf.embeddings)#globals()["embeddings"])
197
- vectorstore_state = gr.State(chatf.vectorstore)#globals()["vectorstore"])
 
198
 
 
 
 
 
 
 
199
  relevant_query_state = gr.Checkbox(value=True, visible=False)
200
 
201
- model_state = gr.State() # chatf.model (gives error)
 
202
  tokenizer_state = gr.State() # chatf.tokenizer (gives error)
203
 
204
  chat_history_state = gr.State()
205
  instruction_prompt_out = gr.State()
206
 
207
  session_hash_state = gr.State()
208
- s3_output_folder_state = gr.State()
 
209
 
210
  session_hash_textbox = gr.Textbox(value="", visible=False)
211
  s3_logs_output_textbox = gr.Textbox(label="S3 logs", visible=False)
@@ -219,21 +222,18 @@ with app:
219
 
220
  gr.Markdown("<h1><center>Lightweight PDF / web page QA bot</center></h1>")
221
 
222
- gr.Markdown("Chat with PDF, web page or (new) csv/Excel documents. The default is a small model (Qwen 2 0.5B), that can only answer specific questions that are answered in the text. It cannot give overall impressions of, or summarise the document. The alternative (Phi 3.5 Mini (larger, slow)), can reason a little better, but is much slower (See Advanced tab).\n\nBy default the Lambeth Borough Plan '[Lambeth 2030 : Our Future, Our Lambeth](https://www.lambeth.gov.uk/better-fairer-lambeth/projects/lambeth-2030-our-future-our-lambeth)' is loaded. If you want to talk about another document or web page, please select from the second tab. If switching topic, please click the 'Clear chat' button.\n\nCaution: This is a public app. Please ensure that the document you upload is not sensitive is any way as other users may see it! Also, please note that LLM chatbots may give incomplete or incorrect information, so please use with care.")
223
-
224
- with gr.Accordion(label="Use Gemini or AWS Claude model", open=False, visible=False):
225
- api_model_choice = gr.Dropdown(value = "None", choices = ["gemini-1.5-flash-002", "gemini-1.5-pro-002", "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0", "None"], label="LLM model to use", multiselect=False, interactive=True, visible=False)
226
- in_api_key = gr.Textbox(value = "", label="Enter Gemini API key (only if using Google API models)", lines=1, type="password",interactive=True, visible=False)
227
 
228
  with gr.Row():
229
- current_source = gr.Textbox(label="Current data source(s)", value="Lambeth_2030-Our_Future_Our_Lambeth.pdf", scale = 10)
230
  current_model = gr.Textbox(label="Current model", value=model_type, scale = 3)
231
 
232
  with gr.Tab("Chatbot"):
233
 
234
  with gr.Row():
235
  #chat_height = 500
236
- chatbot = gr.Chatbot(avatar_images=('user.jfif', 'bot.jpg'), scale = 1, resizable=True, show_copy_all_button=True, show_copy_button=True, show_share_button=True, type='tuples') # , height=chat_height
237
  with gr.Accordion("Open this tab to see the source paragraphs used to generate the answer", open = True):
238
  sources = gr.HTML(value = "Source paragraphs with the most relevant text will appear here") # , height=chat_height
239
 
@@ -245,17 +245,12 @@ with app:
245
  with gr.Row():
246
  submit = gr.Button(value="Send message", variant="primary", scale = 4)
247
  clear = gr.Button(value="Clear chat", variant="secondary", scale=1)
248
- stop = gr.Button(value="Stop generating", variant="secondary", scale=1)
249
-
250
- examples_set = gr.Radio(label="Examples for the Lambeth Borough Plan",
251
- #value = "What were the five pillars of the previous borough plan?",
252
- choices=["What were the five pillars of the previous borough plan?",
253
- "What is the vision statement for Lambeth?",
254
- "What are the commitments for Lambeth?",
255
- "What are the 2030 outcomes for Lambeth?"])
256
-
257
- current_topic = gr.Textbox(label="Feature currently disabled - Keywords related to current conversation topic.", placeholder="Keywords related to the conversation topic will appear here", visible=False)
258
 
 
 
 
 
259
 
260
  with gr.Tab("Load in a different file to chat with"):
261
  with gr.Accordion("PDF file", open = False):
@@ -281,75 +276,91 @@ with app:
281
  out_passages = gr.Slider(minimum=1, value = 2, maximum=10, step=1, label="Choose number of passages to retrieve from the document. Numbers greater than 2 may lead to increased hallucinations or input text being truncated.")
282
  temp_slide = gr.Slider(minimum=0.1, value = 0.5, maximum=1, step=0.1, label="Choose temperature setting for response generation.")
283
  with gr.Row():
284
- model_choice = gr.Radio(label="Choose a chat model", value="Qwen 2 0.5B (small, fast)", choices = ["Qwen 2 0.5B (small, fast)", "Phi 3.5 Mini (larger, slow)"])
 
285
  change_model_button = gr.Button(value="Load model", scale=0)
286
  with gr.Accordion("Choose number of model layers to send to GPU (WARNING: please don't modify unless you are sure you have a GPU).", open = False):
287
  gpu_layer_choice = gr.Slider(label="Choose number of model layers to send to GPU.", value=0, minimum=0, maximum=100, step = 1, visible=True)
288
 
289
- load_text = gr.Text(label="Load status")
290
-
291
 
292
  gr.HTML(
293
- "<center>This app is based on the models Qwen 2 0.5B and Phi 3.5 Mini. It powered by Gradio, Transformers, and Llama.cpp.</a></center>"
294
  )
295
 
296
  examples_set.change(fn=chatf.update_message, inputs=[examples_set], outputs=[message])
297
 
298
- change_model_button.click(fn=chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
299
- success(fn=load_model, inputs=[model_choice, gpu_layer_choice], outputs = [model_type_state, load_text, current_model]).\
300
- success(lambda: chatf.restore_interactivity(), None, [message], queue=False).\
301
- success(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic]).\
302
- success(lambda: None, None, chatbot, queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  # Load in a pdf
305
  load_pdf_click = load_pdf.click(ing.parse_file, inputs=[in_pdf], outputs=[ingest_text, current_source]).\
306
  success(ing.text_to_docs, inputs=[ingest_text], outputs=[ingest_docs]).\
307
- success(embed_faiss_save_to_zip, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state, file_out_box]).\
308
  success(chatf.hide_block, outputs = [examples_set])
309
 
310
  # Load in a webpage
311
  load_web_click = load_web.click(ing.parse_html, inputs=[in_web, in_div], outputs=[ingest_text, ingest_metadata, current_source]).\
312
  success(ing.html_text_to_docs, inputs=[ingest_text, ingest_metadata], outputs=[ingest_docs]).\
313
- success(embed_faiss_save_to_zip, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state, file_out_box]).\
314
  success(chatf.hide_block, outputs = [examples_set])
315
 
316
  # Load in a csv/excel file
317
  load_csv_click = load_csv.click(ing.parse_csv_or_excel, inputs=[in_csv, in_text_column], outputs=[ingest_text, current_source]).\
318
  success(ing.csv_excel_text_to_docs, inputs=[ingest_text, in_text_column], outputs=[ingest_docs]).\
319
- success(embed_faiss_save_to_zip, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state, file_out_box]).\
320
  success(chatf.hide_block, outputs = [examples_set])
 
321
 
322
- # Load in a webpage
323
-
324
- # Click/enter to send message action
325
- response_click = submit.click(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages, api_model_choice, in_api_key], outputs=[chat_history_state, sources, instruction_prompt_out, relevant_query_state], queue=False, api_name="retrieval").\
326
- success(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
327
- success(chatf.produce_streaming_answer_chatbot, inputs=[chatbot, instruction_prompt_out, model_type_state, temp_slide, relevant_query_state], outputs=chatbot)
328
- response_click.success(chatf.highlight_found_text, [chatbot, sources], [sources]).\
329
- success(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
330
- success(lambda: chatf.restore_interactivity(), None, [message], queue=False)
331
-
332
- response_enter = message.submit(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages, api_model_choice, in_api_key], outputs=[chat_history_state, sources, instruction_prompt_out, relevant_query_state], queue=False).\
333
- success(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
334
- success(chatf.produce_streaming_answer_chatbot, [chatbot, instruction_prompt_out, model_type_state, temp_slide, relevant_query_state], chatbot)
335
- response_enter.success(chatf.highlight_found_text, [chatbot, sources], [sources]).\
336
- success(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
337
- success(lambda: chatf.restore_interactivity(), None, [message], queue=False)
338
-
339
- # Stop box
340
- stop.click(fn=None, inputs=None, outputs=None, cancels=[response_click, response_enter])
341
-
342
- # Clear box
343
- clear.click(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic])
344
- clear.click(lambda: None, None, chatbot, queue=False)
345
 
346
- # Thumbs up or thumbs down voting function
347
- chatbot.like(chatf.vote, [chat_history_state, instruction_prompt_out, model_type_state], None)
 
 
 
348
 
349
  ###
350
  # LOGGING AND ON APP LOAD FUNCTIONS
351
- ###
352
- app.load(get_connection_params, inputs=None, outputs=[session_hash_state, s3_output_folder_state, session_hash_textbox])
 
 
 
353
 
354
  # Log usernames and times of access to file (to know who is using the app when running on AWS)
355
  access_callback = gr.CSVLogger()
@@ -358,12 +369,8 @@ with app:
358
  session_hash_textbox.change(lambda *args: access_callback.flag(list(args)), [session_hash_textbox], None, preprocess=False).\
359
  success(fn = upload_file_to_s3, inputs=[access_logs_state, access_s3_logs_loc_state], outputs=[s3_logs_output_textbox])
360
 
361
- # Launch the Gradio app
362
- COGNITO_AUTH = get_or_create_env_var('COGNITO_AUTH', '0')
363
- print(f'The value of COGNITO_AUTH is {COGNITO_AUTH}')
364
-
365
  if __name__ == "__main__":
366
- if os.environ['COGNITO_AUTH'] == "1":
367
- app.queue().launch(show_error=True, auth=authenticate_user, max_file_size='50mb')
368
  else:
369
- app.queue().launch(show_error=True, inbrowser=True, max_file_size='50mb')
 
 
 
1
  import os
 
 
2
  from typing import Type
3
+ from langchain_huggingface.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
  import gradio as gr
6
  import pandas as pd
7
+ from torch import float16
 
 
 
8
  from llama_cpp import Llama
9
  from huggingface_hub import hf_hub_download
10
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
11
+ import zipfile
12
+
13
  from chatfuncs.ingest import embed_faiss_save_to_zip
 
14
 
15
+ from chatfuncs.helper_functions import get_connection_params, reveal_feedback_buttons, wipe_logs
16
  from chatfuncs.aws_functions import upload_file_to_s3
 
17
  from chatfuncs.auth import authenticate_user
18
+ from chatfuncs.config import FEEDBACK_LOGS_FOLDER, ACCESS_LOGS_FOLDER, USAGE_LOGS_FOLDER, HOST_NAME, COGNITO_AUTH, INPUT_FOLDER, OUTPUT_FOLDER, MAX_QUEUE_SIZE, DEFAULT_CONCURRENCY_LIMIT, MAX_FILE_SIZE, GRADIO_SERVER_PORT, ROOT_PATH, DEFAULT_EMBEDDINGS_LOCATION, EMBEDDINGS_MODEL_NAME, DEFAULT_DATA_SOURCE, HF_TOKEN, LARGE_MODEL_REPO_ID, LARGE_MODEL_GGUF_FILE, LARGE_MODEL_NAME, SMALL_MODEL_NAME, SMALL_MODEL_REPO_ID, DEFAULT_DATA_SOURCE_NAME, DEFAULT_EXAMPLES, DEFAULT_MODEL_CHOICES
19
+ from chatfuncs.model_load import torch_device, gpu_config, cpu_config, context_length
20
+ import chatfuncs.chatfuncs as chatf
21
+ import chatfuncs.ingest as ing
22
 
23
  PandasDataFrame = Type[pd.DataFrame]
24
 
25
  from datetime import datetime
26
  today_rev = datetime.now().strftime("%Y%m%d")
27
 
28
+ host_name = HOST_NAME
29
+ access_logs_data_folder = ACCESS_LOGS_FOLDER
30
+ feedback_data_folder = FEEDBACK_LOGS_FOLDER
31
+ usage_data_folder = USAGE_LOGS_FOLDER
32
 
33
+ if isinstance(DEFAULT_EXAMPLES, str): default_examples_set = eval(DEFAULT_EXAMPLES)
34
+ if isinstance(DEFAULT_MODEL_CHOICES, str): default_model_choices = eval(DEFAULT_MODEL_CHOICES)
 
 
 
35
 
36
  # Disable cuda devices if necessary
37
  #os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
38
 
 
 
39
 
40
  ###
41
  # Load preset embeddings, vectorstore, and model
42
  ###
43
 
44
+ def load_embeddings_model(embeddings_model = EMBEDDINGS_MODEL_NAME):
45
 
46
+ embeddings_func = HuggingFaceEmbeddings(model_name=embeddings_model)
47
 
48
+ #global embeddings
49
 
50
+ #embeddings = embeddings_func
51
 
52
+ return embeddings_func
53
 
54
+ def get_faiss_store(faiss_vstore_folder:str, embeddings_model:object):
55
 
 
 
56
  with zipfile.ZipFile(faiss_vstore_folder + '/' + faiss_vstore_folder + '.zip', 'r') as zip_ref:
57
  zip_ref.extractall(faiss_vstore_folder)
58
 
59
+ faiss_vstore = FAISS.load_local(folder_path=faiss_vstore_folder, embeddings=embeddings_model, allow_dangerous_deserialization=True)
60
  os.remove(faiss_vstore_folder + "/index.faiss")
61
  os.remove(faiss_vstore_folder + "/index.pkl")
62
 
63
+ #global vectorstore
64
 
65
+ #vectorstore = faiss_vstore
66
 
67
+ return faiss_vstore #vectorstore
68
 
69
+ # Load in default embeddings and embeddings model name
70
+ embeddings_model = load_embeddings_model(EMBEDDINGS_MODEL_NAME)
71
+ vectorstore = get_faiss_store(faiss_vstore_folder=DEFAULT_EMBEDDINGS_LOCATION,embeddings_model=embeddings_model)#globals()["embeddings"])
72
 
73
+ chatf.embeddings = embeddings_model
74
+ chatf.vectorstore = vectorstore
75
 
76
+ def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings_model=embeddings_model):
77
 
78
+ print(f"> Total split documents: {len(docs_out)}")
79
+
80
+ print(docs_out)
81
 
82
+ vectorstore_func = FAISS.from_documents(documents=docs_out, embedding=embeddings_model)
 
 
 
 
 
 
83
 
84
+ chatf.vectorstore = vectorstore_func
85
+
86
+ out_message = "Document processing complete"
87
+
88
+ return out_message, vectorstore_func
89
+
90
+
91
+ def create_hf_model(model_name:str, hf_token=HF_TOKEN):
92
+ if torch_device == "cuda":
93
+ if "flan" in model_name:
94
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")#, torch_dtype=torch.float16)
95
+ else:
96
+ if hf_token:
97
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", token=hf_token) # , torch_dtype=float16
98
+ else:
99
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") # , torch_dtype=float16
100
+ else:
101
+ if "flan" in model_name:
102
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)#, torch_dtype=torch.float16)
103
+ else:
104
+ if hf_token:
105
+ model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token) # , torch_dtype=float16
106
+ else:
107
+ model = AutoModelForCausalLM.from_pretrained(model_name) # , torch_dtype=float16
108
+
109
+ if hf_token:
110
+ tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = context_length, token=hf_token)
111
+ else:
112
+ tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = context_length)
113
+
114
+ return model, tokenizer
115
+
116
+ def load_model(model_type:str, gpu_layers:int, gpu_config:dict=gpu_config, cpu_config:dict=cpu_config, torch_device:str=torch_device):
117
+ print("Loading model")
118
+
119
+ if model_type == LARGE_MODEL_NAME:
120
  if torch_device == "cuda":
121
  gpu_config.update_gpu(gpu_layers)
122
  print("Loading with", gpu_config.n_gpu_layers, "model layers sent to GPU.")
 
126
 
127
  print("Loading with", cpu_config.n_gpu_layers, "model layers sent to GPU.")
128
 
 
 
 
129
  try:
130
  model = Llama(
131
  model_path=hf_hub_download(
132
+ repo_id=LARGE_MODEL_REPO_ID,
133
+ filename=LARGE_MODEL_GGUF_FILE
134
  ),
135
  **vars(gpu_config) # change n_gpu_layers if you have more or less VRAM
136
  )
137
 
138
  except Exception as e:
139
+ print("GPU load failed", e, "loading CPU version instead")
 
140
  model = Llama(
141
  model_path=hf_hub_download(
142
+ repo_id=LARGE_MODEL_REPO_ID,
143
+ filename=LARGE_MODEL_GGUF_FILE
144
  ),
145
  **vars(cpu_config)
146
  )
147
 
148
  tokenizer = []
149
 
150
+ if model_type == SMALL_MODEL_NAME:
151
  # Huggingface chat model
152
+ hf_checkpoint = SMALL_MODEL_REPO_ID# 'declare-lab/flan-alpaca-large'#'declare-lab/flan-alpaca-base' # # # 'Qwen/Qwen1.5-0.5B-Chat' #
153
 
154
+ model, tokenizer = create_hf_model(model_name = hf_checkpoint)
155
 
156
+ else:
157
+ model = model_type
158
+ tokenizer = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ chatf.model_object = model
161
  chatf.tokenizer = tokenizer
162
  chatf.model_type = model_type
163
 
164
  load_confirmation = "Finished loading model: " + model_type
165
 
166
  print(load_confirmation)
 
 
 
 
 
 
 
 
167
 
168
+ return model_type, load_confirmation, model_type#model, tokenizer, model_type
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  ###
 
175
  app = gr.Blocks(theme = gr.themes.Base(), fill_width=True)#css=".gradio-container {background-color: black}")
176
 
177
  with app:
178
+ model_type = SMALL_MODEL_NAME
179
+ load_model(model_type, 0, gpu_config, cpu_config, torch_device) # chatf.model_object, chatf.tokenizer, chatf.model_type =
180
+
181
+ # Both models are loaded on app initialisation so that users don't have to wait for the models to be downloaded
182
+ #model_type = "Phi 3.5 Mini (larger, slow)"
183
+ #load_model(model_type, gpu_layers, gpu_config, cpu_config, torch_device)
184
+
185
  ingest_text = gr.State()
186
  ingest_metadata = gr.State()
187
  ingest_docs = gr.State()
188
 
189
  model_type_state = gr.State(model_type)
190
+ gpu_config_state = gr.State(gpu_config)
191
+ cpu_config_state = gr.State(cpu_config)
192
+ torch_device_state = gr.State(torch_device)
193
 
194
+ # Embeddings related vars
195
+ embeddings_model_object_state = gr.State(embeddings_model)#globals()["embeddings"])
196
+ vectorstore_state = gr.State(vectorstore)#globals()["vectorstore"])
197
+ default_embeddings_store_text = gr.Textbox(value=DEFAULT_EMBEDDINGS_LOCATION, visible=False)
198
+
199
+ # Is the query relevant to the sources provided?
200
  relevant_query_state = gr.Checkbox(value=True, visible=False)
201
 
202
+ # Storing model objects in state doesn't seem to work, so we have to load in different models in roundabout ways
203
+ model_state = gr.State() # chatf.model_object (gives error)
204
  tokenizer_state = gr.State() # chatf.tokenizer (gives error)
205
 
206
  chat_history_state = gr.State()
207
  instruction_prompt_out = gr.State()
208
 
209
  session_hash_state = gr.State()
210
+ output_folder_textbox = gr.Textbox(value=OUTPUT_FOLDER, visible=False)
211
+ input_folder_textbox = gr.Textbox(value=INPUT_FOLDER, visible=False)
212
 
213
  session_hash_textbox = gr.Textbox(value="", visible=False)
214
  s3_logs_output_textbox = gr.Textbox(label="S3 logs", visible=False)
 
222
 
223
  gr.Markdown("<h1><center>Lightweight PDF / web page QA bot</center></h1>")
224
 
225
+ gr.Markdown(f"""Chat with PDF, web page or (new) csv/Excel documents. The default is a small model ({SMALL_MODEL_NAME}), that can only answer specific questions that are answered in the text. It cannot give overall impressions of, or summarise the document. The alternative ({LARGE_MODEL_NAME}), can reason a little better, but is much slower (See Advanced settings tab).\n\nBy default '[{DEFAULT_DATA_SOURCE_NAME}]({DEFAULT_DATA_SOURCE})' is loaded.If you want to talk about another document or web page, please select from the second tab. If switching topic, please click the 'Clear chat' button.\n\nCaution: This is a public app. Please ensure that the document you upload is not sensitive is any way as other users may see it! Also, please note that LLM chatbots may give incomplete or incorrect information, so please use with care.""")
226
+
 
 
 
227
 
228
  with gr.Row():
229
+ current_source = gr.Textbox(label="Current data source(s)", value=DEFAULT_DATA_SOURCE, scale = 10)
230
  current_model = gr.Textbox(label="Current model", value=model_type, scale = 3)
231
 
232
  with gr.Tab("Chatbot"):
233
 
234
  with gr.Row():
235
  #chat_height = 500
236
+ chatbot = gr.Chatbot(value=None, avatar_images=('user.jfif', 'bot.jpg'), scale = 1, resizable=True, show_copy_all_button=True, show_copy_button=True, show_share_button=True, type='messages') # , height=chat_height
237
  with gr.Accordion("Open this tab to see the source paragraphs used to generate the answer", open = True):
238
  sources = gr.HTML(value = "Source paragraphs with the most relevant text will appear here") # , height=chat_height
239
 
 
245
  with gr.Row():
246
  submit = gr.Button(value="Send message", variant="primary", scale = 4)
247
  clear = gr.Button(value="Clear chat", variant="secondary", scale=1)
248
+ stop = gr.Button(value="Stop generating", variant="stop", scale=1)
 
 
 
 
 
 
 
 
 
249
 
250
+ examples_set = gr.Radio(label="Example questions",
251
+ choices=default_examples_set)
252
+
253
+ current_topic = gr.Textbox(label="Feature currently disabled - Keywords related to current conversation topic.", placeholder="Keywords related to the conversation topic will appear here", visible=False)
254
 
255
  with gr.Tab("Load in a different file to chat with"):
256
  with gr.Accordion("PDF file", open = False):
 
276
  out_passages = gr.Slider(minimum=1, value = 2, maximum=10, step=1, label="Choose number of passages to retrieve from the document. Numbers greater than 2 may lead to increased hallucinations or input text being truncated.")
277
  temp_slide = gr.Slider(minimum=0.1, value = 0.5, maximum=1, step=0.1, label="Choose temperature setting for response generation.")
278
  with gr.Row():
279
+ model_choice = gr.Radio(label="Choose a chat model", value=SMALL_MODEL_NAME, choices = default_model_choices)
280
+ in_api_key = gr.Textbox(value = "", label="Enter Gemini API key (only if using Google API models)", lines=1, type="password",interactive=True, visible=True)
281
  change_model_button = gr.Button(value="Load model", scale=0)
282
  with gr.Accordion("Choose number of model layers to send to GPU (WARNING: please don't modify unless you are sure you have a GPU).", open = False):
283
  gpu_layer_choice = gr.Slider(label="Choose number of model layers to send to GPU.", value=0, minimum=0, maximum=100, step = 1, visible=True)
284
 
285
+ load_text = gr.Text(label="Load status")
 
286
 
287
  gr.HTML(
288
+ "<center>This app is powered by Gradio, Transformers, and Llama.cpp.</center>"
289
  )
290
 
291
  examples_set.change(fn=chatf.update_message, inputs=[examples_set], outputs=[message])
292
 
293
+ ###
294
+ # CHAT PAGE
295
+ ###
296
+
297
+ # Click to send message
298
+ response_click = submit.click(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_model_object_state, model_type_state, out_passages, in_api_key], outputs=[chat_history_state, sources, instruction_prompt_out, relevant_query_state], queue=False, api_name="retrieval").\
299
+ success(chatf.turn_off_interactivity, inputs=None, outputs=[message, submit], queue=False).\
300
+ success(chatf.produce_streaming_answer_chatbot, inputs=[chatbot, instruction_prompt_out, model_type_state, temp_slide, relevant_query_state, chat_history_state, in_api_key], outputs=chatbot)
301
+ response_click.success(chatf.highlight_found_text, [chatbot, sources], [sources]).\
302
+ success(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
303
+ success(lambda: chatf.restore_interactivity(), None, [message, submit], queue=False)
304
+
305
+ # Press enter to send message
306
+ response_enter = message.submit(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_model_object_state, model_type_state, out_passages, in_api_key], outputs=[chat_history_state, sources, instruction_prompt_out, relevant_query_state], queue=False).\
307
+ success(chatf.turn_off_interactivity, inputs=None, outputs=[message, submit], queue=False).\
308
+ success(chatf.produce_streaming_answer_chatbot, [chatbot, instruction_prompt_out, model_type_state, temp_slide, relevant_query_state, chat_history_state, in_api_key], chatbot)
309
+ response_enter.success(chatf.highlight_found_text, [chatbot, sources], [sources]).\
310
+ success(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
311
+ success(lambda: chatf.restore_interactivity(), None, [message, submit], queue=False)
312
+
313
+ # Stop box
314
+ stop.click(fn=None, inputs=None, outputs=None, cancels=[response_click, response_enter])
315
+
316
+ # Clear box
317
+ clear.click(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic])
318
+ clear.click(lambda: None, None, chatbot, queue=False)
319
+
320
+ # Thumbs up or thumbs down voting function
321
+ chatbot.like(chatf.vote, [chat_history_state, instruction_prompt_out, model_type_state], None)
322
+
323
+
324
+ ###
325
+ # LOAD NEW DATA PAGE
326
+ ###
327
 
328
  # Load in a pdf
329
  load_pdf_click = load_pdf.click(ing.parse_file, inputs=[in_pdf], outputs=[ingest_text, current_source]).\
330
  success(ing.text_to_docs, inputs=[ingest_text], outputs=[ingest_docs]).\
331
+ success(embed_faiss_save_to_zip, inputs=[ingest_docs, output_folder_textbox, embeddings_model_object_state], outputs=[ingest_embed_out, vectorstore_state, file_out_box]).\
332
  success(chatf.hide_block, outputs = [examples_set])
333
 
334
  # Load in a webpage
335
  load_web_click = load_web.click(ing.parse_html, inputs=[in_web, in_div], outputs=[ingest_text, ingest_metadata, current_source]).\
336
  success(ing.html_text_to_docs, inputs=[ingest_text, ingest_metadata], outputs=[ingest_docs]).\
337
+ success(embed_faiss_save_to_zip, inputs=[ingest_docs, output_folder_textbox, embeddings_model_object_state], outputs=[ingest_embed_out, vectorstore_state, file_out_box]).\
338
  success(chatf.hide_block, outputs = [examples_set])
339
 
340
  # Load in a csv/excel file
341
  load_csv_click = load_csv.click(ing.parse_csv_or_excel, inputs=[in_csv, in_text_column], outputs=[ingest_text, current_source]).\
342
  success(ing.csv_excel_text_to_docs, inputs=[ingest_text, in_text_column], outputs=[ingest_docs]).\
343
+ success(embed_faiss_save_to_zip, inputs=[ingest_docs, output_folder_textbox, embeddings_model_object_state], outputs=[ingest_embed_out, vectorstore_state, file_out_box]).\
344
  success(chatf.hide_block, outputs = [examples_set])
345
+
346
 
347
+ ###
348
+ # LOAD MODEL PAGE
349
+ ###
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
+ change_model_button.click(fn=chatf.turn_off_interactivity, inputs=None, outputs=[message, submit], queue=False).\
352
+ success(fn=load_model, inputs=[model_choice, gpu_layer_choice], outputs = [model_type_state, load_text, current_model]).\
353
+ success(lambda: chatf.restore_interactivity(), None, [message, submit], queue=False).\
354
+ success(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic]).\
355
+ success(lambda: None, None, chatbot, queue=False)
356
 
357
  ###
358
  # LOGGING AND ON APP LOAD FUNCTIONS
359
+ ###
360
+ # Load in default model and embeddings for each user
361
+ app.load(get_connection_params, inputs=None, outputs=[session_hash_state, output_folder_textbox, session_hash_textbox, input_folder_textbox]).\
362
+ success(load_model, inputs=[model_type_state, gpu_layer_choice, gpu_config_state, cpu_config_state, torch_device_state], outputs=[model_type_state, load_text, current_model]).\
363
+ success(get_faiss_store, inputs=[default_embeddings_store_text, embeddings_model_object_state], outputs=[vectorstore_state])
364
 
365
  # Log usernames and times of access to file (to know who is using the app when running on AWS)
366
  access_callback = gr.CSVLogger()
 
369
  session_hash_textbox.change(lambda *args: access_callback.flag(list(args)), [session_hash_textbox], None, preprocess=False).\
370
  success(fn = upload_file_to_s3, inputs=[access_logs_state, access_s3_logs_loc_state], outputs=[s3_logs_output_textbox])
371
 
 
 
 
 
372
  if __name__ == "__main__":
373
+ if COGNITO_AUTH == "1":
374
+ app.queue(max_size=int(MAX_QUEUE_SIZE), default_concurrency_limit=int(DEFAULT_CONCURRENCY_LIMIT)).launch(show_error=True, inbrowser=True, auth=authenticate_user, max_file_size=MAX_FILE_SIZE, server_port=GRADIO_SERVER_PORT, root_path=ROOT_PATH)
375
  else:
376
+ app.queue(max_size=int(MAX_QUEUE_SIZE), default_concurrency_limit=int(DEFAULT_CONCURRENCY_LIMIT)).launch(show_error=True, inbrowser=True, max_file_size=MAX_FILE_SIZE, server_port=GRADIO_SERVER_PORT, root_path=ROOT_PATH)
chatfuncs/auth.py CHANGED
@@ -1,14 +1,22 @@
1
-
2
  import boto3
3
- from chatfuncs.helper_functions import get_or_create_env_var
4
-
5
- client_id = get_or_create_env_var('AWS_CLIENT_ID', '') # This client id is borrowed from async gradio app client
6
- print(f'The value of AWS_CLIENT_ID is {client_id}')
 
7
 
8
- user_pool_id = get_or_create_env_var('AWS_USER_POOL_ID', '')
9
- print(f'The value of AWS_USER_POOL_ID is {user_pool_id}')
 
 
 
 
 
 
 
10
 
11
- def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=client_id):
12
  """Authenticates a user against an AWS Cognito user pool.
13
 
14
  Args:
@@ -16,22 +24,39 @@ def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=c
16
  client_id (str): The ID of the Cognito user pool client.
17
  username (str): The username of the user.
18
  password (str): The password of the user.
 
19
 
20
  Returns:
21
  bool: True if the user is authenticated, False otherwise.
22
  """
23
 
24
- client = boto3.client('cognito-idp') # Cognito Identity Provider client
 
 
 
25
 
26
  try:
27
- response = client.initiate_auth(
 
 
 
 
 
 
 
 
 
 
 
 
28
  AuthFlow='USER_PASSWORD_AUTH',
29
  AuthParameters={
30
  'USERNAME': username,
31
  'PASSWORD': password,
 
32
  },
33
  ClientId=client_id
34
- )
35
 
36
  # If successful, you'll receive an AuthenticationResult in the response
37
  if response.get('AuthenticationResult'):
@@ -44,5 +69,7 @@ def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=c
44
  except client.exceptions.UserNotFoundException:
45
  return False
46
  except Exception as e:
47
- print(f"An error occurred: {e}")
48
- return False
 
 
 
1
+ #import os
2
  import boto3
3
+ #import gradio as gr
4
+ import hmac
5
+ import hashlib
6
+ import base64
7
+ from chatfuncs.config import AWS_CLIENT_ID, AWS_CLIENT_SECRET, AWS_USER_POOL_ID, AWS_REGION
8
 
9
+ def calculate_secret_hash(client_id:str, client_secret:str, username:str):
10
+ message = username + client_id
11
+ dig = hmac.new(
12
+ str(client_secret).encode('utf-8'),
13
+ msg=str(message).encode('utf-8'),
14
+ digestmod=hashlib.sha256
15
+ ).digest()
16
+ secret_hash = base64.b64encode(dig).decode()
17
+ return secret_hash
18
 
19
+ def authenticate_user(username:str, password:str, user_pool_id:str=AWS_USER_POOL_ID, client_id:str=AWS_CLIENT_ID, client_secret:str=AWS_CLIENT_SECRET):
20
  """Authenticates a user against an AWS Cognito user pool.
21
 
22
  Args:
 
24
  client_id (str): The ID of the Cognito user pool client.
25
  username (str): The username of the user.
26
  password (str): The password of the user.
27
+ client_secret (str): The client secret of the app client
28
 
29
  Returns:
30
  bool: True if the user is authenticated, False otherwise.
31
  """
32
 
33
+ client = boto3.client('cognito-idp', region_name=AWS_REGION) # Cognito Identity Provider client
34
+
35
+ # Compute the secret hash
36
+ secret_hash = calculate_secret_hash(client_id, client_secret, username)
37
 
38
  try:
39
+
40
+ if client_secret == '':
41
+ response = client.initiate_auth(
42
+ AuthFlow='USER_PASSWORD_AUTH',
43
+ AuthParameters={
44
+ 'USERNAME': username,
45
+ 'PASSWORD': password,
46
+ },
47
+ ClientId=client_id
48
+ )
49
+
50
+ else:
51
+ response = client.initiate_auth(
52
  AuthFlow='USER_PASSWORD_AUTH',
53
  AuthParameters={
54
  'USERNAME': username,
55
  'PASSWORD': password,
56
+ 'SECRET_HASH': secret_hash
57
  },
58
  ClientId=client_id
59
+ )
60
 
61
  # If successful, you'll receive an AuthenticationResult in the response
62
  if response.get('AuthenticationResult'):
 
69
  except client.exceptions.UserNotFoundException:
70
  return False
71
  except Exception as e:
72
+ out_message = f"An error occurred: {e}"
73
+ print(out_message)
74
+ raise Exception(out_message)
75
+ return False
chatfuncs/chatfuncs.py CHANGED
@@ -5,65 +5,60 @@ from typing import Type, Dict, List, Tuple
5
  import time
6
  from itertools import compress
7
  import pandas as pd
8
- import numpy as np
9
-
10
- # Model packages
11
- import torch.cuda
12
- from threading import Thread
13
- from transformers import pipeline, TextIteratorStreamer
14
-
15
- # Alternative model sources
16
- #from dataclasses import asdict, dataclass
17
-
18
- # Langchain functions
19
- from langchain.prompts import PromptTemplate
20
- from langchain_community.vectorstores import FAISS
21
- from langchain_community.retrievers import SVMRetriever
22
- from langchain.text_splitter import RecursiveCharacterTextSplitter
23
- from langchain.docstore.document import Document
24
-
25
- # For keyword extraction (not currently used)
26
- #import nltk
27
- #nltk.download('wordnet')
28
  from nltk.corpus import stopwords
29
  from nltk.tokenize import RegexpTokenizer
30
  from nltk.stem import WordNetLemmatizer
31
- #from nltk.stem.snowball import SnowballStemmer
32
  from keybert import KeyBERT
33
 
34
  # For Name Entity Recognition model
35
  #from span_marker import SpanMarkerModel # Not currently used
36
 
37
-
38
  # For BM25 retrieval
39
  import bm25s
40
  import Stemmer
 
 
 
 
 
 
 
 
 
 
41
 
42
- #from gensim.corpora import Dictionary
43
- #from gensim.models import TfidfModel, OkapiBM25Model
44
- #from gensim.similarities import SparseMatrixSimilarity
45
 
46
- from llama_cpp import Llama
47
- from huggingface_hub import hf_hub_download
48
 
49
- from chatfuncs.prompts import instruction_prompt_template_alpaca, instruction_prompt_mistral_orca, instruction_prompt_phi3, instruction_prompt_llama3, instruction_prompt_qwen
 
 
 
 
50
 
51
- import gradio as gr
52
 
53
  torch.cuda.empty_cache()
54
 
55
  PandasDataFrame = Type[pd.DataFrame]
56
 
57
  embeddings = None # global variable setup
 
58
  vectorstore = None # global variable setup
59
  model_type = None # global variable setup
60
 
61
  max_memory_length = 0 # How long should the memory of the conversation last?
62
 
63
- full_text = "" # Define dummy source text (full text) just to enable highlight function to load
64
-
65
- model = [] # Define empty list for model functions to run
66
- tokenizer = [] # Define empty list for model functions to run
67
 
68
  ## Highlight text constants
69
  hlt_chunk_size = 12
@@ -77,117 +72,53 @@ ner_model = []#SpanMarkerModel.from_pretrained("tomaarsen/span-marker-mbert-base
77
  # Used to pull out keywords from chat history to add to user queries behind the scenes
78
  kw_model = pipeline("feature-extraction", model="sentence-transformers/all-MiniLM-L6-v2")
79
 
80
- # Currently set gpu_layers to 0 even with cuda due to persistent bugs in implementation with cuda
81
- if torch.cuda.is_available():
82
- torch_device = "cuda"
83
- gpu_layers = 100
84
- else:
85
- torch_device = "cpu"
86
- gpu_layers = 0
87
-
88
- print("Running on device:", torch_device)
89
- threads = 8 #torch.get_num_threads()
90
- print("CPU threads:", threads)
91
-
92
- # Qwen 2 0.5B (small, fast) Model parameters
93
- temperature: float = 0.1
94
- top_k: int = 3
95
- top_p: float = 1
96
- repetition_penalty: float = 1.15
97
- flan_alpaca_repetition_penalty: float = 1.3
98
- last_n_tokens: int = 64
99
- max_new_tokens: int = 1024
100
- seed: int = 42
101
- reset: bool = False
102
- stream: bool = True
103
- threads: int = threads
104
- batch_size:int = 256
105
- context_length:int = 2048
106
- sample = True
107
-
108
-
109
- class CtransInitConfig_gpu:
110
- def __init__(self,
111
- last_n_tokens=last_n_tokens,
112
- seed=seed,
113
- n_threads=threads,
114
- n_batch=batch_size,
115
- n_ctx=4096,
116
- n_gpu_layers=gpu_layers):
117
-
118
- self.last_n_tokens = last_n_tokens
119
- self.seed = seed
120
- self.n_threads = n_threads
121
- self.n_batch = n_batch
122
- self.n_ctx = n_ctx
123
- self.n_gpu_layers = n_gpu_layers
124
- # self.stop: list[str] = field(default_factory=lambda: [stop_string])
125
-
126
- def update_gpu(self, new_value):
127
- self.n_gpu_layers = new_value
128
-
129
- class CtransInitConfig_cpu(CtransInitConfig_gpu):
130
- def __init__(self):
131
- super().__init__()
132
- self.n_gpu_layers = 0
133
-
134
- gpu_config = CtransInitConfig_gpu()
135
- cpu_config = CtransInitConfig_cpu()
136
-
137
-
138
- class CtransGenGenerationConfig:
139
- def __init__(self, temperature=temperature,
140
- top_k=top_k,
141
- top_p=top_p,
142
- repeat_penalty=repetition_penalty,
143
- seed=seed,
144
- stream=stream,
145
- max_tokens=max_new_tokens
146
- ):
147
- self.temperature = temperature
148
- self.top_k = top_k
149
- self.top_p = top_p
150
- self.repeat_penalty = repeat_penalty
151
- self.seed = seed
152
- self.max_tokens=max_tokens
153
- self.stream = stream
154
-
155
- def update_temp(self, new_value):
156
- self.temperature = new_value
157
-
158
  # Vectorstore funcs
159
 
160
- def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
161
 
162
- print(f"> Total split documents: {len(docs_out)}")
163
 
164
- vectorstore_func = FAISS.from_documents(documents=docs_out, embedding=embeddings)
165
 
166
- '''
167
- #with open("vectorstore.pkl", "wb") as f:
168
- #pickle.dump(vectorstore, f)
169
- '''
170
 
171
- #if Path(save_to).exists():
172
- # vectorstore_func.save_local(folder_path=save_to)
173
- #else:
174
- # os.mkdir(save_to)
175
- # vectorstore_func.save_local(folder_path=save_to)
176
 
177
- global vectorstore
178
 
179
- vectorstore = vectorstore_func
180
 
181
- out_message = "Document processing complete"
182
 
183
- #print(out_message)
184
- #print(f"> Saved to: {save_to}")
185
 
186
- return out_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  # Prompt functions
189
 
190
- def base_prompt_templates(model_type = "Qwen 2 0.5B (small, fast)"):
191
 
192
  #EXAMPLE_PROMPT = PromptTemplate(
193
  # template="\nCONTENT:\n\n{page_content}\n\nSOURCE: {source}\n\n",
@@ -201,24 +132,24 @@ def base_prompt_templates(model_type = "Qwen 2 0.5B (small, fast)"):
201
 
202
  # The main prompt:
203
 
204
- if model_type == "Qwen 2 0.5B (small, fast)":
205
  INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_qwen, input_variables=['question', 'summaries'])
206
- elif model_type == "Phi 3.5 Mini (larger, slow)":
207
  INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_phi3, input_variables=['question', 'summaries'])
 
 
 
208
 
209
  return INSTRUCTION_PROMPT, CONTENT_PROMPT
210
 
211
- def write_out_metadata_as_string(metadata_in):
212
  metadata_string = [f"{' '.join(f'{k}: {v}' for k, v in d.items() if k != 'page_section')}" for d in metadata_in] # ['metadata']
213
  return metadata_string
214
 
215
- def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content_prompt, extracted_memory, vectorstore, embeddings, relevant_flag = True, out_passages = 2): # ,
216
 
217
  question = inputs["question"]
218
  chat_history = inputs["chat_history"]
219
-
220
- print("relevant_flag in generate_expanded_prompt:", relevant_flag)
221
-
222
 
223
  if relevant_flag == True:
224
  new_question_kworded = adapt_q_from_chat_history(question, chat_history, extracted_memory) # new_question_keywords,
@@ -234,8 +165,6 @@ def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content
234
  return sorry_prompt, "No relevant sources found.", new_question_kworded
235
 
236
  # Expand the found passages to the neighbouring context
237
- print("Doc_df columns:", doc_df.columns)
238
-
239
  if 'meta_url' in doc_df.columns:
240
  file_type = determine_file_type(doc_df['meta_url'][0])
241
  else:
@@ -243,7 +172,7 @@ def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content
243
 
244
  # Only expand passages if not tabular data
245
  if (file_type != ".csv") & (file_type != ".xlsx"):
246
- docs_keep_as_doc, doc_df = get_expanded_passages(vectorstore, docs_keep_out, width=3)
247
 
248
  # Build up sources content to add to user display
249
  doc_df['meta_clean'] = write_out_metadata_as_string(doc_df["metadata"]) # [f"<b>{' '.join(f'{k}: {v}' for k, v in d.items() if k != 'page_section')}</b>" for d in doc_df['metadata']]
@@ -259,29 +188,28 @@ def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content
259
  sources_docs_content_string = '<br><br>'.join(doc_df['content_meta'])#.replace(" "," ")#.strip()
260
 
261
  instruction_prompt_out = instruction_prompt.format(question=new_question_kworded, summaries=docs_content_string)
262
-
263
- print('Final prompt is: ')
264
- print(instruction_prompt_out)
265
 
266
  return instruction_prompt_out, sources_docs_content_string, new_question_kworded
267
 
268
- def create_full_prompt(user_input, history, extracted_memory, vectorstore, embeddings, model_type, out_passages, api_model_choice=None, api_key=None, relevant_flag = True):
 
 
 
 
 
 
 
 
 
 
 
269
 
270
  #if chain_agent is None:
271
  # history.append((user_input, "Please click the button to submit the Huggingface API key before using the chatbot (top right)"))
272
  # return history, history, "", ""
273
  print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
274
-
275
-
276
  history = history or []
277
-
278
- if api_model_choice and api_model_choice != "None":
279
- print("API model choice detected")
280
- if api_key:
281
- print("API key detected")
282
- return history, "", None, relevant_flag
283
- else:
284
- return history, "", None, relevant_flag
285
 
286
  # Create instruction prompt
287
  instruction_prompt, content_prompt = base_prompt_templates(model_type=model_type)
@@ -291,38 +219,15 @@ def create_full_prompt(user_input, history, extracted_memory, vectorstore, embed
291
  relevant_flag = False
292
  else:
293
  relevant_flag = True
294
-
295
- print("User input:", user_input)
296
 
297
  instruction_prompt_out, docs_content_string, new_question_kworded =\
298
  generate_expanded_prompt({"question": user_input, "chat_history": history}, #vectorstore,
299
  instruction_prompt, content_prompt, extracted_memory, vectorstore, embeddings, relevant_flag, out_passages)
300
 
301
- history.append(user_input)
302
-
303
- print("Output history is:", history)
304
- print("Final prompt to model is:",instruction_prompt_out)
305
 
306
  return history, docs_content_string, instruction_prompt_out, relevant_flag
307
 
308
- # Chat functions
309
- import boto3
310
- import json
311
- from chatfuncs.helper_functions import get_or_create_env_var
312
-
313
- # ResponseObject class for AWS Bedrock calls
314
- class ResponseObject:
315
- def __init__(self, text, usage_metadata):
316
- self.text = text
317
- self.usage_metadata = usage_metadata
318
-
319
- max_tokens = 4096
320
-
321
- AWS_DEFAULT_REGION = get_or_create_env_var('AWS_DEFAULT_REGION', 'eu-west-2')
322
- print(f'The value of AWS_DEFAULT_REGION is {AWS_DEFAULT_REGION}')
323
-
324
- bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_DEFAULT_REGION)
325
-
326
  def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice: str) -> ResponseObject:
327
  """
328
  This function sends a request to AWS Claude with the following parameters:
@@ -351,6 +256,8 @@ def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tok
351
  ],
352
  }
353
 
 
 
354
  body = json.dumps(prompt_config)
355
 
356
  modelId = model_choice
@@ -376,16 +283,173 @@ def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tok
376
 
377
  return response
378
 
379
- def produce_streaming_answer_chatbot(history,
380
- full_prompt,
381
- model_type,
382
- temperature=temperature,
383
- relevant_query_bool=True,
384
- max_new_tokens=max_new_tokens,
385
- sample=sample,
386
- repetition_penalty=repetition_penalty,
387
- top_p=top_p,
388
- top_k=top_k
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  ):
390
  #print("Model type is: ", model_type)
391
 
@@ -395,16 +459,18 @@ def produce_streaming_answer_chatbot(history,
395
 
396
  # return history
397
 
398
-
 
 
399
 
400
  if relevant_query_bool == False:
401
- out_message = [("","No relevant query found. Please retry your question")]
402
- history.append(out_message)
403
 
404
  yield history
405
  return
406
 
407
- if model_type == "Qwen 2 0.5B (small, fast)":
 
408
  # Get the model and tokenizer, and tokenize the user text.
409
  model_inputs = tokenizer(text=full_prompt, return_tensors="pt", return_attention_mask=False).to(torch_device)
410
 
@@ -422,9 +488,7 @@ def produce_streaming_answer_chatbot(history,
422
  top_k=top_k
423
  )
424
 
425
- #print(generate_kwargs)
426
-
427
- t = Thread(target=model.generate, kwargs=generate_kwargs)
428
  t.start()
429
 
430
  # Pull the generated text from the streamer, and update the model output.
@@ -432,12 +496,15 @@ def produce_streaming_answer_chatbot(history,
432
  NUM_TOKENS=0
433
  print('-'*4+'Start Generation'+'-'*4)
434
 
435
- history[-1][1] = ""
 
436
  for new_text in streamer:
437
  try:
438
- if new_text == None: new_text = ""
439
- history[-1][1] += new_text
440
- NUM_TOKENS+=1
 
 
441
  yield history
442
  except Exception as e:
443
  print(f"Error during text generation: {e}")
@@ -450,7 +517,7 @@ def produce_streaming_answer_chatbot(history,
450
  print(f'Tokens per secound: {NUM_TOKENS/time_generate}')
451
  print(f'Time per token: {(time_generate/NUM_TOKENS)*1000}ms')
452
 
453
- elif model_type == "Phi 3.5 Mini (larger, slow)":
454
  #tokens = model.tokenize(full_prompt)
455
 
456
  gen_config = CtransGenGenerationConfig()
@@ -463,15 +530,17 @@ def produce_streaming_answer_chatbot(history,
463
  NUM_TOKENS=0
464
  print('-'*4+'Start Generation'+'-'*4)
465
 
466
- output = model(
467
  full_prompt, **vars(gen_config))
468
 
469
- history[-1][1] = ""
 
470
  for out in output:
471
 
472
  if "choices" in out and len(out["choices"]) > 0 and "text" in out["choices"][0]:
473
- history[-1][1] += out["choices"][0]["text"]
474
  NUM_TOKENS+=1
 
475
  yield history
476
  else:
477
  print(f"Unexpected output structure: {out}")
@@ -481,36 +550,80 @@ def produce_streaming_answer_chatbot(history,
481
  print('-'*4+'End Generation'+'-'*4)
482
  print(f'Num of generated tokens: {NUM_TOKENS}')
483
  print(f'Time for complete generation: {time_generate}s')
484
- print(f'Tokens per secound: {NUM_TOKENS/time_generate}')
485
  print(f'Time per token: {(time_generate/NUM_TOKENS)*1000}ms')
486
 
487
- elif model_type == "anthropic.claude-3-haiku-20240307-v1:0" or model_type == "anthropic.claude-3-sonnet-20240229-v1:0":
488
  system_prompt = "You are answering questions from the user based on source material. Respond with short, factually correct answers."
489
 
490
- try:
491
- print("Calling AWS Claude model")
492
- response = call_aws_claude(full_prompt, system_prompt, temperature, max_tokens, model_type)
493
- except Exception as e:
494
- # If fails, try again after 10 seconds in case there is a throttle limit
495
- print(e)
496
- try:
497
- out_message = "API limit hit - waiting 30 seconds to retry."
498
- print(out_message)
 
 
 
 
 
 
 
 
 
499
 
500
- time.sleep(30)
501
- response = call_aws_claude(full_prompt, system_prompt, temperature, max_tokens, model_type)
502
-
503
- except Exception as e:
504
- print(e)
505
- return "", history
506
  # Update the conversation history with the new prompt and response
507
- history.append({'role': 'user', 'parts': [full_prompt]})
508
- history.append({'role': 'assistant', 'parts': [response.text]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
 
510
- # Print the updated conversation history
511
- #print("conversation_history:", conversation_history)
512
 
513
- return response, history
 
 
 
 
 
514
 
515
  # Chat helper functions
516
 
@@ -589,9 +702,6 @@ def hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val, out_p
589
 
590
  docs = vectorstore.similarity_search_with_score(new_question_kworded, k=k_val)
591
 
592
- print("Docs from similarity search:")
593
- print(docs)
594
-
595
  # Keep only documents with a certain score
596
  docs_len = [len(x[0].page_content) for x in docs]
597
  docs_scores = [x[1] for x in docs]
@@ -688,12 +798,8 @@ def hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val, out_p
688
  # 3rd level check on retrieved docs with SVM retriever
689
  # Check the type of the embeddings object
690
  embeddings_type = type(embeddings)
691
- print("Type of embeddings object:", embeddings_type)
692
 
693
 
694
- print("embeddings:", embeddings)
695
-
696
- from langchain_huggingface import HuggingFaceEmbeddings
697
  #hf_embeddings = HuggingFaceEmbeddings(**embeddings)
698
  hf_embeddings = embeddings
699
 
@@ -743,10 +849,6 @@ def hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val, out_p
743
  # Make df of best options
744
  doc_df = create_doc_df(docs_keep_out)
745
 
746
- print("doc_df:",doc_df)
747
- print("docs_keep_as_doc:",docs_keep_as_doc)
748
- print("docs_keep_out:", docs_keep_out)
749
-
750
  return docs_keep_as_doc, doc_df, docs_keep_out
751
 
752
  def get_expanded_passages(vectorstore, docs, width):
@@ -836,16 +938,16 @@ def get_expanded_passages(vectorstore, docs, width):
836
 
837
  return expanded_docs, doc_df
838
 
839
- def highlight_found_text(search_text: str, full_text: str, hlt_chunk_size:int=hlt_chunk_size, hlt_strat:List=hlt_strat, hlt_overlap:int=hlt_overlap) -> str:
840
  """
841
- Highlights occurrences of search_text within full_text.
842
 
843
  Parameters:
844
- - search_text (str): The text to be searched for within full_text.
845
- - full_text (str): The text within which search_text occurrences will be highlighted.
846
 
847
  Returns:
848
- - str: A string with occurrences of search_text highlighted.
849
 
850
  Example:
851
  >>> highlight_found_text("world", "Hello, world! This is a test. Another world awaits.")
@@ -859,32 +961,27 @@ def highlight_found_text(search_text: str, full_text: str, hlt_chunk_size:int=hl
859
  return text[i][0].replace(" ", " ").strip()
860
  else:
861
  return ""
862
-
863
- def extract_search_text_from_input(text):
864
- if isinstance(text, str):
865
- return text.replace(" ", " ").strip()
866
- elif isinstance(text, list):
867
- return text[-1][1].replace(" ", " ").strip()
868
- else:
869
- return ""
870
-
871
- full_text = extract_text_from_input(full_text)
872
- search_text = extract_search_text_from_input(search_text)
873
-
874
-
875
 
876
  text_splitter = RecursiveCharacterTextSplitter(
877
  chunk_size=hlt_chunk_size,
878
  separators=hlt_strat,
879
  chunk_overlap=hlt_overlap,
880
  )
881
- sections = text_splitter.split_text(search_text)
882
 
883
  found_positions = {}
884
  for x in sections:
885
  text_start_pos = 0
886
  while text_start_pos != -1:
887
- text_start_pos = full_text.find(x, text_start_pos)
888
  if text_start_pos != -1:
889
  found_positions[text_start_pos] = text_start_pos + len(x)
890
  text_start_pos += 1
@@ -907,20 +1004,22 @@ def highlight_found_text(search_text: str, full_text: str, hlt_chunk_size:int=hl
907
  prev_end = 0
908
  for start, end in combined_positions:
909
  if end-start > 15: # Only combine if there is a significant amount of matched text. Avoids picking up single words like 'and' etc.
910
- pos_tokens.append(full_text[prev_end:start])
911
- pos_tokens.append('<mark style="color:black;">' + full_text[start:end] + '</mark>')
912
  prev_end = end
913
- pos_tokens.append(full_text[prev_end:])
 
 
914
 
915
- return "".join(pos_tokens)
916
 
917
 
918
  # # Chat history functions
919
 
920
  def clear_chat(chat_history_state, sources, chat_message, current_topic):
921
- chat_history_state = []
922
  sources = ''
923
- chat_message = ''
924
  current_topic = ''
925
 
926
  return chat_history_state, sources, chat_message, current_topic
@@ -1011,8 +1110,7 @@ def remove_q_stopwords(question): # Remove stopwords from question. Not used at
1011
  for word in tokens_without_sw:
1012
  if word not in ordered_tokens:
1013
  ordered_tokens.add(word)
1014
- result.append(word)
1015
-
1016
 
1017
 
1018
  new_question_keywords = ' '.join(result)
@@ -1021,9 +1119,6 @@ def remove_q_stopwords(question): # Remove stopwords from question. Not used at
1021
  def remove_q_ner_extractor(question):
1022
 
1023
  predict_out = ner_model.predict(question)
1024
-
1025
-
1026
-
1027
  predict_tokens = [' '.join(v for k, v in d.items() if k == 'span') for d in predict_out]
1028
 
1029
  # Remove duplicate words while preserving order
@@ -1075,11 +1170,11 @@ def keybert_keywords(text, n, kw_model):
1075
  return keywords_list
1076
 
1077
  # Gradio functions
1078
- def turn_off_interactivity(user_message, history):
1079
- return gr.update(value="", interactive=False), history + [[user_message, None]]
1080
 
1081
  def restore_interactivity():
1082
- return gr.update(interactive=True)
1083
 
1084
  def update_message(dropdown_value):
1085
  return gr.Textbox(value=dropdown_value)
 
5
  import time
6
  from itertools import compress
7
  import pandas as pd
8
+ import google.generativeai as ai
9
+ import gradio as gr
10
+ from gradio import Progress
11
+ import boto3
12
+ import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from nltk.corpus import stopwords
14
  from nltk.tokenize import RegexpTokenizer
15
  from nltk.stem import WordNetLemmatizer
 
16
  from keybert import KeyBERT
17
 
18
  # For Name Entity Recognition model
19
  #from span_marker import SpanMarkerModel # Not currently used
20
 
 
21
  # For BM25 retrieval
22
  import bm25s
23
  import Stemmer
24
+ # Model packages
25
+ import torch.cuda
26
+ from threading import Thread
27
+ from transformers import pipeline, TextIteratorStreamer
28
+ # Langchain functions
29
+ from langchain.prompts import PromptTemplate
30
+ from langchain_community.vectorstores import FAISS
31
+ from langchain_community.retrievers import SVMRetriever
32
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
33
+ from langchain.docstore.document import Document
34
 
35
+ from chatfuncs.prompts import instruction_prompt_template_alpaca, instruction_prompt_mistral_orca, instruction_prompt_phi3, instruction_prompt_llama3, instruction_prompt_qwen, instruction_prompt_template_orca, instruction_prompt_gemma
36
+ from chatfuncs.model_load import temperature, max_new_tokens, sample, repetition_penalty, top_p, top_k, torch_device, CtransGenGenerationConfig, max_tokens
37
+ from chatfuncs.config import GEMINI_API_KEY, AWS_DEFAULT_REGION, LARGE_MODEL_NAME, SMALL_MODEL_NAME
38
 
39
+ model_object = [] # Define empty list for model functions to run
40
+ tokenizer = [] # Define empty list for model functions to run
41
 
42
+ # ResponseObject class for AWS Bedrock calls
43
+ class ResponseObject:
44
+ def __init__(self, text, usage_metadata):
45
+ self.text = text
46
+ self.usage_metadata = usage_metadata
47
 
48
+ bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_DEFAULT_REGION)
49
 
50
  torch.cuda.empty_cache()
51
 
52
  PandasDataFrame = Type[pd.DataFrame]
53
 
54
  embeddings = None # global variable setup
55
+ embeddings_model = None # global variable setup
56
  vectorstore = None # global variable setup
57
  model_type = None # global variable setup
58
 
59
  max_memory_length = 0 # How long should the memory of the conversation last?
60
 
61
+ source_texts = "" # Define dummy source text (full text) just to enable highlight function to load
 
 
 
62
 
63
  ## Highlight text constants
64
  hlt_chunk_size = 12
 
72
  # Used to pull out keywords from chat history to add to user queries behind the scenes
73
  kw_model = pipeline("feature-extraction", model="sentence-transformers/all-MiniLM-L6-v2")
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  # Vectorstore funcs
76
 
77
+ # def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
78
 
79
+ # print(f"> Total split documents: {len(docs_out)}")
80
 
81
+ # vectorstore_func = FAISS.from_documents(documents=docs_out, embedding=embeddings)
82
 
83
+ # '''
84
+ # #with open("vectorstore.pkl", "wb") as f:
85
+ # #pickle.dump(vectorstore, f)
86
+ # '''
87
 
88
+ # #if Path(save_to).exists():
89
+ # # vectorstore_func.save_local(folder_path=save_to)
90
+ # #else:
91
+ # # os.mkdir(save_to)
92
+ # # vectorstore_func.save_local(folder_path=save_to)
93
 
94
+ # global vectorstore
95
 
96
+ # vectorstore = vectorstore_func
97
 
98
+ # out_message = "Document processing complete"
99
 
100
+ # #print(out_message)
101
+ # #print(f"> Saved to: {save_to}")
102
 
103
+ # return out_message
104
+
105
+ # def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings_model=embeddings_model):
106
+
107
+ # print(f"> Total split documents: {len(docs_out)}")
108
+
109
+ # print(docs_out)
110
+
111
+ # vectorstore_func = FAISS.from_documents(documents=docs_out, embedding=embeddings_model)
112
+
113
+ # vectorstore = vectorstore_func
114
+
115
+ # out_message = "Document processing complete"
116
+
117
+ # return out_message, vectorstore_func
118
 
119
  # Prompt functions
120
 
121
+ def base_prompt_templates(model_type:str = SMALL_MODEL_NAME):
122
 
123
  #EXAMPLE_PROMPT = PromptTemplate(
124
  # template="\nCONTENT:\n\n{page_content}\n\nSOURCE: {source}\n\n",
 
132
 
133
  # The main prompt:
134
 
135
+ if model_type == SMALL_MODEL_NAME:
136
  INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_qwen, input_variables=['question', 'summaries'])
137
+ elif model_type == LARGE_MODEL_NAME:
138
  INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_phi3, input_variables=['question', 'summaries'])
139
+ else:
140
+ INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_template_orca, input_variables=['question', 'summaries'])
141
+
142
 
143
  return INSTRUCTION_PROMPT, CONTENT_PROMPT
144
 
145
+ def write_out_metadata_as_string(metadata_in:str):
146
  metadata_string = [f"{' '.join(f'{k}: {v}' for k, v in d.items() if k != 'page_section')}" for d in metadata_in] # ['metadata']
147
  return metadata_string
148
 
149
+ def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt:str, content_prompt:str, extracted_memory:list, vectorstore:object, embeddings:object, relevant_flag:bool = True, out_passages:int = 2, total_output_passage_chunks_size:int=5): # ,
150
 
151
  question = inputs["question"]
152
  chat_history = inputs["chat_history"]
 
 
 
153
 
154
  if relevant_flag == True:
155
  new_question_kworded = adapt_q_from_chat_history(question, chat_history, extracted_memory) # new_question_keywords,
 
165
  return sorry_prompt, "No relevant sources found.", new_question_kworded
166
 
167
  # Expand the found passages to the neighbouring context
 
 
168
  if 'meta_url' in doc_df.columns:
169
  file_type = determine_file_type(doc_df['meta_url'][0])
170
  else:
 
172
 
173
  # Only expand passages if not tabular data
174
  if (file_type != ".csv") & (file_type != ".xlsx"):
175
+ docs_keep_as_doc, doc_df = get_expanded_passages(vectorstore, docs_keep_out, width=total_output_passage_chunks_size)
176
 
177
  # Build up sources content to add to user display
178
  doc_df['meta_clean'] = write_out_metadata_as_string(doc_df["metadata"]) # [f"<b>{' '.join(f'{k}: {v}' for k, v in d.items() if k != 'page_section')}</b>" for d in doc_df['metadata']]
 
188
  sources_docs_content_string = '<br><br>'.join(doc_df['content_meta'])#.replace(" "," ")#.strip()
189
 
190
  instruction_prompt_out = instruction_prompt.format(question=new_question_kworded, summaries=docs_content_string)
 
 
 
191
 
192
  return instruction_prompt_out, sources_docs_content_string, new_question_kworded
193
 
194
+ def create_full_prompt(user_input:str,
195
+ history:list[dict],
196
+ extracted_memory:str,
197
+ vectorstore:object,
198
+ embeddings:object,
199
+ model_type:str,
200
+ out_passages:list[str],
201
+ api_key:str="",
202
+ relevant_flag:bool=True):
203
+
204
+ if "gemini" in model_type and not GEMINI_API_KEY and not api_key:
205
+ raise Exception("Gemini model selected but no API key found. Please enter an API key on the Advanced settings page.")
206
 
207
  #if chain_agent is None:
208
  # history.append((user_input, "Please click the button to submit the Huggingface API key before using the chatbot (top right)"))
209
  # return history, history, "", ""
210
  print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
211
+
 
212
  history = history or []
 
 
 
 
 
 
 
 
213
 
214
  # Create instruction prompt
215
  instruction_prompt, content_prompt = base_prompt_templates(model_type=model_type)
 
219
  relevant_flag = False
220
  else:
221
  relevant_flag = True
 
 
222
 
223
  instruction_prompt_out, docs_content_string, new_question_kworded =\
224
  generate_expanded_prompt({"question": user_input, "chat_history": history}, #vectorstore,
225
  instruction_prompt, content_prompt, extracted_memory, vectorstore, embeddings, relevant_flag, out_passages)
226
 
227
+ history.append({"metadata":None, "options":None, "role": 'user', "content": user_input})
 
 
 
228
 
229
  return history, docs_content_string, instruction_prompt_out, relevant_flag
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice: str) -> ResponseObject:
232
  """
233
  This function sends a request to AWS Claude with the following parameters:
 
256
  ],
257
  }
258
 
259
+ print("prompt_config:", prompt_config)
260
+
261
  body = json.dumps(prompt_config)
262
 
263
  modelId = model_choice
 
283
 
284
  return response
285
 
286
+ def construct_gemini_generative_model(in_api_key: str, temperature: float, model_choice: str, system_prompt: str, max_tokens: int) -> Tuple[object, dict]:
287
+ """
288
+ Constructs a GenerativeModel for Gemini API calls.
289
+
290
+ Parameters:
291
+ - in_api_key (str): The API key for authentication.
292
+ - temperature (float): The temperature parameter for the model, controlling the randomness of the output.
293
+ - model_choice (str): The choice of model to use for generation.
294
+ - system_prompt (str): The system prompt to guide the generation.
295
+ - max_tokens (int): The maximum number of tokens to generate.
296
+
297
+ Returns:
298
+ - Tuple[object, dict]: A tuple containing the constructed GenerativeModel and its configuration.
299
+ """
300
+ # Construct a GenerativeModel
301
+ try:
302
+ if in_api_key:
303
+ #print("Getting API key from textbox")
304
+ api_key = in_api_key
305
+ ai.configure(api_key=api_key)
306
+ elif "GOOGLE_API_KEY" in os.environ:
307
+ #print("Searching for API key in environmental variables")
308
+ api_key = os.environ["GOOGLE_API_KEY"]
309
+ ai.configure(api_key=api_key)
310
+ else:
311
+ print("No API key foound")
312
+ raise gr.Error("No API key found.")
313
+ except Exception as e:
314
+ print(e)
315
+
316
+ config = ai.GenerationConfig(temperature=temperature, max_output_tokens=max_tokens)
317
+
318
+ print("model_choice:", model_choice)
319
+
320
+ #model = ai.GenerativeModel.from_cached_content(cached_content=cache, generation_config=config)
321
+ model = ai.GenerativeModel(model_name=model_choice, system_instruction=system_prompt, generation_config=config)
322
+
323
+ return model, config
324
+
325
+ # Function to send a request and update history
326
+ def send_request(prompt: str, conversation_history: List[dict], model: object, config: dict, model_choice: str, system_prompt: str, temperature: float, progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
327
+ """
328
+ This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
329
+ It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
330
+ If the model choice is specific to AWS Claude, it calls the `call_aws_claude` function; otherwise, it uses the `model.generate_content` method.
331
+ The function returns the response text and the updated conversation history.
332
+ """
333
+ # Constructing the full prompt from the conversation history
334
+ full_prompt = "Conversation history:\n"
335
+
336
+ for entry in conversation_history:
337
+ role = entry['role'].capitalize() # Assuming the history is stored with 'role' and 'content'
338
+ message = ' '.join(entry['parts']) # Combining all parts of the message
339
+ full_prompt += f"{role}: {message}\n"
340
+
341
+ # Adding the new user prompt
342
+ full_prompt += f"\nUser: {prompt}"
343
+
344
+ # Print the full prompt for debugging purposes
345
+ #print("full_prompt:", full_prompt)
346
+
347
+ # Generate the model's response
348
+ if "gemini" in model_choice:
349
+ try:
350
+ response = model.generate_content(contents=full_prompt, generation_config=config)
351
+ except Exception as e:
352
+ # If fails, try again after 10 seconds in case there is a throttle limit
353
+ print(e)
354
+ try:
355
+ print("Calling Gemini model")
356
+ out_message = "API limit hit - waiting 30 seconds to retry."
357
+ print(out_message)
358
+ progress(0.5, desc=out_message)
359
+ time.sleep(30)
360
+ response = model.generate_content(contents=full_prompt, generation_config=config)
361
+ except Exception as e:
362
+ print(e)
363
+ return "", conversation_history
364
+ elif "claude" in model_choice:
365
+ try:
366
+ print("Calling AWS Claude model")
367
+ print("prompt:", prompt)
368
+ print("system_prompt:", system_prompt)
369
+ response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice)
370
+ except Exception as e:
371
+ # If fails, try again after x seconds in case there is a throttle limit
372
+ print(e)
373
+ try:
374
+ out_message = "API limit hit - waiting 30 seconds to retry."
375
+ print(out_message)
376
+ progress(0.5, desc=out_message)
377
+ time.sleep(30)
378
+ response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice)
379
+
380
+ except Exception as e:
381
+ print(e)
382
+ return "", conversation_history
383
+ else:
384
+ raise Exception("Model not found")
385
+
386
+ # Update the conversation history with the new prompt and response
387
+ conversation_history.append({"metadata":None, "options":None, "role": 'user', 'parts': [prompt]})
388
+ conversation_history.append({"metadata":None, "options":None, "role": "assistant", 'parts': [response.text]})
389
+
390
+ # Print the updated conversation history
391
+ #print("conversation_history:", conversation_history)
392
+
393
+ return response, conversation_history
394
+
395
+ def process_requests(prompts: List[str], system_prompt_with_table: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], model: object, config: dict, model_choice: str, temperature: float, batch_no:int = 1, master:bool = False) -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
396
+ """
397
+ Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
398
+
399
+ Args:
400
+ prompts (List[str]): A list of prompts to be processed.
401
+ system_prompt_with_table (str): The system prompt including a table.
402
+ conversation_history (List[dict]): The history of the conversation.
403
+ whole_conversation (List[str]): The complete conversation including prompts and responses.
404
+ whole_conversation_metadata (List[str]): Metadata about the whole conversation.
405
+ model (object): The model to use for processing the prompts.
406
+ config (dict): Configuration for the model.
407
+ model_choice (str): The choice of model to use.
408
+ temperature (float): The temperature parameter for the model.
409
+ batch_no (int): Batch number of the large language model request.
410
+ master (bool): Is this request for the master table.
411
+
412
+ Returns:
413
+ Tuple[List[ResponseObject], List[dict], List[str], List[str]]: A tuple containing the list of responses, the updated conversation history, the updated whole conversation, and the updated whole conversation metadata.
414
+ """
415
+ responses = []
416
+ #for prompt in prompts:
417
+
418
+ response, conversation_history = send_request(prompts[0], conversation_history, model=model, config=config, model_choice=model_choice, system_prompt=system_prompt_with_table, temperature=temperature)
419
+
420
+ print(response.text)
421
+ #"Okay, I'm ready. What source are we discussing, and what's your question about it? Please provide as much context as possible so I can give you the best answer."]
422
+ print(response.usage_metadata)
423
+ responses.append(response)
424
+
425
+ # Create conversation txt object
426
+ whole_conversation.append(prompts[0])
427
+ whole_conversation.append(response.text)
428
+
429
+ # Create conversation metadata
430
+ if master == False:
431
+ whole_conversation_metadata.append(f"Query batch {batch_no} prompt {len(responses)} metadata:")
432
+ else:
433
+ whole_conversation_metadata.append(f"Query summary metadata:")
434
+
435
+ whole_conversation_metadata.append(str(response.usage_metadata))
436
+
437
+ return responses, conversation_history, whole_conversation, whole_conversation_metadata
438
+
439
+ def produce_streaming_answer_chatbot(
440
+ history:list,
441
+ full_prompt:str,
442
+ model_type:str,
443
+ temperature:float=temperature,
444
+ relevant_query_bool:bool=True,
445
+ chat_history:list[dict]=[{"metadata":None, "options":None, "role": 'user', "content": ""}],
446
+ in_api_key:str=GEMINI_API_KEY,
447
+ max_new_tokens:int=max_new_tokens,
448
+ sample:bool=sample,
449
+ repetition_penalty:float=repetition_penalty,
450
+ top_p:float=top_p,
451
+ top_k:float=top_k,
452
+ max_tokens:int=max_tokens
453
  ):
454
  #print("Model type is: ", model_type)
455
 
 
459
 
460
  # return history
461
 
462
+ history = chat_history
463
+
464
+ print("history at start of streaming function:", history)
465
 
466
  if relevant_query_bool == False:
467
+ history.append({"metadata":None, "options":None, "role": "assistant", "content": 'No relevant query found. Please retry your question'})
 
468
 
469
  yield history
470
  return
471
 
472
+ if model_type == SMALL_MODEL_NAME:
473
+
474
  # Get the model and tokenizer, and tokenize the user text.
475
  model_inputs = tokenizer(text=full_prompt, return_tensors="pt", return_attention_mask=False).to(torch_device)
476
 
 
488
  top_k=top_k
489
  )
490
 
491
+ t = Thread(target=model_object.generate, kwargs=generate_kwargs)
 
 
492
  t.start()
493
 
494
  # Pull the generated text from the streamer, and update the model output.
 
496
  NUM_TOKENS=0
497
  print('-'*4+'Start Generation'+'-'*4)
498
 
499
+ history.append({"metadata":None, "options":None, "role": "assistant", "content": ""})
500
+
501
  for new_text in streamer:
502
  try:
503
+ if new_text is None:
504
+ new_text = ""
505
+ history[-1]['content'] += new_text
506
+ NUM_TOKENS += 1
507
+ history[-1]['content'] = history[-1]['content'].replace('<|im_end|>','')
508
  yield history
509
  except Exception as e:
510
  print(f"Error during text generation: {e}")
 
517
  print(f'Tokens per secound: {NUM_TOKENS/time_generate}')
518
  print(f'Time per token: {(time_generate/NUM_TOKENS)*1000}ms')
519
 
520
+ elif model_type == LARGE_MODEL_NAME:
521
  #tokens = model.tokenize(full_prompt)
522
 
523
  gen_config = CtransGenGenerationConfig()
 
530
  NUM_TOKENS=0
531
  print('-'*4+'Start Generation'+'-'*4)
532
 
533
+ output = model_object(
534
  full_prompt, **vars(gen_config))
535
 
536
+ history.append({"metadata":None, "options":None, "role": "assistant", "content": ""})
537
+
538
  for out in output:
539
 
540
  if "choices" in out and len(out["choices"]) > 0 and "text" in out["choices"][0]:
541
+ history[-1]['content'] += out["choices"][0]["text"]
542
  NUM_TOKENS+=1
543
+ history[-1]['content'] = history[-1]['content'].replace('<|im_end|>','')
544
  yield history
545
  else:
546
  print(f"Unexpected output structure: {out}")
 
550
  print('-'*4+'End Generation'+'-'*4)
551
  print(f'Num of generated tokens: {NUM_TOKENS}')
552
  print(f'Time for complete generation: {time_generate}s')
553
+ print(f'Tokens per second: {NUM_TOKENS/time_generate}')
554
  print(f'Time per token: {(time_generate/NUM_TOKENS)*1000}ms')
555
 
556
+ elif "claude" in model_type:
557
  system_prompt = "You are answering questions from the user based on source material. Respond with short, factually correct answers."
558
 
559
+ print("full_prompt:", full_prompt)
560
+
561
+ if isinstance(full_prompt, str):
562
+ full_prompt = [full_prompt]
563
+
564
+ model = model_type
565
+ config = {}
566
+
567
+ responses, summary_conversation_history, whole_summary_conversation, whole_conversation_metadata = process_requests(full_prompt, system_prompt, conversation_history=[], whole_conversation=[], whole_conversation_metadata=[], model=model, config = config, model_choice = model_type, temperature = temperature)
568
+
569
+ if isinstance(responses[-1], ResponseObject):
570
+ response_texts = [resp.text for resp in responses]
571
+ elif "choices" in responses[-1]:
572
+ response_texts = [resp["choices"][0]['text'] for resp in responses]
573
+ else:
574
+ response_texts = [resp.text for resp in responses]
575
+
576
+ latest_response_text = response_texts[-1]
577
 
 
 
 
 
 
 
578
  # Update the conversation history with the new prompt and response
579
+ clean_text = re.sub(r'[\n\t\r]', ' ', latest_response_text) # Replace newlines, tabs, and carriage returns with a space
580
+ clean_response_text = re.sub(r'[^\x20-\x7E]', '', clean_text).strip() # Remove all non-ASCII printable characters
581
+
582
+ history.append({"metadata":None, "options":None, "role": "assistant", "content": ''})
583
+
584
+ for char in clean_response_text:
585
+ time.sleep(0.005)
586
+ history[-1]['content'] += char
587
+ yield history
588
+
589
+ elif "gemini" in model_type:
590
+
591
+ if in_api_key: gemini_api_key = in_api_key
592
+ elif GEMINI_API_KEY: gemini_api_key = GEMINI_API_KEY
593
+ else: raise Exception("Gemini API key not found. Please enter a key on the Advanced settings page or select another model type")
594
+
595
+ print("Using Gemini model:", model_type)
596
+ print("full_prompt:", full_prompt)
597
+
598
+ if isinstance(full_prompt, str):
599
+ full_prompt = [full_prompt]
600
+
601
+ system_prompt = "You are answering questions from the user based on source material. Respond with short, factually correct answers."
602
+
603
+ model, config = construct_gemini_generative_model(gemini_api_key, temperature, model_type, system_prompt, max_tokens)
604
+
605
+ responses, summary_conversation_history, whole_summary_conversation, whole_conversation_metadata = process_requests(full_prompt, system_prompt, conversation_history=[], whole_conversation=[], whole_conversation_metadata=[], model=model, config = config, model_choice = model_type, temperature = temperature)
606
+
607
+ if isinstance(responses[-1], ResponseObject):
608
+ response_texts = [resp.text for resp in responses]
609
+ elif "choices" in responses[-1]:
610
+ response_texts = [resp["choices"][0]['text'] for resp in responses]
611
+ else:
612
+ response_texts = [resp.text for resp in responses]
613
+
614
+ latest_response_text = response_texts[-1]
615
+
616
+ clean_text = re.sub(r'[\n\t\r]', ' ', latest_response_text) # Replace newlines, tabs, and carriage returns with a space
617
+ clean_response_text = re.sub(r'[^\x20-\x7E]', '', clean_text).strip() # Remove all non-ASCII printable characters
618
 
619
+ history.append({"metadata":None, "options":None, "role": "assistant", "content": ''})
 
620
 
621
+ for char in clean_response_text:
622
+ time.sleep(0.005)
623
+ history[-1]['content'] += char
624
+ yield history
625
+
626
+ print("history at end of function:", history)
627
 
628
  # Chat helper functions
629
 
 
702
 
703
  docs = vectorstore.similarity_search_with_score(new_question_kworded, k=k_val)
704
 
 
 
 
705
  # Keep only documents with a certain score
706
  docs_len = [len(x[0].page_content) for x in docs]
707
  docs_scores = [x[1] for x in docs]
 
798
  # 3rd level check on retrieved docs with SVM retriever
799
  # Check the type of the embeddings object
800
  embeddings_type = type(embeddings)
 
801
 
802
 
 
 
 
803
  #hf_embeddings = HuggingFaceEmbeddings(**embeddings)
804
  hf_embeddings = embeddings
805
 
 
849
  # Make df of best options
850
  doc_df = create_doc_df(docs_keep_out)
851
 
 
 
 
 
852
  return docs_keep_as_doc, doc_df, docs_keep_out
853
 
854
  def get_expanded_passages(vectorstore, docs, width):
 
938
 
939
  return expanded_docs, doc_df
940
 
941
+ def highlight_found_text(chat_history: list[dict], source_texts: list[dict], hlt_chunk_size:int=hlt_chunk_size, hlt_strat:List=hlt_strat, hlt_overlap:int=hlt_overlap) -> str:
942
  """
943
+ Highlights occurrences of chat_history within source_texts.
944
 
945
  Parameters:
946
+ - chat_history (str): The text to be searched for within source_texts.
947
+ - source_texts (str): The text within which chat_history occurrences will be highlighted.
948
 
949
  Returns:
950
+ - str: A string with occurrences of chat_history highlighted.
951
 
952
  Example:
953
  >>> highlight_found_text("world", "Hello, world! This is a test. Another world awaits.")
 
961
  return text[i][0].replace(" ", " ").strip()
962
  else:
963
  return ""
964
+
965
+ print("chat_history:", chat_history)
966
+
967
+ response_text = next(
968
+ (entry['content'] for entry in reversed(chat_history) if entry.get('role') == 'assistant'),
969
+ "")
970
+
971
+ source_texts = extract_text_from_input(source_texts)
 
 
 
 
 
972
 
973
  text_splitter = RecursiveCharacterTextSplitter(
974
  chunk_size=hlt_chunk_size,
975
  separators=hlt_strat,
976
  chunk_overlap=hlt_overlap,
977
  )
978
+ sections = text_splitter.split_text(response_text)
979
 
980
  found_positions = {}
981
  for x in sections:
982
  text_start_pos = 0
983
  while text_start_pos != -1:
984
+ text_start_pos = source_texts.find(x, text_start_pos)
985
  if text_start_pos != -1:
986
  found_positions[text_start_pos] = text_start_pos + len(x)
987
  text_start_pos += 1
 
1004
  prev_end = 0
1005
  for start, end in combined_positions:
1006
  if end-start > 15: # Only combine if there is a significant amount of matched text. Avoids picking up single words like 'and' etc.
1007
+ pos_tokens.append(source_texts[prev_end:start])
1008
+ pos_tokens.append('<mark style="color:black;">' + source_texts[start:end] + '</mark>')
1009
  prev_end = end
1010
+ pos_tokens.append(source_texts[prev_end:])
1011
+
1012
+ out_pos_tokens = "".join(pos_tokens)
1013
 
1014
+ return out_pos_tokens
1015
 
1016
 
1017
  # # Chat history functions
1018
 
1019
  def clear_chat(chat_history_state, sources, chat_message, current_topic):
1020
+ chat_history_state = None
1021
  sources = ''
1022
+ chat_message = None
1023
  current_topic = ''
1024
 
1025
  return chat_history_state, sources, chat_message, current_topic
 
1110
  for word in tokens_without_sw:
1111
  if word not in ordered_tokens:
1112
  ordered_tokens.add(word)
1113
+ result.append(word)
 
1114
 
1115
 
1116
  new_question_keywords = ' '.join(result)
 
1119
  def remove_q_ner_extractor(question):
1120
 
1121
  predict_out = ner_model.predict(question)
 
 
 
1122
  predict_tokens = [' '.join(v for k, v in d.items() if k == 'span') for d in predict_out]
1123
 
1124
  # Remove duplicate words while preserving order
 
1170
  return keywords_list
1171
 
1172
  # Gradio functions
1173
+ def turn_off_interactivity():
1174
+ return gr.Textbox(interactive=False), gr.Button(interactive=False)
1175
 
1176
  def restore_interactivity():
1177
+ return gr.Textbox(interactive=True), gr.Button(interactive=True)
1178
 
1179
  def update_message(dropdown_value):
1180
  return gr.Textbox(value=dropdown_value)
chatfuncs/config.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import socket
4
+ import logging
5
+ from datetime import datetime
6
+ from dotenv import load_dotenv
7
+
8
+ today_rev = datetime.now().strftime("%Y%m%d")
9
+ HOST_NAME = socket.gethostname()
10
+
11
+ # Set or retrieve configuration variables for the redaction app
12
+
13
+ def get_or_create_env_var(var_name:str, default_value:str, print_val:bool=False):
14
+ '''
15
+ Get an environmental variable, and set it to a default value if it doesn't exist
16
+ '''
17
+ # Get the environment variable if it exists
18
+ value = os.environ.get(var_name)
19
+
20
+ # If it doesn't exist, set the environment variable to the default value
21
+ if value is None:
22
+ os.environ[var_name] = default_value
23
+ value = default_value
24
+
25
+ if print_val == True:
26
+ print(f'The value of {var_name} is {value}')
27
+
28
+ return value
29
+
30
+ def ensure_folder_exists(output_folder:str):
31
+ """Checks if the specified folder exists, creates it if not."""
32
+
33
+ if not os.path.exists(output_folder):
34
+ # Create the folder if it doesn't exist
35
+ os.makedirs(output_folder, exist_ok=True)
36
+ print(f"Created the {output_folder} folder.")
37
+ else:
38
+ print(f"The {output_folder} folder already exists.")
39
+
40
+ def add_folder_to_path(folder_path: str):
41
+ '''
42
+ Check if a folder exists on your system. If so, get the absolute path and then add it to the system Path variable if it doesn't already exist. Function is only relevant for locally-created executable files based on this app (when using pyinstaller it creates a _internal folder that contains tesseract and poppler. These need to be added to the system path to enable the app to run)
43
+ '''
44
+
45
+ if os.path.exists(folder_path) and os.path.isdir(folder_path):
46
+ print(folder_path, "folder exists.")
47
+
48
+ # Resolve relative path to absolute path
49
+ absolute_path = os.path.abspath(folder_path)
50
+
51
+ current_path = os.environ['PATH']
52
+ if absolute_path not in current_path.split(os.pathsep):
53
+ full_path_extension = absolute_path + os.pathsep + current_path
54
+ os.environ['PATH'] = full_path_extension
55
+ #print(f"Updated PATH with: ", full_path_extension)
56
+ else:
57
+ print(f"Directory {folder_path} already exists in PATH.")
58
+ else:
59
+ print(f"Folder not found at {folder_path} - not added to PATH")
60
+
61
+ ensure_folder_exists("config/")
62
+
63
+ # If you have an aws_config env file in the config folder, you can load in app variables this way, e.g. 'config/app_config.env'
64
+ APP_CONFIG_PATH = get_or_create_env_var('APP_CONFIG_PATH', 'config/app_config.env') # e.g. config/app_config.env
65
+
66
+ if APP_CONFIG_PATH:
67
+ if os.path.exists(APP_CONFIG_PATH):
68
+ print(f"Loading app variables from config file {APP_CONFIG_PATH}")
69
+ load_dotenv(APP_CONFIG_PATH)
70
+ else: print("App config file not found at location:", APP_CONFIG_PATH)
71
+
72
+ # Report logging to console?
73
+ LOGGING = get_or_create_env_var('LOGGING', 'False')
74
+
75
+ if LOGGING == 'True':
76
+ # Configure logging
77
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
78
+
79
+ ###
80
+ # AWS CONFIG
81
+ ###
82
+
83
+ # If you have an aws_config env file in the config folder, you can load in AWS keys this way, e.g. 'env/aws_config.env'
84
+ AWS_CONFIG_PATH = get_or_create_env_var('AWS_CONFIG_PATH', '') # e.g. config/aws_config.env
85
+
86
+ if AWS_CONFIG_PATH:
87
+ if os.path.exists(AWS_CONFIG_PATH):
88
+ print(f"Loading AWS variables from config file {AWS_CONFIG_PATH}")
89
+ load_dotenv(AWS_CONFIG_PATH)
90
+ else: print("AWS config file not found at location:", AWS_CONFIG_PATH)
91
+
92
+ RUN_AWS_FUNCTIONS = get_or_create_env_var("RUN_AWS_FUNCTIONS", "0")
93
+
94
+ AWS_REGION = get_or_create_env_var('AWS_REGION', '')
95
+
96
+ AWS_DEFAULT_REGION = get_or_create_env_var('AWS_DEFAULT_REGION', '')
97
+
98
+ AWS_CLIENT_ID = get_or_create_env_var('AWS_CLIENT_ID', '')
99
+
100
+ AWS_CLIENT_SECRET = get_or_create_env_var('AWS_CLIENT_SECRET', '')
101
+
102
+ AWS_USER_POOL_ID = get_or_create_env_var('AWS_USER_POOL_ID', '')
103
+
104
+ AWS_ACCESS_KEY = get_or_create_env_var('AWS_ACCESS_KEY', '')
105
+ if AWS_ACCESS_KEY: print(f'AWS_ACCESS_KEY found in environment variables')
106
+
107
+ AWS_SECRET_KEY = get_or_create_env_var('AWS_SECRET_KEY', '')
108
+ if AWS_SECRET_KEY: print(f'AWS_SECRET_KEY found in environment variables')
109
+
110
+ DOCUMENT_REDACTION_BUCKET = get_or_create_env_var('DOCUMENT_REDACTION_BUCKET', '')
111
+
112
+ # Custom headers e.g. if routing traffic through Cloudfront
113
+ # Retrieving or setting CUSTOM_HEADER
114
+ CUSTOM_HEADER = get_or_create_env_var('CUSTOM_HEADER', '')
115
+ #if CUSTOM_HEADER: print(f'CUSTOM_HEADER found')
116
+
117
+ # Retrieving or setting CUSTOM_HEADER_VALUE
118
+ CUSTOM_HEADER_VALUE = get_or_create_env_var('CUSTOM_HEADER_VALUE', '')
119
+ #if CUSTOM_HEADER_VALUE: print(f'CUSTOM_HEADER_VALUE found')
120
+
121
+ ###
122
+ # File I/O config
123
+ ###
124
+ SESSION_OUTPUT_FOLDER = get_or_create_env_var('SESSION_OUTPUT_FOLDER', 'False') # i.e. do you want your input and output folders saved within a subfolder based on session hash value within output/input folders
125
+
126
+ OUTPUT_FOLDER = get_or_create_env_var('GRADIO_OUTPUT_FOLDER', 'output/') # 'output/'
127
+ INPUT_FOLDER = get_or_create_env_var('GRADIO_INPUT_FOLDER', 'input/') # 'input/'
128
+
129
+ ensure_folder_exists(OUTPUT_FOLDER)
130
+ ensure_folder_exists(INPUT_FOLDER)
131
+
132
+ # Allow for files to be saved in a temporary folder for increased security in some instances
133
+ if OUTPUT_FOLDER == "TEMP" or INPUT_FOLDER == "TEMP":
134
+ # Create a temporary directory
135
+ with tempfile.TemporaryDirectory() as temp_dir:
136
+ print(f'Temporary directory created at: {temp_dir}')
137
+
138
+ if OUTPUT_FOLDER == "TEMP": OUTPUT_FOLDER = temp_dir + "/"
139
+ if INPUT_FOLDER == "TEMP": INPUT_FOLDER = temp_dir + "/"
140
+
141
+ # By default, logs are put into a subfolder of today's date and the host name of the instance running the app. This is to avoid at all possible the possibility of log files from one instance overwriting the logs of another instance on S3. If running the app on one system always, or just locally, it is not necessary to make the log folders so specific.
142
+ # Another way to address this issue would be to write logs to another type of storage, e.g. database such as dynamodb. I may look into this in future.
143
+
144
+ USE_LOG_SUBFOLDERS = get_or_create_env_var('USE_LOG_SUBFOLDERS', 'True')
145
+
146
+ if USE_LOG_SUBFOLDERS == "True":
147
+ day_log_subfolder = today_rev + '/'
148
+ host_name_subfolder = HOST_NAME + '/'
149
+ full_log_subfolder = day_log_subfolder + host_name_subfolder
150
+ else:
151
+ full_log_subfolder = ""
152
+
153
+ FEEDBACK_LOGS_FOLDER = get_or_create_env_var('FEEDBACK_LOGS_FOLDER', 'feedback/' + full_log_subfolder)
154
+ ACCESS_LOGS_FOLDER = get_or_create_env_var('ACCESS_LOGS_FOLDER', 'logs/' + full_log_subfolder)
155
+ USAGE_LOGS_FOLDER = get_or_create_env_var('USAGE_LOGS_FOLDER', 'usage/' + full_log_subfolder)
156
+
157
+ ensure_folder_exists(FEEDBACK_LOGS_FOLDER)
158
+ ensure_folder_exists(ACCESS_LOGS_FOLDER)
159
+ ensure_folder_exists(USAGE_LOGS_FOLDER)
160
+
161
+ # Should the redacted file name be included in the logs? In some instances, the names of the files themselves could be sensitive, and should not be disclosed beyond the app. So, by default this is false.
162
+ DISPLAY_FILE_NAMES_IN_LOGS = get_or_create_env_var('DISPLAY_FILE_NAMES_IN_LOGS', 'False')
163
+
164
+ ###
165
+ # RUN CONFIG
166
+ GEMINI_API_KEY = get_or_create_env_var('GEMINI_API_KEY', '')
167
+
168
+ HF_TOKEN = get_or_create_env_var('HF_TOKEN', '')
169
+
170
+
171
+ # Number of pages to loop through before breaking the function and restarting from the last finished page (not currently activated).
172
+ PAGE_BREAK_VALUE = get_or_create_env_var('PAGE_BREAK_VALUE', '99999')
173
+
174
+ MAX_TIME_VALUE = get_or_create_env_var('MAX_TIME_VALUE', '999999')
175
+
176
+ ###
177
+ # APP RUN CONFIG
178
+ ###
179
+
180
+ SMALL_MODEL_NAME = get_or_create_env_var("SMALL_MODEL_NAME", "Gemma 3 1B (small, fast)") # "Qwen 2 0.5B (small, fast)"
181
+
182
+ SMALL_MODEL_REPO_ID = get_or_create_env_var("SMALL_MODEL_REPO_ID", 'google/gemma-3-1b-it') #'Qwen/Qwen2-0.5B-Instruct')
183
+
184
+ LARGE_MODEL_NAME = get_or_create_env_var("LARGE_MODEL_NAME", "Phi 3.5 Mini (larger, slow)")
185
+
186
+ LARGE_MODEL_REPO_ID = get_or_create_env_var("LARGE_MODEL_REPO_ID", "QuantFactory/Phi-3.5-mini-instruct-GGUF") # "QuantFactory/Phi-3-mini-128k-instruct-GGUF"), # "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF-v2"), #"microsoft/Phi-3-mini-4k-instruct-gguf"),#"TheBloke/Mistral-7B-OpenOrca-GGUF"),
187
+ LARGE_MODEL_GGUF_FILE = get_or_create_env_var("LARGE_MODEL_GGUF_FILE", "Phi-3.5-mini-instruct.Q4_K_M.gguf") #"Phi-3-mini-128k-instruct.Q4_K_M.gguf") #"Meta-Llama-3-8B-Instruct-v2.Q6_K.gguf") #"Phi-3-mini-4k-instruct-q4.gguf")#"mistral-7b-openorca.Q4_K_M.gguf"),
188
+
189
+ if RUN_AWS_FUNCTIONS == "1":
190
+ default_model_choices = f'["{SMALL_MODEL_NAME}", "{LARGE_MODEL_NAME}", "gemini-2.0-flash-001", "gemini-2.5-flash-preview-04-17", "gemini-2.5-pro-preview-03-25", "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"]'
191
+ else:
192
+ default_model_choices = f'["{SMALL_MODEL_NAME}", "{LARGE_MODEL_NAME}", "gemini-2.0-flash-001", "gemini-2.5-flash-preview-04-17", "gemini-2.5-pro-preview-03-25"]'
193
+
194
+ DEFAULT_MODEL_CHOICES = get_or_create_env_var("DEFAULT_MODEL_CHOICES", default_model_choices)
195
+
196
+ EMBEDDINGS_MODEL_NAME = get_or_create_env_var('EMBEDDINGS_MODEL_NAME', "BAAI/bge-base-en-v1.5") #"mixedbread-ai/mxbai-embed-xsmall-v1"
197
+
198
+ DEFAULT_EMBEDDINGS_LOCATION = get_or_create_env_var('DEFAULT_EMBEDDINGS_LOCATION', "faiss_embedding")
199
+
200
+ DEFAULT_DATA_SOURCE_NAME = get_or_create_env_var('DEFAULT_DATA_SOURCE_NAME', "Document redaction app documentation")
201
+
202
+ DEFAULT_DATA_SOURCE = get_or_create_env_var('DEFAULT_DATA_SOURCE', "https://seanpedrick-case.github.io/doc_redaction/README.html")
203
+
204
+ DEFAULT_EXAMPLES = get_or_create_env_var('DEFAULT_EXAMPLES', '[ "How can I make a custom deny list?", "How can I find page duplicates?", "How can I review and modify existing redactions?", "How can I export my review files to Adobe?"]')
205
+ #
206
+ # ') # ["What were the five pillars of the previous borough plan?",
207
+ #"What is the vision statement for Lambeth?",
208
+ #"What are the commitments for Lambeth?",
209
+ #"What are the 2030 outcomes for Lambeth?"]
210
+
211
+ # Get some environment variables and Launch the Gradio app
212
+ COGNITO_AUTH = get_or_create_env_var('COGNITO_AUTH', '0')
213
+
214
+ RUN_DIRECT_MODE = get_or_create_env_var('RUN_DIRECT_MODE', '0')
215
+
216
+ MAX_QUEUE_SIZE = int(get_or_create_env_var('MAX_QUEUE_SIZE', '5'))
217
+
218
+ MAX_FILE_SIZE = get_or_create_env_var('MAX_FILE_SIZE', '250mb')
219
+
220
+ GRADIO_SERVER_PORT = int(get_or_create_env_var('GRADIO_SERVER_PORT', '7860'))
221
+
222
+ ROOT_PATH = get_or_create_env_var('ROOT_PATH', '')
223
+
224
+ DEFAULT_CONCURRENCY_LIMIT = get_or_create_env_var('DEFAULT_CONCURRENCY_LIMIT', '3')
225
+
226
+ GET_DEFAULT_ALLOW_LIST = get_or_create_env_var('GET_DEFAULT_ALLOW_LIST', 'False')
227
+
228
+ ALLOW_LIST_PATH = get_or_create_env_var('ALLOW_LIST_PATH', '') # config/default_allow_list.csv
229
+
230
+ S3_ALLOW_LIST_PATH = get_or_create_env_var('S3_ALLOW_LIST_PATH', '') # default_allow_list.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
231
+
232
+ if ALLOW_LIST_PATH: OUTPUT_ALLOW_LIST_PATH = ALLOW_LIST_PATH
233
+ else: OUTPUT_ALLOW_LIST_PATH = 'config/default_allow_list.csv'
234
+
235
+ SHOW_COSTS = get_or_create_env_var('SHOW_COSTS', 'False')
236
+
237
+ GET_COST_CODES = get_or_create_env_var('GET_COST_CODES', 'False')
238
+
239
+ DEFAULT_COST_CODE = get_or_create_env_var('DEFAULT_COST_CODE', '')
240
+
241
+ COST_CODES_PATH = get_or_create_env_var('COST_CODES_PATH', '') # 'config/COST_CENTRES.csv' # file should be a csv file with a single table in it that has two columns with a header. First column should contain cost codes, second column should contain a name or description for the cost code
242
+
243
+ S3_COST_CODES_PATH = get_or_create_env_var('S3_COST_CODES_PATH', '') # COST_CENTRES.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
244
+
245
+ if COST_CODES_PATH: OUTPUT_COST_CODES_PATH = COST_CODES_PATH
246
+ else: OUTPUT_COST_CODES_PATH = 'config/COST_CENTRES.csv'
247
+
248
+ ENFORCE_COST_CODES = get_or_create_env_var('ENFORCE_COST_CODES', 'False') # If you have cost codes listed, is it compulsory to choose one before redacting?
249
+
250
+ if ENFORCE_COST_CODES == 'True': GET_COST_CODES = 'True'
chatfuncs/helper_functions.py CHANGED
@@ -1,6 +1,9 @@
1
  import os
2
  import gradio as gr
3
  import pandas as pd
 
 
 
4
 
5
  def get_or_create_env_var(var_name, default_value):
6
  # Get the environment variable if it exists
@@ -13,12 +16,6 @@ def get_or_create_env_var(var_name, default_value):
13
 
14
  return value
15
 
16
- # Retrieving or setting output folder
17
- env_var_name = 'GRADIO_OUTPUT_FOLDER'
18
- default_value = 'output/'
19
-
20
- output_folder = get_or_create_env_var(env_var_name, default_value)
21
- print(f'The value of {env_var_name} is {output_folder}')
22
 
23
  def get_file_path_with_extension(file_path):
24
  # First, get the basename of the file (e.g., "example.txt" from "/path/to/example.txt")
@@ -165,64 +162,129 @@ def wipe_logs(feedback_logs_loc, usage_logs_loc):
165
 
166
 
167
 
168
- async def get_connection_params(request: gr.Request):
169
- base_folder = ""
170
-
171
- if request:
172
- #print("request user:", request.username)
173
-
174
- #request_data = await request.json() # Parse JSON body
175
- #print("All request data:", request_data)
176
- #context_value = request_data.get('context')
177
- #if 'context' in request_data:
178
- # print("Request context dictionary:", request_data['context'])
179
-
180
- # print("Request headers dictionary:", request.headers)
181
- # print("All host elements", request.client)
182
- # print("IP address:", request.client.host)
183
- # print("Query parameters:", dict(request.query_params))
184
- # To get the underlying FastAPI items you would need to use await and some fancy @ stuff for a live query: https://fastapi.tiangolo.com/vi/reference/request/
185
- #print("Request dictionary to object:", request.request.body())
186
- print("Session hash:", request.session_hash)
187
-
188
- # Retrieving or setting CUSTOM_CLOUDFRONT_HEADER
189
- CUSTOM_CLOUDFRONT_HEADER_var = get_or_create_env_var('CUSTOM_CLOUDFRONT_HEADER', '')
190
- #print(f'The value of CUSTOM_CLOUDFRONT_HEADER is {CUSTOM_CLOUDFRONT_HEADER_var}')
191
-
192
- # Retrieving or setting CUSTOM_CLOUDFRONT_HEADER_VALUE
193
- CUSTOM_CLOUDFRONT_HEADER_VALUE_var = get_or_create_env_var('CUSTOM_CLOUDFRONT_HEADER_VALUE', '')
194
- #print(f'The value of CUSTOM_CLOUDFRONT_HEADER_VALUE_var is {CUSTOM_CLOUDFRONT_HEADER_VALUE_var}')
195
-
196
- if CUSTOM_CLOUDFRONT_HEADER_var and CUSTOM_CLOUDFRONT_HEADER_VALUE_var:
197
- if CUSTOM_CLOUDFRONT_HEADER_var in request.headers:
198
- supplied_cloudfront_custom_value = request.headers[CUSTOM_CLOUDFRONT_HEADER_var]
199
- if supplied_cloudfront_custom_value == CUSTOM_CLOUDFRONT_HEADER_VALUE_var:
200
- print("Custom Cloudfront header found:", supplied_cloudfront_custom_value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  else:
202
- raise(ValueError, "Custom Cloudfront header value does not match expected value.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
- # Get output save folder from 1 - username passed in from direct Cognito login, 2 - Cognito ID header passed through a Lambda authenticator, 3 - the session hash.
205
-
206
- if request.username:
207
- out_session_hash = request.username
208
- base_folder = "user-files/"
209
- print("Request username found:", out_session_hash)
210
 
211
- elif 'x-cognito-id' in request.headers:
212
- out_session_hash = request.headers['x-cognito-id']
213
- base_folder = "user-files/"
214
- print("Cognito ID found:", out_session_hash)
215
 
216
- else:
217
- out_session_hash = request.session_hash
218
- base_folder = "temp-files/"
219
- # print("Cognito ID not found. Using session hash as save folder:", out_session_hash)
220
 
221
- output_folder = base_folder + out_session_hash + "/"
222
- #if bucket_name:
223
- # print("S3 output folder is: " + "s3://" + bucket_name + "/" + output_folder)
224
 
225
- return out_session_hash, output_folder, out_session_hash
226
- else:
227
- print("No session parameters found.")
228
- return "",""
 
1
  import os
2
  import gradio as gr
3
  import pandas as pd
4
+ import boto3
5
+ from botocore.exceptions import ClientError
6
+ from chatfuncs.config import CUSTOM_HEADER_VALUE, CUSTOM_HEADER, OUTPUT_FOLDER, INPUT_FOLDER, SESSION_OUTPUT_FOLDER, AWS_USER_POOL_ID
7
 
8
  def get_or_create_env_var(var_name, default_value):
9
  # Get the environment variable if it exists
 
16
 
17
  return value
18
 
 
 
 
 
 
 
19
 
20
  def get_file_path_with_extension(file_path):
21
  # First, get the basename of the file (e.g., "example.txt" from "/path/to/example.txt")
 
162
 
163
 
164
 
165
+ # async def get_connection_params(request: gr.Request):
166
+ # base_folder = ""
167
+
168
+ # if request:
169
+ # #print("request user:", request.username)
170
+
171
+ # #request_data = await request.json() # Parse JSON body
172
+ # #print("All request data:", request_data)
173
+ # #context_value = request_data.get('context')
174
+ # #if 'context' in request_data:
175
+ # # print("Request context dictionary:", request_data['context'])
176
+
177
+ # # print("Request headers dictionary:", request.headers)
178
+ # # print("All host elements", request.client)
179
+ # # print("IP address:", request.client.host)
180
+ # # print("Query parameters:", dict(request.query_params))
181
+ # # To get the underlying FastAPI items you would need to use await and some fancy @ stuff for a live query: https://fastapi.tiangolo.com/vi/reference/request/
182
+ # #print("Request dictionary to object:", request.request.body())
183
+ # print("Session hash:", request.session_hash)
184
+
185
+ # # Retrieving or setting CUSTOM_CLOUDFRONT_HEADER
186
+ # CUSTOM_CLOUDFRONT_HEADER_var = get_or_create_env_var('CUSTOM_CLOUDFRONT_HEADER', '')
187
+ # #print(f'The value of CUSTOM_CLOUDFRONT_HEADER is {CUSTOM_CLOUDFRONT_HEADER_var}')
188
+
189
+ # # Retrieving or setting CUSTOM_CLOUDFRONT_HEADER_VALUE
190
+ # CUSTOM_CLOUDFRONT_HEADER_VALUE_var = get_or_create_env_var('CUSTOM_CLOUDFRONT_HEADER_VALUE', '')
191
+ # #print(f'The value of CUSTOM_CLOUDFRONT_HEADER_VALUE_var is {CUSTOM_CLOUDFRONT_HEADER_VALUE_var}')
192
+
193
+ # if CUSTOM_CLOUDFRONT_HEADER_var and CUSTOM_CLOUDFRONT_HEADER_VALUE_var:
194
+ # if CUSTOM_CLOUDFRONT_HEADER_var in request.headers:
195
+ # supplied_cloudfront_custom_value = request.headers[CUSTOM_CLOUDFRONT_HEADER_var]
196
+ # if supplied_cloudfront_custom_value == CUSTOM_CLOUDFRONT_HEADER_VALUE_var:
197
+ # print("Custom Cloudfront header found:", supplied_cloudfront_custom_value)
198
+ # else:
199
+ # raise(ValueError, "Custom Cloudfront header value does not match expected value.")
200
+
201
+ # # Get output save folder from 1 - username passed in from direct Cognito login, 2 - Cognito ID header passed through a Lambda authenticator, 3 - the session hash.
202
+
203
+ # if request.username:
204
+ # out_session_hash = request.username
205
+ # base_folder = "user-files/"
206
+ # print("Request username found:", out_session_hash)
207
+
208
+ # elif 'x-cognito-id' in request.headers:
209
+ # out_session_hash = request.headers['x-cognito-id']
210
+ # base_folder = "user-files/"
211
+ # print("Cognito ID found:", out_session_hash)
212
+
213
+ # else:
214
+ # out_session_hash = request.session_hash
215
+ # base_folder = "temp-files/"
216
+ # # print("Cognito ID not found. Using session hash as save folder:", out_session_hash)
217
+
218
+ # output_folder = base_folder + out_session_hash + "/"
219
+ # #if bucket_name:
220
+ # # print("S3 output folder is: " + "s3://" + bucket_name + "/" + output_folder)
221
+
222
+ # return out_session_hash, output_folder, out_session_hash
223
+ # else:
224
+ # print("No session parameters found.")
225
+ # return "",""
226
+
227
+ async def get_connection_params(request: gr.Request,
228
+ output_folder_textbox:str=OUTPUT_FOLDER,
229
+ input_folder_textbox:str=INPUT_FOLDER,
230
+ session_output_folder:str=SESSION_OUTPUT_FOLDER):
231
+
232
+ #print("Session hash:", request.session_hash)
233
+
234
+ if CUSTOM_HEADER and CUSTOM_HEADER_VALUE:
235
+ if CUSTOM_HEADER in request.headers:
236
+ supplied_custom_header_value = request.headers[CUSTOM_HEADER]
237
+ if supplied_custom_header_value == CUSTOM_HEADER_VALUE:
238
+ print("Custom header supplied and matches CUSTOM_HEADER_VALUE")
239
  else:
240
+ print("Custom header value does not match expected value.")
241
+ raise ValueError("Custom header value does not match expected value.")
242
+ else:
243
+ print("Custom header value not found.")
244
+ raise ValueError("Custom header value not found.")
245
+
246
+ # Get output save folder from 1 - username passed in from direct Cognito login, 2 - Cognito ID header passed through a Lambda authenticator, 3 - the session hash.
247
+
248
+ if request.username:
249
+ out_session_hash = request.username
250
+ #print("Request username found:", out_session_hash)
251
+
252
+ elif 'x-cognito-id' in request.headers:
253
+ out_session_hash = request.headers['x-cognito-id']
254
+ #print("Cognito ID found:", out_session_hash)
255
+
256
+ elif 'x-amzn-oidc-identity' in request.headers:
257
+ out_session_hash = request.headers['x-amzn-oidc-identity']
258
+
259
+ # Fetch email address using Cognito client
260
+ cognito_client = boto3.client('cognito-idp')
261
+ try:
262
+ response = cognito_client.admin_get_user(
263
+ UserPoolId=AWS_USER_POOL_ID, # Replace with your User Pool ID
264
+ Username=out_session_hash
265
+ )
266
+ email = next(attr['Value'] for attr in response['UserAttributes'] if attr['Name'] == 'email')
267
+ #print("Email address found:", email)
268
+
269
+ out_session_hash = email
270
+ except ClientError as e:
271
+ print("Error fetching user details:", e)
272
+ email = None
273
+
274
+ print("Cognito ID found:", out_session_hash)
275
 
276
+ else:
277
+ out_session_hash = request.session_hash
 
 
 
 
278
 
279
+ if session_output_folder == 'True':
280
+ output_folder = output_folder_textbox + out_session_hash + "/"
281
+ input_folder = input_folder_textbox + out_session_hash + "/"
 
282
 
283
+ else:
284
+ output_folder = output_folder_textbox
285
+ input_folder = input_folder_textbox
 
286
 
287
+ if not os.path.exists(output_folder): os.mkdir(output_folder)
288
+ if not os.path.exists(input_folder): os.mkdir(input_folder)
 
289
 
290
+ return out_session_hash, output_folder, out_session_hash, input_folder
 
 
 
chatfuncs/ingest.py CHANGED
@@ -7,13 +7,14 @@ import requests
7
  import pandas as pd
8
  import dateutil.parser
9
  from typing import Type, List
10
- import shutil
11
 
12
- from langchain_community.embeddings import HuggingFaceEmbeddings # HuggingFaceInstructEmbeddings,
13
  from langchain_community.vectorstores.faiss import FAISS
14
  #from langchain_community.vectorstores import Chroma
15
  from langchain.text_splitter import RecursiveCharacterTextSplitter
16
  from langchain.docstore.document import Document
 
17
 
18
  from bs4 import BeautifulSoup
19
  from docx import Document as Doc
@@ -557,31 +558,24 @@ def docs_elements_from_csv_save(docs_path="documents.csv"):
557
 
558
  # ## Create embeddings and save faiss vector store to the path specified in `save_to`
559
 
560
- def load_embeddings(model_name = "BAAI/bge-base-en-v1.5"):
561
 
562
- #if model_name == "hkunlp/instructor-large":
563
- # embeddings_func = HuggingFaceInstructEmbeddings(model_name=model_name,
564
- # embed_instruction="Represent the paragraph for retrieval: ",
565
- # query_instruction="Represent the question for retrieving supporting documents: "
566
- # )
567
 
568
- #else:
569
- embeddings_func = HuggingFaceEmbeddings(model_name=model_name)
570
 
571
- global embeddings
572
 
573
- embeddings = embeddings_func
574
 
575
- return embeddings_func
576
-
577
- def embed_faiss_save_to_zip(docs_out, save_to="output", model_name="BAAI/bge-base-en-v1.5"):
578
- load_embeddings(model_name=model_name)
579
 
580
  print(f"> Total split documents: {len(docs_out)}")
581
 
582
- vectorstore = FAISS.from_documents(documents=docs_out, embedding=embeddings)
583
 
584
- save_to_path = Path(save_to)
585
  save_to_path.mkdir(parents=True, exist_ok=True)
586
 
587
  vectorstore.save_local(folder_path=str(save_to_path))
@@ -619,20 +613,20 @@ def embed_faiss_save_to_zip(docs_out, save_to="output", model_name="BAAI/bge-bas
619
 
620
 
621
 
622
- def sim_search_local_saved_vec(query, k_val, save_to="faiss_lambeth_census_embedding"):
623
 
624
- load_embeddings()
625
 
626
- docsearch = FAISS.load_local(folder_path=save_to, embeddings=embeddings)
627
 
628
 
629
- display(Markdown(question))
630
 
631
- search = docsearch.similarity_search_with_score(query, k=k_val)
632
 
633
- for item in search:
634
- print(item[0].page_content)
635
- print(f"Page: {item[0].metadata['source']}")
636
- print(f"Date: {item[0].metadata['date']}")
637
- print(f"Score: {item[1]}")
638
- print("---")
 
7
  import pandas as pd
8
  import dateutil.parser
9
  from typing import Type, List
10
+ #import shutil
11
 
12
+ #from langchain_community.embeddings import HuggingFaceEmbeddings # HuggingFaceInstructEmbeddings,
13
  from langchain_community.vectorstores.faiss import FAISS
14
  #from langchain_community.vectorstores import Chroma
15
  from langchain.text_splitter import RecursiveCharacterTextSplitter
16
  from langchain.docstore.document import Document
17
+ #from chatfuncs.config import EMBEDDINGS_MODEL_NAME
18
 
19
  from bs4 import BeautifulSoup
20
  from docx import Document as Doc
 
558
 
559
  # ## Create embeddings and save faiss vector store to the path specified in `save_to`
560
 
561
+ # def load_embeddings_model(embeddings_model = EMBEDDINGS_MODEL_NAME):
562
 
563
+ # embeddings_func = HuggingFaceEmbeddings(model_name=embeddings_model)
 
 
 
 
564
 
565
+ # #global embeddings
 
566
 
567
+ # #embeddings = embeddings_func
568
 
569
+ # return embeddings_func
570
 
571
+ def embed_faiss_save_to_zip(docs_out, save_folder, embeddings_model_object, save_to="faiss_embeddings", model_name="BAAI/bge-base-en-v1.5"):
572
+ #load_embeddings(model_name=model_name)
 
 
573
 
574
  print(f"> Total split documents: {len(docs_out)}")
575
 
576
+ vectorstore = FAISS.from_documents(documents=docs_out, embedding=embeddings_model_object)
577
 
578
+ save_to_path = Path(save_folder, save_to)
579
  save_to_path.mkdir(parents=True, exist_ok=True)
580
 
581
  vectorstore.save_local(folder_path=str(save_to_path))
 
613
 
614
 
615
 
616
+ # def sim_search_local_saved_vec(query, k_val, save_to="faiss_lambeth_census_embedding"):
617
 
618
+ # load_embeddings()
619
 
620
+ # docsearch = FAISS.load_local(folder_path=save_to, embeddings=embeddings)
621
 
622
 
623
+ # display(Markdown(question))
624
 
625
+ # search = docsearch.similarity_search_with_score(query, k=k_val)
626
 
627
+ # for item in search:
628
+ # print(item[0].page_content)
629
+ # print(f"Page: {item[0].metadata['source']}")
630
+ # print(f"Date: {item[0].metadata['date']}")
631
+ # print(f"Score: {item[1]}")
632
+ # print("---")
chatfuncs/model_load.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # Currently set gpu_layers to 0 even with cuda due to persistent bugs in implementation with cuda
4
+ if torch.cuda.is_available():
5
+ torch_device = "cuda"
6
+ gpu_layers = 100
7
+ else:
8
+ torch_device = "cpu"
9
+ gpu_layers = 0
10
+
11
+ print("Running on device:", torch_device)
12
+ threads = 8 #torch.get_num_threads()
13
+ print("CPU threads:", threads)
14
+
15
+ # Qwen 2 0.5B (small, fast) Model parameters
16
+ temperature: float = 0.1
17
+ top_k: int = 3
18
+ top_p: float = 1
19
+ repetition_penalty: float = 1.15
20
+ flan_alpaca_repetition_penalty: float = 1.3
21
+ last_n_tokens: int = 64
22
+ max_new_tokens: int = 1024
23
+ seed: int = 42
24
+ reset: bool = False
25
+ stream: bool = True
26
+ threads: int = threads
27
+ batch_size:int = 256
28
+ context_length:int = 2048
29
+ sample = True
30
+
31
+ # Bedrock parameters
32
+ max_tokens = 4096
33
+
34
+
35
+ class CtransInitConfig_gpu:
36
+ def __init__(self,
37
+ last_n_tokens=last_n_tokens,
38
+ seed=seed,
39
+ n_threads=threads,
40
+ n_batch=batch_size,
41
+ n_ctx=max_tokens,
42
+ n_gpu_layers=gpu_layers):
43
+
44
+ self.last_n_tokens = last_n_tokens
45
+ self.seed = seed
46
+ self.n_threads = n_threads
47
+ self.n_batch = n_batch
48
+ self.n_ctx = n_ctx
49
+ self.n_gpu_layers = n_gpu_layers
50
+ # self.stop: list[str] = field(default_factory=lambda: [stop_string])
51
+
52
+ def update_gpu(self, new_value):
53
+ self.n_gpu_layers = new_value
54
+
55
+ class CtransInitConfig_cpu(CtransInitConfig_gpu):
56
+ def __init__(self):
57
+ super().__init__()
58
+ self.n_gpu_layers = 0
59
+
60
+ gpu_config = CtransInitConfig_gpu()
61
+ cpu_config = CtransInitConfig_cpu()
62
+
63
+
64
+ class CtransGenGenerationConfig:
65
+ def __init__(self, temperature=temperature,
66
+ top_k=top_k,
67
+ top_p=top_p,
68
+ repeat_penalty=repetition_penalty,
69
+ seed=seed,
70
+ stream=stream,
71
+ max_tokens=max_new_tokens
72
+ ):
73
+ self.temperature = temperature
74
+ self.top_k = top_k
75
+ self.top_p = top_p
76
+ self.repeat_penalty = repeat_penalty
77
+ self.seed = seed
78
+ self.max_tokens=max_tokens
79
+ self.stream = stream
80
+
81
+ def update_temp(self, new_value):
82
+ self.temperature = new_value
chatfuncs/prompts.py CHANGED
@@ -23,8 +23,7 @@ QUESTION - {question}
23
  """
24
 
25
 
26
- instruction_prompt_template_orca = """
27
- ### System:
28
  You are an AI assistant that follows instruction extremely well. Help as much as you can.
29
  ### User:
30
  Answer the QUESTION with a short response using information from the following CONTENT.
@@ -33,8 +32,7 @@ CONTENT: {summaries}
33
 
34
  ### Response:"""
35
 
36
- instruction_prompt_template_orca_quote = """
37
- ### System:
38
  You are an AI assistant that follows instruction extremely well. Help as much as you can.
39
  ### User:
40
  Quote text from the CONTENT to answer the QUESTION below.
@@ -73,4 +71,9 @@ Answer the QUESTION using information from the following CONTENT. Respond with s
73
  CONTENT: {summaries}
74
  QUESTION: {question}\n
75
  Answer:<|im_end|>
76
- <|im_start|>assistant\n"""
 
 
 
 
 
 
23
  """
24
 
25
 
26
+ instruction_prompt_template_orca = """### System:
 
27
  You are an AI assistant that follows instruction extremely well. Help as much as you can.
28
  ### User:
29
  Answer the QUESTION with a short response using information from the following CONTENT.
 
32
 
33
  ### Response:"""
34
 
35
+ instruction_prompt_template_orca_quote = """### System:
 
36
  You are an AI assistant that follows instruction extremely well. Help as much as you can.
37
  ### User:
38
  Quote text from the CONTENT to answer the QUESTION below.
 
71
  CONTENT: {summaries}
72
  QUESTION: {question}\n
73
  Answer:<|im_end|>
74
+ <|im_start|>assistant\n"""
75
+
76
+ instruction_prompt_gemma = """Answer the QUESTION using information from the following CONTENT. Respond with short answers that directly answer the question.
77
+ CONTENT: {summaries}
78
+ QUESTION: {question}
79
+ assistant:"""
faiss_embedding/faiss_embedding.zip CHANGED
Binary files a/faiss_embedding/faiss_embedding.zip and b/faiss_embedding/faiss_embedding.zip differ
 
requirements.txt CHANGED
@@ -4,7 +4,7 @@ langchain-community==0.3.22
4
  beautifulsoup4==4.13.4
5
  google-generativeai==0.8.5
6
  pandas==2.2.3
7
- transformers==4.41.2
8
  # For Windows https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.2/llama_cpp_python-0.3.2-cp311-cp311-win_amd64.whl -C cmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS"
9
  llama-cpp-python==0.3.2 --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu
10
  #-C cmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS"
@@ -22,3 +22,4 @@ PyStemmer==2.2.0.3
22
  scipy==1.15.2
23
  numpy==1.26.4
24
  boto3==1.38.0
 
 
4
  beautifulsoup4==4.13.4
5
  google-generativeai==0.8.5
6
  pandas==2.2.3
7
+ transformers==4.51.3
8
  # For Windows https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.2/llama_cpp_python-0.3.2-cp311-cp311-win_amd64.whl -C cmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS"
9
  llama-cpp-python==0.3.2 --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu
10
  #-C cmake.args="-DGGML_BLAS=ON;-DGGML_BLAS_VENDOR=OpenBLAS"
 
22
  scipy==1.15.2
23
  numpy==1.26.4
24
  boto3==1.38.0
25
+ python-dotenv==1.1.0
requirements_gpu.txt CHANGED
@@ -20,4 +20,5 @@ bm25s==0.2.12
20
  PyStemmer==2.2.0.3
21
  scipy==1.15.2
22
  numpy==1.26.4
23
- boto3==1.38.0
 
 
20
  PyStemmer==2.2.0.3
21
  scipy==1.15.2
22
  numpy==1.26.4
23
+ boto3==1.38.0
24
+ python-dotenv==1.1.0