File size: 3,250 Bytes
4902aa0
 
3bc9acc
4902aa0
 
0f8e862
da194cb
3bc9acc
4902aa0
 
 
da194cb
4902aa0
 
 
3bc9acc
4902aa0
3bc9acc
4902aa0
 
 
 
 
 
 
 
 
 
da194cb
4902aa0
da194cb
4902aa0
 
96376c2
3bc9acc
da194cb
 
 
3bc9acc
da194cb
 
 
86115e8
da194cb
 
3bc9acc
da194cb
 
 
3bc9acc
da194cb
 
 
 
 
 
 
 
 
 
 
1fb71cb
96376c2
 
 
da194cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
import torch
from qwen_vl_utils import process_vision_info
from PIL import Image
import gradio as gr
import re

rag = RAGMultiModalModel.from_pretrained("vidore/colpali")
vlm = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct",
    torch_dtype=torch.float16,
    trust_remote_code=True,
    device_map="auto",
)

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)

def extract_text(image, query):
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": query},
            ],
        }
    ]

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
    inputs = inputs.to("cpu")

    generated_ids = vlm.generate(**inputs, max_new_tokens=200, temperature=0.7, top_p=0.9)
    generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
    return processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

def search_text(text, query):
    if query:
        searched_text = re.sub(f'({re.escape(query)})', r'<span style="background-color: yellow;">\1</span>', text, flags=re.IGNORECASE)
    else:
        searched_text = text
    return searched_text

def extraction(image, query):
    extracted_text = extract_text(image, query)
    return extracted_text, extracted_text   # return twice - one to display output and the other for state management


"""
    Main App
"""
with gr.Blocks() as main_app:
    gr.Markdown("# Document Reader using OCR(English/Hindi)")
    
    with gr.Row():
        with gr.Column():
            img_input = gr.Image(type="pil", label="Upload an Image")
            query_input = gr.Textbox(label="Enter query for retrieval", placeholder="Query/Prompt")
            gr.Markdown("""
                        ### Please use this prompt for text extraction 
                        **What text can you identify in this image? Include everything, even if it's partially obscured or in the background.**
                        """)
            search_input = gr.Textbox(label="Enter search term", placeholder="Search")
            extract_button = gr.Button("Read Doc!")
            search_button = gr.Button("Search!")
            
        with gr.Column():
            extracted_text_op = gr.Textbox(label="Output")
            search_text_op = gr.HTML(label="Search Results")
        
        extracted_text_state = gr.State()
        extract_button.click(
            extraction,
            inputs=[img_input, query_input],
            outputs=[extracted_text_op, extracted_text_state]
        )
        
        search_button.click(
            search_text,
            inputs=[extracted_text_state, search_input],
            outputs=[search_text_op]
        )
        
main_app.launch()