|
import torch |
|
import torch.nn as nn |
|
import lightning as L |
|
import torchmetrics as tm |
|
from tokenizers import Tokenizer |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
|
|
COARSE_LABELS = [ |
|
"ABBR (0): Abbreviation", |
|
"ENTY (1): Entity", |
|
"DESC (2): Description and abstract concept", |
|
"HUM (3): Human being", |
|
"LOC (4): Location", |
|
"NUM (5): Numeric value", |
|
] |
|
|
|
FINE_LABELS = [ |
|
"ABBR (0): Abbreviation", |
|
"ABBR (1): Expression abbreviated", |
|
"ENTY (2): Animal", |
|
"ENTY (3): Organ of body", |
|
"ENTY (4): Color", |
|
"ENTY (5): Invention, book and other creative piece", |
|
"ENTY (6): Currency name", |
|
"ENTY (7): Disease and medicine", |
|
"ENTY (8): Event", |
|
"ENTY (9): Food", |
|
"ENTY (10): Musical instrument", |
|
"ENTY (11): Language", |
|
"ENTY (12): Letter like a-z", |
|
"ENTY (13): Other entity", |
|
"ENTY (14): Plant", |
|
"ENTY (15): Product", |
|
"ENTY (16): Religion", |
|
"ENTY (17): Sport", |
|
"ENTY (18): Element and substance", |
|
"ENTY (19): Symbols and sign", |
|
"ENTY (20): Techniques and method", |
|
"ENTY (21): Equivalent term", |
|
"ENTY (22): Vehicle", |
|
"ENTY (23): Word with a special property", |
|
"DESC (24): Definition of something", |
|
"DESC (25): Description of something", |
|
"DESC (26): Manner of an action", |
|
"DESC (27): Reason", |
|
"HUM (28): Group or organization of persons", |
|
"HUM (29): Individual", |
|
"HUM (30): Title of a person", |
|
"HUM (31): Description of a person", |
|
"LOC (32): City", |
|
"LOC (33): Country", |
|
"LOC (34): Mountain", |
|
"LOC (35): Other location", |
|
"LOC (36): State", |
|
"NUM (37): Postcode or other code", |
|
"NUM (38): Number of something", |
|
"NUM (39): Date", |
|
"NUM (40): Distance, linear measure", |
|
"NUM (41): Price", |
|
"NUM (42): Order, rank", |
|
"NUM (43): Other number", |
|
"NUM (44): Lasting time of something", |
|
"NUM (45): Percent, fraction", |
|
"NUM (46): Speed", |
|
"NUM (47): Temperature", |
|
"NUM (48): Size, area and volume", |
|
"NUM (49): Weight", |
|
] |
|
|
|
|
|
class Classifier: |
|
def __init__(self, tokenizer_ckpt_path, model_ckpt_path): |
|
self.tokenizer = Tokenizer.from_file(tokenizer_ckpt_path) |
|
self.model = LSTMWithAttentionClassifier.load_from_checkpoint( |
|
model_ckpt_path, |
|
map_location="cpu", |
|
) |
|
|
|
def predict(self, text): |
|
encoding = self.tokenizer.encode(text) |
|
ids = torch.tensor([encoding.ids]) |
|
logits, _ = self.model(ids) |
|
probs = torch.softmax(logits, dim=1).squeeze().tolist() |
|
return { |
|
category: prob |
|
for category, prob in zip( |
|
FINE_LABELS if self.model.fine else COARSE_LABELS, probs |
|
) |
|
} |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, hidden_dim): |
|
super().__init__() |
|
self.WQuery = nn.Linear(hidden_dim, hidden_dim) |
|
self.WKey = nn.Linear(hidden_dim, hidden_dim) |
|
self.WValue = nn.Linear(hidden_dim, 1) |
|
|
|
def forward(self, x): |
|
query = torch.tanh(self.WQuery(x)) |
|
key = torch.tanh(self.WKey(x)) |
|
|
|
attention_weights = torch.softmax(self.WValue(query + key), dim=1) |
|
|
|
return (attention_weights * x).sum(dim=1), attention_weights |
|
|
|
|
|
class LSTMWithAttentionClassifier(L.LightningModule): |
|
def __init__( |
|
self, |
|
vocab_size, |
|
embedding_dim, |
|
hidden_dim, |
|
num_classes, |
|
lr=1e-3, |
|
weight_decay=1e-2, |
|
num_layers=1, |
|
bidirectional=False, |
|
dropout=0.0, |
|
padding_idx=3, |
|
fine=False, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.save_hyperparameters() |
|
self.lr = lr |
|
self.weight_decay = weight_decay |
|
self.fine = fine |
|
|
|
self.embedding = nn.Embedding( |
|
vocab_size, |
|
embedding_dim, |
|
padding_idx=padding_idx, |
|
) |
|
self.lstm = nn.LSTM( |
|
embedding_dim, |
|
hidden_dim, |
|
num_layers=num_layers, |
|
batch_first=True, |
|
bidirectional=bidirectional, |
|
dropout=dropout, |
|
) |
|
self.attention = Attention( |
|
hidden_dim * (1 + bidirectional), |
|
) |
|
self.fc = nn.Linear( |
|
hidden_dim * (1 + bidirectional), |
|
num_classes, |
|
) |
|
|
|
self.criteria = nn.CrossEntropyLoss() |
|
self.accuracy = tm.Accuracy( |
|
task="multiclass", |
|
num_classes=num_classes, |
|
) |
|
|
|
def forward(self, input_ids): |
|
x = self.embedding(input_ids) |
|
x, _ = self.lstm(x) |
|
x, attention_weights = self.attention(x) |
|
x = self.fc(x) |
|
return x, attention_weights |
|
|
|
def training_step(self, batch, batch_idx): |
|
input_ids = batch["input_ids"] |
|
coarse = batch["coarse"] |
|
fine = batch["fine"] |
|
logits, _ = self(input_ids) |
|
loss = self.criteria(logits, fine if self.fine else coarse) |
|
self.log("train_loss", loss) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
input_ids = batch["input_ids"] |
|
coarse = batch["coarse"] |
|
fine = batch["fine"] |
|
logits, _ = self(input_ids) |
|
loss = self.criteria(logits, fine if self.fine else coarse) |
|
self.log("val_loss", loss) |
|
pred = logits.argmax(dim=1) |
|
self.accuracy(pred, fine if self.fine else coarse) |
|
self.log("val_acc", self.accuracy, prog_bar=True) |
|
|
|
def configure_optimizers(self): |
|
return torch.optim.AdamW( |
|
self.parameters(), |
|
lr=self.lr, |
|
weight_decay=self.weight_decay, |
|
) |
|
|
|
|
|
tokenizer_ckpt_path = hf_hub_download( |
|
repo_id="SatwikKambham/trec-classifier", |
|
filename="tokenizer.json", |
|
) |
|
model_ckpt_path = hf_hub_download( |
|
repo_id="SatwikKambham/trec-classifier", |
|
filename="lstm_attention.ckpt", |
|
) |
|
classifier = Classifier(tokenizer_ckpt_path, model_ckpt_path) |
|
interface = gr.Interface( |
|
fn=classifier.predict, |
|
inputs=gr.components.Textbox( |
|
label="Question", |
|
placeholder="Enter a question here...", |
|
), |
|
outputs=gr.components.Label( |
|
label="Predicted class", |
|
num_top_classes=3, |
|
), |
|
examples=[ |
|
[ |
|
"What does LOL mean?", |
|
], |
|
[ |
|
"What is the meaning of life?", |
|
], |
|
[ |
|
"How long does it take for light from the sun to reach the earth?", |
|
], |
|
[ |
|
"When is friendship day?", |
|
], |
|
], |
|
) |
|
interface.launch() |
|
|