File size: 7,301 Bytes
50931bf
 
 
a31b6f8
5b18406
f4ba38e
7c308ee
 
6e4b218
 
7c308ee
50931bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266b856
50931bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

import os

# os.system("pip install -q gradio==4.10.0")
# os.system("pip install torch==2.1.0 torchvision torchaudio")
# os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git@v0.6'")
# os.system("pip install layoutparser==0.3.4 layoutparser[layoutmodels] layoutparser[ocr]")
# os.system("pip install requests==2.31.0")
os.system("pip install torch")
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
os.system("pip install layoutparser layoutparser[layoutmodels] layoutparser[ocr]")
os.system("pip install Pillow==9.4.0")

import gradio as gr
import layoutparser as lp
from PIL import Image
from urllib.parse import urlparse
import requests

def get_RGB_image(image_or_path: str | Image.Image) -> bytes:
    if isinstance(image_or_path, str):
        if urlparse(image_or_path).scheme in ["http", "https"]:  # Online
            image_or_path = Image.open(
                requests.get(image_or_path, stream=True).raw)
        else:  # Local
            image_or_path = Image.open(image_or_path)
    return image_or_path.convert("RGB")

def inference_factory(config_path: str, model_path: str, label_map: dict, color_map: dict, examples=[], launch=True):
    import traceback
    model: lp.elements.layout.Layout = lp.Detectron2LayoutModel(
        config_path=config_path,
        model_path=model_path,
        # extra_config = ["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.8],
        label_map=label_map)
    default_threshold = 0.8
    cache = {
        'annotated_image': None,
        'message': None,
        'threshold': default_threshold,
        'image': None,
        'predicted': None
    }

    def truncate(f, n):
        return int(f * 10 ** n) / 10 ** n

    def fn(image: Image.Image, threshold: float = default_threshold, just_image=True):
        try:
            nonlocal cache
            if cache['image'] == image and cache['threshold'] == threshold and bool(cache['annotated_image']):
                return [cache['annotated_image'], cache['message'], cache['threshold']]
            layout_predicted = cache['predicted'] if cache['image'] == image else model.detect(
                image)
            threshold = truncate(
                min([max([block.score for block in layout_predicted] + [0])] + [threshold]), 1)
            blocks: List[lp.elements.layout_elements.TextBlock] = [block.set(
                id=f'{block.type}/{block.score:.2f}') for block in layout_predicted if block.score >= threshold]
            annotated_image = lp.draw_box(
                image,
                blocks,
                color_map=color_map,
                show_element_id=True,
                id_font_size=14,
                id_text_background_color='black',
                id_text_color='white')
            message = \
                f'{len(blocks)} bounding boxes matched for {threshold} threshold, out of {len(layout_predicted)} total bounding boxes' if len(blocks) > 0 \
                else f'No bounding boxesfor {threshold} threshold.'
            cache = {
                'annotated_image': annotated_image,
                'message': message,
                'threshold': threshold,
                'image': image,
                'predicted': layout_predicted
            }
            return annotated_image if just_image else [annotated_image, message, threshold]
        except Exception as e:
            error = traceback.format_exc()
            return error if just_image else [None, error, threshold]
    if not launch:
        return fn

    ###########################################################
    ################### Start of Gradio setup #################
    ###########################################################
    title = "Document Similarity Search using Detectron2"
    description = "<h2>Document Similarity Search using Detectron2<h2>"
    article = "<h4>More details, Links about this! - Document Similarity Search using Detectron2<h4>"
    css = '''
    image { max-height="86vh" !important; }
    .center { display: flex; flex: 1 1 auto; align-items: center; align-content: center; justify-content: center; justify-items: center; }
  '''

    def preview(image_url):
        try:
            return [gr.Tabs(selected=0), get_RGB_image(image_url), None]
        except:
            error = traceback.format_exc()
            return [gr.Tabs(selected=1), None, gr.HTML(value=error, visible=True)]

    with gr.Blocks(title=title, css=css) as app:
        with gr.Row():
            gr.HTML(value=description, elem_classes=['center'])
        with gr.Row():
            with gr.Column():
                with gr.Tabs() as tabs:
                    with gr.Tab("From Image", id=0):
                        document_image = gr.Image(type="pil", label="Document Image")
                        submit = gr.Button(value="Submit", variant="primary")
                        if len(examples) > 0:
                            gr.Examples(
                                examples=examples,
                                inputs=document_image,
                                label='Select any of these test examples')
                    with gr.Tab("From URL", id=1):
                        image_url = gr.Textbox(
                            label="Document Image Link",
                            info="Paste a Link to Document Image",
                            placeholder="https://datasets-server.huggingface.co/assets/ds4sd/icdar2023-doclaynet/--/2023.01/validation/6/image/image.jpg")
                        error_message = gr.HTML(label="Error Message", visible=False)
                        preview_btn = gr.Button(value="Preview", variant="primary")
            with gr.Column():
                with gr.Group():
                    annotated_document_image = gr.Image(type="pil", label="Annotated Document Image")
                    message = gr.HTML(label="Message")
                    threshold = gr.Slider(0.0, 1.0, value=0.0, label="Threshold", info="Choose between 0.0 and 1.0")
        with gr.Row():
            gr.HTML(value=article, elem_classes=['center'])
        preview_btn.click(preview, [image_url], [tabs, document_image, error_message])
        submit.click(
            fn=lambda image: fn(image, just_image=False),
            inputs=document_image,
            outputs=[annotated_document_image, message, threshold])
        threshold.change(
            fn=lambda image, threshold: fn(image, threshold, just_image=False),
            inputs=[document_image, threshold],
            outputs=[annotated_document_image, message])
    return app.launch

label_map = {0: 'Caption', 1: 'Footnote', 2: 'Formula', 3: 'List-item', 4: 'Page-footer', 5: 'Page-header', 6: 'Picture', 7: 'Section-header', 8: 'Table', 9: 'Text', 10: 'Title'}
color_map = {'Caption': '#acc2d9', 'Footnote': '#56ae57', 'Formula': '#b2996e', 'List-item': '#a8ff04', 'Page-footer': '#69d84f', 'Page-header': '#894585', 'Picture': '#70b23f', 'Section-header': '#d4ffff', 'Table': '#65ab7c', 'Text': '#952e8f', 'Title': '#fcfc81'}
config_path = './config.yaml'
model_path = './model_final.pth'
examples = ['./example.1.jpg', './example.2.jpg', './example.3.jpg']

infer = inference_factory(config_path, model_path, label_map, color_map, examples = examples)
infer(debug=True)