Spaces:
Sleeping
Sleeping
# app.py — Unified ColPali + MCP Agent (indices-only search, agent receives images) | |
import os | |
import base64 | |
import tempfile | |
from io import BytesIO | |
from urllib.request import urlretrieve | |
from typing import List, Tuple, Dict, Any | |
import gradio as gr | |
from gradio_pdf import PDF | |
import torch | |
from pdf2image import convert_from_path | |
from PIL import Image | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from colpali_engine.models import ColQwen2, ColQwen2Processor | |
# Optional (used by the streaming agent) | |
from openai import OpenAI | |
# ============================= | |
# Globals & Config | |
# ============================= | |
api_key_env = os.getenv("OPENAI_API_KEY", "").strip() | |
ds: List[torch.Tensor] = [] # page embeddings | |
images: List[Image.Image] = [] # PIL images in page order | |
current_pdf_path: str | None = None | |
device_map = ( | |
"cuda:0" | |
if torch.cuda.is_available() | |
else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu") | |
) | |
# ============================= | |
# Load Model & Processor | |
# ============================= | |
model = ColQwen2.from_pretrained( | |
"vidore/colqwen2-v1.0", | |
torch_dtype=torch.bfloat16, | |
device_map=device_map, | |
attn_implementation="flash_attention_2", | |
).eval() | |
processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0") | |
# ============================= | |
# Utilities | |
# ============================= | |
def _ensure_model_device() -> str: | |
dev = ( | |
"cuda:0" | |
if torch.cuda.is_available() | |
else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu") | |
) | |
if str(model.device) != dev: | |
model.to(dev) | |
return dev | |
def encode_image_to_base64(image: Image.Image) -> str: | |
"""Encodes a PIL image to base64 (JPEG).""" | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
# ============================= | |
# Indexing Helpers | |
# ============================= | |
def convert_files(pdf_path: str) -> List[Image.Image]: | |
"""Convert a single PDF path into a list of PIL Images (pages).""" | |
imgs = convert_from_path(pdf_path, thread_count=4) | |
if len(imgs) >= 800: | |
raise gr.Error("The number of images in the dataset should be less than 800.") | |
return imgs | |
def index_gpu(imgs: List[Image.Image]) -> str: | |
"""Embed a list of images (pages) with ColQwen2 (ColPali) and store in globals.""" | |
global ds, images | |
device = _ensure_model_device() | |
# reset previous dataset | |
ds = [] | |
images = imgs | |
dataloader = DataLoader( | |
images, | |
batch_size=4, | |
shuffle=False, | |
collate_fn=lambda x: processor.process_images(x).to(model.device), | |
) | |
for batch_doc in tqdm(dataloader, desc="Indexing pages"): | |
with torch.no_grad(): | |
batch_doc = {k: v.to(device) for k, v in batch_doc.items()} | |
embeddings_doc = model(**batch_doc) | |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) | |
return f"Indexed {len(images)} pages successfully." | |
def index_from_path(pdf_path: str) -> str: | |
imgs = convert_files(pdf_path) | |
return index_gpu(imgs) | |
def index_from_url(url: str) -> Tuple[str, str]: | |
""" | |
Download a PDF from URL and index it. | |
Returns: (status_message, saved_pdf_path) | |
""" | |
tmp_dir = tempfile.mkdtemp(prefix="colpali_") | |
local_path = os.path.join(tmp_dir, "document.pdf") | |
urlretrieve(url, local_path) | |
status = index_from_path(local_path) | |
return status, local_path | |
# ============================= | |
# MCP Tools | |
# ============================= | |
def search(query: str, k: int = 5) -> List[int]: | |
""" | |
Search within an indexed PDF and return ONLY the indices of the most relevant pages (0-based). | |
MCP tool description: | |
- name: mcp_test_search | |
- description: Search within the indexed PDF for the most relevant pages and return their 0-based indices only. | |
- input_schema: | |
type: object | |
properties: | |
query: {type: string, description: "User query in natural language."} | |
k: {type: integer, minimum: 1, maximum: 50, default: 5, description: "Number of top pages to retrieve (before neighbor expansion)."} | |
required: ["query"] | |
Returns: | |
List[int]: Sorted unique 0-based indices of pages to inspect (includes neighbor expansion). | |
""" | |
global ds, images | |
if not images or not ds: | |
return [] | |
k = max(1, min(int(k), len(images))) | |
device = _ensure_model_device() | |
# Encode query | |
with torch.no_grad(): | |
batch_query = processor.process_queries([query]).to(model.device) | |
embeddings_query = model(**batch_query) | |
q_vecs = list(torch.unbind(embeddings_query.to("cpu"))) | |
# Score and select top-k | |
scores = processor.score(q_vecs, ds, device=device) | |
top_k_indices = scores[0].topk(k).indices.tolist() | |
print(query, top_k_indices) | |
# Neighbor expansion for context | |
base = set(top_k_indices) | |
expanded = set(base) | |
for i in base: | |
expanded.add(i - 1) | |
expanded.add(i + 1) | |
expanded = {i for i in expanded if 0 <= i < len(images)} # strict bounds | |
return sorted(expanded) | |
def get_pages(indices: List[int]) -> Dict[str, Any]: | |
""" | |
Return page images (as data URLs) for the given 0-based indices. | |
MCP tool description: | |
- name: mcp_test_get_pages | |
- description: Given 0-based indices from mcp_test_search, return the corresponding page images as data URLs for vision reasoning. | |
- input_schema: | |
type: object | |
properties: | |
indices: { | |
type: array, | |
items: { type: integer, minimum: 0 }, | |
description: "0-based page indices to fetch", | |
} | |
required: ["indices"] | |
Returns: | |
{"images": [{"index": int, "page": int, "image_url": str}], "count": int} | |
""" | |
global images | |
print("indices to get", indices) | |
if not images: | |
return {"images": [], "count": 0} | |
uniq = sorted({i for i in indices if 0 <= i < len(images)}) | |
payload = [] | |
for idx in uniq: | |
im = images[idx] | |
b64 = encode_image_to_base64(im) | |
payload.append({ | |
"index": idx, | |
"page": idx + 1, | |
"image_url": f"data:image/jpeg;base64,{b64}", | |
}) | |
return {"images": payload, "count": len(payload)} | |
# ============================= | |
# Gradio UI — Unified App | |
# ============================= | |
SYSTEM = ( | |
""" | |
You are a PDF research agent with two tools: | |
• mcp_test_search(query: string, k: int) → returns ONLY 0-based page indices. | |
• mcp_test_get_pages(indices: int[]) → returns the actual page images (as base64 images) for vision. | |
Policy & procedure: | |
1) Break the user task into 1–4 targeted sub-queries (in English). | |
2) For each sub-query, call mcp_test_search to get indices; THEN immediately call mcp_get_pages with those indices to obtain the page images. | |
3) Continue reasoning using ONLY the provided images. If info is insufficient, iterate: refine sub-queries and call the tools again. You may make further tool calls later in the conversation as needed. | |
Grounding & citations: | |
• Use ONLY information visible in the provided page images. | |
• After any claim, cite as (p.<page>). | |
• If an answer is not present, say “Not found in the provided pages.” | |
Final deliverable: | |
• Write a clear, standalone Markdown answer in the user's language. For lists of dates/items, include a concise table. | |
• Do not refer to “the above” or “previous messages”. | |
""" | |
).strip() | |
DEFAULT_MCP_SERVER_URL = "https://manu-mcp-test.hf.space/gradio_api/mcp/" | |
DEFAULT_MCP_SERVER_LABEL = "colpali_rag" | |
DEFAULT_ALLOWED_TOOLS = "mcp_test_search,mcp_test_get_pages" | |
def stream_agent(question: str, | |
api_key: str, | |
model: str, | |
server_url: str, | |
server_label: str, | |
require_approval: str, | |
allowed_tools: str): | |
""" | |
Streaming generator for the agent. | |
NOTE: We rely on OpenAI's MCP tool routing. The mcp_test_search tool returns indices only; | |
the agent is instructed to call mcp_get_pages next to receive images and continue reasoning. | |
""" | |
final_text = "Answer:" | |
summary_text = "Reasoning:" | |
log_lines = ["Log"] | |
if not api_key: | |
yield "⚠️ **Please provide your OpenAI API key.**", "", "" | |
return | |
client = OpenAI(api_key=api_key) | |
tools = [{ | |
"type": "mcp", | |
"server_label": server_label or DEFAULT_MCP_SERVER_LABEL, | |
"server_url": server_url or DEFAULT_MCP_SERVER_URL, | |
"allowed_tools": [t.strip() for t in (allowed_tools or DEFAULT_ALLOWED_TOOLS).split(",") if t.strip()], | |
"require_approval": require_approval or "never", | |
}] | |
req_kwargs = dict( | |
model=model, | |
input=[ | |
{"role": "system", "content": SYSTEM}, | |
{"role": "user", "content": question}, | |
], | |
reasoning={"effort": "medium", "summary": "auto"}, | |
tools=tools, | |
) | |
try: | |
with client.responses.stream(**req_kwargs) as stream: | |
for event in stream: | |
etype = getattr(event, "type", "") | |
if etype == "response.output_text.delta": | |
final_text += event.delta | |
yield final_text, summary_text, "\n".join(log_lines[-400:]) | |
elif etype == "response.reasoning_summary_text.delta": | |
summary_text += event.delta | |
yield final_text, summary_text, "\n".join(log_lines[-400:]) | |
elif etype in ("response.function_call_arguments.delta", "response.tool_call_arguments.delta"): | |
# Show tool call argument deltas in the log for transparency | |
log_lines.append(str(event.delta)) | |
elif etype == "response.error": | |
log_lines.append(f"[error] {getattr(event, 'error', '')}") | |
yield final_text, summary_text, "\n".join(log_lines[-400:]) | |
# finalize | |
_final = stream.get_final_response() | |
yield final_text, summary_text, "\n".join(log_lines[-400:]) | |
except Exception as e: | |
yield f"❌ {e}", summary_text, "\n".join(log_lines[-400:]) | |
CUSTOM_CSS = """ | |
:root { | |
--bg: #0e1117; | |
--panel: #111827; | |
--accent: #7c3aed; | |
--accent-2: #06b6d4; | |
--text: #e5e7eb; | |
--muted: #9ca3af; | |
--border: #1f2937; | |
} | |
.gradio-container {max-width: 1180px !important; margin: 0 auto !important;} | |
body {background: radial-gradient(1200px 600px at 20% -10%, rgba(124,58,237,.25), transparent 60%), | |
radial-gradient(1000px 500px at 120% 10%, rgba(6,182,212,.2), transparent 60%), | |
var(--bg) !important;} | |
.app-header { | |
display:flex; gap:16px; align-items:center; padding:20px 18px; margin:8px 0 12px; | |
border:1px solid var(--border); border-radius:20px; | |
background: linear-gradient(180deg, rgba(255,255,255,.02), rgba(255,255,255,.01)); | |
box-shadow: 0 10px 30px rgba(0,0,0,.25), inset 0 1px 0 rgba(255,255,255,.05); | |
} | |
.app-header .icon { | |
width:48px; height:48px; display:grid; place-items:center; border-radius:14px; | |
background: linear-gradient(135deg, var(--accent), var(--accent-2)); | |
color:white; font-size:26px; | |
} | |
.app-header h1 {font-size:22px; margin:0; color:var(--text); letter-spacing:.2px;} | |
.app-header p {margin:2px 0 0; color:var(--muted); font-size:14px;} | |
.card { | |
border:1px solid var(--border); border-radius:18px; padding:14px 16px; | |
background: linear-gradient(180deg, rgba(255,255,255,.02), rgba(255,255,255,.01)); | |
box-shadow: 0 12px 28px rgba(0,0,0,.18), inset 0 1px 0 rgba(255,255,255,.04); | |
} | |
.gr-button-primary {border-radius:12px !important; font-weight:600;} | |
.gradio-container .tabs {border-radius:16px; overflow:hidden; border:1px solid var(--border);} | |
.markdown-wrap {min-height: 260px;} | |
.summary-wrap {min-height: 180px;} | |
.gr-markdown, .gr-prose { color: var(--text) !important; } | |
.gr-markdown h1, .gr-markdown h2, .gr-markdown h3 {color: #f3f4f6;} | |
.gr-markdown a {color: var(--accent-2); text-decoration: none;} | |
.gr-markdown a:hover {text-decoration: underline;} | |
.gr-markdown table {width: 100%; border-collapse: collapse; margin: 10px 0 16px;} | |
.gr-markdown th, .gr-markdown td {border: 1px solid var(--border); padding: 8px 10px;} | |
.gr-markdown th {background: rgba(255,255,255,.03);} | |
.gr-markdown pre, .gr-markdown code { background: #0b1220; color: #eaeaf0; border-radius: 12px; border: 1px solid #172036; } | |
.gr-markdown pre {padding: 12px 14px; overflow:auto;} | |
.gr-markdown blockquote { border-left: 4px solid var(--accent); padding: 6px 12px; margin: 8px 0; color: #d1d5db; background: rgba(124,58,237,.06); border-radius: 8px; } | |
.log-box { font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; white-space: pre-wrap; color: #d1d5db; background:#0b1220; border:1px solid #172036; border-radius:14px; padding:12px; max-height:280px; overflow:auto; } | |
""" | |
def build_ui(): | |
theme = gr.themes.Soft() | |
with gr.Blocks(title="ColPali PDF RAG + MCP Agent (Indices-only)", theme=theme, css=CUSTOM_CSS) as demo: | |
gr.HTML( | |
""" | |
<div class="app-header"> | |
<div class="icon">📚</div> | |
<div> | |
<h1>ColPali PDF Search + Streaming Agent</h1> | |
<p>Index PDFs with ColQwen2 (ColPali). The search tool returns page indices only; the agent fetches images and reasons visually.</p> | |
</div> | |
</div> | |
""" | |
) | |
with gr.Tab("1) Index & Preview"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"]) | |
index_btn = gr.Button("📥 Index Uploaded PDF", variant="secondary") | |
url_box = gr.Textbox( | |
label="Or index from URL", | |
placeholder="https://example.com/file.pdf", | |
value="", | |
) | |
index_url_btn = gr.Button("🌐 Load From URL", variant="secondary") | |
status_box = gr.Textbox(label="Status", interactive=False) | |
with gr.Column(scale=2): | |
pdf_view = PDF(label="PDF Preview") | |
# wiring | |
def handle_upload(file): | |
global current_pdf_path | |
if file is None: | |
return "Please upload a PDF.", None | |
path = getattr(file, "name", file) | |
status = index_from_path(path) | |
current_pdf_path = path | |
return status, path | |
def handle_url(url: str): | |
global current_pdf_path | |
if not url or not url.lower().endswith(".pdf"): | |
return "Please provide a direct PDF URL ending in .pdf", None | |
status, path = index_from_url(url) | |
current_pdf_path = path | |
return status, path | |
index_btn.click(handle_upload, inputs=[pdf_input], outputs=[status_box, pdf_view]) | |
index_url_btn.click(handle_url, inputs=[url_box], outputs=[status_box, pdf_view]) | |
with gr.Tab("2) Ask (Direct — returns indices)"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
query_box = gr.Textbox(placeholder="Enter your question…", label="Query", lines=4) | |
k_slider = gr.Slider(minimum=1, maximum=50, step=1, label="Number of results (k)", value=5) | |
search_button = gr.Button("🔍 Search", variant="primary") | |
with gr.Column(scale=2): | |
output_text = gr.Textbox(label="Indices (0-based)", lines=12, placeholder="[0, 1, 2, ...]") | |
search_button.click(search, inputs=[query_box, k_slider], outputs=[output_text]) | |
with gr.Tab("3) Agent (Streaming)"): | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
with gr.Group(): | |
question = gr.Textbox( | |
label="Your question", | |
placeholder="Enter your question…", | |
lines=8, | |
elem_classes=["card"], | |
) | |
run_btn = gr.Button("Run", variant="primary") | |
with gr.Accordion("Connection & Model", open=False, elem_classes=["card"]): | |
with gr.Row(): | |
api_key_box = gr.Textbox( | |
label="OpenAI API Key", | |
placeholder="sk-...", | |
type="password", | |
value=api_key_env, | |
) | |
model_box = gr.Dropdown( | |
label="Model", | |
choices=["gpt-5", "gpt-4.1", "gpt-4o"], | |
value="gpt-5", | |
) | |
with gr.Row(): | |
server_url_box = gr.Textbox( | |
label="MCP Server URL", | |
value=DEFAULT_MCP_SERVER_URL, | |
) | |
server_label_box = gr.Textbox( | |
label="MCP Server Label", | |
value=DEFAULT_MCP_SERVER_LABEL, | |
) | |
with gr.Row(): | |
allowed_tools_box = gr.Textbox( | |
label="Allowed Tools (comma-separated)", | |
value=DEFAULT_ALLOWED_TOOLS, | |
) | |
require_approval_box = gr.Dropdown( | |
label="Require Approval", | |
choices=["never", "auto", "always"], | |
value="never", | |
) | |
with gr.Column(scale=3): | |
with gr.Tab("Answer (Markdown)"): | |
final_md = gr.Markdown(value="", elem_classes=["card", "markdown-wrap"]) | |
with gr.Tab("Live Summary (Markdown)"): | |
summary_md = gr.Markdown(value="", elem_classes=["card", "summary-wrap"]) | |
with gr.Tab("Event Log"): | |
log_md = gr.Markdown(value="", elem_classes=["card", "log-box"]) | |
run_btn.click( | |
stream_agent, | |
inputs=[question, api_key_box, model_box, server_url_box, server_label_box, require_approval_box, allowed_tools_box], | |
outputs=[final_md, summary_md, log_md], | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = build_ui() | |
# mcp_server=True exposes this app's MCP endpoint at /gradio_api/mcp/ | |
demo.queue(max_size=5).launch(debug=True, mcp_server=True) | |