Rohit1412 commited on
Commit
d7f5ad7
·
verified ·
1 Parent(s): 959cfe7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -0
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import faiss
4
+ import numpy as np
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ # ---------------------------
9
+ # Load Models (cached on first run)
10
+ # ---------------------------
11
+ def load_models():
12
+ hf_token = os.getenv("HF_TOKEN") # Set this secret in your HF Space settings
13
+ embed_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') # For embeddings
14
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-it", use_auth_token=hf_token)
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ "google/gemma-3-4b-it",
17
+ device_map="auto",
18
+ low_cpu_mem_usage=True,
19
+ use_auth_token=hf_token
20
+ )
21
+ return embed_model, tokenizer, model
22
+
23
+ embed_model, tokenizer, model = load_models()
24
+
25
+ # ---------------------------
26
+ # Global state for FAISS index and document chunks.
27
+ # Using a dictionary to hold state.
28
+ state = {
29
+ "faiss_index": None,
30
+ "doc_chunks": []
31
+ }
32
+
33
+ # ---------------------------
34
+ # Document Processing Function
35
+ # ---------------------------
36
+ def process_document(file, chunk_size, chunk_overlap):
37
+ """
38
+ Reads the uploaded file (PDF or text), extracts text, splits into chunks,
39
+ computes embeddings, and builds a FAISS index.
40
+ """
41
+ if file is None:
42
+ return "No file uploaded."
43
+
44
+ file_bytes = file.read()
45
+ file_name = file.name
46
+ text = ""
47
+
48
+ if file_name.lower().endswith(".pdf"):
49
+ try:
50
+ from PyPDF2 import PdfReader
51
+ except ImportError:
52
+ return "Error: PyPDF2 is required for PDF extraction."
53
+ # Save file to temporary path
54
+ temp_path = os.path.join("temp", file_name)
55
+ os.makedirs("temp", exist_ok=True)
56
+ with open(temp_path, "wb") as f:
57
+ f.write(file_bytes)
58
+ reader = PdfReader(temp_path)
59
+ for page in reader.pages:
60
+ text += page.extract_text() or ""
61
+ else:
62
+ # Assume it's a text file
63
+ text = file_bytes.decode("utf-8", errors="ignore")
64
+
65
+ if text.strip() == "":
66
+ return "No text found in the document."
67
+
68
+ # Split text into overlapping chunks
69
+ chunks = []
70
+ for start in range(0, len(text), chunk_size - chunk_overlap):
71
+ chunk_text = text[start: start + chunk_size]
72
+ chunks.append(chunk_text)
73
+
74
+ # Compute embeddings for each chunk using the embedding model.
75
+ embeddings = embed_model.encode(chunks, normalize_embeddings=True).astype('float32')
76
+ dim = embeddings.shape[1]
77
+
78
+ # Build FAISS index using cosine similarity (normalized vectors -> inner product)
79
+ index = faiss.IndexFlatIP(dim)
80
+ index.add(embeddings)
81
+
82
+ # Update global state
83
+ state["faiss_index"] = index
84
+ state["doc_chunks"] = chunks
85
+
86
+ # Return a preview (first 500 characters of the first chunk) and status.
87
+ preview = chunks[0][:500] if chunks else "No content"
88
+ return f"Indexed {len(chunks)} chunks.\n\n**Document Preview:**\n{preview}"
89
+
90
+ # ---------------------------
91
+ # Question Answering Function
92
+ # ---------------------------
93
+ def answer_question(query, top_k):
94
+ """
95
+ Retrieves the top_k chunks most relevant to the query using the FAISS index,
96
+ builds a prompt with the retrieved context, and generates an answer using the Gemma model.
97
+ """
98
+ index = state.get("faiss_index")
99
+ chunks = state.get("doc_chunks")
100
+ if index is None or len(chunks) == 0:
101
+ return "No document processed. Please upload a document first."
102
+
103
+ # Encode query using the same embedding model
104
+ query_vec = embed_model.encode([query], normalize_embeddings=True).astype('float32')
105
+ D, I = index.search(query_vec, top_k)
106
+
107
+ # Concatenate retrieved chunks as context
108
+ retrieved_text = ""
109
+ for idx in I[0]:
110
+ retrieved_text += chunks[idx] + "\n"
111
+
112
+ # Formulate the prompt for the generative model
113
+ prompt = f"Context:\n{retrieved_text}\nQuestion: {query}\nAnswer:"
114
+
115
+ # Tokenize and generate answer
116
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
117
+ output_ids = model.generate(input_ids, max_new_tokens=200, temperature=0.2)
118
+ answer = tokenizer.decode(output_ids[0][input_ids.size(1):], skip_special_tokens=True)
119
+ return answer.strip()
120
+
121
+ # ---------------------------
122
+ # Gradio Interface
123
+ # ---------------------------
124
+ with gr.Blocks(title="RAG System with Gemma‑3‑4B‑it") as demo:
125
+ gr.Markdown(
126
+ """
127
+ # RAG System with Gemma‑3‑4B‑it
128
+ Upload a document (PDF or TXT) below. The system will extract text, split it into chunks,
129
+ build a vector index using FAISS, and then allow you to ask questions based on the document.
130
+ """
131
+ )
132
+
133
+ with gr.Tab("Document Upload & Processing"):
134
+ with gr.Row():
135
+ file_input = gr.File(label="Upload Document (PDF or TXT)", file_count="single")
136
+ with gr.Row():
137
+ chunk_size_input = gr.Number(label="Chunk Size (characters)", value=1000, precision=0)
138
+ chunk_overlap_input = gr.Number(label="Chunk Overlap (characters)", value=100, precision=0)
139
+ process_btn = gr.Button("Process Document")
140
+ process_output = gr.Markdown()
141
+
142
+ with gr.Tab("Ask a Question"):
143
+ query_input = gr.Textbox(label="Enter your question", placeholder="Type your question here...")
144
+ top_k_input = gr.Number(label="Number of Chunks to Retrieve", value=3, precision=0)
145
+ answer_btn = gr.Button("Get Answer")
146
+ answer_output = gr.Markdown(label="Answer")
147
+
148
+ # Set up actions
149
+ process_btn.click(
150
+ fn=process_document,
151
+ inputs=[file_input, chunk_size_input, chunk_overlap_input],
152
+ outputs=process_output
153
+ )
154
+ answer_btn.click(
155
+ fn=answer_question,
156
+ inputs=[query_input, top_k_input],
157
+ outputs=answer_output
158
+ )
159
+
160
+ if __name__ == "__main__":
161
+ demo.launch()