Commit
·
cfa9e5f
0
Parent(s):
update space
Browse files- .gitattributes +36 -0
- README.md +13 -0
- app.py +54 -0
- data/.gitattributes +1 -0
- data/2400594-RR-Vol 1-E-A5.pdf +3 -0
- data/2400594-RR-Vol 2-E-A5.pdf +3 -0
- data/2400594-RR-Vol 3-E-A5.pdf +3 -0
- data/2400594-RR-Vol 4-E-A5.pdf +3 -0
- rag_agent.py +153 -0
- requirements.txt +9 -0
- utils/__init__.py +0 -0
- utils/generation.py +112 -0
- utils/retrieval.py +14 -0
.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: RegRag
|
3 |
+
emoji: 💬
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.0.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
app.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
from rag_agent import prepare_index_and_chunks, load_model
|
5 |
+
from utils.retrieval import retrieve_relevant_chunks
|
6 |
+
from utils.generation import generate_answer
|
7 |
+
|
8 |
+
# ——— FIXED CONFIGURATION ———
|
9 |
+
PDF_FOLDER = "./data" # your folder with all PDFs
|
10 |
+
EMBEDDER = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
|
11 |
+
CHUNK_SIZE = 500
|
12 |
+
OVERLAP = 100
|
13 |
+
INDEX_TYPE = "innerproduct"
|
14 |
+
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
15 |
+
TOP_K = 5
|
16 |
+
|
17 |
+
# ——— PREPARE INDEX & MODEL ONCE ———
|
18 |
+
faiss_index_path, chunks_path = prepare_index_and_chunks(
|
19 |
+
pdf_folder=PDF_FOLDER,
|
20 |
+
chunk_size=CHUNK_SIZE,
|
21 |
+
overlap=OVERLAP,
|
22 |
+
index_type=INDEX_TYPE,
|
23 |
+
embedder_name=EMBEDDER
|
24 |
+
)
|
25 |
+
model, tokenizer = load_model(MODEL_NAME)
|
26 |
+
|
27 |
+
# ——— INFERENCE FUNCTION ———
|
28 |
+
def answer_query(query: str) -> str:
|
29 |
+
if not query.strip():
|
30 |
+
return "⚠️ Please enter a question."
|
31 |
+
# Retrieve top-K chunks
|
32 |
+
chunks = retrieve_relevant_chunks(
|
33 |
+
query=query,
|
34 |
+
embedder_name=EMBEDDER,
|
35 |
+
k=TOP_K,
|
36 |
+
faiss_index=faiss_index_path,
|
37 |
+
chunks_path=chunks_path
|
38 |
+
)
|
39 |
+
# Generate answer
|
40 |
+
return generate_answer(query, chunks, model, tokenizer)
|
41 |
+
|
42 |
+
# ——— GRADIO UI ———
|
43 |
+
iface = gr.Interface(
|
44 |
+
fn=answer_query,
|
45 |
+
inputs=gr.Textbox(lines=2, placeholder="Type your telecom question here…", label="Question"),
|
46 |
+
outputs=gr.Textbox(label="Answer"),
|
47 |
+
title="📡 Telecom RAG Assistant",
|
48 |
+
description=(
|
49 |
+
"Ask questions over the preloaded telecom regulation PDFs.\n\n"
|
50 |
+
)
|
51 |
+
)
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
iface.launch()
|
data/.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
data/2400594-RR-Vol 1-E-A5.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a19db60201342f44443143c01c528abe725b93cf133602631edaa2b6c9ba6a8f
|
3 |
+
size 2496298
|
data/2400594-RR-Vol 2-E-A5.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:49e283a5c22e57b44b5e0220eee94df4a7d2055418662bf83c125475d98f3abf
|
3 |
+
size 7200340
|
data/2400594-RR-Vol 3-E-A5.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a08159f6efe34c5ea71db5c698fc30bd27167a94b05dafdbe9a86196d6cdc492
|
3 |
+
size 5431837
|
data/2400594-RR-Vol 4-E-A5.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a4f5f521c351c117c817257bcb90a748ffaf15a41717a6d82385c3a6be2d3271
|
3 |
+
size 8507819
|
rag_agent.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import pdfplumber
|
4 |
+
import numpy as np
|
5 |
+
import faiss
|
6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
+
import torch
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
+
from utils.retrieval import retrieve_relevant_chunks
|
11 |
+
from utils.generation import generate_answer
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
|
15 |
+
DEFAULT_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
16 |
+
DEFAULT_TOP_K = 5
|
17 |
+
DEFAULT_CHUNK_SIZE = 500
|
18 |
+
DEFAULT_OVERLAP = 100
|
19 |
+
DEFAULT_INDEX_TYPE = "innerproduct"
|
20 |
+
|
21 |
+
def extract_text_from_pdfs(pdf_folder):
|
22 |
+
"""
|
23 |
+
Extract text from all PDF files in a folder.
|
24 |
+
"""
|
25 |
+
texts = []
|
26 |
+
for pdf_file in os.listdir(pdf_folder):
|
27 |
+
if pdf_file.endswith(".pdf"):
|
28 |
+
with pdfplumber.open(os.path.join(pdf_folder, pdf_file)) as pdf:
|
29 |
+
text = "\n".join(
|
30 |
+
[page.extract_text() for page in pdf.pages if page.extract_text()]
|
31 |
+
)
|
32 |
+
texts.append(text)
|
33 |
+
return texts
|
34 |
+
|
35 |
+
def chunk_text(texts, chunk_size, overlap, folder_path):
|
36 |
+
"""
|
37 |
+
Split text into overlapping chunks.
|
38 |
+
"""
|
39 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap)
|
40 |
+
chunks = []
|
41 |
+
for text in texts:
|
42 |
+
chunks.extend(splitter.split_text(text))
|
43 |
+
chunks_path = folder_path / f"chunks_{chunk_size}_{overlap}.npy"
|
44 |
+
np.save(chunks_path, chunks)
|
45 |
+
return chunks
|
46 |
+
|
47 |
+
def create_faiss_index(chunks, index_type, folder_path, embedder_name):
|
48 |
+
"""
|
49 |
+
Create a FAISS index based on the selected type.
|
50 |
+
"""
|
51 |
+
embedder = SentenceTransformer(embedder_name)
|
52 |
+
embeddings = embedder.encode(chunks, convert_to_numpy=True)
|
53 |
+
dimension = embeddings.shape[1]
|
54 |
+
|
55 |
+
if index_type == "flatl2":
|
56 |
+
index = faiss.IndexFlatL2(dimension)
|
57 |
+
elif index_type == "innerproduct":
|
58 |
+
index = faiss.IndexFlatIP(dimension)
|
59 |
+
elif index_type == "hnsw":
|
60 |
+
index = faiss.IndexHNSWFlat(dimension, 32) # HNSW with 32 connections per node
|
61 |
+
elif index_type == "ivfflat":
|
62 |
+
nlist = 100
|
63 |
+
quantizer = faiss.IndexFlatL2(dimension)
|
64 |
+
index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_L2)
|
65 |
+
index.train(embeddings)
|
66 |
+
elif index_type == "ivfpq":
|
67 |
+
nlist = 100
|
68 |
+
m = 8 # Number of subquantizers
|
69 |
+
quantizer = faiss.IndexFlatL2(dimension)
|
70 |
+
index = faiss.IndexIVFPQ(quantizer, dimension, nlist, m, 8)
|
71 |
+
index.train(embeddings)
|
72 |
+
elif index_type == "ivfsq":
|
73 |
+
nlist = 100
|
74 |
+
quantizer = faiss.IndexFlatL2(dimension)
|
75 |
+
index = faiss.IndexIVFScalarQuantizer(quantizer, dimension, nlist, faiss.ScalarQuantizer.QT_fp16)
|
76 |
+
index.train(embeddings)
|
77 |
+
else:
|
78 |
+
raise ValueError(f"Unsupported index type: {index_type}")
|
79 |
+
|
80 |
+
index.add(embeddings)
|
81 |
+
index_path = folder_path / f"index_{index_type}.idx"
|
82 |
+
faiss.write_index(index, str(index_path))
|
83 |
+
print(f"✅ FAISS Index ({index_type}) and text chunks saved successfully.")
|
84 |
+
|
85 |
+
|
86 |
+
def load_model(model_name):
|
87 |
+
"""
|
88 |
+
Load the generative model and its tokenizer.
|
89 |
+
"""
|
90 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
91 |
+
model = AutoModelForCausalLM.from_pretrained(
|
92 |
+
model_name,
|
93 |
+
torch_dtype=torch.float16,
|
94 |
+
device_map="auto"
|
95 |
+
)
|
96 |
+
return model, tokenizer
|
97 |
+
|
98 |
+
def prepare_index_and_chunks(pdf_folder, chunk_size, overlap, index_type, embedder_name):
|
99 |
+
"""
|
100 |
+
Prepare (or create if necessary) the FAISS index and text chunks from PDFs.
|
101 |
+
The folder is named based on the parameters, similar to evaluate_rag.
|
102 |
+
"""
|
103 |
+
folder_name = f"{embedder_name} ; {index_type}_chunk{chunk_size}_overlap{overlap}"
|
104 |
+
folder_path = Path(folder_name)
|
105 |
+
if folder_path.exists():
|
106 |
+
faiss_index_path = str(folder_path / f"index_{index_type}.idx")
|
107 |
+
chunks_path = folder_path / f"chunks_{chunk_size}_{overlap}.npy"
|
108 |
+
else:
|
109 |
+
folder_path.mkdir(parents=True, exist_ok=True)
|
110 |
+
texts = extract_text_from_pdfs(pdf_folder)
|
111 |
+
chunks = chunk_text(texts, chunk_size, overlap, folder_path)
|
112 |
+
create_faiss_index(chunks, index_type, folder_path, embedder_name)
|
113 |
+
faiss_index_path = str(folder_path / f"index_{index_type}.idx")
|
114 |
+
chunks_path = folder_path / f"chunks_{chunk_size}_{overlap}.npy"
|
115 |
+
return faiss_index_path, chunks_path
|
116 |
+
|
117 |
+
|
118 |
+
def rag_agent(pdf_folder, chunk_size, overlap, index_type, model_name, k):
|
119 |
+
"""
|
120 |
+
Interactive RAG chatbot that creates the FAISS index and text chunks if they don't exist.
|
121 |
+
"""
|
122 |
+
print("\n📡 Telecom Regulation RAG Agent (type 'exit' to quit)\n")
|
123 |
+
|
124 |
+
# Use the same embedder as in evaluate_rag for consistency
|
125 |
+
embedder_name = "all-MiniLM-L6-v2"
|
126 |
+
embedder = SentenceTransformer(embedder_name)
|
127 |
+
faiss_index, chunks_path = prepare_index_and_chunks(pdf_folder, chunk_size, overlap, index_type, embedder_name)
|
128 |
+
model, tokenizer = load_model(model_name)
|
129 |
+
|
130 |
+
while True:
|
131 |
+
query = input("Ask a question: ")
|
132 |
+
if query.lower() == "exit":
|
133 |
+
print("Exiting...")
|
134 |
+
break
|
135 |
+
|
136 |
+
retrieved_chunks = retrieve_relevant_chunks(query,embedder, k, faiss_index, chunks_path)
|
137 |
+
answer = generate_answer(query, retrieved_chunks, model, tokenizer)
|
138 |
+
|
139 |
+
print("\n🔹 Question:\n", query, "\n")
|
140 |
+
print("\n💡 Answer:\n", answer, "\n")
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
parser = argparse.ArgumentParser(description="Run the interactive RAG agent with index creation from PDFs.")
|
144 |
+
parser.add_argument("--pdf_folder", type=str, default="./data", help="Path to the folder containing PDF files.")
|
145 |
+
parser.add_argument("--chunk_size", type=int, default=DEFAULT_CHUNK_SIZE, help="Text chunk size.")
|
146 |
+
parser.add_argument("--overlap", type=int, default=DEFAULT_OVERLAP, help="Overlap size between chunks.")
|
147 |
+
parser.add_argument("--index_type", type=str, choices=["flatl2", "innerproduct", "hnsw", "ivfflat", "ivfpq", "ivfsq"],
|
148 |
+
default=DEFAULT_INDEX_TYPE, help="Type of FAISS index to use.")
|
149 |
+
parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Hugging Face model name.")
|
150 |
+
parser.add_argument("--top_k", type=int, default=DEFAULT_TOP_K, help="Number of retrieved text chunks to use.")
|
151 |
+
|
152 |
+
args = parser.parse_args()
|
153 |
+
rag_agent(args.pdf_folder, args.chunk_size, args.overlap, args.index_type, args.model_name, args.top_k)
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub==0.25.2
|
2 |
+
gradio
|
3 |
+
pdfplumber
|
4 |
+
numpy
|
5 |
+
faiss-cpu
|
6 |
+
langchain
|
7 |
+
torch
|
8 |
+
sentence-transformers
|
9 |
+
transformers
|
utils/__init__.py
ADDED
File without changes
|
utils/generation.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
|
4 |
+
|
5 |
+
MODEL_NAME = "facebook/opt-1.3b"
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
7 |
+
model = AutoModelForCausalLM.from_pretrained(
|
8 |
+
MODEL_NAME, torch_dtype=torch.float16, device_map="auto"
|
9 |
+
)
|
10 |
+
|
11 |
+
def generate_answer_chat(query, options, retrieved_chunks, model=model, tokenizer=tokenizer):
|
12 |
+
"""
|
13 |
+
Generates an answer using the retrieved context, formatted as a conversation
|
14 |
+
to better suit Llama 2 7B Chat's conversational tuning.
|
15 |
+
"""
|
16 |
+
# Format each retrieved chunk as a numbered paragraph.
|
17 |
+
paragraphs = [f"Paragraph {idx+1}: {chunk}" for idx, chunk in enumerate(retrieved_chunks)]
|
18 |
+
context = "\n\n".join(paragraphs)
|
19 |
+
|
20 |
+
# Create a conversational prompt.
|
21 |
+
system_message = (
|
22 |
+
"System: You are a telecom regulations expert. Answer using the information provided in the context. Start directly by Giving the best choice from options"
|
23 |
+
)
|
24 |
+
context_message = f"Context:\n{context}"
|
25 |
+
user_message = f"User: {query}\nOptions: " + " | ".join(options)
|
26 |
+
assistant_cue = "Assistant: "
|
27 |
+
|
28 |
+
prompt = "\n\n".join([system_message, context_message, user_message, assistant_cue])
|
29 |
+
|
30 |
+
# Determine the model type: seq2seq or causal.
|
31 |
+
model_type = "seq2seq" if getattr(model.config, "is_encoder_decoder", False) else "causal"
|
32 |
+
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
33 |
+
|
34 |
+
outputs = model.generate(
|
35 |
+
**inputs,
|
36 |
+
max_new_tokens=128,
|
37 |
+
num_return_sequences=1,
|
38 |
+
no_repeat_ngram_size=2
|
39 |
+
)
|
40 |
+
|
41 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
42 |
+
|
43 |
+
if model_type == "causal":
|
44 |
+
# Attempt to extract only the assistant's response.
|
45 |
+
answer_start = generated_text.find("Assistant:")
|
46 |
+
if answer_start != -1:
|
47 |
+
answer = generated_text[answer_start + len("Assistant:"):].strip()
|
48 |
+
else:
|
49 |
+
answer = generated_text[len(prompt):].strip()
|
50 |
+
return answer
|
51 |
+
else:
|
52 |
+
return generated_text.strip()
|
53 |
+
|
54 |
+
|
55 |
+
def generate_answer(query, retrieved_chunks, model=model, tokenizer=tokenizer):
|
56 |
+
"""
|
57 |
+
Generates an answer using the retrieved context.
|
58 |
+
|
59 |
+
For causal models, the prompt is included in the output so it must be removed.
|
60 |
+
For seq2seq models, the output is directly the generated answer.
|
61 |
+
"""
|
62 |
+
# Format each chunk as a separate paragraph with a numbered prefix.
|
63 |
+
paragraphs = [f"Paragraph {idx+1}: {chunk}" for idx, chunk in enumerate(retrieved_chunks)]
|
64 |
+
context = "\n\n".join(paragraphs)
|
65 |
+
|
66 |
+
prompt = (f"You are a telecom regulations expert. Using the following context, answer the question:\n\n"
|
67 |
+
f"Context:\n{context}\n\n"
|
68 |
+
f"Question: {query}\nAnswer:")
|
69 |
+
|
70 |
+
model_type = "seq2seq" if getattr(model.config, "is_encoder_decoder", False) else "causal"
|
71 |
+
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
72 |
+
outputs = model.generate(
|
73 |
+
**inputs,
|
74 |
+
max_new_tokens=128, # Specifies the number of tokens to generate.
|
75 |
+
num_return_sequences=1,
|
76 |
+
no_repeat_ngram_size=2
|
77 |
+
)
|
78 |
+
|
79 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
80 |
+
|
81 |
+
# For causal models, remove the prompt from the output.
|
82 |
+
if model_type == "causal":
|
83 |
+
# Remove the prompt from the output for causal models
|
84 |
+
return generated_text[len(prompt):].strip()
|
85 |
+
else:
|
86 |
+
return generated_text.strip()
|
87 |
+
|
88 |
+
|
89 |
+
def generate_norag(query, model, tokenizer):
|
90 |
+
"""
|
91 |
+
Generates an answer without additional context.
|
92 |
+
"""
|
93 |
+
prompt = f"Answer the question:\n\nQuestion: {query}\nAnswer:"
|
94 |
+
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
95 |
+
|
96 |
+
# Generate output with a specified maximum number of new tokens.
|
97 |
+
outputs = model.generate(
|
98 |
+
**inputs,
|
99 |
+
max_new_tokens=128, # Specifies the number of tokens to generate.
|
100 |
+
num_return_sequences=1,
|
101 |
+
no_repeat_ngram_size=2
|
102 |
+
)
|
103 |
+
|
104 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
105 |
+
model_type = "seq2seq" if getattr(model.config, "is_encoder_decoder", False) else "causal"
|
106 |
+
|
107 |
+
if model_type == "causal":
|
108 |
+
return generated_text[len(prompt):].strip()
|
109 |
+
else: # For seq2seq models
|
110 |
+
return generated_text.strip()
|
111 |
+
|
112 |
+
|
utils/retrieval.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss
|
2 |
+
import numpy as np
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
def retrieve_relevant_chunks(query,embedder, k=5, index_path="faiss_index.idx", chunks_path="text_chunks.npy"):
|
8 |
+
index = faiss.read_index(index_path)
|
9 |
+
chunks = np.load(chunks_path, allow_pickle=True)
|
10 |
+
query_embedding = embedder.encode([query], convert_to_numpy=True)
|
11 |
+
distances, indices = index.search(query_embedding, k)
|
12 |
+
|
13 |
+
retrieved_chunks = [chunks[i] for i in indices[0]]
|
14 |
+
return retrieved_chunks
|