Spaces:
Sleeping
Sleeping
from paddleocr import PaddleOCR | |
from gliner import GLiNER | |
import json | |
from PIL import Image | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
import logging | |
import os | |
from pathlib import Path | |
import tempfile | |
import pandas as pd | |
import io | |
import re | |
import traceback | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Set up GLiNER environment variables (adjust if needed) | |
os.environ['GLINER_HOME'] = './gliner_models' | |
# Load GLiNER model (do not change the model) | |
try: | |
logger.info("Loading GLiNER model...") | |
gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1") | |
except Exception as e: | |
logger.error("Failed to load GLiNER model") | |
raise e | |
# Get a random color (used for drawing bounding boxes, if needed) | |
def get_random_color(): | |
c = tuple(np.random.randint(0, 256, 3).tolist()) | |
return c | |
# Draw OCR bounding boxes (this function is kept for debugging/visualization purposes) | |
def draw_ocr_bbox(image, boxes, colors): | |
for i in range(len(boxes)): | |
box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64) | |
image = cv2.polylines(np.array(image), [box], True, colors[i], 2) | |
return image | |
# Scan for a QR code using OpenCV's QRCodeDetector | |
def scan_qr_code(image): | |
try: | |
# Ensure the image is in numpy array format | |
image_np = np.array(image) if not isinstance(image, np.ndarray) else image | |
qr_detector = cv2.QRCodeDetector() | |
data, points, _ = qr_detector.detectAndDecode(image_np) | |
if data: | |
return data.strip() | |
return None | |
except Exception as e: | |
logger.error("QR code scanning failed: " + str(e)) | |
return None | |
# Main inference function | |
def inference(img: Image.Image, confidence): | |
try: | |
# Initialize PaddleOCR for English only (removed other languages) | |
ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False, | |
det_model_dir=f'./models/det/en', | |
cls_model_dir=f'./models/cls/en', | |
rec_model_dir=f'./models/rec/en') | |
img_np = np.array(img) | |
result = ocr.ocr(img_np, cls=True)[0] | |
# Concatenate all recognized texts | |
ocr_texts = [line[1][0] for line in result] | |
ocr_text = " ".join(ocr_texts) | |
# (Optional) Draw bounding boxes on the image if needed for debugging | |
image_rgb = img.convert('RGB') | |
boxes = [line[0] for line in result] | |
colors = [get_random_color() for _ in boxes] | |
# Uncomment next two lines if you want to visualize OCR results: | |
# im_show = draw_ocr_bbox(image_rgb, boxes, colors) | |
# im_show = Image.fromarray(im_show) | |
# Extract entities using GLiNER with updated labels (adding 'website') | |
labels = ["person name", "company name", "job title", "phone", "email", "address", "website"] | |
entities = gliner_model.predict_entities(ocr_text, labels, threshold=confidence, flat_ner=True) | |
results = {label.title(): [] for label in labels} | |
for entity in entities: | |
lab = entity["label"].title() | |
if lab in results: | |
results[lab].append(entity["text"]) | |
# Scan the original image for a QR code and add it if found | |
qr_data = scan_qr_code(img) | |
if qr_data: | |
results["QR"] = [qr_data] | |
# Generate CSV content in memory using BytesIO | |
csv_io = io.BytesIO() | |
pd.DataFrame([{k: "; ".join(v) for k, v in results.items()}]).to_csv(csv_io, index=False) | |
csv_io.seek(0) | |
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="wb") as tmp_file: | |
tmp_file.write(csv_io.getvalue()) | |
csv_path = tmp_file.name | |
# Return tuple: (OCR text, JSON entities, CSV file path, error message) | |
return ocr_text, {k: "; ".join(v) for k, v in results.items()}, csv_path, "" | |
except Exception as e: | |
logger.error("Processing failed: " + traceback.format_exc()) | |
return "", {}, None, f"Error: {str(e)}\n{traceback.format_exc()}" | |
# Gradio Interface setup (output structure remains unchanged) | |
title = 'Business Card Information Extractor' | |
description = 'Extracts text using PaddleOCR and entities using GLiNER (with added website label) along with QR code scanning.' | |
# Examples can be updated accordingly | |
examples = [ | |
['example_imgs/example.jpg', 0.5], | |
['example_imgs/demo003.jpeg', 0.7], | |
] | |
css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}" | |
if __name__ == '__main__': | |
demo = gr.Interface( | |
inference, | |
[gr.Image(type='pil', label='Upload Business Card'), | |
gr.Slider(0.1, 1, 0.5, step=0.1, label='Confidence Threshold')], | |
[gr.Textbox(label="Extracted Text"), | |
gr.JSON(label="Entities"), | |
gr.File(label="Download CSV"), | |
gr.Textbox(label="Error Details")], | |
title=title, | |
description=description, | |
examples=examples, | |
css=css, | |
cache_examples=True | |
) | |
demo.queue(max_size=10) | |
demo.launch() | |