|
from numpy import asarray as np_asarray,int64 as np_int64
|
|
from json import load as json_load
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import device as Device
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self, input_dim, embed_dim, hidden_dim ,
|
|
rnn_type = 'gru', layers = 1,
|
|
bidirectional =False,
|
|
dropout = 0, device = "cpu"):
|
|
super(Encoder, self).__init__()
|
|
|
|
self.input_dim = input_dim
|
|
self.enc_embed_dim = embed_dim
|
|
self.enc_hidden_dim = hidden_dim
|
|
self.enc_rnn_type = rnn_type
|
|
self.enc_layers = layers
|
|
self.enc_directions = 2 if bidirectional else 1
|
|
self.device = device
|
|
|
|
self.embedding = nn.Embedding(self.input_dim, self.enc_embed_dim)
|
|
|
|
if self.enc_rnn_type == "gru":
|
|
self.enc_rnn = nn.GRU(input_size= self.enc_embed_dim,
|
|
hidden_size= self.enc_hidden_dim,
|
|
num_layers= self.enc_layers,
|
|
bidirectional= bidirectional)
|
|
elif self.enc_rnn_type == "lstm":
|
|
self.enc_rnn = nn.LSTM(input_size= self.enc_embed_dim,
|
|
hidden_size= self.enc_hidden_dim,
|
|
num_layers= self.enc_layers,
|
|
bidirectional= bidirectional)
|
|
else:
|
|
raise Exception("XlitError: unknown RNN type mentioned")
|
|
|
|
def forward(self, x, x_sz, hidden = None):
|
|
"""
|
|
x_sz: (batch_size, 1) - Unpadded sequence lengths used for pack_pad
|
|
"""
|
|
batch_sz = x.shape[0]
|
|
|
|
x = self.embedding(x)
|
|
|
|
|
|
|
|
x = x.permute(1,0,2)
|
|
x = nn.utils.rnn.pack_padded_sequence(x, x_sz, enforce_sorted=False)
|
|
|
|
|
|
|
|
output, hidden = self.enc_rnn(x)
|
|
|
|
|
|
|
|
output, _ = nn.utils.rnn.pad_packed_sequence(output)
|
|
|
|
|
|
output = output.permute(1,0,2)
|
|
|
|
return output, hidden
|
|
|
|
def get_word_embedding(self, x):
|
|
"""
|
|
"""
|
|
x_sz = torch.tensor([len(x)])
|
|
x_ = torch.tensor(x).unsqueeze(0).to(dtype=torch.long)
|
|
|
|
x = self.embedding(x_)
|
|
|
|
|
|
|
|
x = x.permute(1,0,2)
|
|
x = nn.utils.rnn.pack_padded_sequence(x, x_sz, enforce_sorted=False)
|
|
|
|
|
|
|
|
output, hidden = self.enc_rnn(x)
|
|
|
|
out_embed = hidden[0].squeeze()
|
|
|
|
return out_embed
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self, output_dim, embed_dim, hidden_dim,
|
|
rnn_type = 'gru', layers = 1,
|
|
use_attention = True,
|
|
enc_outstate_dim = None,
|
|
dropout = 0, device = "cpu"):
|
|
super(Decoder, self).__init__()
|
|
|
|
self.output_dim = output_dim
|
|
self.dec_hidden_dim = hidden_dim
|
|
self.dec_embed_dim = embed_dim
|
|
self.dec_rnn_type = rnn_type
|
|
self.dec_layers = layers
|
|
self.use_attention = use_attention
|
|
self.device = device
|
|
if self.use_attention:
|
|
self.enc_outstate_dim = enc_outstate_dim if enc_outstate_dim else hidden_dim
|
|
else:
|
|
self.enc_outstate_dim = 0
|
|
|
|
|
|
self.embedding = nn.Embedding(self.output_dim, self.dec_embed_dim)
|
|
|
|
if self.dec_rnn_type == 'gru':
|
|
self.dec_rnn = nn.GRU(input_size= self.dec_embed_dim + self.enc_outstate_dim,
|
|
hidden_size= self.dec_hidden_dim,
|
|
num_layers= self.dec_layers,
|
|
batch_first = True )
|
|
elif self.dec_rnn_type == "lstm":
|
|
self.dec_rnn = nn.LSTM(input_size= self.dec_embed_dim + self.enc_outstate_dim,
|
|
hidden_size= self.dec_hidden_dim,
|
|
num_layers= self.dec_layers,
|
|
batch_first = True )
|
|
else:
|
|
raise Exception("XlitError: unknown RNN type mentioned")
|
|
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(self.dec_hidden_dim, self.dec_embed_dim), nn.LeakyReLU(),
|
|
|
|
nn.Linear(self.dec_embed_dim, self.output_dim),
|
|
)
|
|
|
|
|
|
if self.use_attention:
|
|
self.W1 = nn.Linear( self.enc_outstate_dim, self.dec_hidden_dim)
|
|
self.W2 = nn.Linear( self.dec_hidden_dim, self.dec_hidden_dim)
|
|
self.V = nn.Linear( self.dec_hidden_dim, 1)
|
|
|
|
def attention(self, x, hidden, enc_output):
|
|
'''
|
|
x: (batch_size, 1, dec_embed_dim) -> after Embedding
|
|
enc_output: batch_size, max_length, enc_hidden_dim *num_directions
|
|
hidden: n_layers, batch_size, hidden_size | if LSTM (h_n, c_n)
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
hidden_with_time_axis = torch.sum(hidden, axis=0) if self.dec_rnn_type != "lstm" \
|
|
else torch.sum(hidden[0], axis=0)
|
|
|
|
hidden_with_time_axis = hidden_with_time_axis.unsqueeze(1)
|
|
|
|
|
|
score = torch.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis))
|
|
|
|
|
|
|
|
attention_weights = torch.softmax(self.V(score), dim=1)
|
|
|
|
|
|
context_vector = attention_weights * enc_output
|
|
context_vector = torch.sum(context_vector, dim=1)
|
|
|
|
context_vector = context_vector.unsqueeze(1)
|
|
|
|
|
|
attend_out = torch.cat((context_vector, x), -1)
|
|
|
|
return attend_out, attention_weights
|
|
|
|
def forward(self, x, hidden, enc_output):
|
|
'''
|
|
x: (batch_size, 1)
|
|
enc_output: batch_size, max_length, dec_embed_dim
|
|
hidden: n_layer, batch_size, hidden_size | lstm: (h_n, c_n)
|
|
'''
|
|
if (hidden is None) and (self.use_attention is False):
|
|
raise Exception( "XlitError: No use of a decoder with No attention and No Hidden")
|
|
|
|
batch_sz = x.shape[0]
|
|
|
|
if hidden is None:
|
|
|
|
hid_for_att = torch.zeros((self.dec_layers, batch_sz,
|
|
self.dec_hidden_dim )).to(self.device)
|
|
elif self.dec_rnn_type == 'lstm':
|
|
hid_for_att = hidden[1]
|
|
|
|
|
|
x = self.embedding(x)
|
|
|
|
if self.use_attention:
|
|
|
|
|
|
x, aw = self.attention( x, hidden, enc_output)
|
|
else:
|
|
x, aw = x, 0
|
|
|
|
|
|
|
|
|
|
output, hidden = self.dec_rnn(x, hidden) if hidden is not None else self.dec_rnn(x)
|
|
|
|
|
|
output = output.view(-1, output.size(2))
|
|
|
|
|
|
output = self.fc(output)
|
|
|
|
return output, hidden, aw
|
|
|
|
class Seq2Seq(nn.Module):
|
|
"""
|
|
Class dependency: Encoder, Decoder
|
|
"""
|
|
def __init__(self, encoder, decoder, pass_enc2dec_hid=False, dropout = 0, device = "cpu"):
|
|
super(Seq2Seq, self).__init__()
|
|
|
|
self.encoder = encoder
|
|
self.decoder = decoder
|
|
self.device = device
|
|
self.pass_enc2dec_hid = pass_enc2dec_hid
|
|
_force_en2dec_hid_conv = False
|
|
|
|
if self.pass_enc2dec_hid:
|
|
assert decoder.dec_hidden_dim == encoder.enc_hidden_dim, "Hidden Dimension of encoder and decoder must be same, or unset `pass_enc2dec_hid`"
|
|
if decoder.use_attention:
|
|
assert decoder.enc_outstate_dim == encoder.enc_directions*encoder.enc_hidden_dim,"Set `enc_out_dim` correctly in decoder"
|
|
assert self.pass_enc2dec_hid or decoder.use_attention, "No use of a decoder with No attention and No Hidden from Encoder"
|
|
|
|
|
|
self.use_conv_4_enc2dec_hid = False
|
|
if (
|
|
( self.pass_enc2dec_hid and
|
|
(encoder.enc_directions * encoder.enc_layers != decoder.dec_layers)
|
|
)
|
|
or _force_en2dec_hid_conv
|
|
):
|
|
if encoder.enc_rnn_type == "lstm" or encoder.enc_rnn_type == "lstm":
|
|
raise Exception("XlitError: conv for enc2dec_hid not implemented; Change the layer numbers appropriately")
|
|
|
|
self.use_conv_4_enc2dec_hid = True
|
|
self.enc_hid_1ax = encoder.enc_directions * encoder.enc_layers
|
|
self.dec_hid_1ax = decoder.dec_layers
|
|
self.e2d_hidden_conv = nn.Conv1d(self.enc_hid_1ax, self.dec_hid_1ax, 1)
|
|
|
|
def enc2dec_hidden(self, enc_hidden):
|
|
"""
|
|
enc_hidden: n_layer, batch_size, hidden_dim*num_directions
|
|
TODO: Implement the logic for LSTm bsed model
|
|
"""
|
|
|
|
hidden = enc_hidden.permute(1,0,2).contiguous()
|
|
|
|
hidden = self.e2d_hidden_conv(hidden)
|
|
|
|
|
|
hidden_for_dec = hidden.permute(1,0,2).contiguous()
|
|
|
|
return hidden_for_dec
|
|
|
|
def active_beam_inference(self, src, beam_width=3, max_tgt_sz=50):
|
|
''' Search based decoding
|
|
src: (sequence_len)
|
|
'''
|
|
def _avg_score(p_tup):
|
|
""" Used for Sorting
|
|
TODO: Dividing by length of sequence power alpha as hyperparam
|
|
"""
|
|
return p_tup[0]
|
|
|
|
batch_size = 1
|
|
start_tok = src[0]
|
|
end_tok = src[-1]
|
|
src_sz = torch.tensor([len(src)])
|
|
src_ = src.unsqueeze(0)
|
|
|
|
|
|
|
|
enc_output, enc_hidden = self.encoder(src_, src_sz)
|
|
|
|
if self.pass_enc2dec_hid:
|
|
|
|
if self.use_conv_4_enc2dec_hid:
|
|
init_dec_hidden = self.enc2dec_hidden(enc_hidden)
|
|
else:
|
|
init_dec_hidden = enc_hidden
|
|
else:
|
|
|
|
init_dec_hidden = None
|
|
|
|
|
|
|
|
|
|
top_pred_list = [ (0, start_tok.unsqueeze(0) , init_dec_hidden) ]
|
|
|
|
for t in range(max_tgt_sz):
|
|
cur_pred_list = []
|
|
|
|
for p_tup in top_pred_list:
|
|
if p_tup[1][-1] == end_tok:
|
|
cur_pred_list.append(p_tup)
|
|
continue
|
|
|
|
|
|
|
|
dec_output, dec_hidden, _ = self.decoder( x = p_tup[1][-1].view(1,1),
|
|
hidden = p_tup[2],
|
|
enc_output = enc_output, )
|
|
|
|
|
|
|
|
dec_output = nn.functional.log_softmax(dec_output, dim=1)
|
|
|
|
pred_topk = torch.topk(dec_output, k=beam_width, dim=1)
|
|
|
|
for i in range(beam_width):
|
|
sig_logsmx_ = p_tup[0] + pred_topk.values[0][i]
|
|
|
|
seq_tensor_ = torch.cat( (p_tup[1], pred_topk.indices[0][i].view(1)) )
|
|
|
|
cur_pred_list.append( (sig_logsmx_, seq_tensor_, dec_hidden) )
|
|
|
|
cur_pred_list.sort(key = _avg_score, reverse =True)
|
|
top_pred_list = cur_pred_list[:beam_width]
|
|
|
|
|
|
end_flags_ = [1 if t[1][-1] == end_tok else 0 for t in top_pred_list]
|
|
if beam_width == sum( end_flags_ ): break
|
|
|
|
pred_tnsr_list = [t[1] for t in top_pred_list ]
|
|
|
|
return pred_tnsr_list
|
|
|
|
class GlyphStrawboss():
|
|
def __init__(self, glyphs = 'en'):
|
|
""" list of letters in a language in unicode
|
|
lang: ISO Language code
|
|
glyphs: json file with script information
|
|
"""
|
|
if glyphs == 'en':
|
|
|
|
self.glyphs = [chr(alpha) for alpha in range(97, 122+1)]
|
|
else:
|
|
self.dossier = json_load(open(glyphs, encoding='utf-8'))
|
|
self.glyphs = self.dossier["glyphs"]
|
|
self.numsym_map = self.dossier["numsym_map"]
|
|
|
|
self.char2idx = {}
|
|
self.idx2char = {}
|
|
self._create_index()
|
|
|
|
def _create_index(self):
|
|
|
|
self.char2idx['_'] = 0
|
|
self.char2idx['$'] = 1
|
|
self.char2idx['#'] = 2
|
|
self.char2idx['*'] = 3
|
|
self.char2idx["'"] = 4
|
|
self.char2idx['%'] = 5
|
|
self.char2idx['!'] = 6
|
|
|
|
|
|
for idx, char in enumerate(self.glyphs):
|
|
self.char2idx[char] = idx + 7
|
|
|
|
|
|
for char, idx in self.char2idx.items():
|
|
self.idx2char[idx] = char
|
|
|
|
def size(self):
|
|
return len(self.char2idx)
|
|
|
|
|
|
def word2xlitvec(self, word):
|
|
""" Converts given string of gyphs(word) to vector(numpy)
|
|
Also adds tokens for start and end
|
|
"""
|
|
try:
|
|
vec = [self.char2idx['$']]
|
|
for i in list(word):
|
|
vec.append(self.char2idx[i])
|
|
vec.append(self.char2idx['#'])
|
|
vec = np_asarray(vec, dtype=np_int64)
|
|
return vec
|
|
|
|
except Exception as error:
|
|
print("XlitError: In word:", word)
|
|
exit("Error Char not in Token: " + error)
|
|
|
|
def xlitvec2word(self, vector):
|
|
""" Converts vector(numpy) to string of glyphs(word)
|
|
"""
|
|
char_list = []
|
|
for i in vector:
|
|
char_list.append(self.idx2char[i])
|
|
|
|
word = "".join(char_list).replace('$','').replace('#','')
|
|
word = word.replace("_", "").replace('*','')
|
|
return word
|
|
|
|
class VocabSanitizer():
|
|
def __init__(self, data_file):
|
|
'''
|
|
data_file: path to file conatining vocabulary list
|
|
'''
|
|
self.vocab_set = set( json_load(open(data_file, encoding='utf-8')) )
|
|
|
|
def reposition(self, word_list):
|
|
'''Reorder Words in list
|
|
'''
|
|
new_list = []
|
|
temp_ = word_list.copy()
|
|
for v in word_list:
|
|
if v in self.vocab_set:
|
|
new_list.append(v)
|
|
temp_.remove(v)
|
|
new_list.extend(temp_)
|
|
return new_list
|
|
|
|
class XlitPiston():
|
|
"""
|
|
For handling prediction & post-processing of transliteration for a single language
|
|
|
|
Class dependency: Seq2Seq, GlyphStrawboss, VocabSanitizer
|
|
Global Variables: F_DIR
|
|
"""
|
|
def __init__(self, weight_path, tglyph_cfg_file,vocab_file,device:Device, iglyph_cfg_file = "en"):
|
|
self.device = device
|
|
self.in_glyph_obj = GlyphStrawboss(iglyph_cfg_file)
|
|
self.tgt_glyph_obj = GlyphStrawboss(glyphs = tglyph_cfg_file)
|
|
if vocab_file:
|
|
self.voc_sanitizer = VocabSanitizer(vocab_file)
|
|
else:
|
|
self.voc_sanitizer = None
|
|
|
|
self._numsym_set = set(json_load(open(tglyph_cfg_file, encoding='utf-8'))["numsym_map"].keys() )
|
|
self._inchar_set = set("abcdefghijklmnopqrstuvwxyz")
|
|
self._natscr_set = set().union(self.tgt_glyph_obj.glyphs,
|
|
sum(self.tgt_glyph_obj.numsym_map.values(),[]) )
|
|
|
|
|
|
|
|
input_dim = self.in_glyph_obj.size()
|
|
output_dim = self.tgt_glyph_obj.size()
|
|
enc_emb_dim = 300
|
|
dec_emb_dim = 300
|
|
enc_hidden_dim = 512
|
|
dec_hidden_dim = 512
|
|
rnn_type = "lstm"
|
|
enc2dec_hid = True
|
|
attention = True
|
|
enc_layers = 1
|
|
dec_layers = 2
|
|
m_dropout = 0
|
|
enc_bidirect = True
|
|
enc_outstate_dim = enc_hidden_dim * (2 if enc_bidirect else 1)
|
|
|
|
enc = Encoder( input_dim= input_dim, embed_dim = enc_emb_dim,
|
|
hidden_dim= enc_hidden_dim,
|
|
rnn_type = rnn_type, layers= enc_layers,
|
|
dropout= m_dropout, device = self.device,
|
|
bidirectional= enc_bidirect)
|
|
dec = Decoder( output_dim= output_dim, embed_dim = dec_emb_dim,
|
|
hidden_dim= dec_hidden_dim,
|
|
rnn_type = rnn_type, layers= dec_layers,
|
|
dropout= m_dropout,
|
|
use_attention = attention,
|
|
enc_outstate_dim= enc_outstate_dim,
|
|
device = self.device,)
|
|
self.model = Seq2Seq(enc, dec, pass_enc2dec_hid=enc2dec_hid, device=self.device)
|
|
self.model = self.model.to(self.device)
|
|
weights = torch.load( weight_path, map_location=torch.device(self.device))
|
|
|
|
self.model.load_state_dict(weights)
|
|
self.model.eval()
|
|
|
|
def character_model(self, word, beam_width = 1):
|
|
in_vec = torch.from_numpy(self.in_glyph_obj.word2xlitvec(word)).to(self.device)
|
|
|
|
p_out_list = self.model.active_beam_inference(in_vec, beam_width = beam_width)
|
|
p_result = [ self.tgt_glyph_obj.xlitvec2word(out.cpu().numpy()) for out in p_out_list]
|
|
|
|
if self.voc_sanitizer:
|
|
return self.voc_sanitizer.reposition(p_result)
|
|
|
|
|
|
return p_result
|
|
|
|
def numsym_model(self, seg):
|
|
''' tgt_glyph_obj.numsym_map[x] returns a list object
|
|
'''
|
|
if len(seg) == 1:
|
|
return [seg] + self.tgt_glyph_obj.numsym_map[seg]
|
|
|
|
a = [self.tgt_glyph_obj.numsym_map[n][0] for n in seg]
|
|
return [seg] + ["".join(a)]
|
|
|
|
def _word_segementer(self, sequence):
|
|
|
|
sequence = sequence.lower()
|
|
accepted = set().union(self._numsym_set, self._inchar_set, self._natscr_set)
|
|
|
|
|
|
segment = []
|
|
idx = 0
|
|
seq_ = list(sequence)
|
|
while len(seq_):
|
|
|
|
temp = ""
|
|
while len(seq_) and seq_[0] in self._numsym_set:
|
|
temp += seq_[0]
|
|
seq_.pop(0)
|
|
if temp != "": segment.append(temp)
|
|
|
|
|
|
temp = ""
|
|
while len(seq_) and seq_[0] in self._natscr_set:
|
|
temp += seq_[0]
|
|
seq_.pop(0)
|
|
if temp != "": segment.append(temp)
|
|
|
|
|
|
temp = ""
|
|
while len(seq_) and seq_[0] in self._inchar_set:
|
|
temp += seq_[0]
|
|
seq_.pop(0)
|
|
if temp != "": segment.append(temp)
|
|
|
|
temp = ""
|
|
while len(seq_) and seq_[0] not in accepted:
|
|
temp += seq_[0]
|
|
seq_.pop(0)
|
|
if temp != "": segment.append(temp)
|
|
|
|
return segment
|
|
|
|
def inferencer(self, sequence, beam_width = 10):
|
|
|
|
seg = self._word_segementer(sequence[:120])
|
|
lit_seg = []
|
|
|
|
p = 0
|
|
while p < len(seg):
|
|
if seg[p][0] in self._natscr_set:
|
|
lit_seg.append([seg[p]])
|
|
p+=1
|
|
|
|
elif seg[p][0] in self._inchar_set:
|
|
lit_seg.append(self.character_model(seg[p], beam_width=beam_width))
|
|
p+=1
|
|
|
|
elif seg[p][0] in self._numsym_set:
|
|
lit_seg.append(self.numsym_model(seg[p]))
|
|
p+=1
|
|
else:
|
|
lit_seg.append([ seg[p] ])
|
|
p+=1
|
|
|
|
|
|
|
|
if len(lit_seg) == 1:
|
|
final_result = lit_seg[0]
|
|
|
|
elif len(lit_seg) == 2:
|
|
final_result = [""]
|
|
for seg in lit_seg:
|
|
new_result = []
|
|
for s in seg:
|
|
for f in final_result:
|
|
new_result.append(f+s)
|
|
final_result = new_result
|
|
|
|
else:
|
|
new_result = []
|
|
for seg in lit_seg:
|
|
new_result.append(seg[0])
|
|
final_result = ["".join(new_result) ]
|
|
|
|
return final_result
|
|
|