from byaldi import RAGMultiModalModel from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor import torch from qwen_vl_utils import process_vision_info from PIL import Image import re import gradio as gr rag = RAGMultiModalModel.from_pretrained("vidore/colpali") vlm = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.float32, trust_remote_code=True, device_map="auto", ) processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True) def extract_text(image, query): messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": query}, ], } ] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") inputs = inputs.to("cpu") with torch.no_grad(): generated_ids = vlm.generate(**inputs, max_new_tokens=200, temperature=0.7, top_p=0.9) generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] return processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] def post_process_text(text): # Split the text into lines lines = text.split('. ') processed_lines = [] for line in lines: # Separate Hindi and English text parts = re.split(r'([^\u0900-\u097F\s]+:)', line, 1) if len(parts) > 1: processed_lines.append(f"{parts[0]}{parts[1]}\n {parts[2]}") else: processed_lines.append(line) # Join the lines with double line breaks text = '\n\n'.join(processed_lines) # Remove repeated phrases unique_phrases = list(dict.fromkeys(text.split('\n\n'))) text = '\n\n'.join(unique_phrases) return text def ocr(image): queries = [ "Extract and transcribe all the text visible in the image, including any small or partially visible text.", # "Look closely at the image and list any text you see, no matter how small or unclear.", # "What text can you identify in this image? Include everything, even if it's partially obscured or in the background." ] all_extracted_text = [] for query in queries: extracted_text = extract_text(image, query) all_extracted_text.append(extracted_text) # Combine and deduplicate the results final_text = "\n".join(set(all_extracted_text)) final_text = post_process_text(final_text) return final_text def main_fun(image, keyword): ext_text = ocr(image) if keyword: highlight_text = re.sub(f'({re.escape(keyword)})', r'\1', ext_text, flags=re.IGNORECASE) else: highlight_text = ext_text return ext_text, highlight_text iface = gr.Interface( fn=main_fun, inputs=[ gr.Image(type="pil", label="Upload an Image"), gr.Textbox(label="Enter search term", placeholder="Search") ], outputs=[ gr.Textbox(label="Extracted Text"), gr.HTML(label="Search Results") ], title="Document Search using OCR (English/Hindi)" ) iface.launch()