|
import fitz |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import os |
|
from typing import List, Tuple |
|
|
|
class DocumentProcessor: |
|
def __init__(self, model_name: str = "all-MiniLM-L6-v2"): |
|
self.model = SentenceTransformer(model_name) |
|
self.chunk_size = 500 |
|
self.chunk_overlap = 50 |
|
|
|
def extract_text_from_pdf(self, pdf_path: str) -> str: |
|
"""Extract text from PDF file.""" |
|
doc = fitz.open(pdf_path) |
|
text = "" |
|
for page in doc: |
|
text += page.get_text() |
|
return text |
|
|
|
def chunk_text(self, text: str) -> List[str]: |
|
"""Split text into overlapping chunks.""" |
|
words = text.split() |
|
chunks = [] |
|
|
|
for i in range(0, len(words), self.chunk_size - self.chunk_overlap): |
|
chunk = " ".join(words[i:i + self.chunk_size]) |
|
chunks.append(chunk) |
|
|
|
return chunks |
|
|
|
def create_embeddings(self, chunks: List[str]) -> np.ndarray: |
|
"""Create embeddings for text chunks.""" |
|
return self.model.encode(chunks) |
|
|
|
def build_faiss_index(self, embeddings: np.ndarray) -> faiss.Index: |
|
"""Build and return a FAISS index.""" |
|
dimension = embeddings.shape[1] |
|
index = faiss.IndexFlatL2(dimension) |
|
index.add(embeddings.astype('float32')) |
|
return index |
|
|
|
def save_index(self, index: faiss.Index, path: str): |
|
"""Save FAISS index to disk.""" |
|
faiss.write_index(index, path) |
|
|
|
def load_index(self, path: str) -> faiss.Index: |
|
"""Load FAISS index from disk.""" |
|
return faiss.read_index(path) |
|
|
|
def process_document(self, pdf_path: str, index_path: str) -> Tuple[List[str], faiss.Index]: |
|
"""Process document and create FAISS index.""" |
|
|
|
text = self.extract_text_from_pdf(pdf_path) |
|
chunks = self.chunk_text(text) |
|
|
|
|
|
embeddings = self.create_embeddings(chunks) |
|
index = self.build_faiss_index(embeddings) |
|
|
|
|
|
self.save_index(index, index_path) |
|
|
|
return chunks, index |
|
|
|
def retrieve_chunks(self, query: str, index: faiss.Index, chunks: List[str], k: int = 5) -> List[str]: |
|
"""Retrieve most relevant chunks for a query.""" |
|
query_embedding = self.model.encode([query]) |
|
distances, indices = index.search(query_embedding.astype('float32'), k) |
|
|
|
return [chunks[i] for i in indices[0]] |