from typing import ( Dict, List, Tuple, ) import torch from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, ) from src.application.config import AI_TEXT_CLASSIFICATION_MODEL def load_model_and_tokenizer( model_path: str = AI_TEXT_CLASSIFICATION_MODEL, ) -> Tuple[AutoTokenizer, AutoModelForSequenceClassification]: """ Loads the trained model and tokenizer from the specified path. Args: model_path: path of directory containing the saved model and tokenizer. Returns: A tuple containing the loaded tokenizer and model. """ tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForSequenceClassification.from_pretrained(model_path) model.eval() return tokenizer, model def predict( texts: List[str], model: AutoModelForSequenceClassification, tokenizer: AutoTokenizer, ) -> List[Dict[str, str]]: """ Classify on input texts into gpt-4o or gpt-4o-mini. Args: texts: A list of input text strings to be classified. model: The loaded model for sequence classification. tokenizer: The loaded tokenizer. Returns: A list of dictionaries, where each dictionary contains the input text, the predicted label, and the confidence score. """ label_map = {0: "GPT-4o", 1: "GPT-4o mini"} inputs = tokenizer( texts, padding="max_length", truncation=True, return_tensors="pt", ) with torch.no_grad(): outputs = model(**inputs) probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) confidence, predictions = torch.max(probabilities, dim=-1) results = [] for text, pred, conf in zip( texts, predictions.tolist(), confidence.tolist(), ): results.append( {"input": text, "prediction": label_map[pred], "confidence": conf}, ) return results if __name__ == "__main__": text = """The resignation brings a long political chapter to an end. Trudeau has been in office since 2015, when he brought the Liberals back to power from the political wilderness. """ tokenizer, model = load_model_and_tokenizer("ductuan024/gpts-detector") predictions = predict(text, model, tokenizer) print(predictions[0]["prediction"]) print(predictions[0]["confidence"])