gemma3-27b-RAG / app.py
Rohit1412's picture
Create app.py
d7f5ad7 verified
raw
history blame
5.91 kB
import os
import gradio as gr
import faiss
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
# ---------------------------
# Load Models (cached on first run)
# ---------------------------
def load_models():
hf_token = os.getenv("HF_TOKEN") # Set this secret in your HF Space settings
embed_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') # For embeddings
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-it", use_auth_token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-3-4b-it",
device_map="auto",
low_cpu_mem_usage=True,
use_auth_token=hf_token
)
return embed_model, tokenizer, model
embed_model, tokenizer, model = load_models()
# ---------------------------
# Global state for FAISS index and document chunks.
# Using a dictionary to hold state.
state = {
"faiss_index": None,
"doc_chunks": []
}
# ---------------------------
# Document Processing Function
# ---------------------------
def process_document(file, chunk_size, chunk_overlap):
"""
Reads the uploaded file (PDF or text), extracts text, splits into chunks,
computes embeddings, and builds a FAISS index.
"""
if file is None:
return "No file uploaded."
file_bytes = file.read()
file_name = file.name
text = ""
if file_name.lower().endswith(".pdf"):
try:
from PyPDF2 import PdfReader
except ImportError:
return "Error: PyPDF2 is required for PDF extraction."
# Save file to temporary path
temp_path = os.path.join("temp", file_name)
os.makedirs("temp", exist_ok=True)
with open(temp_path, "wb") as f:
f.write(file_bytes)
reader = PdfReader(temp_path)
for page in reader.pages:
text += page.extract_text() or ""
else:
# Assume it's a text file
text = file_bytes.decode("utf-8", errors="ignore")
if text.strip() == "":
return "No text found in the document."
# Split text into overlapping chunks
chunks = []
for start in range(0, len(text), chunk_size - chunk_overlap):
chunk_text = text[start: start + chunk_size]
chunks.append(chunk_text)
# Compute embeddings for each chunk using the embedding model.
embeddings = embed_model.encode(chunks, normalize_embeddings=True).astype('float32')
dim = embeddings.shape[1]
# Build FAISS index using cosine similarity (normalized vectors -> inner product)
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
# Update global state
state["faiss_index"] = index
state["doc_chunks"] = chunks
# Return a preview (first 500 characters of the first chunk) and status.
preview = chunks[0][:500] if chunks else "No content"
return f"Indexed {len(chunks)} chunks.\n\n**Document Preview:**\n{preview}"
# ---------------------------
# Question Answering Function
# ---------------------------
def answer_question(query, top_k):
"""
Retrieves the top_k chunks most relevant to the query using the FAISS index,
builds a prompt with the retrieved context, and generates an answer using the Gemma model.
"""
index = state.get("faiss_index")
chunks = state.get("doc_chunks")
if index is None or len(chunks) == 0:
return "No document processed. Please upload a document first."
# Encode query using the same embedding model
query_vec = embed_model.encode([query], normalize_embeddings=True).astype('float32')
D, I = index.search(query_vec, top_k)
# Concatenate retrieved chunks as context
retrieved_text = ""
for idx in I[0]:
retrieved_text += chunks[idx] + "\n"
# Formulate the prompt for the generative model
prompt = f"Context:\n{retrieved_text}\nQuestion: {query}\nAnswer:"
# Tokenize and generate answer
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
output_ids = model.generate(input_ids, max_new_tokens=200, temperature=0.2)
answer = tokenizer.decode(output_ids[0][input_ids.size(1):], skip_special_tokens=True)
return answer.strip()
# ---------------------------
# Gradio Interface
# ---------------------------
with gr.Blocks(title="RAG System with Gemma‑3‑4B‑it") as demo:
gr.Markdown(
"""
# RAG System with Gemma‑3‑4B‑it
Upload a document (PDF or TXT) below. The system will extract text, split it into chunks,
build a vector index using FAISS, and then allow you to ask questions based on the document.
"""
)
with gr.Tab("Document Upload & Processing"):
with gr.Row():
file_input = gr.File(label="Upload Document (PDF or TXT)", file_count="single")
with gr.Row():
chunk_size_input = gr.Number(label="Chunk Size (characters)", value=1000, precision=0)
chunk_overlap_input = gr.Number(label="Chunk Overlap (characters)", value=100, precision=0)
process_btn = gr.Button("Process Document")
process_output = gr.Markdown()
with gr.Tab("Ask a Question"):
query_input = gr.Textbox(label="Enter your question", placeholder="Type your question here...")
top_k_input = gr.Number(label="Number of Chunks to Retrieve", value=3, precision=0)
answer_btn = gr.Button("Get Answer")
answer_output = gr.Markdown(label="Answer")
# Set up actions
process_btn.click(
fn=process_document,
inputs=[file_input, chunk_size_input, chunk_overlap_input],
outputs=process_output
)
answer_btn.click(
fn=answer_question,
inputs=[query_input, top_k_input],
outputs=answer_output
)
if __name__ == "__main__":
demo.launch()