Spaces:
Runtime error
Runtime error
import os | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
import warnings | |
from huggingface_hub import spaces | |
# Suppress all warnings | |
warnings.filterwarnings("ignore") | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp" | |
# Initialize GPU for Hugging Face Spaces | |
def init_gpu(): | |
return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Initialize model and tokenizer | |
MODEL_NAME = "s-nlp/roberta-base-formality-ranker" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
# Move model to GPU | |
device = init_gpu() | |
model = model.to(device) | |
app = FastAPI(title="Formality Classifier API") | |
class TextInput(BaseModel): | |
text: str | |
def calculate_formality_percentages(score): | |
# Convert score to grayscale percentage (0-100) | |
grayscale = int(score * 100) | |
# Use grayscale to determine formal/informal percentages | |
formal_percent = grayscale | |
informal_percent = 100 - grayscale | |
return formal_percent, informal_percent | |
async def home(): | |
return {"message": "Formality Classifier API is running! Use /predict to classify text."} | |
async def predict_formality(input_data: TextInput): | |
try: | |
# Tokenize input | |
encoding = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding=True) | |
encoding = {k: v.to(device) for k, v in encoding.items()} | |
# Predict formality score | |
with torch.no_grad(): | |
logits = model(**encoding).logits | |
score = logits.softmax(dim=1)[:, 1].item() | |
# Calculate percentages using grayscale | |
formal_percent, informal_percent = calculate_formality_percentages(score) | |
# Create response in the new format | |
response = { | |
"formality_score": round(score, 3), | |
"formal_percent": formal_percent, | |
"informal_percent": informal_percent, | |
"classification": f"Your speech is {formal_percent}% formal and {informal_percent}% informal." | |
} | |
return response | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |