|
import torch |
|
import pickle |
|
import numpy as np |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
class AttrDict(dict): |
|
def __init__(self, *args, **kwargs): |
|
super(AttrDict, self).__init__(*args, **kwargs) |
|
self.__dict__ = self |
|
|
|
|
|
class BeamEntry: |
|
"information about one single beam at specific time-step" |
|
def __init__(self): |
|
self.prTotal = 0 |
|
self.prNonBlank = 0 |
|
self.prBlank = 0 |
|
self.prText = 1 |
|
self.lmApplied = False |
|
self.labeling = () |
|
|
|
class BeamState: |
|
"information about the beams at specific time-step" |
|
def __init__(self): |
|
self.entries = {} |
|
|
|
def norm(self): |
|
"length-normalise LM score" |
|
for (k, _) in self.entries.items(): |
|
labelingLen = len(self.entries[k].labeling) |
|
self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0)) |
|
|
|
def sort(self): |
|
"return beam-labelings, sorted by probability" |
|
beams = [v for (_, v) in self.entries.items()] |
|
sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText) |
|
return [x.labeling for x in sortedBeams] |
|
|
|
def wordsearch(self, classes, ignore_idx, beamWidth, dict_list): |
|
beams = [v for (_, v) in self.entries.items()] |
|
sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText)[:beamWidth] |
|
|
|
for j, candidate in enumerate(sortedBeams): |
|
idx_list = candidate.labeling |
|
text = '' |
|
for i,l in enumerate(idx_list): |
|
if l not in ignore_idx and (not (i > 0 and idx_list[i - 1] == idx_list[i])): |
|
text += classes[l] |
|
|
|
if j == 0: best_text = text |
|
if text in dict_list: |
|
print('found text: ', text) |
|
best_text = text |
|
break |
|
else: |
|
print('not in dict: ', text) |
|
return best_text |
|
|
|
def applyLM(parentBeam, childBeam, classes, lm): |
|
"calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars" |
|
if lm and not childBeam.lmApplied: |
|
c1 = classes[parentBeam.labeling[-1] if parentBeam.labeling else classes.index(' ')] |
|
c2 = classes[childBeam.labeling[-1]] |
|
lmFactor = 0.01 |
|
bigramProb = lm.getCharBigram(c1, c2) ** lmFactor |
|
childBeam.prText = parentBeam.prText * bigramProb |
|
childBeam.lmApplied = True |
|
|
|
def addBeam(beamState, labeling): |
|
"add beam if it does not yet exist" |
|
if labeling not in beamState.entries: |
|
beamState.entries[labeling] = BeamEntry() |
|
|
|
def ctcBeamSearch(mat, classes, ignore_idx, lm, beamWidth=25, dict_list = []): |
|
"beam search as described by the paper of Hwang et al. and the paper of Graves et al." |
|
|
|
|
|
blankIdx = 0 |
|
maxT, maxC = mat.shape |
|
|
|
|
|
last = BeamState() |
|
labeling = () |
|
last.entries[labeling] = BeamEntry() |
|
last.entries[labeling].prBlank = 1 |
|
last.entries[labeling].prTotal = 1 |
|
|
|
|
|
for t in range(maxT): |
|
curr = BeamState() |
|
|
|
|
|
bestLabelings = last.sort()[0:beamWidth] |
|
|
|
|
|
for labeling in bestLabelings: |
|
|
|
|
|
prNonBlank = 0 |
|
|
|
if labeling: |
|
|
|
prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]] |
|
|
|
|
|
prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx] |
|
|
|
|
|
addBeam(curr, labeling) |
|
|
|
|
|
curr.entries[labeling].labeling = labeling |
|
curr.entries[labeling].prNonBlank += prNonBlank |
|
curr.entries[labeling].prBlank += prBlank |
|
curr.entries[labeling].prTotal += prBlank + prNonBlank |
|
curr.entries[labeling].prText = last.entries[labeling].prText |
|
curr.entries[labeling].lmApplied = True |
|
|
|
|
|
for c in range(maxC - 1): |
|
|
|
newLabeling = labeling + (c,) |
|
|
|
|
|
if labeling and labeling[-1] == c: |
|
prNonBlank = mat[t, c] * last.entries[labeling].prBlank |
|
else: |
|
prNonBlank = mat[t, c] * last.entries[labeling].prTotal |
|
|
|
|
|
addBeam(curr, newLabeling) |
|
|
|
|
|
curr.entries[newLabeling].labeling = newLabeling |
|
curr.entries[newLabeling].prNonBlank += prNonBlank |
|
curr.entries[newLabeling].prTotal += prNonBlank |
|
|
|
|
|
|
|
|
|
|
|
last = curr |
|
|
|
|
|
last.norm() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if dict_list == []: |
|
bestLabeling = last.sort()[0] |
|
res = '' |
|
for i,l in enumerate(bestLabeling): |
|
if l not in ignore_idx and (not (i > 0 and bestLabeling[i - 1] == bestLabeling[i])): |
|
res += classes[l] |
|
else: |
|
res = last.wordsearch(classes, ignore_idx, beamWidth, dict_list) |
|
|
|
return res |
|
|
|
|
|
def consecutive(data, mode ='first', stepsize=1): |
|
group = np.split(data, np.where(np.diff(data) != stepsize)[0]+1) |
|
group = [item for item in group if len(item)>0] |
|
|
|
if mode == 'first': result = [l[0] for l in group] |
|
elif mode == 'last': result = [l[-1] for l in group] |
|
return result |
|
|
|
def word_segmentation(mat, separator_idx = {'th': [1,2],'en': [3,4]}, separator_idx_list = [1,2,3,4]): |
|
result = [] |
|
sep_list = [] |
|
start_idx = 0 |
|
for sep_idx in separator_idx_list: |
|
if sep_idx % 2 == 0: mode ='first' |
|
else: mode ='last' |
|
a = consecutive( np.argwhere(mat == sep_idx).flatten(), mode) |
|
new_sep = [ [item, sep_idx] for item in a] |
|
sep_list += new_sep |
|
sep_list = sorted(sep_list, key=lambda x: x[0]) |
|
|
|
for sep in sep_list: |
|
for lang in separator_idx.keys(): |
|
if sep[1] == separator_idx[lang][0]: |
|
sep_lang = lang |
|
sep_start_idx = sep[0] |
|
elif sep[1] == separator_idx[lang][1]: |
|
if sep_lang == lang: |
|
new_sep_pair = [lang, [sep_start_idx+1, sep[0]-1]] |
|
if sep_start_idx > start_idx: |
|
result.append( ['', [start_idx, sep_start_idx-1] ] ) |
|
start_idx = sep[0]+1 |
|
result.append(new_sep_pair) |
|
else: |
|
sep_lang = '' |
|
|
|
if start_idx <= len(mat)-1: |
|
result.append( ['', [start_idx, len(mat)-1] ] ) |
|
return result |
|
|
|
class CTCLabelConverter(object): |
|
""" Convert between text-label and text-index """ |
|
|
|
|
|
def __init__(self, character, separator_list = {}, dict_pathlist = {}): |
|
|
|
dict_character = list(character) |
|
|
|
|
|
|
|
|
|
self.dict = {} |
|
|
|
for i, char in enumerate(dict_character): |
|
|
|
self.dict[char] = i + 1 |
|
|
|
self.character = ['[blank]'] + dict_character |
|
|
|
self.separator_list = separator_list |
|
|
|
separator_char = [] |
|
for lang, sep in separator_list.items(): |
|
separator_char += sep |
|
|
|
self.ignore_idx = [0] + [i+1 for i,item in enumerate(separator_char)] |
|
|
|
dict_list = {} |
|
for lang, dict_path in dict_pathlist.items(): |
|
with open(dict_path, "rb") as input_file: |
|
word_count = pickle.load(input_file) |
|
dict_list[lang] = word_count |
|
self.dict_list = dict_list |
|
|
|
def encode(self, text, batch_max_length=25): |
|
"""convert text-label into text-index. |
|
input: |
|
text: text labels of each image. [batch_size] |
|
|
|
output: |
|
text: concatenated text index for CTCLoss. |
|
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] |
|
length: length of each text. [batch_size] |
|
""" |
|
length = [len(s) for s in text] |
|
text = ''.join(text) |
|
text = [self.dict[char] for char in text] |
|
|
|
return (torch.IntTensor(text), torch.IntTensor(length)) |
|
|
|
def decode_greedy(self, text_index, length): |
|
""" convert text-index into text-label. """ |
|
texts = [] |
|
index = 0 |
|
for l in length: |
|
t = text_index[index:index + l] |
|
|
|
char_list = [] |
|
for i in range(l): |
|
if t[i] not in self.ignore_idx and (not (i > 0 and t[i - 1] == t[i])): |
|
|
|
char_list.append(self.character[t[i]]) |
|
text = ''.join(char_list) |
|
|
|
texts.append(text) |
|
index += l |
|
return texts |
|
|
|
def decode_beamsearch(self, mat, beamWidth=5): |
|
texts = [] |
|
|
|
for i in range(mat.shape[0]): |
|
t = ctcBeamSearch(mat[i], self.character, self.ignore_idx, None, beamWidth=beamWidth) |
|
texts.append(t) |
|
return texts |
|
|
|
def decode_wordbeamsearch(self, mat, beamWidth=5): |
|
texts = [] |
|
argmax = np.argmax(mat, axis = 2) |
|
for i in range(mat.shape[0]): |
|
words = word_segmentation(argmax[i]) |
|
string = '' |
|
for word in words: |
|
matrix = mat[i, word[1][0]:word[1][1]+1,:] |
|
if word[0] == '': dict_list = [] |
|
else: dict_list = self.dict_list[word[0]] |
|
t = ctcBeamSearch(matrix, self.character, self.ignore_idx, None, beamWidth=beamWidth, dict_list=dict_list) |
|
string += t |
|
texts.append(string) |
|
return texts |
|
|
|
class AttnLabelConverter(object): |
|
""" Convert between text-label and text-index """ |
|
|
|
def __init__(self, character): |
|
|
|
|
|
list_token = ['[GO]', '[s]'] |
|
list_character = list(character) |
|
self.character = list_token + list_character |
|
|
|
self.dict = {} |
|
for i, char in enumerate(self.character): |
|
|
|
self.dict[char] = i |
|
|
|
def encode(self, text, batch_max_length=25): |
|
""" convert text-label into text-index. |
|
input: |
|
text: text labels of each image. [batch_size] |
|
batch_max_length: max length of text label in the batch. 25 by default |
|
|
|
output: |
|
text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. |
|
text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. |
|
length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] |
|
""" |
|
length = [len(s) + 1 for s in text] |
|
|
|
batch_max_length += 1 |
|
|
|
batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) |
|
for i, t in enumerate(text): |
|
text = list(t) |
|
text.append('[s]') |
|
text = [self.dict[char] for char in text] |
|
batch_text[i][1:1 + len(text)] = torch.LongTensor(text) |
|
return (batch_text.to(device), torch.IntTensor(length).to(device)) |
|
|
|
def decode(self, text_index, length): |
|
""" convert text-index into text-label. """ |
|
texts = [] |
|
for index, l in enumerate(length): |
|
text = ''.join([self.character[i] for i in text_index[index, :]]) |
|
texts.append(text) |
|
return texts |
|
|
|
|
|
class Averager(object): |
|
"""Compute average for torch.Tensor, used for loss average.""" |
|
|
|
def __init__(self): |
|
self.reset() |
|
|
|
def add(self, v): |
|
count = v.data.numel() |
|
v = v.data.sum() |
|
self.n_count += count |
|
self.sum += v |
|
|
|
def reset(self): |
|
self.n_count = 0 |
|
self.sum = 0 |
|
|
|
def val(self): |
|
res = 0 |
|
if self.n_count != 0: |
|
res = self.sum / float(self.n_count) |
|
return res |
|
|