Spaces:
Runtime error
Runtime error
File size: 7,301 Bytes
50931bf a31b6f8 5b18406 f4ba38e 7c308ee 6e4b218 7c308ee 50931bf 266b856 50931bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
# os.system("pip install -q gradio==4.10.0")
# os.system("pip install torch==2.1.0 torchvision torchaudio")
# os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git@v0.6'")
# os.system("pip install layoutparser==0.3.4 layoutparser[layoutmodels] layoutparser[ocr]")
# os.system("pip install requests==2.31.0")
os.system("pip install torch")
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
os.system("pip install layoutparser layoutparser[layoutmodels] layoutparser[ocr]")
os.system("pip install Pillow==9.4.0")
import gradio as gr
import layoutparser as lp
from PIL import Image
from urllib.parse import urlparse
import requests
def get_RGB_image(image_or_path: str | Image.Image) -> bytes:
if isinstance(image_or_path, str):
if urlparse(image_or_path).scheme in ["http", "https"]: # Online
image_or_path = Image.open(
requests.get(image_or_path, stream=True).raw)
else: # Local
image_or_path = Image.open(image_or_path)
return image_or_path.convert("RGB")
def inference_factory(config_path: str, model_path: str, label_map: dict, color_map: dict, examples=[], launch=True):
import traceback
model: lp.elements.layout.Layout = lp.Detectron2LayoutModel(
config_path=config_path,
model_path=model_path,
# extra_config = ["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.8],
label_map=label_map)
default_threshold = 0.8
cache = {
'annotated_image': None,
'message': None,
'threshold': default_threshold,
'image': None,
'predicted': None
}
def truncate(f, n):
return int(f * 10 ** n) / 10 ** n
def fn(image: Image.Image, threshold: float = default_threshold, just_image=True):
try:
nonlocal cache
if cache['image'] == image and cache['threshold'] == threshold and bool(cache['annotated_image']):
return [cache['annotated_image'], cache['message'], cache['threshold']]
layout_predicted = cache['predicted'] if cache['image'] == image else model.detect(
image)
threshold = truncate(
min([max([block.score for block in layout_predicted] + [0])] + [threshold]), 1)
blocks: List[lp.elements.layout_elements.TextBlock] = [block.set(
id=f'{block.type}/{block.score:.2f}') for block in layout_predicted if block.score >= threshold]
annotated_image = lp.draw_box(
image,
blocks,
color_map=color_map,
show_element_id=True,
id_font_size=14,
id_text_background_color='black',
id_text_color='white')
message = \
f'{len(blocks)} bounding boxes matched for {threshold} threshold, out of {len(layout_predicted)} total bounding boxes' if len(blocks) > 0 \
else f'No bounding boxesfor {threshold} threshold.'
cache = {
'annotated_image': annotated_image,
'message': message,
'threshold': threshold,
'image': image,
'predicted': layout_predicted
}
return annotated_image if just_image else [annotated_image, message, threshold]
except Exception as e:
error = traceback.format_exc()
return error if just_image else [None, error, threshold]
if not launch:
return fn
###########################################################
################### Start of Gradio setup #################
###########################################################
title = "Document Similarity Search using Detectron2"
description = "<h2>Document Similarity Search using Detectron2<h2>"
article = "<h4>More details, Links about this! - Document Similarity Search using Detectron2<h4>"
css = '''
image { max-height="86vh" !important; }
.center { display: flex; flex: 1 1 auto; align-items: center; align-content: center; justify-content: center; justify-items: center; }
'''
def preview(image_url):
try:
return [gr.Tabs(selected=0), get_RGB_image(image_url), None]
except:
error = traceback.format_exc()
return [gr.Tabs(selected=1), None, gr.HTML(value=error, visible=True)]
with gr.Blocks(title=title, css=css) as app:
with gr.Row():
gr.HTML(value=description, elem_classes=['center'])
with gr.Row():
with gr.Column():
with gr.Tabs() as tabs:
with gr.Tab("From Image", id=0):
document_image = gr.Image(type="pil", label="Document Image")
submit = gr.Button(value="Submit", variant="primary")
if len(examples) > 0:
gr.Examples(
examples=examples,
inputs=document_image,
label='Select any of these test examples')
with gr.Tab("From URL", id=1):
image_url = gr.Textbox(
label="Document Image Link",
info="Paste a Link to Document Image",
placeholder="https://datasets-server.huggingface.co/assets/ds4sd/icdar2023-doclaynet/--/2023.01/validation/6/image/image.jpg")
error_message = gr.HTML(label="Error Message", visible=False)
preview_btn = gr.Button(value="Preview", variant="primary")
with gr.Column():
with gr.Group():
annotated_document_image = gr.Image(type="pil", label="Annotated Document Image")
message = gr.HTML(label="Message")
threshold = gr.Slider(0.0, 1.0, value=0.0, label="Threshold", info="Choose between 0.0 and 1.0")
with gr.Row():
gr.HTML(value=article, elem_classes=['center'])
preview_btn.click(preview, [image_url], [tabs, document_image, error_message])
submit.click(
fn=lambda image: fn(image, just_image=False),
inputs=document_image,
outputs=[annotated_document_image, message, threshold])
threshold.change(
fn=lambda image, threshold: fn(image, threshold, just_image=False),
inputs=[document_image, threshold],
outputs=[annotated_document_image, message])
return app.launch
label_map = {0: 'Caption', 1: 'Footnote', 2: 'Formula', 3: 'List-item', 4: 'Page-footer', 5: 'Page-header', 6: 'Picture', 7: 'Section-header', 8: 'Table', 9: 'Text', 10: 'Title'}
color_map = {'Caption': '#acc2d9', 'Footnote': '#56ae57', 'Formula': '#b2996e', 'List-item': '#a8ff04', 'Page-footer': '#69d84f', 'Page-header': '#894585', 'Picture': '#70b23f', 'Section-header': '#d4ffff', 'Table': '#65ab7c', 'Text': '#952e8f', 'Title': '#fcfc81'}
config_path = './config.yaml'
model_path = './model_final.pth'
examples = ['./example.1.jpg', './example.2.jpg', './example.3.jpg']
infer = inference_factory(config_path, model_path, label_map, color_map, examples = examples)
infer(debug=True)
|