Spaces:
Runtime error
Runtime error
Commit
·
7bd8456
1
Parent(s):
da5809b
Add application file
Browse files- easyocr.py +579 -0
- model.py +24 -0
easyocr.py
ADDED
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from .recognition import get_recognizer, get_text
|
4 |
+
from .utils import group_text_box, get_image_list, calculate_md5, get_paragraph,\
|
5 |
+
download_and_unzip, printProgressBar, diff, reformat_input,\
|
6 |
+
make_rotated_img_list, set_result_with_confidence,\
|
7 |
+
reformat_input_batched, merge_to_free
|
8 |
+
from .config import *
|
9 |
+
from bidi import get_display
|
10 |
+
import numpy as np
|
11 |
+
import cv2
|
12 |
+
import torch
|
13 |
+
import os
|
14 |
+
import sys
|
15 |
+
from PIL import Image
|
16 |
+
from logging import getLogger
|
17 |
+
import yaml
|
18 |
+
import json
|
19 |
+
|
20 |
+
if sys.version_info[0] == 2:
|
21 |
+
from io import open
|
22 |
+
from six.moves.urllib.request import urlretrieve
|
23 |
+
from pathlib2 import Path
|
24 |
+
else:
|
25 |
+
from urllib.request import urlretrieve
|
26 |
+
from pathlib import Path
|
27 |
+
|
28 |
+
LOGGER = getLogger(__name__)
|
29 |
+
|
30 |
+
class Reader(object):
|
31 |
+
|
32 |
+
def __init__(self, lang_list, gpu=True, model_storage_directory=None,
|
33 |
+
user_network_directory=None, detect_network="craft",
|
34 |
+
recog_network='standard', download_enabled=True,
|
35 |
+
detector=True, recognizer=True, verbose=True,
|
36 |
+
quantize=True, cudnn_benchmark=False):
|
37 |
+
"""Create an EasyOCR Reader
|
38 |
+
|
39 |
+
Parameters:
|
40 |
+
lang_list (list): Language codes (ISO 639) for languages to be recognized during analysis.
|
41 |
+
|
42 |
+
gpu (bool): Enable GPU support (default)
|
43 |
+
|
44 |
+
model_storage_directory (string): Path to directory for model data. If not specified,
|
45 |
+
models will be read from a directory as defined by the environment variable
|
46 |
+
EASYOCR_MODULE_PATH (preferred), MODULE_PATH (if defined), or ~/.EasyOCR/.
|
47 |
+
|
48 |
+
user_network_directory (string): Path to directory for custom network architecture.
|
49 |
+
If not specified, it is as defined by the environment variable
|
50 |
+
EASYOCR_MODULE_PATH (preferred), MODULE_PATH (if defined), or ~/.EasyOCR/.
|
51 |
+
|
52 |
+
download_enabled (bool): Enabled downloading of model data via HTTP (default).
|
53 |
+
"""
|
54 |
+
self.verbose = verbose
|
55 |
+
self.download_enabled = download_enabled
|
56 |
+
|
57 |
+
self.model_storage_directory = MODULE_PATH + '/model'
|
58 |
+
if model_storage_directory:
|
59 |
+
self.model_storage_directory = model_storage_directory
|
60 |
+
Path(self.model_storage_directory).mkdir(parents=True, exist_ok=True)
|
61 |
+
|
62 |
+
self.user_network_directory = MODULE_PATH + '/user_network'
|
63 |
+
if user_network_directory:
|
64 |
+
self.user_network_directory = user_network_directory
|
65 |
+
Path(self.user_network_directory).mkdir(parents=True, exist_ok=True)
|
66 |
+
sys.path.append(self.user_network_directory)
|
67 |
+
|
68 |
+
if gpu is False:
|
69 |
+
self.device = 'cpu'
|
70 |
+
if verbose:
|
71 |
+
LOGGER.warning('Using CPU. Note: This module is much faster with a GPU.')
|
72 |
+
elif gpu is True:
|
73 |
+
if torch.cuda.is_available():
|
74 |
+
self.device = 'cuda'
|
75 |
+
elif torch.backends.mps.is_available():
|
76 |
+
self.device = 'mps'
|
77 |
+
else:
|
78 |
+
self.device = 'cpu'
|
79 |
+
if verbose:
|
80 |
+
LOGGER.warning('Neither CUDA nor MPS are available - defaulting to CPU. Note: This module is much faster with a GPU.')
|
81 |
+
else:
|
82 |
+
self.device = gpu
|
83 |
+
|
84 |
+
self.detection_models = detection_models
|
85 |
+
self.recognition_models = recognition_models
|
86 |
+
|
87 |
+
# check and download detection model
|
88 |
+
self.support_detection_network = ['craft', 'dbnet18']
|
89 |
+
self.quantize=quantize,
|
90 |
+
self.cudnn_benchmark=cudnn_benchmark
|
91 |
+
if detector:
|
92 |
+
detector_path = self.getDetectorPath(detect_network)
|
93 |
+
|
94 |
+
# recognition model
|
95 |
+
separator_list = {}
|
96 |
+
|
97 |
+
if recog_network in ['standard'] + [model for model in recognition_models['gen1']] + [model for model in recognition_models['gen2']]:
|
98 |
+
if recog_network in [model for model in recognition_models['gen1']]:
|
99 |
+
model = recognition_models['gen1'][recog_network]
|
100 |
+
recog_network = 'generation1'
|
101 |
+
self.model_lang = model['model_script']
|
102 |
+
elif recog_network in [model for model in recognition_models['gen2']]:
|
103 |
+
model = recognition_models['gen2'][recog_network]
|
104 |
+
recog_network = 'generation2'
|
105 |
+
self.model_lang = model['model_script']
|
106 |
+
else: # auto-detect
|
107 |
+
unknown_lang = set(lang_list) - set(all_lang_list)
|
108 |
+
if unknown_lang != set():
|
109 |
+
raise ValueError(unknown_lang, 'is not supported')
|
110 |
+
# choose recognition model
|
111 |
+
if lang_list == ['en']:
|
112 |
+
self.setModelLanguage('english', lang_list, ['en'], '["en"]')
|
113 |
+
model = recognition_models['gen2']['english_g2']
|
114 |
+
recog_network = 'generation2'
|
115 |
+
elif 'th' in lang_list:
|
116 |
+
self.setModelLanguage('thai', lang_list, ['th','en'], '["th","en"]')
|
117 |
+
model = recognition_models['gen1']['thai_g1']
|
118 |
+
recog_network = 'generation1'
|
119 |
+
elif 'ch_tra' in lang_list:
|
120 |
+
self.setModelLanguage('chinese_tra', lang_list, ['ch_tra','en'], '["ch_tra","en"]')
|
121 |
+
model = recognition_models['gen1']['zh_tra_g1']
|
122 |
+
recog_network = 'generation1'
|
123 |
+
elif 'ch_sim' in lang_list:
|
124 |
+
self.setModelLanguage('chinese_sim', lang_list, ['ch_sim','en'], '["ch_sim","en"]')
|
125 |
+
model = recognition_models['gen2']['zh_sim_g2']
|
126 |
+
recog_network = 'generation2'
|
127 |
+
elif 'ja' in lang_list:
|
128 |
+
self.setModelLanguage('japanese', lang_list, ['ja','en'], '["ja","en"]')
|
129 |
+
model = recognition_models['gen2']['japanese_g2']
|
130 |
+
recog_network = 'generation2'
|
131 |
+
elif 'ko' in lang_list:
|
132 |
+
self.setModelLanguage('korean', lang_list, ['ko','en'], '["ko","en"]')
|
133 |
+
model = recognition_models['gen2']['korean_g2']
|
134 |
+
recog_network = 'generation2'
|
135 |
+
elif 'ta' in lang_list:
|
136 |
+
self.setModelLanguage('tamil', lang_list, ['ta','en'], '["ta","en"]')
|
137 |
+
model = recognition_models['gen1']['tamil_g1']
|
138 |
+
recog_network = 'generation1'
|
139 |
+
elif 'te' in lang_list:
|
140 |
+
self.setModelLanguage('telugu', lang_list, ['te','en'], '["te","en"]')
|
141 |
+
model = recognition_models['gen2']['telugu_g2']
|
142 |
+
recog_network = 'generation2'
|
143 |
+
elif 'kn' in lang_list:
|
144 |
+
self.setModelLanguage('kannada', lang_list, ['kn','en'], '["kn","en"]')
|
145 |
+
model = recognition_models['gen2']['kannada_g2']
|
146 |
+
recog_network = 'generation2'
|
147 |
+
elif set(lang_list) & set(bengali_lang_list):
|
148 |
+
self.setModelLanguage('bengali', lang_list, bengali_lang_list+['en'], '["bn","as","en"]')
|
149 |
+
model = recognition_models['gen1']['bengali_g1']
|
150 |
+
recog_network = 'generation1'
|
151 |
+
elif set(lang_list) & set(arabic_lang_list):
|
152 |
+
self.setModelLanguage('arabic', lang_list, arabic_lang_list+['en'], '["ar","fa","ur","ug","en"]')
|
153 |
+
model = recognition_models['gen1']['arabic_g1']
|
154 |
+
recog_network = 'generation1'
|
155 |
+
elif set(lang_list) & set(devanagari_lang_list):
|
156 |
+
self.setModelLanguage('devanagari', lang_list, devanagari_lang_list+['en'], '["hi","mr","ne","en"]')
|
157 |
+
model = recognition_models['gen1']['devanagari_g1']
|
158 |
+
recog_network = 'generation1'
|
159 |
+
elif set(lang_list) & set(cyrillic_lang_list):
|
160 |
+
self.setModelLanguage('cyrillic', lang_list, cyrillic_lang_list+['en'],
|
161 |
+
'["ru","rs_cyrillic","be","bg","uk","mn","en"]')
|
162 |
+
model = recognition_models['gen2']['cyrillic_g2']
|
163 |
+
recog_network = 'generation2'
|
164 |
+
else:
|
165 |
+
self.model_lang = 'latin'
|
166 |
+
model = recognition_models['gen2']['latin_g2']
|
167 |
+
recog_network = 'generation2'
|
168 |
+
self.character = model['characters']
|
169 |
+
|
170 |
+
model_path = os.path.join(self.model_storage_directory, model['filename'])
|
171 |
+
# check recognition model file
|
172 |
+
if recognizer:
|
173 |
+
if os.path.isfile(model_path) == False:
|
174 |
+
if not self.download_enabled:
|
175 |
+
raise FileNotFoundError("Missing %s and downloads disabled" % model_path)
|
176 |
+
LOGGER.warning('Downloading recognition model, please wait. '
|
177 |
+
'This may take several minutes depending upon your network connection.')
|
178 |
+
download_and_unzip(model['url'], model['filename'], self.model_storage_directory, verbose)
|
179 |
+
assert calculate_md5(model_path) == model['md5sum'], corrupt_msg
|
180 |
+
LOGGER.info('Download complete.')
|
181 |
+
elif calculate_md5(model_path) != model['md5sum']:
|
182 |
+
if not self.download_enabled:
|
183 |
+
raise FileNotFoundError("MD5 mismatch for %s and downloads disabled" % model_path)
|
184 |
+
LOGGER.warning(corrupt_msg)
|
185 |
+
os.remove(model_path)
|
186 |
+
LOGGER.warning('Re-downloading the recognition model, please wait. '
|
187 |
+
'This may take several minutes depending upon your network connection.')
|
188 |
+
download_and_unzip(model['url'], model['filename'], self.model_storage_directory, verbose)
|
189 |
+
assert calculate_md5(model_path) == model['md5sum'], corrupt_msg
|
190 |
+
LOGGER.info('Download complete')
|
191 |
+
self.setLanguageList(lang_list, model)
|
192 |
+
|
193 |
+
else: # user-defined model
|
194 |
+
with open(os.path.join(self.user_network_directory, recog_network+ '.yaml'), encoding='utf8') as file:
|
195 |
+
recog_config = yaml.load(file, Loader=yaml.FullLoader)
|
196 |
+
|
197 |
+
global imgH # if custom model, save this variable. (from *.yaml)
|
198 |
+
if recog_config['imgH']:
|
199 |
+
imgH = recog_config['imgH']
|
200 |
+
|
201 |
+
available_lang = recog_config['lang_list']
|
202 |
+
self.setModelLanguage(recog_network, lang_list, available_lang, str(available_lang))
|
203 |
+
#char_file = os.path.join(self.user_network_directory, recog_network+ '.txt')
|
204 |
+
self.character = recog_config['character_list']
|
205 |
+
model_file = recog_network+ '.pth'
|
206 |
+
model_path = os.path.join(self.model_storage_directory, model_file)
|
207 |
+
self.setLanguageList(lang_list, recog_config)
|
208 |
+
|
209 |
+
dict_list = {}
|
210 |
+
for lang in lang_list:
|
211 |
+
dict_list[lang] = os.path.join(BASE_PATH, 'dict', lang + ".txt")
|
212 |
+
|
213 |
+
if detector:
|
214 |
+
self.detector = self.initDetector(detector_path)
|
215 |
+
|
216 |
+
if recognizer:
|
217 |
+
if recog_network == 'generation1':
|
218 |
+
network_params = {
|
219 |
+
'input_channel': 1,
|
220 |
+
'output_channel': 512,
|
221 |
+
'hidden_size': 512
|
222 |
+
}
|
223 |
+
elif recog_network == 'generation2':
|
224 |
+
network_params = {
|
225 |
+
'input_channel': 1,
|
226 |
+
'output_channel': 256,
|
227 |
+
'hidden_size': 256
|
228 |
+
}
|
229 |
+
else:
|
230 |
+
network_params = recog_config['network_params']
|
231 |
+
self.recognizer, self.converter = get_recognizer(recog_network, network_params,\
|
232 |
+
self.character, separator_list,\
|
233 |
+
dict_list, model_path, device = self.device, quantize=quantize)
|
234 |
+
|
235 |
+
def getDetectorPath(self, detect_network):
|
236 |
+
if detect_network in self.support_detection_network:
|
237 |
+
self.detect_network = detect_network
|
238 |
+
if self.detect_network == 'craft':
|
239 |
+
from .detection import get_detector, get_textbox
|
240 |
+
elif self.detect_network in ['dbnet18']:
|
241 |
+
from .detection_db import get_detector, get_textbox
|
242 |
+
else:
|
243 |
+
raise RuntimeError("Unsupport detector network. Support networks are craft and dbnet18.")
|
244 |
+
self.get_textbox = get_textbox
|
245 |
+
self.get_detector = get_detector
|
246 |
+
corrupt_msg = 'MD5 hash mismatch, possible file corruption'
|
247 |
+
detector_path = os.path.join(self.model_storage_directory, self.detection_models[self.detect_network]['filename'])
|
248 |
+
if os.path.isfile(detector_path) == False:
|
249 |
+
if not self.download_enabled:
|
250 |
+
raise FileNotFoundError("Missing %s and downloads disabled" % detector_path)
|
251 |
+
LOGGER.warning('Downloading detection model, please wait. '
|
252 |
+
'This may take several minutes depending upon your network connection.')
|
253 |
+
download_and_unzip(self.detection_models[self.detect_network]['url'], self.detection_models[self.detect_network]['filename'], self.model_storage_directory, self.verbose)
|
254 |
+
assert calculate_md5(detector_path) == self.detection_models[self.detect_network]['md5sum'], corrupt_msg
|
255 |
+
LOGGER.info('Download complete')
|
256 |
+
elif calculate_md5(detector_path) != self.detection_models[self.detect_network]['md5sum']:
|
257 |
+
if not self.download_enabled:
|
258 |
+
raise FileNotFoundError("MD5 mismatch for %s and downloads disabled" % detector_path)
|
259 |
+
LOGGER.warning(corrupt_msg)
|
260 |
+
os.remove(detector_path)
|
261 |
+
LOGGER.warning('Re-downloading the detection model, please wait. '
|
262 |
+
'This may take several minutes depending upon your network connection.')
|
263 |
+
download_and_unzip(self.detection_models[self.detect_network]['url'], self.detection_models[self.detect_network]['filename'], self.model_storage_directory, self.verbose)
|
264 |
+
assert calculate_md5(detector_path) == self.detection_models[self.detect_network]['md5sum'], corrupt_msg
|
265 |
+
else:
|
266 |
+
raise RuntimeError("Unsupport detector network. Support networks are {}.".format(', '.join(self.support_detection_network)))
|
267 |
+
|
268 |
+
return detector_path
|
269 |
+
|
270 |
+
def initDetector(self, detector_path):
|
271 |
+
return self.get_detector(detector_path,
|
272 |
+
device = self.device,
|
273 |
+
quantize = self.quantize,
|
274 |
+
cudnn_benchmark = self.cudnn_benchmark
|
275 |
+
)
|
276 |
+
|
277 |
+
def setDetector(self, detect_network):
|
278 |
+
detector_path = self.getDetectorPath(detect_network)
|
279 |
+
self.detector = self.initDetector(detector_path)
|
280 |
+
|
281 |
+
def setModelLanguage(self, language, lang_list, list_lang, list_lang_string):
|
282 |
+
self.model_lang = language
|
283 |
+
if set(lang_list) - set(list_lang) != set():
|
284 |
+
if language == 'ch_tra' or language == 'ch_sim':
|
285 |
+
language = 'chinese'
|
286 |
+
raise ValueError(language.capitalize() + ' is only compatible with English, try lang_list=' + list_lang_string)
|
287 |
+
|
288 |
+
def getChar(self, fileName):
|
289 |
+
char_file = os.path.join(BASE_PATH, 'character', fileName)
|
290 |
+
with open(char_file, "r", encoding="utf-8-sig") as input_file:
|
291 |
+
list = input_file.read().splitlines()
|
292 |
+
char = ''.join(list)
|
293 |
+
return char
|
294 |
+
|
295 |
+
def setLanguageList(self, lang_list, model):
|
296 |
+
self.lang_char = []
|
297 |
+
for lang in lang_list:
|
298 |
+
char_file = os.path.join(BASE_PATH, 'character', lang + "_char.txt")
|
299 |
+
with open(char_file, "r", encoding = "utf-8-sig") as input_file:
|
300 |
+
char_list = input_file.read().splitlines()
|
301 |
+
self.lang_char += char_list
|
302 |
+
if model.get('symbols'):
|
303 |
+
symbol = model['symbols']
|
304 |
+
elif model.get('character_list'):
|
305 |
+
symbol = model['character_list']
|
306 |
+
else:
|
307 |
+
symbol = '0123456789!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ '
|
308 |
+
self.lang_char = set(self.lang_char).union(set(symbol))
|
309 |
+
self.lang_char = ''.join(self.lang_char)
|
310 |
+
|
311 |
+
def detect(self, img, min_size = 20, text_threshold = 0.7, low_text = 0.4,\
|
312 |
+
link_threshold = 0.4,canvas_size = 2560, mag_ratio = 1.,\
|
313 |
+
slope_ths = 0.1, ycenter_ths = 0.5, height_ths = 0.5,\
|
314 |
+
width_ths = 0.5, add_margin = 0.1, reformat=True, optimal_num_chars=None,
|
315 |
+
threshold = 0.2, bbox_min_score = 0.2, bbox_min_size = 3, max_candidates = 0,
|
316 |
+
):
|
317 |
+
|
318 |
+
if reformat:
|
319 |
+
img, img_cv_grey = reformat_input(img)
|
320 |
+
|
321 |
+
text_box_list = self.get_textbox(self.detector,
|
322 |
+
img,
|
323 |
+
canvas_size = canvas_size,
|
324 |
+
mag_ratio = mag_ratio,
|
325 |
+
text_threshold = text_threshold,
|
326 |
+
link_threshold = link_threshold,
|
327 |
+
low_text = low_text,
|
328 |
+
poly = False,
|
329 |
+
device = self.device,
|
330 |
+
optimal_num_chars = optimal_num_chars,
|
331 |
+
threshold = threshold,
|
332 |
+
bbox_min_score = bbox_min_score,
|
333 |
+
bbox_min_size = bbox_min_size,
|
334 |
+
max_candidates = max_candidates,
|
335 |
+
)
|
336 |
+
|
337 |
+
horizontal_list_agg, free_list_agg = [], []
|
338 |
+
for text_box in text_box_list:
|
339 |
+
horizontal_list, free_list = group_text_box(text_box, slope_ths,
|
340 |
+
ycenter_ths, height_ths,
|
341 |
+
width_ths, add_margin,
|
342 |
+
(optimal_num_chars is None))
|
343 |
+
if min_size:
|
344 |
+
horizontal_list = [i for i in horizontal_list if max(
|
345 |
+
i[1] - i[0], i[3] - i[2]) > min_size]
|
346 |
+
free_list = [i for i in free_list if max(
|
347 |
+
diff([c[0] for c in i]), diff([c[1] for c in i])) > min_size]
|
348 |
+
horizontal_list_agg.append(horizontal_list)
|
349 |
+
free_list_agg.append(free_list)
|
350 |
+
|
351 |
+
return horizontal_list_agg, free_list_agg
|
352 |
+
|
353 |
+
def recognize(self, img_cv_grey, horizontal_list=None, free_list=None,\
|
354 |
+
decoder = 'greedy', beamWidth= 5, batch_size = 1,\
|
355 |
+
workers = 0, allowlist = None, blocklist = None, detail = 1,\
|
356 |
+
rotation_info = None,paragraph = False,\
|
357 |
+
contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\
|
358 |
+
y_ths = 0.5, x_ths = 1.0, reformat=True, output_format='standard'):
|
359 |
+
|
360 |
+
if reformat:
|
361 |
+
img, img_cv_grey = reformat_input(img_cv_grey)
|
362 |
+
|
363 |
+
if allowlist:
|
364 |
+
ignore_char = ''.join(set(self.character)-set(allowlist))
|
365 |
+
elif blocklist:
|
366 |
+
ignore_char = ''.join(set(blocklist))
|
367 |
+
else:
|
368 |
+
ignore_char = ''.join(set(self.character)-set(self.lang_char))
|
369 |
+
|
370 |
+
if self.model_lang in ['chinese_tra','chinese_sim']: decoder = 'greedy'
|
371 |
+
|
372 |
+
if (horizontal_list==None) and (free_list==None):
|
373 |
+
y_max, x_max = img_cv_grey.shape
|
374 |
+
horizontal_list = [[0, x_max, 0, y_max]]
|
375 |
+
free_list = []
|
376 |
+
|
377 |
+
# without gpu/parallelization, it is faster to process image one by one
|
378 |
+
if ((batch_size == 1) or (self.device == 'cpu')) and not rotation_info:
|
379 |
+
result = []
|
380 |
+
for bbox in horizontal_list:
|
381 |
+
h_list = [bbox]
|
382 |
+
f_list = []
|
383 |
+
image_list, max_width = get_image_list(h_list, f_list, img_cv_grey, model_height = imgH)
|
384 |
+
result0 = get_text(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\
|
385 |
+
ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\
|
386 |
+
workers, self.device)
|
387 |
+
result += result0
|
388 |
+
for bbox in free_list:
|
389 |
+
h_list = []
|
390 |
+
f_list = [bbox]
|
391 |
+
image_list, max_width = get_image_list(h_list, f_list, img_cv_grey, model_height = imgH)
|
392 |
+
result0 = get_text(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\
|
393 |
+
ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\
|
394 |
+
workers, self.device)
|
395 |
+
result += result0
|
396 |
+
# default mode will try to process multiple boxes at the same time
|
397 |
+
else:
|
398 |
+
image_list, max_width = get_image_list(horizontal_list, free_list, img_cv_grey, model_height = imgH)
|
399 |
+
image_len = len(image_list)
|
400 |
+
if rotation_info and image_list:
|
401 |
+
image_list = make_rotated_img_list(rotation_info, image_list)
|
402 |
+
max_width = max(max_width, imgH)
|
403 |
+
|
404 |
+
result = get_text(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\
|
405 |
+
ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\
|
406 |
+
workers, self.device)
|
407 |
+
|
408 |
+
if rotation_info and (horizontal_list+free_list):
|
409 |
+
# Reshape result to be a list of lists, each row being for
|
410 |
+
# one of the rotations (first row being no rotation)
|
411 |
+
result = set_result_with_confidence(
|
412 |
+
[result[image_len*i:image_len*(i+1)] for i in range(len(rotation_info) + 1)])
|
413 |
+
|
414 |
+
if self.model_lang == 'arabic':
|
415 |
+
direction_mode = 'rtl'
|
416 |
+
result = [list(item) for item in result]
|
417 |
+
for item in result:
|
418 |
+
item[1] = get_display(item[1])
|
419 |
+
else:
|
420 |
+
direction_mode = 'ltr'
|
421 |
+
|
422 |
+
if paragraph:
|
423 |
+
result = get_paragraph(result, x_ths=x_ths, y_ths=y_ths, mode = direction_mode)
|
424 |
+
|
425 |
+
if detail == 0:
|
426 |
+
return [item[1] for item in result]
|
427 |
+
elif output_format == 'dict':
|
428 |
+
if paragraph:
|
429 |
+
return [ {'boxes':item[0],'text':item[1]} for item in result]
|
430 |
+
return [ {'boxes':item[0],'text':item[1],'confident':item[2]} for item in result]
|
431 |
+
elif output_format == 'json':
|
432 |
+
if paragraph:
|
433 |
+
return [json.dumps({'boxes':[list(map(int, lst)) for lst in item[0]],'text':item[1]}, ensure_ascii=False) for item in result]
|
434 |
+
return [json.dumps({'boxes':[list(map(int, lst)) for lst in item[0]],'text':item[1],'confident':item[2]}, ensure_ascii=False) for item in result]
|
435 |
+
elif output_format == 'free_merge':
|
436 |
+
return merge_to_free(result, free_list)
|
437 |
+
else:
|
438 |
+
return result
|
439 |
+
|
440 |
+
def readtext(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\
|
441 |
+
workers = 0, allowlist = None, blocklist = None, detail = 1,\
|
442 |
+
rotation_info = None, paragraph = False, min_size = 20,\
|
443 |
+
contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\
|
444 |
+
text_threshold = 0.7, low_text = 0.4, link_threshold = 0.4,\
|
445 |
+
canvas_size = 2560, mag_ratio = 1.,\
|
446 |
+
slope_ths = 0.1, ycenter_ths = 0.5, height_ths = 0.5,\
|
447 |
+
width_ths = 0.5, y_ths = 0.5, x_ths = 1.0, add_margin = 0.1,
|
448 |
+
threshold = 0.2, bbox_min_score = 0.2, bbox_min_size = 3, max_candidates = 0,
|
449 |
+
output_format='standard'):
|
450 |
+
'''
|
451 |
+
Parameters:
|
452 |
+
image: file path or numpy-array or a byte stream object
|
453 |
+
'''
|
454 |
+
img, img_cv_grey = reformat_input(image)
|
455 |
+
|
456 |
+
horizontal_list, free_list = self.detect(img,
|
457 |
+
min_size = min_size, text_threshold = text_threshold,\
|
458 |
+
low_text = low_text, link_threshold = link_threshold,\
|
459 |
+
canvas_size = canvas_size, mag_ratio = mag_ratio,\
|
460 |
+
slope_ths = slope_ths, ycenter_ths = ycenter_ths,\
|
461 |
+
height_ths = height_ths, width_ths= width_ths,\
|
462 |
+
add_margin = add_margin, reformat = False,\
|
463 |
+
threshold = threshold, bbox_min_score = bbox_min_score,\
|
464 |
+
bbox_min_size = bbox_min_size, max_candidates = max_candidates
|
465 |
+
)
|
466 |
+
# get the 1st result from hor & free list as self.detect returns a list of depth 3
|
467 |
+
horizontal_list, free_list = horizontal_list[0], free_list[0]
|
468 |
+
result = self.recognize(img_cv_grey, horizontal_list, free_list,\
|
469 |
+
decoder, beamWidth, batch_size,\
|
470 |
+
workers, allowlist, blocklist, detail, rotation_info,\
|
471 |
+
paragraph, contrast_ths, adjust_contrast,\
|
472 |
+
filter_ths, y_ths, x_ths, False, output_format)
|
473 |
+
|
474 |
+
return result
|
475 |
+
|
476 |
+
def readtextlang(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\
|
477 |
+
workers = 0, allowlist = None, blocklist = None, detail = 1,\
|
478 |
+
rotation_info = None, paragraph = False, min_size = 20,\
|
479 |
+
contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\
|
480 |
+
text_threshold = 0.7, low_text = 0.4, link_threshold = 0.4,\
|
481 |
+
canvas_size = 2560, mag_ratio = 1.,\
|
482 |
+
slope_ths = 0.1, ycenter_ths = 0.5, height_ths = 0.5,\
|
483 |
+
width_ths = 0.5, y_ths = 0.5, x_ths = 1.0, add_margin = 0.1,
|
484 |
+
threshold = 0.2, bbox_min_score = 0.2, bbox_min_size = 3, max_candidates = 0,
|
485 |
+
output_format='standard'):
|
486 |
+
'''
|
487 |
+
Parameters:
|
488 |
+
image: file path or numpy-array or a byte stream object
|
489 |
+
'''
|
490 |
+
img, img_cv_grey = reformat_input(image)
|
491 |
+
|
492 |
+
horizontal_list, free_list = self.detect(img,
|
493 |
+
min_size = min_size, text_threshold = text_threshold,\
|
494 |
+
low_text = low_text, link_threshold = link_threshold,\
|
495 |
+
canvas_size = canvas_size, mag_ratio = mag_ratio,\
|
496 |
+
slope_ths = slope_ths, ycenter_ths = ycenter_ths,\
|
497 |
+
height_ths = height_ths, width_ths= width_ths,\
|
498 |
+
add_margin = add_margin, reformat = False,\
|
499 |
+
threshold = threshold, bbox_min_score = bbox_min_score,\
|
500 |
+
bbox_min_size = bbox_min_size, max_candidates = max_candidates
|
501 |
+
)
|
502 |
+
# get the 1st result from hor & free list as self.detect returns a list of depth 3
|
503 |
+
horizontal_list, free_list = horizontal_list[0], free_list[0]
|
504 |
+
result = self.recognize(img_cv_grey, horizontal_list, free_list,\
|
505 |
+
decoder, beamWidth, batch_size,\
|
506 |
+
workers, allowlist, blocklist, detail, rotation_info,\
|
507 |
+
paragraph, contrast_ths, adjust_contrast,\
|
508 |
+
filter_ths, y_ths, x_ths, False, output_format)
|
509 |
+
|
510 |
+
char = []
|
511 |
+
directory = 'characters/'
|
512 |
+
for i in range(len(result)):
|
513 |
+
char.append(result[i][1])
|
514 |
+
|
515 |
+
def search(arr,x):
|
516 |
+
g = False
|
517 |
+
for i in range(len(arr)):
|
518 |
+
if arr[i]==x:
|
519 |
+
g = True
|
520 |
+
return 1
|
521 |
+
if g == False:
|
522 |
+
return -1
|
523 |
+
def tupleadd(i):
|
524 |
+
a = result[i]
|
525 |
+
b = a + (filename[0:2],)
|
526 |
+
return b
|
527 |
+
|
528 |
+
for filename in os.listdir(directory):
|
529 |
+
if filename.endswith(".txt"):
|
530 |
+
with open ('characters/'+ filename,'rt',encoding="utf8") as myfile:
|
531 |
+
chartrs = str(myfile.read().splitlines()).replace('\n','')
|
532 |
+
for i in range(len(char)):
|
533 |
+
res = search(chartrs,char[i])
|
534 |
+
if res != -1:
|
535 |
+
if filename[0:2]=="en" or filename[0:2]=="ch":
|
536 |
+
print(tupleadd(i))
|
537 |
+
|
538 |
+
def readtext_batched(self, image, n_width=None, n_height=None,\
|
539 |
+
decoder = 'greedy', beamWidth= 5, batch_size = 1,\
|
540 |
+
workers = 0, allowlist = None, blocklist = None, detail = 1,\
|
541 |
+
rotation_info = None, paragraph = False, min_size = 20,\
|
542 |
+
contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\
|
543 |
+
text_threshold = 0.7, low_text = 0.4, link_threshold = 0.4,\
|
544 |
+
canvas_size = 2560, mag_ratio = 1.,\
|
545 |
+
slope_ths = 0.1, ycenter_ths = 0.5, height_ths = 0.5,\
|
546 |
+
width_ths = 0.5, y_ths = 0.5, x_ths = 1.0, add_margin = 0.1,
|
547 |
+
threshold = 0.2, bbox_min_score = 0.2, bbox_min_size = 3, max_candidates = 0,
|
548 |
+
output_format='standard'):
|
549 |
+
'''
|
550 |
+
Parameters:
|
551 |
+
image: file path or numpy-array or a byte stream object
|
552 |
+
When sending a list of images, they all must of the same size,
|
553 |
+
the following parameters will automatically resize if they are not None
|
554 |
+
n_width: int, new width
|
555 |
+
n_height: int, new height
|
556 |
+
'''
|
557 |
+
img, img_cv_grey = reformat_input_batched(image, n_width, n_height)
|
558 |
+
|
559 |
+
horizontal_list_agg, free_list_agg = self.detect(img,
|
560 |
+
min_size = min_size, text_threshold = text_threshold,\
|
561 |
+
low_text = low_text, link_threshold = link_threshold,\
|
562 |
+
canvas_size = canvas_size, mag_ratio = mag_ratio,\
|
563 |
+
slope_ths = slope_ths, ycenter_ths = ycenter_ths,\
|
564 |
+
height_ths = height_ths, width_ths= width_ths,\
|
565 |
+
add_margin = add_margin, reformat = False,\
|
566 |
+
threshold = threshold, bbox_min_score = bbox_min_score,\
|
567 |
+
bbox_min_size = bbox_min_size, max_candidates = max_candidates
|
568 |
+
)
|
569 |
+
result_agg = []
|
570 |
+
# put img_cv_grey in a list if its a single img
|
571 |
+
img_cv_grey = [img_cv_grey] if len(img_cv_grey.shape) == 2 else img_cv_grey
|
572 |
+
for grey_img, horizontal_list, free_list in zip(img_cv_grey, horizontal_list_agg, free_list_agg):
|
573 |
+
result_agg.append(self.recognize(grey_img, horizontal_list, free_list,\
|
574 |
+
decoder, beamWidth, batch_size,\
|
575 |
+
workers, allowlist, blocklist, detail, rotation_info,\
|
576 |
+
paragraph, contrast_ths, adjust_contrast,\
|
577 |
+
filter_ths, y_ths, x_ths, False, output_format))
|
578 |
+
|
579 |
+
return result_agg
|
model.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import easyocr
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
class EasyOCRModel:
|
5 |
+
def __init__(self):
|
6 |
+
self.reader = easyocr.Reader(['en']) # Initialize with English; add languages if needed.
|
7 |
+
|
8 |
+
def predict(self, image_path: str) -> List[str]:
|
9 |
+
"""
|
10 |
+
Perform OCR on the given image.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
image_path (str): Path to the input image.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
List[str]: Extracted text from the image.
|
17 |
+
"""
|
18 |
+
return self.reader.readtext(image_path, detail=0)
|
19 |
+
|
20 |
+
# Test the model locally
|
21 |
+
if __name__ == "__main__":
|
22 |
+
model = EasyOCRModel()
|
23 |
+
result = model.predict("sample_image.jpg")
|
24 |
+
print(result)
|