|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from transformers import BertForSequenceClassification, BertTokenizer |
|
import requests |
|
import json |
|
|
|
|
|
repo_id = "logasanjeev/goemotions-bert" |
|
model = BertForSequenceClassification.from_pretrained(repo_id) |
|
tokenizer = BertTokenizer.from_pretrained(repo_id) |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model.to(device) |
|
if torch.cuda.device_count() > 1: |
|
model = nn.DataParallel(model) |
|
model.eval() |
|
|
|
|
|
thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json" |
|
response = requests.get(thresholds_url) |
|
thresholds_data = json.loads(response.text) |
|
emotion_labels = thresholds_data["emotion_labels"] |
|
best_thresholds = thresholds_data["thresholds"] |
|
|
|
|
|
def predict_emotions(text): |
|
encodings = tokenizer( |
|
text, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=128, |
|
return_tensors='pt' |
|
) |
|
input_ids = encodings['input_ids'].to(device) |
|
attention_mask = encodings['attention_mask'].to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids, attention_mask=attention_mask) |
|
logits = torch.sigmoid(outputs.logits).cpu().numpy()[0] |
|
|
|
predictions = [] |
|
for i, (logit, thresh) in enumerate(zip(logits, best_thresholds)): |
|
if logit >= thresh: |
|
predictions.append((emotion_labels[i], logit)) |
|
|
|
predictions.sort(key=lambda x: x[1], reverse=True) |
|
if not predictions: |
|
return "No emotions predicted above thresholds." |
|
|
|
return "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions]) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict_emotions, |
|
inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."), |
|
outputs="text", |
|
title="GoEmotions BERT Classifier", |
|
description="Predict emotions using a fine-tuned BERT-base model from logasanjeev/goemotions-bert.", |
|
examples=[ |
|
"I’m just chilling today.", |
|
"Thank you for saving my life!", |
|
"I’m nervous about my exam tomorrow." |
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |