Spaces:
Sleeping
Sleeping
# !pip install pdfplumber | |
# !pip install rank_bm25 | |
# !pip install langchain | |
# pip install sentence_transformers | |
# conda install -c conda-forge faiss-cpu | |
import pdfplumber | |
import pandas as pd | |
import numpy as np | |
import re | |
import os | |
from ast import literal_eval | |
import faiss | |
from llama_cpp import Llama, LlamaGrammar | |
from rank_bm25 import BM25Okapi | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from sentence_transformers import SentenceTransformer, util | |
from sklearn.metrics.pairwise import cosine_similarity | |
import PyPDF2 | |
embedding_model = SentenceTransformer("models/all-MiniLM-L6-v2/") | |
llm = Llama(model_path="models/Llama-3.2-1B-Instruct-Q4_K_M.gguf", | |
n_gpu_layers=-1, n_ctx=8000) | |
def extract_info_from_pdf(pdf_path): | |
""" | |
Extracts both paragraphs and tables from each PDF page using pdfplumber. | |
Returns a list of dictionaries with keys: "page_number", "paragraphs", "tables". | |
""" | |
document_data = [] | |
with pdfplumber.open(pdf_path) as pdf: | |
for i, page in enumerate(pdf.pages, start=1): | |
page_data = {"page_number": i, "paragraphs": [], "tables": []} | |
text = page.extract_text() | |
if text: | |
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] | |
page_data["paragraphs"] = paragraphs | |
tables = page.extract_tables() | |
dfs = [] | |
for table in tables: | |
if len(table) > 1: | |
df = pd.DataFrame(table[1:], columns=table[0]) | |
else: | |
df = pd.DataFrame(table) | |
dfs.append(df) | |
page_data["tables"] = dfs | |
document_data.append(page_data) | |
return document_data | |
def extract_financial_tables_regex(text): | |
""" | |
Extracts financial table information using a regex pattern (basic extraction). | |
""" | |
pattern = re.compile(r"(Revenue from Operations.*?)\n\n", re.DOTALL) | |
matches = pattern.findall(text) | |
if matches: | |
data_lines = matches[0].split("\n") | |
structured_data = [line.split() for line in data_lines if line.strip()] | |
if len(structured_data) > 1: | |
df = pd.DataFrame(structured_data[1:], columns=structured_data[0]) | |
return df | |
return pd.DataFrame() | |
def clean_financial_data(df): | |
""" | |
Cleans the financial DataFrame by converting numerical columns. | |
""" | |
if df.empty: | |
return "" | |
for col in df.columns[1:]: | |
df[col] = df[col].replace({',': ''}, regex=True) | |
df[col] = pd.to_numeric(df[col], errors='coerce') | |
return df.to_string() | |
def combine_extracted_info(document_data, financial_text_regex=""): | |
""" | |
Combines extracted paragraphs and tables (converted to strings) into a single text. | |
Optionally appends extra financial table text. | |
""" | |
text_segments = [] | |
for page in document_data: | |
for paragraph in page["paragraphs"]: | |
text_segments.append(paragraph) | |
for table in page["tables"]: | |
text_segments.append(table.to_string(index=False)) | |
if financial_text_regex: | |
text_segments.append(financial_text_regex) | |
return "\n".join(text_segments) | |
def extract_text_from_pdf_pypdf2(pdf_path): | |
text = "" | |
with open(pdf_path, "rb") as file: | |
reader = PyPDF2.PdfReader(file) | |
for page in reader.pages: | |
text += page.extract_text() + "\n" | |
return text | |
def chunk_text(text, chunk_size=500, chunk_overlap=50): | |
""" | |
Uses RecursiveCharacterTextSplitter to chunk text. | |
""" | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
chunks = text_splitter.split_text(text) | |
return chunks | |
def build_faiss_index(chunks, embedding_model): | |
chunk_embeddings = embedding_model.encode(chunks) | |
dimension = chunk_embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(np.array(chunk_embeddings)) | |
return index, chunk_embeddings | |
def retrieve_basic(query, index, chunks, embedding_model, k=5): | |
query_embedding = embedding_model.encode([query]) | |
distances, indices = index.search(np.array(query_embedding), k) | |
return [chunks[i] for i in indices[0]], distances[0] | |
def retrieve_bm25(query, chunks, k=5): | |
tokenized_corpus = [chunk.lower().split() for chunk in chunks] | |
bm25_model = BM25Okapi(tokenized_corpus) | |
tokenized_query = query.lower().split() | |
scores = bm25_model.get_scores(tokenized_query) | |
top_indices = np.argsort(scores)[::-1][:k] | |
return [chunks[i] for i in top_indices], scores[top_indices] | |
def retrieve_advanced_embedding(query, chunks, embedding_model, k=5): | |
chunk_embeddings = embedding_model.encode(chunks) | |
query_embedding = embedding_model.encode([query]) | |
scores = cosine_similarity(np.array(query_embedding), np.array(chunk_embeddings))[0] | |
top_indices = np.argsort(scores)[::-1][:k] | |
return [chunks[i] for i in top_indices], scores[top_indices] | |
def rerank_candidates(query, candidate_chunks, embedding_model): | |
""" | |
Re-ranks candidate chunks using cosine similarity with the query. | |
""" | |
candidate_embeddings = embedding_model.encode(candidate_chunks) | |
query_embedding = embedding_model.encode([query]) | |
scores = cosine_similarity(np.array(query_embedding), np.array(candidate_embeddings))[0] | |
ranked_indices = np.argsort(scores)[::-1] | |
reranked_chunks = [candidate_chunks[i] for i in ranked_indices] | |
reranked_scores = scores[ranked_indices] | |
return reranked_chunks, reranked_scores | |
def get_grammar() -> LlamaGrammar: | |
""" | |
:return: | |
""" | |
file_path = "rag_app/guardrail.gbnf" | |
with open(file_path, 'r') as handler: | |
content = handler.read() | |
return LlamaGrammar.from_string(content) | |
def answer_question(query, context=None, max_length=5000): | |
# guardrails logic | |
output = llm(f"""Is this a harmful query: \n Query: {query}. \n\n Answer in 'SAFE'/'UNSAFE'""", | |
max_tokens=1000, stop=[], echo=False) | |
tag = llm(f"Is this a harmful query. Content:\n {output['choices'][0]['text']} \n\n Answer in 'SAFE'/'UNSAFE'", | |
max_tokens=1000, stop=[], echo=False, grammar=get_grammar()) | |
flag = literal_eval(tag['choices'][0]['text'])['flag'] | |
if flag == 'unsafe': | |
return "This question has been categorized as harmful. I can't help with these types of queries." | |
if not context: | |
output = llm( | |
f"""You're a helpful assistant. Answer the user query's in a professional tone. | |
Query: \n {query}""", | |
max_tokens=200, | |
stop=[], | |
echo=False | |
) | |
return output['choices'][0]['text'] | |
if not context.strip(): | |
return "Insufficient context to generate an answer." | |
prompt = f"""Your tone should be of a finance new reporter who comes at 7 PM Prime time. Questions would be | |
regarding a company's financials. Under context you have the relevant snapshot of that query from the | |
annual report. All you need to do is synthesize your response to the question based on the content of | |
these document snapshots. | |
# Context: | |
{context}\n\n | |
# Question: {query} | |
\nAnswer: | |
""" | |
output = llm( | |
prompt, | |
max_tokens=max_length, | |
stop=[], | |
echo=False | |
) | |
return output['choices'][0]['text'] | |
def extract_final_answer(pdf_files, query): | |
combined_text = "" | |
for pdf_path in pdf_files: | |
print("reading:", pdf_path) | |
document_data = extract_info_from_pdf(pdf_path) | |
print("document_data:", len(document_data)) | |
basic_text = extract_text_from_pdf_pypdf2(pdf_path) | |
financial_df = extract_financial_tables_regex(basic_text) | |
cleaned_financial_text = clean_financial_data(financial_df) | |
combined_text = combined_text + "\n" + combine_extracted_info(document_data, cleaned_financial_text) | |
print("Combined text length:", len(combined_text)) | |
chunks = chunk_text(combined_text, chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
print(f"Total chunks created: {len(chunks)}") | |
faiss_index, _ = build_faiss_index(chunks, embedding_model) | |
basic_results, basic_distances = retrieve_basic(query, faiss_index, chunks, embedding_model, k=k) | |
print("\n--- Basic RAG Results (FAISS) ---\n\n\n") | |
for chunk, dist in zip(basic_results, basic_distances): | |
print(f"Distance: {dist:.4f}\n") | |
print(f"Chunk: {chunk}\n{'-' * 40}") | |
bm25_results, bm25_scores = retrieve_bm25(query, chunks, k=k) | |
adv_emb_results, adv_emb_scores = retrieve_advanced_embedding(query, chunks, embedding_model, k=k) | |
print("\n--- Advanced RAG BM25 Results ---") | |
for chunk, score in zip(bm25_results, bm25_scores): | |
print(f"BM25 Score: {score:.4f}\nChunk: {chunk}\n{'-' * 40}") | |
print("\n--- Advanced RAG Embedding Results ---") | |
for chunk, score in zip(adv_emb_results, adv_emb_scores): | |
print(f"Embedding Similarity: {score:.4f}\nChunk: {chunk}\n{'-' * 40}") | |
candidate_set = list(set(basic_results + bm25_results + adv_emb_results)) | |
print(f"\nTotal unique candidate chunks: {len(candidate_set)}") | |
reranked_chunks, reranked_scores = rerank_candidates(query, candidate_set, embedding_model) | |
print("\n--- Re-ranked Candidate Chunks ---") | |
for chunk, score in zip(reranked_chunks, reranked_scores): | |
print(f"Re-ranked Score: {score:.4f}\nChunk: {chunk}\n{'-' * 40}") | |
top_context = "\n".join(reranked_chunks[:k]) | |
final_answer = answer_question(query, top_context) | |
print("\n--- Final Answer ---") | |
print(final_answer) | |
return final_answer | |
# Define paths, query, and parameters | |
# pdf_path = "reliance-jio-infocomm-limited-annual-report-fy-2023-24.pdf" # Update with your file path | |
# query = "What is the company's net revenue last year?" # Example query | |
chunk_size = 500 | |
chunk_overlap = 50 | |
candiadate_to_retrieve = 10 # Number of candidates to retrieve | |
k = 2 | |
# extract_final_answer([pdf_path],"hello world") | |