codic commited on
Commit
cef9fba
·
verified ·
1 Parent(s): 21c5eee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -55
app.py CHANGED
@@ -27,14 +27,12 @@ except Exception:
27
  logger.exception("Failed to load GLiNER model")
28
  raise
29
 
30
- # Regex patterns for emails and websites
31
  EMAIL_REGEX = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b")
32
  WEBSITE_REGEX = re.compile(r"(?:https?://)?(?:www\.)?([A-Za-z0-9-]+\.[A-Za-z]{2,})")
33
 
34
- # Phone number constants and regex for Saudi/UAE support
35
- SAUDI_CODE = '+966'
36
  UAE_CODE = '+971'
37
- PHONE_REGEX = re.compile(r'^(?:\+9665\d{8}|\+9715\d{8}|05\d{8}|5\d{8})$')
38
 
39
  # Utility functions
40
  def extract_emails(text: str) -> list[str]:
@@ -47,50 +45,54 @@ def normalize_website(url: str) -> str | None:
47
  u = url.lower().replace('www.', '').split('/')[0]
48
  return f"www.{u}" if re.match(r"^[a-z0-9-]+\.[a-z]{2,}$", u) else None
49
 
 
 
50
  def clean_phone_number(phone: str) -> str | None:
51
- cleaned = re.sub(r"[^\d+]", "", phone)
52
- # International formats
53
- if cleaned.startswith(SAUDI_CODE + '5') and len(cleaned) == 12:
54
- return cleaned
55
- if cleaned.startswith(UAE_CODE + '5') and len(cleaned) == 12:
56
- return cleaned
57
- # Local to international
58
- if cleaned.startswith('05') and len(cleaned) == 10:
59
- # Determine country by leading digit after 0 (6 Saudi, 5 UAE)
60
- return (SAUDI_CODE if cleaned[1]=='5' and cleaned[1:2] == '5' else UAE_CODE) + cleaned[1:]
61
- if cleaned.startswith('5') and len(cleaned) == 9:
62
- return UAE_CODE + cleaned
63
- if cleaned.startswith('9665') and len(cleaned) == 12:
64
  return '+' + cleaned
 
 
 
65
  return None
66
 
 
 
67
  def process_phone_numbers(text: str) -> list[str]:
68
  found = []
69
- for match in re.finditer(r'(?:\+?\d{8,13}|05\d{8})', text):
 
70
  raw = match.group().strip()
71
  if (c := clean_phone_number(raw)):
72
  found.append(c)
73
  return list(set(found))
74
 
 
 
75
  def extract_address(ocr_texts: list[str]) -> str | None:
76
  keywords = ["block","street","ave","area","industrial","road"]
77
  parts = [t for t in ocr_texts if any(kw in t.lower() for kw in keywords)]
78
  return " ".join(parts) if parts else None
79
 
80
  # QR scanning
 
81
  def scan_qr_code(image: Image.Image) -> str | None:
82
  try:
83
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
84
  image.save(tmp, format="PNG")
85
  path = tmp.name
86
  img_cv = cv2.imread(path)
87
- # Direct decode
88
  try:
89
  res = zxingcpp.read_barcodes(img_cv)
90
  if res and res[0].text:
91
  return res[0].text.strip()
92
  except:
93
- logger.warning("Direct ZXing decode failed")
94
  # Fallback recolor
95
  default_color = (0, 0, 0)
96
  tol = 50
@@ -107,6 +109,7 @@ def scan_qr_code(image: Image.Image) -> str | None:
107
  return None
108
 
109
  # Deduplication
 
110
  def deduplicate_data(results: dict[str, list[str]]) -> None:
111
  def clean_list(items, normalizer=lambda x: x):
112
  seen = set(); out = []
@@ -118,11 +121,11 @@ def deduplicate_data(results: dict[str, list[str]]) -> None:
118
  if norm and norm not in seen:
119
  seen.add(norm); out.append(norm)
120
  return out
121
- # Normalize lists
122
  results['Email Address'] = clean_list(results.get('Email Address', []), lambda e: e.lower())
123
  results['Website'] = clean_list(results.get('Website', []), normalize_website)
124
  results['Phone Number'] = clean_list(results.get('Phone Number', []), clean_phone_number)
125
- # Others: simple dedupe
126
  for key in ['Person Name','Company Name','Job Title','Address','QR Code']:
127
  seen = set(); out = []
128
  for v in results.get(key, []):
@@ -134,10 +137,7 @@ def deduplicate_data(results: dict[str, list[str]]) -> None:
134
  # Inference pipeline
135
  def inference(img: Image.Image, confidence: float):
136
  try:
137
- ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False,
138
- det_model_dir='./models/det/en',
139
- cls_model_dir='./models/cls/en',
140
- rec_model_dir='./models/rec/en')
141
  arr = np.array(img)
142
  raw = ocr.ocr(arr, cls=True)[0]
143
  ocr_texts = [ln[1][0] for ln in raw]
@@ -147,63 +147,61 @@ def inference(img: Image.Image, confidence: float):
147
  entities = gliner_model.predict_entities(full_text, labels, threshold=confidence, flat_ner=True)
148
 
149
  results = {k: [] for k in ['Person Name','Company Name','Job Title','Phone Number','Email Address','Address','Website','QR Code']}
150
- # Entity processing
 
151
  for ent in entities:
152
  txt, lbl = ent['text'].strip(), ent['label'].lower()
153
- if lbl == 'person name':
154
- results['Person Name'].append(txt)
155
- elif lbl == 'company name':
156
- results['Company Name'].append(txt)
157
- elif lbl == 'job title':
158
- results['Job Title'].append(txt.title())
159
  elif lbl == 'phone number':
160
- if (c:=clean_phone_number(txt)):
161
- results['Phone Number'].append(c)
162
  elif lbl == 'email address' and EMAIL_REGEX.fullmatch(txt):
163
  results['Email Address'].append(txt.lower())
164
  elif lbl == 'website' and WEBSITE_REGEX.search(txt):
165
- if (n:=normalize_website(txt)):
166
- results['Website'].append(n)
167
- elif lbl == 'address':
168
- results['Address'].append(txt)
169
  # Regex fallbacks
170
  results['Email Address'] += extract_emails(full_text)
171
  results['Website'] += extract_websites(full_text)
172
- # Phone regex fallback
173
  results['Phone Number'] += process_phone_numbers(full_text)
 
174
  # QR code
175
  if qr := scan_qr_code(img):
176
  results['QR Code'].append(qr)
 
177
  # Address fallback
178
- if not results['Address']:
179
- if addr := extract_address(ocr_texts):
180
- results['Address'].append(addr)
181
- # Deduplicate
182
  deduplicate_data(results)
 
183
  # Company fallback
184
- if not results['Company Name']:
185
- if results['Email Address']:
186
- dom = results['Email Address'][0].split('@')[-1].split('.')[0]
187
- results['Company Name'].append(dom.title())
188
- elif results['Website']:
189
- dom = results['Website'][0].split('.')[1]
190
- results['Company Name'].append(dom.title())
191
  # Name fallback
192
  if not results['Person Name']:
193
  for t in ocr_texts:
194
  if re.match(r'^(?:[A-Z][a-z]+\s?){2,}$', t):
195
  results['Person Name'].append(t)
196
  break
197
- # Build CSV map including all keys
198
- csv_map = {k: '; '.join(v) for k,v in results.items()}
 
199
  with tempfile.NamedTemporaryFile(suffix='.csv', delete=False, mode='w') as f:
200
  pd.DataFrame([csv_map]).to_csv(f, index=False)
201
  csv_path = f.name
 
202
  return full_text, results, csv_path, ''
203
  except Exception:
204
  err = traceback.format_exc()
205
  logger.error(f"Processing failed: {err}")
206
- return '', {k: [] for k in ['Person Name','Company Name','Job Title','Phone Number','Email Address','Address','Website','QR Code']}, None, f"Error:\n{err}"
 
207
 
208
  # Gradio Interface
209
  if __name__ == '__main__':
@@ -216,7 +214,8 @@ if __name__ == '__main__':
216
  gr.File(label="Download CSV"),
217
  gr.Textbox(label="Error Log")],
218
  title='Enhanced Business Card Parser',
219
- description='Accurate entity extraction with combined AI and regex validation (with Saudi/UAE support)',
220
  css=".gr-interface {max-width: 800px !important;}"
221
  )
222
  demo.launch()
 
 
27
  logger.exception("Failed to load GLiNER model")
28
  raise
29
 
30
+ # Regex patterns
31
  EMAIL_REGEX = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b")
32
  WEBSITE_REGEX = re.compile(r"(?:https?://)?(?:www\.)?([A-Za-z0-9-]+\.[A-Za-z]{2,})")
33
 
34
+ # UAE phone country code
 
35
  UAE_CODE = '+971'
 
36
 
37
  # Utility functions
38
  def extract_emails(text: str) -> list[str]:
 
45
  u = url.lower().replace('www.', '').split('/')[0]
46
  return f"www.{u}" if re.match(r"^[a-z0-9-]+\.[a-z]{2,}$", u) else None
47
 
48
+ # Phone cleaning: treat all local '0XXXXXXXXX' as UAE
49
+
50
  def clean_phone_number(phone: str) -> str | None:
51
+ cleaned = re.sub(r"\D", "", phone)
52
+ # Local UAE numbers (10 digits starting with 0)
53
+ if len(cleaned) == 10 and cleaned.startswith('0'):
54
+ return UAE_CODE + cleaned[1:]
55
+ # International UAE numbers without plus (12 digits starting '971')
56
+ if len(cleaned) == 12 and cleaned.startswith('971'):
 
 
 
 
 
 
 
57
  return '+' + cleaned
58
+ # Already plus-prefixed UAE number
59
+ if phone.strip().startswith('+971') and len(cleaned) == 12:
60
+ return phone.strip()
61
  return None
62
 
63
+ # Extract phone numbers from text
64
+
65
  def process_phone_numbers(text: str) -> list[str]:
66
  found = []
67
+ # Match '05' followed by 8 digits or plus variant
68
+ for match in re.finditer(r'(?:05\d{8}|\+?\d{8,12})', text):
69
  raw = match.group().strip()
70
  if (c := clean_phone_number(raw)):
71
  found.append(c)
72
  return list(set(found))
73
 
74
+ # Address extraction
75
+
76
  def extract_address(ocr_texts: list[str]) -> str | None:
77
  keywords = ["block","street","ave","area","industrial","road"]
78
  parts = [t for t in ocr_texts if any(kw in t.lower() for kw in keywords)]
79
  return " ".join(parts) if parts else None
80
 
81
  # QR scanning
82
+
83
  def scan_qr_code(image: Image.Image) -> str | None:
84
  try:
85
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
86
  image.save(tmp, format="PNG")
87
  path = tmp.name
88
  img_cv = cv2.imread(path)
89
+ # Direct decoding
90
  try:
91
  res = zxingcpp.read_barcodes(img_cv)
92
  if res and res[0].text:
93
  return res[0].text.strip()
94
  except:
95
+ logger.warning("Direct QR decode failed")
96
  # Fallback recolor
97
  default_color = (0, 0, 0)
98
  tol = 50
 
109
  return None
110
 
111
  # Deduplication
112
+
113
  def deduplicate_data(results: dict[str, list[str]]) -> None:
114
  def clean_list(items, normalizer=lambda x: x):
115
  seen = set(); out = []
 
121
  if norm and norm not in seen:
122
  seen.add(norm); out.append(norm)
123
  return out
124
+
125
  results['Email Address'] = clean_list(results.get('Email Address', []), lambda e: e.lower())
126
  results['Website'] = clean_list(results.get('Website', []), normalize_website)
127
  results['Phone Number'] = clean_list(results.get('Phone Number', []), clean_phone_number)
128
+
129
  for key in ['Person Name','Company Name','Job Title','Address','QR Code']:
130
  seen = set(); out = []
131
  for v in results.get(key, []):
 
137
  # Inference pipeline
138
  def inference(img: Image.Image, confidence: float):
139
  try:
140
+ ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False)
 
 
 
141
  arr = np.array(img)
142
  raw = ocr.ocr(arr, cls=True)[0]
143
  ocr_texts = [ln[1][0] for ln in raw]
 
147
  entities = gliner_model.predict_entities(full_text, labels, threshold=confidence, flat_ner=True)
148
 
149
  results = {k: [] for k in ['Person Name','Company Name','Job Title','Phone Number','Email Address','Address','Website','QR Code']}
150
+
151
+ # Process NER entities
152
  for ent in entities:
153
  txt, lbl = ent['text'].strip(), ent['label'].lower()
154
+ if lbl == 'person name': results['Person Name'].append(txt)
155
+ elif lbl == 'company name': results['Company Name'].append(txt)
156
+ elif lbl == 'job title': results['Job Title'].append(txt.title())
 
 
 
157
  elif lbl == 'phone number':
158
+ if (c := clean_phone_number(txt)): results['Phone Number'].append(c)
 
159
  elif lbl == 'email address' and EMAIL_REGEX.fullmatch(txt):
160
  results['Email Address'].append(txt.lower())
161
  elif lbl == 'website' and WEBSITE_REGEX.search(txt):
162
+ if (n := normalize_website(txt)): results['Website'].append(n)
163
+ elif lbl == 'address': results['Address'].append(txt)
164
+
 
165
  # Regex fallbacks
166
  results['Email Address'] += extract_emails(full_text)
167
  results['Website'] += extract_websites(full_text)
 
168
  results['Phone Number'] += process_phone_numbers(full_text)
169
+
170
  # QR code
171
  if qr := scan_qr_code(img):
172
  results['QR Code'].append(qr)
173
+
174
  # Address fallback
175
+ if not results['Address'] and (addr := extract_address(ocr_texts)):
176
+ results['Address'].append(addr)
177
+
178
+ # Deduplicate all fields
179
  deduplicate_data(results)
180
+
181
  # Company fallback
182
+ if not results['Company Name'] and (dom := (results['Email Address'] or results['Website'])):
183
+ domain = dom[0].split('@')[-1].split('.')[0]
184
+ results['Company Name'].append(domain.title())
185
+
 
 
 
186
  # Name fallback
187
  if not results['Person Name']:
188
  for t in ocr_texts:
189
  if re.match(r'^(?:[A-Z][a-z]+\s?){2,}$', t):
190
  results['Person Name'].append(t)
191
  break
192
+
193
+ # Prepare CSV
194
+ csv_map = {k: '; '.join(v) for k, v in results.items()}
195
  with tempfile.NamedTemporaryFile(suffix='.csv', delete=False, mode='w') as f:
196
  pd.DataFrame([csv_map]).to_csv(f, index=False)
197
  csv_path = f.name
198
+
199
  return full_text, results, csv_path, ''
200
  except Exception:
201
  err = traceback.format_exc()
202
  logger.error(f"Processing failed: {err}")
203
+ empty = {k: [] for k in ['Person Name','Company Name','Job Title','Phone Number','Email Address','Address','Website','QR Code']}
204
+ return '', empty, None, f"Error:\n{err}"
205
 
206
  # Gradio Interface
207
  if __name__ == '__main__':
 
214
  gr.File(label="Download CSV"),
215
  gr.Textbox(label="Error Log")],
216
  title='Enhanced Business Card Parser',
217
+ description='Entity extraction with AI and regex validation (UAE-focused phone support)',
218
  css=".gr-interface {max-width: 800px !important;}"
219
  )
220
  demo.launch()
221
+