File size: 6,355 Bytes
cfa9e5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import argparse
import pdfplumber
import numpy as np
import faiss
from langchain.text_splitter import RecursiveCharacterTextSplitter
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.retrieval import retrieve_relevant_chunks
from utils.generation import generate_answer
from pathlib import Path
DEFAULT_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
DEFAULT_TOP_K = 5
DEFAULT_CHUNK_SIZE = 500
DEFAULT_OVERLAP = 100
DEFAULT_INDEX_TYPE = "innerproduct"
def extract_text_from_pdfs(pdf_folder):
"""
Extract text from all PDF files in a folder.
"""
texts = []
for pdf_file in os.listdir(pdf_folder):
if pdf_file.endswith(".pdf"):
with pdfplumber.open(os.path.join(pdf_folder, pdf_file)) as pdf:
text = "\n".join(
[page.extract_text() for page in pdf.pages if page.extract_text()]
)
texts.append(text)
return texts
def chunk_text(texts, chunk_size, overlap, folder_path):
"""
Split text into overlapping chunks.
"""
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap)
chunks = []
for text in texts:
chunks.extend(splitter.split_text(text))
chunks_path = folder_path / f"chunks_{chunk_size}_{overlap}.npy"
np.save(chunks_path, chunks)
return chunks
def create_faiss_index(chunks, index_type, folder_path, embedder_name):
"""
Create a FAISS index based on the selected type.
"""
embedder = SentenceTransformer(embedder_name)
embeddings = embedder.encode(chunks, convert_to_numpy=True)
dimension = embeddings.shape[1]
if index_type == "flatl2":
index = faiss.IndexFlatL2(dimension)
elif index_type == "innerproduct":
index = faiss.IndexFlatIP(dimension)
elif index_type == "hnsw":
index = faiss.IndexHNSWFlat(dimension, 32) # HNSW with 32 connections per node
elif index_type == "ivfflat":
nlist = 100
quantizer = faiss.IndexFlatL2(dimension)
index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_L2)
index.train(embeddings)
elif index_type == "ivfpq":
nlist = 100
m = 8 # Number of subquantizers
quantizer = faiss.IndexFlatL2(dimension)
index = faiss.IndexIVFPQ(quantizer, dimension, nlist, m, 8)
index.train(embeddings)
elif index_type == "ivfsq":
nlist = 100
quantizer = faiss.IndexFlatL2(dimension)
index = faiss.IndexIVFScalarQuantizer(quantizer, dimension, nlist, faiss.ScalarQuantizer.QT_fp16)
index.train(embeddings)
else:
raise ValueError(f"Unsupported index type: {index_type}")
index.add(embeddings)
index_path = folder_path / f"index_{index_type}.idx"
faiss.write_index(index, str(index_path))
print(f"✅ FAISS Index ({index_type}) and text chunks saved successfully.")
def load_model(model_name):
"""
Load the generative model and its tokenizer.
"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
return model, tokenizer
def prepare_index_and_chunks(pdf_folder, chunk_size, overlap, index_type, embedder_name):
"""
Prepare (or create if necessary) the FAISS index and text chunks from PDFs.
The folder is named based on the parameters, similar to evaluate_rag.
"""
folder_name = f"{embedder_name} ; {index_type}_chunk{chunk_size}_overlap{overlap}"
folder_path = Path(folder_name)
if folder_path.exists():
faiss_index_path = str(folder_path / f"index_{index_type}.idx")
chunks_path = folder_path / f"chunks_{chunk_size}_{overlap}.npy"
else:
folder_path.mkdir(parents=True, exist_ok=True)
texts = extract_text_from_pdfs(pdf_folder)
chunks = chunk_text(texts, chunk_size, overlap, folder_path)
create_faiss_index(chunks, index_type, folder_path, embedder_name)
faiss_index_path = str(folder_path / f"index_{index_type}.idx")
chunks_path = folder_path / f"chunks_{chunk_size}_{overlap}.npy"
return faiss_index_path, chunks_path
def rag_agent(pdf_folder, chunk_size, overlap, index_type, model_name, k):
"""
Interactive RAG chatbot that creates the FAISS index and text chunks if they don't exist.
"""
print("\n📡 Telecom Regulation RAG Agent (type 'exit' to quit)\n")
# Use the same embedder as in evaluate_rag for consistency
embedder_name = "all-MiniLM-L6-v2"
embedder = SentenceTransformer(embedder_name)
faiss_index, chunks_path = prepare_index_and_chunks(pdf_folder, chunk_size, overlap, index_type, embedder_name)
model, tokenizer = load_model(model_name)
while True:
query = input("Ask a question: ")
if query.lower() == "exit":
print("Exiting...")
break
retrieved_chunks = retrieve_relevant_chunks(query,embedder, k, faiss_index, chunks_path)
answer = generate_answer(query, retrieved_chunks, model, tokenizer)
print("\n🔹 Question:\n", query, "\n")
print("\n💡 Answer:\n", answer, "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the interactive RAG agent with index creation from PDFs.")
parser.add_argument("--pdf_folder", type=str, default="./data", help="Path to the folder containing PDF files.")
parser.add_argument("--chunk_size", type=int, default=DEFAULT_CHUNK_SIZE, help="Text chunk size.")
parser.add_argument("--overlap", type=int, default=DEFAULT_OVERLAP, help="Overlap size between chunks.")
parser.add_argument("--index_type", type=str, choices=["flatl2", "innerproduct", "hnsw", "ivfflat", "ivfpq", "ivfsq"],
default=DEFAULT_INDEX_TYPE, help="Type of FAISS index to use.")
parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Hugging Face model name.")
parser.add_argument("--top_k", type=int, default=DEFAULT_TOP_K, help="Number of retrieved text chunks to use.")
args = parser.parse_args()
rag_agent(args.pdf_folder, args.chunk_size, args.overlap, args.index_type, args.model_name, args.top_k)
|