cheesecz's picture
Update app.py
a6e5417 verified
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import warnings
from huggingface_hub import spaces
# Suppress all warnings
warnings.filterwarnings("ignore")
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
# Initialize GPU for Hugging Face Spaces
@spaces.GPU
def init_gpu():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model and tokenizer
MODEL_NAME = "s-nlp/roberta-base-formality-ranker"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
# Move model to GPU
device = init_gpu()
model = model.to(device)
app = FastAPI(title="Formality Classifier API")
class TextInput(BaseModel):
text: str
def calculate_formality_percentages(score):
# Convert score to grayscale percentage (0-100)
grayscale = int(score * 100)
# Use grayscale to determine formal/informal percentages
formal_percent = grayscale
informal_percent = 100 - grayscale
return formal_percent, informal_percent
@app.get("/")
async def home():
return {"message": "Formality Classifier API is running! Use /predict to classify text."}
@app.post("/predict")
async def predict_formality(input_data: TextInput):
try:
# Tokenize input
encoding = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding=True)
encoding = {k: v.to(device) for k, v in encoding.items()}
# Predict formality score
with torch.no_grad():
logits = model(**encoding).logits
score = logits.softmax(dim=1)[:, 1].item()
# Calculate percentages using grayscale
formal_percent, informal_percent = calculate_formality_percentages(score)
# Create response in the new format
response = {
"formality_score": round(score, 3),
"formal_percent": formal_percent,
"informal_percent": informal_percent,
"classification": f"Your speech is {formal_percent}% formal and {informal_percent}% informal."
}
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)