File size: 9,237 Bytes
89de2d0
 
 
 
c5ff753
 
7a84512
c5ff753
 
 
 
 
7058c3b
 
c5ff753
 
 
 
 
7a84512
5d77312
c5ff753
d5b11f8
7058c3b
 
c5ff753
 
7058c3b
c5ff753
7058c3b
c5ff753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5b11f8
c5ff753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b3b7bc
 
 
 
c5ff753
 
 
 
 
 
 
 
 
 
 
 
 
 
62fbae0
d7f0680
c5ff753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55b884c
 
 
c5ff753
 
 
 
 
 
 
 
 
 
89de2d0
c5ff753
aa58f2a
c5ff753
 
aa58f2a
 
c5ff753
aa58f2a
 
 
c5ff753
 
 
 
 
 
 
 
 
326169e
c5ff753
1daad19
c5ff753
 
 
 
326169e
 
 
 
 
c5ff753
 
 
 
 
 
 
 
 
cb6c88c
c5ff753
 
326169e
c5ff753
 
 
326169e
 
c5ff753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
# Importing the requirements
import warnings
warnings.filterwarnings("ignore")

import os
import base64
import subprocess
from io import BytesIO
from tqdm import tqdm
from pdf2image import convert_from_path
import torch
from torch.utils.data import DataLoader
from transformers.utils.import_utils import is_flash_attn_2_available
from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor
from openai import OpenAI
import spaces
import gradio as gr


# Enable flash attention
# subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

# Load the visual document retrieval model
model = ColQwen2_5.from_pretrained(
    "vidore/colqwen2.5-v0.2",
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",
    attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
).eval()
processor = ColQwen2_5_Processor.from_pretrained("vidore/colqwen2.5-v0.2")


################################################
# Helper functions
################################################
def encode_image_to_base64(image):
    """Encodes a PIL image to a base64 string."""
    buffered = BytesIO()
    image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")


def convert_files(files):
    """Converts a list of PDF files to a list of images."""
    images = []
    for f in files:
        images.extend(convert_from_path(f, thread_count=4))

    # Check if the number of images is greater than 150
    if len(images) >= 150:
        raise gr.Error("The number of images in the dataset should be less than 150.")
    return images


################################################
# Model Inference with ColPali and Gemini
################################################
@spaces.GPU
def index_gpu(images, ds):
    """Runs inference on the GPU for the given images with the visual document retrieval model."""
    # Specify the device
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    if device != model.device:
        model.to(device)

    # Create a DataLoader for the images
    dataloader = DataLoader(
        images,
        batch_size=4,
        # num_workers=4,
        shuffle=False,
        collate_fn=lambda x: processor.process_images(x).to(model.device),
    )

    # Store the document embeddings
    for batch_doc in tqdm(dataloader):
        with torch.no_grad():
            batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
            embeddings_doc = model(**batch_doc)
        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
    return f"Uploaded and converted {len(images)} pages", ds, images


def query_gemini(query, images, api_key):
    """Calls Google's Gemini model with the query and image data."""
    if api_key:
        try:
            # Convert images to base64 strings
            base64_images = [encode_image_to_base64(image[0]) for image in images]

            # Initialize the OpenAI client with the Gemini API key
            client = OpenAI(
                api_key=api_key.strip(),
                base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
            )
            PROMPT = """
            You are a smart assistant designed to answer questions about a PDF document.
            You are given relevant information in the form of PDF pages. Use them to construct a short response to the question, and cite your sources (page numbers, etc).
            If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
            Give detailed and extensive answers, only containing info in the pages you are given.
            You can answer using information contained in plots and figures if necessary.
            Answer in the same language as the query.
            
            Query: {query}
            PDF pages:
            """

            # Get the response from the Gemini API
            response = client.chat.completions.create(
                model="gemini-2.5-flash-preview-04-17",
                reasoning_effort="none",
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": PROMPT.format(query=query)}
                        ]
                        + [
                            {
                                "type": "image_url",
                                "image_url": {"url": f"data:image/jpeg;base64,{im}"},
                            }
                            for im in base64_images
                        ],
                    }
                ],
                max_tokens=500,
            )
            
            # Return the content of the response
            return response.choices[0].message.content
        
        # Handle errors from the API
        except Exception as e:
            return "API connection error! Please check your API key and try again."

    # If no API key is provided, return a message indicating that the user should enter their key
    return "Enter your Gemini API key to get a custom response."


################################################
# Document Indexing and Search
################################################
def index(files, ds):
    """Convert files to images and index them."""
    images = convert_files(files)
    return index_gpu(images, ds)


@spaces.GPU
def search(query: str, ds, images, k, api_key):
    """Search for the most relevant pages based on the query."""
    k = min(k, len(ds))

    # Specify the device
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    if device != model.device:
        model.to(device)

    # Store the query embeddings
    qs = []
    with torch.no_grad():
        batch_query = processor.process_queries([query]).to(model.device)
        embeddings_query = model(**batch_query)
        qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

    # Compute scores
    scores = processor.score(qs, ds, device=device)
    top_k_indices = scores[0].topk(k).indices.tolist()

    # Get the top k images
    results = []
    for idx in top_k_indices:
        img = images[idx]
        img_copy = img.copy()
        results.append((img_copy, f"Page {idx}"))

    # Generate response from Gemini
    ai_response = query_gemini(query, results, api_key)

    return results, ai_response


################################################
# Gradio UI
################################################
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
    gr.Markdown(
        "# Multimodal RAG with ColPali & Gemini 📚"
    )
    gr.Markdown(
        """Demo to test ColQwen2.5 (ColPali) on PDF documents. 
    ColPali is a model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).
    This demo allows you to upload PDF files and search for the most relevant pages based on your query.
    Refresh the page if you change documents!
    ⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing English text. Performance is expected to drop for other page formats and languages.
    Other models will be released with better robustness towards different languages and document formats!
    """
    )
    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("## 1️⃣ Upload PDFs")
            file = gr.File(
                file_types=[".pdf"], file_count="multiple", label="Upload PDFs"
            )

            gr.Markdown("## 2️⃣ Index the PDFs")
            message = gr.Textbox("Files not yet uploaded", label="Status")
            convert_button = gr.Button("🔄 Index documents")
            embeds = gr.State(value=[])
            imgs = gr.State(value=[])

        with gr.Column(scale=3):
            gr.Markdown("## 3️⃣ Search")
            api_key = gr.Textbox(
                placeholder="Enter your Gemini API key here (must be valid)",
                label="API key",
            )
            query = gr.Textbox(placeholder="Enter your query here", label="Query")
            k = gr.Slider(
                minimum=1,
                maximum=10,
                step=1,
                label="Number of results",
                value=3,
                info="Number of pages to retrieve",
            )
            search_button = gr.Button("🔍 Search", variant="primary")

    # Define the output components
    gr.Markdown("## 4️⃣ Retrieved Image")
    output_gallery = gr.Gallery(
        label="Retrieved Documents", height=600, show_label=True
    )

    gr.Markdown("## 5️⃣ Gemini Response")
    output_text = gr.Textbox(
        label="AI Response",
        placeholder="Generated response based on retrieved documents",
        show_copy_button=True,
    )

    # Define the button actions
    convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
    search_button.click(
        search,
        inputs=[query, embeds, imgs, k, api_key],
        outputs=[output_gallery, output_text],
    )


# Launch the gradio app
if __name__ == "__main__":
    demo.queue(max_size=10).launch()