Hunter-Pax commited on
Commit
f62f5b8
·
verified ·
1 Parent(s): 2a9c553

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -108
app.py CHANGED
@@ -1,108 +1,107 @@
1
- import gradio as gr
2
- import torch
3
- import torch.nn.functional as F
4
- from transformers import AutoTokenizer
5
- import pickle
6
-
7
- from models.rnn import RNNClassifier
8
- from models.lstm import LSTMClassifier
9
- from models.transformer import TransformerClassifier
10
- from utility import simple_tokenizer
11
-
12
- # =========================
13
- # Load models and vocab
14
- # =========================
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- model_name = "prajjwal1/bert-tiny"
17
-
18
- def load_vocab():
19
- with open("pretrained_models/vocab.pkl", "rb") as f:
20
- return pickle.load(f)
21
-
22
- def load_models(vocab_size, output_dim=6, padding_idx=0):
23
- rnn_model = RNNClassifier(vocab_size, 128, 128, output_dim, padding_idx)
24
- rnn_model.load_state_dict(torch.load("pretrained_models/best_rnn.pt"))
25
- rnn_model = rnn_model.to(device)
26
- rnn_model.eval()
27
-
28
- lstm_model = LSTMClassifier(vocab_size, 128, 128, output_dim, padding_idx)
29
- lstm_model.load_state_dict(torch.load("pretrained_models/best_lstm.pt"))
30
- lstm_model = lstm_model.to(device)
31
- lstm_model.eval()
32
-
33
- transformer_model = TransformerClassifier(model_name, output_dim)
34
- transformer_model.load_state_dict(torch.load("pretrained_models/best_transformer.pt", map_location=device))
35
- transformer_model = transformer_model.to(device)
36
- transformer_model.eval()
37
-
38
- return rnn_model, lstm_model, transformer_model
39
-
40
-
41
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
- vocab = load_vocab()
43
- tokenizer = AutoTokenizer.from_pretrained(model_name)
44
- rnn_model, lstm_model, transformer_model = load_models(len(vocab))
45
-
46
- emotions = ["anger", "fear", "joy", "love", "sadness", "surprise"]
47
-
48
- def predict(model, text, model_type, vocab, tokenizer=None, max_length=32):
49
- if model_type in ["rnn", "lstm"]:
50
- # Match collate_fn_rnn but with no random truncation
51
- tokens = simple_tokenizer(text)
52
- ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens]
53
-
54
- if len(ids) < max_length:
55
- ids += [vocab["<PAD>"]] * (max_length - len(ids))
56
- else:
57
- ids = ids[:max_length]
58
-
59
- input_ids = torch.tensor([ids], dtype=torch.long).to(device)
60
- outputs = model(input_ids)
61
-
62
- else:
63
- # Match collate_fn_transformer but with no partial_prob
64
- encoding = tokenizer(
65
- text,
66
- padding="max_length",
67
- truncation=True,
68
- max_length=128,
69
- return_tensors="pt"
70
- )
71
- input_ids = encoding["input_ids"].to(device)
72
- attention_mask = encoding["attention_mask"].to(device)
73
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
74
-
75
- probs = F.softmax(outputs, dim=-1)
76
- return probs.squeeze().detach().cpu().numpy()
77
-
78
- # =========================
79
- # Gradio App
80
- # =========================
81
-
82
- def emotion_typeahead(text):
83
- if len(text.strip()) <= 2:
84
- return {}, {}, {}
85
-
86
- rnn_probs = predict(rnn_model, text.strip(), "rnn", vocab)
87
- lstm_probs = predict(lstm_model, text.strip(), "lstm", vocab)
88
- transformer_probs = predict(transformer_model, text.strip(), "transformer", vocab, tokenizer)
89
-
90
- rnn_dict = {emo: float(prob) for emo, prob in zip(emotions, rnn_probs)}
91
- lstm_dict = {emo: float(prob) for emo, prob in zip(emotions, lstm_probs)}
92
- transformer_dict = {emo: float(prob) for emo, prob in zip(emotions, transformer_probs)}
93
-
94
- return rnn_dict, lstm_dict, transformer_dict
95
-
96
- with gr.Blocks() as demo:
97
- gr.Markdown("## 🎯 Emotion Typeahead Predictor (RNN, LSTM, Transformer)")
98
-
99
- text_input = gr.Textbox(label="Type your sentence here...")
100
-
101
- with gr.Row():
102
- rnn_output = gr.Label(label="🧠 RNN Prediction")
103
- lstm_output = gr.Label(label="🧠 LSTM Prediction")
104
- transformer_output = gr.Label(label="🧠 Transformer Prediction")
105
-
106
- text_input.change(emotion_typeahead, inputs=text_input, outputs=[rnn_output, lstm_output, transformer_output])
107
-
108
- demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer
5
+ import pickle
6
+
7
+ from models.rnn import RNNClassifier
8
+ from models.lstm import LSTMClassifier
9
+ from models.transformer import TransformerClassifier
10
+ from utility import simple_tokenizer
11
+
12
+ # =========================
13
+ # Load models and vocab
14
+ # =========================
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model_name = "prajjwal1/bert-tiny"
17
+
18
+ def load_vocab():
19
+ with open("pretrained_models/vocab.pkl", "rb") as f:
20
+ return pickle.load(f)
21
+
22
+ def load_models(vocab_size, output_dim=6, padding_idx=0):
23
+ rnn_model = RNNClassifier(vocab_size, 128, 128, output_dim, padding_idx)
24
+ rnn_model.load_state_dict(torch.load("pretrained_models/best_rnn.pt", map_location=device))
25
+ rnn_model = rnn_model.to(device)
26
+ rnn_model.eval()
27
+
28
+ lstm_model = LSTMClassifier(vocab_size, 128, 128, output_dim, padding_idx)
29
+ lstm_model.load_state_dict(torch.load("pretrained_models/best_lstm.pt", map_location=device))
30
+ lstm_model = lstm_model.to(device)
31
+ lstm_model.eval()
32
+
33
+ transformer_model = TransformerClassifier(model_name, output_dim)
34
+ transformer_model.load_state_dict(torch.load("pretrained_models/best_transformer.pt", map_location=device))
35
+ transformer_model = transformer_model.to(device)
36
+ transformer_model.eval()
37
+
38
+ return rnn_model, lstm_model, transformer_model
39
+
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ vocab = load_vocab()
42
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
43
+ rnn_model, lstm_model, transformer_model = load_models(len(vocab))
44
+
45
+ emotions = ["anger", "fear", "joy", "love", "sadness", "surprise"]
46
+
47
+ def predict(model, text, model_type, vocab, tokenizer=None, max_length=32):
48
+ if model_type in ["rnn", "lstm"]:
49
+ # Match collate_fn_rnn but with no random truncation
50
+ tokens = simple_tokenizer(text)
51
+ ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens]
52
+
53
+ if len(ids) < max_length:
54
+ ids += [vocab["<PAD>"]] * (max_length - len(ids))
55
+ else:
56
+ ids = ids[:max_length]
57
+
58
+ input_ids = torch.tensor([ids], dtype=torch.long).to(device)
59
+ outputs = model(input_ids)
60
+
61
+ else:
62
+ # Match collate_fn_transformer but with no partial_prob
63
+ encoding = tokenizer(
64
+ text,
65
+ padding="max_length",
66
+ truncation=True,
67
+ max_length=128,
68
+ return_tensors="pt"
69
+ )
70
+ input_ids = encoding["input_ids"].to(device)
71
+ attention_mask = encoding["attention_mask"].to(device)
72
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
73
+
74
+ probs = F.softmax(outputs, dim=-1)
75
+ return probs.squeeze().detach().cpu().numpy()
76
+
77
+ # =========================
78
+ # Gradio App
79
+ # =========================
80
+
81
+ def emotion_typeahead(text):
82
+ if len(text.strip()) <= 2:
83
+ return {}, {}, {}
84
+
85
+ rnn_probs = predict(rnn_model, text.strip(), "rnn", vocab)
86
+ lstm_probs = predict(lstm_model, text.strip(), "lstm", vocab)
87
+ transformer_probs = predict(transformer_model, text.strip(), "transformer", vocab, tokenizer)
88
+
89
+ rnn_dict = {emo: float(prob) for emo, prob in zip(emotions, rnn_probs)}
90
+ lstm_dict = {emo: float(prob) for emo, prob in zip(emotions, lstm_probs)}
91
+ transformer_dict = {emo: float(prob) for emo, prob in zip(emotions, transformer_probs)}
92
+
93
+ return rnn_dict, lstm_dict, transformer_dict
94
+
95
+ with gr.Blocks() as demo:
96
+ gr.Markdown("## 🎯 Emotion Typeahead Predictor (RNN, LSTM, Transformer)")
97
+
98
+ text_input = gr.Textbox(label="Type your sentence here...")
99
+
100
+ with gr.Row():
101
+ rnn_output = gr.Label(label="🧠 RNN Prediction")
102
+ lstm_output = gr.Label(label="🧠 LSTM Prediction")
103
+ transformer_output = gr.Label(label="🧠 Transformer Prediction")
104
+
105
+ text_input.change(emotion_typeahead, inputs=text_input, outputs=[rnn_output, lstm_output, transformer_output])
106
+
107
+ demo.launch()