paddle-ocr-demo / app.py
codic's picture
First update
24a8aeb verified
raw
history blame
4.08 kB
from paddleocr import PaddleOCR
import json
from PIL import Image
import gradio as gr
import numpy as np
import cv2
from gliner import GLiNER
# Initialize GLiNER model
gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
# Entity labels including website
labels = ["person name", "company name", "job title", "phone", "email", "address", "website"]
def get_random_color():
c = tuple(np.random.randint(0, 256, 3).tolist())
return c
def draw_ocr_bbox(image, boxes, colors):
valid_boxes = []
valid_colors = []
for box, color in zip(boxes, colors):
if len(box) > 0: # Only draw valid boxes
valid_boxes.append(box)
valid_colors.append(color)
for box, color in zip(valid_boxes, valid_colors):
box = np.array(box).reshape(-1, 1, 2).astype(np.int64)
image = cv2.polylines(np.array(image), [box], True, color, 2)
return image
def inference(img: Image.Image, lang, confidence):
# Initialize PaddleOCR
ocr = PaddleOCR(use_angle_cls=True, lang=lang, use_gpu=False,
det_model_dir=f'./models/det/{lang}',
cls_model_dir=f'./models/cls/{lang}',
rec_model_dir=f'./models/rec/{lang}')
# Process image
img2np = np.array(img)
ocr_result = ocr.ocr(img2np, cls=True)[0]
# Original OCR processing
ocr_items = []
if ocr_result:
boxes = [line[0] for line in ocr_result]
txts = [line[1][0] for line in ocr_result]
scores = [line[1][1] for line in ocr_result]
ocr_items = [
{'boxes': box, 'txt': txt, 'score': score, '_c': get_random_color()}
for box, txt, score in zip(boxes, txts, scores)
if score > confidence
]
# GLiNER Entity Extraction
combined_text = " ".join([item['txt'] for item in ocr_items])
gliner_entities = gliner_model.predict_entities(combined_text, labels, threshold=0.3)
# Add GLiNER entities (without boxes)
gliner_items = [
{'boxes': [], 'txt': f"{ent['text']} ({ent['label']})", 'score': 1.0, '_c': get_random_color()}
for ent in gliner_entities
]
# QR Code Detection
qr_items = []
qr_detector = cv2.QRCodeDetector()
retval, decoded_info, points, _ = qr_detector.detectAndDecodeMulti(img2np)
if retval:
for i, url in enumerate(decoded_info):
if url:
qr_box = points[i].reshape(-1, 2).tolist()
qr_items.append({
'boxes': qr_box,
'txt': url,
'score': 1.0,
'_c': get_random_color()
})
# Combine all results
final_result = ocr_items + gliner_items + qr_items
# Prepare output
image = img.convert('RGB')
image_with_boxes = draw_ocr_bbox(image,
[item['boxes'] for item in final_result],
[item['_c'] for item in final_result])
data = [
[json.dumps(item['boxes']), round(item['score'], 3), item['txt']]
for item in final_result
]
return Image.fromarray(image_with_boxes), data
title = 'Enhanced Business Card Scanner'
description = 'Combines OCR, entity recognition, and QR scanning'
examples = [
['example_imgs/example.jpg', 'en', 0.5],
['example_imgs/demo003.jpeg', 'en', 0.7],
]
css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
if __name__ == '__main__':
demo = gr.Interface(
inference,
[
gr.Image(type='pil', label='Input'),
gr.Dropdown(choices=['en', 'fr', 'german', 'korean', 'japan'], value='en', label='Language'),
gr.Slider(0.1, 1, 0.5, step=0.1, label='Confidence Threshold')
],
[gr.Image(type='pil', label='Output'), gr.Dataframe(headers=['bbox', 'score', 'text'], label='Results')],
title=title,
description=description,
examples=examples,
css=css
)
demo.launch()