Spaces:
Sleeping
Sleeping
from paddleocr import PaddleOCR | |
from gliner import GLiNER | |
from PIL import Image | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
import logging | |
import os | |
import tempfile | |
import pandas as pd | |
import io | |
import re | |
import traceback | |
import zxingcpp # Added zxingcpp for QR decoding | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Set up GLiNER environment variables | |
os.environ['GLINER_HOME'] = './gliner_models' | |
# Load GLiNER model | |
try: | |
logger.info("Loading GLiNER model...") | |
gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1") | |
except Exception as e: | |
logger.error("Failed to load GLiNER model") | |
raise e | |
# Get a random color (used for drawing bounding boxes, if needed) | |
def get_random_color(): | |
return tuple(np.random.randint(0, 256, 3).tolist()) | |
def scan_qr_code(image): | |
""" | |
Attempts to scan a QR code from the given PIL image using zxingcpp. | |
The image is first saved to a temporary file to be read by zxingcpp. | |
If the direct decoding fails, the function tries a fallback | |
where the image is converted based on a default QR color (black) and tolerance. | |
""" | |
try: | |
# Save the PIL image to a temporary file | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: | |
image.save(tmp, format="PNG") | |
tmp_path = tmp.name | |
# Convert the saved image to a CV2 image | |
img_cv = cv2.imread(tmp_path) | |
# First attempt: direct decoding with zxingcpp | |
try: | |
results = zxingcpp.read_barcodes(img_cv) | |
if results and results[0].text: | |
return results[0].text.strip() | |
except Exception as e: | |
logger.warning(f"Direct zxingcpp decoding failed: {e}") | |
# Fallback: Process image by converting specific QR colors with default parameters. | |
default_color = "#000000" # Default QR color assumed (black) | |
tolerance = 50 # Fixed tolerance value | |
qr_img = image.convert("RGB") | |
datas = list(qr_img.getdata()) | |
newData = [] | |
# Convert hex default color to an RGB tuple | |
h1 = default_color.strip("#") | |
rgb_tup = tuple(int(h1[i:i+2], 16) for i in (0, 2, 4)) | |
for item in datas: | |
# Check if the pixel is within the tolerance of the default color | |
if (item[0] in range(rgb_tup[0]-tolerance, rgb_tup[0]+tolerance) and | |
item[1] in range(rgb_tup[1]-tolerance, rgb_tup[1]+tolerance) and | |
item[2] in range(rgb_tup[2]-tolerance, rgb_tup[2]+tolerance)): | |
newData.append((0, 0, 0)) | |
else: | |
newData.append((255, 255, 255)) | |
qr_img.putdata(newData) | |
fallback_path = tmp_path + "_converted.png" | |
qr_img.save(fallback_path) | |
img_cv = cv2.imread(fallback_path) | |
try: | |
results = zxingcpp.read_barcodes(img_cv) | |
if results and results[0].text: | |
return results[0].text.strip() | |
except Exception as e: | |
logger.error(f"Fallback decoding failed: {e}") | |
return None | |
except Exception as e: | |
logger.error(f"QR scan failed: {str(e)}") | |
return None | |
def extract_emails(text): | |
email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" | |
return re.findall(email_regex, text) | |
def extract_websites(text): | |
website_regex = r"\b(?:https?://)?(?:www\.)?([A-Za-z0-9-]+\.[A-Za-z]{2,})(?:/\S*)?\b" | |
matches = re.findall(website_regex, text) | |
return [m for m in matches if '@' not in m] | |
def clean_phone_number(phone): | |
cleaned = re.sub(r"(?!^\+)[^\d]", "", phone) | |
if len(cleaned) < 9 or (len(cleaned) == 9 and cleaned.startswith("+")): | |
return None | |
return cleaned | |
def normalize_website(url): | |
url = url.lower().replace("www.", "").split('/')[0] | |
if not re.match(r"^[a-z0-9-]+\.[a-z]{2,}$", url): | |
return None | |
return f"www.{url}" | |
def extract_address(ocr_texts): | |
address_keywords = ["block", "street", "ave", "area", "industrial", "road"] | |
address_parts = [] | |
for text in ocr_texts: | |
if any(kw in text.lower() for kw in address_keywords): | |
address_parts.append(text) | |
return " ".join(address_parts) if address_parts else None | |
def inference(img: Image.Image, confidence): | |
try: | |
ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False, | |
det_model_dir='./models/det/en', | |
cls_model_dir='./models/cls/en', | |
rec_model_dir='./models/rec/en') | |
img_np = np.array(img) | |
result = ocr.ocr(img_np, cls=True)[0] | |
ocr_texts = [line[1][0] for line in result] | |
ocr_text = " ".join(ocr_texts) | |
labels = ["person name", "company name", "job title", | |
"phone number", "email address", "address", | |
"website"] | |
entities = gliner_model.predict_entities(ocr_text, labels, threshold=confidence, flat_ner=True) | |
results = { | |
"Person Name": [], | |
"Company Name": [], | |
"Job Title": [], | |
"Phone Number": [], | |
"Email Address": [], | |
"Address": [], | |
"Website": [], | |
"QR Code": [] | |
} | |
# Process entities with validation | |
for entity in entities: | |
text = entity["text"].strip() | |
label = entity["label"].lower() | |
if label == "phone number": | |
if (cleaned := clean_phone_number(text)): | |
results["Phone Number"].append(cleaned) | |
elif label == "email address" and "@" in text: | |
results["Email Address"].append(text.lower()) | |
elif label == "website": | |
if (normalized := normalize_website(text)): | |
results["Website"].append(normalized) | |
elif label == "address": | |
results["Address"].append(text) | |
elif label == "company name": | |
results["Company Name"].append(text) | |
elif label == "person name": | |
results["Person Name"].append(text) | |
elif label == "job title": | |
results["Job Title"].append(text.title()) | |
# Regex fallbacks | |
results["Email Address"] += extract_emails(ocr_text) | |
results["Website"] += [normalize_website(w) for w in extract_websites(ocr_text)] | |
# Phone number validation | |
seen_phones = set() | |
for phone in results["Phone Number"] + re.findall(r'\+\d{8,}|\d{9,}', ocr_text): | |
if (cleaned := clean_phone_number(phone)) and cleaned not in seen_phones: | |
results["Phone Number"].append(cleaned) | |
seen_phones.add(cleaned) | |
results["Phone Number"] = list(seen_phones) | |
# Address processing | |
if not results["Address"]: | |
if (address := extract_address(ocr_texts)): | |
results["Address"].append(address) | |
# Website normalization | |
seen_websites = set() | |
final_websites = [] | |
for web in results["Website"]: | |
if web and web not in seen_websites: | |
final_websites.append(web) | |
seen_websites.add(web) | |
results["Website"] = final_websites | |
# Company name fallback | |
if not results["Company Name"]: | |
if results["Email Address"]: | |
domain = results["Email Address"][0].split('@')[-1].split('.')[0] | |
results["Company Name"].append(domain.title()) | |
elif results["Website"]: | |
domain = results["Website"][0].split('.')[1] | |
results["Company Name"].append(domain.title()) | |
# Name fallback | |
if not results["Person Name"]: | |
for text in ocr_texts: | |
if re.match(r"^(?:[A-Z][a-z]+\s?){2,}$", text): | |
results["Person Name"].append(text) | |
break | |
# QR Code scanning using the new zxingcpp-based function | |
if (qr_data := scan_qr_code(img)): | |
results["QR Code"].append(qr_data) | |
# Create CSV file containing the results | |
csv_data = {k: "; ".join(v) for k, v in results.items() if v} | |
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="w") as tmp_file: | |
pd.DataFrame([csv_data]).to_csv(tmp_file, index=False) | |
csv_path = tmp_file.name | |
return ocr_text, csv_data, csv_path, "" | |
except Exception as e: | |
logger.error(f"Processing failed: {traceback.format_exc()}") | |
return "", {}, None, f"Error: {str(e)}\n{traceback.format_exc()}" | |
# Gradio Interface | |
title = 'Enhanced Business Card Parser' | |
description = 'Accurate entity extraction with combined AI and regex validation' | |
if __name__ == '__main__': | |
demo = gr.Interface( | |
inference, | |
[gr.Image(type='pil', label='Upload Business Card'), | |
gr.Slider(0.1, 1, 0.4, step=0.1, label='Confidence Threshold')], | |
[gr.Textbox(label="OCR Result"), | |
gr.JSON(label="Structured Data"), | |
gr.File(label="Download CSV"), | |
gr.Textbox(label="Error Log")], | |
title=title, | |
description=description, | |
css=".gr-interface {max-width: 800px !important;}" | |
) | |
demo.launch() | |