Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,237 Bytes
89de2d0 c5ff753 7a84512 c5ff753 7058c3b c5ff753 7a84512 5d77312 c5ff753 d5b11f8 7058c3b c5ff753 7058c3b c5ff753 7058c3b c5ff753 d5b11f8 c5ff753 6b3b7bc c5ff753 62fbae0 d7f0680 c5ff753 55b884c c5ff753 89de2d0 c5ff753 aa58f2a c5ff753 aa58f2a c5ff753 aa58f2a c5ff753 326169e c5ff753 1daad19 c5ff753 326169e c5ff753 cb6c88c c5ff753 326169e c5ff753 326169e c5ff753 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
# Importing the requirements
import warnings
warnings.filterwarnings("ignore")
import os
import base64
import subprocess
from io import BytesIO
from tqdm import tqdm
from pdf2image import convert_from_path
import torch
from torch.utils.data import DataLoader
from transformers.utils.import_utils import is_flash_attn_2_available
from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor
from openai import OpenAI
import spaces
import gradio as gr
# Enable flash attention
# subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# Load the visual document retrieval model
model = ColQwen2_5.from_pretrained(
"vidore/colqwen2.5-v0.2",
torch_dtype=torch.bfloat16,
device_map="cuda:0",
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
).eval()
processor = ColQwen2_5_Processor.from_pretrained("vidore/colqwen2.5-v0.2")
################################################
# Helper functions
################################################
def encode_image_to_base64(image):
"""Encodes a PIL image to a base64 string."""
buffered = BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def convert_files(files):
"""Converts a list of PDF files to a list of images."""
images = []
for f in files:
images.extend(convert_from_path(f, thread_count=4))
# Check if the number of images is greater than 150
if len(images) >= 150:
raise gr.Error("The number of images in the dataset should be less than 150.")
return images
################################################
# Model Inference with ColPali and Gemini
################################################
@spaces.GPU
def index_gpu(images, ds):
"""Runs inference on the GPU for the given images with the visual document retrieval model."""
# Specify the device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
# Create a DataLoader for the images
dataloader = DataLoader(
images,
batch_size=4,
# num_workers=4,
shuffle=False,
collate_fn=lambda x: processor.process_images(x).to(model.device),
)
# Store the document embeddings
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return f"Uploaded and converted {len(images)} pages", ds, images
def query_gemini(query, images, api_key):
"""Calls Google's Gemini model with the query and image data."""
if api_key:
try:
# Convert images to base64 strings
base64_images = [encode_image_to_base64(image[0]) for image in images]
# Initialize the OpenAI client with the Gemini API key
client = OpenAI(
api_key=api_key.strip(),
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
)
PROMPT = """
You are a smart assistant designed to answer questions about a PDF document.
You are given relevant information in the form of PDF pages. Use them to construct a short response to the question, and cite your sources (page numbers, etc).
If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
Give detailed and extensive answers, only containing info in the pages you are given.
You can answer using information contained in plots and figures if necessary.
Answer in the same language as the query.
Query: {query}
PDF pages:
"""
# Get the response from the Gemini API
response = client.chat.completions.create(
model="gemini-2.5-flash-preview-04-17",
reasoning_effort="none",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": PROMPT.format(query=query)}
]
+ [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{im}"},
}
for im in base64_images
],
}
],
max_tokens=500,
)
# Return the content of the response
return response.choices[0].message.content
# Handle errors from the API
except Exception as e:
return "API connection error! Please check your API key and try again."
# If no API key is provided, return a message indicating that the user should enter their key
return "Enter your Gemini API key to get a custom response."
################################################
# Document Indexing and Search
################################################
def index(files, ds):
"""Convert files to images and index them."""
images = convert_files(files)
return index_gpu(images, ds)
@spaces.GPU
def search(query: str, ds, images, k, api_key):
"""Search for the most relevant pages based on the query."""
k = min(k, len(ds))
# Specify the device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
# Store the query embeddings
qs = []
with torch.no_grad():
batch_query = processor.process_queries([query]).to(model.device)
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
# Compute scores
scores = processor.score(qs, ds, device=device)
top_k_indices = scores[0].topk(k).indices.tolist()
# Get the top k images
results = []
for idx in top_k_indices:
img = images[idx]
img_copy = img.copy()
results.append((img_copy, f"Page {idx}"))
# Generate response from Gemini
ai_response = query_gemini(query, results, api_key)
return results, ai_response
################################################
# Gradio UI
################################################
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
gr.Markdown(
"# Multimodal RAG with ColPali & Gemini 📚"
)
gr.Markdown(
"""Demo to test ColQwen2.5 (ColPali) on PDF documents.
ColPali is a model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).
This demo allows you to upload PDF files and search for the most relevant pages based on your query.
Refresh the page if you change documents!
⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing English text. Performance is expected to drop for other page formats and languages.
Other models will be released with better robustness towards different languages and document formats!
"""
)
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("## 1️⃣ Upload PDFs")
file = gr.File(
file_types=[".pdf"], file_count="multiple", label="Upload PDFs"
)
gr.Markdown("## 2️⃣ Index the PDFs")
message = gr.Textbox("Files not yet uploaded", label="Status")
convert_button = gr.Button("🔄 Index documents")
embeds = gr.State(value=[])
imgs = gr.State(value=[])
with gr.Column(scale=3):
gr.Markdown("## 3️⃣ Search")
api_key = gr.Textbox(
placeholder="Enter your Gemini API key here (must be valid)",
label="API key",
)
query = gr.Textbox(placeholder="Enter your query here", label="Query")
k = gr.Slider(
minimum=1,
maximum=10,
step=1,
label="Number of results",
value=3,
info="Number of pages to retrieve",
)
search_button = gr.Button("🔍 Search", variant="primary")
# Define the output components
gr.Markdown("## 4️⃣ Retrieved Image")
output_gallery = gr.Gallery(
label="Retrieved Documents", height=600, show_label=True
)
gr.Markdown("## 5️⃣ Gemini Response")
output_text = gr.Textbox(
label="AI Response",
placeholder="Generated response based on retrieved documents",
show_copy_button=True,
)
# Define the button actions
convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
search_button.click(
search,
inputs=[query, embeds, imgs, k, api_key],
outputs=[output_gallery, output_text],
)
# Launch the gradio app
if __name__ == "__main__":
demo.queue(max_size=10).launch()
|