File size: 4,080 Bytes
70b98b2
 
 
 
 
 
24a8aeb
 
 
 
 
 
 
70b98b2
 
 
 
 
 
24a8aeb
 
 
 
 
 
 
 
 
 
70b98b2
 
 
24a8aeb
13801ad
 
 
 
24a8aeb
 
70b98b2
24a8aeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70b98b2
24a8aeb
 
 
70b98b2
24a8aeb
 
 
 
 
 
70b98b2
24a8aeb
 
70b98b2
 
24a8aeb
 
70b98b2
 
 
 
631b608
 
 
24a8aeb
 
 
 
631b608
24a8aeb
631b608
 
 
24a8aeb
70b98b2
24a8aeb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from paddleocr import PaddleOCR
import json
from PIL import Image
import gradio as gr
import numpy as np
import cv2
from gliner import GLiNER

# Initialize GLiNER model
gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")

# Entity labels including website
labels = ["person name", "company name", "job title", "phone", "email", "address", "website"]

def get_random_color():
    c = tuple(np.random.randint(0, 256, 3).tolist())
    return c

def draw_ocr_bbox(image, boxes, colors):
    valid_boxes = []
    valid_colors = []
    for box, color in zip(boxes, colors):
        if len(box) > 0:  # Only draw valid boxes
            valid_boxes.append(box)
            valid_colors.append(color)
    
    for box, color in zip(valid_boxes, valid_colors):
        box = np.array(box).reshape(-1, 1, 2).astype(np.int64)
        image = cv2.polylines(np.array(image), [box], True, color, 2)
    return image

def inference(img: Image.Image, lang, confidence):
    # Initialize PaddleOCR
    ocr = PaddleOCR(use_angle_cls=True, lang=lang, use_gpu=False,
                    det_model_dir=f'./models/det/{lang}',
                    cls_model_dir=f'./models/cls/{lang}',
                    rec_model_dir=f'./models/rec/{lang}')
    
    # Process image
    img2np = np.array(img)
    ocr_result = ocr.ocr(img2np, cls=True)[0]
    
    # Original OCR processing
    ocr_items = []
    if ocr_result:
        boxes = [line[0] for line in ocr_result]
        txts = [line[1][0] for line in ocr_result]
        scores = [line[1][1] for line in ocr_result]
        
        ocr_items = [
            {'boxes': box, 'txt': txt, 'score': score, '_c': get_random_color()}
            for box, txt, score in zip(boxes, txts, scores)
            if score > confidence
        ]
    
    # GLiNER Entity Extraction
    combined_text = " ".join([item['txt'] for item in ocr_items])
    gliner_entities = gliner_model.predict_entities(combined_text, labels, threshold=0.3)
    
    # Add GLiNER entities (without boxes)
    gliner_items = [
        {'boxes': [], 'txt': f"{ent['text']} ({ent['label']})", 'score': 1.0, '_c': get_random_color()}
        for ent in gliner_entities
    ]
    
    # QR Code Detection
    qr_items = []
    qr_detector = cv2.QRCodeDetector()
    retval, decoded_info, points, _ = qr_detector.detectAndDecodeMulti(img2np)
    
    if retval:
        for i, url in enumerate(decoded_info):
            if url:
                qr_box = points[i].reshape(-1, 2).tolist()
                qr_items.append({
                    'boxes': qr_box,
                    'txt': url,
                    'score': 1.0,
                    '_c': get_random_color()
                })
    
    # Combine all results
    final_result = ocr_items + gliner_items + qr_items
    
    # Prepare output
    image = img.convert('RGB')
    image_with_boxes = draw_ocr_bbox(image, 
                                   [item['boxes'] for item in final_result],
                                   [item['_c'] for item in final_result])
    
    data = [
        [json.dumps(item['boxes']), round(item['score'], 3), item['txt']]
        for item in final_result
    ]
    
    return Image.fromarray(image_with_boxes), data

title = 'Enhanced Business Card Scanner'
description = 'Combines OCR, entity recognition, and QR scanning'

examples = [
    ['example_imgs/example.jpg', 'en', 0.5],
    ['example_imgs/demo003.jpeg', 'en', 0.7],
]

css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"

if __name__ == '__main__':
    demo = gr.Interface(
        inference,
        [
            gr.Image(type='pil', label='Input'),
            gr.Dropdown(choices=['en', 'fr', 'german', 'korean', 'japan'], value='en', label='Language'),
            gr.Slider(0.1, 1, 0.5, step=0.1, label='Confidence Threshold')
        ],
        [gr.Image(type='pil', label='Output'), gr.Dataframe(headers=['bbox', 'score', 'text'], label='Results')],
        title=title,
        description=description,
        examples=examples,
        css=css
    )
    demo.launch()