|
import math |
|
|
|
import gradio as gr |
|
import lightning as L |
|
import torch |
|
import torch.nn as nn |
|
from huggingface_hub import hf_hub_download |
|
from tokenizers import Tokenizer |
|
|
|
|
|
class Translator: |
|
def __init__( |
|
self, |
|
src_tokenizer_ckpt_path, |
|
tgt_tokenizer_ckpt_path, |
|
model_ckpt_path, |
|
): |
|
self.src_tokenizer = Tokenizer.from_file(src_tokenizer_ckpt_path) |
|
self.tgt_tokenizer = Tokenizer.from_file(tgt_tokenizer_ckpt_path) |
|
|
|
self.src_tokenizer.model.dropout = 0 |
|
self.tgt_tokenizer.model.dropout = 0 |
|
|
|
self.model = TransformerSeq2Seq.load_from_checkpoint( |
|
model_ckpt_path, |
|
map_location="cpu", |
|
) |
|
self.model.eval() |
|
|
|
def predict(self, src): |
|
tokenized_text = self.src_tokenizer.encode(src) |
|
src = torch.LongTensor(tokenized_text.ids).view(-1, 1) |
|
tgt = self.model.greedy_decode(src, max_len=100) |
|
tgt = tgt.squeeze(1).tolist() |
|
tgt_text = self.tgt_tokenizer.decode(tgt) |
|
return tgt_text |
|
|
|
|
|
def generate_square_subsequent_mask(sz): |
|
mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1) |
|
mask = ( |
|
mask.float() |
|
.masked_fill(mask == 0, float("-inf")) |
|
.masked_fill(mask == 1, float(0.0)) |
|
) |
|
return mask |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
def __init__(self, embedding_dim, dropout, maxlen=5000): |
|
super(PositionalEncoding, self).__init__() |
|
den = torch.exp( |
|
-torch.arange(0, embedding_dim, 2) * math.log(10000) / embedding_dim |
|
) |
|
pos = torch.arange(0, maxlen).reshape(maxlen, 1) |
|
pos_embedding = torch.zeros((maxlen, embedding_dim)) |
|
pos_embedding[:, 0::2] = torch.sin(pos * den) |
|
pos_embedding[:, 1::2] = torch.cos(pos * den) |
|
pos_embedding = pos_embedding.unsqueeze(-2) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
self.register_buffer("pos_embedding", pos_embedding) |
|
|
|
def forward(self, token_embedding): |
|
return self.dropout( |
|
token_embedding + self.pos_embedding[: token_embedding.size(0), :] |
|
) |
|
|
|
|
|
class TransformerSeq2Seq(L.LightningModule): |
|
def __init__( |
|
self, |
|
src_vocab_size, |
|
tgt_vocab_size, |
|
embedding_dim=512, |
|
hidden_dim=512, |
|
dropout=0.1, |
|
nhead=8, |
|
num_layers=3, |
|
batch_size=32, |
|
lr=1e-4, |
|
weight_decay=1e-4, |
|
sos_idx=1, |
|
eos_idx=2, |
|
padding_idx=3, |
|
): |
|
super().__init__() |
|
self.save_hyperparameters() |
|
|
|
self.src_embedding = nn.Embedding( |
|
src_vocab_size, |
|
embedding_dim, |
|
padding_idx=padding_idx, |
|
) |
|
self.tgt_embedding = nn.Embedding( |
|
tgt_vocab_size, |
|
embedding_dim, |
|
padding_idx=padding_idx, |
|
) |
|
self.positional_encoding = PositionalEncoding( |
|
embedding_dim=embedding_dim, |
|
dropout=dropout, |
|
) |
|
self.transformer = nn.Transformer( |
|
d_model=embedding_dim, |
|
nhead=nhead, |
|
num_encoder_layers=num_layers, |
|
num_decoder_layers=num_layers, |
|
dim_feedforward=hidden_dim, |
|
dropout=dropout, |
|
) |
|
self.fc = nn.Linear(embedding_dim, tgt_vocab_size) |
|
|
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
self.criteria = nn.CrossEntropyLoss() |
|
|
|
def forward( |
|
self, |
|
src, |
|
tgt, |
|
src_mask, |
|
tgt_mask, |
|
src_padding_mask, |
|
tgt_padding_mask, |
|
): |
|
src = self.src_embedding(src) * (self.hparams.embedding_dim**0.5) |
|
tgt = self.tgt_embedding(tgt) * (self.hparams.embedding_dim**0.5) |
|
src = self.positional_encoding(src) |
|
tgt = self.positional_encoding(tgt) |
|
out = self.transformer( |
|
src, |
|
tgt, |
|
src_mask=src_mask, |
|
tgt_mask=tgt_mask, |
|
src_key_padding_mask=src_padding_mask, |
|
tgt_key_padding_mask=tgt_padding_mask, |
|
) |
|
out = self.fc(out) |
|
return out |
|
|
|
def greedy_decode(self, src, max_len): |
|
src = self.src_embedding(src) * (self.hparams.embedding_dim**0.5) |
|
src = self.positional_encoding(src) |
|
memory = self.transformer.encoder(src) |
|
ys = torch.ones(1, 1).fill_(self.hparams.sos_idx).type(torch.long) |
|
for i in range(max_len - 1): |
|
tgt = self.tgt_embedding(ys) * (self.hparams.embedding_dim**0.5) |
|
tgt = self.positional_encoding(tgt) |
|
tgt_mask = generate_square_subsequent_mask(ys.size(0)).type(torch.bool) |
|
out = self.transformer.decoder( |
|
tgt, |
|
memory, |
|
tgt_mask=tgt_mask, |
|
) |
|
out = self.fc(out) |
|
out = out.transpose(0, 1)[:, -1] |
|
prob = out.softmax(dim=-1) |
|
_, next_word = torch.max(prob, dim=1) |
|
next_word = next_word.item() |
|
ys = torch.cat( |
|
[ys, torch.ones(1, 1).fill_(next_word).type(torch.long)], |
|
dim=0, |
|
) |
|
|
|
if next_word == self.hparams.eos_idx: |
|
break |
|
|
|
return ys |
|
|
|
def training_step(self, batch, batch_idx): |
|
src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = batch |
|
tgt_input = tgt[:-1, :] |
|
logits = self( |
|
src, |
|
tgt_input, |
|
src_mask, |
|
tgt_mask, |
|
src_padding_mask, |
|
tgt_padding_mask, |
|
) |
|
tgt_out = tgt[1:, :] |
|
loss = self.criteria( |
|
logits.reshape(-1, logits.shape[-1]), |
|
tgt_out.reshape(-1), |
|
) |
|
self.log("train_loss", loss, batch_size=self.hparams.batch_size) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = batch |
|
tgt_input = tgt[:-1, :] |
|
logits = self( |
|
src, |
|
tgt_input, |
|
src_mask, |
|
tgt_mask, |
|
src_padding_mask, |
|
tgt_padding_mask, |
|
) |
|
tgt_out = tgt[1:, :] |
|
loss = self.criteria( |
|
logits.reshape(-1, logits.shape[-1]), |
|
tgt_out.reshape(-1), |
|
) |
|
self.log("val_loss", loss, batch_size=self.hparams.batch_size) |
|
|
|
def configure_optimizers(self): |
|
optimizer = torch.optim.AdamW( |
|
self.parameters(), |
|
lr=self.hparams.lr, |
|
weight_decay=self.hparams.weight_decay, |
|
) |
|
return { |
|
"optimizer": optimizer, |
|
"lr_scheduler": { |
|
"scheduler": torch.optim.lr_scheduler.OneCycleLR( |
|
optimizer=optimizer, |
|
max_lr=self.hparams.lr, |
|
total_steps=self.trainer.estimated_stepping_batches, |
|
), |
|
"interval": "step", |
|
}, |
|
} |
|
|
|
|
|
src_tokenizer_ckpt_path = hf_hub_download( |
|
repo_id="SatwikKambham/opus100-en-hi-transformer", |
|
filename="tokenizer-en.json", |
|
) |
|
tgt_tokenizer_ckpt_path = hf_hub_download( |
|
repo_id="SatwikKambham/opus100-en-hi-transformer", |
|
filename="tokenizer-hi.json", |
|
) |
|
model_ckpt_path = hf_hub_download( |
|
repo_id="SatwikKambham/opus100-en-hi-transformer", |
|
filename="transformer.ckpt", |
|
) |
|
classifier = Translator( |
|
src_tokenizer_ckpt_path, |
|
tgt_tokenizer_ckpt_path, |
|
model_ckpt_path, |
|
) |
|
interface = gr.Interface( |
|
fn=classifier.predict, |
|
inputs=gr.components.Textbox( |
|
label="Source Language (English)", |
|
placeholder="Enter text here...", |
|
), |
|
outputs=gr.components.Textbox( |
|
label="Target Language (Hindi)", |
|
placeholder="Translation", |
|
), |
|
examples=[ |
|
["Hi how are you?"], |
|
["Today is a very important day."], |
|
["I like playing the guitar."], |
|
], |
|
) |
|
interface.launch() |
|
|