File size: 2,382 Bytes
038a896
4bf9f99
 
038a896
 
63183b6
a6e5417
63183b6
7cf74f6
 
038a896
 
 
a6e5417
 
 
 
70b07d3
a6e5417
038a896
 
 
 
a6e5417
 
038a896
 
4bf9f99
 
 
 
 
70b07d3
718c378
70b07d3
718c378
70b07d3
 
 
 
4bf9f99
 
 
 
 
 
7cf74f6
 
4bf9f99
7cf74f6
038a896
7cf74f6
 
 
 
038a896
7cf74f6
 
 
 
 
 
 
 
 
 
 
 
 
4bf9f99
038a896
 
4bf9f99
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import warnings
from huggingface_hub import spaces

# Suppress all warnings
warnings.filterwarnings("ignore")

os.environ["TRANSFORMERS_CACHE"] = "/tmp"

# Initialize GPU for Hugging Face Spaces
@spaces.GPU
def init_gpu():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model and tokenizer
MODEL_NAME = "s-nlp/roberta-base-formality-ranker"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)

# Move model to GPU
device = init_gpu()
model = model.to(device)

app = FastAPI(title="Formality Classifier API")

class TextInput(BaseModel):
    text: str

def calculate_formality_percentages(score):
    # Convert score to grayscale percentage (0-100)
    grayscale = int(score * 100)
    # Use grayscale to determine formal/informal percentages
    formal_percent = grayscale
    informal_percent = 100 - grayscale
    return formal_percent, informal_percent

@app.get("/")
async def home():
    return {"message": "Formality Classifier API is running! Use /predict to classify text."}

@app.post("/predict")
async def predict_formality(input_data: TextInput):
    try:
        # Tokenize input
        encoding = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding=True)
        encoding = {k: v.to(device) for k, v in encoding.items()}

        # Predict formality score
        with torch.no_grad():
            logits = model(**encoding).logits
        score = logits.softmax(dim=1)[:, 1].item()

        # Calculate percentages using grayscale
        formal_percent, informal_percent = calculate_formality_percentages(score)
        
        # Create response in the new format
        response = {
            "formality_score": round(score, 3),
            "formal_percent": formal_percent,
            "informal_percent": informal_percent,
            "classification": f"Your speech is {formal_percent}% formal and {informal_percent}% informal."
        }
        
        return response
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)