SatwikKambham commited on
Commit
abea982
·
1 Parent(s): 2cb750c

Add gradio app file

Browse files
Files changed (1) hide show
  1. app.py +262 -0
app.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import gradio as gr
4
+ import lightning as L
5
+ import torch
6
+ import torch.nn as nn
7
+ from huggingface_hub import hf_hub_download
8
+ from tokenizers import Tokenizer
9
+
10
+
11
+ class Translator:
12
+ def __init__(
13
+ self,
14
+ src_tokenizer_ckpt_path,
15
+ tgt_tokenizer_ckpt_path,
16
+ model_ckpt_path,
17
+ ):
18
+ self.src_tokenizer = Tokenizer.from_file(src_tokenizer_ckpt_path)
19
+ self.tgt_tokenizer = Tokenizer.from_file(tgt_tokenizer_ckpt_path)
20
+
21
+ self.src_tokenizer.model.dropout = 0
22
+ self.tgt_tokenizer.model.dropout = 0
23
+
24
+ self.model = TransformerSeq2Seq.load_from_checkpoint(
25
+ model_ckpt_path,
26
+ map_location="cpu",
27
+ )
28
+ self.model.eval()
29
+
30
+ def predict(self, src):
31
+ tokenized_text = self.src_tokenizer.encode(src)
32
+ src = torch.LongTensor(tokenized_text.ids).view(-1, 1)
33
+ tgt = self.model.greedy_decode(src, max_len=100)
34
+ tgt = tgt.squeeze(1).tolist()
35
+ tgt_text = self.tgt_tokenizer.decode(tgt)
36
+ return tgt_text
37
+
38
+
39
+ def generate_square_subsequent_mask(sz):
40
+ mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1)
41
+ mask = (
42
+ mask.float()
43
+ .masked_fill(mask == 0, float("-inf"))
44
+ .masked_fill(mask == 1, float(0.0))
45
+ )
46
+ return mask
47
+
48
+
49
+ class PositionalEncoding(nn.Module):
50
+ def __init__(self, embedding_dim, dropout, maxlen=5000):
51
+ super(PositionalEncoding, self).__init__()
52
+ den = torch.exp(
53
+ -torch.arange(0, embedding_dim, 2) * math.log(10000) / embedding_dim
54
+ )
55
+ pos = torch.arange(0, maxlen).reshape(maxlen, 1)
56
+ pos_embedding = torch.zeros((maxlen, embedding_dim))
57
+ pos_embedding[:, 0::2] = torch.sin(pos * den)
58
+ pos_embedding[:, 1::2] = torch.cos(pos * den)
59
+ pos_embedding = pos_embedding.unsqueeze(-2)
60
+
61
+ self.dropout = nn.Dropout(dropout)
62
+ self.register_buffer("pos_embedding", pos_embedding)
63
+
64
+ def forward(self, token_embedding):
65
+ return self.dropout(
66
+ token_embedding + self.pos_embedding[: token_embedding.size(0), :]
67
+ )
68
+
69
+
70
+ class TransformerSeq2Seq(L.LightningModule):
71
+ def __init__(
72
+ self,
73
+ src_vocab_size,
74
+ tgt_vocab_size,
75
+ embedding_dim=512,
76
+ hidden_dim=512,
77
+ dropout=0.1,
78
+ nhead=8,
79
+ num_layers=3,
80
+ batch_size=32,
81
+ lr=1e-4,
82
+ weight_decay=1e-4,
83
+ sos_idx=1,
84
+ eos_idx=2,
85
+ padding_idx=3,
86
+ ):
87
+ super().__init__()
88
+ self.save_hyperparameters()
89
+
90
+ self.src_embedding = nn.Embedding(
91
+ src_vocab_size,
92
+ embedding_dim,
93
+ padding_idx=padding_idx,
94
+ )
95
+ self.tgt_embedding = nn.Embedding(
96
+ tgt_vocab_size,
97
+ embedding_dim,
98
+ padding_idx=padding_idx,
99
+ )
100
+ self.positional_encoding = PositionalEncoding(
101
+ embedding_dim=embedding_dim,
102
+ dropout=dropout,
103
+ )
104
+ self.transformer = nn.Transformer(
105
+ d_model=embedding_dim,
106
+ nhead=nhead,
107
+ num_encoder_layers=num_layers,
108
+ num_decoder_layers=num_layers,
109
+ dim_feedforward=hidden_dim,
110
+ dropout=dropout,
111
+ )
112
+ self.fc = nn.Linear(embedding_dim, tgt_vocab_size)
113
+
114
+ for p in self.parameters():
115
+ if p.dim() > 1:
116
+ nn.init.xavier_uniform_(p)
117
+
118
+ self.criteria = nn.CrossEntropyLoss()
119
+
120
+ def forward(
121
+ self,
122
+ src,
123
+ tgt,
124
+ src_mask,
125
+ tgt_mask,
126
+ src_padding_mask,
127
+ tgt_padding_mask,
128
+ ):
129
+ src = self.src_embedding(src) * (self.hparams.embedding_dim**0.5)
130
+ tgt = self.tgt_embedding(tgt) * (self.hparams.embedding_dim**0.5)
131
+ src = self.positional_encoding(src)
132
+ tgt = self.positional_encoding(tgt)
133
+ out = self.transformer(
134
+ src,
135
+ tgt,
136
+ src_mask=src_mask,
137
+ tgt_mask=tgt_mask,
138
+ src_key_padding_mask=src_padding_mask,
139
+ tgt_key_padding_mask=tgt_padding_mask,
140
+ )
141
+ out = self.fc(out)
142
+ return out
143
+
144
+ def greedy_decode(self, src, max_len):
145
+ src = self.src_embedding(src) * (self.hparams.embedding_dim**0.5)
146
+ src = self.positional_encoding(src)
147
+ memory = self.transformer.encoder(src)
148
+ ys = torch.ones(1, 1).fill_(self.hparams.sos_idx).type(torch.long)
149
+ for i in range(max_len - 1):
150
+ tgt = self.tgt_embedding(ys) * (self.hparams.embedding_dim**0.5)
151
+ tgt = self.positional_encoding(tgt)
152
+ tgt_mask = generate_square_subsequent_mask(ys.size(0)).type(torch.bool)
153
+ out = self.transformer.decoder(
154
+ tgt,
155
+ memory,
156
+ tgt_mask=tgt_mask,
157
+ )
158
+ out = self.fc(out)
159
+ out = out.transpose(0, 1)[:, -1]
160
+ prob = out.softmax(dim=-1)
161
+ _, next_word = torch.max(prob, dim=1)
162
+ next_word = next_word.item()
163
+ ys = torch.cat(
164
+ [ys, torch.ones(1, 1).fill_(next_word).type(torch.long)],
165
+ dim=0,
166
+ )
167
+
168
+ if next_word == self.hparams.eos_idx:
169
+ break
170
+
171
+ return ys
172
+
173
+ def training_step(self, batch, batch_idx):
174
+ src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = batch
175
+ tgt_input = tgt[:-1, :]
176
+ logits = self(
177
+ src,
178
+ tgt_input,
179
+ src_mask,
180
+ tgt_mask,
181
+ src_padding_mask,
182
+ tgt_padding_mask,
183
+ )
184
+ tgt_out = tgt[1:, :]
185
+ loss = self.criteria(
186
+ logits.reshape(-1, logits.shape[-1]),
187
+ tgt_out.reshape(-1),
188
+ )
189
+ self.log("train_loss", loss, batch_size=self.hparams.batch_size)
190
+ return loss
191
+
192
+ def validation_step(self, batch, batch_idx):
193
+ src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = batch
194
+ tgt_input = tgt[:-1, :]
195
+ logits = self(
196
+ src,
197
+ tgt_input,
198
+ src_mask,
199
+ tgt_mask,
200
+ src_padding_mask,
201
+ tgt_padding_mask,
202
+ )
203
+ tgt_out = tgt[1:, :]
204
+ loss = self.criteria(
205
+ logits.reshape(-1, logits.shape[-1]),
206
+ tgt_out.reshape(-1),
207
+ )
208
+ self.log("val_loss", loss, batch_size=self.hparams.batch_size)
209
+
210
+ def configure_optimizers(self):
211
+ optimizer = torch.optim.AdamW(
212
+ self.parameters(),
213
+ lr=self.hparams.lr,
214
+ weight_decay=self.hparams.weight_decay,
215
+ )
216
+ return {
217
+ "optimizer": optimizer,
218
+ "lr_scheduler": {
219
+ "scheduler": torch.optim.lr_scheduler.OneCycleLR(
220
+ optimizer=optimizer,
221
+ max_lr=self.hparams.lr,
222
+ total_steps=self.trainer.estimated_stepping_batches,
223
+ ),
224
+ "interval": "step",
225
+ },
226
+ }
227
+
228
+
229
+ src_tokenizer_ckpt_path = hf_hub_download(
230
+ repo_id="SatwikKambham/opus100-en-hi-transformer",
231
+ filename="tokenizer-en.json",
232
+ )
233
+ tgt_tokenizer_ckpt_path = hf_hub_download(
234
+ repo_id="SatwikKambham/opus100-en-hi-transformer",
235
+ filename="tokenizer-hi.json",
236
+ )
237
+ model_ckpt_path = hf_hub_download(
238
+ repo_id="SatwikKambham/opus100-en-hi-transformer",
239
+ filename="transformer.ckpt",
240
+ )
241
+ classifier = Translator(
242
+ src_tokenizer_ckpt_path,
243
+ tgt_tokenizer_ckpt_path,
244
+ model_ckpt_path,
245
+ )
246
+ interface = gr.Interface(
247
+ fn=classifier.predict,
248
+ inputs=gr.components.Textbox(
249
+ label="Source Language (English)",
250
+ placeholder="Enter text here...",
251
+ ),
252
+ outputs=gr.components.Textbox(
253
+ label="Target Language (Hindi)",
254
+ placeholder="Translation",
255
+ ),
256
+ examples=[
257
+ ["Hi how are you?"],
258
+ ["Today is a very important day."],
259
+ ["I like playing the guitar."],
260
+ ],
261
+ )
262
+ interface.launch()