Spaces:
Sleeping
Sleeping
import gradio as gr | |
from byaldi import RAGMultiModalModel | |
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor | |
from qwen_vl_utils import process_vision_info | |
import torch | |
from PIL import Image | |
import os | |
import traceback | |
import re | |
# Load models | |
rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali") | |
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
"Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16 | |
) | |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True) | |
extracted_text = "" # Store the extracted text globally for keyword search | |
def ocr_and_extract(image, text_query=None): | |
global extracted_text | |
try: | |
# Save the uploaded image temporarily | |
temp_image_path = "temp_image.jpg" | |
image.save(temp_image_path) | |
# Index the image with Byaldi | |
rag_model.index( | |
input_path=temp_image_path, | |
index_name="image_index", | |
store_collection_with_index=False, | |
overwrite=True | |
) | |
# Perform the search query on the indexed image | |
results = rag_model.search(text_query, k=1) | |
# Prepare the input for Qwen2-VL | |
image_data = Image.open(temp_image_path) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image_data}, | |
{"type": "text", "text": text_query}, | |
], | |
} | |
] | |
# Process input for Qwen2-VL | |
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
image_inputs, _ = process_vision_info(messages) | |
inputs = processor( | |
text=[text_input], | |
images=image_inputs, | |
padding=True, | |
return_tensors="pt", | |
) | |
qwen_model.to("cuda") | |
inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
# Generate the output with Qwen2-VL | |
generated_ids = qwen_model.generate(**inputs, max_new_tokens=50) | |
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
# Store the extracted text for keyword search | |
extracted_text = output_text[0] | |
os.remove(temp_image_path) | |
return extracted_text | |
except Exception as e: | |
error_message = str(e) | |
traceback.print_exc() | |
return f"Error: {error_message}" | |
def search_keywords(keyword): | |
global extracted_text | |
if not extracted_text: | |
return "No text extracted yet. Please upload an image." | |
# Perform basic keyword search within the extracted text | |
if re.search(rf"\b{re.escape(keyword)}\b", extracted_text, re.IGNORECASE): | |
highlighted_text = re.sub(rf"({re.escape(keyword)})", r"<mark>\1</mark>", extracted_text, flags=re.IGNORECASE) | |
return f"Keyword found! {highlighted_text}" | |
else: | |
return "Keyword not found in the extracted text." | |
# Gradio interface | |
image_input = gr.Image(type="pil") | |
text_output = gr.Textbox(label="Extracted Text", interactive=True) | |
keyword_search = gr.Textbox(label="Enter keywords to search") | |
search_button = gr.Button("Search Keywords") | |
search_output = gr.HTML() | |
extract_button = gr.Button("Extract Text") | |
# Layout update | |
iface = gr.Interface( | |
fn=ocr_and_extract, | |
inputs=[image_input], | |
outputs=[text_output], | |
title="Image OCR with Byaldi + Qwen2-VL", | |
description="Upload an image containing Hindi and English text for OCR. Then, search for specific keywords.", | |
) | |
# Keyword search layout | |
iface_search = gr.Interface( | |
fn=search_keywords, | |
inputs=[keyword_search], | |
outputs=[search_output], | |
) | |
# Move extract button above the text output | |
def combined_interface(image, keyword): | |
ocr_text = ocr_and_extract(image) | |
search_result = search_keywords(keyword) | |
return ocr_text, search_result | |
combined_iface = gr.Interface( | |
fn=combined_interface, | |
inputs=[image_input, keyword_search], | |
outputs=[text_output, search_output], | |
live=True, | |
title="Image OCR & Keyword Search", | |
description="Extract text from the image and search for specific keywords." | |
) | |
# Launch the app | |
combined_iface.launch() | |