emotion-analyzer-bert / onnx_inference.py
logasanjeev's picture
Update onnx_inference.py
c6cef66 verified
import numpy as np
import json
from huggingface_hub import hf_hub_download
import re
import emoji
from transformers import BertTokenizer
import onnxruntime as ort
def preprocess_text(text):
"""Preprocess the input text to match training conditions."""
text = re.sub(r'u/\w+', '[USER]', text)
text = re.sub(r'r/\w+', '[SUBREDDIT]', text)
text = re.sub(r'http[s]?://\S+', '[URL]', text)
text = emoji.demojize(text, delimiters=(" ", " "))
text = text.lower()
return text
def load_model_and_resources():
"""Load the ONNX model, tokenizer, emotion labels, and thresholds from Hugging Face."""
repo_id = "logasanjeev/emotion-analyzer-bert"
try:
tokenizer = BertTokenizer.from_pretrained(repo_id)
except Exception as e:
raise RuntimeError(f"Error loading tokenizer: {str(e)}")
try:
model_path = hf_hub_download(repo_id=repo_id, filename="model.onnx")
session = ort.InferenceSession(model_path)
except Exception as e:
raise RuntimeError(f"Error loading ONNX model: {str(e)}")
try:
thresholds_file = hf_hub_download(repo_id=repo_id, filename="optimized_thresholds.json")
with open(thresholds_file, "r") as f:
thresholds_data = json.load(f)
if not (isinstance(thresholds_data, dict) and "emotion_labels" in thresholds_data and "thresholds" in thresholds_data):
raise ValueError("Unexpected format in optimized_thresholds.json. Expected a dictionary with keys 'emotion_labels' and 'thresholds'.")
emotion_labels = thresholds_data["emotion_labels"]
thresholds = thresholds_data["thresholds"]
except Exception as e:
raise RuntimeError(f"Error loading thresholds: {str(e)}")
return session, tokenizer, emotion_labels, thresholds
SESSION, TOKENIZER, EMOTION_LABELS, THRESHOLDS = None, None, None, None
def predict_emotions(text):
"""Predict emotions for the given text using the GoEmotions BERT ONNX model.
Args:
text (str): The input text to analyze.
Returns:
tuple: (predictions, processed_text)
- predictions (str): Formatted string of predicted emotions and their confidence scores.
- processed_text (str): The preprocessed input text.
"""
global SESSION, TOKENIZER, EMOTION_LABELS, THRESHOLDS
if SESSION is None:
SESSION, TOKENIZER, EMOTION_LABELS, THRESHOLDS = load_model_and_resources()
processed_text = preprocess_text(text)
encodings = TOKENIZER(
processed_text,
padding='max_length',
truncation=True,
max_length=128,
return_tensors='np'
)
inputs = {
'input_ids': encodings['input_ids'].astype(np.int64),
'attention_mask': encodings['attention_mask'].astype(np.int64)
}
logits = SESSION.run(None, inputs)[0][0]
logits = 1 / (1 + np.exp(-logits)) # Sigmoid
predictions = []
for i, (logit, thresh) in enumerate(zip(logits, THRESHOLDS)):
if logit >= thresh:
predictions.append((EMOTION_LABELS[i], round(logit, 4)))
predictions.sort(key=lambda x: x[1], reverse=True)
result = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions]) or "No emotions predicted."
return result, processed_text
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Predict emotions using the GoEmotions BERT ONNX model.")
parser.add_argument("text", type=str, help="The input text to analyze for emotions.")
args = parser.parse_args()
result, processed = predict_emotions(args.text)
print(f"Input: {args.text}")
print(f"Processed: {processed}")
print("Predicted Emotions:")
print(result)