VLAIResearchLab's picture
Upload agent
1205948 verified
from smolagents import Tool
from typing import Any, Optional
class SimpleTool(Tool):
name = "classify_topic"
description = "This tool classifies whether the given Vietnamese text is related to the specified topic."
inputs = {'text': {'type': 'string', 'description': 'The Vietnamese text to be classified.'}, 'topic': {'type': 'string', 'description': 'The string representing the topic to be checked.'}}
output_type = "boolean"
def forward(self, text: str, topic: str) -> bool:
"""
This tool classifies whether the given Vietnamese text is related to the specified topic.
Args:
text: The Vietnamese text to be classified.
topic: The string representing the topic to be checked.
Returns:
bool: True if the text is related to the topic; False otherwise.
"""
from transformers import pipeline
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
classifier = pipeline(
"zero-shot-classification",
model="vicgalle/xlm-roberta-large-xnli-anli",
device=device,
trust_remote_code=True,
)
candidate_labels = [topic, f"không liên quan {topic}"]
result = classifier(text, candidate_labels)
predicted_label = result["labels"][0]
return predicted_label == topic