import torch import torch.nn as nn import lightning as L import torchmetrics as tm from tokenizers import Tokenizer import gradio as gr from huggingface_hub import hf_hub_download COARSE_LABELS = [ "ABBR (0): Abbreviation", "ENTY (1): Entity", "DESC (2): Description and abstract concept", "HUM (3): Human being", "LOC (4): Location", "NUM (5): Numeric value", ] FINE_LABELS = [ "ABBR (0): Abbreviation", "ABBR (1): Expression abbreviated", "ENTY (2): Animal", "ENTY (3): Organ of body", "ENTY (4): Color", "ENTY (5): Invention, book and other creative piece", "ENTY (6): Currency name", "ENTY (7): Disease and medicine", "ENTY (8): Event", "ENTY (9): Food", "ENTY (10): Musical instrument", "ENTY (11): Language", "ENTY (12): Letter like a-z", "ENTY (13): Other entity", "ENTY (14): Plant", "ENTY (15): Product", "ENTY (16): Religion", "ENTY (17): Sport", "ENTY (18): Element and substance", "ENTY (19): Symbols and sign", "ENTY (20): Techniques and method", "ENTY (21): Equivalent term", "ENTY (22): Vehicle", "ENTY (23): Word with a special property", "DESC (24): Definition of something", "DESC (25): Description of something", "DESC (26): Manner of an action", "DESC (27): Reason", "HUM (28): Group or organization of persons", "HUM (29): Individual", "HUM (30): Title of a person", "HUM (31): Description of a person", "LOC (32): City", "LOC (33): Country", "LOC (34): Mountain", "LOC (35): Other location", "LOC (36): State", "NUM (37): Postcode or other code", "NUM (38): Number of something", "NUM (39): Date", "NUM (40): Distance, linear measure", "NUM (41): Price", "NUM (42): Order, rank", "NUM (43): Other number", "NUM (44): Lasting time of something", "NUM (45): Percent, fraction", "NUM (46): Speed", "NUM (47): Temperature", "NUM (48): Size, area and volume", "NUM (49): Weight", ] class Classifier: def __init__(self, tokenizer_ckpt_path, model_ckpt_path): self.tokenizer = Tokenizer.from_file(tokenizer_ckpt_path) self.model = LSTMWithAttentionClassifier.load_from_checkpoint( model_ckpt_path, map_location="cpu", ) def predict(self, text): encoding = self.tokenizer.encode(text) ids = torch.tensor([encoding.ids]) logits, _ = self.model(ids) probs = torch.softmax(logits, dim=1).squeeze().tolist() return { category: prob for category, prob in zip( FINE_LABELS if self.model.fine else COARSE_LABELS, probs ) } class Attention(nn.Module): def __init__(self, hidden_dim): super().__init__() self.WQuery = nn.Linear(hidden_dim, hidden_dim) self.WKey = nn.Linear(hidden_dim, hidden_dim) self.WValue = nn.Linear(hidden_dim, 1) def forward(self, x): query = torch.tanh(self.WQuery(x)) key = torch.tanh(self.WKey(x)) attention_weights = torch.softmax(self.WValue(query + key), dim=1) return (attention_weights * x).sum(dim=1), attention_weights class LSTMWithAttentionClassifier(L.LightningModule): def __init__( self, vocab_size, embedding_dim, hidden_dim, num_classes, lr=1e-3, weight_decay=1e-2, num_layers=1, bidirectional=False, dropout=0.0, padding_idx=3, fine=False, **kwargs, ): super().__init__() self.save_hyperparameters() self.lr = lr self.weight_decay = weight_decay self.fine = fine self.embedding = nn.Embedding( vocab_size, embedding_dim, padding_idx=padding_idx, ) self.lstm = nn.LSTM( embedding_dim, hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=bidirectional, dropout=dropout, ) self.attention = Attention( hidden_dim * (1 + bidirectional), ) self.fc = nn.Linear( hidden_dim * (1 + bidirectional), num_classes, ) self.criteria = nn.CrossEntropyLoss() self.accuracy = tm.Accuracy( task="multiclass", num_classes=num_classes, ) def forward(self, input_ids): x = self.embedding(input_ids) x, _ = self.lstm(x) x, attention_weights = self.attention(x) x = self.fc(x) return x, attention_weights def training_step(self, batch, batch_idx): input_ids = batch["input_ids"] coarse = batch["coarse"] fine = batch["fine"] logits, _ = self(input_ids) loss = self.criteria(logits, fine if self.fine else coarse) self.log("train_loss", loss) return loss def validation_step(self, batch, batch_idx): input_ids = batch["input_ids"] coarse = batch["coarse"] fine = batch["fine"] logits, _ = self(input_ids) loss = self.criteria(logits, fine if self.fine else coarse) self.log("val_loss", loss) pred = logits.argmax(dim=1) self.accuracy(pred, fine if self.fine else coarse) self.log("val_acc", self.accuracy, prog_bar=True) def configure_optimizers(self): return torch.optim.AdamW( self.parameters(), lr=self.lr, weight_decay=self.weight_decay, ) tokenizer_ckpt_path = hf_hub_download( repo_id="SatwikKambham/trec-classifier", filename="tokenizer.json", ) model_ckpt_path = hf_hub_download( repo_id="SatwikKambham/trec-classifier", filename="lstm_attention.ckpt", ) classifier = Classifier(tokenizer_ckpt_path, model_ckpt_path) interface = gr.Interface( fn=classifier.predict, inputs=gr.components.Textbox( label="Question", placeholder="Enter a question here...", ), outputs=gr.components.Label( label="Predicted class", num_top_classes=3, ), examples=[ [ "What does LOL mean?", ], [ "What is the meaning of life?", ], [ "How long does it take for light from the sun to reach the earth?", ], [ "When is friendship day?", ], ], ) interface.launch()