SatwikKambham's picture
Add gradio app file
abea982
import math
import gradio as gr
import lightning as L
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer
class Translator:
def __init__(
self,
src_tokenizer_ckpt_path,
tgt_tokenizer_ckpt_path,
model_ckpt_path,
):
self.src_tokenizer = Tokenizer.from_file(src_tokenizer_ckpt_path)
self.tgt_tokenizer = Tokenizer.from_file(tgt_tokenizer_ckpt_path)
self.src_tokenizer.model.dropout = 0
self.tgt_tokenizer.model.dropout = 0
self.model = TransformerSeq2Seq.load_from_checkpoint(
model_ckpt_path,
map_location="cpu",
)
self.model.eval()
def predict(self, src):
tokenized_text = self.src_tokenizer.encode(src)
src = torch.LongTensor(tokenized_text.ids).view(-1, 1)
tgt = self.model.greedy_decode(src, max_len=100)
tgt = tgt.squeeze(1).tolist()
tgt_text = self.tgt_tokenizer.decode(tgt)
return tgt_text
def generate_square_subsequent_mask(sz):
mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1)
mask = (
mask.float()
.masked_fill(mask == 0, float("-inf"))
.masked_fill(mask == 1, float(0.0))
)
return mask
class PositionalEncoding(nn.Module):
def __init__(self, embedding_dim, dropout, maxlen=5000):
super(PositionalEncoding, self).__init__()
den = torch.exp(
-torch.arange(0, embedding_dim, 2) * math.log(10000) / embedding_dim
)
pos = torch.arange(0, maxlen).reshape(maxlen, 1)
pos_embedding = torch.zeros((maxlen, embedding_dim))
pos_embedding[:, 0::2] = torch.sin(pos * den)
pos_embedding[:, 1::2] = torch.cos(pos * den)
pos_embedding = pos_embedding.unsqueeze(-2)
self.dropout = nn.Dropout(dropout)
self.register_buffer("pos_embedding", pos_embedding)
def forward(self, token_embedding):
return self.dropout(
token_embedding + self.pos_embedding[: token_embedding.size(0), :]
)
class TransformerSeq2Seq(L.LightningModule):
def __init__(
self,
src_vocab_size,
tgt_vocab_size,
embedding_dim=512,
hidden_dim=512,
dropout=0.1,
nhead=8,
num_layers=3,
batch_size=32,
lr=1e-4,
weight_decay=1e-4,
sos_idx=1,
eos_idx=2,
padding_idx=3,
):
super().__init__()
self.save_hyperparameters()
self.src_embedding = nn.Embedding(
src_vocab_size,
embedding_dim,
padding_idx=padding_idx,
)
self.tgt_embedding = nn.Embedding(
tgt_vocab_size,
embedding_dim,
padding_idx=padding_idx,
)
self.positional_encoding = PositionalEncoding(
embedding_dim=embedding_dim,
dropout=dropout,
)
self.transformer = nn.Transformer(
d_model=embedding_dim,
nhead=nhead,
num_encoder_layers=num_layers,
num_decoder_layers=num_layers,
dim_feedforward=hidden_dim,
dropout=dropout,
)
self.fc = nn.Linear(embedding_dim, tgt_vocab_size)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
self.criteria = nn.CrossEntropyLoss()
def forward(
self,
src,
tgt,
src_mask,
tgt_mask,
src_padding_mask,
tgt_padding_mask,
):
src = self.src_embedding(src) * (self.hparams.embedding_dim**0.5)
tgt = self.tgt_embedding(tgt) * (self.hparams.embedding_dim**0.5)
src = self.positional_encoding(src)
tgt = self.positional_encoding(tgt)
out = self.transformer(
src,
tgt,
src_mask=src_mask,
tgt_mask=tgt_mask,
src_key_padding_mask=src_padding_mask,
tgt_key_padding_mask=tgt_padding_mask,
)
out = self.fc(out)
return out
def greedy_decode(self, src, max_len):
src = self.src_embedding(src) * (self.hparams.embedding_dim**0.5)
src = self.positional_encoding(src)
memory = self.transformer.encoder(src)
ys = torch.ones(1, 1).fill_(self.hparams.sos_idx).type(torch.long)
for i in range(max_len - 1):
tgt = self.tgt_embedding(ys) * (self.hparams.embedding_dim**0.5)
tgt = self.positional_encoding(tgt)
tgt_mask = generate_square_subsequent_mask(ys.size(0)).type(torch.bool)
out = self.transformer.decoder(
tgt,
memory,
tgt_mask=tgt_mask,
)
out = self.fc(out)
out = out.transpose(0, 1)[:, -1]
prob = out.softmax(dim=-1)
_, next_word = torch.max(prob, dim=1)
next_word = next_word.item()
ys = torch.cat(
[ys, torch.ones(1, 1).fill_(next_word).type(torch.long)],
dim=0,
)
if next_word == self.hparams.eos_idx:
break
return ys
def training_step(self, batch, batch_idx):
src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = batch
tgt_input = tgt[:-1, :]
logits = self(
src,
tgt_input,
src_mask,
tgt_mask,
src_padding_mask,
tgt_padding_mask,
)
tgt_out = tgt[1:, :]
loss = self.criteria(
logits.reshape(-1, logits.shape[-1]),
tgt_out.reshape(-1),
)
self.log("train_loss", loss, batch_size=self.hparams.batch_size)
return loss
def validation_step(self, batch, batch_idx):
src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = batch
tgt_input = tgt[:-1, :]
logits = self(
src,
tgt_input,
src_mask,
tgt_mask,
src_padding_mask,
tgt_padding_mask,
)
tgt_out = tgt[1:, :]
loss = self.criteria(
logits.reshape(-1, logits.shape[-1]),
tgt_out.reshape(-1),
)
self.log("val_loss", loss, batch_size=self.hparams.batch_size)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.hparams.lr,
weight_decay=self.hparams.weight_decay,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=self.hparams.lr,
total_steps=self.trainer.estimated_stepping_batches,
),
"interval": "step",
},
}
src_tokenizer_ckpt_path = hf_hub_download(
repo_id="SatwikKambham/opus100-en-hi-transformer",
filename="tokenizer-en.json",
)
tgt_tokenizer_ckpt_path = hf_hub_download(
repo_id="SatwikKambham/opus100-en-hi-transformer",
filename="tokenizer-hi.json",
)
model_ckpt_path = hf_hub_download(
repo_id="SatwikKambham/opus100-en-hi-transformer",
filename="transformer.ckpt",
)
classifier = Translator(
src_tokenizer_ckpt_path,
tgt_tokenizer_ckpt_path,
model_ckpt_path,
)
interface = gr.Interface(
fn=classifier.predict,
inputs=gr.components.Textbox(
label="Source Language (English)",
placeholder="Enter text here...",
),
outputs=gr.components.Textbox(
label="Target Language (Hindi)",
placeholder="Translation",
),
examples=[
["Hi how are you?"],
["Today is a very important day."],
["I like playing the guitar."],
],
)
interface.launch()