import math import gradio as gr import lightning as L import torch import torch.nn as nn from huggingface_hub import hf_hub_download from tokenizers import Tokenizer class Translator: def __init__( self, src_tokenizer_ckpt_path, tgt_tokenizer_ckpt_path, model_ckpt_path, ): self.src_tokenizer = Tokenizer.from_file(src_tokenizer_ckpt_path) self.tgt_tokenizer = Tokenizer.from_file(tgt_tokenizer_ckpt_path) self.src_tokenizer.model.dropout = 0 self.tgt_tokenizer.model.dropout = 0 self.model = TransformerSeq2Seq.load_from_checkpoint( model_ckpt_path, map_location="cpu", ) self.model.eval() def predict(self, src): tokenized_text = self.src_tokenizer.encode(src) src = torch.LongTensor(tokenized_text.ids).view(-1, 1) tgt = self.model.greedy_decode(src, max_len=100) tgt = tgt.squeeze(1).tolist() tgt_text = self.tgt_tokenizer.decode(tgt) return tgt_text def generate_square_subsequent_mask(sz): mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1) mask = ( mask.float() .masked_fill(mask == 0, float("-inf")) .masked_fill(mask == 1, float(0.0)) ) return mask class PositionalEncoding(nn.Module): def __init__(self, embedding_dim, dropout, maxlen=5000): super(PositionalEncoding, self).__init__() den = torch.exp( -torch.arange(0, embedding_dim, 2) * math.log(10000) / embedding_dim ) pos = torch.arange(0, maxlen).reshape(maxlen, 1) pos_embedding = torch.zeros((maxlen, embedding_dim)) pos_embedding[:, 0::2] = torch.sin(pos * den) pos_embedding[:, 1::2] = torch.cos(pos * den) pos_embedding = pos_embedding.unsqueeze(-2) self.dropout = nn.Dropout(dropout) self.register_buffer("pos_embedding", pos_embedding) def forward(self, token_embedding): return self.dropout( token_embedding + self.pos_embedding[: token_embedding.size(0), :] ) class TransformerSeq2Seq(L.LightningModule): def __init__( self, src_vocab_size, tgt_vocab_size, embedding_dim=512, hidden_dim=512, dropout=0.1, nhead=8, num_layers=3, batch_size=32, lr=1e-4, weight_decay=1e-4, sos_idx=1, eos_idx=2, padding_idx=3, ): super().__init__() self.save_hyperparameters() self.src_embedding = nn.Embedding( src_vocab_size, embedding_dim, padding_idx=padding_idx, ) self.tgt_embedding = nn.Embedding( tgt_vocab_size, embedding_dim, padding_idx=padding_idx, ) self.positional_encoding = PositionalEncoding( embedding_dim=embedding_dim, dropout=dropout, ) self.transformer = nn.Transformer( d_model=embedding_dim, nhead=nhead, num_encoder_layers=num_layers, num_decoder_layers=num_layers, dim_feedforward=hidden_dim, dropout=dropout, ) self.fc = nn.Linear(embedding_dim, tgt_vocab_size) for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) self.criteria = nn.CrossEntropyLoss() def forward( self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, ): src = self.src_embedding(src) * (self.hparams.embedding_dim**0.5) tgt = self.tgt_embedding(tgt) * (self.hparams.embedding_dim**0.5) src = self.positional_encoding(src) tgt = self.positional_encoding(tgt) out = self.transformer( src, tgt, src_mask=src_mask, tgt_mask=tgt_mask, src_key_padding_mask=src_padding_mask, tgt_key_padding_mask=tgt_padding_mask, ) out = self.fc(out) return out def greedy_decode(self, src, max_len): src = self.src_embedding(src) * (self.hparams.embedding_dim**0.5) src = self.positional_encoding(src) memory = self.transformer.encoder(src) ys = torch.ones(1, 1).fill_(self.hparams.sos_idx).type(torch.long) for i in range(max_len - 1): tgt = self.tgt_embedding(ys) * (self.hparams.embedding_dim**0.5) tgt = self.positional_encoding(tgt) tgt_mask = generate_square_subsequent_mask(ys.size(0)).type(torch.bool) out = self.transformer.decoder( tgt, memory, tgt_mask=tgt_mask, ) out = self.fc(out) out = out.transpose(0, 1)[:, -1] prob = out.softmax(dim=-1) _, next_word = torch.max(prob, dim=1) next_word = next_word.item() ys = torch.cat( [ys, torch.ones(1, 1).fill_(next_word).type(torch.long)], dim=0, ) if next_word == self.hparams.eos_idx: break return ys def training_step(self, batch, batch_idx): src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = batch tgt_input = tgt[:-1, :] logits = self( src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, ) tgt_out = tgt[1:, :] loss = self.criteria( logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1), ) self.log("train_loss", loss, batch_size=self.hparams.batch_size) return loss def validation_step(self, batch, batch_idx): src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = batch tgt_input = tgt[:-1, :] logits = self( src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, ) tgt_out = tgt[1:, :] loss = self.criteria( logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1), ) self.log("val_loss", loss, batch_size=self.hparams.batch_size) def configure_optimizers(self): optimizer = torch.optim.AdamW( self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay, ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": torch.optim.lr_scheduler.OneCycleLR( optimizer=optimizer, max_lr=self.hparams.lr, total_steps=self.trainer.estimated_stepping_batches, ), "interval": "step", }, } src_tokenizer_ckpt_path = hf_hub_download( repo_id="SatwikKambham/opus100-en-hi-transformer", filename="tokenizer-en.json", ) tgt_tokenizer_ckpt_path = hf_hub_download( repo_id="SatwikKambham/opus100-en-hi-transformer", filename="tokenizer-hi.json", ) model_ckpt_path = hf_hub_download( repo_id="SatwikKambham/opus100-en-hi-transformer", filename="transformer.ckpt", ) classifier = Translator( src_tokenizer_ckpt_path, tgt_tokenizer_ckpt_path, model_ckpt_path, ) interface = gr.Interface( fn=classifier.predict, inputs=gr.components.Textbox( label="Source Language (English)", placeholder="Enter text here...", ), outputs=gr.components.Textbox( label="Target Language (Hindi)", placeholder="Translation", ), examples=[ ["Hi how are you?"], ["Today is a very important day."], ["I like playing the guitar."], ], ) interface.launch()