Commit
·
abea982
1
Parent(s):
2cb750c
Add gradio app file
Browse files
app.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import lightning as L
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
from tokenizers import Tokenizer
|
9 |
+
|
10 |
+
|
11 |
+
class Translator:
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
src_tokenizer_ckpt_path,
|
15 |
+
tgt_tokenizer_ckpt_path,
|
16 |
+
model_ckpt_path,
|
17 |
+
):
|
18 |
+
self.src_tokenizer = Tokenizer.from_file(src_tokenizer_ckpt_path)
|
19 |
+
self.tgt_tokenizer = Tokenizer.from_file(tgt_tokenizer_ckpt_path)
|
20 |
+
|
21 |
+
self.src_tokenizer.model.dropout = 0
|
22 |
+
self.tgt_tokenizer.model.dropout = 0
|
23 |
+
|
24 |
+
self.model = TransformerSeq2Seq.load_from_checkpoint(
|
25 |
+
model_ckpt_path,
|
26 |
+
map_location="cpu",
|
27 |
+
)
|
28 |
+
self.model.eval()
|
29 |
+
|
30 |
+
def predict(self, src):
|
31 |
+
tokenized_text = self.src_tokenizer.encode(src)
|
32 |
+
src = torch.LongTensor(tokenized_text.ids).view(-1, 1)
|
33 |
+
tgt = self.model.greedy_decode(src, max_len=100)
|
34 |
+
tgt = tgt.squeeze(1).tolist()
|
35 |
+
tgt_text = self.tgt_tokenizer.decode(tgt)
|
36 |
+
return tgt_text
|
37 |
+
|
38 |
+
|
39 |
+
def generate_square_subsequent_mask(sz):
|
40 |
+
mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1)
|
41 |
+
mask = (
|
42 |
+
mask.float()
|
43 |
+
.masked_fill(mask == 0, float("-inf"))
|
44 |
+
.masked_fill(mask == 1, float(0.0))
|
45 |
+
)
|
46 |
+
return mask
|
47 |
+
|
48 |
+
|
49 |
+
class PositionalEncoding(nn.Module):
|
50 |
+
def __init__(self, embedding_dim, dropout, maxlen=5000):
|
51 |
+
super(PositionalEncoding, self).__init__()
|
52 |
+
den = torch.exp(
|
53 |
+
-torch.arange(0, embedding_dim, 2) * math.log(10000) / embedding_dim
|
54 |
+
)
|
55 |
+
pos = torch.arange(0, maxlen).reshape(maxlen, 1)
|
56 |
+
pos_embedding = torch.zeros((maxlen, embedding_dim))
|
57 |
+
pos_embedding[:, 0::2] = torch.sin(pos * den)
|
58 |
+
pos_embedding[:, 1::2] = torch.cos(pos * den)
|
59 |
+
pos_embedding = pos_embedding.unsqueeze(-2)
|
60 |
+
|
61 |
+
self.dropout = nn.Dropout(dropout)
|
62 |
+
self.register_buffer("pos_embedding", pos_embedding)
|
63 |
+
|
64 |
+
def forward(self, token_embedding):
|
65 |
+
return self.dropout(
|
66 |
+
token_embedding + self.pos_embedding[: token_embedding.size(0), :]
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
class TransformerSeq2Seq(L.LightningModule):
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
src_vocab_size,
|
74 |
+
tgt_vocab_size,
|
75 |
+
embedding_dim=512,
|
76 |
+
hidden_dim=512,
|
77 |
+
dropout=0.1,
|
78 |
+
nhead=8,
|
79 |
+
num_layers=3,
|
80 |
+
batch_size=32,
|
81 |
+
lr=1e-4,
|
82 |
+
weight_decay=1e-4,
|
83 |
+
sos_idx=1,
|
84 |
+
eos_idx=2,
|
85 |
+
padding_idx=3,
|
86 |
+
):
|
87 |
+
super().__init__()
|
88 |
+
self.save_hyperparameters()
|
89 |
+
|
90 |
+
self.src_embedding = nn.Embedding(
|
91 |
+
src_vocab_size,
|
92 |
+
embedding_dim,
|
93 |
+
padding_idx=padding_idx,
|
94 |
+
)
|
95 |
+
self.tgt_embedding = nn.Embedding(
|
96 |
+
tgt_vocab_size,
|
97 |
+
embedding_dim,
|
98 |
+
padding_idx=padding_idx,
|
99 |
+
)
|
100 |
+
self.positional_encoding = PositionalEncoding(
|
101 |
+
embedding_dim=embedding_dim,
|
102 |
+
dropout=dropout,
|
103 |
+
)
|
104 |
+
self.transformer = nn.Transformer(
|
105 |
+
d_model=embedding_dim,
|
106 |
+
nhead=nhead,
|
107 |
+
num_encoder_layers=num_layers,
|
108 |
+
num_decoder_layers=num_layers,
|
109 |
+
dim_feedforward=hidden_dim,
|
110 |
+
dropout=dropout,
|
111 |
+
)
|
112 |
+
self.fc = nn.Linear(embedding_dim, tgt_vocab_size)
|
113 |
+
|
114 |
+
for p in self.parameters():
|
115 |
+
if p.dim() > 1:
|
116 |
+
nn.init.xavier_uniform_(p)
|
117 |
+
|
118 |
+
self.criteria = nn.CrossEntropyLoss()
|
119 |
+
|
120 |
+
def forward(
|
121 |
+
self,
|
122 |
+
src,
|
123 |
+
tgt,
|
124 |
+
src_mask,
|
125 |
+
tgt_mask,
|
126 |
+
src_padding_mask,
|
127 |
+
tgt_padding_mask,
|
128 |
+
):
|
129 |
+
src = self.src_embedding(src) * (self.hparams.embedding_dim**0.5)
|
130 |
+
tgt = self.tgt_embedding(tgt) * (self.hparams.embedding_dim**0.5)
|
131 |
+
src = self.positional_encoding(src)
|
132 |
+
tgt = self.positional_encoding(tgt)
|
133 |
+
out = self.transformer(
|
134 |
+
src,
|
135 |
+
tgt,
|
136 |
+
src_mask=src_mask,
|
137 |
+
tgt_mask=tgt_mask,
|
138 |
+
src_key_padding_mask=src_padding_mask,
|
139 |
+
tgt_key_padding_mask=tgt_padding_mask,
|
140 |
+
)
|
141 |
+
out = self.fc(out)
|
142 |
+
return out
|
143 |
+
|
144 |
+
def greedy_decode(self, src, max_len):
|
145 |
+
src = self.src_embedding(src) * (self.hparams.embedding_dim**0.5)
|
146 |
+
src = self.positional_encoding(src)
|
147 |
+
memory = self.transformer.encoder(src)
|
148 |
+
ys = torch.ones(1, 1).fill_(self.hparams.sos_idx).type(torch.long)
|
149 |
+
for i in range(max_len - 1):
|
150 |
+
tgt = self.tgt_embedding(ys) * (self.hparams.embedding_dim**0.5)
|
151 |
+
tgt = self.positional_encoding(tgt)
|
152 |
+
tgt_mask = generate_square_subsequent_mask(ys.size(0)).type(torch.bool)
|
153 |
+
out = self.transformer.decoder(
|
154 |
+
tgt,
|
155 |
+
memory,
|
156 |
+
tgt_mask=tgt_mask,
|
157 |
+
)
|
158 |
+
out = self.fc(out)
|
159 |
+
out = out.transpose(0, 1)[:, -1]
|
160 |
+
prob = out.softmax(dim=-1)
|
161 |
+
_, next_word = torch.max(prob, dim=1)
|
162 |
+
next_word = next_word.item()
|
163 |
+
ys = torch.cat(
|
164 |
+
[ys, torch.ones(1, 1).fill_(next_word).type(torch.long)],
|
165 |
+
dim=0,
|
166 |
+
)
|
167 |
+
|
168 |
+
if next_word == self.hparams.eos_idx:
|
169 |
+
break
|
170 |
+
|
171 |
+
return ys
|
172 |
+
|
173 |
+
def training_step(self, batch, batch_idx):
|
174 |
+
src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = batch
|
175 |
+
tgt_input = tgt[:-1, :]
|
176 |
+
logits = self(
|
177 |
+
src,
|
178 |
+
tgt_input,
|
179 |
+
src_mask,
|
180 |
+
tgt_mask,
|
181 |
+
src_padding_mask,
|
182 |
+
tgt_padding_mask,
|
183 |
+
)
|
184 |
+
tgt_out = tgt[1:, :]
|
185 |
+
loss = self.criteria(
|
186 |
+
logits.reshape(-1, logits.shape[-1]),
|
187 |
+
tgt_out.reshape(-1),
|
188 |
+
)
|
189 |
+
self.log("train_loss", loss, batch_size=self.hparams.batch_size)
|
190 |
+
return loss
|
191 |
+
|
192 |
+
def validation_step(self, batch, batch_idx):
|
193 |
+
src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = batch
|
194 |
+
tgt_input = tgt[:-1, :]
|
195 |
+
logits = self(
|
196 |
+
src,
|
197 |
+
tgt_input,
|
198 |
+
src_mask,
|
199 |
+
tgt_mask,
|
200 |
+
src_padding_mask,
|
201 |
+
tgt_padding_mask,
|
202 |
+
)
|
203 |
+
tgt_out = tgt[1:, :]
|
204 |
+
loss = self.criteria(
|
205 |
+
logits.reshape(-1, logits.shape[-1]),
|
206 |
+
tgt_out.reshape(-1),
|
207 |
+
)
|
208 |
+
self.log("val_loss", loss, batch_size=self.hparams.batch_size)
|
209 |
+
|
210 |
+
def configure_optimizers(self):
|
211 |
+
optimizer = torch.optim.AdamW(
|
212 |
+
self.parameters(),
|
213 |
+
lr=self.hparams.lr,
|
214 |
+
weight_decay=self.hparams.weight_decay,
|
215 |
+
)
|
216 |
+
return {
|
217 |
+
"optimizer": optimizer,
|
218 |
+
"lr_scheduler": {
|
219 |
+
"scheduler": torch.optim.lr_scheduler.OneCycleLR(
|
220 |
+
optimizer=optimizer,
|
221 |
+
max_lr=self.hparams.lr,
|
222 |
+
total_steps=self.trainer.estimated_stepping_batches,
|
223 |
+
),
|
224 |
+
"interval": "step",
|
225 |
+
},
|
226 |
+
}
|
227 |
+
|
228 |
+
|
229 |
+
src_tokenizer_ckpt_path = hf_hub_download(
|
230 |
+
repo_id="SatwikKambham/opus100-en-hi-transformer",
|
231 |
+
filename="tokenizer-en.json",
|
232 |
+
)
|
233 |
+
tgt_tokenizer_ckpt_path = hf_hub_download(
|
234 |
+
repo_id="SatwikKambham/opus100-en-hi-transformer",
|
235 |
+
filename="tokenizer-hi.json",
|
236 |
+
)
|
237 |
+
model_ckpt_path = hf_hub_download(
|
238 |
+
repo_id="SatwikKambham/opus100-en-hi-transformer",
|
239 |
+
filename="transformer.ckpt",
|
240 |
+
)
|
241 |
+
classifier = Translator(
|
242 |
+
src_tokenizer_ckpt_path,
|
243 |
+
tgt_tokenizer_ckpt_path,
|
244 |
+
model_ckpt_path,
|
245 |
+
)
|
246 |
+
interface = gr.Interface(
|
247 |
+
fn=classifier.predict,
|
248 |
+
inputs=gr.components.Textbox(
|
249 |
+
label="Source Language (English)",
|
250 |
+
placeholder="Enter text here...",
|
251 |
+
),
|
252 |
+
outputs=gr.components.Textbox(
|
253 |
+
label="Target Language (Hindi)",
|
254 |
+
placeholder="Translation",
|
255 |
+
),
|
256 |
+
examples=[
|
257 |
+
["Hi how are you?"],
|
258 |
+
["Today is a very important day."],
|
259 |
+
["I like playing the guitar."],
|
260 |
+
],
|
261 |
+
)
|
262 |
+
interface.launch()
|