Spaces:
Sleeping
Sleeping
from paddleocr import PaddleOCR | |
import json | |
from PIL import Image | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
from gliner import GLiNER | |
# Initialize GLiNER model | |
gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1") | |
# Entity labels including website | |
labels = ["person name", "company name", "job title", "phone", "email", "address", "website"] | |
def get_random_color(): | |
c = tuple(np.random.randint(0, 256, 3).tolist()) | |
return c | |
def draw_ocr_bbox(image, boxes, colors): | |
valid_boxes = [] | |
valid_colors = [] | |
for box, color in zip(boxes, colors): | |
if len(box) > 0: # Only draw valid boxes | |
valid_boxes.append(box) | |
valid_colors.append(color) | |
for box, color in zip(valid_boxes, valid_colors): | |
box = np.array(box).reshape(-1, 1, 2).astype(np.int64) | |
image = cv2.polylines(np.array(image), [box], True, color, 2) | |
return image | |
def inference(img: Image.Image, lang, confidence): | |
# Initialize PaddleOCR | |
ocr = PaddleOCR(use_angle_cls=True, lang=lang, use_gpu=False, | |
det_model_dir=f'./models/det/{lang}', | |
cls_model_dir=f'./models/cls/{lang}', | |
rec_model_dir=f'./models/rec/{lang}') | |
# Process image | |
img2np = np.array(img) | |
ocr_result = ocr.ocr(img2np, cls=True)[0] | |
# Original OCR processing | |
ocr_items = [] | |
if ocr_result: | |
boxes = [line[0] for line in ocr_result] | |
txts = [line[1][0] for line in ocr_result] | |
scores = [line[1][1] for line in ocr_result] | |
ocr_items = [ | |
{'boxes': box, 'txt': txt, 'score': score, '_c': get_random_color()} | |
for box, txt, score in zip(boxes, txts, scores) | |
if score > confidence | |
] | |
# GLiNER Entity Extraction | |
combined_text = " ".join([item['txt'] for item in ocr_items]) | |
gliner_entities = gliner_model.predict_entities(combined_text, labels, threshold=0.3) | |
# Add GLiNER entities (without boxes) | |
gliner_items = [ | |
{'boxes': [], 'txt': f"{ent['text']} ({ent['label']})", 'score': 1.0, '_c': get_random_color()} | |
for ent in gliner_entities | |
] | |
# QR Code Detection | |
qr_items = [] | |
qr_detector = cv2.QRCodeDetector() | |
retval, decoded_info, points, _ = qr_detector.detectAndDecodeMulti(img2np) | |
if retval: | |
for i, url in enumerate(decoded_info): | |
if url: | |
qr_box = points[i].reshape(-1, 2).tolist() | |
qr_items.append({ | |
'boxes': qr_box, | |
'txt': url, | |
'score': 1.0, | |
'_c': get_random_color() | |
}) | |
# Combine all results | |
final_result = ocr_items + gliner_items + qr_items | |
# Prepare output | |
image = img.convert('RGB') | |
image_with_boxes = draw_ocr_bbox(image, | |
[item['boxes'] for item in final_result], | |
[item['_c'] for item in final_result]) | |
data = [ | |
[json.dumps(item['boxes']), round(item['score'], 3), item['txt']] | |
for item in final_result | |
] | |
return Image.fromarray(image_with_boxes), data | |
title = 'Enhanced Business Card Scanner' | |
description = 'Combines OCR, entity recognition, and QR scanning' | |
examples = [ | |
['example_imgs/example.jpg', 'en', 0.5], | |
['example_imgs/demo003.jpeg', 'en', 0.7], | |
] | |
css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}" | |
if __name__ == '__main__': | |
demo = gr.Interface( | |
inference, | |
[ | |
gr.Image(type='pil', label='Input'), | |
gr.Dropdown(choices=['en', 'fr', 'german', 'korean', 'japan'], value='en', label='Language'), | |
gr.Slider(0.1, 1, 0.5, step=0.1, label='Confidence Threshold') | |
], | |
[gr.Image(type='pil', label='Output'), gr.Dataframe(headers=['bbox', 'score', 'text'], label='Results')], | |
title=title, | |
description=description, | |
examples=examples, | |
css=css | |
) | |
demo.launch() |