codic commited on
Commit
c66181c
·
verified ·
1 Parent(s): 97de021

update -- before was working

Browse files
Files changed (1) hide show
  1. app.py +154 -218
app.py CHANGED
@@ -3,243 +3,179 @@ from gliner import GLiNER
3
  from PIL import Image
4
  import gradio as gr
5
  import numpy as np
6
- import cv2
7
  import logging
8
- import os
9
  import tempfile
10
  import pandas as pd
11
- import io
12
  import re
13
  import traceback
14
- import zxingcpp # Added zxingcpp for QR decoding
15
 
16
- # Configure logging
 
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
20
- # Set up GLiNER environment variables
21
- os.environ['GLINER_HOME'] = './gliner_models'
22
-
23
- # Load GLiNER model
24
- try:
25
- logger.info("Loading GLiNER model...")
26
- gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
27
- except Exception as e:
28
- logger.error("Failed to load GLiNER model")
29
- raise e
30
-
31
- # Get a random color (used for drawing bounding boxes, if needed)
32
- def get_random_color():
33
- return tuple(np.random.randint(0, 256, 3).tolist())
34
-
35
- def scan_qr_code(image):
36
- """
37
- Attempts to scan a QR code from the given PIL image using zxingcpp.
38
- The image is first saved to a temporary file to be read by zxingcpp.
39
- If the direct decoding fails, the function tries a fallback
40
- where the image is converted based on a default QR color (black) and tolerance.
41
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  try:
43
- # Save the PIL image to a temporary file
44
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
45
- image.save(tmp, format="PNG")
46
- tmp_path = tmp.name
47
 
48
- # Convert the saved image to a CV2 image
49
- img_cv = cv2.imread(tmp_path)
50
- # First attempt: direct decoding with zxingcpp
51
- try:
52
- results = zxingcpp.read_barcodes(img_cv)
53
- if results and results[0].text:
54
- return results[0].text.strip()
55
- except Exception as e:
56
- logger.warning(f"Direct zxingcpp decoding failed: {e}")
57
 
58
- # Fallback: Process image by converting specific QR colors with default parameters.
59
- default_color = "#000000" # Default QR color assumed (black)
60
- tolerance = 50 # Fixed tolerance value
61
- qr_img = image.convert("RGB")
62
- datas = list(qr_img.getdata())
63
- newData = []
64
- # Convert hex default color to an RGB tuple
65
- h1 = default_color.strip("#")
66
- rgb_tup = tuple(int(h1[i:i+2], 16) for i in (0, 2, 4))
67
- for item in datas:
68
- # Check if the pixel is within the tolerance of the default color
69
- if (item[0] in range(rgb_tup[0]-tolerance, rgb_tup[0]+tolerance) and
70
- item[1] in range(rgb_tup[1]-tolerance, rgb_tup[1]+tolerance) and
71
- item[2] in range(rgb_tup[2]-tolerance, rgb_tup[2]+tolerance)):
72
- newData.append((0, 0, 0))
73
- else:
74
- newData.append((255, 255, 255))
75
- qr_img.putdata(newData)
76
- fallback_path = tmp_path + "_converted.png"
77
- qr_img.save(fallback_path)
78
- img_cv = cv2.imread(fallback_path)
79
- try:
80
- results = zxingcpp.read_barcodes(img_cv)
81
- if results and results[0].text:
82
- return results[0].text.strip()
83
- except Exception as e:
84
- logger.error(f"Fallback decoding failed: {e}")
85
- return None
86
- except Exception as e:
87
- logger.error(f"QR scan failed: {str(e)}")
88
- return None
89
-
90
- def extract_emails(text):
91
- email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
92
- return re.findall(email_regex, text)
93
-
94
- def extract_websites(text):
95
- website_regex = r"\b(?:https?://)?(?:www\.)?([A-Za-z0-9-]+\.[A-Za-z]{2,})(?:/\S*)?\b"
96
- matches = re.findall(website_regex, text)
97
- return [m for m in matches if '@' not in m]
98
-
99
- def clean_phone_number(phone):
100
- cleaned = re.sub(r"(?!^\+)[^\d]", "", phone)
101
- if len(cleaned) < 9 or (len(cleaned) == 9 and cleaned.startswith("+")):
102
- return None
103
- return cleaned
104
-
105
- def normalize_website(url):
106
- url = url.lower().replace("www.", "").split('/')[0]
107
- if not re.match(r"^[a-z0-9-]+\.[a-z]{2,}$", url):
108
- return None
109
- return f"www.{url}"
110
-
111
- def extract_address(ocr_texts):
112
- address_keywords = ["block", "street", "ave", "area", "industrial", "road"]
113
- address_parts = []
114
- for text in ocr_texts:
115
- if any(kw in text.lower() for kw in address_keywords):
116
- address_parts.append(text)
117
- return " ".join(address_parts) if address_parts else None
118
-
119
- def inference(img: Image.Image, confidence):
120
- try:
121
- ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False,
122
- det_model_dir='./models/det/en',
123
- cls_model_dir='./models/cls/en',
124
- rec_model_dir='./models/rec/en')
125
 
126
- img_np = np.array(img)
127
- result = ocr.ocr(img_np, cls=True)[0]
128
- ocr_texts = [line[1][0] for line in result]
129
- ocr_text = " ".join(ocr_texts)
130
-
131
- labels = ["person name", "company name", "job title",
132
- "phone number", "email address", "address",
133
- "website"]
134
- entities = gliner_model.predict_entities(ocr_text, labels, threshold=confidence, flat_ner=True)
135
 
 
136
  results = {
137
- "Person Name": [],
138
- "Company Name": [],
139
- "Job Title": [],
140
- "Phone Number": [],
141
- "Email Address": [],
142
- "Address": [],
143
- "Website": [],
144
- "QR Code": []
 
 
 
 
 
 
145
  }
146
-
147
- # Process entities with validation
148
- for entity in entities:
149
- text = entity["text"].strip()
150
- label = entity["label"].lower()
151
-
152
- if label == "phone number":
153
- if (cleaned := clean_phone_number(text)):
154
- results["Phone Number"].append(cleaned)
155
- elif label == "email address" and "@" in text:
156
- results["Email Address"].append(text.lower())
157
- elif label == "website":
158
- if (normalized := normalize_website(text)):
159
- results["Website"].append(normalized)
160
- elif label == "address":
161
- results["Address"].append(text)
162
- elif label == "company name":
163
- results["Company Name"].append(text)
164
- elif label == "person name":
165
- results["Person Name"].append(text)
166
- elif label == "job title":
167
- results["Job Title"].append(text.title())
168
-
169
- # Regex fallbacks
170
- results["Email Address"] += extract_emails(ocr_text)
171
- results["Website"] += [normalize_website(w) for w in extract_websites(ocr_text)]
172
 
173
- # Phone number validation
174
- seen_phones = set()
175
- for phone in results["Phone Number"] + re.findall(r'\+\d{8,}|\d{9,}', ocr_text):
176
- if (cleaned := clean_phone_number(phone)) and cleaned not in seen_phones:
177
- results["Phone Number"].append(cleaned)
178
- seen_phones.add(cleaned)
179
- results["Phone Number"] = list(seen_phones)
180
-
181
- # Address processing
182
- if not results["Address"]:
183
- if (address := extract_address(ocr_texts)):
184
- results["Address"].append(address)
185
-
186
- # Website normalization
187
- seen_websites = set()
188
- final_websites = []
189
- for web in results["Website"]:
190
- if web and web not in seen_websites:
191
- final_websites.append(web)
192
- seen_websites.add(web)
193
- results["Website"] = final_websites
194
-
195
- # Company name fallback
196
- if not results["Company Name"]:
197
- if results["Email Address"]:
198
- domain = results["Email Address"][0].split('@')[-1].split('.')[0]
199
- results["Company Name"].append(domain.title())
200
- elif results["Website"]:
201
- domain = results["Website"][0].split('.')[1]
202
- results["Company Name"].append(domain.title())
203
-
204
- # Name fallback
205
- if not results["Person Name"]:
206
- for text in ocr_texts:
207
- if re.match(r"^(?:[A-Z][a-z]+\s?){2,}$", text):
208
- results["Person Name"].append(text)
209
- break
210
-
211
- # QR Code scanning using the new zxingcpp-based function
212
- if (qr_data := scan_qr_code(img)):
213
- results["QR Code"].append(qr_data)
214
-
215
- # Create CSV file containing the results
216
- csv_data = {k: "; ".join(v) for k, v in results.items() if v}
217
- with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="w") as tmp_file:
218
- pd.DataFrame([csv_data]).to_csv(tmp_file, index=False)
219
- csv_path = tmp_file.name
220
-
221
- return ocr_text, csv_data, csv_path, ""
222
-
223
  except Exception as e:
224
- logger.error(f"Processing failed: {traceback.format_exc()}")
225
- return "", {}, None, f"Error: {str(e)}\n{traceback.format_exc()}"
226
 
 
227
  # Gradio Interface
228
- title = 'Enhanced Business Card Parser'
229
- description = 'Accurate entity extraction with combined AI and regex validation'
230
-
231
- if __name__ == '__main__':
232
- demo = gr.Interface(
233
- inference,
234
- [gr.Image(type='pil', label='Upload Business Card'),
235
- gr.Slider(0.1, 1, 0.4, step=0.1, label='Confidence Threshold')],
236
- [gr.Textbox(label="OCR Result"),
237
- gr.JSON(label="Structured Data"),
238
- gr.File(label="Download CSV"),
239
- gr.Textbox(label="Error Log")],
240
- title=title,
241
- description=description,
242
- css=".gr-interface {max-width: 800px !important;}"
 
243
  )
244
 
245
- demo.launch()
 
 
3
  from PIL import Image
4
  import gradio as gr
5
  import numpy as np
 
6
  import logging
 
7
  import tempfile
8
  import pandas as pd
 
9
  import re
10
  import traceback
11
+ import zxingcpp
12
 
13
+ # --------------------------
14
+ # Configuration & Constants
15
+ # --------------------------
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
+ COUNTRY_CODES = {
20
+ 'SAUDI': {'code': '+966', 'pattern': r'^(\+9665\d{8}|05\d{8})$'},
21
+ 'UAE': {'code': '+971', 'pattern': r'^(\+9715\d{8}|05\d{8})$'}
22
+ }
23
+
24
+ VALIDATION_PATTERNS = {
25
+ 'email': re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', re.IGNORECASE),
26
+ 'website': re.compile(r'(?:https?://)?(?:www\.)?([A-Za-z0-9-]+\.[A-Za-z]{2,})'),
27
+ 'name': re.compile(r'^[A-Z][a-z]+(?:\s+[A-Z][a-z]+){1,2}$')
28
+ }
29
+
30
+ # --------------------------
31
+ # Core Processing Functions
32
+ # --------------------------
33
+
34
+ def process_phone_number(raw_number: str) -> str:
35
+ """Validate and standardize phone numbers for supported countries"""
36
+ cleaned = re.sub(r'[^\d+]', '', raw_number)
37
+
38
+ for country, config in COUNTRY_CODES.items():
39
+ if re.match(config['pattern'], cleaned):
40
+ if cleaned.startswith('0'):
41
+ return f"{config['code']}{cleaned[1:]}"
42
+ if cleaned.startswith('5'):
43
+ return f"{config['code']}{cleaned}"
44
+ return cleaned
45
+ return None
46
+
47
+ def extract_contact_info(text: str) -> dict:
48
+ """Extract and validate all contact information from text"""
49
+ contacts = {
50
+ 'phones': set(),
51
+ 'emails': set(),
52
+ 'websites': set()
53
+ }
54
+
55
+ # Phone number extraction
56
+ for match in re.finditer(r'(\+?\d{10,13}|05\d{8})', text):
57
+ if processed := process_phone_number(match.group()):
58
+ contacts['phones'].add(processed)
59
+
60
+ # Email validation
61
+ contacts['emails'].update(
62
+ email.lower() for email in VALIDATION_PATTERNS['email'].findall(text)
63
+ )
64
+
65
+ # Website normalization
66
+ for match in VALIDATION_PATTERNS['website'].finditer(text):
67
+ domain = match.group(1).lower()
68
+ if '.' in domain:
69
+ contacts['websites'].add(f"www.{domain.split('/')[0]}")
70
+
71
+ return {k: list(v) for k, v in contacts.items() if v}
72
+
73
+ def process_entities(entities: list, ocr_text: list) -> dict:
74
+ """Process GLiNER entities with validation and fallbacks"""
75
+ result = {
76
+ 'name': None,
77
+ 'company': None,
78
+ 'title': None,
79
+ 'address': None
80
+ }
81
+
82
+ # Entity extraction
83
+ for entity in entities:
84
+ label = entity['label'].lower()
85
+ text = entity['text'].strip()
86
+
87
+ if label == 'person name' and VALIDATION_PATTERNS['name'].match(text):
88
+ result['name'] = text.title()
89
+ elif label == 'company name':
90
+ result['company'] = text
91
+ elif label == 'job title':
92
+ result['title'] = text.title()
93
+ elif label == 'address':
94
+ result['address'] = text
95
+
96
+ # Name fallback from OCR text
97
+ if not result['name']:
98
+ for text in ocr_text:
99
+ if VALIDATION_PATTERNS['name'].match(text):
100
+ result['name'] = text.title()
101
+ break
102
+
103
+ return result
104
+
105
+ # --------------------------
106
+ # Main Processing Pipeline
107
+ # --------------------------
108
+
109
+ def process_business_card(img: Image.Image, confidence: float) -> tuple:
110
+ """Full processing pipeline for business card images"""
111
  try:
112
+ # Initialize OCR
113
+ ocr_engine = PaddleOCR(lang='en', use_gpu=False)
 
 
114
 
115
+ # OCR Processing
116
+ ocr_result = ocr_engine.ocr(np.array(img), cls=True)
117
+ ocr_text = [line[1][0] for line in ocr_result[0]]
118
+ full_text = " ".join(ocr_text)
 
 
 
 
 
119
 
120
+ # Entity Recognition
121
+ labels = ["person name", "company name", "job title",
122
+ "phone number", "email address", "address",
123
+ "website"]
124
+ entities = gliner_model.predict_entities(full_text, labels, threshold=confidence)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ # Data Extraction
127
+ contacts = extract_contact_info(full_text)
128
+ entity_data = process_entities(entities, ocr_text)
129
+ qr_data = zxingcpp.read_barcodes(np.array(img.convert('RGB')))
 
 
 
 
 
130
 
131
+ # Compile Final Results
132
  results = {
133
+ 'Person Name': entity_data['name'],
134
+ 'Company Name': entity_data['company'] or (
135
+ contacts['emails'][0].split('@')[1].split('.')[0].title()
136
+ if contacts['emails'] else None
137
+ ),
138
+ 'Job Title': entity_data['title'],
139
+ 'Phone Numbers': contacts['phones'],
140
+ 'Email Addresses': contacts['emails'],
141
+ 'Address': entity_data['address'] or next(
142
+ (t for t in ocr_text if any(kw in t.lower()
143
+ for kw in {'street', 'ave', 'road'})), None
144
+ ),
145
+ 'Website': contacts['websites'][0] if contacts['websites'] else None,
146
+ 'QR Code': qr_data[0].text if qr_data else None
147
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ # Generate CSV Output
150
+ with tempfile.NamedTemporaryFile(suffix='.csv', delete=False, mode='w') as f:
151
+ pd.DataFrame([results]).to_csv(f)
152
+ csv_path = f.name
153
+
154
+ return full_text, results, csv_path, ""
155
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  except Exception as e:
157
+ logger.error(f"Processing Error: {traceback.format_exc()}")
158
+ return "", {}, None, f"Error: {str(e)}"
159
 
160
+ # --------------------------
161
  # Gradio Interface
162
+ # --------------------------
163
+
164
+ interface = gr.Interface(
165
+ fn=process_business_card,
166
+ inputs=[
167
+ gr.Image(type='pil', label='Upload Business Card'),
168
+ gr.Slider(0.1, 1.0, value=0.4, label='Confidence Threshold')
169
+ ],
170
+ outputs=[
171
+ gr.Textbox(label='OCR Result'),
172
+ gr.JSON(label='Structured Data'),
173
+ gr.File(label='Download CSV'),
174
+ gr.Textbox(label='Error Log')
175
+ ],
176
+ title='Enterprise Business Card Parser',
177
+ description='Multi-country support with comprehensive validation'
178
  )
179
 
180
+ if __name__ == '__main__':
181
+ interface.launch()