{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "4032a920-2db8-4977-8b4f-a5a771dd022f", "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "\n", "project_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n", "sys.path.append(project_root)\n", "\n", "from transformers import pipeline\n", "from src.model_hartmann import load_model as load_hartmann_model, load_tokenizer as load_hartmann_tokenizer\n", "from src.model_custom import load_model as load_custom_model, load_tokenizer as load_custom_tokenizer" ] }, { "cell_type": "code", "execution_count": 2, "id": "525cf57e-4ec3-40fd-aca2-0e9700a73298", "metadata": {}, "outputs": [], "source": [ "hartmann_model = load_hartmann_model()\n", "hartmann_tokenizer = load_hartmann_tokenizer()\n", "\n", "custom_model = load_custom_model()\n", "custom_tokenizer = load_custom_tokenizer()" ] }, { "cell_type": "code", "execution_count": 3, "id": "04f9415c-3d4f-4ac0-8f51-74ec4bd64293", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n", "/opt/anaconda3/lib/python3.12/site-packages/transformers/pipelines/text_classification.py:104: UserWarning: `return_all_scores` is now deprecated, if want a similar functionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n", " warnings.warn(\n", "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n" ] } ], "source": [ "# Create pipelines for easy predictions\n", "hartmann_pipeline = pipeline(\"text-classification\", model=hartmann_model, tokenizer=hartmann_tokenizer, return_all_scores=True)\n", "custom_pipeline = pipeline(\"text-classification\", model=custom_model, tokenizer=custom_tokenizer, return_all_scores=True)" ] }, { "cell_type": "code", "execution_count": 4, "id": "3fcdf650-3abc-42a6-b1fd-0129e49d1e68", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "========= Sentence 1 ==========\n", "Text: I love spending time with my family.\n", "\n", "--- Hartmann Model Top 3 Predictions ---\n", "joy: 0.9883\n", "sadness: 0.0067\n", "disgust: 0.0013\n", "\n", "--- Pretrained Model Top 3 Predictions ---\n", "love: 0.9536\n", "joy: 0.0134\n", "admiration: 0.0108\n", "\n", "\n", "========= Sentence 2 ==========\n", "Text: This is the worst day of my life.\n", "\n", "--- Hartmann Model Top 3 Predictions ---\n", "disgust: 0.9805\n", "anger: 0.0086\n", "sadness: 0.0055\n", "\n", "--- Pretrained Model Top 3 Predictions ---\n", "anger: 0.3353\n", "surprise: 0.2010\n", "disgust: 0.1235\n", "\n", "\n", "========= Sentence 3 ==========\n", "Text: I'm feeling very nervous about the exam.\n", "\n", "--- Hartmann Model Top 3 Predictions ---\n", "fear: 0.9947\n", "sadness: 0.0013\n", "joy: 0.0011\n", "\n", "--- Pretrained Model Top 3 Predictions ---\n", "nervousness: 0.6201\n", "fear: 0.0828\n", "embarrassment: 0.0393\n", "\n", "\n", "========= Sentence 4 ==========\n", "Text: What a beautiful sunset!\n", "\n", "--- Hartmann Model Top 3 Predictions ---\n", "joy: 0.8377\n", "surprise: 0.1189\n", "neutral: 0.0221\n", "\n", "--- Pretrained Model Top 3 Predictions ---\n", "admiration: 0.8548\n", "excitement: 0.0729\n", "joy: 0.0351\n", "\n", "\n", "========= Sentence 5 ==========\n", "Text: I feel so disappointed and frustrated with the situation.\n", "\n", "--- Hartmann Model Top 3 Predictions ---\n", "sadness: 0.9310\n", "anger: 0.0381\n", "disgust: 0.0158\n", "\n", "--- Pretrained Model Top 3 Predictions ---\n", "disappointment: 0.5645\n", "annoyance: 0.1864\n", "anger: 0.0736\n", "\n", "\n", "========= Sentence 6 ==========\n", "Text: I'm not sure how to feel about this.\n", "\n", "--- Hartmann Model Top 3 Predictions ---\n", "neutral: 0.5698\n", "disgust: 0.2213\n", "sadness: 0.0720\n", "\n", "--- Pretrained Model Top 3 Predictions ---\n", "confusion: 0.9011\n", "optimism: 0.0230\n", "disapproval: 0.0223\n", "\n", "\n", "========= Sentence 7 ==========\n", "Text: That was hilarious, I can't stop laughing!\n", "\n", "--- Hartmann Model Top 3 Predictions ---\n", "joy: 0.9336\n", "surprise: 0.0306\n", "neutral: 0.0178\n", "\n", "--- Pretrained Model Top 3 Predictions ---\n", "amusement: 0.9551\n", "joy: 0.0286\n", "optimism: 0.0032\n", "\n", "\n", "========= Sentence 8 ==========\n", "Text: I feel completely empty and lost.\n", "\n", "--- Hartmann Model Top 3 Predictions ---\n", "sadness: 0.9808\n", "neutral: 0.0086\n", "disgust: 0.0051\n", "\n", "--- Pretrained Model Top 3 Predictions ---\n", "surprise: 0.8055\n", "disappointment: 0.1067\n", "optimism: 0.0222\n", "\n", "\n", "========= Sentence 9 ==========\n", "Text: Your help means a lot to me, thank you!\n", "\n", "--- Hartmann Model Top 3 Predictions ---\n", "joy: 0.9760\n", "neutral: 0.0104\n", "surprise: 0.0057\n", "\n", "--- Pretrained Model Top 3 Predictions ---\n", "gratitude: 0.9890\n", "caring: 0.0014\n", "sadness: 0.0009\n", "\n", "\n", "========= Sentence 10 ==========\n", "Text: I'm so angry I could scream.\n", "\n", "--- Hartmann Model Top 3 Predictions ---\n", "anger: 0.9785\n", "fear: 0.0084\n", "neutral: 0.0047\n", "\n", "--- Pretrained Model Top 3 Predictions ---\n", "anger: 0.9155\n", "annoyance: 0.0223\n", "optimism: 0.0082\n", "\n", "\n" ] } ], "source": [ "from tabulate import tabulate\n", "\n", "goemotions_labels = [\n", " \"admiration\", \"amusement\", \"anger\", \"annoyance\", \"approval\", \"caring\", \"confusion\", \"curiosity\",\n", " \"desire\", \"disappointment\", \"disapproval\", \"disgust\", \"embarrassment\", \"excitement\", \"fear\",\n", " \"gratitude\", \"grief\", \"joy\", \"love\", \"nervousness\", \"optimism\", \"pride\", \"realization\", \"relief\",\n", " \"remorse\", \"sadness\", \"surprise\", \"neutral\"\n", "]\n", "\n", "\n", "# Your 10 test sentences\n", "sentences = [\n", " \"I love spending time with my family.\",\n", " \"This is the worst day of my life.\",\n", " \"I'm feeling very nervous about the exam.\",\n", " \"What a beautiful sunset!\",\n", " \"I feel so disappointed and frustrated with the situation.\",\n", " \"I'm not sure how to feel about this.\",\n", " \"That was hilarious, I can't stop laughing!\",\n", " \"I feel completely empty and lost.\",\n", " \"Your help means a lot to me, thank you!\",\n", " \"I'm so angry I could scream.\"\n", "]\n", "\n", "# Loop over sentences and collect results\n", "for i, sentence in enumerate(sentences):\n", " print(f\"========= Sentence {i+1} ==========\")\n", " print(f\"Text: {sentence}\\n\")\n", "\n", " # Get predictions\n", " hartmann_results = hartmann_pipeline(sentence, return_all_scores=True)\n", " custom_results = custom_pipeline(sentence, return_all_scores=True)\n", "\n", " # Unwrap the list to get the actual results\n", " hartmann_results = hartmann_results[0]\n", " custom_results = custom_results[0]\n", "\n", " # Sort and get top 3 predictions for each\n", " hartmann_top3 = sorted(hartmann_results, key=lambda x: x['score'], reverse=True)[:3]\n", " custom_top3 = sorted(custom_results, key=lambda x: x['score'], reverse=True)[:3]\n", "\n", " # Display Hartmann predictions\n", " print(\"--- Hartmann Model Top 3 Predictions ---\")\n", " for res in hartmann_top3:\n", " print(f\"{res['label']}: {res['score']:.4f}\")\n", "\n", " # Display Custom Model predictions\n", " print(\"\\n--- Pretrained Model Top 3 Predictions ---\")\n", " for res in custom_top3:\n", " label_idx = int(res['label'].split(\"_\")[-1])\n", " emotion = goemotions_labels[label_idx]\n", " print(f\"{emotion}: {res['score']:.4f}\")\n", "\n", " print(\"\\n\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 5 }