Spaces:
Sleeping
Sleeping
import os | |
from pydantic import BaseModel | |
from typing import List, Dict, Union | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
# Definition of Pydantic data models | |
class ProblematicItem(BaseModel): | |
text: str | |
class ProblematicList(BaseModel): | |
problematics: List[str] | |
class PredictionResponse(BaseModel): | |
predicted_class: str | |
score: float | |
class PredictionsResponse(BaseModel): | |
results: List[Dict[str, Union[str, float]]] | |
class BatchPredictionScoreItem(BaseModel): | |
problematic: str | |
score: float | |
# Model environment variables | |
MODEL_NAME = os.getenv("MODEL_NAME") | |
LABEL_0 = os.getenv("LABEL_0") | |
LABEL_1 = os.getenv("LABEL_1") | |
if not MODEL_NAME: | |
raise ValueError("Environment variable MODEL_NAME is not set.") | |
# Loading the model and tokenizer | |
tokenizer = None | |
model = None | |
def load_model(): | |
global tokenizer, model | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
return True | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
return False | |
def health_check(): | |
global model, tokenizer | |
if model is None or tokenizer is None: | |
success = load_model() | |
if not success: | |
print("Model not available") | |
return {"status": "ok", "model": MODEL_NAME} | |
def predict_single(item: ProblematicItem): | |
global model, tokenizer | |
if model is None or tokenizer is None: | |
success = load_model() | |
if not success: | |
print('Error loading the model.') | |
try: | |
# Tokenization | |
inputs = tokenizer(item.text, padding=True, truncation=True, return_tensors="pt") | |
# Prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
predicted_class = torch.argmax(probabilities, dim=1).item() | |
confidence_score = probabilities[0][predicted_class].item() | |
# Associate the correct label | |
predicted_label = LABEL_0 if predicted_class == 0 else LABEL_1 | |
return PredictionResponse(predicted_class=predicted_label, score=confidence_score) | |
except Exception as e: | |
print(f"Error during prediction: {str(e)}") | |
def predict_batch(items: ProblematicList): | |
global model, tokenizer | |
if model is None or tokenizer is None: | |
success = load_model() | |
if not success: | |
print("Model not available") | |
try: | |
results = [] | |
# Batch processing | |
batch_size = 8 | |
for i in range(0, len(items.problematics), batch_size): | |
batch_texts = items.problematics[i:i+batch_size] | |
# Tokenization | |
inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt") | |
# Prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
# predicted_classes = torch.argmax(probabilities, dim=1).tolist() | |
# confidence_scores = [probabilities[j][predicted_classes[j]].item() for j in range(len(predicted_classes))] | |
for j in range(len(batch_texts)): | |
score_specific_class = probabilities[j][1].item() | |
results.append( | |
BatchPredictionScoreItem( | |
problematic=batch_texts[j], | |
score=score_specific_class | |
) | |
) | |
return results | |
except Exception as e: | |
print(f"Error during prediction: {str(e)}") |