Ashwin B commited on
Commit
0b6b733
·
1 Parent(s): 9afcdc1

Move project to Hugging Space

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. README.md +135 -9
  3. app/.DS_Store +0 -0
  4. app/app.py +62 -0
  5. notebooks/.DS_Store +0 -0
  6. notebooks/.ipynb_checkpoints/01_exploration-checkpoint.ipynb +358 -0
  7. notebooks/.ipynb_checkpoints/02_training-checkpoint.ipynb +368 -0
  8. notebooks/.ipynb_checkpoints/03_evaluation-checkpoint.ipynb +0 -0
  9. notebooks/.ipynb_checkpoints/04_model_comparison-checkpoint.ipynb +290 -0
  10. notebooks/01_exploration.ipynb +358 -0
  11. notebooks/02_training.ipynb +368 -0
  12. notebooks/03_evaluation.ipynb +0 -0
  13. notebooks/04_model_comparison.ipynb +290 -0
  14. outputs/.DS_Store +0 -0
  15. outputs/interpretations/.DS_Store +0 -0
  16. outputs/interpretations/sample_1_disapproval_bar.png +0 -0
  17. outputs/interpretations/sample_1_disapproval_heatmap.png +0 -0
  18. outputs/interpretations/sample_2_neutral_bar.png +0 -0
  19. outputs/interpretations/sample_2_neutral_heatmap.png +0 -0
  20. outputs/interpretations/sample_3_neutral_bar.png +0 -0
  21. outputs/interpretations/sample_3_neutral_heatmap.png +0 -0
  22. outputs/interpretations/sample_4_sadness_bar.png +0 -0
  23. outputs/interpretations/sample_4_sadness_heatmap.png +0 -0
  24. outputs/interpretations/sample_5_neutral_bar.png +0 -0
  25. outputs/interpretations/sample_5_neutral_heatmap.png +0 -0
  26. outputs/metrics/.DS_Store +0 -0
  27. outputs/metrics/.gitkeep +0 -0
  28. outputs/metrics/.ipynb_checkpoints/confusion_matrix-checkpoint.png +0 -0
  29. outputs/metrics/.ipynb_checkpoints/report-checkpoint.json +183 -0
  30. outputs/metrics/report.json +183 -0
  31. outputs/model-old/.gitkeep +0 -0
  32. outputs/model-old/.ipynb_checkpoints/special_tokens_map-checkpoint.json +7 -0
  33. outputs/model-old/config.json +87 -0
  34. outputs/model-old/model.safetensors +3 -0
  35. outputs/model-old/special_tokens_map.json +7 -0
  36. outputs/model-old/tokenizer_config.json +57 -0
  37. outputs/model-old/vocab.txt +0 -0
  38. outputs/model/config.json +87 -0
  39. outputs/model/model.safetensors +3 -0
  40. outputs/model/special_tokens_map.json +37 -0
  41. outputs/model/tokenizer.json +0 -0
  42. outputs/model/tokenizer_config.json +57 -0
  43. outputs/model/vocab.txt +0 -0
  44. requirements.txt +9 -0
  45. save_clean_model.py +12 -0
  46. src/__pycache__/data_loader.cpython-312.pyc +0 -0
  47. src/__pycache__/evaluate.cpython-312.pyc +0 -0
  48. src/__pycache__/model.cpython-312.pyc +0 -0
  49. src/__pycache__/model_custom.cpython-312.pyc +0 -0
  50. src/__pycache__/model_hartmann.cpython-312.pyc +0 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
README.md CHANGED
@@ -1,12 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Emotion Classifier Nlp
3
- emoji: 📚
4
- colorFrom: indigo
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.29.0
8
- app_file: app.py
9
- pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Emotion Classifier (NLP)
2
+
3
+ A simple NLP-based emotion classification app that uses a fine-tuned transformer model on the GoEmotions dataset to predict the emotion conveyed in a given sentence.
4
+
5
+ ## Project Structure
6
+
7
+ Emotion-Classifier-NLP/
8
+ ├── notebooks/
9
+ │ ├── 01_exploration.ipynb
10
+ │ ├── 02_training.ipynb
11
+ │ ├── 03_evaluation.ipynb
12
+ │ └── 04_comparison.ipynb
13
+ ├── src/
14
+ │ ├── data_loader.py
15
+ │ ├── model.py
16
+ │ ├── model_hartmann.py
17
+ │ ├── model_custom.py
18
+ │ ├── train.py
19
+ │ └── evaluate.py
20
+ ├── app/
21
+ │ └── app.py
22
+ ├── outputs/
23
+ │ ├── model/ # Trained model
24
+ │ ├── metrics/ # Evaluation metrics
25
+ │ └── interpretations/ # Integrated gradients visualization plots
26
+ ├── requirements.txt
27
+ ├── README.md
28
+ └── .gitignore
29
+
30
+ ## Features
31
+ - Uses Hugging Face Transformers with a fine-tuned model
32
+ - Classifies emotion from text across 28 GoEmotions labels
33
+ - Streamlit frontend for interactive use
34
+ - Displays model prediction probabilities
35
+ - Shows sample integrated gradients visualizations per label (optional)
36
+
37
+ ## Running the App Locally
38
+ ### Make sure you have Streamlit and other dependencies installed:
39
+
40
+ pip install -r requirements.txt
41
+
42
+ ### Then start the app using:
43
+
44
+ streamlit run app/app.py
45
+
46
+ ## Running the App on a web browser
47
+
48
+ You can visit https://emotion-classifier-nlp.streamlit.app/ to run the app online
49
+
50
+ ## Example Output
51
+ - Input: "I find this funny"
52
+ - Output: Predicted Emotion: `amusement`
53
+ - Shows prediction probabilities across all 28 classes
54
+
55
+ ## Notes
56
+ - Pretrained model is saved in `outputs/model/`
57
+ - Integrated Gradients plots should be saved under `outputs/interpretations/` and named using the format: `sample_{n}_{label}.png`
58
+
59
+
60
+ # Performance Metrics
61
+
62
+ | Metric | Score |
63
+ |------------|-------|
64
+ | Accuracy | 60.2% |
65
+ | Macro F1 | 48.3% |
66
+ | Weighted F1| 59.6% |
67
+
68
+ ### Confusion matrix
69
+
70
+ ![confusion_matrix](https://github.com/user-attachments/assets/f571bafa-daa9-4cf2-88cd-2bad069d187a)
71
+
72
+
73
+
74
+ # Model Comparison (Hartmann vs Custom Model)
75
+
76
+ | Sample Sentence | Hartmann Prediction(s) | Custom Model Prediction(s) |
77
+ |-----------------------------------------------------|-----------------------------|-----------------------------|
78
+ | I love spending time with my family. | joy, sadness, disgust | love, joy, admiration |
79
+ | This is the worst day of my life. | disgust, anger, sadness | anger, surprise, disgust |
80
+ | I'm feeling very nervous about the exam. | fear, sadness, joy | nervousness, fear, embarrassment |
81
+ | What a beautiful sunset! | joy, surprise, neutral | admiration, excitement, joy |
82
+ | I feel so disappointed and frustrated. | sadness, anger, disgust | disappointment, annoyance, anger |
83
+ | I'm not sure how to feel about this. | neutral, disgust, sadness | confusion, optimism, disapproval |
84
+ | That was hilarious, I can't stop laughing! | joy, surprise, neutral | amusement, joy, optimism |
85
+ | I feel completely empty and lost. | sadness, neutral, disgust | surprise, disappointment, optimism |
86
+
87
+ **Insights**:
88
+ - The custom model captured more nuanced emotions, while Hartmann’s model tended to favor high-level emotions.
89
+ - Some variance due to differences in label granularity between the models.
90
+ - The custom model showed stronger performance in emotions like admiration, amusement, and disappointment.
91
+
92
+ ---
93
+
94
+ # Known Limitations
95
+
96
+ - **Single-label restriction**: While the data supports multi-label emotion classification, the model currently predicts only the highest probability emotion.
97
+ - **Low support for some classes**: Emotions like grief and pride had low representation in the training data.
98
+ - **Data bias**: Results reflect Reddit comment biases present in the GoEmotions dataset.
99
+
100
  ---
101
+
102
+ # Confidence Threshold
103
+
104
+ A **confidence threshold of 0.6** is applied in the app.
105
+ If the top emotion’s probability is below this value, the app returns:
106
+
107
+ "Unclear / Not enough signal"
108
+
109
+ This prevents overconfident predictions on uncertain or ambiguous text.
110
+
111
  ---
112
 
113
+ # Future Work
114
+
115
+ - Expand to multi-label predictions to better capture complex emotions.
116
+ - Improve minority class performance via data augmentation or rebalancing.
117
+ - Incorporate explainability methods directly into the Streamlit app.
118
+ - Deploy the app to Streamlit Cloud or Hugging Face Spaces.
119
+ - Collect user feedback for real-world validation.
120
+
121
+ ---
122
+
123
+ # Credits
124
+
125
+ - **Model architecture**: [RoBERTa](https://huggingface.co/roberta-base) (Hugging Face)
126
+ - **Training dataset**: [GoEmotions](https://huggingface.co/datasets/go_emotions)
127
+ - **Reference model**: [Hartmann et al. (2023)](https://arxiv.org/abs/2305.05894)
128
+ - **Streamlit App Framework**: [Streamlit](https://streamlit.io/)
129
+ - **Berta Emotion Model**: [BERTa](https://huggingface.co/bhadresh-savani/bert-base-uncased-emotion)
130
+ - **Transformers Library**: [Hugging Face Transformers](https://huggingface.co/docs/transformers/index)
131
+
132
+ ---
133
+
134
+ # License
135
+
136
+ This project is licensed under the MIT License.
137
+ You can use, modify, and distribute the software freely, but there is no warranty.
138
+
app/.DS_Store ADDED
Binary file (6.15 kB). View file
 
app/app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("git lfs pull")
3
+
4
+ import streamlit as st
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+ import numpy as np
8
+
9
+ st.set_page_config(page_title="Emotion Classifier", layout="centered")
10
+
11
+ # Load GoEmotions label names
12
+ GOEMOTIONS_LABELS = [
13
+ "admiration", "amusement", "anger", "annoyance", "approval", "caring",
14
+ "confusion", "curiosity", "desire", "disappointment", "disapproval",
15
+ "disgust", "embarrassment", "excitement", "fear", "gratitude", "grief",
16
+ "joy", "love", "nervousness", "optimism", "pride", "realization", "relief",
17
+ "remorse", "sadness", "surprise", "neutral"
18
+ ]
19
+
20
+ # -----------------------------
21
+ # Load model and tokenizer
22
+ # -----------------------------
23
+ MODEL_PATH = "/mount/src/emotion-classifier-nlp/outputs/model"
24
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH, device_map="auto", low_cpu_mem_usage=True)
25
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
26
+ print("Model device:", next(model.parameters()).device)
27
+
28
+ # -----------------------------
29
+ # Streamlit UI
30
+ # -----------------------------
31
+ st.title("🧠 Emotion Classifier (NLP)")
32
+ st.markdown("Enter a sentence to analyze:")
33
+
34
+ input_text = st.text_area(" ", height=100)
35
+
36
+ if st.button("Classify") and input_text.strip():
37
+
38
+ with torch.no_grad():
39
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
40
+ outputs = model(**inputs)
41
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
42
+
43
+ pred_label_idx = torch.argmax(probs).cpu().item()
44
+ pred_score = probs[pred_label_idx].cpu().item()
45
+ pred_emotion = GOEMOTIONS_LABELS[pred_label_idx]
46
+
47
+ # -----------------------------
48
+ # Confidence threshold logic
49
+ # -----------------------------
50
+ threshold = 0.6
51
+ if pred_score < threshold:
52
+ st.warning(f"**Predicted Emotion:** Unclear / Not enough signal (Confidence: {pred_score:.0%})")
53
+ else:
54
+ st.success(f"**Predicted Emotion:** {pred_emotion} (Confidence: {pred_score:.0%})")
55
+
56
+ # -----------------------------
57
+ # Show all probabilities
58
+ # -----------------------------
59
+ st.markdown("### Prediction Probabilities:")
60
+ for i, prob in enumerate(probs):
61
+ st.write(f"- {GOEMOTIONS_LABELS[i]}: {prob.item():.4f}")
62
+
notebooks/.DS_Store ADDED
Binary file (6.15 kB). View file
 
notebooks/.ipynb_checkpoints/01_exploration-checkpoint.ipynb ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "ec88103e-4da3-4eb9-9ed5-8aab3f7745a5",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from datasets import load_dataset\n",
11
+ "\n",
12
+ "dataset = load_dataset(\"go_emotions\")"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 2,
18
+ "id": "9e38e4c1-5c32-4e1b-8612-e723645e34bb",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "import pandas as pd\n",
23
+ "import matplotlib.pyplot as plt\n",
24
+ "from collections import Counter"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "id": "169cb7bc-8ac6-41d4-a109-c30ca1688899",
30
+ "metadata": {},
31
+ "source": [
32
+ "### Converting the dataset to a dataframe for ease"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 4,
38
+ "id": "fadfe4cd-e0a3-4f09-a411-56db306fbd0b",
39
+ "metadata": {},
40
+ "outputs": [
41
+ {
42
+ "data": {
43
+ "text/html": [
44
+ "<div>\n",
45
+ "<style scoped>\n",
46
+ " .dataframe tbody tr th:only-of-type {\n",
47
+ " vertical-align: middle;\n",
48
+ " }\n",
49
+ "\n",
50
+ " .dataframe tbody tr th {\n",
51
+ " vertical-align: top;\n",
52
+ " }\n",
53
+ "\n",
54
+ " .dataframe thead th {\n",
55
+ " text-align: right;\n",
56
+ " }\n",
57
+ "</style>\n",
58
+ "<table border=\"1\" class=\"dataframe\">\n",
59
+ " <thead>\n",
60
+ " <tr style=\"text-align: right;\">\n",
61
+ " <th></th>\n",
62
+ " <th>text</th>\n",
63
+ " <th>labels</th>\n",
64
+ " <th>id</th>\n",
65
+ " </tr>\n",
66
+ " </thead>\n",
67
+ " <tbody>\n",
68
+ " <tr>\n",
69
+ " <th>0</th>\n",
70
+ " <td>My favourite food is anything I didn't have to...</td>\n",
71
+ " <td>[27]</td>\n",
72
+ " <td>eebbqej</td>\n",
73
+ " </tr>\n",
74
+ " <tr>\n",
75
+ " <th>1</th>\n",
76
+ " <td>Now if he does off himself, everyone will thin...</td>\n",
77
+ " <td>[27]</td>\n",
78
+ " <td>ed00q6i</td>\n",
79
+ " </tr>\n",
80
+ " <tr>\n",
81
+ " <th>2</th>\n",
82
+ " <td>WHY THE FUCK IS BAYLESS ISOING</td>\n",
83
+ " <td>[2]</td>\n",
84
+ " <td>eezlygj</td>\n",
85
+ " </tr>\n",
86
+ " <tr>\n",
87
+ " <th>3</th>\n",
88
+ " <td>To make her feel threatened</td>\n",
89
+ " <td>[14]</td>\n",
90
+ " <td>ed7ypvh</td>\n",
91
+ " </tr>\n",
92
+ " <tr>\n",
93
+ " <th>4</th>\n",
94
+ " <td>Dirty Southern Wankers</td>\n",
95
+ " <td>[3]</td>\n",
96
+ " <td>ed0bdzj</td>\n",
97
+ " </tr>\n",
98
+ " </tbody>\n",
99
+ "</table>\n",
100
+ "</div>"
101
+ ],
102
+ "text/plain": [
103
+ " text labels id\n",
104
+ "0 My favourite food is anything I didn't have to... [27] eebbqej\n",
105
+ "1 Now if he does off himself, everyone will thin... [27] ed00q6i\n",
106
+ "2 WHY THE FUCK IS BAYLESS ISOING [2] eezlygj\n",
107
+ "3 To make her feel threatened [14] ed7ypvh\n",
108
+ "4 Dirty Southern Wankers [3] ed0bdzj"
109
+ ]
110
+ },
111
+ "execution_count": 4,
112
+ "metadata": {},
113
+ "output_type": "execute_result"
114
+ }
115
+ ],
116
+ "source": [
117
+ "df = pd.DataFrame(dataset[\"train\"])\n",
118
+ "df.head()"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 5,
124
+ "id": "20a1d0af-1cd6-45c9-8128-29be160713d4",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "label_counts = Counter() # counts number of occurences\n",
129
+ "\n",
130
+ "# Looping through all label lists and counting each label's total occurrences\n",
131
+ "for labels in df[\"labels\"]:\n",
132
+ " label_counts.update(labels)"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": 6,
138
+ "id": "11c20083-c1a7-412d-9f98-e51defaded84",
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "# Mapping indices to label names\n",
143
+ "label_names = dataset[\"train\"].features[\"labels\"].feature.names\n",
144
+ "label_distribution = {label_names[i]: count for i, count in label_counts.items()}"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "markdown",
149
+ "id": "967e6d78-3dec-4dc5-b035-8ff263e689c8",
150
+ "metadata": {},
151
+ "source": [
152
+ "### Plot for the number of occurences of each emotion in the training set"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": 8,
158
+ "id": "91e04be4-4132-4c78-8e4e-6d4c1b29fd6c",
159
+ "metadata": {},
160
+ "outputs": [
161
+ {
162
+ "data": {
163
+ "image/png": "",
164
+ "text/plain": [
165
+ "<Figure size 1200x600 with 1 Axes>"
166
+ ]
167
+ },
168
+ "metadata": {},
169
+ "output_type": "display_data"
170
+ }
171
+ ],
172
+ "source": [
173
+ "# Plot for the number of occurences of each emotion in the training set\n",
174
+ "plt.figure(figsize=(12, 6))\n",
175
+ "pd.Series(label_distribution).sort_values(ascending=False).plot(kind='bar')\n",
176
+ "plt.title(\"GoEmotions Label Distribution\")\n",
177
+ "plt.ylabel(\"Count\")\n",
178
+ "plt.xticks(rotation=90)\n",
179
+ "plt.tight_layout()\n",
180
+ "plt.show()"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": 9,
186
+ "id": "6d28b28c-d87f-41e5-afd2-cc3257dd974b",
187
+ "metadata": {},
188
+ "outputs": [
189
+ {
190
+ "data": {
191
+ "text/html": [
192
+ "<div>\n",
193
+ "<style scoped>\n",
194
+ " .dataframe tbody tr th:only-of-type {\n",
195
+ " vertical-align: middle;\n",
196
+ " }\n",
197
+ "\n",
198
+ " .dataframe tbody tr th {\n",
199
+ " vertical-align: top;\n",
200
+ " }\n",
201
+ "\n",
202
+ " .dataframe thead th {\n",
203
+ " text-align: right;\n",
204
+ " }\n",
205
+ "</style>\n",
206
+ "<table border=\"1\" class=\"dataframe\">\n",
207
+ " <thead>\n",
208
+ " <tr style=\"text-align: right;\">\n",
209
+ " <th></th>\n",
210
+ " <th>text</th>\n",
211
+ " <th>labels</th>\n",
212
+ " <th>id</th>\n",
213
+ " <th>clean_text</th>\n",
214
+ " </tr>\n",
215
+ " </thead>\n",
216
+ " <tbody>\n",
217
+ " <tr>\n",
218
+ " <th>0</th>\n",
219
+ " <td>My favourite food is anything I didn't have to...</td>\n",
220
+ " <td>[27]</td>\n",
221
+ " <td>eebbqej</td>\n",
222
+ " <td>my favourite food is anything i didnt have to ...</td>\n",
223
+ " </tr>\n",
224
+ " <tr>\n",
225
+ " <th>1</th>\n",
226
+ " <td>Now if he does off himself, everyone will thin...</td>\n",
227
+ " <td>[27]</td>\n",
228
+ " <td>ed00q6i</td>\n",
229
+ " <td>now if he does off himself everyone will think...</td>\n",
230
+ " </tr>\n",
231
+ " <tr>\n",
232
+ " <th>2</th>\n",
233
+ " <td>WHY THE FUCK IS BAYLESS ISOING</td>\n",
234
+ " <td>[2]</td>\n",
235
+ " <td>eezlygj</td>\n",
236
+ " <td>why the fuck is bayless isoing</td>\n",
237
+ " </tr>\n",
238
+ " <tr>\n",
239
+ " <th>3</th>\n",
240
+ " <td>To make her feel threatened</td>\n",
241
+ " <td>[14]</td>\n",
242
+ " <td>ed7ypvh</td>\n",
243
+ " <td>to make her feel threatened</td>\n",
244
+ " </tr>\n",
245
+ " <tr>\n",
246
+ " <th>4</th>\n",
247
+ " <td>Dirty Southern Wankers</td>\n",
248
+ " <td>[3]</td>\n",
249
+ " <td>ed0bdzj</td>\n",
250
+ " <td>dirty southern wankers</td>\n",
251
+ " </tr>\n",
252
+ " <tr>\n",
253
+ " <th>5</th>\n",
254
+ " <td>OmG pEyToN iSn'T gOoD eNoUgH tO hElP uS iN tHe...</td>\n",
255
+ " <td>[26]</td>\n",
256
+ " <td>edvnz26</td>\n",
257
+ " <td>omg peyton isnt good enough to help us in the ...</td>\n",
258
+ " </tr>\n",
259
+ " <tr>\n",
260
+ " <th>6</th>\n",
261
+ " <td>Yes I heard abt the f bombs! That has to be wh...</td>\n",
262
+ " <td>[15]</td>\n",
263
+ " <td>ee3b6wu</td>\n",
264
+ " <td>yes i heard abt the f bombs that has to be why...</td>\n",
265
+ " </tr>\n",
266
+ " <tr>\n",
267
+ " <th>7</th>\n",
268
+ " <td>We need more boards and to create a bit more s...</td>\n",
269
+ " <td>[8, 20]</td>\n",
270
+ " <td>ef4qmod</td>\n",
271
+ " <td>we need more boards and to create a bit more s...</td>\n",
272
+ " </tr>\n",
273
+ " <tr>\n",
274
+ " <th>8</th>\n",
275
+ " <td>Damn youtube and outrage drama is super lucrat...</td>\n",
276
+ " <td>[0]</td>\n",
277
+ " <td>ed8wbdn</td>\n",
278
+ " <td>damn youtube and outrage drama is super lucrat...</td>\n",
279
+ " </tr>\n",
280
+ " <tr>\n",
281
+ " <th>9</th>\n",
282
+ " <td>It might be linked to the trust factor of your...</td>\n",
283
+ " <td>[27]</td>\n",
284
+ " <td>eczgv1o</td>\n",
285
+ " <td>it might be linked to the trust factor of your...</td>\n",
286
+ " </tr>\n",
287
+ " </tbody>\n",
288
+ "</table>\n",
289
+ "</div>"
290
+ ],
291
+ "text/plain": [
292
+ " text labels id \\\n",
293
+ "0 My favourite food is anything I didn't have to... [27] eebbqej \n",
294
+ "1 Now if he does off himself, everyone will thin... [27] ed00q6i \n",
295
+ "2 WHY THE FUCK IS BAYLESS ISOING [2] eezlygj \n",
296
+ "3 To make her feel threatened [14] ed7ypvh \n",
297
+ "4 Dirty Southern Wankers [3] ed0bdzj \n",
298
+ "5 OmG pEyToN iSn'T gOoD eNoUgH tO hElP uS iN tHe... [26] edvnz26 \n",
299
+ "6 Yes I heard abt the f bombs! That has to be wh... [15] ee3b6wu \n",
300
+ "7 We need more boards and to create a bit more s... [8, 20] ef4qmod \n",
301
+ "8 Damn youtube and outrage drama is super lucrat... [0] ed8wbdn \n",
302
+ "9 It might be linked to the trust factor of your... [27] eczgv1o \n",
303
+ "\n",
304
+ " clean_text \n",
305
+ "0 my favourite food is anything i didnt have to ... \n",
306
+ "1 now if he does off himself everyone will think... \n",
307
+ "2 why the fuck is bayless isoing \n",
308
+ "3 to make her feel threatened \n",
309
+ "4 dirty southern wankers \n",
310
+ "5 omg peyton isnt good enough to help us in the ... \n",
311
+ "6 yes i heard abt the f bombs that has to be why... \n",
312
+ "7 we need more boards and to create a bit more s... \n",
313
+ "8 damn youtube and outrage drama is super lucrat... \n",
314
+ "9 it might be linked to the trust factor of your... "
315
+ ]
316
+ },
317
+ "execution_count": 9,
318
+ "metadata": {},
319
+ "output_type": "execute_result"
320
+ }
321
+ ],
322
+ "source": [
323
+ "import re\n",
324
+ "\n",
325
+ "def clean_text(text):\n",
326
+ " text = text.lower() # this makes every character in the text lower-cased\n",
327
+ " text = re.sub(r\"http\\S+|www\\S+|https\\S+\", '', text) # this removes links\n",
328
+ " text = re.sub(r'[^A-Za-z0-9\\s]+', '', text) # this removes special characters\n",
329
+ " text = re.sub(r'\\s+', ' ', text).strip() # this normalizes whitespace\n",
330
+ " return text\n",
331
+ "\n",
332
+ "df[\"clean_text\"] = df[\"text\"].apply(clean_text)\n",
333
+ "df.head(10)"
334
+ ]
335
+ }
336
+ ],
337
+ "metadata": {
338
+ "kernelspec": {
339
+ "display_name": "Python 3 (ipykernel)",
340
+ "language": "python",
341
+ "name": "python3"
342
+ },
343
+ "language_info": {
344
+ "codemirror_mode": {
345
+ "name": "ipython",
346
+ "version": 3
347
+ },
348
+ "file_extension": ".py",
349
+ "mimetype": "text/x-python",
350
+ "name": "python",
351
+ "nbconvert_exporter": "python",
352
+ "pygments_lexer": "ipython3",
353
+ "version": "3.12.2"
354
+ }
355
+ },
356
+ "nbformat": 4,
357
+ "nbformat_minor": 5
358
+ }
notebooks/.ipynb_checkpoints/02_training-checkpoint.ipynb ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "80f816c1-0839-41cb-847b-c79a62ca1465",
6
+ "metadata": {},
7
+ "source": [
8
+ "### Load all required modules for loading data, model setup, training, and metric evaluation"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "2554d05b-f08a-4c21-953f-4f507407e426",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import sys\n",
19
+ "import os\n",
20
+ "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), \"..\", \"src\")))\n",
21
+ "from data_loader import load_and_prepare_data \n",
22
+ "from model import get_model, get_tokenizer \n",
23
+ "from train import get_training_args, train_model \n",
24
+ "from evaluate import compute_metrics \n",
25
+ "from torch.utils.data import Dataset \n",
26
+ "import torch"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "markdown",
31
+ "id": "3bfbb706-4b0b-43de-a95a-884d46343668",
32
+ "metadata": {},
33
+ "source": [
34
+ "### Define a class that wraps tokenized data and labels for Hugging Face’s Trainer to use"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 3,
40
+ "id": "c814c354-7962-4a2d-b7bd-5c498f1d004e",
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "class EmotionDataset(Dataset):\n",
45
+ " def __init__(self, encodings, labels):\n",
46
+ " self.encodings = encodings # BERT tokenized inputs (input_ids, attention_mask)\n",
47
+ " self.labels = labels # Encoded labels (integers)\n",
48
+ "\n",
49
+ " def __len__(self):\n",
50
+ " return len(self.labels) # Total number of samples\n",
51
+ "\n",
52
+ " def __getitem__(self, idx):\n",
53
+ " # Return dictionary of input tensors + label tensor for a single sample\n",
54
+ " return {\n",
55
+ " key: torch.tensor(val[idx]) for key, val in self.encodings.items()\n",
56
+ " } | {\"labels\": torch.tensor(self.labels[idx])}"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "id": "f9b87257-f0c0-4532-9eee-939d8747ef79",
62
+ "metadata": {},
63
+ "source": [
64
+ "### Load the dataset from Hugging Face, clean and encode it, then tokenize it using the BERT tokenizer."
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": 5,
70
+ "id": "18e312be-5863-4e24-900a-843e42e145cc",
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": [
74
+ "# Load train/test splits and label encoder\n",
75
+ "train_texts, test_texts, train_labels, test_labels, label_encoder = load_and_prepare_data()\n",
76
+ "\n",
77
+ "# Load BERT tokenizer\n",
78
+ "tokenizer = get_tokenizer()\n",
79
+ "\n",
80
+ "# Tokenize training and testing texts with truncation and padding\n",
81
+ "train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=128)\n",
82
+ "test_encodings = tokenizer(test_texts, truncation=True, padding=True, max_length=128)\n",
83
+ "\n",
84
+ "# Wrap the tokenized data into EmotionDataset objects\n",
85
+ "train_dataset = EmotionDataset(train_encodings, train_labels)\n",
86
+ "test_dataset = EmotionDataset(test_encodings, test_labels)"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "markdown",
91
+ "id": "66b99b4e-5297-4bc0-8cfb-20dbe22526c0",
92
+ "metadata": {},
93
+ "source": [
94
+ "### Samples from the dataset"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": 7,
100
+ "id": "35db4426-db21-4438-ba0e-ebb51d52edfb",
101
+ "metadata": {},
102
+ "outputs": [
103
+ {
104
+ "name": "stdout",
105
+ "output_type": "stream",
106
+ "text": [
107
+ "Sample 1\n",
108
+ "Text: i'd just feel less out of place, i guess. my sa makes me feel like i'm so behind my peers in terms of a social life\n",
109
+ "Label (encoded): 9\n",
110
+ "\n",
111
+ "Sample 2\n",
112
+ "Text: i love the lady in the green jacket chasing after the second car looking back at the first car like \"look what you did\"\n",
113
+ "Label (encoded): 18\n",
114
+ "\n",
115
+ "Sample 3\n",
116
+ "Text: man. really bad last possession there. bummer.\n",
117
+ "Label (encoded): 10\n",
118
+ "\n",
119
+ "Sample 4\n",
120
+ "Text: never would’ve guessed that one.\n",
121
+ "Label (encoded): 20\n",
122
+ "\n",
123
+ "Sample 5\n",
124
+ "Text: i wasn’t even expecting the reply that’s why i’m literally bamboozled.\n",
125
+ "Label (encoded): 27\n",
126
+ "\n"
127
+ ]
128
+ }
129
+ ],
130
+ "source": [
131
+ "for i in range(5):\n",
132
+ " print(f\"Sample {i+1}\")\n",
133
+ " print(f\"Text: {train_texts[i]}\")\n",
134
+ " print(f\"Label (encoded): {train_labels[i]}\")\n",
135
+ " print()"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "markdown",
140
+ "id": "0883760a-a449-42ca-ba69-fa01d874e50b",
141
+ "metadata": {},
142
+ "source": [
143
+ "### Set up the BERT model for sequence classification and define training parameters."
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": 9,
149
+ "id": "3176ccf4-d20d-460c-b620-c73a1ab9cb6d",
150
+ "metadata": {},
151
+ "outputs": [
152
+ {
153
+ "name": "stderr",
154
+ "output_type": "stream",
155
+ "text": [
156
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
157
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
158
+ "/opt/anaconda3/lib/python3.12/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
159
+ " warnings.warn(\n"
160
+ ]
161
+ }
162
+ ],
163
+ "source": [
164
+ "# Load pre-trained BERT model with classification head for number of emotion classes\n",
165
+ "model = get_model(num_labels=len(label_encoder.classes_))\n",
166
+ "\n",
167
+ "# Set training configuration: batch size, epochs, logging, saving, evaluation\n",
168
+ "training_args = get_training_args()"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "markdown",
173
+ "id": "874a4e6a-80dd-470d-9283-e1c88e731b8e",
174
+ "metadata": {},
175
+ "source": [
176
+ "### Train the Model "
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": 13,
182
+ "id": "4c312e56-52bf-417d-82c0-8a1f47b82670",
183
+ "metadata": {},
184
+ "outputs": [
185
+ {
186
+ "data": {
187
+ "text/html": [
188
+ "\n",
189
+ " <div>\n",
190
+ " \n",
191
+ " <progress value='5448' max='5448' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
192
+ " [5448/5448 1:46:28, Epoch 3/3]\n",
193
+ " </div>\n",
194
+ " <table border=\"1\" class=\"dataframe\">\n",
195
+ " <thead>\n",
196
+ " <tr style=\"text-align: left;\">\n",
197
+ " <th>Epoch</th>\n",
198
+ " <th>Training Loss</th>\n",
199
+ " <th>Validation Loss</th>\n",
200
+ " <th>Accuracy</th>\n",
201
+ " <th>F1</th>\n",
202
+ " </tr>\n",
203
+ " </thead>\n",
204
+ " <tbody>\n",
205
+ " <tr>\n",
206
+ " <td>1</td>\n",
207
+ " <td>1.358900</td>\n",
208
+ " <td>1.335635</td>\n",
209
+ " <td>0.613467</td>\n",
210
+ " <td>0.579882</td>\n",
211
+ " </tr>\n",
212
+ " <tr>\n",
213
+ " <td>2</td>\n",
214
+ " <td>0.947100</td>\n",
215
+ " <td>1.284574</td>\n",
216
+ " <td>0.615671</td>\n",
217
+ " <td>0.601428</td>\n",
218
+ " </tr>\n",
219
+ " <tr>\n",
220
+ " <td>3</td>\n",
221
+ " <td>0.970400</td>\n",
222
+ " <td>1.297894</td>\n",
223
+ " <td>0.617048</td>\n",
224
+ " <td>0.606042</td>\n",
225
+ " </tr>\n",
226
+ " </tbody>\n",
227
+ "</table><p>"
228
+ ],
229
+ "text/plain": [
230
+ "<IPython.core.display.HTML object>"
231
+ ]
232
+ },
233
+ "metadata": {},
234
+ "output_type": "display_data"
235
+ },
236
+ {
237
+ "data": {
238
+ "text/html": [
239
+ "\n",
240
+ " <div>\n",
241
+ " \n",
242
+ " <progress value='5448' max='5448' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
243
+ " [5448/5448 1:35:20, Epoch 3/3]\n",
244
+ " </div>\n",
245
+ " <table border=\"1\" class=\"dataframe\">\n",
246
+ " <thead>\n",
247
+ " <tr style=\"text-align: left;\">\n",
248
+ " <th>Epoch</th>\n",
249
+ " <th>Training Loss</th>\n",
250
+ " <th>Validation Loss</th>\n",
251
+ " <th>Accuracy</th>\n",
252
+ " <th>F1</th>\n",
253
+ " </tr>\n",
254
+ " </thead>\n",
255
+ " <tbody>\n",
256
+ " <tr>\n",
257
+ " <td>1</td>\n",
258
+ " <td>0.907200</td>\n",
259
+ " <td>1.365916</td>\n",
260
+ " <td>0.602313</td>\n",
261
+ " <td>0.595804</td>\n",
262
+ " </tr>\n",
263
+ " <tr>\n",
264
+ " <td>2</td>\n",
265
+ " <td>0.549100</td>\n",
266
+ " <td>1.488130</td>\n",
267
+ " <td>0.595566</td>\n",
268
+ " <td>0.591464</td>\n",
269
+ " </tr>\n",
270
+ " <tr>\n",
271
+ " <td>3</td>\n",
272
+ " <td>0.514400</td>\n",
273
+ " <td>1.593286</td>\n",
274
+ " <td>0.591297</td>\n",
275
+ " <td>0.589066</td>\n",
276
+ " </tr>\n",
277
+ " </tbody>\n",
278
+ "</table><p>"
279
+ ],
280
+ "text/plain": [
281
+ "<IPython.core.display.HTML object>"
282
+ ]
283
+ },
284
+ "metadata": {},
285
+ "output_type": "display_data"
286
+ },
287
+ {
288
+ "data": {
289
+ "text/plain": [
290
+ "TrainOutput(global_step=5448, training_loss=0.7054264770818002, metrics={'train_runtime': 5721.3012, 'train_samples_per_second': 15.23, 'train_steps_per_second': 0.952, 'total_flos': 5733080823638016.0, 'train_loss': 0.7054264770818002, 'epoch': 3.0})"
291
+ ]
292
+ },
293
+ "execution_count": 13,
294
+ "metadata": {},
295
+ "output_type": "execute_result"
296
+ }
297
+ ],
298
+ "source": [
299
+ "trainer = train_model(\n",
300
+ " model=model,\n",
301
+ " args=training_args,\n",
302
+ " train_dataset=train_dataset,\n",
303
+ " val_dataset=test_dataset,\n",
304
+ " compute_metrics=compute_metrics\n",
305
+ ")\n",
306
+ "\n",
307
+ "# Begin training\n",
308
+ "trainer.train()"
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "markdown",
313
+ "id": "020729b6-c545-42ba-bd2c-00ee5f9bbb80",
314
+ "metadata": {},
315
+ "source": [
316
+ "### Save both model weights and tokenizer files for future inference or deployment."
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": 23,
322
+ "id": "5f12aedb-b3f8-4a1b-8e1f-6a68eb29933f",
323
+ "metadata": {},
324
+ "outputs": [
325
+ {
326
+ "data": {
327
+ "text/plain": [
328
+ "('../outputs/model/tokenizer_config.json',\n",
329
+ " '../outputs/model/special_tokens_map.json',\n",
330
+ " '../outputs/model/vocab.txt',\n",
331
+ " '../outputs/model/added_tokens.json')"
332
+ ]
333
+ },
334
+ "execution_count": 23,
335
+ "metadata": {},
336
+ "output_type": "execute_result"
337
+ }
338
+ ],
339
+ "source": [
340
+ "from pathlib import Path\n",
341
+ "model_path = Path(\"..\") / \"outputs\" / \"model\"\n",
342
+ "model.save_pretrained(model_path)\n",
343
+ "tokenizer.save_pretrained(model_path)"
344
+ ]
345
+ }
346
+ ],
347
+ "metadata": {
348
+ "kernelspec": {
349
+ "display_name": "Python 3 (ipykernel)",
350
+ "language": "python",
351
+ "name": "python3"
352
+ },
353
+ "language_info": {
354
+ "codemirror_mode": {
355
+ "name": "ipython",
356
+ "version": 3
357
+ },
358
+ "file_extension": ".py",
359
+ "mimetype": "text/x-python",
360
+ "name": "python",
361
+ "nbconvert_exporter": "python",
362
+ "pygments_lexer": "ipython3",
363
+ "version": "3.12.2"
364
+ }
365
+ },
366
+ "nbformat": 4,
367
+ "nbformat_minor": 5
368
+ }
notebooks/.ipynb_checkpoints/03_evaluation-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/.ipynb_checkpoints/04_model_comparison-checkpoint.ipynb ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "4032a920-2db8-4977-8b4f-a5a771dd022f",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import sys\n",
11
+ "import os\n",
12
+ "\n",
13
+ "project_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n",
14
+ "sys.path.append(project_root)\n",
15
+ "\n",
16
+ "from transformers import pipeline\n",
17
+ "from src.model_hartmann import load_model as load_hartmann_model, load_tokenizer as load_hartmann_tokenizer\n",
18
+ "from src.model_custom import load_model as load_custom_model, load_tokenizer as load_custom_tokenizer"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 2,
24
+ "id": "525cf57e-4ec3-40fd-aca2-0e9700a73298",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "hartmann_model = load_hartmann_model()\n",
29
+ "hartmann_tokenizer = load_hartmann_tokenizer()\n",
30
+ "\n",
31
+ "custom_model = load_custom_model()\n",
32
+ "custom_tokenizer = load_custom_tokenizer()"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 3,
38
+ "id": "04f9415c-3d4f-4ac0-8f51-74ec4bd64293",
39
+ "metadata": {},
40
+ "outputs": [
41
+ {
42
+ "name": "stderr",
43
+ "output_type": "stream",
44
+ "text": [
45
+ "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n",
46
+ "/opt/anaconda3/lib/python3.12/site-packages/transformers/pipelines/text_classification.py:104: UserWarning: `return_all_scores` is now deprecated, if want a similar functionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n",
47
+ " warnings.warn(\n",
48
+ "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n"
49
+ ]
50
+ }
51
+ ],
52
+ "source": [
53
+ "# Create pipelines for easy predictions\n",
54
+ "hartmann_pipeline = pipeline(\"text-classification\", model=hartmann_model, tokenizer=hartmann_tokenizer, return_all_scores=True)\n",
55
+ "custom_pipeline = pipeline(\"text-classification\", model=custom_model, tokenizer=custom_tokenizer, return_all_scores=True)"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 4,
61
+ "id": "3fcdf650-3abc-42a6-b1fd-0129e49d1e68",
62
+ "metadata": {},
63
+ "outputs": [
64
+ {
65
+ "name": "stdout",
66
+ "output_type": "stream",
67
+ "text": [
68
+ "========= Sentence 1 ==========\n",
69
+ "Text: I love spending time with my family.\n",
70
+ "\n",
71
+ "--- Hartmann Model Top 3 Predictions ---\n",
72
+ "joy: 0.9883\n",
73
+ "sadness: 0.0067\n",
74
+ "disgust: 0.0013\n",
75
+ "\n",
76
+ "--- Pretrained Model Top 3 Predictions ---\n",
77
+ "love: 0.9536\n",
78
+ "joy: 0.0134\n",
79
+ "admiration: 0.0108\n",
80
+ "\n",
81
+ "\n",
82
+ "========= Sentence 2 ==========\n",
83
+ "Text: This is the worst day of my life.\n",
84
+ "\n",
85
+ "--- Hartmann Model Top 3 Predictions ---\n",
86
+ "disgust: 0.9805\n",
87
+ "anger: 0.0086\n",
88
+ "sadness: 0.0055\n",
89
+ "\n",
90
+ "--- Pretrained Model Top 3 Predictions ---\n",
91
+ "anger: 0.3353\n",
92
+ "surprise: 0.2010\n",
93
+ "disgust: 0.1235\n",
94
+ "\n",
95
+ "\n",
96
+ "========= Sentence 3 ==========\n",
97
+ "Text: I'm feeling very nervous about the exam.\n",
98
+ "\n",
99
+ "--- Hartmann Model Top 3 Predictions ---\n",
100
+ "fear: 0.9947\n",
101
+ "sadness: 0.0013\n",
102
+ "joy: 0.0011\n",
103
+ "\n",
104
+ "--- Pretrained Model Top 3 Predictions ---\n",
105
+ "nervousness: 0.6201\n",
106
+ "fear: 0.0828\n",
107
+ "embarrassment: 0.0393\n",
108
+ "\n",
109
+ "\n",
110
+ "========= Sentence 4 ==========\n",
111
+ "Text: What a beautiful sunset!\n",
112
+ "\n",
113
+ "--- Hartmann Model Top 3 Predictions ---\n",
114
+ "joy: 0.8377\n",
115
+ "surprise: 0.1189\n",
116
+ "neutral: 0.0221\n",
117
+ "\n",
118
+ "--- Pretrained Model Top 3 Predictions ---\n",
119
+ "admiration: 0.8548\n",
120
+ "excitement: 0.0729\n",
121
+ "joy: 0.0351\n",
122
+ "\n",
123
+ "\n",
124
+ "========= Sentence 5 ==========\n",
125
+ "Text: I feel so disappointed and frustrated with the situation.\n",
126
+ "\n",
127
+ "--- Hartmann Model Top 3 Predictions ---\n",
128
+ "sadness: 0.9310\n",
129
+ "anger: 0.0381\n",
130
+ "disgust: 0.0158\n",
131
+ "\n",
132
+ "--- Pretrained Model Top 3 Predictions ---\n",
133
+ "disappointment: 0.5645\n",
134
+ "annoyance: 0.1864\n",
135
+ "anger: 0.0736\n",
136
+ "\n",
137
+ "\n",
138
+ "========= Sentence 6 ==========\n",
139
+ "Text: I'm not sure how to feel about this.\n",
140
+ "\n",
141
+ "--- Hartmann Model Top 3 Predictions ---\n",
142
+ "neutral: 0.5698\n",
143
+ "disgust: 0.2213\n",
144
+ "sadness: 0.0720\n",
145
+ "\n",
146
+ "--- Pretrained Model Top 3 Predictions ---\n",
147
+ "confusion: 0.9011\n",
148
+ "optimism: 0.0230\n",
149
+ "disapproval: 0.0223\n",
150
+ "\n",
151
+ "\n",
152
+ "========= Sentence 7 ==========\n",
153
+ "Text: That was hilarious, I can't stop laughing!\n",
154
+ "\n",
155
+ "--- Hartmann Model Top 3 Predictions ---\n",
156
+ "joy: 0.9336\n",
157
+ "surprise: 0.0306\n",
158
+ "neutral: 0.0178\n",
159
+ "\n",
160
+ "--- Pretrained Model Top 3 Predictions ---\n",
161
+ "amusement: 0.9551\n",
162
+ "joy: 0.0286\n",
163
+ "optimism: 0.0032\n",
164
+ "\n",
165
+ "\n",
166
+ "========= Sentence 8 ==========\n",
167
+ "Text: I feel completely empty and lost.\n",
168
+ "\n",
169
+ "--- Hartmann Model Top 3 Predictions ---\n",
170
+ "sadness: 0.9808\n",
171
+ "neutral: 0.0086\n",
172
+ "disgust: 0.0051\n",
173
+ "\n",
174
+ "--- Pretrained Model Top 3 Predictions ---\n",
175
+ "surprise: 0.8055\n",
176
+ "disappointment: 0.1067\n",
177
+ "optimism: 0.0222\n",
178
+ "\n",
179
+ "\n",
180
+ "========= Sentence 9 ==========\n",
181
+ "Text: Your help means a lot to me, thank you!\n",
182
+ "\n",
183
+ "--- Hartmann Model Top 3 Predictions ---\n",
184
+ "joy: 0.9760\n",
185
+ "neutral: 0.0104\n",
186
+ "surprise: 0.0057\n",
187
+ "\n",
188
+ "--- Pretrained Model Top 3 Predictions ---\n",
189
+ "gratitude: 0.9890\n",
190
+ "caring: 0.0014\n",
191
+ "sadness: 0.0009\n",
192
+ "\n",
193
+ "\n",
194
+ "========= Sentence 10 ==========\n",
195
+ "Text: I'm so angry I could scream.\n",
196
+ "\n",
197
+ "--- Hartmann Model Top 3 Predictions ---\n",
198
+ "anger: 0.9785\n",
199
+ "fear: 0.0084\n",
200
+ "neutral: 0.0047\n",
201
+ "\n",
202
+ "--- Pretrained Model Top 3 Predictions ---\n",
203
+ "anger: 0.9155\n",
204
+ "annoyance: 0.0223\n",
205
+ "optimism: 0.0082\n",
206
+ "\n",
207
+ "\n"
208
+ ]
209
+ }
210
+ ],
211
+ "source": [
212
+ "from tabulate import tabulate\n",
213
+ "\n",
214
+ "goemotions_labels = [\n",
215
+ " \"admiration\", \"amusement\", \"anger\", \"annoyance\", \"approval\", \"caring\", \"confusion\", \"curiosity\",\n",
216
+ " \"desire\", \"disappointment\", \"disapproval\", \"disgust\", \"embarrassment\", \"excitement\", \"fear\",\n",
217
+ " \"gratitude\", \"grief\", \"joy\", \"love\", \"nervousness\", \"optimism\", \"pride\", \"realization\", \"relief\",\n",
218
+ " \"remorse\", \"sadness\", \"surprise\", \"neutral\"\n",
219
+ "]\n",
220
+ "\n",
221
+ "\n",
222
+ "# Your 10 test sentences\n",
223
+ "sentences = [\n",
224
+ " \"I love spending time with my family.\",\n",
225
+ " \"This is the worst day of my life.\",\n",
226
+ " \"I'm feeling very nervous about the exam.\",\n",
227
+ " \"What a beautiful sunset!\",\n",
228
+ " \"I feel so disappointed and frustrated with the situation.\",\n",
229
+ " \"I'm not sure how to feel about this.\",\n",
230
+ " \"That was hilarious, I can't stop laughing!\",\n",
231
+ " \"I feel completely empty and lost.\",\n",
232
+ " \"Your help means a lot to me, thank you!\",\n",
233
+ " \"I'm so angry I could scream.\"\n",
234
+ "]\n",
235
+ "\n",
236
+ "# Loop over sentences and collect results\n",
237
+ "for i, sentence in enumerate(sentences):\n",
238
+ " print(f\"========= Sentence {i+1} ==========\")\n",
239
+ " print(f\"Text: {sentence}\\n\")\n",
240
+ "\n",
241
+ " # Get predictions\n",
242
+ " hartmann_results = hartmann_pipeline(sentence, return_all_scores=True)\n",
243
+ " custom_results = custom_pipeline(sentence, return_all_scores=True)\n",
244
+ "\n",
245
+ " # Unwrap the list to get the actual results\n",
246
+ " hartmann_results = hartmann_results[0]\n",
247
+ " custom_results = custom_results[0]\n",
248
+ "\n",
249
+ " # Sort and get top 3 predictions for each\n",
250
+ " hartmann_top3 = sorted(hartmann_results, key=lambda x: x['score'], reverse=True)[:3]\n",
251
+ " custom_top3 = sorted(custom_results, key=lambda x: x['score'], reverse=True)[:3]\n",
252
+ "\n",
253
+ " # Display Hartmann predictions\n",
254
+ " print(\"--- Hartmann Model Top 3 Predictions ---\")\n",
255
+ " for res in hartmann_top3:\n",
256
+ " print(f\"{res['label']}: {res['score']:.4f}\")\n",
257
+ "\n",
258
+ " # Display Custom Model predictions\n",
259
+ " print(\"\\n--- Pretrained Model Top 3 Predictions ---\")\n",
260
+ " for res in custom_top3:\n",
261
+ " label_idx = int(res['label'].split(\"_\")[-1])\n",
262
+ " emotion = goemotions_labels[label_idx]\n",
263
+ " print(f\"{emotion}: {res['score']:.4f}\")\n",
264
+ "\n",
265
+ " print(\"\\n\")\n"
266
+ ]
267
+ }
268
+ ],
269
+ "metadata": {
270
+ "kernelspec": {
271
+ "display_name": "Python 3 (ipykernel)",
272
+ "language": "python",
273
+ "name": "python3"
274
+ },
275
+ "language_info": {
276
+ "codemirror_mode": {
277
+ "name": "ipython",
278
+ "version": 3
279
+ },
280
+ "file_extension": ".py",
281
+ "mimetype": "text/x-python",
282
+ "name": "python",
283
+ "nbconvert_exporter": "python",
284
+ "pygments_lexer": "ipython3",
285
+ "version": "3.12.2"
286
+ }
287
+ },
288
+ "nbformat": 4,
289
+ "nbformat_minor": 5
290
+ }
notebooks/01_exploration.ipynb ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "ec88103e-4da3-4eb9-9ed5-8aab3f7745a5",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from datasets import load_dataset\n",
11
+ "\n",
12
+ "dataset = load_dataset(\"go_emotions\")"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 2,
18
+ "id": "9e38e4c1-5c32-4e1b-8612-e723645e34bb",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "import pandas as pd\n",
23
+ "import matplotlib.pyplot as plt\n",
24
+ "from collections import Counter"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "id": "169cb7bc-8ac6-41d4-a109-c30ca1688899",
30
+ "metadata": {},
31
+ "source": [
32
+ "### Converting the dataset to a dataframe for ease"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 4,
38
+ "id": "fadfe4cd-e0a3-4f09-a411-56db306fbd0b",
39
+ "metadata": {},
40
+ "outputs": [
41
+ {
42
+ "data": {
43
+ "text/html": [
44
+ "<div>\n",
45
+ "<style scoped>\n",
46
+ " .dataframe tbody tr th:only-of-type {\n",
47
+ " vertical-align: middle;\n",
48
+ " }\n",
49
+ "\n",
50
+ " .dataframe tbody tr th {\n",
51
+ " vertical-align: top;\n",
52
+ " }\n",
53
+ "\n",
54
+ " .dataframe thead th {\n",
55
+ " text-align: right;\n",
56
+ " }\n",
57
+ "</style>\n",
58
+ "<table border=\"1\" class=\"dataframe\">\n",
59
+ " <thead>\n",
60
+ " <tr style=\"text-align: right;\">\n",
61
+ " <th></th>\n",
62
+ " <th>text</th>\n",
63
+ " <th>labels</th>\n",
64
+ " <th>id</th>\n",
65
+ " </tr>\n",
66
+ " </thead>\n",
67
+ " <tbody>\n",
68
+ " <tr>\n",
69
+ " <th>0</th>\n",
70
+ " <td>My favourite food is anything I didn't have to...</td>\n",
71
+ " <td>[27]</td>\n",
72
+ " <td>eebbqej</td>\n",
73
+ " </tr>\n",
74
+ " <tr>\n",
75
+ " <th>1</th>\n",
76
+ " <td>Now if he does off himself, everyone will thin...</td>\n",
77
+ " <td>[27]</td>\n",
78
+ " <td>ed00q6i</td>\n",
79
+ " </tr>\n",
80
+ " <tr>\n",
81
+ " <th>2</th>\n",
82
+ " <td>WHY THE FUCK IS BAYLESS ISOING</td>\n",
83
+ " <td>[2]</td>\n",
84
+ " <td>eezlygj</td>\n",
85
+ " </tr>\n",
86
+ " <tr>\n",
87
+ " <th>3</th>\n",
88
+ " <td>To make her feel threatened</td>\n",
89
+ " <td>[14]</td>\n",
90
+ " <td>ed7ypvh</td>\n",
91
+ " </tr>\n",
92
+ " <tr>\n",
93
+ " <th>4</th>\n",
94
+ " <td>Dirty Southern Wankers</td>\n",
95
+ " <td>[3]</td>\n",
96
+ " <td>ed0bdzj</td>\n",
97
+ " </tr>\n",
98
+ " </tbody>\n",
99
+ "</table>\n",
100
+ "</div>"
101
+ ],
102
+ "text/plain": [
103
+ " text labels id\n",
104
+ "0 My favourite food is anything I didn't have to... [27] eebbqej\n",
105
+ "1 Now if he does off himself, everyone will thin... [27] ed00q6i\n",
106
+ "2 WHY THE FUCK IS BAYLESS ISOING [2] eezlygj\n",
107
+ "3 To make her feel threatened [14] ed7ypvh\n",
108
+ "4 Dirty Southern Wankers [3] ed0bdzj"
109
+ ]
110
+ },
111
+ "execution_count": 4,
112
+ "metadata": {},
113
+ "output_type": "execute_result"
114
+ }
115
+ ],
116
+ "source": [
117
+ "df = pd.DataFrame(dataset[\"train\"])\n",
118
+ "df.head()"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 5,
124
+ "id": "20a1d0af-1cd6-45c9-8128-29be160713d4",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "label_counts = Counter() # counts number of occurences\n",
129
+ "\n",
130
+ "# Looping through all label lists and counting each label's total occurrences\n",
131
+ "for labels in df[\"labels\"]:\n",
132
+ " label_counts.update(labels)"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": 6,
138
+ "id": "11c20083-c1a7-412d-9f98-e51defaded84",
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "# Mapping indices to label names\n",
143
+ "label_names = dataset[\"train\"].features[\"labels\"].feature.names\n",
144
+ "label_distribution = {label_names[i]: count for i, count in label_counts.items()}"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "markdown",
149
+ "id": "967e6d78-3dec-4dc5-b035-8ff263e689c8",
150
+ "metadata": {},
151
+ "source": [
152
+ "### Plot for the number of occurences of each emotion in the training set"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": 8,
158
+ "id": "91e04be4-4132-4c78-8e4e-6d4c1b29fd6c",
159
+ "metadata": {},
160
+ "outputs": [
161
+ {
162
+ "data": {
163
+ "image/png": "",
164
+ "text/plain": [
165
+ "<Figure size 1200x600 with 1 Axes>"
166
+ ]
167
+ },
168
+ "metadata": {},
169
+ "output_type": "display_data"
170
+ }
171
+ ],
172
+ "source": [
173
+ "# Plot for the number of occurences of each emotion in the training set\n",
174
+ "plt.figure(figsize=(12, 6))\n",
175
+ "pd.Series(label_distribution).sort_values(ascending=False).plot(kind='bar')\n",
176
+ "plt.title(\"GoEmotions Label Distribution\")\n",
177
+ "plt.ylabel(\"Count\")\n",
178
+ "plt.xticks(rotation=90)\n",
179
+ "plt.tight_layout()\n",
180
+ "plt.show()"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": 9,
186
+ "id": "6d28b28c-d87f-41e5-afd2-cc3257dd974b",
187
+ "metadata": {},
188
+ "outputs": [
189
+ {
190
+ "data": {
191
+ "text/html": [
192
+ "<div>\n",
193
+ "<style scoped>\n",
194
+ " .dataframe tbody tr th:only-of-type {\n",
195
+ " vertical-align: middle;\n",
196
+ " }\n",
197
+ "\n",
198
+ " .dataframe tbody tr th {\n",
199
+ " vertical-align: top;\n",
200
+ " }\n",
201
+ "\n",
202
+ " .dataframe thead th {\n",
203
+ " text-align: right;\n",
204
+ " }\n",
205
+ "</style>\n",
206
+ "<table border=\"1\" class=\"dataframe\">\n",
207
+ " <thead>\n",
208
+ " <tr style=\"text-align: right;\">\n",
209
+ " <th></th>\n",
210
+ " <th>text</th>\n",
211
+ " <th>labels</th>\n",
212
+ " <th>id</th>\n",
213
+ " <th>clean_text</th>\n",
214
+ " </tr>\n",
215
+ " </thead>\n",
216
+ " <tbody>\n",
217
+ " <tr>\n",
218
+ " <th>0</th>\n",
219
+ " <td>My favourite food is anything I didn't have to...</td>\n",
220
+ " <td>[27]</td>\n",
221
+ " <td>eebbqej</td>\n",
222
+ " <td>my favourite food is anything i didnt have to ...</td>\n",
223
+ " </tr>\n",
224
+ " <tr>\n",
225
+ " <th>1</th>\n",
226
+ " <td>Now if he does off himself, everyone will thin...</td>\n",
227
+ " <td>[27]</td>\n",
228
+ " <td>ed00q6i</td>\n",
229
+ " <td>now if he does off himself everyone will think...</td>\n",
230
+ " </tr>\n",
231
+ " <tr>\n",
232
+ " <th>2</th>\n",
233
+ " <td>WHY THE FUCK IS BAYLESS ISOING</td>\n",
234
+ " <td>[2]</td>\n",
235
+ " <td>eezlygj</td>\n",
236
+ " <td>why the fuck is bayless isoing</td>\n",
237
+ " </tr>\n",
238
+ " <tr>\n",
239
+ " <th>3</th>\n",
240
+ " <td>To make her feel threatened</td>\n",
241
+ " <td>[14]</td>\n",
242
+ " <td>ed7ypvh</td>\n",
243
+ " <td>to make her feel threatened</td>\n",
244
+ " </tr>\n",
245
+ " <tr>\n",
246
+ " <th>4</th>\n",
247
+ " <td>Dirty Southern Wankers</td>\n",
248
+ " <td>[3]</td>\n",
249
+ " <td>ed0bdzj</td>\n",
250
+ " <td>dirty southern wankers</td>\n",
251
+ " </tr>\n",
252
+ " <tr>\n",
253
+ " <th>5</th>\n",
254
+ " <td>OmG pEyToN iSn'T gOoD eNoUgH tO hElP uS iN tHe...</td>\n",
255
+ " <td>[26]</td>\n",
256
+ " <td>edvnz26</td>\n",
257
+ " <td>omg peyton isnt good enough to help us in the ...</td>\n",
258
+ " </tr>\n",
259
+ " <tr>\n",
260
+ " <th>6</th>\n",
261
+ " <td>Yes I heard abt the f bombs! That has to be wh...</td>\n",
262
+ " <td>[15]</td>\n",
263
+ " <td>ee3b6wu</td>\n",
264
+ " <td>yes i heard abt the f bombs that has to be why...</td>\n",
265
+ " </tr>\n",
266
+ " <tr>\n",
267
+ " <th>7</th>\n",
268
+ " <td>We need more boards and to create a bit more s...</td>\n",
269
+ " <td>[8, 20]</td>\n",
270
+ " <td>ef4qmod</td>\n",
271
+ " <td>we need more boards and to create a bit more s...</td>\n",
272
+ " </tr>\n",
273
+ " <tr>\n",
274
+ " <th>8</th>\n",
275
+ " <td>Damn youtube and outrage drama is super lucrat...</td>\n",
276
+ " <td>[0]</td>\n",
277
+ " <td>ed8wbdn</td>\n",
278
+ " <td>damn youtube and outrage drama is super lucrat...</td>\n",
279
+ " </tr>\n",
280
+ " <tr>\n",
281
+ " <th>9</th>\n",
282
+ " <td>It might be linked to the trust factor of your...</td>\n",
283
+ " <td>[27]</td>\n",
284
+ " <td>eczgv1o</td>\n",
285
+ " <td>it might be linked to the trust factor of your...</td>\n",
286
+ " </tr>\n",
287
+ " </tbody>\n",
288
+ "</table>\n",
289
+ "</div>"
290
+ ],
291
+ "text/plain": [
292
+ " text labels id \\\n",
293
+ "0 My favourite food is anything I didn't have to... [27] eebbqej \n",
294
+ "1 Now if he does off himself, everyone will thin... [27] ed00q6i \n",
295
+ "2 WHY THE FUCK IS BAYLESS ISOING [2] eezlygj \n",
296
+ "3 To make her feel threatened [14] ed7ypvh \n",
297
+ "4 Dirty Southern Wankers [3] ed0bdzj \n",
298
+ "5 OmG pEyToN iSn'T gOoD eNoUgH tO hElP uS iN tHe... [26] edvnz26 \n",
299
+ "6 Yes I heard abt the f bombs! That has to be wh... [15] ee3b6wu \n",
300
+ "7 We need more boards and to create a bit more s... [8, 20] ef4qmod \n",
301
+ "8 Damn youtube and outrage drama is super lucrat... [0] ed8wbdn \n",
302
+ "9 It might be linked to the trust factor of your... [27] eczgv1o \n",
303
+ "\n",
304
+ " clean_text \n",
305
+ "0 my favourite food is anything i didnt have to ... \n",
306
+ "1 now if he does off himself everyone will think... \n",
307
+ "2 why the fuck is bayless isoing \n",
308
+ "3 to make her feel threatened \n",
309
+ "4 dirty southern wankers \n",
310
+ "5 omg peyton isnt good enough to help us in the ... \n",
311
+ "6 yes i heard abt the f bombs that has to be why... \n",
312
+ "7 we need more boards and to create a bit more s... \n",
313
+ "8 damn youtube and outrage drama is super lucrat... \n",
314
+ "9 it might be linked to the trust factor of your... "
315
+ ]
316
+ },
317
+ "execution_count": 9,
318
+ "metadata": {},
319
+ "output_type": "execute_result"
320
+ }
321
+ ],
322
+ "source": [
323
+ "import re\n",
324
+ "\n",
325
+ "def clean_text(text):\n",
326
+ " text = text.lower() # this makes every character in the text lower-cased\n",
327
+ " text = re.sub(r\"http\\S+|www\\S+|https\\S+\", '', text) # this removes links\n",
328
+ " text = re.sub(r'[^A-Za-z0-9\\s]+', '', text) # this removes special characters\n",
329
+ " text = re.sub(r'\\s+', ' ', text).strip() # this normalizes whitespace\n",
330
+ " return text\n",
331
+ "\n",
332
+ "df[\"clean_text\"] = df[\"text\"].apply(clean_text)\n",
333
+ "df.head(10)"
334
+ ]
335
+ }
336
+ ],
337
+ "metadata": {
338
+ "kernelspec": {
339
+ "display_name": "Python 3 (ipykernel)",
340
+ "language": "python",
341
+ "name": "python3"
342
+ },
343
+ "language_info": {
344
+ "codemirror_mode": {
345
+ "name": "ipython",
346
+ "version": 3
347
+ },
348
+ "file_extension": ".py",
349
+ "mimetype": "text/x-python",
350
+ "name": "python",
351
+ "nbconvert_exporter": "python",
352
+ "pygments_lexer": "ipython3",
353
+ "version": "3.12.2"
354
+ }
355
+ },
356
+ "nbformat": 4,
357
+ "nbformat_minor": 5
358
+ }
notebooks/02_training.ipynb ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "80f816c1-0839-41cb-847b-c79a62ca1465",
6
+ "metadata": {},
7
+ "source": [
8
+ "### Load all required modules for loading data, model setup, training, and metric evaluation"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "2554d05b-f08a-4c21-953f-4f507407e426",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import sys\n",
19
+ "import os\n",
20
+ "sys.path.append(os.path.abspath(os.path.join(os.getcwd(), \"..\", \"src\")))\n",
21
+ "from data_loader import load_and_prepare_data \n",
22
+ "from model import get_model, get_tokenizer \n",
23
+ "from train import get_training_args, train_model \n",
24
+ "from evaluate import compute_metrics \n",
25
+ "from torch.utils.data import Dataset \n",
26
+ "import torch"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "markdown",
31
+ "id": "3bfbb706-4b0b-43de-a95a-884d46343668",
32
+ "metadata": {},
33
+ "source": [
34
+ "### Define a class that wraps tokenized data and labels for Hugging Face’s Trainer to use"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 3,
40
+ "id": "c814c354-7962-4a2d-b7bd-5c498f1d004e",
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "class EmotionDataset(Dataset):\n",
45
+ " def __init__(self, encodings, labels):\n",
46
+ " self.encodings = encodings # BERT tokenized inputs (input_ids, attention_mask)\n",
47
+ " self.labels = labels # Encoded labels (integers)\n",
48
+ "\n",
49
+ " def __len__(self):\n",
50
+ " return len(self.labels) # Total number of samples\n",
51
+ "\n",
52
+ " def __getitem__(self, idx):\n",
53
+ " # Return dictionary of input tensors + label tensor for a single sample\n",
54
+ " return {\n",
55
+ " key: torch.tensor(val[idx]) for key, val in self.encodings.items()\n",
56
+ " } | {\"labels\": torch.tensor(self.labels[idx])}"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "id": "f9b87257-f0c0-4532-9eee-939d8747ef79",
62
+ "metadata": {},
63
+ "source": [
64
+ "### Load the dataset from Hugging Face, clean and encode it, then tokenize it using the BERT tokenizer."
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": 5,
70
+ "id": "18e312be-5863-4e24-900a-843e42e145cc",
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": [
74
+ "# Load train/test splits and label encoder\n",
75
+ "train_texts, test_texts, train_labels, test_labels, label_encoder = load_and_prepare_data()\n",
76
+ "\n",
77
+ "# Load BERT tokenizer\n",
78
+ "tokenizer = get_tokenizer()\n",
79
+ "\n",
80
+ "# Tokenize training and testing texts with truncation and padding\n",
81
+ "train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=128)\n",
82
+ "test_encodings = tokenizer(test_texts, truncation=True, padding=True, max_length=128)\n",
83
+ "\n",
84
+ "# Wrap the tokenized data into EmotionDataset objects\n",
85
+ "train_dataset = EmotionDataset(train_encodings, train_labels)\n",
86
+ "test_dataset = EmotionDataset(test_encodings, test_labels)"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "markdown",
91
+ "id": "66b99b4e-5297-4bc0-8cfb-20dbe22526c0",
92
+ "metadata": {},
93
+ "source": [
94
+ "### Samples from the dataset"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": 7,
100
+ "id": "35db4426-db21-4438-ba0e-ebb51d52edfb",
101
+ "metadata": {},
102
+ "outputs": [
103
+ {
104
+ "name": "stdout",
105
+ "output_type": "stream",
106
+ "text": [
107
+ "Sample 1\n",
108
+ "Text: i'd just feel less out of place, i guess. my sa makes me feel like i'm so behind my peers in terms of a social life\n",
109
+ "Label (encoded): 9\n",
110
+ "\n",
111
+ "Sample 2\n",
112
+ "Text: i love the lady in the green jacket chasing after the second car looking back at the first car like \"look what you did\"\n",
113
+ "Label (encoded): 18\n",
114
+ "\n",
115
+ "Sample 3\n",
116
+ "Text: man. really bad last possession there. bummer.\n",
117
+ "Label (encoded): 10\n",
118
+ "\n",
119
+ "Sample 4\n",
120
+ "Text: never would’ve guessed that one.\n",
121
+ "Label (encoded): 20\n",
122
+ "\n",
123
+ "Sample 5\n",
124
+ "Text: i wasn’t even expecting the reply that’s why i’m literally bamboozled.\n",
125
+ "Label (encoded): 27\n",
126
+ "\n"
127
+ ]
128
+ }
129
+ ],
130
+ "source": [
131
+ "for i in range(5):\n",
132
+ " print(f\"Sample {i+1}\")\n",
133
+ " print(f\"Text: {train_texts[i]}\")\n",
134
+ " print(f\"Label (encoded): {train_labels[i]}\")\n",
135
+ " print()"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "markdown",
140
+ "id": "0883760a-a449-42ca-ba69-fa01d874e50b",
141
+ "metadata": {},
142
+ "source": [
143
+ "### Set up the BERT model for sequence classification and define training parameters."
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": 9,
149
+ "id": "3176ccf4-d20d-460c-b620-c73a1ab9cb6d",
150
+ "metadata": {},
151
+ "outputs": [
152
+ {
153
+ "name": "stderr",
154
+ "output_type": "stream",
155
+ "text": [
156
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
157
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
158
+ "/opt/anaconda3/lib/python3.12/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
159
+ " warnings.warn(\n"
160
+ ]
161
+ }
162
+ ],
163
+ "source": [
164
+ "# Load pre-trained BERT model with classification head for number of emotion classes\n",
165
+ "model = get_model(num_labels=len(label_encoder.classes_))\n",
166
+ "\n",
167
+ "# Set training configuration: batch size, epochs, logging, saving, evaluation\n",
168
+ "training_args = get_training_args()"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "markdown",
173
+ "id": "874a4e6a-80dd-470d-9283-e1c88e731b8e",
174
+ "metadata": {},
175
+ "source": [
176
+ "### Train the Model "
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": 13,
182
+ "id": "4c312e56-52bf-417d-82c0-8a1f47b82670",
183
+ "metadata": {},
184
+ "outputs": [
185
+ {
186
+ "data": {
187
+ "text/html": [
188
+ "\n",
189
+ " <div>\n",
190
+ " \n",
191
+ " <progress value='5448' max='5448' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
192
+ " [5448/5448 1:46:28, Epoch 3/3]\n",
193
+ " </div>\n",
194
+ " <table border=\"1\" class=\"dataframe\">\n",
195
+ " <thead>\n",
196
+ " <tr style=\"text-align: left;\">\n",
197
+ " <th>Epoch</th>\n",
198
+ " <th>Training Loss</th>\n",
199
+ " <th>Validation Loss</th>\n",
200
+ " <th>Accuracy</th>\n",
201
+ " <th>F1</th>\n",
202
+ " </tr>\n",
203
+ " </thead>\n",
204
+ " <tbody>\n",
205
+ " <tr>\n",
206
+ " <td>1</td>\n",
207
+ " <td>1.358900</td>\n",
208
+ " <td>1.335635</td>\n",
209
+ " <td>0.613467</td>\n",
210
+ " <td>0.579882</td>\n",
211
+ " </tr>\n",
212
+ " <tr>\n",
213
+ " <td>2</td>\n",
214
+ " <td>0.947100</td>\n",
215
+ " <td>1.284574</td>\n",
216
+ " <td>0.615671</td>\n",
217
+ " <td>0.601428</td>\n",
218
+ " </tr>\n",
219
+ " <tr>\n",
220
+ " <td>3</td>\n",
221
+ " <td>0.970400</td>\n",
222
+ " <td>1.297894</td>\n",
223
+ " <td>0.617048</td>\n",
224
+ " <td>0.606042</td>\n",
225
+ " </tr>\n",
226
+ " </tbody>\n",
227
+ "</table><p>"
228
+ ],
229
+ "text/plain": [
230
+ "<IPython.core.display.HTML object>"
231
+ ]
232
+ },
233
+ "metadata": {},
234
+ "output_type": "display_data"
235
+ },
236
+ {
237
+ "data": {
238
+ "text/html": [
239
+ "\n",
240
+ " <div>\n",
241
+ " \n",
242
+ " <progress value='5448' max='5448' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
243
+ " [5448/5448 1:35:20, Epoch 3/3]\n",
244
+ " </div>\n",
245
+ " <table border=\"1\" class=\"dataframe\">\n",
246
+ " <thead>\n",
247
+ " <tr style=\"text-align: left;\">\n",
248
+ " <th>Epoch</th>\n",
249
+ " <th>Training Loss</th>\n",
250
+ " <th>Validation Loss</th>\n",
251
+ " <th>Accuracy</th>\n",
252
+ " <th>F1</th>\n",
253
+ " </tr>\n",
254
+ " </thead>\n",
255
+ " <tbody>\n",
256
+ " <tr>\n",
257
+ " <td>1</td>\n",
258
+ " <td>0.907200</td>\n",
259
+ " <td>1.365916</td>\n",
260
+ " <td>0.602313</td>\n",
261
+ " <td>0.595804</td>\n",
262
+ " </tr>\n",
263
+ " <tr>\n",
264
+ " <td>2</td>\n",
265
+ " <td>0.549100</td>\n",
266
+ " <td>1.488130</td>\n",
267
+ " <td>0.595566</td>\n",
268
+ " <td>0.591464</td>\n",
269
+ " </tr>\n",
270
+ " <tr>\n",
271
+ " <td>3</td>\n",
272
+ " <td>0.514400</td>\n",
273
+ " <td>1.593286</td>\n",
274
+ " <td>0.591297</td>\n",
275
+ " <td>0.589066</td>\n",
276
+ " </tr>\n",
277
+ " </tbody>\n",
278
+ "</table><p>"
279
+ ],
280
+ "text/plain": [
281
+ "<IPython.core.display.HTML object>"
282
+ ]
283
+ },
284
+ "metadata": {},
285
+ "output_type": "display_data"
286
+ },
287
+ {
288
+ "data": {
289
+ "text/plain": [
290
+ "TrainOutput(global_step=5448, training_loss=0.7054264770818002, metrics={'train_runtime': 5721.3012, 'train_samples_per_second': 15.23, 'train_steps_per_second': 0.952, 'total_flos': 5733080823638016.0, 'train_loss': 0.7054264770818002, 'epoch': 3.0})"
291
+ ]
292
+ },
293
+ "execution_count": 13,
294
+ "metadata": {},
295
+ "output_type": "execute_result"
296
+ }
297
+ ],
298
+ "source": [
299
+ "trainer = train_model(\n",
300
+ " model=model,\n",
301
+ " args=training_args,\n",
302
+ " train_dataset=train_dataset,\n",
303
+ " val_dataset=test_dataset,\n",
304
+ " compute_metrics=compute_metrics\n",
305
+ ")\n",
306
+ "\n",
307
+ "# Begin training\n",
308
+ "trainer.train()"
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "markdown",
313
+ "id": "020729b6-c545-42ba-bd2c-00ee5f9bbb80",
314
+ "metadata": {},
315
+ "source": [
316
+ "### Save both model weights and tokenizer files for future inference or deployment."
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": 23,
322
+ "id": "5f12aedb-b3f8-4a1b-8e1f-6a68eb29933f",
323
+ "metadata": {},
324
+ "outputs": [
325
+ {
326
+ "data": {
327
+ "text/plain": [
328
+ "('../outputs/model/tokenizer_config.json',\n",
329
+ " '../outputs/model/special_tokens_map.json',\n",
330
+ " '../outputs/model/vocab.txt',\n",
331
+ " '../outputs/model/added_tokens.json')"
332
+ ]
333
+ },
334
+ "execution_count": 23,
335
+ "metadata": {},
336
+ "output_type": "execute_result"
337
+ }
338
+ ],
339
+ "source": [
340
+ "from pathlib import Path\n",
341
+ "model_path = Path(\"..\") / \"outputs\" / \"model\"\n",
342
+ "model.save_pretrained(model_path)\n",
343
+ "tokenizer.save_pretrained(model_path)"
344
+ ]
345
+ }
346
+ ],
347
+ "metadata": {
348
+ "kernelspec": {
349
+ "display_name": "Python 3 (ipykernel)",
350
+ "language": "python",
351
+ "name": "python3"
352
+ },
353
+ "language_info": {
354
+ "codemirror_mode": {
355
+ "name": "ipython",
356
+ "version": 3
357
+ },
358
+ "file_extension": ".py",
359
+ "mimetype": "text/x-python",
360
+ "name": "python",
361
+ "nbconvert_exporter": "python",
362
+ "pygments_lexer": "ipython3",
363
+ "version": "3.12.2"
364
+ }
365
+ },
366
+ "nbformat": 4,
367
+ "nbformat_minor": 5
368
+ }
notebooks/03_evaluation.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/04_model_comparison.ipynb ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "4032a920-2db8-4977-8b4f-a5a771dd022f",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import sys\n",
11
+ "import os\n",
12
+ "\n",
13
+ "project_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n",
14
+ "sys.path.append(project_root)\n",
15
+ "\n",
16
+ "from transformers import pipeline\n",
17
+ "from src.model_hartmann import load_model as load_hartmann_model, load_tokenizer as load_hartmann_tokenizer\n",
18
+ "from src.model_custom import load_model as load_custom_model, load_tokenizer as load_custom_tokenizer"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 2,
24
+ "id": "525cf57e-4ec3-40fd-aca2-0e9700a73298",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "hartmann_model = load_hartmann_model()\n",
29
+ "hartmann_tokenizer = load_hartmann_tokenizer()\n",
30
+ "\n",
31
+ "custom_model = load_custom_model()\n",
32
+ "custom_tokenizer = load_custom_tokenizer()"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 3,
38
+ "id": "04f9415c-3d4f-4ac0-8f51-74ec4bd64293",
39
+ "metadata": {},
40
+ "outputs": [
41
+ {
42
+ "name": "stderr",
43
+ "output_type": "stream",
44
+ "text": [
45
+ "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n",
46
+ "/opt/anaconda3/lib/python3.12/site-packages/transformers/pipelines/text_classification.py:104: UserWarning: `return_all_scores` is now deprecated, if want a similar functionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n",
47
+ " warnings.warn(\n",
48
+ "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n"
49
+ ]
50
+ }
51
+ ],
52
+ "source": [
53
+ "# Create pipelines for easy predictions\n",
54
+ "hartmann_pipeline = pipeline(\"text-classification\", model=hartmann_model, tokenizer=hartmann_tokenizer, return_all_scores=True)\n",
55
+ "custom_pipeline = pipeline(\"text-classification\", model=custom_model, tokenizer=custom_tokenizer, return_all_scores=True)"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 4,
61
+ "id": "3fcdf650-3abc-42a6-b1fd-0129e49d1e68",
62
+ "metadata": {},
63
+ "outputs": [
64
+ {
65
+ "name": "stdout",
66
+ "output_type": "stream",
67
+ "text": [
68
+ "========= Sentence 1 ==========\n",
69
+ "Text: I love spending time with my family.\n",
70
+ "\n",
71
+ "--- Hartmann Model Top 3 Predictions ---\n",
72
+ "joy: 0.9883\n",
73
+ "sadness: 0.0067\n",
74
+ "disgust: 0.0013\n",
75
+ "\n",
76
+ "--- Pretrained Model Top 3 Predictions ---\n",
77
+ "love: 0.9536\n",
78
+ "joy: 0.0134\n",
79
+ "admiration: 0.0108\n",
80
+ "\n",
81
+ "\n",
82
+ "========= Sentence 2 ==========\n",
83
+ "Text: This is the worst day of my life.\n",
84
+ "\n",
85
+ "--- Hartmann Model Top 3 Predictions ---\n",
86
+ "disgust: 0.9805\n",
87
+ "anger: 0.0086\n",
88
+ "sadness: 0.0055\n",
89
+ "\n",
90
+ "--- Pretrained Model Top 3 Predictions ---\n",
91
+ "anger: 0.3353\n",
92
+ "surprise: 0.2010\n",
93
+ "disgust: 0.1235\n",
94
+ "\n",
95
+ "\n",
96
+ "========= Sentence 3 ==========\n",
97
+ "Text: I'm feeling very nervous about the exam.\n",
98
+ "\n",
99
+ "--- Hartmann Model Top 3 Predictions ---\n",
100
+ "fear: 0.9947\n",
101
+ "sadness: 0.0013\n",
102
+ "joy: 0.0011\n",
103
+ "\n",
104
+ "--- Pretrained Model Top 3 Predictions ---\n",
105
+ "nervousness: 0.6201\n",
106
+ "fear: 0.0828\n",
107
+ "embarrassment: 0.0393\n",
108
+ "\n",
109
+ "\n",
110
+ "========= Sentence 4 ==========\n",
111
+ "Text: What a beautiful sunset!\n",
112
+ "\n",
113
+ "--- Hartmann Model Top 3 Predictions ---\n",
114
+ "joy: 0.8377\n",
115
+ "surprise: 0.1189\n",
116
+ "neutral: 0.0221\n",
117
+ "\n",
118
+ "--- Pretrained Model Top 3 Predictions ---\n",
119
+ "admiration: 0.8548\n",
120
+ "excitement: 0.0729\n",
121
+ "joy: 0.0351\n",
122
+ "\n",
123
+ "\n",
124
+ "========= Sentence 5 ==========\n",
125
+ "Text: I feel so disappointed and frustrated with the situation.\n",
126
+ "\n",
127
+ "--- Hartmann Model Top 3 Predictions ---\n",
128
+ "sadness: 0.9310\n",
129
+ "anger: 0.0381\n",
130
+ "disgust: 0.0158\n",
131
+ "\n",
132
+ "--- Pretrained Model Top 3 Predictions ---\n",
133
+ "disappointment: 0.5645\n",
134
+ "annoyance: 0.1864\n",
135
+ "anger: 0.0736\n",
136
+ "\n",
137
+ "\n",
138
+ "========= Sentence 6 ==========\n",
139
+ "Text: I'm not sure how to feel about this.\n",
140
+ "\n",
141
+ "--- Hartmann Model Top 3 Predictions ---\n",
142
+ "neutral: 0.5698\n",
143
+ "disgust: 0.2213\n",
144
+ "sadness: 0.0720\n",
145
+ "\n",
146
+ "--- Pretrained Model Top 3 Predictions ---\n",
147
+ "confusion: 0.9011\n",
148
+ "optimism: 0.0230\n",
149
+ "disapproval: 0.0223\n",
150
+ "\n",
151
+ "\n",
152
+ "========= Sentence 7 ==========\n",
153
+ "Text: That was hilarious, I can't stop laughing!\n",
154
+ "\n",
155
+ "--- Hartmann Model Top 3 Predictions ---\n",
156
+ "joy: 0.9336\n",
157
+ "surprise: 0.0306\n",
158
+ "neutral: 0.0178\n",
159
+ "\n",
160
+ "--- Pretrained Model Top 3 Predictions ---\n",
161
+ "amusement: 0.9551\n",
162
+ "joy: 0.0286\n",
163
+ "optimism: 0.0032\n",
164
+ "\n",
165
+ "\n",
166
+ "========= Sentence 8 ==========\n",
167
+ "Text: I feel completely empty and lost.\n",
168
+ "\n",
169
+ "--- Hartmann Model Top 3 Predictions ---\n",
170
+ "sadness: 0.9808\n",
171
+ "neutral: 0.0086\n",
172
+ "disgust: 0.0051\n",
173
+ "\n",
174
+ "--- Pretrained Model Top 3 Predictions ---\n",
175
+ "surprise: 0.8055\n",
176
+ "disappointment: 0.1067\n",
177
+ "optimism: 0.0222\n",
178
+ "\n",
179
+ "\n",
180
+ "========= Sentence 9 ==========\n",
181
+ "Text: Your help means a lot to me, thank you!\n",
182
+ "\n",
183
+ "--- Hartmann Model Top 3 Predictions ---\n",
184
+ "joy: 0.9760\n",
185
+ "neutral: 0.0104\n",
186
+ "surprise: 0.0057\n",
187
+ "\n",
188
+ "--- Pretrained Model Top 3 Predictions ---\n",
189
+ "gratitude: 0.9890\n",
190
+ "caring: 0.0014\n",
191
+ "sadness: 0.0009\n",
192
+ "\n",
193
+ "\n",
194
+ "========= Sentence 10 ==========\n",
195
+ "Text: I'm so angry I could scream.\n",
196
+ "\n",
197
+ "--- Hartmann Model Top 3 Predictions ---\n",
198
+ "anger: 0.9785\n",
199
+ "fear: 0.0084\n",
200
+ "neutral: 0.0047\n",
201
+ "\n",
202
+ "--- Pretrained Model Top 3 Predictions ---\n",
203
+ "anger: 0.9155\n",
204
+ "annoyance: 0.0223\n",
205
+ "optimism: 0.0082\n",
206
+ "\n",
207
+ "\n"
208
+ ]
209
+ }
210
+ ],
211
+ "source": [
212
+ "from tabulate import tabulate\n",
213
+ "\n",
214
+ "goemotions_labels = [\n",
215
+ " \"admiration\", \"amusement\", \"anger\", \"annoyance\", \"approval\", \"caring\", \"confusion\", \"curiosity\",\n",
216
+ " \"desire\", \"disappointment\", \"disapproval\", \"disgust\", \"embarrassment\", \"excitement\", \"fear\",\n",
217
+ " \"gratitude\", \"grief\", \"joy\", \"love\", \"nervousness\", \"optimism\", \"pride\", \"realization\", \"relief\",\n",
218
+ " \"remorse\", \"sadness\", \"surprise\", \"neutral\"\n",
219
+ "]\n",
220
+ "\n",
221
+ "\n",
222
+ "# Your 10 test sentences\n",
223
+ "sentences = [\n",
224
+ " \"I love spending time with my family.\",\n",
225
+ " \"This is the worst day of my life.\",\n",
226
+ " \"I'm feeling very nervous about the exam.\",\n",
227
+ " \"What a beautiful sunset!\",\n",
228
+ " \"I feel so disappointed and frustrated with the situation.\",\n",
229
+ " \"I'm not sure how to feel about this.\",\n",
230
+ " \"That was hilarious, I can't stop laughing!\",\n",
231
+ " \"I feel completely empty and lost.\",\n",
232
+ " \"Your help means a lot to me, thank you!\",\n",
233
+ " \"I'm so angry I could scream.\"\n",
234
+ "]\n",
235
+ "\n",
236
+ "# Loop over sentences and collect results\n",
237
+ "for i, sentence in enumerate(sentences):\n",
238
+ " print(f\"========= Sentence {i+1} ==========\")\n",
239
+ " print(f\"Text: {sentence}\\n\")\n",
240
+ "\n",
241
+ " # Get predictions\n",
242
+ " hartmann_results = hartmann_pipeline(sentence, return_all_scores=True)\n",
243
+ " custom_results = custom_pipeline(sentence, return_all_scores=True)\n",
244
+ "\n",
245
+ " # Unwrap the list to get the actual results\n",
246
+ " hartmann_results = hartmann_results[0]\n",
247
+ " custom_results = custom_results[0]\n",
248
+ "\n",
249
+ " # Sort and get top 3 predictions for each\n",
250
+ " hartmann_top3 = sorted(hartmann_results, key=lambda x: x['score'], reverse=True)[:3]\n",
251
+ " custom_top3 = sorted(custom_results, key=lambda x: x['score'], reverse=True)[:3]\n",
252
+ "\n",
253
+ " # Display Hartmann predictions\n",
254
+ " print(\"--- Hartmann Model Top 3 Predictions ---\")\n",
255
+ " for res in hartmann_top3:\n",
256
+ " print(f\"{res['label']}: {res['score']:.4f}\")\n",
257
+ "\n",
258
+ " # Display Custom Model predictions\n",
259
+ " print(\"\\n--- Pretrained Model Top 3 Predictions ---\")\n",
260
+ " for res in custom_top3:\n",
261
+ " label_idx = int(res['label'].split(\"_\")[-1])\n",
262
+ " emotion = goemotions_labels[label_idx]\n",
263
+ " print(f\"{emotion}: {res['score']:.4f}\")\n",
264
+ "\n",
265
+ " print(\"\\n\")\n"
266
+ ]
267
+ }
268
+ ],
269
+ "metadata": {
270
+ "kernelspec": {
271
+ "display_name": "Python 3 (ipykernel)",
272
+ "language": "python",
273
+ "name": "python3"
274
+ },
275
+ "language_info": {
276
+ "codemirror_mode": {
277
+ "name": "ipython",
278
+ "version": 3
279
+ },
280
+ "file_extension": ".py",
281
+ "mimetype": "text/x-python",
282
+ "name": "python",
283
+ "nbconvert_exporter": "python",
284
+ "pygments_lexer": "ipython3",
285
+ "version": "3.12.2"
286
+ }
287
+ },
288
+ "nbformat": 4,
289
+ "nbformat_minor": 5
290
+ }
outputs/.DS_Store ADDED
Binary file (8.2 kB). View file
 
outputs/interpretations/.DS_Store ADDED
Binary file (6.15 kB). View file
 
outputs/interpretations/sample_1_disapproval_bar.png ADDED
outputs/interpretations/sample_1_disapproval_heatmap.png ADDED
outputs/interpretations/sample_2_neutral_bar.png ADDED
outputs/interpretations/sample_2_neutral_heatmap.png ADDED
outputs/interpretations/sample_3_neutral_bar.png ADDED
outputs/interpretations/sample_3_neutral_heatmap.png ADDED
outputs/interpretations/sample_4_sadness_bar.png ADDED
outputs/interpretations/sample_4_sadness_heatmap.png ADDED
outputs/interpretations/sample_5_neutral_bar.png ADDED
outputs/interpretations/sample_5_neutral_heatmap.png ADDED
outputs/metrics/.DS_Store ADDED
Binary file (6.15 kB). View file
 
outputs/metrics/.gitkeep ADDED
File without changes
outputs/metrics/.ipynb_checkpoints/confusion_matrix-checkpoint.png ADDED
outputs/metrics/.ipynb_checkpoints/report-checkpoint.json ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "admiration": {
3
+ "precision": 0.6699186991869919,
4
+ "recall": 0.7601476014760148,
5
+ "f1-score": 0.7121866897147796,
6
+ "support": 542.0
7
+ },
8
+ "amusement": {
9
+ "precision": 0.727979274611399,
10
+ "recall": 0.8515151515151516,
11
+ "f1-score": 0.7849162011173184,
12
+ "support": 330.0
13
+ },
14
+ "anger": {
15
+ "precision": 0.4329501915708812,
16
+ "recall": 0.551219512195122,
17
+ "f1-score": 0.48497854077253216,
18
+ "support": 205.0
19
+ },
20
+ "annoyance": {
21
+ "precision": 0.2983425414364641,
22
+ "recall": 0.18620689655172415,
23
+ "f1-score": 0.22929936305732485,
24
+ "support": 290.0
25
+ },
26
+ "approval": {
27
+ "precision": 0.38440111420612816,
28
+ "recall": 0.368,
29
+ "f1-score": 0.3760217983651226,
30
+ "support": 375.0
31
+ },
32
+ "caring": {
33
+ "precision": 0.375,
34
+ "recall": 0.4846153846153846,
35
+ "f1-score": 0.4228187919463087,
36
+ "support": 130.0
37
+ },
38
+ "confusion": {
39
+ "precision": 0.4965986394557823,
40
+ "recall": 0.42441860465116277,
41
+ "f1-score": 0.45768025078369906,
42
+ "support": 172.0
43
+ },
44
+ "curiosity": {
45
+ "precision": 0.4628099173553719,
46
+ "recall": 0.60431654676259,
47
+ "f1-score": 0.5241809672386896,
48
+ "support": 278.0
49
+ },
50
+ "desire": {
51
+ "precision": 0.640625,
52
+ "recall": 0.5256410256410257,
53
+ "f1-score": 0.5774647887323944,
54
+ "support": 78.0
55
+ },
56
+ "disappointment": {
57
+ "precision": 0.3258426966292135,
58
+ "recall": 0.20422535211267606,
59
+ "f1-score": 0.2510822510822511,
60
+ "support": 142.0
61
+ },
62
+ "disapproval": {
63
+ "precision": 0.3228915662650602,
64
+ "recall": 0.4785714285714286,
65
+ "f1-score": 0.3856115107913669,
66
+ "support": 280.0
67
+ },
68
+ "disgust": {
69
+ "precision": 0.5113636363636364,
70
+ "recall": 0.45,
71
+ "f1-score": 0.4787234042553192,
72
+ "support": 100.0
73
+ },
74
+ "embarrassment": {
75
+ "precision": 0.7931034482758621,
76
+ "recall": 0.5609756097560976,
77
+ "f1-score": 0.6571428571428571,
78
+ "support": 41.0
79
+ },
80
+ "excitement": {
81
+ "precision": 0.42105263157894735,
82
+ "recall": 0.3137254901960784,
83
+ "f1-score": 0.3595505617977528,
84
+ "support": 102.0
85
+ },
86
+ "fear": {
87
+ "precision": 0.8113207547169812,
88
+ "recall": 0.5,
89
+ "f1-score": 0.6187050359712231,
90
+ "support": 86.0
91
+ },
92
+ "gratitude": {
93
+ "precision": 0.9085872576177285,
94
+ "recall": 0.8840970350404312,
95
+ "f1-score": 0.8961748633879781,
96
+ "support": 371.0
97
+ },
98
+ "grief": {
99
+ "precision": 0.0,
100
+ "recall": 0.0,
101
+ "f1-score": 0.0,
102
+ "support": 8.0
103
+ },
104
+ "joy": {
105
+ "precision": 0.55,
106
+ "recall": 0.5789473684210527,
107
+ "f1-score": 0.5641025641025641,
108
+ "support": 171.0
109
+ },
110
+ "love": {
111
+ "precision": 0.7774193548387097,
112
+ "recall": 0.8456140350877193,
113
+ "f1-score": 0.8100840336134454,
114
+ "support": 285.0
115
+ },
116
+ "nervousness": {
117
+ "precision": 0.46153846153846156,
118
+ "recall": 0.35294117647058826,
119
+ "f1-score": 0.4,
120
+ "support": 17.0
121
+ },
122
+ "neutral": {
123
+ "precision": 0.6760843613211301,
124
+ "recall": 0.6623781676413255,
125
+ "f1-score": 0.6691610870421426,
126
+ "support": 2565.0
127
+ },
128
+ "optimism": {
129
+ "precision": 0.6713286713286714,
130
+ "recall": 0.5581395348837209,
131
+ "f1-score": 0.6095238095238096,
132
+ "support": 172.0
133
+ },
134
+ "pride": {
135
+ "precision": 0.3333333333333333,
136
+ "recall": 0.1,
137
+ "f1-score": 0.15384615384615385,
138
+ "support": 10.0
139
+ },
140
+ "realization": {
141
+ "precision": 0.5625,
142
+ "recall": 0.15384615384615385,
143
+ "f1-score": 0.24161073825503357,
144
+ "support": 117.0
145
+ },
146
+ "relief": {
147
+ "precision": 0.0,
148
+ "recall": 0.0,
149
+ "f1-score": 0.0,
150
+ "support": 17.0
151
+ },
152
+ "remorse": {
153
+ "precision": 0.6122448979591837,
154
+ "recall": 0.8450704225352113,
155
+ "f1-score": 0.7100591715976331,
156
+ "support": 71.0
157
+ },
158
+ "sadness": {
159
+ "precision": 0.5540540540540541,
160
+ "recall": 0.5030674846625767,
161
+ "f1-score": 0.5273311897106109,
162
+ "support": 163.0
163
+ },
164
+ "surprise": {
165
+ "precision": 0.5688622754491018,
166
+ "recall": 0.6597222222222222,
167
+ "f1-score": 0.6109324758842444,
168
+ "support": 144.0
169
+ },
170
+ "accuracy": 0.6023134122831176,
171
+ "macro avg": {
172
+ "precision": 0.5125054563961818,
173
+ "recall": 0.47883579303055207,
174
+ "f1-score": 0.48261389641901975,
175
+ "support": 7262.0
176
+ },
177
+ "weighted avg": {
178
+ "precision": 0.6008835076974202,
179
+ "recall": 0.6023134122831176,
180
+ "f1-score": 0.595804330387673,
181
+ "support": 7262.0
182
+ }
183
+ }
outputs/metrics/report.json ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "admiration": {
3
+ "precision": 0.6699186991869919,
4
+ "recall": 0.7601476014760148,
5
+ "f1-score": 0.7121866897147796,
6
+ "support": 542.0
7
+ },
8
+ "amusement": {
9
+ "precision": 0.727979274611399,
10
+ "recall": 0.8515151515151516,
11
+ "f1-score": 0.7849162011173184,
12
+ "support": 330.0
13
+ },
14
+ "anger": {
15
+ "precision": 0.4329501915708812,
16
+ "recall": 0.551219512195122,
17
+ "f1-score": 0.48497854077253216,
18
+ "support": 205.0
19
+ },
20
+ "annoyance": {
21
+ "precision": 0.2983425414364641,
22
+ "recall": 0.18620689655172415,
23
+ "f1-score": 0.22929936305732485,
24
+ "support": 290.0
25
+ },
26
+ "approval": {
27
+ "precision": 0.38440111420612816,
28
+ "recall": 0.368,
29
+ "f1-score": 0.3760217983651226,
30
+ "support": 375.0
31
+ },
32
+ "caring": {
33
+ "precision": 0.375,
34
+ "recall": 0.4846153846153846,
35
+ "f1-score": 0.4228187919463087,
36
+ "support": 130.0
37
+ },
38
+ "confusion": {
39
+ "precision": 0.4965986394557823,
40
+ "recall": 0.42441860465116277,
41
+ "f1-score": 0.45768025078369906,
42
+ "support": 172.0
43
+ },
44
+ "curiosity": {
45
+ "precision": 0.4628099173553719,
46
+ "recall": 0.60431654676259,
47
+ "f1-score": 0.5241809672386896,
48
+ "support": 278.0
49
+ },
50
+ "desire": {
51
+ "precision": 0.640625,
52
+ "recall": 0.5256410256410257,
53
+ "f1-score": 0.5774647887323944,
54
+ "support": 78.0
55
+ },
56
+ "disappointment": {
57
+ "precision": 0.3258426966292135,
58
+ "recall": 0.20422535211267606,
59
+ "f1-score": 0.2510822510822511,
60
+ "support": 142.0
61
+ },
62
+ "disapproval": {
63
+ "precision": 0.3228915662650602,
64
+ "recall": 0.4785714285714286,
65
+ "f1-score": 0.3856115107913669,
66
+ "support": 280.0
67
+ },
68
+ "disgust": {
69
+ "precision": 0.5113636363636364,
70
+ "recall": 0.45,
71
+ "f1-score": 0.4787234042553192,
72
+ "support": 100.0
73
+ },
74
+ "embarrassment": {
75
+ "precision": 0.7931034482758621,
76
+ "recall": 0.5609756097560976,
77
+ "f1-score": 0.6571428571428571,
78
+ "support": 41.0
79
+ },
80
+ "excitement": {
81
+ "precision": 0.42105263157894735,
82
+ "recall": 0.3137254901960784,
83
+ "f1-score": 0.3595505617977528,
84
+ "support": 102.0
85
+ },
86
+ "fear": {
87
+ "precision": 0.8113207547169812,
88
+ "recall": 0.5,
89
+ "f1-score": 0.6187050359712231,
90
+ "support": 86.0
91
+ },
92
+ "gratitude": {
93
+ "precision": 0.9085872576177285,
94
+ "recall": 0.8840970350404312,
95
+ "f1-score": 0.8961748633879781,
96
+ "support": 371.0
97
+ },
98
+ "grief": {
99
+ "precision": 0.0,
100
+ "recall": 0.0,
101
+ "f1-score": 0.0,
102
+ "support": 8.0
103
+ },
104
+ "joy": {
105
+ "precision": 0.55,
106
+ "recall": 0.5789473684210527,
107
+ "f1-score": 0.5641025641025641,
108
+ "support": 171.0
109
+ },
110
+ "love": {
111
+ "precision": 0.7774193548387097,
112
+ "recall": 0.8456140350877193,
113
+ "f1-score": 0.8100840336134454,
114
+ "support": 285.0
115
+ },
116
+ "nervousness": {
117
+ "precision": 0.46153846153846156,
118
+ "recall": 0.35294117647058826,
119
+ "f1-score": 0.4,
120
+ "support": 17.0
121
+ },
122
+ "neutral": {
123
+ "precision": 0.6760843613211301,
124
+ "recall": 0.6623781676413255,
125
+ "f1-score": 0.6691610870421426,
126
+ "support": 2565.0
127
+ },
128
+ "optimism": {
129
+ "precision": 0.6713286713286714,
130
+ "recall": 0.5581395348837209,
131
+ "f1-score": 0.6095238095238096,
132
+ "support": 172.0
133
+ },
134
+ "pride": {
135
+ "precision": 0.3333333333333333,
136
+ "recall": 0.1,
137
+ "f1-score": 0.15384615384615385,
138
+ "support": 10.0
139
+ },
140
+ "realization": {
141
+ "precision": 0.5625,
142
+ "recall": 0.15384615384615385,
143
+ "f1-score": 0.24161073825503357,
144
+ "support": 117.0
145
+ },
146
+ "relief": {
147
+ "precision": 0.0,
148
+ "recall": 0.0,
149
+ "f1-score": 0.0,
150
+ "support": 17.0
151
+ },
152
+ "remorse": {
153
+ "precision": 0.6122448979591837,
154
+ "recall": 0.8450704225352113,
155
+ "f1-score": 0.7100591715976331,
156
+ "support": 71.0
157
+ },
158
+ "sadness": {
159
+ "precision": 0.5540540540540541,
160
+ "recall": 0.5030674846625767,
161
+ "f1-score": 0.5273311897106109,
162
+ "support": 163.0
163
+ },
164
+ "surprise": {
165
+ "precision": 0.5688622754491018,
166
+ "recall": 0.6597222222222222,
167
+ "f1-score": 0.6109324758842444,
168
+ "support": 144.0
169
+ },
170
+ "accuracy": 0.6023134122831176,
171
+ "macro avg": {
172
+ "precision": 0.5125054563961818,
173
+ "recall": 0.47883579303055207,
174
+ "f1-score": 0.48261389641901975,
175
+ "support": 7262.0
176
+ },
177
+ "weighted avg": {
178
+ "precision": 0.6008835076974202,
179
+ "recall": 0.6023134122831176,
180
+ "f1-score": 0.595804330387673,
181
+ "support": 7262.0
182
+ }
183
+ }
outputs/model-old/.gitkeep ADDED
File without changes
outputs/model-old/.ipynb_checkpoints/special_tokens_map-checkpoint.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
outputs/model-old/config.json ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "bert-base-uncased",
3
+ "architectures": [
4
+ "BertForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "id2label": {
13
+ "0": "LABEL_0",
14
+ "1": "LABEL_1",
15
+ "2": "LABEL_2",
16
+ "3": "LABEL_3",
17
+ "4": "LABEL_4",
18
+ "5": "LABEL_5",
19
+ "6": "LABEL_6",
20
+ "7": "LABEL_7",
21
+ "8": "LABEL_8",
22
+ "9": "LABEL_9",
23
+ "10": "LABEL_10",
24
+ "11": "LABEL_11",
25
+ "12": "LABEL_12",
26
+ "13": "LABEL_13",
27
+ "14": "LABEL_14",
28
+ "15": "LABEL_15",
29
+ "16": "LABEL_16",
30
+ "17": "LABEL_17",
31
+ "18": "LABEL_18",
32
+ "19": "LABEL_19",
33
+ "20": "LABEL_20",
34
+ "21": "LABEL_21",
35
+ "22": "LABEL_22",
36
+ "23": "LABEL_23",
37
+ "24": "LABEL_24",
38
+ "25": "LABEL_25",
39
+ "26": "LABEL_26",
40
+ "27": "LABEL_27"
41
+ },
42
+ "initializer_range": 0.02,
43
+ "intermediate_size": 3072,
44
+ "label2id": {
45
+ "LABEL_0": 0,
46
+ "LABEL_1": 1,
47
+ "LABEL_10": 10,
48
+ "LABEL_11": 11,
49
+ "LABEL_12": 12,
50
+ "LABEL_13": 13,
51
+ "LABEL_14": 14,
52
+ "LABEL_15": 15,
53
+ "LABEL_16": 16,
54
+ "LABEL_17": 17,
55
+ "LABEL_18": 18,
56
+ "LABEL_19": 19,
57
+ "LABEL_2": 2,
58
+ "LABEL_20": 20,
59
+ "LABEL_21": 21,
60
+ "LABEL_22": 22,
61
+ "LABEL_23": 23,
62
+ "LABEL_24": 24,
63
+ "LABEL_25": 25,
64
+ "LABEL_26": 26,
65
+ "LABEL_27": 27,
66
+ "LABEL_3": 3,
67
+ "LABEL_4": 4,
68
+ "LABEL_5": 5,
69
+ "LABEL_6": 6,
70
+ "LABEL_7": 7,
71
+ "LABEL_8": 8,
72
+ "LABEL_9": 9
73
+ },
74
+ "layer_norm_eps": 1e-12,
75
+ "max_position_embeddings": 512,
76
+ "model_type": "bert",
77
+ "num_attention_heads": 12,
78
+ "num_hidden_layers": 12,
79
+ "pad_token_id": 0,
80
+ "position_embedding_type": "absolute",
81
+ "problem_type": "single_label_classification",
82
+ "torch_dtype": "float32",
83
+ "transformers_version": "4.45.2",
84
+ "type_vocab_size": 2,
85
+ "use_cache": true,
86
+ "vocab_size": 30522
87
+ }
outputs/model-old/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a22a9c126ae06332bdd125ee414560a9539321c42dcfd4882c50f56b7d5a8c35
3
+ size 438038624
outputs/model-old/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
outputs/model-old/tokenizer_config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "never_split": null,
51
+ "pad_token": "[PAD]",
52
+ "sep_token": "[SEP]",
53
+ "strip_accents": null,
54
+ "tokenize_chinese_chars": true,
55
+ "tokenizer_class": "BertTokenizer",
56
+ "unk_token": "[UNK]"
57
+ }
outputs/model-old/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
outputs/model/config.json ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "outputs/model",
3
+ "architectures": [
4
+ "BertForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "id2label": {
13
+ "0": "LABEL_0",
14
+ "1": "LABEL_1",
15
+ "2": "LABEL_2",
16
+ "3": "LABEL_3",
17
+ "4": "LABEL_4",
18
+ "5": "LABEL_5",
19
+ "6": "LABEL_6",
20
+ "7": "LABEL_7",
21
+ "8": "LABEL_8",
22
+ "9": "LABEL_9",
23
+ "10": "LABEL_10",
24
+ "11": "LABEL_11",
25
+ "12": "LABEL_12",
26
+ "13": "LABEL_13",
27
+ "14": "LABEL_14",
28
+ "15": "LABEL_15",
29
+ "16": "LABEL_16",
30
+ "17": "LABEL_17",
31
+ "18": "LABEL_18",
32
+ "19": "LABEL_19",
33
+ "20": "LABEL_20",
34
+ "21": "LABEL_21",
35
+ "22": "LABEL_22",
36
+ "23": "LABEL_23",
37
+ "24": "LABEL_24",
38
+ "25": "LABEL_25",
39
+ "26": "LABEL_26",
40
+ "27": "LABEL_27"
41
+ },
42
+ "initializer_range": 0.02,
43
+ "intermediate_size": 3072,
44
+ "label2id": {
45
+ "LABEL_0": 0,
46
+ "LABEL_1": 1,
47
+ "LABEL_10": 10,
48
+ "LABEL_11": 11,
49
+ "LABEL_12": 12,
50
+ "LABEL_13": 13,
51
+ "LABEL_14": 14,
52
+ "LABEL_15": 15,
53
+ "LABEL_16": 16,
54
+ "LABEL_17": 17,
55
+ "LABEL_18": 18,
56
+ "LABEL_19": 19,
57
+ "LABEL_2": 2,
58
+ "LABEL_20": 20,
59
+ "LABEL_21": 21,
60
+ "LABEL_22": 22,
61
+ "LABEL_23": 23,
62
+ "LABEL_24": 24,
63
+ "LABEL_25": 25,
64
+ "LABEL_26": 26,
65
+ "LABEL_27": 27,
66
+ "LABEL_3": 3,
67
+ "LABEL_4": 4,
68
+ "LABEL_5": 5,
69
+ "LABEL_6": 6,
70
+ "LABEL_7": 7,
71
+ "LABEL_8": 8,
72
+ "LABEL_9": 9
73
+ },
74
+ "layer_norm_eps": 1e-12,
75
+ "max_position_embeddings": 512,
76
+ "model_type": "bert",
77
+ "num_attention_heads": 12,
78
+ "num_hidden_layers": 12,
79
+ "pad_token_id": 0,
80
+ "position_embedding_type": "absolute",
81
+ "problem_type": "single_label_classification",
82
+ "torch_dtype": "float32",
83
+ "transformers_version": "4.45.2",
84
+ "type_vocab_size": 2,
85
+ "use_cache": true,
86
+ "vocab_size": 30522
87
+ }
outputs/model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a22a9c126ae06332bdd125ee414560a9539321c42dcfd4882c50f56b7d5a8c35
3
+ size 438038624
outputs/model/special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
outputs/model/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
outputs/model/tokenizer_config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "never_split": null,
51
+ "pad_token": "[PAD]",
52
+ "sep_token": "[SEP]",
53
+ "strip_accents": null,
54
+ "tokenize_chinese_chars": true,
55
+ "tokenizer_class": "BertTokenizer",
56
+ "unk_token": "[UNK]"
57
+ }
outputs/model/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ scikit-learn
4
+ transformers
5
+ torch
6
+ streamlit
7
+ captum
8
+ Pillow
9
+
save_clean_model.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
+
3
+ # Load the current model from the existing folder
4
+ model = AutoModelForSequenceClassification.from_pretrained("outputs/model")
5
+ tokenizer = AutoTokenizer.from_pretrained("outputs/model")
6
+
7
+ # Save to a new folder without any GPU/device info
8
+ model.save_pretrained("outputs/model-clean")
9
+ tokenizer.save_pretrained("outputs/model-clean")
10
+
11
+ print("Model saved to outputs/model-clean")
12
+
src/__pycache__/data_loader.cpython-312.pyc ADDED
Binary file (2.17 kB). View file
 
src/__pycache__/evaluate.cpython-312.pyc ADDED
Binary file (2.11 kB). View file
 
src/__pycache__/model.cpython-312.pyc ADDED
Binary file (652 Bytes). View file
 
src/__pycache__/model_custom.cpython-312.pyc ADDED
Binary file (706 Bytes). View file
 
src/__pycache__/model_hartmann.cpython-312.pyc ADDED
Binary file (715 Bytes). View file