paddle-ocr-demo / app.py
codic's picture
trying the other update
6bba885 verified
raw
history blame
5.15 kB
from paddleocr import PaddleOCR
from gliner import GLiNER
import json
from PIL import Image
import gradio as gr
import numpy as np
import cv2
import logging
import os
from pathlib import Path
import tempfile
import pandas as pd
import io
import re
import traceback
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set up GLiNER environment variables (adjust if needed)
os.environ['GLINER_HOME'] = './gliner_models'
# Load GLiNER model (do not change the model)
try:
logger.info("Loading GLiNER model...")
gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
except Exception as e:
logger.error("Failed to load GLiNER model")
raise e
# Get a random color (used for drawing bounding boxes, if needed)
def get_random_color():
c = tuple(np.random.randint(0, 256, 3).tolist())
return c
# Draw OCR bounding boxes (this function is kept for debugging/visualization purposes)
def draw_ocr_bbox(image, boxes, colors):
for i in range(len(boxes)):
box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
image = cv2.polylines(np.array(image), [box], True, colors[i], 2)
return image
# Scan for a QR code using OpenCV's QRCodeDetector
def scan_qr_code(image):
try:
# Ensure the image is in numpy array format
image_np = np.array(image) if not isinstance(image, np.ndarray) else image
qr_detector = cv2.QRCodeDetector()
data, points, _ = qr_detector.detectAndDecode(image_np)
if data:
return data.strip()
return None
except Exception as e:
logger.error("QR code scanning failed: " + str(e))
return None
# Main inference function
def inference(img: Image.Image, confidence):
try:
# Initialize PaddleOCR for English only (removed other languages)
ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False,
det_model_dir=f'./models/det/en',
cls_model_dir=f'./models/cls/en',
rec_model_dir=f'./models/rec/en')
img_np = np.array(img)
result = ocr.ocr(img_np, cls=True)[0]
# Concatenate all recognized texts
ocr_texts = [line[1][0] for line in result]
ocr_text = " ".join(ocr_texts)
# (Optional) Draw bounding boxes on the image if needed for debugging
image_rgb = img.convert('RGB')
boxes = [line[0] for line in result]
colors = [get_random_color() for _ in boxes]
# Uncomment next two lines if you want to visualize OCR results:
# im_show = draw_ocr_bbox(image_rgb, boxes, colors)
# im_show = Image.fromarray(im_show)
# Extract entities using GLiNER with updated labels (adding 'website')
labels = ["person name", "company name", "job title", "phone", "email", "address", "website"]
entities = gliner_model.predict_entities(ocr_text, labels, threshold=confidence, flat_ner=True)
results = {label.title(): [] for label in labels}
for entity in entities:
lab = entity["label"].title()
if lab in results:
results[lab].append(entity["text"])
# Scan the original image for a QR code and add it if found
qr_data = scan_qr_code(img)
if qr_data:
results["QR"] = [qr_data]
# Generate CSV content in memory using BytesIO
csv_io = io.BytesIO()
pd.DataFrame([{k: "; ".join(v) for k, v in results.items()}]).to_csv(csv_io, index=False)
csv_io.seek(0)
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="wb") as tmp_file:
tmp_file.write(csv_io.getvalue())
csv_path = tmp_file.name
# Return tuple: (OCR text, JSON entities, CSV file path, error message)
return ocr_text, {k: "; ".join(v) for k, v in results.items()}, csv_path, ""
except Exception as e:
logger.error("Processing failed: " + traceback.format_exc())
return "", {}, None, f"Error: {str(e)}\n{traceback.format_exc()}"
# Gradio Interface setup (output structure remains unchanged)
title = 'Business Card Information Extractor'
description = 'Extracts text using PaddleOCR and entities using GLiNER (with added website label) along with QR code scanning.'
# Examples can be updated accordingly
examples = [
['example_imgs/example.jpg', 0.5],
['example_imgs/demo003.jpeg', 0.7],
]
css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
if __name__ == '__main__':
demo = gr.Interface(
inference,
[gr.Image(type='pil', label='Upload Business Card'),
gr.Slider(0.1, 1, 0.5, step=0.1, label='Confidence Threshold')],
[gr.Textbox(label="Extracted Text"),
gr.JSON(label="Entities"),
gr.File(label="Download CSV"),
gr.Textbox(label="Error Details")],
title=title,
description=description,
examples=examples,
css=css,
cache_examples=True
)
demo.queue(max_size=10)
demo.launch()