Spaces:
Sleeping
Sleeping
update -- before was working
Browse files
app.py
CHANGED
@@ -3,243 +3,179 @@ from gliner import GLiNER
|
|
3 |
from PIL import Image
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
6 |
-
import cv2
|
7 |
import logging
|
8 |
-
import os
|
9 |
import tempfile
|
10 |
import pandas as pd
|
11 |
-
import io
|
12 |
import re
|
13 |
import traceback
|
14 |
-
import zxingcpp
|
15 |
|
16 |
-
#
|
|
|
|
|
17 |
logging.basicConfig(level=logging.INFO)
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
#
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
def
|
36 |
-
"""
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
try:
|
43 |
-
#
|
44 |
-
|
45 |
-
image.save(tmp, format="PNG")
|
46 |
-
tmp_path = tmp.name
|
47 |
|
48 |
-
#
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
results = zxingcpp.read_barcodes(img_cv)
|
53 |
-
if results and results[0].text:
|
54 |
-
return results[0].text.strip()
|
55 |
-
except Exception as e:
|
56 |
-
logger.warning(f"Direct zxingcpp decoding failed: {e}")
|
57 |
|
58 |
-
#
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
newData = []
|
64 |
-
# Convert hex default color to an RGB tuple
|
65 |
-
h1 = default_color.strip("#")
|
66 |
-
rgb_tup = tuple(int(h1[i:i+2], 16) for i in (0, 2, 4))
|
67 |
-
for item in datas:
|
68 |
-
# Check if the pixel is within the tolerance of the default color
|
69 |
-
if (item[0] in range(rgb_tup[0]-tolerance, rgb_tup[0]+tolerance) and
|
70 |
-
item[1] in range(rgb_tup[1]-tolerance, rgb_tup[1]+tolerance) and
|
71 |
-
item[2] in range(rgb_tup[2]-tolerance, rgb_tup[2]+tolerance)):
|
72 |
-
newData.append((0, 0, 0))
|
73 |
-
else:
|
74 |
-
newData.append((255, 255, 255))
|
75 |
-
qr_img.putdata(newData)
|
76 |
-
fallback_path = tmp_path + "_converted.png"
|
77 |
-
qr_img.save(fallback_path)
|
78 |
-
img_cv = cv2.imread(fallback_path)
|
79 |
-
try:
|
80 |
-
results = zxingcpp.read_barcodes(img_cv)
|
81 |
-
if results and results[0].text:
|
82 |
-
return results[0].text.strip()
|
83 |
-
except Exception as e:
|
84 |
-
logger.error(f"Fallback decoding failed: {e}")
|
85 |
-
return None
|
86 |
-
except Exception as e:
|
87 |
-
logger.error(f"QR scan failed: {str(e)}")
|
88 |
-
return None
|
89 |
-
|
90 |
-
def extract_emails(text):
|
91 |
-
email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
|
92 |
-
return re.findall(email_regex, text)
|
93 |
-
|
94 |
-
def extract_websites(text):
|
95 |
-
website_regex = r"\b(?:https?://)?(?:www\.)?([A-Za-z0-9-]+\.[A-Za-z]{2,})(?:/\S*)?\b"
|
96 |
-
matches = re.findall(website_regex, text)
|
97 |
-
return [m for m in matches if '@' not in m]
|
98 |
-
|
99 |
-
def clean_phone_number(phone):
|
100 |
-
cleaned = re.sub(r"(?!^\+)[^\d]", "", phone)
|
101 |
-
if len(cleaned) < 9 or (len(cleaned) == 9 and cleaned.startswith("+")):
|
102 |
-
return None
|
103 |
-
return cleaned
|
104 |
-
|
105 |
-
def normalize_website(url):
|
106 |
-
url = url.lower().replace("www.", "").split('/')[0]
|
107 |
-
if not re.match(r"^[a-z0-9-]+\.[a-z]{2,}$", url):
|
108 |
-
return None
|
109 |
-
return f"www.{url}"
|
110 |
-
|
111 |
-
def extract_address(ocr_texts):
|
112 |
-
address_keywords = ["block", "street", "ave", "area", "industrial", "road"]
|
113 |
-
address_parts = []
|
114 |
-
for text in ocr_texts:
|
115 |
-
if any(kw in text.lower() for kw in address_keywords):
|
116 |
-
address_parts.append(text)
|
117 |
-
return " ".join(address_parts) if address_parts else None
|
118 |
-
|
119 |
-
def inference(img: Image.Image, confidence):
|
120 |
-
try:
|
121 |
-
ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False,
|
122 |
-
det_model_dir='./models/det/en',
|
123 |
-
cls_model_dir='./models/cls/en',
|
124 |
-
rec_model_dir='./models/rec/en')
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
labels = ["person name", "company name", "job title",
|
132 |
-
"phone number", "email address", "address",
|
133 |
-
"website"]
|
134 |
-
entities = gliner_model.predict_entities(ocr_text, labels, threshold=confidence, flat_ner=True)
|
135 |
|
|
|
136 |
results = {
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
}
|
146 |
-
|
147 |
-
# Process entities with validation
|
148 |
-
for entity in entities:
|
149 |
-
text = entity["text"].strip()
|
150 |
-
label = entity["label"].lower()
|
151 |
-
|
152 |
-
if label == "phone number":
|
153 |
-
if (cleaned := clean_phone_number(text)):
|
154 |
-
results["Phone Number"].append(cleaned)
|
155 |
-
elif label == "email address" and "@" in text:
|
156 |
-
results["Email Address"].append(text.lower())
|
157 |
-
elif label == "website":
|
158 |
-
if (normalized := normalize_website(text)):
|
159 |
-
results["Website"].append(normalized)
|
160 |
-
elif label == "address":
|
161 |
-
results["Address"].append(text)
|
162 |
-
elif label == "company name":
|
163 |
-
results["Company Name"].append(text)
|
164 |
-
elif label == "person name":
|
165 |
-
results["Person Name"].append(text)
|
166 |
-
elif label == "job title":
|
167 |
-
results["Job Title"].append(text.title())
|
168 |
-
|
169 |
-
# Regex fallbacks
|
170 |
-
results["Email Address"] += extract_emails(ocr_text)
|
171 |
-
results["Website"] += [normalize_website(w) for w in extract_websites(ocr_text)]
|
172 |
|
173 |
-
#
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
# Address processing
|
182 |
-
if not results["Address"]:
|
183 |
-
if (address := extract_address(ocr_texts)):
|
184 |
-
results["Address"].append(address)
|
185 |
-
|
186 |
-
# Website normalization
|
187 |
-
seen_websites = set()
|
188 |
-
final_websites = []
|
189 |
-
for web in results["Website"]:
|
190 |
-
if web and web not in seen_websites:
|
191 |
-
final_websites.append(web)
|
192 |
-
seen_websites.add(web)
|
193 |
-
results["Website"] = final_websites
|
194 |
-
|
195 |
-
# Company name fallback
|
196 |
-
if not results["Company Name"]:
|
197 |
-
if results["Email Address"]:
|
198 |
-
domain = results["Email Address"][0].split('@')[-1].split('.')[0]
|
199 |
-
results["Company Name"].append(domain.title())
|
200 |
-
elif results["Website"]:
|
201 |
-
domain = results["Website"][0].split('.')[1]
|
202 |
-
results["Company Name"].append(domain.title())
|
203 |
-
|
204 |
-
# Name fallback
|
205 |
-
if not results["Person Name"]:
|
206 |
-
for text in ocr_texts:
|
207 |
-
if re.match(r"^(?:[A-Z][a-z]+\s?){2,}$", text):
|
208 |
-
results["Person Name"].append(text)
|
209 |
-
break
|
210 |
-
|
211 |
-
# QR Code scanning using the new zxingcpp-based function
|
212 |
-
if (qr_data := scan_qr_code(img)):
|
213 |
-
results["QR Code"].append(qr_data)
|
214 |
-
|
215 |
-
# Create CSV file containing the results
|
216 |
-
csv_data = {k: "; ".join(v) for k, v in results.items() if v}
|
217 |
-
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="w") as tmp_file:
|
218 |
-
pd.DataFrame([csv_data]).to_csv(tmp_file, index=False)
|
219 |
-
csv_path = tmp_file.name
|
220 |
-
|
221 |
-
return ocr_text, csv_data, csv_path, ""
|
222 |
-
|
223 |
except Exception as e:
|
224 |
-
logger.error(f"Processing
|
225 |
-
return "", {}, None, f"Error: {str(e)}
|
226 |
|
|
|
227 |
# Gradio Interface
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
[
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
|
|
243 |
)
|
244 |
|
245 |
-
|
|
|
|
3 |
from PIL import Image
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
|
|
6 |
import logging
|
|
|
7 |
import tempfile
|
8 |
import pandas as pd
|
|
|
9 |
import re
|
10 |
import traceback
|
11 |
+
import zxingcpp
|
12 |
|
13 |
+
# --------------------------
|
14 |
+
# Configuration & Constants
|
15 |
+
# --------------------------
|
16 |
logging.basicConfig(level=logging.INFO)
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
19 |
+
COUNTRY_CODES = {
|
20 |
+
'SAUDI': {'code': '+966', 'pattern': r'^(\+9665\d{8}|05\d{8})$'},
|
21 |
+
'UAE': {'code': '+971', 'pattern': r'^(\+9715\d{8}|05\d{8})$'}
|
22 |
+
}
|
23 |
+
|
24 |
+
VALIDATION_PATTERNS = {
|
25 |
+
'email': re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', re.IGNORECASE),
|
26 |
+
'website': re.compile(r'(?:https?://)?(?:www\.)?([A-Za-z0-9-]+\.[A-Za-z]{2,})'),
|
27 |
+
'name': re.compile(r'^[A-Z][a-z]+(?:\s+[A-Z][a-z]+){1,2}$')
|
28 |
+
}
|
29 |
+
|
30 |
+
# --------------------------
|
31 |
+
# Core Processing Functions
|
32 |
+
# --------------------------
|
33 |
+
|
34 |
+
def process_phone_number(raw_number: str) -> str:
|
35 |
+
"""Validate and standardize phone numbers for supported countries"""
|
36 |
+
cleaned = re.sub(r'[^\d+]', '', raw_number)
|
37 |
+
|
38 |
+
for country, config in COUNTRY_CODES.items():
|
39 |
+
if re.match(config['pattern'], cleaned):
|
40 |
+
if cleaned.startswith('0'):
|
41 |
+
return f"{config['code']}{cleaned[1:]}"
|
42 |
+
if cleaned.startswith('5'):
|
43 |
+
return f"{config['code']}{cleaned}"
|
44 |
+
return cleaned
|
45 |
+
return None
|
46 |
+
|
47 |
+
def extract_contact_info(text: str) -> dict:
|
48 |
+
"""Extract and validate all contact information from text"""
|
49 |
+
contacts = {
|
50 |
+
'phones': set(),
|
51 |
+
'emails': set(),
|
52 |
+
'websites': set()
|
53 |
+
}
|
54 |
+
|
55 |
+
# Phone number extraction
|
56 |
+
for match in re.finditer(r'(\+?\d{10,13}|05\d{8})', text):
|
57 |
+
if processed := process_phone_number(match.group()):
|
58 |
+
contacts['phones'].add(processed)
|
59 |
+
|
60 |
+
# Email validation
|
61 |
+
contacts['emails'].update(
|
62 |
+
email.lower() for email in VALIDATION_PATTERNS['email'].findall(text)
|
63 |
+
)
|
64 |
+
|
65 |
+
# Website normalization
|
66 |
+
for match in VALIDATION_PATTERNS['website'].finditer(text):
|
67 |
+
domain = match.group(1).lower()
|
68 |
+
if '.' in domain:
|
69 |
+
contacts['websites'].add(f"www.{domain.split('/')[0]}")
|
70 |
+
|
71 |
+
return {k: list(v) for k, v in contacts.items() if v}
|
72 |
+
|
73 |
+
def process_entities(entities: list, ocr_text: list) -> dict:
|
74 |
+
"""Process GLiNER entities with validation and fallbacks"""
|
75 |
+
result = {
|
76 |
+
'name': None,
|
77 |
+
'company': None,
|
78 |
+
'title': None,
|
79 |
+
'address': None
|
80 |
+
}
|
81 |
+
|
82 |
+
# Entity extraction
|
83 |
+
for entity in entities:
|
84 |
+
label = entity['label'].lower()
|
85 |
+
text = entity['text'].strip()
|
86 |
+
|
87 |
+
if label == 'person name' and VALIDATION_PATTERNS['name'].match(text):
|
88 |
+
result['name'] = text.title()
|
89 |
+
elif label == 'company name':
|
90 |
+
result['company'] = text
|
91 |
+
elif label == 'job title':
|
92 |
+
result['title'] = text.title()
|
93 |
+
elif label == 'address':
|
94 |
+
result['address'] = text
|
95 |
+
|
96 |
+
# Name fallback from OCR text
|
97 |
+
if not result['name']:
|
98 |
+
for text in ocr_text:
|
99 |
+
if VALIDATION_PATTERNS['name'].match(text):
|
100 |
+
result['name'] = text.title()
|
101 |
+
break
|
102 |
+
|
103 |
+
return result
|
104 |
+
|
105 |
+
# --------------------------
|
106 |
+
# Main Processing Pipeline
|
107 |
+
# --------------------------
|
108 |
+
|
109 |
+
def process_business_card(img: Image.Image, confidence: float) -> tuple:
|
110 |
+
"""Full processing pipeline for business card images"""
|
111 |
try:
|
112 |
+
# Initialize OCR
|
113 |
+
ocr_engine = PaddleOCR(lang='en', use_gpu=False)
|
|
|
|
|
114 |
|
115 |
+
# OCR Processing
|
116 |
+
ocr_result = ocr_engine.ocr(np.array(img), cls=True)
|
117 |
+
ocr_text = [line[1][0] for line in ocr_result[0]]
|
118 |
+
full_text = " ".join(ocr_text)
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
+
# Entity Recognition
|
121 |
+
labels = ["person name", "company name", "job title",
|
122 |
+
"phone number", "email address", "address",
|
123 |
+
"website"]
|
124 |
+
entities = gliner_model.predict_entities(full_text, labels, threshold=confidence)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
+
# Data Extraction
|
127 |
+
contacts = extract_contact_info(full_text)
|
128 |
+
entity_data = process_entities(entities, ocr_text)
|
129 |
+
qr_data = zxingcpp.read_barcodes(np.array(img.convert('RGB')))
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
+
# Compile Final Results
|
132 |
results = {
|
133 |
+
'Person Name': entity_data['name'],
|
134 |
+
'Company Name': entity_data['company'] or (
|
135 |
+
contacts['emails'][0].split('@')[1].split('.')[0].title()
|
136 |
+
if contacts['emails'] else None
|
137 |
+
),
|
138 |
+
'Job Title': entity_data['title'],
|
139 |
+
'Phone Numbers': contacts['phones'],
|
140 |
+
'Email Addresses': contacts['emails'],
|
141 |
+
'Address': entity_data['address'] or next(
|
142 |
+
(t for t in ocr_text if any(kw in t.lower()
|
143 |
+
for kw in {'street', 'ave', 'road'})), None
|
144 |
+
),
|
145 |
+
'Website': contacts['websites'][0] if contacts['websites'] else None,
|
146 |
+
'QR Code': qr_data[0].text if qr_data else None
|
147 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
+
# Generate CSV Output
|
150 |
+
with tempfile.NamedTemporaryFile(suffix='.csv', delete=False, mode='w') as f:
|
151 |
+
pd.DataFrame([results]).to_csv(f)
|
152 |
+
csv_path = f.name
|
153 |
+
|
154 |
+
return full_text, results, csv_path, ""
|
155 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
except Exception as e:
|
157 |
+
logger.error(f"Processing Error: {traceback.format_exc()}")
|
158 |
+
return "", {}, None, f"Error: {str(e)}"
|
159 |
|
160 |
+
# --------------------------
|
161 |
# Gradio Interface
|
162 |
+
# --------------------------
|
163 |
+
|
164 |
+
interface = gr.Interface(
|
165 |
+
fn=process_business_card,
|
166 |
+
inputs=[
|
167 |
+
gr.Image(type='pil', label='Upload Business Card'),
|
168 |
+
gr.Slider(0.1, 1.0, value=0.4, label='Confidence Threshold')
|
169 |
+
],
|
170 |
+
outputs=[
|
171 |
+
gr.Textbox(label='OCR Result'),
|
172 |
+
gr.JSON(label='Structured Data'),
|
173 |
+
gr.File(label='Download CSV'),
|
174 |
+
gr.Textbox(label='Error Log')
|
175 |
+
],
|
176 |
+
title='Enterprise Business Card Parser',
|
177 |
+
description='Multi-country support with comprehensive validation'
|
178 |
)
|
179 |
|
180 |
+
if __name__ == '__main__':
|
181 |
+
interface.launch()
|