cheesecz's picture
Update app.py
a6e5417 verified
raw
history blame
2.38 kB
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import warnings
from huggingface_hub import spaces
# Suppress all warnings
warnings.filterwarnings("ignore")
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
# Initialize GPU for Hugging Face Spaces
@spaces.GPU
def init_gpu():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model and tokenizer
MODEL_NAME = "s-nlp/roberta-base-formality-ranker"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
# Move model to GPU
device = init_gpu()
model = model.to(device)
app = FastAPI(title="Formality Classifier API")
class TextInput(BaseModel):
text: str
def calculate_formality_percentages(score):
# Convert score to grayscale percentage (0-100)
grayscale = int(score * 100)
# Use grayscale to determine formal/informal percentages
formal_percent = grayscale
informal_percent = 100 - grayscale
return formal_percent, informal_percent
@app.get("/")
async def home():
return {"message": "Formality Classifier API is running! Use /predict to classify text."}
@app.post("/predict")
async def predict_formality(input_data: TextInput):
try:
# Tokenize input
encoding = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding=True)
encoding = {k: v.to(device) for k, v in encoding.items()}
# Predict formality score
with torch.no_grad():
logits = model(**encoding).logits
score = logits.softmax(dim=1)[:, 1].item()
# Calculate percentages using grayscale
formal_percent, informal_percent = calculate_formality_percentages(score)
# Create response in the new format
response = {
"formality_score": round(score, 3),
"formal_percent": formal_percent,
"informal_percent": informal_percent,
"classification": f"Your speech is {formal_percent}% formal and {informal_percent}% informal."
}
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)