Futuresony commited on
Commit
94b5e7d
·
verified ·
1 Parent(s): 9497f27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +648 -35
app.py CHANGED
@@ -1,52 +1,665 @@
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- # 🔹 Hugging Face Credentials
6
- HF_REPO = "Futuresony/future_ai_12_10_2024.gguf"
7
- HF_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
8
-
9
- client = InferenceClient(HF_REPO, token=HF_TOKEN)
10
 
11
- def format_alpaca_prompt(user_input, system_prompt, history):
12
- """Formats input in Alpaca/LLaMA style"""
13
- history_str = "\n".join([f"### Instruction:\n{h[0]}\n### Response:\n{h[1]}" for h in history])
14
- prompt = f"""{system_prompt}
15
- {history_str}
16
 
17
- ### Instruction:
18
- {user_input}
 
 
 
 
 
 
19
 
20
- ### Response:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  """
22
- return prompt
23
 
24
- def respond(message, history, system_message, max_tokens, temperature, top_p):
25
- formatted_prompt = format_alpaca_prompt(message, system_message, history)
 
26
 
27
- response = client.text_generation(
28
- formatted_prompt,
29
- max_new_tokens=max_tokens,
30
- temperature=temperature,
31
- top_p=top_p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
33
 
34
- # ✅ Extract only the response
35
- cleaned_response = response.split("### Response:")[-1].strip()
 
36
 
37
- history.append((message, cleaned_response)) # Update history with the new message and response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- yield cleaned_response # Output only the answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- demo = gr.ChatInterface(
42
- respond,
43
- additional_inputs=[
44
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
45
- gr.Slider(minimum=1, maximum=250, value=128, step=1, label="Max new tokens"),
46
- gr.Slider(minimum=0.1, maximum=4.0, value=0.9, step=0.1, label="Temperature"),
47
- gr.Slider(minimum=0.1, maximum=1.0, value=0.99, step=0.01, label="Top-p (nucleus sampling)"),
48
- ],
49
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  if __name__ == "__main__":
 
 
52
  demo.launch()
 
1
+ Ok on the following make my model rather than llama model here
2
+
3
  import gradio as gr
 
4
  import os
5
+ import PyPDF2
6
+ import logging
7
+ import torch
8
+ import threading
9
+ import time
10
+ from transformers import (
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ TextIteratorStreamer,
14
+ StoppingCriteria,
15
+ StoppingCriteriaList,
16
+ )
17
+ from transformers import logging as hf_logging
18
+ import spaces
19
+ from llama_index.core import (
20
+ StorageContext,
21
+ VectorStoreIndex,
22
+ load_index_from_storage,
23
+ Document as LlamaDocument,
24
+ )
25
+ from llama_index.core import Settings
26
+ from llama_index.core.node_parser import (
27
+ HierarchicalNodeParser,
28
+ get_leaf_nodes,
29
+ get_root_nodes,
30
+ )
31
+ from llama_index.core.retrievers import AutoMergingRetriever
32
+ from llama_index.core.storage.docstore import SimpleDocumentStore
33
+ from llama_index.llms.huggingface import HuggingFaceLLM
34
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
35
+ from tqdm import tqdm
36
 
37
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
38
+ logging.basicConfig(level=logging.INFO)
39
+ logger = logging.getLogger(__name__)
40
+ hf_logging.set_verbosity_error()
 
41
 
42
+ MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
43
+ EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
44
+ HF_TOKEN = os.environ.get("HF_TOKEN")
45
+ if not HF_TOKEN:
46
+ raise ValueError("HF_TOKEN not found in environment variables")
47
 
48
+ # --- UI Settings ---
49
+ TITLE = "<h1 style='text-align:center; margin-bottom: 20px;'>Local Thinking RAG: Llama 3.1 8B</h1>"
50
+ DISCORD_BADGE = """<p style="text-align:center; margin-top: -10px;">
51
+ <a href="https://discord.gg/openfreeai" target="_blank">
52
+ <img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="badge">
53
+ </a>
54
+ </p>
55
+ """
56
 
57
+ CSS = """
58
+ .upload-section {
59
+ max-width: 400px;
60
+ margin: 0 auto;
61
+ padding: 10px;
62
+ border: 2px dashed #ccc;
63
+ border-radius: 10px;
64
+ }
65
+ .upload-button {
66
+ background: #34c759 !important;
67
+ color: white !important;
68
+ border-radius: 25px !important;
69
+ }
70
+ .chatbot-container {
71
+ margin-top: 20px;
72
+ }
73
+ .status-output {
74
+ margin-top: 10px;
75
+ font-size: 14px;
76
+ }
77
+ .processing-info {
78
+ margin-top: 5px;
79
+ font-size: 12px;
80
+ color: #666;
81
+ }
82
+ .info-container {
83
+ margin-top: 10px;
84
+ padding: 10px;
85
+ border-radius: 5px;
86
+ }
87
+ .file-list {
88
+ margin-top: 0;
89
+ max-height: 200px;
90
+ overflow-y: auto;
91
+ padding: 5px;
92
+ border: 1px solid #eee;
93
+ border-radius: 5px;
94
+ }
95
+ .stats-box {
96
+ margin-top: 10px;
97
+ padding: 10px;
98
+ border-radius: 5px;
99
+ font-size: 12px;
100
+ }
101
+ .submit-btn {
102
+ background: #1a73e8 !important;
103
+ color: white !important;
104
+ border-radius: 25px !important;
105
+ margin-left: 10px;
106
+ padding: 5px 10px;
107
+ font-size: 16px;
108
+ }
109
+ .input-row {
110
+ display: flex;
111
+ align-items: center;
112
+ }
113
+ @media (min-width: 768px) {
114
+ .main-container {
115
+ display: flex;
116
+ justify-content: space-between;
117
+ gap: 20px;
118
+ }
119
+ .upload-section {
120
+ flex: 1;
121
+ max-width: 300px;
122
+ }
123
+ .chatbot-container {
124
+ flex: 2;
125
+ margin-top: 0;
126
+ }
127
+ }
128
  """
 
129
 
130
+ global_model = None
131
+ global_tokenizer = None
132
+ global_file_info = {}
133
 
134
+ def initialize_model_and_tokenizer():
135
+ global global_model, global_tokenizer
136
+ if global_model is None or global_tokenizer is None:
137
+ logger.info("Initializing model and tokenizer...")
138
+ global_tokenizer = AutoTokenizer.from_pretrained(MODEL, token=HF_TOKEN)
139
+ global_model = AutoModelForCausalLM.from_pretrained(
140
+ MODEL,
141
+ device_map="auto",
142
+ trust_remote_code=True,
143
+ token=HF_TOKEN,
144
+ torch_dtype=torch.float16
145
+ )
146
+ logger.info("Model and tokenizer initialized successfully")
147
+
148
+ def get_llm(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
149
+ global global_model, global_tokenizer
150
+ if global_model is None or global_tokenizer is None:
151
+ initialize_model_and_tokenizer()
152
+
153
+ return HuggingFaceLLM(
154
+ context_window=4096,
155
+ max_new_tokens=max_new_tokens,
156
+ tokenizer=global_tokenizer,
157
+ model=global_model,
158
+ generate_kwargs={
159
+ "do_sample": True,
160
+ "temperature": temperature,
161
+ "top_k": top_k,
162
+ "top_p": top_p
163
+ }
164
  )
165
 
166
+ def extract_text_from_document(file):
167
+ file_name = file.name
168
+ file_extension = os.path.splitext(file_name)[1].lower()
169
 
170
+ if file_extension == '.txt':
171
+ text = file.read().decode('utf-8')
172
+ return text, len(text.split()), None
173
+ elif file_extension == '.pdf':
174
+ pdf_reader = PyPDF2.PdfReader(file)
175
+ text = "\n\n".join(page.extract_text() for page in pdf_reader.pages)
176
+ return text, len(text.split()), None
177
+ else:
178
+ return None, 0, ValueError(f"Unsupported file format: {file_extension}")
179
+
180
+ @spaces.GPU()
181
+ def create_or_update_index(files, request: gr.Request):
182
+ global global_file_info
183
+
184
+ if not files:
185
+ return "Please provide files.", ""
186
+
187
+ start_time = time.time()
188
+ user_id = request.session_hash
189
+ save_dir = f"./{user_id}_index"
190
+ # Initialize LlamaIndex modules
191
+ llm = get_llm()
192
+ embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
193
+ Settings.llm = llm
194
+ Settings.embed_model = embed_model
195
+ file_stats = []
196
+ new_documents = []
197
 
198
+ for file in tqdm(files, desc="Processing files"):
199
+ file_basename = os.path.basename(file.name)
200
+ text, word_count, error = extract_text_from_document(file)
201
+ if error:
202
+ logger.error(f"Error processing file {file_basename}: {str(error)}")
203
+ file_stats.append({
204
+ "name": file_basename,
205
+ "words": 0,
206
+ "status": f"error: {str(error)}"
207
+ })
208
+ continue
209
+
210
+ doc = LlamaDocument(
211
+ text=text,
212
+ metadata={
213
+ "file_name": file_basename,
214
+ "word_count": word_count,
215
+ "source": "user_upload"
216
+ }
217
+ )
218
+ new_documents.append(doc)
219
+
220
+ file_stats.append({
221
+ "name": file_basename,
222
+ "words": word_count,
223
+ "status": "processed"
224
+ })
225
+
226
+ global_file_info[file_basename] = {
227
+ "word_count": word_count,
228
+ "processed_at": time.time()
229
+ }
230
+
231
+ node_parser = HierarchicalNodeParser.from_defaults(
232
+ chunk_sizes=[2048, 512, 128],
233
+ chunk_overlap=20
234
+ )
235
+ logger.info(f"Parsing {len(new_documents)} documents into hierarchical nodes")
236
+ new_nodes = node_parser.get_nodes_from_documents(new_documents)
237
+ new_leaf_nodes = get_leaf_nodes(new_nodes)
238
+ new_root_nodes = get_root_nodes(new_nodes)
239
+ logger.info(f"Generated {len(new_nodes)} total nodes ({len(new_root_nodes)} root, {len(new_leaf_nodes)} leaf)")
240
+
241
+ if os.path.exists(save_dir):
242
+ logger.info(f"Loading existing index from {save_dir}")
243
+ storage_context = StorageContext.from_defaults(persist_dir=save_dir)
244
+ index = load_index_from_storage(storage_context, settings=Settings)
245
+ docstore = storage_context.docstore
246
+
247
+ docstore.add_documents(new_nodes)
248
+ for node in tqdm(new_leaf_nodes, desc="Adding leaf nodes to index"):
249
+ index.insert_nodes([node])
250
+
251
+ total_docs = len(docstore.docs)
252
+ logger.info(f"Updated index with {len(new_nodes)} new nodes from {len(new_documents)} files")
253
+ else:
254
+ logger.info("Creating new index")
255
+ docstore = SimpleDocumentStore()
256
+ storage_context = StorageContext.from_defaults(docstore=docstore)
257
+ docstore.add_documents(new_nodes)
258
+
259
+ index = VectorStoreIndex(
260
+ new_leaf_nodes,
261
+ storage_context=storage_context,
262
+ settings=Settings
263
+ )
264
+ total_docs = len(new_documents)
265
+ logger.info(f"Created new index with {len(new_nodes)} nodes from {len(new_documents)} files")
266
+
267
+ index.storage_context.persist(persist_dir=save_dir)
268
+ # custom outputs after processing files
269
+ file_list_html = "<div class='file-list'>"
270
+ for stat in file_stats:
271
+ status_color = "#4CAF50" if stat["status"] == "processed" else "#f44336"
272
+ file_list_html += f"<div><span style='color:{status_color}'>●</span> {stat['name']} - {stat['words']} words</div>"
273
+ file_list_html += "</div>"
274
+ processing_time = time.time() - start_time
275
+ stats_output = f"<div class='stats-box'>"
276
+ stats_output += f"✓ Processed {len(files)} files in {processing_time:.2f} seconds<br>"
277
+ stats_output += f"✓ Created {len(new_nodes)} nodes ({len(new_leaf_nodes)} leaf nodes)<br>"
278
+ stats_output += f"✓ Total documents in index: {total_docs}<br>"
279
+ stats_output += f"✓ Index saved to: {save_dir}<br>"
280
+ stats_output += "</div>"
281
+ output_container = f"<div class='info-container'>"
282
+ output_container += file_list_html
283
+ output_container += stats_output
284
+ output_container += "</div>"
285
+ return f"Successfully indexed {len(files)} files.", output_container
286
 
287
+ @spaces.GPU()
288
+ def create_or_update_index(files, request: gr.Request):
289
+ global global_file_info
290
+
291
+ if not files:
292
+ return "Please provide files.", ""
293
+
294
+ start_time = time.time()
295
+ user_id = request.session_hash
296
+ save_dir = f"./{user_id}_index"
297
+ # Initialize LlamaIndex modules
298
+ llm = get_llm()
299
+ embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
300
+ Settings.llm = llm
301
+ Settings.embed_model = embed_model
302
+ file_stats = []
303
+ new_documents = []
304
+
305
+ for file in tqdm(files, desc="Processing files"):
306
+ file_basename = os.path.basename(file.name)
307
+ text, word_count, error = extract_text_from_document(file)
308
+ if error:
309
+ logger.error(f"Error processing file {file_basename}: {str(error)}")
310
+ file_stats.append({
311
+ "name": file_basename,
312
+ "words": 0,
313
+ "status": f"error: {str(error)}"
314
+ })
315
+ continue
316
+
317
+ doc = LlamaDocument(
318
+ text=text,
319
+ metadata={
320
+ "file_name": file_basename,
321
+ "word_count": word_count,
322
+ "source": "user_upload"
323
+ }
324
+ )
325
+ new_documents.append(doc)
326
+
327
+ file_stats.append({
328
+ "name": file_basename,
329
+ "words": word_count,
330
+ "status": "processed"
331
+ })
332
+
333
+ global_file_info[file_basename] = {
334
+ "word_count": word_count,
335
+ "processed_at": time.time()
336
+ }
337
+
338
+ node_parser = HierarchicalNodeParser.from_defaults(
339
+ chunk_sizes=[2048, 512, 128],
340
+ chunk_overlap=20
341
+ )
342
+ logger.info(f"Parsing {len(new_documents)} documents into hierarchical nodes")
343
+ new_nodes = node_parser.get_nodes_from_documents(new_documents)
344
+ new_leaf_nodes = get_leaf_nodes(new_nodes)
345
+ new_root_nodes = get_root_nodes(new_nodes)
346
+ logger.info(f"Generated {len(new_nodes)} total nodes ({len(new_root_nodes)} root, {len(new_leaf_nodes)} leaf)")
347
+
348
+ if os.path.exists(save_dir):
349
+ logger.info(f"Loading existing index from {save_dir}")
350
+ storage_context = StorageContext.from_defaults(persist_dir=save_dir)
351
+ index = load_index_from_storage(storage_context, settings=Settings)
352
+ docstore = storage_context.docstore
353
+
354
+ docstore.add_documents(new_nodes)
355
+ for node in tqdm(new_leaf_nodes, desc="Adding leaf nodes to index"):
356
+ index.insert_nodes([node])
357
+
358
+ total_docs = len(docstore.docs)
359
+ logger.info(f"Updated index with {len(new_nodes)} new nodes from {len(new_documents)} files")
360
+ else:
361
+ logger.info("Creating new index")
362
+ docstore = SimpleDocumentStore()
363
+ storage_context = StorageContext.from_defaults(docstore=docstore)
364
+ docstore.add_documents(new_nodes)
365
+
366
+ index = VectorStoreIndex(
367
+ new_leaf_nodes,
368
+ storage_context=storage_context,
369
+ settings=Settings
370
+ )
371
+ total_docs = len(new_documents)
372
+ logger.info(f"Created new index with {len(new_nodes)} nodes from {len(new_documents)} files")
373
+
374
+ index.storage_context.persist(persist_dir=save_dir)
375
+ # custom outputs after processing files
376
+ file_list_html = "<div class='file-list'>"
377
+ for stat in file_stats:
378
+ status_color = "#4CAF50" if stat["status"] == "processed" else "#f44336"
379
+ file_list_html += f"<div><span style='color:{status_color}'>●</span> {stat['name']} - {stat['words']} words</div>"
380
+ file_list_html += "</div>"
381
+ processing_time = time.time() - start_time
382
+ stats_output = f"<div class='stats-box'>"
383
+ stats_output += f"✓ Processed {len(files)} files in {processing_time:.2f} seconds<br>"
384
+ stats_output += f"✓ Created {len(new_nodes)} nodes ({len(new_leaf_nodes)} leaf nodes)<br>"
385
+ stats_output += f"✓ Total documents in index: {total_docs}<br>"
386
+ stats_output += f"✓ Index saved to: {save_dir}<br>"
387
+ stats_output += "</div>"
388
+ output_container = f"<div class='info-container'>"
389
+ output_container += file_list_html
390
+ output_container += stats_output
391
+ output_container += "</div>"
392
+ return f"Successfully indexed {len(files)} files.", output_container
393
+
394
+ @spaces.GPU()
395
+ def stream_chat(
396
+ message: str,
397
+ history: list,
398
+ system_prompt: str,
399
+ temperature: float,
400
+ max_new_tokens: int,
401
+ top_p: float,
402
+ top_k: int,
403
+ penalty: float,
404
+ retriever_k: int,
405
+ merge_threshold: float,
406
+ request: gr.Request
407
+ ):
408
+ if not request:
409
+ yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}]
410
+ return
411
+ user_id = request.session_hash
412
+ index_dir = f"./{user_id}_index"
413
+ if not os.path.exists(index_dir):
414
+ yield history + [{"role": "assistant", "content": "Please upload documents first."}]
415
+ return
416
+
417
+ max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 1024
418
+ temperature = float(temperature) if isinstance(temperature, (int, float)) else 0.9
419
+ top_p = float(top_p) if isinstance(top_p, (int, float)) else 0.95
420
+ top_k = int(top_k) if isinstance(top_k, (int, float)) else 50
421
+ penalty = float(penalty) if isinstance(penalty, (int, float)) else 1.2
422
+ retriever_k = int(retriever_k) if isinstance(retriever_k, (int, float)) else 15
423
+ merge_threshold = float(merge_threshold) if isinstance(merge_threshold, (int, float)) else 0.5
424
+ llm = get_llm(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k)
425
+ embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
426
+ Settings.llm = llm
427
+ Settings.embed_model = embed_model
428
+ storage_context = StorageContext.from_defaults(persist_dir=index_dir)
429
+ index = load_index_from_storage(storage_context, settings=Settings)
430
+ base_retriever = index.as_retriever(similarity_top_k=retriever_k)
431
+ auto_merging_retriever = AutoMergingRetriever(
432
+ base_retriever,
433
+ storage_context=storage_context,
434
+ simple_ratio_thresh=merge_threshold,
435
+ verbose=True
436
+ )
437
+ logger.info(f"Query: {message}")
438
+ retrieval_start = time.time()
439
+ base_nodes = base_retriever.retrieve(message)
440
+ logger.info(f"Retrieved {len(base_nodes)} base nodes in {time.time() - retrieval_start:.2f}s")
441
+ base_file_sources = {}
442
+ for node in base_nodes:
443
+ if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata:
444
+ file_name = node.node.metadata['file_name']
445
+ if file_name not in base_file_sources:
446
+ base_file_sources[file_name] = 0
447
+ base_file_sources[file_name] += 1
448
+ logger.info(f"Base retrieval file distribution: {base_file_sources}")
449
+ merging_start = time.time()
450
+ merged_nodes = auto_merging_retriever.retrieve(message)
451
+ logger.info(f"Retrieved {len(merged_nodes)} merged nodes in {time.time() - merging_start:.2f}s")
452
+ merged_file_sources = {}
453
+ for node in merged_nodes:
454
+ if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata:
455
+ file_name = node.node.metadata['file_name']
456
+ if file_name not in merged_file_sources:
457
+ merged_file_sources[file_name] = 0
458
+ merged_file_sources[file_name] += 1
459
+ logger.info(f"Merged retrieval file distribution: {merged_file_sources}")
460
+ context = "\n\n".join([n.node.text for n in merged_nodes])
461
+ source_info = ""
462
+ if merged_file_sources:
463
+ source_info = "\n\nRetrieved information from files: " + ", ".join(merged_file_sources.keys())
464
+ formatted_system_prompt = f"{system_prompt}\n\nDocument Context:\n{context}{source_info}"
465
+ messages = [{"role": "system", "content": formatted_system_prompt}]
466
+ for entry in history:
467
+ messages.append(entry)
468
+ messages.append({"role": "user", "content": message})
469
+ prompt = global_tokenizer.apply_chat_template(
470
+ messages,
471
+ tokenize=False,
472
+ add_generation_prompt=True
473
+ )
474
+ stop_event = threading.Event()
475
+ class StopOnEvent(StoppingCriteria):
476
+ def __init__(self, stop_event):
477
+ super().__init__()
478
+ self.stop_event = stop_event
479
+
480
+ def __call__(self, input_ids, scores, **kwargs):
481
+ return self.stop_event.is_set()
482
+ stopping_criteria = StoppingCriteriaList([StopOnEvent(stop_event)])
483
+ streamer = TextIteratorStreamer(
484
+ global_tokenizer,
485
+ skip_prompt=True,
486
+ skip_special_tokens=True
487
+ )
488
+ inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device)
489
+ generation_kwargs = dict(
490
+ inputs,
491
+ streamer=streamer,
492
+ max_new_tokens=max_new_tokens,
493
+ temperature=temperature,
494
+ top_p=top_p,
495
+ top_k=top_k,
496
+ repetition_penalty=penalty,
497
+ do_sample=True,
498
+ stopping_criteria=stopping_criteria
499
+ )
500
+ thread = threading.Thread(target=global_model.generate, kwargs=generation_kwargs)
501
+ thread.start()
502
+ updated_history = history + [
503
+ {"role": "user", "content": message},
504
+ {"role": "assistant", "content": ""}
505
+ ]
506
+ yield updated_history
507
+ partial_response = ""
508
+ try:
509
+ for new_text in streamer:
510
+ partial_response += new_text
511
+ updated_history[-1]["content"] = partial_response
512
+ yield updated_history
513
+ yield updated_history
514
+ except GeneratorExit:
515
+ stop_event.set()
516
+ thread.join()
517
+ raise
518
+
519
+ def create_demo():
520
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
521
+ # Title
522
+ gr.HTML(TITLE)
523
+ # Discord badge immediately under the title
524
+ gr.HTML(DISCORD_BADGE)
525
+
526
+ with gr.Row(elem_classes="main-container"):
527
+ with gr.Column(elem_classes="upload-section"):
528
+ file_upload = gr.File(
529
+ file_count="multiple",
530
+ label="Drag & Drop PDF/TXT Files Here",
531
+ file_types=[".pdf", ".txt"],
532
+ elem_id="file-upload"
533
+ )
534
+ upload_button = gr.Button("Upload & Index", elem_classes="upload-button")
535
+ status_output = gr.Textbox(
536
+ label="Status",
537
+ placeholder="Upload files to start...",
538
+ interactive=False
539
+ )
540
+ file_info_output = gr.HTML(
541
+ label="File Information",
542
+ elem_classes="processing-info"
543
+ )
544
+ upload_button.click(
545
+ fn=create_or_update_index,
546
+ inputs=[file_upload],
547
+ outputs=[status_output, file_info_output]
548
+ )
549
+
550
+ with gr.Column(elem_classes="chatbot-container"):
551
+ chatbot = gr.Chatbot(
552
+ height=500,
553
+ placeholder="Chat with your documents...",
554
+ show_label=False,
555
+ type="messages"
556
+ )
557
+ with gr.Row(elem_classes="input-row"):
558
+ message_input = gr.Textbox(
559
+ placeholder="Type your question here...",
560
+ show_label=False,
561
+ container=False,
562
+ lines=1,
563
+ scale=8
564
+ )
565
+ submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1)
566
+
567
+ with gr.Accordion("Advanced Settings", open=False):
568
+ system_prompt = gr.Textbox(
569
+ value="You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside tags, and then provide your solution or response to the problem. As a knowledgeable assistant, provide detailed answers using the relevant information from all uploaded documents.",
570
+ label="System Prompt",
571
+ lines=3
572
+ )
573
+
574
+ with gr.Tab("Generation Parameters"):
575
+ temperature = gr.Slider(
576
+ minimum=0,
577
+ maximum=1,
578
+ step=0.1,
579
+ value=0.9,
580
+ label="Temperature"
581
+ )
582
+ max_new_tokens = gr.Slider(
583
+ minimum=128,
584
+ maximum=8192,
585
+ step=64,
586
+ value=1024,
587
+ label="Max New Tokens",
588
+ )
589
+ top_p = gr.Slider(
590
+ minimum=0.0,
591
+ maximum=1.0,
592
+ step=0.1,
593
+ value=0.95,
594
+ label="Top P"
595
+ )
596
+ top_k = gr.Slider(
597
+ minimum=1,
598
+ maximum=100,
599
+ step=1,
600
+ value=50,
601
+ label="Top K"
602
+ )
603
+ penalty = gr.Slider(
604
+ minimum=0.0,
605
+ maximum=2.0,
606
+ step=0.1,
607
+ value=1.2,
608
+ label="Repetition Penalty"
609
+ )
610
+
611
+ with gr.Tab("Retrieval Parameters"):
612
+ retriever_k = gr.Slider(
613
+ minimum=5,
614
+ maximum=30,
615
+ step=1,
616
+ value=15,
617
+ label="Initial Retrieval Size (Top K)"
618
+ )
619
+ merge_threshold = gr.Slider(
620
+ minimum=0.1,
621
+ maximum=0.9,
622
+ step=0.1,
623
+ value=0.5,
624
+ label="Merge Threshold (lower = more merging)"
625
+ )
626
+
627
+ submit_button.click(
628
+ fn=stream_chat,
629
+ inputs=[
630
+ message_input,
631
+ chatbot,
632
+ system_prompt,
633
+ temperature,
634
+ max_new_tokens,
635
+ top_p,
636
+ top_k,
637
+ penalty,
638
+ retriever_k,
639
+ merge_threshold
640
+ ],
641
+ outputs=chatbot
642
+ )
643
+
644
+ message_input.submit(
645
+ fn=stream_chat,
646
+ inputs=[
647
+ message_input,
648
+ chatbot,
649
+ system_prompt,
650
+ temperature,
651
+ max_new_tokens,
652
+ top_p,
653
+ top_k,
654
+ penalty,
655
+ retriever_k,
656
+ merge_threshold
657
+ ],
658
+ outputs=chatbot
659
+ )
660
+ return demo
661
 
662
  if __name__ == "__main__":
663
+ initialize_model_and_tokenizer()
664
+ demo = create_demo()
665
  demo.launch()