File size: 5,857 Bytes
d7f5ad7
1c5c7d4
 
 
3a1d8c8
1c5c7d4
 
 
 
 
 
 
d7f5ad7
1c5c7d4
 
 
 
 
 
 
 
 
 
 
3a1d8c8
1c5c7d4
3a1d8c8
 
 
 
1c5c7d4
3a1d8c8
1c5c7d4
 
 
 
4d1faa0
1c5c7d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22257f4
1c5c7d4
 
 
 
 
 
 
 
22257f4
1c5c7d4
 
4688d6b
1c5c7d4
 
 
 
 
 
 
 
 
 
3a1d8c8
4688d6b
1c5c7d4
4688d6b
 
 
 
1c5c7d4
 
 
 
 
 
 
 
 
 
 
 
 
4d1faa0
1c5c7d4
 
4688d6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7f5ad7
4688d6b
1c5c7d4
4d1faa0
1c5c7d4
 
 
 
4688d6b
1c5c7d4
 
3a1d8c8
1c5c7d4
d7f5ad7
4688d6b
1c5c7d4
 
d7f5ad7
3a1d8c8
 
4688d6b
1c5c7d4
3a1d8c8
 
 
d7f5ad7
1c5c7d4
 
4688d6b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
import torch
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import PyPDF2
import os
import time
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load models
retriever_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
gen_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
gen_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")

# Cache for document embeddings
embedding_cache = {}

def extract_text_from_pdf(pdf_file):
    """Extract text from a PDF file, returning a list of page texts."""
    pages = []
    try:
        with open(pdf_file.name, "rb") as f:
            reader = PyPDF2.PdfReader(f)
            for page in reader.pages:
                text = page.extract_text()
                if text:
                    pages.append(text.strip())
    except Exception as e:
        logger.error(f"Error reading PDF {pdf_file.name}: {str(e)}")
        pages.append(f"Error reading PDF: {str(e)}")
    return pages

def chunk_text(text, chunk_size=1500):
    """Split text into chunks of approximately chunk_size characters."""
    words = text.split()
    chunks = []
    current_chunk = []
    current_length = 0
    for word in words:
        if current_length + len(word) > chunk_size and current_chunk:
            chunks.append(" ".join(current_chunk))
            current_chunk = []
            current_length = 0
        current_chunk.append(word)
        current_length += len(word) + 1  # +1 for space
    if current_chunk:
        chunks.append(" ".join(current_chunk))
    return chunks

def get_document_embeddings(documents):
    """Compute embeddings for documents, using cache if available, and return a stacked tensor."""
    embeddings = []
    for doc in documents:
        if doc in embedding_cache:
            embeddings.append(embedding_cache[doc])
        else:
            emb = retriever_model.encode(doc, convert_to_tensor=True)
            embedding_cache[doc] = emb
            embeddings.append(emb)
    return torch.stack(embeddings)

def rag_pipeline(question, pdf_files):
    """Optimized RAG pipeline with improved prompting and fallback."""
    start_time = time.time()
    documents = []

    # Process PDFs if provided
    if pdf_files:
        for pdf in pdf_files:
            pages = extract_text_from_pdf(pdf)
            for page in pages:
                chunks = chunk_text(page)
                documents.extend(chunks)
    else:
        # Default documents relevant to AI and Data Science
        documents = [
            "Artificial Intelligence (AI) is the simulation of human intelligence in machines.",
            "Data Science involves extracting insights from structured and unstructured data using statistical methods.",
            "AI and Data Science often work together to build predictive models and automate decision-making.",
            "Machine learning, a subset of AI, is widely used in Data Science for pattern recognition.",
        ]

    if not documents:
        return "No valid text could be extracted from the PDFs."

    # Compute embeddings with caching
    doc_embeddings = get_document_embeddings(documents)

    # Embed the query
    query_embedding = retriever_model.encode(question, convert_to_tensor=True)

    # Retrieve top 3 chunks using cosine similarity
    cos_scores = util.pytorch_cos_sim(query_embedding, doc_embeddings)[0]
    top_results = torch.topk(cos_scores, k=min(5, len(documents)))
    retrieved_context = ""
    for score, idx in zip(top_results.values, top_results.indices):
        retrieved_context += f"- {documents[idx]} (score: {score:.2f})\n"
    
    # Log retrieved context for debugging
    logger.info(f"Retrieved context:\n{retrieved_context}")

    # Improved prompt with fallback
    if retrieved_context.strip():
        prompt = (
            f"Based on the following context, provide a concise and accurate answer to the question.\n\n"
            f"Context:\n{retrieved_context}\n\n"
            f"Question: {question}\n\n"
            f"Answer:"
        )
    else:
        prompt = (
            f"No relevant context found. Provide a general answer to the question based on your knowledge.\n\n"
            f"Question: {question}\n\n"
            f"Answer:"
        )

    # Generate answer with more tokens
    inputs = gen_tokenizer(prompt, return_tensors="pt")
    outputs = gen_model.generate(**inputs, max_new_tokens=1500, num_beams=2)
    answer = gen_tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Log processing time
    logger.info(f"Processing time: {time.time() - start_time:.2f} seconds")
    return answer if answer else "Unable to generate a meaningful response."

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# Improved Lightweight Local RAG Pipeline with PDF Input")
    gr.Markdown(
        "Upload one or more PDF files (or leave blank for default AI/Data Science documents), enter your question, "
        "and get an answer generated using an optimized retrieval step (all-MiniLM-L6-v2) and a small "
        "generator model (flan-t5-small). Designed for 2 vCPUs and 16GB RAM."
    )
    with gr.Row():
        with gr.Column():
            question_input = gr.Textbox(label="Your Question", placeholder="e.g., What is AI and Data Science?", lines=3)
            pdf_input = gr.File(label="Upload PDF(s) (optional)", file_types=[".pdf"], file_count="multiple")
            submit_button = gr.Button("Submit")
        with gr.Column():
            response_output = gr.Textbox(label="Response", placeholder="The answer will appear here...", lines=10)
    
    submit_button.click(fn=rag_pipeline, inputs=[question_input, pdf_input], outputs=response_output)

demo.launch()