Jesudian's picture
Upload 6 files
ceea46c verified
import fitz # PyMuPDF
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
import os
from typing import List, Tuple
class DocumentProcessor:
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
self.model = SentenceTransformer(model_name)
self.chunk_size = 500 # tokens
self.chunk_overlap = 50 # tokens
def extract_text_from_pdf(self, pdf_path: str) -> str:
"""Extract text from PDF file."""
doc = fitz.open(pdf_path)
text = ""
for page in doc:
text += page.get_text()
return text
def chunk_text(self, text: str) -> List[str]:
"""Split text into overlapping chunks."""
words = text.split()
chunks = []
for i in range(0, len(words), self.chunk_size - self.chunk_overlap):
chunk = " ".join(words[i:i + self.chunk_size])
chunks.append(chunk)
return chunks
def create_embeddings(self, chunks: List[str]) -> np.ndarray:
"""Create embeddings for text chunks."""
return self.model.encode(chunks)
def build_faiss_index(self, embeddings: np.ndarray) -> faiss.Index:
"""Build and return a FAISS index."""
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings.astype('float32'))
return index
def save_index(self, index: faiss.Index, path: str):
"""Save FAISS index to disk."""
faiss.write_index(index, path)
def load_index(self, path: str) -> faiss.Index:
"""Load FAISS index from disk."""
return faiss.read_index(path)
def process_document(self, pdf_path: str, index_path: str) -> Tuple[List[str], faiss.Index]:
"""Process document and create FAISS index."""
# Extract and chunk text
text = self.extract_text_from_pdf(pdf_path)
chunks = self.chunk_text(text)
# Create embeddings and index
embeddings = self.create_embeddings(chunks)
index = self.build_faiss_index(embeddings)
# Save index
self.save_index(index, index_path)
return chunks, index
def retrieve_chunks(self, query: str, index: faiss.Index, chunks: List[str], k: int = 5) -> List[str]:
"""Retrieve most relevant chunks for a query."""
query_embedding = self.model.encode([query])
distances, indices = index.search(query_embedding.astype('float32'), k)
return [chunks[i] for i in indices[0]]