codic commited on
Commit
642c3ad
·
verified ·
1 Parent(s): 6bba885

try fixing the website and the qr --before was working

Browse files
Files changed (1) hide show
  1. app.py +92 -57
app.py CHANGED
@@ -18,10 +18,10 @@ import traceback
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
21
- # Set up GLiNER environment variables (adjust if needed)
22
  os.environ['GLINER_HOME'] = './gliner_models'
23
 
24
- # Load GLiNER model (do not change the model)
25
  try:
26
  logger.info("Loading GLiNER model...")
27
  gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
@@ -29,109 +29,144 @@ except Exception as e:
29
  logger.error("Failed to load GLiNER model")
30
  raise e
31
 
32
- # Get a random color (used for drawing bounding boxes, if needed)
33
  def get_random_color():
34
- c = tuple(np.random.randint(0, 256, 3).tolist())
35
- return c
36
 
37
- # Draw OCR bounding boxes (this function is kept for debugging/visualization purposes)
38
  def draw_ocr_bbox(image, boxes, colors):
39
  for i in range(len(boxes)):
40
  box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
41
  image = cv2.polylines(np.array(image), [box], True, colors[i], 2)
42
  return image
43
 
44
- # Scan for a QR code using OpenCV's QRCodeDetector
45
  def scan_qr_code(image):
46
  try:
47
- # Ensure the image is in numpy array format
48
- image_np = np.array(image) if not isinstance(image, np.ndarray) else image
49
  qr_detector = cv2.QRCodeDetector()
50
- data, points, _ = qr_detector.detectAndDecode(image_np)
51
- if data:
52
- return data.strip()
53
- return None
54
  except Exception as e:
55
- logger.error("QR code scanning failed: " + str(e))
56
  return None
57
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # Main inference function
59
  def inference(img: Image.Image, confidence):
60
  try:
61
- # Initialize PaddleOCR for English only (removed other languages)
62
  ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False,
63
- det_model_dir=f'./models/det/en',
64
- cls_model_dir=f'./models/cls/en',
65
- rec_model_dir=f'./models/rec/en')
 
 
66
  img_np = np.array(img)
67
  result = ocr.ocr(img_np, cls=True)[0]
68
-
69
- # Concatenate all recognized texts
70
  ocr_texts = [line[1][0] for line in result]
71
  ocr_text = " ".join(ocr_texts)
72
-
73
- # (Optional) Draw bounding boxes on the image if needed for debugging
74
- image_rgb = img.convert('RGB')
75
- boxes = [line[0] for line in result]
76
- colors = [get_random_color() for _ in boxes]
77
- # Uncomment next two lines if you want to visualize OCR results:
78
- # im_show = draw_ocr_bbox(image_rgb, boxes, colors)
79
- # im_show = Image.fromarray(im_show)
80
-
81
- # Extract entities using GLiNER with updated labels (adding 'website')
82
- labels = ["person name", "company name", "job title", "phone", "email", "address", "website"]
83
  entities = gliner_model.predict_entities(ocr_text, labels, threshold=confidence, flat_ner=True)
84
- results = {label.title(): [] for label in labels}
 
 
 
 
 
 
 
 
 
 
 
 
85
  for entity in entities:
86
- lab = entity["label"].title()
87
- if lab in results:
88
- results[lab].append(entity["text"])
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # Scan the original image for a QR code and add it if found
 
 
 
 
 
 
 
 
 
 
91
  qr_data = scan_qr_code(img)
92
  if qr_data:
93
- results["QR"] = [qr_data]
94
-
95
- # Generate CSV content in memory using BytesIO
 
96
  csv_io = io.BytesIO()
97
- pd.DataFrame([{k: "; ".join(v) for k, v in results.items()}]).to_csv(csv_io, index=False)
98
  csv_io.seek(0)
 
99
  with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="wb") as tmp_file:
100
  tmp_file.write(csv_io.getvalue())
101
  csv_path = tmp_file.name
102
-
103
- # Return tuple: (OCR text, JSON entities, CSV file path, error message)
104
- return ocr_text, {k: "; ".join(v) for k, v in results.items()}, csv_path, ""
105
  except Exception as e:
106
- logger.error("Processing failed: " + traceback.format_exc())
107
  return "", {}, None, f"Error: {str(e)}\n{traceback.format_exc()}"
108
 
109
- # Gradio Interface setup (output structure remains unchanged)
110
- title = 'Business Card Information Extractor'
111
- description = 'Extracts text using PaddleOCR and entities using GLiNER (with added website label) along with QR code scanning.'
112
 
113
- # Examples can be updated accordingly
114
  examples = [
115
- ['example_imgs/example.jpg', 0.5],
116
- ['example_imgs/demo003.jpeg', 0.7],
117
  ]
118
 
119
- css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
 
120
 
121
  if __name__ == '__main__':
122
  demo = gr.Interface(
123
  inference,
124
  [gr.Image(type='pil', label='Upload Business Card'),
125
- gr.Slider(0.1, 1, 0.5, step=0.1, label='Confidence Threshold')],
126
- [gr.Textbox(label="Extracted Text"),
127
- gr.JSON(label="Entities"),
128
  gr.File(label="Download CSV"),
129
- gr.Textbox(label="Error Details")],
130
  title=title,
131
  description=description,
132
  examples=examples,
133
  css=css,
134
  cache_examples=True
135
  )
136
- demo.queue(max_size=10)
137
- demo.launch()
 
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
21
+ # Set up GLiNER environment variables
22
  os.environ['GLINER_HOME'] = './gliner_models'
23
 
24
+ # Load GLiNER model
25
  try:
26
  logger.info("Loading GLiNER model...")
27
  gliner_model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
 
29
  logger.error("Failed to load GLiNER model")
30
  raise e
31
 
32
+ # Helper functions
33
  def get_random_color():
34
+ return tuple(np.random.randint(0, 256, 3).tolist()
 
35
 
 
36
  def draw_ocr_bbox(image, boxes, colors):
37
  for i in range(len(boxes)):
38
  box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
39
  image = cv2.polylines(np.array(image), [box], True, colors[i], 2)
40
  return image
41
 
 
42
  def scan_qr_code(image):
43
  try:
44
+ image_np = np.array(image)
 
45
  qr_detector = cv2.QRCodeDetector()
46
+ data, _, _ = qr_detector.detectAndDecode(image_np)
47
+ return data.strip() if data else None
 
 
48
  except Exception as e:
49
+ logger.error(f"QR scan failed: {str(e)}")
50
  return None
51
 
52
+ def extract_emails(text):
53
+ email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
54
+ return re.findall(email_regex, text)
55
+
56
+ def extract_websites(text):
57
+ website_regex = r"(?:https?://)?(?:www\.)?[A-Za-z0-9-]+\.[A-Za-z]{2,}(?:/\S*)?"
58
+ matches = re.findall(website_regex, text)
59
+ return [m for m in matches if '@' not in m]
60
+
61
+ def clean_phone_number(phone):
62
+ return re.sub(r"[^\d+]", "", phone)
63
+
64
  # Main inference function
65
  def inference(img: Image.Image, confidence):
66
  try:
67
+ # Initialize PaddleOCR
68
  ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False,
69
+ det_model_dir='./models/det/en',
70
+ cls_model_dir='./models/cls/en',
71
+ rec_model_dir='./models/rec/en')
72
+
73
+ # OCR Processing
74
  img_np = np.array(img)
75
  result = ocr.ocr(img_np, cls=True)[0]
 
 
76
  ocr_texts = [line[1][0] for line in result]
77
  ocr_text = " ".join(ocr_texts)
78
+
79
+ # Entity Extraction
80
+ labels = ["person name", "company name", "job title",
81
+ "phone number", "email address", "physical address",
82
+ "website url"]
 
 
 
 
 
 
83
  entities = gliner_model.predict_entities(ocr_text, labels, threshold=confidence, flat_ner=True)
84
+
85
+ results = {
86
+ "Person Name": [],
87
+ "Company Name": [],
88
+ "Job Title": [],
89
+ "Phone Number": [],
90
+ "Email Address": [],
91
+ "Physical Address": [],
92
+ "Website Url": [],
93
+ "QR Code": []
94
+ }
95
+
96
+ # Process GLiNER results
97
  for entity in entities:
98
+ label = entity["label"].title().replace(" ", "")
99
+ if label == "PhoneNumber":
100
+ cleaned = clean_phone_number(entity["text"])
101
+ if cleaned: results["Phone Number"].append(cleaned)
102
+ elif label == "EmailAddress":
103
+ results["Email Address"].append(entity["text"].lower())
104
+ elif label == "WebsiteUrl":
105
+ results["Website Url"].append(entity["text"].lower())
106
+ elif label in results:
107
+ results[label].append(entity["text"])
108
+
109
+ # Regex fallbacks
110
+ if not results["Email Address"]:
111
+ results["Email Address"] = extract_emails(ocr_text)
112
 
113
+ if not results["Website Url"]:
114
+ results["Website Url"] = extract_websites(ocr_text)
115
+
116
+ # Phone number validation
117
+ phone_numbers = []
118
+ for text in ocr_texts:
119
+ numbers = re.findall(r'(?:\+?[0-9]\s?[0-9]+)+', text)
120
+ phone_numbers.extend([clean_phone_number(n) for n in numbers])
121
+ results["Phone Number"] = list(set(phone_numbers + results["Phone Number"]))
122
+
123
+ # QR Code handling
124
  qr_data = scan_qr_code(img)
125
  if qr_data:
126
+ results["QR Code"] = [qr_data]
127
+
128
+ # Create CSV
129
+ csv_data = {k: "; ".join(v) for k, v in results.items() if v}
130
  csv_io = io.BytesIO()
131
+ pd.DataFrame([csv_data]).to_csv(csv_io, index=False)
132
  csv_io.seek(0)
133
+
134
  with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="wb") as tmp_file:
135
  tmp_file.write(csv_io.getvalue())
136
  csv_path = tmp_file.name
137
+
138
+ return ocr_text, csv_data, csv_path, ""
139
+
140
  except Exception as e:
141
+ logger.error(f"Processing failed: {traceback.format_exc()}")
142
  return "", {}, None, f"Error: {str(e)}\n{traceback.format_exc()}"
143
 
144
+ # Gradio Interface
145
+ title = 'Enhanced Business Card Parser'
146
+ description = 'Extracts entities with combined AI and regex validation, including QR codes'
147
 
 
148
  examples = [
149
+ ['example_imgs/example.jpg', 0.4],
150
+ ['example_imgs/demo003.jpeg', 0.5],
151
  ]
152
 
153
+ css = """.output_image, .input_image {height: 40rem !important; width: 100% !important;}
154
+ .gr-interface {max-width: 800px !important;}"""
155
 
156
  if __name__ == '__main__':
157
  demo = gr.Interface(
158
  inference,
159
  [gr.Image(type='pil', label='Upload Business Card'),
160
+ gr.Slider(0.1, 1, 0.4, step=0.1, label='Confidence Threshold')],
161
+ [gr.Textbox(label="OCR Result"),
162
+ gr.JSON(label="Structured Data"),
163
  gr.File(label="Download CSV"),
164
+ gr.Textbox(label="Error Log")],
165
  title=title,
166
  description=description,
167
  examples=examples,
168
  css=css,
169
  cache_examples=True
170
  )
171
+ demo.queue(max_size=20)
172
+ demo.launch()