File size: 3,097 Bytes
4902aa0
 
3bc9acc
4902aa0
 
 
3bc9acc
 
4902aa0
 
 
 
 
 
 
3bc9acc
4902aa0
3bc9acc
4902aa0
 
 
 
 
 
 
 
 
 
3bc9acc
4902aa0
 
 
 
 
 
 
 
 
3bc9acc
 
4902aa0
 
 
 
 
 
 
 
 
 
 
 
 
 
3bc9acc
4902aa0
 
 
 
3bc9acc
4902aa0
 
3bc9acc
4902aa0
 
 
 
 
 
 
 
 
 
 
3bc9acc
4902aa0
 
 
 
 
 
 
 
 
 
 
 
3bc9acc
4902aa0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
import torch
from qwen_vl_utils import process_vision_info
from PIL import Image
import os
import re

rag = RAGMultiModalModel.from_pretrained("vidore/colpali")
vlm = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct",
    torch_dtype=torch.float16,
    trust_remote_code=True,
    device_map="auto",
)

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)

def extract_text(image, query):
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": query},
            ],
        }
    ]
    
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
    inputs = inputs.to("cpu")
    with torch.no_grad():
        generated_ids = vlm.generate(**inputs, max_new_tokens=200, temperature=0.7, top_p=0.9)
        generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
        return processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    
def post_process_text(text):
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    # Remove repeated phrases (which sometimes occur in multi-pass extraction)
    phrases = text.split('. ')
    unique_phrases = list(dict.fromkeys(phrases))
    text = '. '.join(unique_phrases)
    return text

def ocr(image):
    queries = [
        "Extract and transcribe all the text visible in the image, including any small or partially visible text.",
        "Look closely at the image and list any text you see, no matter how small or unclear.",
        "What text can you identify in this image? Include everything, even if it's partially obscured or in the background."
    ]

    all_extracted_text = []
    for query in queries:
        extracted_text = extract_text(image, query)
        all_extracted_text.append(extracted_text)

    # Combine and deduplicate the results
    final_text = "\n".join(set(all_extracted_text))

    final_text = post_process_text(final_text)
    return final_text
    
    
def main_fun(image, keyword):
    ext_text = ocr(image)
    
    if keyword:
        highlight_text = re.sub(f'({re.escape(keyword)})', r'<span style="background-color: yellow;">\1</span>', ext_text, flags=re.IGNORECASE)
    
    return ext_text, highlight_text

iface = gr.Interface(
    fn=app,
    inputs=[
        gr.Image(type="pil", label="Upload an Image").
        gr.Textbox(label="Enter search term")
    ],
    outputs=[
        gr.Textbox(label="Extracted Text"),
        gr.HTML(label="Search Results")
    ],
    title="Document Search using OCR (English/Hindi)"
)

iface.launch()