mckabue commited on
Commit
50931bf
Β·
1 Parent(s): d21cd54

2023-11-27-03-48-27

Browse files
Files changed (1) hide show
  1. app.py +146 -0
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+
4
+ os.system("pip install torch")
5
+ os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
6
+ os.system("pip install layoutparser")
7
+ os.system("pip install layoutparser[layoutmodels]")
8
+ os.system("pip install layoutparser[ocr]")
9
+ os.system("pip install Pillow==9.4.0")
10
+ os.system("pip install requests")
11
+
12
+ import gradio as gr
13
+ import layoutparser as lp
14
+ from PIL import Image
15
+ from urllib.parse import urlparse
16
+ import requests
17
+
18
+ def get_RGB_image(image_or_path: str | Image.Image) -> bytes:
19
+ if isinstance(image_or_path, str):
20
+ if urlparse(image_or_path).scheme in ["http", "https"]: # Online
21
+ image_or_path = Image.open(
22
+ requests.get(image_or_path, stream=True).raw)
23
+ else: # Local
24
+ image_or_path = Image.open(image_or_path)
25
+ return image_or_path.convert("RGB")
26
+
27
+ def inference_factory(config_path: str, model_path: str, label_map: dict, color_map: dict, examples=[], launch=True):
28
+ import traceback
29
+ model: lp.elements.layout.Layout = lp.Detectron2LayoutModel(
30
+ config_path=config_path,
31
+ model_path=model_path,
32
+ # extra_config = ["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.8],
33
+ label_map=label_map)
34
+ default_threshold = 0.8
35
+ cache = {
36
+ 'annotated_image': None,
37
+ 'message': None,
38
+ 'threshold': default_threshold,
39
+ 'image': None,
40
+ 'predicted': None
41
+ }
42
+
43
+ def truncate(f, n):
44
+ return int(f * 10 ** n) / 10 ** n
45
+
46
+ def fn(image: Image.Image, threshold: float = default_threshold, just_image=True):
47
+ try:
48
+ nonlocal cache
49
+ if cache['image'] == image and cache['threshold'] == threshold and bool(cache['annotated_image']):
50
+ return [cache['annotated_image'], cache['message'], cache['threshold']]
51
+ layout_predicted = cache['predicted'] if cache['image'] == image else model.detect(
52
+ image)
53
+ threshold = truncate(
54
+ min([max([block.score for block in layout_predicted] + [0])] + [threshold]), 1)
55
+ blocks: List[lp.elements.layout_elements.TextBlock] = [block.set(
56
+ id=f'{block.type}/{block.score:.2f}') for block in layout_predicted if block.score >= threshold]
57
+ annotated_image = lp.draw_box(
58
+ image,
59
+ blocks,
60
+ color_map=color_map,
61
+ show_element_id=True,
62
+ id_font_size=14,
63
+ id_text_background_color='black',
64
+ id_text_color='white')
65
+ message = \
66
+ f'{len(blocks)} bounding boxes matched for {threshold} threshold, out of {len(layout_predicted)} total bounding boxes' if len(blocks) > 0 \
67
+ else f'No bounding boxesfor {threshold} threshold.'
68
+ cache = {
69
+ 'annotated_image': annotated_image,
70
+ 'message': message,
71
+ 'threshold': threshold,
72
+ 'image': image,
73
+ 'predicted': layout_predicted
74
+ }
75
+ return annotated_image if just_image else [annotated_image, message, threshold]
76
+ except Exception as e:
77
+ error = traceback.format_exc()
78
+ return error if just_image else [None, error, threshold]
79
+ if not launch:
80
+ return fn
81
+
82
+ ###########################################################
83
+ ################### Start of Gradio setup #################
84
+ ###########################################################
85
+ title = "Document Similarity Search using Detectron2"
86
+ description = "<h2>Document Similarity Search using Detectron2<h2>"
87
+ article = "<h4>More details, Links about this! - Document Similarity Search using Detectron2<h4>"
88
+ css = '''
89
+ image { max-height="86vh" !important; }
90
+ .center { display: flex; flex: 1 1 auto; align-items: center; align-content: center; justify-content: center; justify-items: center; }
91
+ '''
92
+
93
+ def preview(image_url):
94
+ try:
95
+ return [gr.Tabs(selected=0), get_RGB_image(image_url), None]
96
+ except:
97
+ error = traceback.format_exc()
98
+ return [gr.Tabs(selected=1), None, gr.HTML(value=error, visible=True)]
99
+
100
+ with gr.Blocks(title=title, css=css) as app:
101
+ with gr.Row():
102
+ gr.HTML(value=description, elem_classes=['center'])
103
+ with gr.Row():
104
+ with gr.Column():
105
+ with gr.Tabs() as tabs:
106
+ with gr.Tab("From Image", id=0):
107
+ document_image = gr.Image(type="pil", label="Document Image")
108
+ submit = gr.Button(value="Submit", variant="primary")
109
+ if len(examples) > 0:
110
+ gr.Examples(
111
+ examples=examples,
112
+ inputs=document_image,
113
+ label='Select any of these test examples')
114
+ with gr.Tab("From URL", id=1):
115
+ image_url = gr.Textbox(
116
+ label="Document Image Link",
117
+ info="Paste a Link to Document Image",
118
+ placeholder="https://datasets-server.huggingface.co/assets/ds4sd/icdar2023-doclaynet/--/2023.01/validation/6/image/image.jpg")
119
+ error_message = gr.HTML(label="Error Message", visible=False)
120
+ preview_btn = gr.Button(value="Preview", variant="primary")
121
+ with gr.Column():
122
+ with gr.Group():
123
+ annotated_document_image = gr.Image(type="pil", label="Annotated Document Image")
124
+ message = gr.HTML(label="Message")
125
+ threshold = gr.Slider(0.0, 1.0, value=0.0, label="Threshold", info="Choose between 0.0 and 1.0")
126
+ with gr.Row():
127
+ gr.HTML(value=article, elem_classes=['center'])
128
+ preview_btn.click(preview, [image_url], [tabs, document_image, error_message])
129
+ submit.click(
130
+ fn=lambda image: fn(image, just_image=False),
131
+ inputs=document_image,
132
+ outputs=[annotated_document_image, message, threshold])
133
+ threshold.change(
134
+ fn=lambda image, threshold: fn(image, threshold, just_image=False),
135
+ inputs=[document_image, threshold],
136
+ outputs=[annotated_document_image, message])
137
+ return app.launch
138
+
139
+ label_map = {0: 'Caption', 1: 'Footnote', 2: 'Formula', 3: 'List-item', 4: 'Page-footer', 5: 'Page-header', 6: 'Picture', 7: 'Section-header', 8: 'Table', 9: 'Text', 10: 'Title'}
140
+ color_map = {'Caption': '#acc2d9', 'Footnote': '#56ae57', 'Formula': '#b2996e', 'List-item': '#a8ff04', 'Page-footer': '#69d84f', 'Page-header': '#894585', 'Picture': '#70b23f', 'Section-header': '#d4ffff', 'Table': '#65ab7c', 'Text': '#952e8f', 'Title': '#fcfc81'}
141
+ config_path = './config.yaml'
142
+ model_path = './model_final.pth'
143
+ examples = ['./example.1.jpg', './example.2.jpg', './example.3.jpg']
144
+
145
+ infer = inference_factory(config_path, model_path, label_map, color_map, examples = examples)
146
+ infer(debug=True)