Spaces:
Running
Running
from paddleocr import PaddleOCR | |
from gliner import GLiNER | |
from PIL import Image | |
import gradio as gr | |
import numpy as np | |
import logging | |
import tempfile | |
import pandas as pd | |
import re | |
import traceback | |
import zxingcpp | |
# -------------------------- | |
# Configuration & Constants | |
# -------------------------- | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
COUNTRY_CODES = { | |
'SAUDI': {'code': '+966', 'pattern': r'^(\+9665\d{8}|05\d{8})$'}, | |
'UAE': {'code': '+971', 'pattern': r'^(\+9715\d{8}|05\d{8})$'} | |
} | |
VALIDATION_PATTERNS = { | |
'email': re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', re.IGNORECASE), | |
'website': re.compile(r'(?:https?://)?(?:www\.)?([A-Za-z0-9-]+\.[A-Za-z]{2,})'), | |
'name': re.compile(r'^[A-Z][a-z]+(?:\s+[A-Z][a-z]+){1,2}$') | |
} | |
# -------------------------- | |
# Core Processing Functions | |
# -------------------------- | |
def process_phone_number(raw_number: str) -> str: | |
"""Validate and standardize phone numbers for supported countries""" | |
cleaned = re.sub(r'[^\d+]', '', raw_number) | |
for country, config in COUNTRY_CODES.items(): | |
if re.match(config['pattern'], cleaned): | |
if cleaned.startswith('0'): | |
return f"{config['code']}{cleaned[1:]}" | |
if cleaned.startswith('5'): | |
return f"{config['code']}{cleaned}" | |
return cleaned | |
return None | |
def extract_contact_info(text: str) -> dict: | |
"""Extract and validate all contact information from text""" | |
contacts = { | |
'phones': set(), | |
'emails': set(), | |
'websites': set() | |
} | |
# Phone number extraction | |
for match in re.finditer(r'(\+?\d{10,13}|05\d{8})', text): | |
if processed := process_phone_number(match.group()): | |
contacts['phones'].add(processed) | |
# Email validation | |
contacts['emails'].update( | |
email.lower() for email in VALIDATION_PATTERNS['email'].findall(text) | |
) | |
# Website normalization | |
for match in VALIDATION_PATTERNS['website'].finditer(text): | |
domain = match.group(1).lower() | |
if '.' in domain: | |
contacts['websites'].add(f"www.{domain.split('/')[0]}") | |
return {k: list(v) for k, v in contacts.items() if v} | |
def process_entities(entities: list, ocr_text: list) -> dict: | |
"""Process GLiNER entities with validation and fallbacks""" | |
result = { | |
'name': None, | |
'company': None, | |
'title': None, | |
'address': None | |
} | |
# Entity extraction | |
for entity in entities: | |
label = entity['label'].lower() | |
text = entity['text'].strip() | |
if label == 'person name' and VALIDATION_PATTERNS['name'].match(text): | |
result['name'] = text.title() | |
elif label == 'company name': | |
result['company'] = text | |
elif label == 'job title': | |
result['title'] = text.title() | |
elif label == 'address': | |
result['address'] = text | |
# Name fallback from OCR text | |
if not result['name']: | |
for text in ocr_text: | |
if VALIDATION_PATTERNS['name'].match(text): | |
result['name'] = text.title() | |
break | |
return result | |
# -------------------------- | |
# Main Processing Pipeline | |
# -------------------------- | |
def process_business_card(img: Image.Image, confidence: float) -> tuple: | |
"""Full processing pipeline for business card images""" | |
try: | |
# Initialize OCR | |
ocr_engine = PaddleOCR(lang='en', use_gpu=False) | |
# OCR Processing | |
ocr_result = ocr_engine.ocr(np.array(img), cls=True) | |
ocr_text = [line[1][0] for line in ocr_result[0]] | |
full_text = " ".join(ocr_text) | |
# Entity Recognition | |
labels = ["person name", "company name", "job title", | |
"phone number", "email address", "address", | |
"website"] | |
entities = gliner_model.predict_entities(full_text, labels, threshold=confidence) | |
# Data Extraction | |
contacts = extract_contact_info(full_text) | |
entity_data = process_entities(entities, ocr_text) | |
qr_data = zxingcpp.read_barcodes(np.array(img.convert('RGB'))) | |
# Compile Final Results | |
results = { | |
'Person Name': entity_data['name'], | |
'Company Name': entity_data['company'] or ( | |
contacts['emails'][0].split('@')[1].split('.')[0].title() | |
if contacts['emails'] else None | |
), | |
'Job Title': entity_data['title'], | |
'Phone Numbers': contacts['phones'], | |
'Email Addresses': contacts['emails'], | |
'Address': entity_data['address'] or next( | |
(t for t in ocr_text if any(kw in t.lower() | |
for kw in {'street', 'ave', 'road'})), None | |
), | |
'Website': contacts['websites'][0] if contacts['websites'] else None, | |
'QR Code': qr_data[0].text if qr_data else None | |
} | |
# Generate CSV Output | |
with tempfile.NamedTemporaryFile(suffix='.csv', delete=False, mode='w') as f: | |
pd.DataFrame([results]).to_csv(f) | |
csv_path = f.name | |
return full_text, results, csv_path, "" | |
except Exception as e: | |
logger.error(f"Processing Error: {traceback.format_exc()}") | |
return "", {}, None, f"Error: {str(e)}" | |
# -------------------------- | |
# Gradio Interface | |
# -------------------------- | |
interface = gr.Interface( | |
fn=process_business_card, | |
inputs=[ | |
gr.Image(type='pil', label='Upload Business Card'), | |
gr.Slider(0.1, 1.0, value=0.4, label='Confidence Threshold') | |
], | |
outputs=[ | |
gr.Textbox(label='OCR Result'), | |
gr.JSON(label='Structured Data'), | |
gr.File(label='Download CSV'), | |
gr.Textbox(label='Error Log') | |
], | |
title='Enterprise Business Card Parser', | |
description='Multi-country support with comprehensive validation' | |
) | |
if __name__ == '__main__': | |
interface.launch() |