SatwikKambham commited on
Commit
5c036af
·
1 Parent(s): 0e28335

Add TREC classifier code and requirements

Browse files
Files changed (2) hide show
  1. app.py +227 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import lightning as L
4
+ import torchmetrics as tm
5
+ from tokenizers import Tokenizer
6
+ import gradio as gr
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ COARSE_LABELS = [
10
+ "ABBR (0): Abbreviation",
11
+ "ENTY (1): Entity",
12
+ "DESC (2): Description and abstract concept",
13
+ "HUM (3): Human being",
14
+ "LOC (4): Location",
15
+ "NUM (5): Numeric value",
16
+ ]
17
+
18
+ FINE_LABELS = [
19
+ "ABBR (0): Abbreviation",
20
+ "ABBR (1): Expression abbreviated",
21
+ "ENTY (2): Animal",
22
+ "ENTY (3): Organ of body",
23
+ "ENTY (4): Color",
24
+ "ENTY (5): Invention, book and other creative piece",
25
+ "ENTY (6): Currency name",
26
+ "ENTY (7): Disease and medicine",
27
+ "ENTY (8): Event",
28
+ "ENTY (9): Food",
29
+ "ENTY (10): Musical instrument",
30
+ "ENTY (11): Language",
31
+ "ENTY (12): Letter like a-z",
32
+ "ENTY (13): Other entity",
33
+ "ENTY (14): Plant",
34
+ "ENTY (15): Product",
35
+ "ENTY (16): Religion",
36
+ "ENTY (17): Sport",
37
+ "ENTY (18): Element and substance",
38
+ "ENTY (19): Symbols and sign",
39
+ "ENTY (20): Techniques and method",
40
+ "ENTY (21): Equivalent term",
41
+ "ENTY (22): Vehicle",
42
+ "ENTY (23): Word with a special property",
43
+ "DESC (24): Definition of something",
44
+ "DESC (25): Description of something",
45
+ "DESC (26): Manner of an action",
46
+ "DESC (27): Reason",
47
+ "HUM (28): Group or organization of persons",
48
+ "HUM (29): Individual",
49
+ "HUM (30): Title of a person",
50
+ "HUM (31): Description of a person",
51
+ "LOC (32): City",
52
+ "LOC (33): Country",
53
+ "LOC (34): Mountain",
54
+ "LOC (35): Other location",
55
+ "LOC (36): State",
56
+ "NUM (37): Postcode or other code",
57
+ "NUM (38): Number of something",
58
+ "NUM (39): Date",
59
+ "NUM (40): Distance, linear measure",
60
+ "NUM (41): Price",
61
+ "NUM (42): Order, rank",
62
+ "NUM (43): Other number",
63
+ "NUM (44): Lasting time of something",
64
+ "NUM (45): Percent, fraction",
65
+ "NUM (46): Speed",
66
+ "NUM (47): Temperature",
67
+ "NUM (48): Size, area and volume",
68
+ "NUM (49): Weight",
69
+ ]
70
+
71
+
72
+ class Classifier:
73
+ def __init__(self, tokenizer_ckpt_path, model_ckpt_path):
74
+ self.tokenizer = Tokenizer.from_file(tokenizer_ckpt_path)
75
+ self.model = LSTMWithAttentionClassifier.load_from_checkpoint(
76
+ model_ckpt_path,
77
+ map_location="cpu",
78
+ )
79
+
80
+ def predict(self, text):
81
+ encoding = self.tokenizer.encode(text)
82
+ ids = torch.tensor([encoding.ids])
83
+ logits, _ = self.model(ids)
84
+ probs = torch.softmax(logits, dim=1).squeeze().tolist()
85
+ return {
86
+ category: prob
87
+ for category, prob in zip(
88
+ FINE_LABELS if self.model.fine else COARSE_LABELS, probs
89
+ )
90
+ }
91
+
92
+
93
+ class Attention(nn.Module):
94
+ def __init__(self, hidden_dim):
95
+ super().__init__()
96
+ self.WQuery = nn.Linear(hidden_dim, hidden_dim)
97
+ self.WKey = nn.Linear(hidden_dim, hidden_dim)
98
+ self.WValue = nn.Linear(hidden_dim, 1)
99
+
100
+ def forward(self, x):
101
+ query = torch.tanh(self.WQuery(x))
102
+ key = torch.tanh(self.WKey(x))
103
+
104
+ attention_weights = torch.softmax(self.WValue(query + key), dim=1)
105
+
106
+ return (attention_weights * x).sum(dim=1), attention_weights
107
+
108
+
109
+ class LSTMWithAttentionClassifier(L.LightningModule):
110
+ def __init__(
111
+ self,
112
+ vocab_size,
113
+ embedding_dim,
114
+ hidden_dim,
115
+ num_classes,
116
+ lr=1e-3,
117
+ weight_decay=1e-2,
118
+ num_layers=1,
119
+ bidirectional=False,
120
+ dropout=0.0,
121
+ padding_idx=3,
122
+ fine=False,
123
+ **kwargs,
124
+ ):
125
+ super().__init__()
126
+ self.save_hyperparameters()
127
+ self.lr = lr
128
+ self.weight_decay = weight_decay
129
+ self.fine = fine
130
+
131
+ self.embedding = nn.Embedding(
132
+ vocab_size,
133
+ embedding_dim,
134
+ padding_idx=padding_idx,
135
+ )
136
+ self.lstm = nn.LSTM(
137
+ embedding_dim,
138
+ hidden_dim,
139
+ num_layers=num_layers,
140
+ batch_first=True,
141
+ bidirectional=bidirectional,
142
+ dropout=dropout,
143
+ )
144
+ self.attention = Attention(
145
+ hidden_dim * (1 + bidirectional),
146
+ )
147
+ self.fc = nn.Linear(
148
+ hidden_dim * (1 + bidirectional),
149
+ num_classes,
150
+ )
151
+
152
+ self.criteria = nn.CrossEntropyLoss()
153
+ self.accuracy = tm.Accuracy(
154
+ task="multiclass",
155
+ num_classes=num_classes,
156
+ )
157
+
158
+ def forward(self, input_ids):
159
+ x = self.embedding(input_ids)
160
+ x, _ = self.lstm(x)
161
+ x, attention_weights = self.attention(x)
162
+ x = self.fc(x)
163
+ return x, attention_weights
164
+
165
+ def training_step(self, batch, batch_idx):
166
+ input_ids = batch["input_ids"]
167
+ coarse = batch["coarse"]
168
+ fine = batch["fine"]
169
+ logits, _ = self(input_ids)
170
+ loss = self.criteria(logits, fine if self.fine else coarse)
171
+ self.log("train_loss", loss)
172
+ return loss
173
+
174
+ def validation_step(self, batch, batch_idx):
175
+ input_ids = batch["input_ids"]
176
+ coarse = batch["coarse"]
177
+ fine = batch["fine"]
178
+ logits, _ = self(input_ids)
179
+ loss = self.criteria(logits, fine if self.fine else coarse)
180
+ self.log("val_loss", loss)
181
+ pred = logits.argmax(dim=1)
182
+ self.accuracy(pred, fine if self.fine else coarse)
183
+ self.log("val_acc", self.accuracy, prog_bar=True)
184
+
185
+ def configure_optimizers(self):
186
+ return torch.optim.AdamW(
187
+ self.parameters(),
188
+ lr=self.lr,
189
+ weight_decay=self.weight_decay,
190
+ )
191
+
192
+
193
+ tokenizer_ckpt_path = hf_hub_download(
194
+ repo_id="SatwikKambham/trec-classifier",
195
+ filename="tokenizer.json",
196
+ )
197
+ model_ckpt_path = hf_hub_download(
198
+ repo_id="SatwikKambham/trec-classifier",
199
+ filename="lstm_attention.ckpt",
200
+ )
201
+ classifier = Classifier(tokenizer_ckpt_path, model_ckpt_path)
202
+ interface = gr.Interface(
203
+ fn=classifier.predict,
204
+ inputs=gr.components.Textbox(
205
+ label="Question",
206
+ placeholder="Enter a question here...",
207
+ ),
208
+ outputs=gr.components.Label(
209
+ label="Predicted class",
210
+ num_top_classes=3,
211
+ ),
212
+ examples=[
213
+ [
214
+ "What does LOL mean?",
215
+ ],
216
+ [
217
+ "What is the meaning of life?",
218
+ ],
219
+ [
220
+ "How long does it take for light from the sun to reach the earth?",
221
+ ],
222
+ [
223
+ "When is friendship day?",
224
+ ],
225
+ ],
226
+ )
227
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ tokenizers
3
+ lightning
4
+ torchmetrics
5
+ huggingface_hub