Spaces:
Sleeping
Sleeping
File size: 5,149 Bytes
70b98b2 6bba885 70b98b2 6bba885 24a8aeb 6bba885 24a8aeb 6bba885 70b98b2 6bba885 70b98b2 6bba885 70b98b2 6bba885 70b98b2 6bba885 24a8aeb 6bba885 70b98b2 6bba885 70b98b2 6bba885 70b98b2 6bba885 70b98b2 631b608 6bba885 631b608 6bba885 70b98b2 6bba885 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from paddleocr import PaddleOCR
from gliner import GLiNER
import json
from PIL import Image
import gradio as gr
import numpy as np
import cv2
import logging
import os
from pathlib import Path
import tempfile
import pandas as pd
import io
import re
import traceback
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set up GLiNER environment variables (adjust if needed)
os.environ['GLINER_HOME'] = './gliner_models'
# Load GLiNER model (do not change the model)
try:
logger.info("Loading GLiNER model...")
gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
except Exception as e:
logger.error("Failed to load GLiNER model")
raise e
# Get a random color (used for drawing bounding boxes, if needed)
def get_random_color():
c = tuple(np.random.randint(0, 256, 3).tolist())
return c
# Draw OCR bounding boxes (this function is kept for debugging/visualization purposes)
def draw_ocr_bbox(image, boxes, colors):
for i in range(len(boxes)):
box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
image = cv2.polylines(np.array(image), [box], True, colors[i], 2)
return image
# Scan for a QR code using OpenCV's QRCodeDetector
def scan_qr_code(image):
try:
# Ensure the image is in numpy array format
image_np = np.array(image) if not isinstance(image, np.ndarray) else image
qr_detector = cv2.QRCodeDetector()
data, points, _ = qr_detector.detectAndDecode(image_np)
if data:
return data.strip()
return None
except Exception as e:
logger.error("QR code scanning failed: " + str(e))
return None
# Main inference function
def inference(img: Image.Image, confidence):
try:
# Initialize PaddleOCR for English only (removed other languages)
ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False,
det_model_dir=f'./models/det/en',
cls_model_dir=f'./models/cls/en',
rec_model_dir=f'./models/rec/en')
img_np = np.array(img)
result = ocr.ocr(img_np, cls=True)[0]
# Concatenate all recognized texts
ocr_texts = [line[1][0] for line in result]
ocr_text = " ".join(ocr_texts)
# (Optional) Draw bounding boxes on the image if needed for debugging
image_rgb = img.convert('RGB')
boxes = [line[0] for line in result]
colors = [get_random_color() for _ in boxes]
# Uncomment next two lines if you want to visualize OCR results:
# im_show = draw_ocr_bbox(image_rgb, boxes, colors)
# im_show = Image.fromarray(im_show)
# Extract entities using GLiNER with updated labels (adding 'website')
labels = ["person name", "company name", "job title", "phone", "email", "address", "website"]
entities = gliner_model.predict_entities(ocr_text, labels, threshold=confidence, flat_ner=True)
results = {label.title(): [] for label in labels}
for entity in entities:
lab = entity["label"].title()
if lab in results:
results[lab].append(entity["text"])
# Scan the original image for a QR code and add it if found
qr_data = scan_qr_code(img)
if qr_data:
results["QR"] = [qr_data]
# Generate CSV content in memory using BytesIO
csv_io = io.BytesIO()
pd.DataFrame([{k: "; ".join(v) for k, v in results.items()}]).to_csv(csv_io, index=False)
csv_io.seek(0)
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="wb") as tmp_file:
tmp_file.write(csv_io.getvalue())
csv_path = tmp_file.name
# Return tuple: (OCR text, JSON entities, CSV file path, error message)
return ocr_text, {k: "; ".join(v) for k, v in results.items()}, csv_path, ""
except Exception as e:
logger.error("Processing failed: " + traceback.format_exc())
return "", {}, None, f"Error: {str(e)}\n{traceback.format_exc()}"
# Gradio Interface setup (output structure remains unchanged)
title = 'Business Card Information Extractor'
description = 'Extracts text using PaddleOCR and entities using GLiNER (with added website label) along with QR code scanning.'
# Examples can be updated accordingly
examples = [
['example_imgs/example.jpg', 0.5],
['example_imgs/demo003.jpeg', 0.7],
]
css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
if __name__ == '__main__':
demo = gr.Interface(
inference,
[gr.Image(type='pil', label='Upload Business Card'),
gr.Slider(0.1, 1, 0.5, step=0.1, label='Confidence Threshold')],
[gr.Textbox(label="Extracted Text"),
gr.JSON(label="Entities"),
gr.File(label="Download CSV"),
gr.Textbox(label="Error Details")],
title=title,
description=description,
examples=examples,
css=css,
cache_examples=True
)
demo.queue(max_size=10)
demo.launch()
|