Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import faiss | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from sentence_transformers import SentenceTransformer | |
# --------------------------- | |
# Load Models (cached on first run) | |
# --------------------------- | |
def load_models(): | |
hf_token = os.getenv("HF_TOKEN") # Set this secret in your HF Space settings | |
embed_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') # For embeddings | |
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-it", use_auth_token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained( | |
"google/gemma-3-4b-it", | |
device_map="auto", | |
low_cpu_mem_usage=True, | |
use_auth_token=hf_token | |
) | |
return embed_model, tokenizer, model | |
embed_model, tokenizer, model = load_models() | |
# --------------------------- | |
# Global state for FAISS index and document chunks. | |
# Using a dictionary to hold state. | |
state = { | |
"faiss_index": None, | |
"doc_chunks": [] | |
} | |
# --------------------------- | |
# Document Processing Function | |
# --------------------------- | |
def process_document(file, chunk_size, chunk_overlap): | |
""" | |
Reads the uploaded file (PDF or text), extracts text, splits into chunks, | |
computes embeddings, and builds a FAISS index. | |
""" | |
if file is None: | |
return "No file uploaded." | |
file_bytes = file.read() | |
file_name = file.name | |
text = "" | |
if file_name.lower().endswith(".pdf"): | |
try: | |
from PyPDF2 import PdfReader | |
except ImportError: | |
return "Error: PyPDF2 is required for PDF extraction." | |
# Save file to temporary path | |
temp_path = os.path.join("temp", file_name) | |
os.makedirs("temp", exist_ok=True) | |
with open(temp_path, "wb") as f: | |
f.write(file_bytes) | |
reader = PdfReader(temp_path) | |
for page in reader.pages: | |
text += page.extract_text() or "" | |
else: | |
# Assume it's a text file | |
text = file_bytes.decode("utf-8", errors="ignore") | |
if text.strip() == "": | |
return "No text found in the document." | |
# Split text into overlapping chunks | |
chunks = [] | |
for start in range(0, len(text), chunk_size - chunk_overlap): | |
chunk_text = text[start: start + chunk_size] | |
chunks.append(chunk_text) | |
# Compute embeddings for each chunk using the embedding model. | |
embeddings = embed_model.encode(chunks, normalize_embeddings=True).astype('float32') | |
dim = embeddings.shape[1] | |
# Build FAISS index using cosine similarity (normalized vectors -> inner product) | |
index = faiss.IndexFlatIP(dim) | |
index.add(embeddings) | |
# Update global state | |
state["faiss_index"] = index | |
state["doc_chunks"] = chunks | |
# Return a preview (first 500 characters of the first chunk) and status. | |
preview = chunks[0][:500] if chunks else "No content" | |
return f"Indexed {len(chunks)} chunks.\n\n**Document Preview:**\n{preview}" | |
# --------------------------- | |
# Question Answering Function | |
# --------------------------- | |
def answer_question(query, top_k): | |
""" | |
Retrieves the top_k chunks most relevant to the query using the FAISS index, | |
builds a prompt with the retrieved context, and generates an answer using the Gemma model. | |
""" | |
index = state.get("faiss_index") | |
chunks = state.get("doc_chunks") | |
if index is None or len(chunks) == 0: | |
return "No document processed. Please upload a document first." | |
# Encode query using the same embedding model | |
query_vec = embed_model.encode([query], normalize_embeddings=True).astype('float32') | |
D, I = index.search(query_vec, top_k) | |
# Concatenate retrieved chunks as context | |
retrieved_text = "" | |
for idx in I[0]: | |
retrieved_text += chunks[idx] + "\n" | |
# Formulate the prompt for the generative model | |
prompt = f"Context:\n{retrieved_text}\nQuestion: {query}\nAnswer:" | |
# Tokenize and generate answer | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) | |
output_ids = model.generate(input_ids, max_new_tokens=200, temperature=0.2) | |
answer = tokenizer.decode(output_ids[0][input_ids.size(1):], skip_special_tokens=True) | |
return answer.strip() | |
# --------------------------- | |
# Gradio Interface | |
# --------------------------- | |
with gr.Blocks(title="RAG System with Gemma‑3‑4B‑it") as demo: | |
gr.Markdown( | |
""" | |
# RAG System with Gemma‑3‑4B‑it | |
Upload a document (PDF or TXT) below. The system will extract text, split it into chunks, | |
build a vector index using FAISS, and then allow you to ask questions based on the document. | |
""" | |
) | |
with gr.Tab("Document Upload & Processing"): | |
with gr.Row(): | |
file_input = gr.File(label="Upload Document (PDF or TXT)", file_count="single") | |
with gr.Row(): | |
chunk_size_input = gr.Number(label="Chunk Size (characters)", value=1000, precision=0) | |
chunk_overlap_input = gr.Number(label="Chunk Overlap (characters)", value=100, precision=0) | |
process_btn = gr.Button("Process Document") | |
process_output = gr.Markdown() | |
with gr.Tab("Ask a Question"): | |
query_input = gr.Textbox(label="Enter your question", placeholder="Type your question here...") | |
top_k_input = gr.Number(label="Number of Chunks to Retrieve", value=3, precision=0) | |
answer_btn = gr.Button("Get Answer") | |
answer_output = gr.Markdown(label="Answer") | |
# Set up actions | |
process_btn.click( | |
fn=process_document, | |
inputs=[file_input, chunk_size_input, chunk_overlap_input], | |
outputs=process_output | |
) | |
answer_btn.click( | |
fn=answer_question, | |
inputs=[query_input, top_k_input], | |
outputs=answer_output | |
) | |
if __name__ == "__main__": | |
demo.launch() | |