Spaces:
Sleeping
Sleeping
from paddleocr import PaddleOCR | |
from gliner import GLiNER | |
import json | |
from PIL import Image | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
import logging | |
import os | |
from pathlib import Path | |
import tempfile | |
import pandas as pd | |
import io | |
import re | |
import traceback | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Set up GLiNER environment variables | |
os.environ['GLINER_HOME'] = './gliner_models' | |
# Load GLiNER model | |
try: | |
logger.info("Loading GLiNER model...") | |
gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1") | |
except Exception as e: | |
logger.error("Failed to load GLiNER model") | |
raise e | |
# Helper functions | |
def get_random_color(): | |
return tuple(np.random.randint(0, 256, 3).tolist() | |
def draw_ocr_bbox(image, boxes, colors): | |
for i in range(len(boxes)): | |
box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64) | |
image = cv2.polylines(np.array(image), [box], True, colors[i], 2) | |
return image | |
def scan_qr_code(image): | |
try: | |
image_np = np.array(image) | |
qr_detector = cv2.QRCodeDetector() | |
data, _, _ = qr_detector.detectAndDecode(image_np) | |
return data.strip() if data else None | |
except Exception as e: | |
logger.error(f"QR scan failed: {str(e)}") | |
return None | |
def extract_emails(text): | |
email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" | |
return re.findall(email_regex, text) | |
def extract_websites(text): | |
website_regex = r"(?:https?://)?(?:www\.)?[A-Za-z0-9-]+\.[A-Za-z]{2,}(?:/\S*)?" | |
matches = re.findall(website_regex, text) | |
return [m for m in matches if '@' not in m] | |
def clean_phone_number(phone): | |
return re.sub(r"[^\d+]", "", phone) | |
# Main inference function | |
def inference(img: Image.Image, confidence): | |
try: | |
# Initialize PaddleOCR | |
ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False, | |
det_model_dir='./models/det/en', | |
cls_model_dir='./models/cls/en', | |
rec_model_dir='./models/rec/en') | |
# OCR Processing | |
img_np = np.array(img) | |
result = ocr.ocr(img_np, cls=True)[0] | |
ocr_texts = [line[1][0] for line in result] | |
ocr_text = " ".join(ocr_texts) | |
# Entity Extraction | |
labels = ["person name", "company name", "job title", | |
"phone number", "email address", "physical address", | |
"website url"] | |
entities = gliner_model.predict_entities(ocr_text, labels, threshold=confidence, flat_ner=True) | |
results = { | |
"Person Name": [], | |
"Company Name": [], | |
"Job Title": [], | |
"Phone Number": [], | |
"Email Address": [], | |
"Physical Address": [], | |
"Website Url": [], | |
"QR Code": [] | |
} | |
# Process GLiNER results | |
for entity in entities: | |
label = entity["label"].title().replace(" ", "") | |
if label == "PhoneNumber": | |
cleaned = clean_phone_number(entity["text"]) | |
if cleaned: results["Phone Number"].append(cleaned) | |
elif label == "EmailAddress": | |
results["Email Address"].append(entity["text"].lower()) | |
elif label == "WebsiteUrl": | |
results["Website Url"].append(entity["text"].lower()) | |
elif label in results: | |
results[label].append(entity["text"]) | |
# Regex fallbacks | |
if not results["Email Address"]: | |
results["Email Address"] = extract_emails(ocr_text) | |
if not results["Website Url"]: | |
results["Website Url"] = extract_websites(ocr_text) | |
# Phone number validation | |
phone_numbers = [] | |
for text in ocr_texts: | |
numbers = re.findall(r'(?:\+?[0-9]\s?[0-9]+)+', text) | |
phone_numbers.extend([clean_phone_number(n) for n in numbers]) | |
results["Phone Number"] = list(set(phone_numbers + results["Phone Number"])) | |
# QR Code handling | |
qr_data = scan_qr_code(img) | |
if qr_data: | |
results["QR Code"] = [qr_data] | |
# Create CSV | |
csv_data = {k: "; ".join(v) for k, v in results.items() if v} | |
csv_io = io.BytesIO() | |
pd.DataFrame([csv_data]).to_csv(csv_io, index=False) | |
csv_io.seek(0) | |
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="wb") as tmp_file: | |
tmp_file.write(csv_io.getvalue()) | |
csv_path = tmp_file.name | |
return ocr_text, csv_data, csv_path, "" | |
except Exception as e: | |
logger.error(f"Processing failed: {traceback.format_exc()}") | |
return "", {}, None, f"Error: {str(e)}\n{traceback.format_exc()}" | |
# Gradio Interface | |
title = 'Enhanced Business Card Parser' | |
description = 'Extracts entities with combined AI and regex validation, including QR codes' | |
examples = [ | |
['example_imgs/example.jpg', 0.4], | |
['example_imgs/demo003.jpeg', 0.5], | |
] | |
css = """.output_image, .input_image {height: 40rem !important; width: 100% !important;} | |
.gr-interface {max-width: 800px !important;}""" | |
if __name__ == '__main__': | |
demo = gr.Interface( | |
inference, | |
[gr.Image(type='pil', label='Upload Business Card'), | |
gr.Slider(0.1, 1, 0.4, step=0.1, label='Confidence Threshold')], | |
[gr.Textbox(label="OCR Result"), | |
gr.JSON(label="Structured Data"), | |
gr.File(label="Download CSV"), | |
gr.Textbox(label="Error Log")], | |
title=title, | |
description=description, | |
examples=examples, | |
css=css, | |
cache_examples=True | |
) | |
demo.queue(max_size=20) | |
demo.launch() |