manu commited on
Commit
5697c10
·
verified ·
1 Parent(s): f4be1c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -120
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import base64
 
3
  from io import BytesIO
 
4
 
5
  import gradio as gr
6
  from gradio_pdf import PDF
@@ -13,179 +15,272 @@ from tqdm import tqdm
13
 
14
  from colpali_engine.models import ColQwen2, ColQwen2Processor
15
 
 
 
 
 
 
 
 
16
 
 
 
 
 
17
 
18
  model = ColQwen2.from_pretrained(
19
- "vidore/colqwen2-v1.0",
20
- torch_dtype=torch.bfloat16,
21
- device_map="cuda:0", # or "mps" if on Apple Silicon
22
- attn_implementation="flash_attention_2"
23
- ).eval()
24
  processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0")
25
 
26
 
27
- def encode_image_to_base64(image):
 
 
 
28
  """Encodes a PIL image to a base64 string."""
29
  buffered = BytesIO()
30
  image.save(buffered, format="JPEG")
31
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
32
-
33
 
34
- def query_gpt4o_mini(query, images, api_key):
35
- """Calls OpenAI's GPT-4o-mini with the query and image data."""
36
 
 
 
37
  if api_key and api_key.startswith("sk"):
38
  try:
39
  from openai import OpenAI
40
-
41
- base64_images = [encode_image_to_base64(image[0]) for image in images]
42
  client = OpenAI(api_key=api_key.strip())
43
  PROMPT = """
44
- You are a smart assistant designed to answer questions about a PDF document.
45
- You are given relevant information in the form of PDF pages. Use them to construct a short response to the question, and cite your sources (page numbers, etc).
46
- If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
47
- Give detailed and extensive answers, only containing info in the pages you are given.
48
- You can answer using information contained in plots and figures if necessary.
49
- Answer in the same language as the query.
50
-
51
- Query: {query}
52
- PDF pages:
53
- """
54
-
55
  response = client.chat.completions.create(
56
- model="gpt-4o-mini",
57
- messages=[
58
- {
59
- "role": "user",
60
- "content": [
61
  {
62
- "type": "text",
63
- "text": PROMPT.format(query=query)
64
- }] + [{
65
- "type": "image_url",
66
- "image_url": {
67
- "url": f"data:image/jpeg;base64,{im}"
68
- },
69
- } for im in base64_images]
70
- }
71
- ],
72
- max_tokens=500,
73
  )
74
  return response.choices[0].message.content
75
- except Exception as e:
76
- return "OpenAI API connection failure. Verify the provided key is correct (sk-***)."
77
-
78
- return "Enter your OpenAI API key to get a custom response"
79
-
80
-
81
- def search(query: str, ds, images, k, api_key):
82
- k = min(k, len(ds))
83
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
84
- if device != model.device:
85
- model.to(device)
86
-
87
- qs = []
88
- with torch.no_grad():
89
- batch_query = processor.process_queries([query]).to(model.device)
90
- embeddings_query = model(**batch_query)
91
- qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
92
-
93
- scores = processor.score(qs, ds, device=device)
94
 
95
- top_k_indices = scores[0].topk(k).indices.tolist()
96
-
97
- results = []
98
- for idx in top_k_indices:
99
- results.append((images[idx], f"Page {idx}"))
100
-
101
- # Generate response from GPT-4o-mini
102
- ai_response = query_gpt4o_mini(query, results, api_key)
103
 
104
- return results, ai_response
105
-
106
-
107
- def index(files, ds):
108
- print("Converting files")
109
- images = convert_files(files)
110
- print(f"Files converted with {len(images)} images.")
111
- return index_gpu(images, ds)
112
-
113
 
114
 
115
- def convert_files(files):
116
- images = []
117
- print(files)
118
- files = [files]
119
- for f in files:
120
- images.extend(convert_from_path(f, thread_count=4))
121
-
122
- if len(images) >= 500:
123
  raise gr.Error("The number of images in the dataset should be less than 500.")
124
- return images
 
125
 
 
 
 
 
126
 
127
- def index_gpu(images, ds):
128
- """Example script to run inference with ColPali (ColQwen2)"""
 
129
 
130
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
131
- if device != model.device:
132
- model.to(device)
133
-
134
- # run inference - docs
135
  dataloader = DataLoader(
136
  images,
137
  batch_size=4,
138
- # num_workers=4,
139
  shuffle=False,
140
  collate_fn=lambda x: processor.process_images(x).to(model.device),
141
  )
142
 
143
- for batch_doc in tqdm(dataloader):
144
  with torch.no_grad():
145
  batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
146
  embeddings_doc = model(**batch_doc)
147
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
148
- return f"Uploaded and converted {len(images)} pages", ds, images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
 
 
 
150
 
 
 
 
 
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
153
  gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) 📚")
154
- gr.Markdown("""Demo to test ColQwen2 (ColPali) on PDF documents.
155
- ColPali is model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).
 
156
 
157
- This demo allows you to upload PDF files and search for the most relevant pages based on your query.
158
- Refresh the page if you change documents !
 
 
 
159
 
160
- ⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing english text. Performance is expected to drop for other page formats and languages.
161
- Other models will be released with better robustness towards different languages and document formats !
162
- """)
163
  with gr.Row():
164
  with gr.Column(scale=2):
165
- gr.Markdown("## 1️⃣ Upload PDFs")
166
- # file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs")
167
- file = PDF(label="PDF Document")
168
- print(file)
169
-
170
- convert_button = gr.Button("🔄 Index documents")
171
- message = gr.Textbox("Files not yet uploaded", label="Status")
172
- api_key = gr.Textbox(placeholder="Enter your OpenAI KEY here (optional)", label="API key")
173
- embeds = gr.State(value=[])
174
- imgs = gr.State(value=[])
 
175
 
176
  with gr.Column(scale=3):
177
  gr.Markdown("## 2️⃣ Search")
178
  query = gr.Textbox(placeholder="Enter your query here", label="Query")
179
- k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
180
-
 
 
181
 
182
- # Define the actions
183
- search_button = gr.Button("🔍 Search", variant="primary")
184
- output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
185
- output_text = gr.Textbox(label="AI Response", placeholder="Generated response based on retrieved documents")
186
-
187
- convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
188
- search_button.click(search, inputs=[query, embeds, imgs, k, api_key], outputs=[output_gallery, output_text])
189
 
190
  if __name__ == "__main__":
191
- demo.queue(max_size=5).launch(debug=True, mcp_server=True)
 
 
 
 
 
 
1
  import os
2
  import base64
3
+ import tempfile
4
  from io import BytesIO
5
+ from urllib.request import urlretrieve
6
 
7
  import gradio as gr
8
  from gradio_pdf import PDF
 
15
 
16
  from colpali_engine.models import ColQwen2, ColQwen2Processor
17
 
18
+ # -----------------------------
19
+ # Globals
20
+ # -----------------------------
21
+ api_key = os.getenv("OPENAI_API_KEY", "") # <- use env var
22
+ ds = [] # list of document embeddings (torch tensors)
23
+ images = [] # list of PIL images (page-order)
24
+ current_pdf_path = None # last (indexed) pdf path for preview
25
 
26
+ # -----------------------------
27
+ # Model & processor
28
+ # -----------------------------
29
+ device_map = "cuda:0" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
30
 
31
  model = ColQwen2.from_pretrained(
32
+ "vidore/colqwen2-v1.0",
33
+ torch_dtype=torch.bfloat16,
34
+ device_map=device_map,
35
+ attn_implementation="flash_attention_2"
36
+ ).eval()
37
  processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0")
38
 
39
 
40
+ # -----------------------------
41
+ # Utilities
42
+ # -----------------------------
43
+ def encode_image_to_base64(image: Image.Image) -> str:
44
  """Encodes a PIL image to a base64 string."""
45
  buffered = BytesIO()
46
  image.save(buffered, format="JPEG")
47
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
 
48
 
 
 
49
 
50
+ def query_gpt(query: str, retrieved_images: list[tuple[Image.Image, str]]) -> str:
51
+ """Calls OpenAI's GPT model with the query and image data."""
52
  if api_key and api_key.startswith("sk"):
53
  try:
54
  from openai import OpenAI
55
+
56
+ base64_images = [encode_image_to_base64(im_caption[0]) for im_caption in retrieved_images]
57
  client = OpenAI(api_key=api_key.strip())
58
  PROMPT = """
59
+ You are a smart assistant designed to answer questions about a PDF document.
60
+ You are given relevant information in the form of PDF pages. Use them to construct a short response to the question, and cite your sources (page numbers, etc).
61
+ If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
62
+ Give detailed and extensive answers, only containing info in the pages you are given.
63
+ You can answer using information contained in plots and figures if necessary.
64
+ Answer in the same language as the query.
65
+
66
+ Query: {query}
67
+ PDF pages:
68
+ """.strip()
69
+
70
  response = client.chat.completions.create(
71
+ model="gpt-5-mini",
72
+ messages=[
 
 
 
73
  {
74
+ "role": "user",
75
+ "content": (
76
+ [{"type": "text", "text": PROMPT.format(query=query)}] +
77
+ [{"type": "image_url",
78
+ "image_url": {"url": f"data:image/jpeg;base64,{im}"}}
79
+ for im in base64_images]
80
+ )
81
+ }
82
+ ],
83
+ max_tokens=500,
 
84
  )
85
  return response.choices[0].message.content
86
+ except Exception:
87
+ return "OpenAI API connection failure. Verify that OPENAI_API_KEY is set and valid (sk-***)."
88
+ return "Set OPENAI_API_KEY in your environment to get a custom response."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
 
 
 
 
 
 
 
 
90
 
91
+ def _ensure_model_device():
92
+ dev = "cuda:0" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu")
93
+ if str(model.device) != dev:
94
+ model.to(dev)
95
+ return dev
 
 
 
 
96
 
97
 
98
+ # -----------------------------
99
+ # Indexing helpers
100
+ # -----------------------------
101
+ def convert_files(pdf_path: str) -> list[Image.Image]:
102
+ """Convert a single PDF path into a list of PIL Images (pages)."""
103
+ imgs = convert_from_path(pdf_path, thread_count=4)
104
+ if len(imgs) >= 500:
 
105
  raise gr.Error("The number of images in the dataset should be less than 500.")
106
+ return imgs
107
+
108
 
109
+ def index_gpu(imgs: list[Image.Image]) -> str:
110
+ """Embed a list of images (pages) with ColPali and store in globals."""
111
+ global ds, images
112
+ device = _ensure_model_device()
113
 
114
+ # reset previous dataset
115
+ ds = []
116
+ images = imgs
117
 
 
 
 
 
 
118
  dataloader = DataLoader(
119
  images,
120
  batch_size=4,
 
121
  shuffle=False,
122
  collate_fn=lambda x: processor.process_images(x).to(model.device),
123
  )
124
 
125
+ for batch_doc in tqdm(dataloader, desc="Indexing pages"):
126
  with torch.no_grad():
127
  batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
128
  embeddings_doc = model(**batch_doc)
129
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
130
+ return f"Indexed {len(images)} pages successfully."
131
+
132
+
133
+ def index_from_path(pdf_path: str) -> str:
134
+ """Public: index a local PDF file path."""
135
+ imgs = convert_files(pdf_path)
136
+ return index_gpu(imgs)
137
+
138
+
139
+ def index_from_url(url: str) -> tuple[str, str]:
140
+ """
141
+ Download a PDF from URL and index it.
142
+
143
+ Returns:
144
+ status message, saved pdf path
145
+ """
146
+ tmp_dir = tempfile.mkdtemp(prefix="colpali_")
147
+ local_path = os.path.join(tmp_dir, "document.pdf")
148
+ urlretrieve(url, local_path)
149
+ status = index_from_path(local_path)
150
+ return status, local_path
151
+
152
+
153
+ # -----------------------------
154
+ # Search (MCP tool-friendly)
155
+ # -----------------------------
156
+ def search(query: str, k: int):
157
+ """
158
+ Search the currently indexed PDF pages for the most relevant content and
159
+ generate an answer grounded ONLY in those pages.
160
+
161
+ MCP tool description:
162
+ - name: search
163
+ - description: Retrieve top-k PDF pages relevant to a query and answer using only those pages.
164
+ - input_schema:
165
+ type: object
166
+ properties:
167
+ query: {type: string, description: "User query in natural language."}
168
+ k: {type: integer, minimum: 1, maximum: 10, description: "Number of top pages to retrieve."}
169
+ required: ["query"]
170
+
171
+ Args:
172
+ query (str): Natural-language question to search for.
173
+ k (int): Number of top results to return (1–10).
174
+
175
+ Returns:
176
+ tuple:
177
+ - results (list[tuple[PIL.Image.Image, str]]): List of (page_image, caption) pairs for a gallery.
178
+ - ai_response (str): Answer grounded only in retrieved pages, with citations (page numbers).
179
+
180
+ Notes:
181
+ • Requires that a PDF has been indexed first.
182
+ • Citations reference 1-based page numbers as shown in the gallery captions.
183
+ """
184
+ global ds, images
185
+
186
+ if not images or not ds:
187
+ return [], "No document indexed yet. Upload a PDF or load the sample, then run Search."
188
+
189
+ k = max(1, min(int(k), len(images)))
190
+ device = _ensure_model_device()
191
+
192
+ # Encode query
193
+ qs = []
194
+ with torch.no_grad():
195
+ batch_query = processor.process_queries([query]).to(model.device)
196
+ embeddings_query = model(**batch_query)
197
+ qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
198
 
199
+ # Score and select top-k
200
+ scores = processor.score(qs, ds, device=device)
201
+ top_k_indices = scores[0].topk(k).indices.tolist()
202
 
203
+ # Build gallery results with 1-based page numbering
204
+ results = []
205
+ for idx in top_k_indices:
206
+ page_num = idx + 1
207
+ results.append((images[idx], f"Page {page_num}"))
208
 
209
+ # Generate grounded response
210
+ ai_response = query_gpt(query, results)
211
+ return results, ai_response
212
+
213
+
214
+ # -----------------------------
215
+ # Gradio UI callbacks
216
+ # -----------------------------
217
+ def handle_upload(file) -> tuple[str, str | None]:
218
+ """Index a user-uploaded PDF file."""
219
+ global current_pdf_path
220
+ if file is None:
221
+ return "Please upload a PDF.", None
222
+ path = getattr(file, "name", file)
223
+ status = index_from_path(path)
224
+ current_pdf_path = path
225
+ return status, path
226
+
227
+
228
+ def handle_url(url: str) -> tuple[str, str | None]:
229
+ """Index a PDF from URL (e.g., a sample)."""
230
+ global current_pdf_path
231
+ if not url or not url.lower().endswith(".pdf"):
232
+ return "Please provide a direct PDF URL.", None
233
+ status, path = index_from_url(url)
234
+ current_pdf_path = path
235
+ return status, path
236
+
237
+
238
+ # -----------------------------
239
+ # Gradio App
240
+ # -----------------------------
241
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
242
  gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) 📚")
243
+ gr.Markdown(
244
+ """Demo to test ColQwen2 (ColPali) on PDF documents.
245
+ ColPali is implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).
246
 
247
+ This demo lets you **upload a PDF or load a sample**, then **search** for the most relevant pages and get a grounded answer.
248
+
249
+ ⚠️ The model was trained on A4 portrait English PDFs; performance may drop on other formats/languages.
250
+ """
251
+ )
252
 
 
 
 
253
  with gr.Row():
254
  with gr.Column(scale=2):
255
+ gr.Markdown("## 1️⃣ Load a PDF")
256
+ pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
257
+ index_btn = gr.Button("📥 Index Uploaded PDF", variant="secondary")
258
+ url_box = gr.Textbox(
259
+ label="Or index from URL",
260
+ placeholder="https://example.com/file.pdf",
261
+ value="https://sist.sathyabama.ac.in/sist_coursematerial/uploads/SAR1614.pdf",
262
+ )
263
+ index_url_btn = gr.Button("🌐 Load Sample / From URL", variant="secondary")
264
+ status_box = gr.Textbox(label="Status", interactive=False)
265
+ pdf_view = PDF(label="PDF Preview")
266
 
267
  with gr.Column(scale=3):
268
  gr.Markdown("## 2️⃣ Search")
269
  query = gr.Textbox(placeholder="Enter your query here", label="Query")
270
+ k_slider = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
271
+ search_button = gr.Button("🔍 Search", variant="primary")
272
+ output_gallery = gr.Gallery(label="Retrieved Pages", height=600, show_label=True)
273
+ output_text = gr.Textbox(label="AI Response", placeholder="Generated response based on retrieved documents")
274
 
275
+ # Wiring
276
+ index_btn.click(handle_upload, inputs=[pdf_input], outputs=[status_box, pdf_view])
277
+ index_url_btn.click(handle_url, inputs=[url_box], outputs=[status_box, pdf_view])
278
+ search_button.click(search, inputs=[query, k_slider], outputs=[output_gallery, output_text])
 
 
 
279
 
280
  if __name__ == "__main__":
281
+ # Optional: pre-load the default sample at startup.
282
+ # Comment these two lines if you prefer a "cold" start.
283
+ # msg, path = index_from_url("https://sist.sathyabama.ac.in/sist_coursematerial/uploads/SAR1614.pdf")
284
+ # print(msg, "->", path)
285
+
286
+ demo.queue(max_size=5).launch(debug=True, mcp_server=True)