File size: 5,149 Bytes
70b98b2
6bba885
70b98b2
 
 
 
 
6bba885
 
 
 
 
 
 
 
 
 
 
 
24a8aeb
6bba885
 
24a8aeb
6bba885
 
 
 
 
 
 
70b98b2
6bba885
70b98b2
 
 
 
6bba885
70b98b2
6bba885
 
 
70b98b2
 
6bba885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24a8aeb
6bba885
 
 
 
 
70b98b2
6bba885
 
 
70b98b2
6bba885
70b98b2
6bba885
 
70b98b2
 
 
 
631b608
 
 
6bba885
 
 
 
 
 
631b608
 
 
6bba885
 
70b98b2
6bba885
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from paddleocr import PaddleOCR
from gliner import GLiNER
import json
from PIL import Image
import gradio as gr
import numpy as np
import cv2
import logging
import os
from pathlib import Path
import tempfile
import pandas as pd
import io
import re
import traceback

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set up GLiNER environment variables (adjust if needed)
os.environ['GLINER_HOME'] = './gliner_models'

# Load GLiNER model (do not change the model)
try:
    logger.info("Loading GLiNER model...")
    gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
except Exception as e:
    logger.error("Failed to load GLiNER model")
    raise e

# Get a random color (used for drawing bounding boxes, if needed)
def get_random_color():
    c = tuple(np.random.randint(0, 256, 3).tolist())
    return c

# Draw OCR bounding boxes (this function is kept for debugging/visualization purposes)
def draw_ocr_bbox(image, boxes, colors):
    for i in range(len(boxes)):
        box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
        image = cv2.polylines(np.array(image), [box], True, colors[i], 2)
    return image

# Scan for a QR code using OpenCV's QRCodeDetector
def scan_qr_code(image):
    try:
        # Ensure the image is in numpy array format
        image_np = np.array(image) if not isinstance(image, np.ndarray) else image
        qr_detector = cv2.QRCodeDetector()
        data, points, _ = qr_detector.detectAndDecode(image_np)
        if data:
            return data.strip()
        return None
    except Exception as e:
        logger.error("QR code scanning failed: " + str(e))
        return None

# Main inference function
def inference(img: Image.Image, confidence):
    try:
        # Initialize PaddleOCR for English only (removed other languages)
        ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False,
                        det_model_dir=f'./models/det/en',
                        cls_model_dir=f'./models/cls/en',
                        rec_model_dir=f'./models/rec/en')
        img_np = np.array(img)
        result = ocr.ocr(img_np, cls=True)[0]
        
        # Concatenate all recognized texts
        ocr_texts = [line[1][0] for line in result]
        ocr_text = " ".join(ocr_texts)
        
        # (Optional) Draw bounding boxes on the image if needed for debugging
        image_rgb = img.convert('RGB')
        boxes = [line[0] for line in result]
        colors = [get_random_color() for _ in boxes]
        # Uncomment next two lines if you want to visualize OCR results:
        # im_show = draw_ocr_bbox(image_rgb, boxes, colors)
        # im_show = Image.fromarray(im_show)
        
        # Extract entities using GLiNER with updated labels (adding 'website')
        labels = ["person name", "company name", "job title", "phone", "email", "address", "website"]
        entities = gliner_model.predict_entities(ocr_text, labels, threshold=confidence, flat_ner=True)
        results = {label.title(): [] for label in labels}
        for entity in entities:
            lab = entity["label"].title()
            if lab in results:
                results[lab].append(entity["text"])
        
        # Scan the original image for a QR code and add it if found
        qr_data = scan_qr_code(img)
        if qr_data:
            results["QR"] = [qr_data]
        
        # Generate CSV content in memory using BytesIO
        csv_io = io.BytesIO()
        pd.DataFrame([{k: "; ".join(v) for k, v in results.items()}]).to_csv(csv_io, index=False)
        csv_io.seek(0)
        with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="wb") as tmp_file:
            tmp_file.write(csv_io.getvalue())
            csv_path = tmp_file.name
        
        # Return tuple: (OCR text, JSON entities, CSV file path, error message)
        return ocr_text, {k: "; ".join(v) for k, v in results.items()}, csv_path, ""
    except Exception as e:
        logger.error("Processing failed: " + traceback.format_exc())
        return "", {}, None, f"Error: {str(e)}\n{traceback.format_exc()}"

# Gradio Interface setup (output structure remains unchanged)
title = 'Business Card Information Extractor'
description = 'Extracts text using PaddleOCR and entities using GLiNER (with added website label) along with QR code scanning.'

# Examples can be updated accordingly
examples = [
    ['example_imgs/example.jpg', 0.5],
    ['example_imgs/demo003.jpeg', 0.7],
]

css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"

if __name__ == '__main__':
    demo = gr.Interface(
        inference,
        [gr.Image(type='pil', label='Upload Business Card'),
         gr.Slider(0.1, 1, 0.5, step=0.1, label='Confidence Threshold')],
        [gr.Textbox(label="Extracted Text"),
         gr.JSON(label="Entities"),
         gr.File(label="Download CSV"),
         gr.Textbox(label="Error Details")],
        title=title,
        description=description,
        examples=examples,
        css=css,
        cache_examples=True
    )
    demo.queue(max_size=10)
    demo.launch()