Spaces:
Sleeping
Sleeping
from paddleocr import PaddleOCR | |
from gliner import GLiNER | |
from PIL import Image | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
import logging | |
import os | |
import tempfile | |
import pandas as pd | |
import re | |
import traceback | |
import zxingcpp # QR decoding | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Environment setup | |
os.environ['GLINER_HOME'] = './gliner_models' | |
# Load GLiNER model | |
try: | |
logger.info("Loading GLiNER model...") | |
gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1") | |
except Exception: | |
logger.exception("Failed to load GLiNER model") | |
raise | |
# Regex patterns | |
EMAIL_REGEX = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b") | |
WEBSITE_REGEX = re.compile(r"(?:https?://)?(?:www\.)?([A-Za-z0-9-]+\.[A-Za-z]{2,})") | |
# UAE phone country code | |
UAE_CODE = '+971' | |
# Utility functions | |
def extract_emails(text: str) -> list[str]: | |
return [e.lower() for e in EMAIL_REGEX.findall(text)] | |
def extract_websites(text: str) -> list[str]: | |
return [m.lower() for m in WEBSITE_REGEX.findall(text)] | |
def normalize_website(url: str) -> str | None: | |
u = url.lower().replace('www.', '').split('/')[0] | |
return f"www.{u}" if re.match(r"^[a-z0-9-]+\.[a-z]{2,}$", u) else None | |
# Phone cleaning: treat all local '0XXXXXXXXX' as UAE | |
def clean_phone_number(phone: str) -> str | None: | |
cleaned = re.sub(r"\D", "", phone) | |
# Local UAE numbers (10 digits starting with 0) | |
if len(cleaned) == 10 and cleaned.startswith('0'): | |
return UAE_CODE + cleaned[1:] | |
# International UAE numbers without plus (12 digits starting '971') | |
if len(cleaned) == 12 and cleaned.startswith('971'): | |
return '+' + cleaned | |
# Already plus-prefixed UAE number | |
if phone.strip().startswith('+971') and len(cleaned) == 12: | |
return phone.strip() | |
return None | |
# Extract phone numbers from text | |
def process_phone_numbers(text: str) -> list[str]: | |
found = [] | |
# Match '05' followed by 8 digits or plus variant | |
for match in re.finditer(r'(?:05\d{8}|\+?\d{8,12})', text): | |
raw = match.group().strip() | |
if (c := clean_phone_number(raw)): | |
found.append(c) | |
return list(set(found)) | |
# Address extraction | |
def extract_address(ocr_texts: list[str]) -> str | None: | |
keywords = ["block","street","ave","area","industrial","road"] | |
parts = [t for t in ocr_texts if any(kw in t.lower() for kw in keywords)] | |
return " ".join(parts) if parts else None | |
# QR scanning | |
def scan_qr_code(image: Image.Image) -> str | None: | |
try: | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: | |
image.save(tmp, format="PNG") | |
path = tmp.name | |
img_cv = cv2.imread(path) | |
# Direct decoding | |
try: | |
res = zxingcpp.read_barcodes(img_cv) | |
if res and res[0].text: | |
return res[0].text.strip() | |
except: | |
logger.warning("Direct QR decode failed") | |
# Fallback recolor | |
default_color = (0, 0, 0) | |
tol = 50 | |
pix = list(image.convert('RGB').getdata()) | |
new_pix = [default_color if all(abs(p[i]-default_color[i])<=tol for i in range(3)) else (255,255,255) for p in pix] | |
img_conv = Image.new('RGB', image.size) | |
img_conv.putdata(new_pix) | |
cv2.imwrite(path + '_conv.png', cv2.cvtColor(np.array(img_conv), cv2.COLOR_RGB2BGR)) | |
res = zxingcpp.read_barcodes(cv2.imread(path + '_conv.png')) | |
if res and res[0].text: | |
return res[0].text.strip() | |
except Exception: | |
logger.exception("QR scan error") | |
return None | |
# Deduplication | |
def deduplicate_data(results: dict[str, list[str]]) -> None: | |
def clean_list(items, normalizer=lambda x: x): | |
seen = set(); out = [] | |
for raw in items: | |
for part in re.split(r'[;,]\s*', raw): | |
p = part.strip() | |
if not p: continue | |
norm = normalizer(p) | |
if norm and norm not in seen: | |
seen.add(norm); out.append(norm) | |
return out | |
results['Email Address'] = clean_list(results.get('Email Address', []), lambda e: e.lower()) | |
results['Website'] = clean_list(results.get('Website', []), normalize_website) | |
results['Phone Number'] = clean_list(results.get('Phone Number', []), clean_phone_number) | |
for key in ['Person Name','Company Name','Job Title','Address','QR Code']: | |
seen = set(); out = [] | |
for v in results.get(key, []): | |
vv = v.strip() | |
if vv and vv not in seen: | |
seen.add(vv); out.append(vv) | |
results[key] = out | |
# Inference pipeline | |
def inference(img: Image.Image, confidence: float): | |
try: | |
ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False) | |
arr = np.array(img) | |
raw = ocr.ocr(arr, cls=True)[0] | |
ocr_texts = [ln[1][0] for ln in raw] | |
full_text = ' '.join(ocr_texts) | |
labels = ['person name','company name','job title','phone number','email address','address','website'] | |
entities = gliner_model.predict_entities(full_text, labels, threshold=confidence, flat_ner=True) | |
results = {k: [] for k in ['Person Name','Company Name','Job Title','Phone Number','Email Address','Address','Website','QR Code']} | |
# Process NER entities | |
for ent in entities: | |
txt, lbl = ent['text'].strip(), ent['label'].lower() | |
if lbl == 'person name': results['Person Name'].append(txt) | |
elif lbl == 'company name': results['Company Name'].append(txt) | |
elif lbl == 'job title': results['Job Title'].append(txt.title()) | |
elif lbl == 'phone number': | |
if (c := clean_phone_number(txt)): results['Phone Number'].append(c) | |
elif lbl == 'email address' and EMAIL_REGEX.fullmatch(txt): | |
results['Email Address'].append(txt.lower()) | |
elif lbl == 'website' and WEBSITE_REGEX.search(txt): | |
if (n := normalize_website(txt)): results['Website'].append(n) | |
elif lbl == 'address': results['Address'].append(txt) | |
# Regex fallbacks | |
results['Email Address'] += extract_emails(full_text) | |
results['Website'] += extract_websites(full_text) | |
results['Phone Number'] += process_phone_numbers(full_text) | |
# QR code | |
if qr := scan_qr_code(img): | |
results['QR Code'].append(qr) | |
# Address fallback | |
if not results['Address'] and (addr := extract_address(ocr_texts)): | |
results['Address'].append(addr) | |
# Deduplicate all fields | |
deduplicate_data(results) | |
# Company fallback | |
if not results['Company Name'] and (dom := (results['Email Address'] or results['Website'])): | |
domain = dom[0].split('@')[-1].split('.')[0] | |
results['Company Name'].append(domain.title()) | |
# Name fallback | |
if not results['Person Name']: | |
for t in ocr_texts: | |
if re.match(r'^(?:[A-Z][a-z]+\s?){2,}$', t): | |
results['Person Name'].append(t) | |
break | |
# Prepare CSV | |
csv_map = {k: '; '.join(v) for k, v in results.items()} | |
with tempfile.NamedTemporaryFile(suffix='.csv', delete=False, mode='w') as f: | |
pd.DataFrame([csv_map]).to_csv(f, index=False) | |
csv_path = f.name | |
return full_text, results, csv_path, '' | |
except Exception: | |
err = traceback.format_exc() | |
logger.error(f"Processing failed: {err}") | |
empty = {k: [] for k in ['Person Name','Company Name','Job Title','Phone Number','Email Address','Address','Website','QR Code']} | |
return '', empty, None, f"Error:\n{err}" | |
# Gradio Interface | |
if __name__ == '__main__': | |
demo = gr.Interface( | |
inference, | |
[gr.Image(type='pil', label='Upload Business Card'), | |
gr.Slider(0.1, 1, 0.4, step=0.1, label='Confidence Threshold')], | |
[gr.Textbox(label="OCR Result"), | |
gr.JSON(label="Structured Data"), | |
gr.File(label="Download CSV"), | |
gr.Textbox(label="Error Log")], | |
title='Enhanced Business Card Parser', | |
description='Entity extraction with AI and regex validation (UAE-focused phone support)', | |
css=".gr-interface {max-width: 800px !important;}" | |
) | |
demo.launch() | |