Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
import warnings | |
# Suppress NVML warning | |
warnings.filterwarnings("ignore", message="Can't initialize NVML") | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if torch.cuda.is_available(): | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
MODEL_NAME = "s-nlp/roberta-base-formality-ranker" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
model = model.to(device) | |
def calculate_formality_percentages(score): | |
grayscale = int(score * 100) | |
formal_percent = grayscale | |
informal_percent = 100 - grayscale | |
return formal_percent, informal_percent | |
def predict_formality(text): | |
# Tokenize input | |
encoding = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
encoding = {k: v.to(device) for k, v in encoding.items()} | |
# Predict formality score | |
with torch.no_grad(): | |
logits = model(**encoding).logits | |
score = logits.softmax(dim=1)[:, 1].item() | |
# Calculate percentages using grayscale | |
formal_percent, informal_percent = calculate_formality_percentages(score) | |
response = { | |
"formality_score": round(score, 3), | |
"formal_percent": formal_percent, | |
"informal_percent": informal_percent, | |
"classification": f"Your speech is {formal_percent}% formal and {informal_percent}% informal." | |
} | |
return response | |
demo = gr.Interface( | |
fn=predict_formality, | |
inputs=gr.Textbox(label="Enter your text", lines=3), | |
outputs=gr.JSON(label="Formality Analysis"), | |
title="Formality Classifier", | |
description="Enter text to analyze its formality level.", | |
examples=[ | |
["Hello, how are you doing today?"], | |
["Hey, what's up?"], | |
["I would like to request your assistance with this matter."] | |
] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) |