paddle-ocr-demo / app.py
codic's picture
update -- before was working
c66181c verified
raw
history blame
6.06 kB
from paddleocr import PaddleOCR
from gliner import GLiNER
from PIL import Image
import gradio as gr
import numpy as np
import logging
import tempfile
import pandas as pd
import re
import traceback
import zxingcpp
# --------------------------
# Configuration & Constants
# --------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
COUNTRY_CODES = {
'SAUDI': {'code': '+966', 'pattern': r'^(\+9665\d{8}|05\d{8})$'},
'UAE': {'code': '+971', 'pattern': r'^(\+9715\d{8}|05\d{8})$'}
}
VALIDATION_PATTERNS = {
'email': re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', re.IGNORECASE),
'website': re.compile(r'(?:https?://)?(?:www\.)?([A-Za-z0-9-]+\.[A-Za-z]{2,})'),
'name': re.compile(r'^[A-Z][a-z]+(?:\s+[A-Z][a-z]+){1,2}$')
}
# --------------------------
# Core Processing Functions
# --------------------------
def process_phone_number(raw_number: str) -> str:
"""Validate and standardize phone numbers for supported countries"""
cleaned = re.sub(r'[^\d+]', '', raw_number)
for country, config in COUNTRY_CODES.items():
if re.match(config['pattern'], cleaned):
if cleaned.startswith('0'):
return f"{config['code']}{cleaned[1:]}"
if cleaned.startswith('5'):
return f"{config['code']}{cleaned}"
return cleaned
return None
def extract_contact_info(text: str) -> dict:
"""Extract and validate all contact information from text"""
contacts = {
'phones': set(),
'emails': set(),
'websites': set()
}
# Phone number extraction
for match in re.finditer(r'(\+?\d{10,13}|05\d{8})', text):
if processed := process_phone_number(match.group()):
contacts['phones'].add(processed)
# Email validation
contacts['emails'].update(
email.lower() for email in VALIDATION_PATTERNS['email'].findall(text)
)
# Website normalization
for match in VALIDATION_PATTERNS['website'].finditer(text):
domain = match.group(1).lower()
if '.' in domain:
contacts['websites'].add(f"www.{domain.split('/')[0]}")
return {k: list(v) for k, v in contacts.items() if v}
def process_entities(entities: list, ocr_text: list) -> dict:
"""Process GLiNER entities with validation and fallbacks"""
result = {
'name': None,
'company': None,
'title': None,
'address': None
}
# Entity extraction
for entity in entities:
label = entity['label'].lower()
text = entity['text'].strip()
if label == 'person name' and VALIDATION_PATTERNS['name'].match(text):
result['name'] = text.title()
elif label == 'company name':
result['company'] = text
elif label == 'job title':
result['title'] = text.title()
elif label == 'address':
result['address'] = text
# Name fallback from OCR text
if not result['name']:
for text in ocr_text:
if VALIDATION_PATTERNS['name'].match(text):
result['name'] = text.title()
break
return result
# --------------------------
# Main Processing Pipeline
# --------------------------
def process_business_card(img: Image.Image, confidence: float) -> tuple:
"""Full processing pipeline for business card images"""
try:
# Initialize OCR
ocr_engine = PaddleOCR(lang='en', use_gpu=False)
# OCR Processing
ocr_result = ocr_engine.ocr(np.array(img), cls=True)
ocr_text = [line[1][0] for line in ocr_result[0]]
full_text = " ".join(ocr_text)
# Entity Recognition
labels = ["person name", "company name", "job title",
"phone number", "email address", "address",
"website"]
entities = gliner_model.predict_entities(full_text, labels, threshold=confidence)
# Data Extraction
contacts = extract_contact_info(full_text)
entity_data = process_entities(entities, ocr_text)
qr_data = zxingcpp.read_barcodes(np.array(img.convert('RGB')))
# Compile Final Results
results = {
'Person Name': entity_data['name'],
'Company Name': entity_data['company'] or (
contacts['emails'][0].split('@')[1].split('.')[0].title()
if contacts['emails'] else None
),
'Job Title': entity_data['title'],
'Phone Numbers': contacts['phones'],
'Email Addresses': contacts['emails'],
'Address': entity_data['address'] or next(
(t for t in ocr_text if any(kw in t.lower()
for kw in {'street', 'ave', 'road'})), None
),
'Website': contacts['websites'][0] if contacts['websites'] else None,
'QR Code': qr_data[0].text if qr_data else None
}
# Generate CSV Output
with tempfile.NamedTemporaryFile(suffix='.csv', delete=False, mode='w') as f:
pd.DataFrame([results]).to_csv(f)
csv_path = f.name
return full_text, results, csv_path, ""
except Exception as e:
logger.error(f"Processing Error: {traceback.format_exc()}")
return "", {}, None, f"Error: {str(e)}"
# --------------------------
# Gradio Interface
# --------------------------
interface = gr.Interface(
fn=process_business_card,
inputs=[
gr.Image(type='pil', label='Upload Business Card'),
gr.Slider(0.1, 1.0, value=0.4, label='Confidence Threshold')
],
outputs=[
gr.Textbox(label='OCR Result'),
gr.JSON(label='Structured Data'),
gr.File(label='Download CSV'),
gr.Textbox(label='Error Log')
],
title='Enterprise Business Card Parser',
description='Multi-country support with comprehensive validation'
)
if __name__ == '__main__':
interface.launch()