from paddleocr import PaddleOCR from gliner import GLiNER from PIL import Image import gradio as gr import numpy as np import cv2 import logging import os import tempfile import pandas as pd import re import traceback import zxingcpp # QR decoding # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Environment setup os.environ['GLINER_HOME'] = './gliner_models' # Load GLiNER model try: logger.info("Loading GLiNER model...") gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1") except Exception: logger.exception("Failed to load GLiNER model") raise # Regex patterns EMAIL_REGEX = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b") WEBSITE_REGEX = re.compile(r"(?:https?://)?(?:www\.)?([A-Za-z0-9-]+\.[A-Za-z]{2,})") # UAE phone country code UAE_CODE = '+971' # Utility functions def extract_emails(text: str) -> list[str]: return [e.lower() for e in EMAIL_REGEX.findall(text)] def extract_websites(text: str) -> list[str]: return [m.lower() for m in WEBSITE_REGEX.findall(text)] def normalize_website(url: str) -> str | None: u = url.lower().replace('www.', '').split('/')[0] return f"www.{u}" if re.match(r"^[a-z0-9-]+\.[a-z]{2,}$", u) else None # Phone cleaning: treat all local '0XXXXXXXXX' as UAE def clean_phone_number(phone: str) -> str | None: cleaned = re.sub(r"\D", "", phone) # Local UAE numbers (10 digits starting with 0) if len(cleaned) == 10 and cleaned.startswith('0'): return UAE_CODE + cleaned[1:] # International UAE numbers without plus (12 digits starting '971') if len(cleaned) == 12 and cleaned.startswith('971'): return '+' + cleaned # Already plus-prefixed UAE number if phone.strip().startswith('+971') and len(cleaned) == 12: return phone.strip() return None # Extract phone numbers from text def process_phone_numbers(text: str) -> list[str]: found = [] # Match '05' followed by 8 digits or plus variant for match in re.finditer(r'(?:05\d{8}|\+?\d{8,12})', text): raw = match.group().strip() if (c := clean_phone_number(raw)): found.append(c) return list(set(found)) # Address extraction def extract_address(ocr_texts: list[str]) -> str | None: keywords = ["block","street","ave","area","industrial","road"] parts = [t for t in ocr_texts if any(kw in t.lower() for kw in keywords)] return " ".join(parts) if parts else None # QR scanning def scan_qr_code(image: Image.Image) -> str | None: try: with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: image.save(tmp, format="PNG") path = tmp.name img_cv = cv2.imread(path) # Direct decoding try: res = zxingcpp.read_barcodes(img_cv) if res and res[0].text: return res[0].text.strip() except: logger.warning("Direct QR decode failed") # Fallback recolor default_color = (0, 0, 0) tol = 50 pix = list(image.convert('RGB').getdata()) new_pix = [default_color if all(abs(p[i]-default_color[i])<=tol for i in range(3)) else (255,255,255) for p in pix] img_conv = Image.new('RGB', image.size) img_conv.putdata(new_pix) cv2.imwrite(path + '_conv.png', cv2.cvtColor(np.array(img_conv), cv2.COLOR_RGB2BGR)) res = zxingcpp.read_barcodes(cv2.imread(path + '_conv.png')) if res and res[0].text: return res[0].text.strip() except Exception: logger.exception("QR scan error") return None # Deduplication def deduplicate_data(results: dict[str, list[str]]) -> None: def clean_list(items, normalizer=lambda x: x): seen = set(); out = [] for raw in items: for part in re.split(r'[;,]\s*', raw): p = part.strip() if not p: continue norm = normalizer(p) if norm and norm not in seen: seen.add(norm); out.append(norm) return out results['Email Address'] = clean_list(results.get('Email Address', []), lambda e: e.lower()) results['Website'] = clean_list(results.get('Website', []), normalize_website) results['Phone Number'] = clean_list(results.get('Phone Number', []), clean_phone_number) for key in ['Person Name','Company Name','Job Title','Address','QR Code']: seen = set(); out = [] for v in results.get(key, []): vv = v.strip() if vv and vv not in seen: seen.add(vv); out.append(vv) results[key] = out # Inference pipeline def inference(img: Image.Image, confidence: float): try: ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False) arr = np.array(img) raw = ocr.ocr(arr, cls=True)[0] ocr_texts = [ln[1][0] for ln in raw] full_text = ' '.join(ocr_texts) labels = ['person name','company name','job title','phone number','email address','address','website'] entities = gliner_model.predict_entities(full_text, labels, threshold=confidence, flat_ner=True) results = {k: [] for k in ['Person Name','Company Name','Job Title','Phone Number','Email Address','Address','Website','QR Code']} # Process NER entities for ent in entities: txt, lbl = ent['text'].strip(), ent['label'].lower() if lbl == 'person name': results['Person Name'].append(txt) elif lbl == 'company name': results['Company Name'].append(txt) elif lbl == 'job title': results['Job Title'].append(txt.title()) elif lbl == 'phone number': if (c := clean_phone_number(txt)): results['Phone Number'].append(c) elif lbl == 'email address' and EMAIL_REGEX.fullmatch(txt): results['Email Address'].append(txt.lower()) elif lbl == 'website' and WEBSITE_REGEX.search(txt): if (n := normalize_website(txt)): results['Website'].append(n) elif lbl == 'address': results['Address'].append(txt) # Regex fallbacks results['Email Address'] += extract_emails(full_text) results['Website'] += extract_websites(full_text) results['Phone Number'] += process_phone_numbers(full_text) # QR code if qr := scan_qr_code(img): results['QR Code'].append(qr) # Address fallback if not results['Address'] and (addr := extract_address(ocr_texts)): results['Address'].append(addr) # Deduplicate all fields deduplicate_data(results) # Company fallback if not results['Company Name'] and (dom := (results['Email Address'] or results['Website'])): domain = dom[0].split('@')[-1].split('.')[0] results['Company Name'].append(domain.title()) # Name fallback if not results['Person Name']: for t in ocr_texts: if re.match(r'^(?:[A-Z][a-z]+\s?){2,}$', t): results['Person Name'].append(t) break # Prepare CSV csv_map = {k: '; '.join(v) for k, v in results.items()} with tempfile.NamedTemporaryFile(suffix='.csv', delete=False, mode='w') as f: pd.DataFrame([csv_map]).to_csv(f, index=False) csv_path = f.name return full_text, results, csv_path, '' except Exception: err = traceback.format_exc() logger.error(f"Processing failed: {err}") empty = {k: [] for k in ['Person Name','Company Name','Job Title','Phone Number','Email Address','Address','Website','QR Code']} return '', empty, None, f"Error:\n{err}" # Gradio Interface if __name__ == '__main__': demo = gr.Interface( inference, [gr.Image(type='pil', label='Upload Business Card'), gr.Slider(0.1, 1, 0.4, step=0.1, label='Confidence Threshold')], [gr.Textbox(label="OCR Result"), gr.JSON(label="Structured Data"), gr.File(label="Download CSV"), gr.Textbox(label="Error Log")], title='Enhanced Business Card Parser', description='Entity extraction with AI and regex validation (UAE-focused phone support)', css=".gr-interface {max-width: 800px !important;}" ) demo.launch()