|
import numpy as np |
|
import json |
|
from huggingface_hub import hf_hub_download |
|
import re |
|
import emoji |
|
from transformers import BertTokenizer |
|
import onnxruntime as ort |
|
|
|
def preprocess_text(text): |
|
"""Preprocess the input text to match training conditions.""" |
|
text = re.sub(r'u/\w+', '[USER]', text) |
|
text = re.sub(r'r/\w+', '[SUBREDDIT]', text) |
|
text = re.sub(r'http[s]?://\S+', '[URL]', text) |
|
text = emoji.demojize(text, delimiters=(" ", " ")) |
|
text = text.lower() |
|
return text |
|
|
|
def load_model_and_resources(): |
|
"""Load the ONNX model, tokenizer, emotion labels, and thresholds from Hugging Face.""" |
|
repo_id = "logasanjeev/emotion-analyzer-bert" |
|
|
|
try: |
|
tokenizer = BertTokenizer.from_pretrained(repo_id) |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading tokenizer: {str(e)}") |
|
|
|
try: |
|
model_path = hf_hub_download(repo_id=repo_id, filename="model.onnx") |
|
session = ort.InferenceSession(model_path) |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading ONNX model: {str(e)}") |
|
|
|
try: |
|
thresholds_file = hf_hub_download(repo_id=repo_id, filename="optimized_thresholds.json") |
|
with open(thresholds_file, "r") as f: |
|
thresholds_data = json.load(f) |
|
if not (isinstance(thresholds_data, dict) and "emotion_labels" in thresholds_data and "thresholds" in thresholds_data): |
|
raise ValueError("Unexpected format in optimized_thresholds.json. Expected a dictionary with keys 'emotion_labels' and 'thresholds'.") |
|
emotion_labels = thresholds_data["emotion_labels"] |
|
thresholds = thresholds_data["thresholds"] |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading thresholds: {str(e)}") |
|
|
|
return session, tokenizer, emotion_labels, thresholds |
|
|
|
SESSION, TOKENIZER, EMOTION_LABELS, THRESHOLDS = None, None, None, None |
|
|
|
def predict_emotions(text): |
|
"""Predict emotions for the given text using the GoEmotions BERT ONNX model. |
|
|
|
Args: |
|
text (str): The input text to analyze. |
|
|
|
Returns: |
|
tuple: (predictions, processed_text) |
|
- predictions (str): Formatted string of predicted emotions and their confidence scores. |
|
- processed_text (str): The preprocessed input text. |
|
""" |
|
global SESSION, TOKENIZER, EMOTION_LABELS, THRESHOLDS |
|
|
|
if SESSION is None: |
|
SESSION, TOKENIZER, EMOTION_LABELS, THRESHOLDS = load_model_and_resources() |
|
|
|
processed_text = preprocess_text(text) |
|
|
|
encodings = TOKENIZER( |
|
processed_text, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=128, |
|
return_tensors='np' |
|
) |
|
|
|
inputs = { |
|
'input_ids': encodings['input_ids'].astype(np.int64), |
|
'attention_mask': encodings['attention_mask'].astype(np.int64) |
|
} |
|
|
|
logits = SESSION.run(None, inputs)[0][0] |
|
logits = 1 / (1 + np.exp(-logits)) |
|
|
|
predictions = [] |
|
for i, (logit, thresh) in enumerate(zip(logits, THRESHOLDS)): |
|
if logit >= thresh: |
|
predictions.append((EMOTION_LABELS[i], round(logit, 4))) |
|
|
|
predictions.sort(key=lambda x: x[1], reverse=True) |
|
|
|
result = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions]) or "No emotions predicted." |
|
return result, processed_text |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="Predict emotions using the GoEmotions BERT ONNX model.") |
|
parser.add_argument("text", type=str, help="The input text to analyze for emotions.") |
|
args = parser.parse_args() |
|
|
|
result, processed = predict_emotions(args.text) |
|
print(f"Input: {args.text}") |
|
print(f"Processed: {processed}") |
|
print("Predicted Emotions:") |
|
print(result) |