File size: 1,390 Bytes
1205948
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from smolagents import Tool
from typing import Any, Optional

class SimpleTool(Tool):
    name = "classify_topic"
    description = "This tool classifies whether the given Vietnamese text is related to the specified topic."
    inputs = {'text': {'type': 'string', 'description': 'The Vietnamese text to be classified.'}, 'topic': {'type': 'string', 'description': 'The string representing the topic to be checked.'}}
    output_type = "boolean"

    def forward(self, text: str, topic: str) -> bool:
        """
        This tool classifies whether the given Vietnamese text is related to the specified topic.

        Args:
            text: The Vietnamese text to be classified.
            topic: The string representing the topic to be checked.

        Returns:
            bool: True if the text is related to the topic; False otherwise.
        """
        from transformers import pipeline
        import torch
        device = "cuda" if torch.cuda.is_available() else "cpu"
        classifier = pipeline(
            "zero-shot-classification",
            model="vicgalle/xlm-roberta-large-xnli-anli",
            device=device,
            trust_remote_code=True,
        )

        candidate_labels = [topic, f"không liên quan {topic}"]
        result = classifier(text, candidate_labels)
        predicted_label = result["labels"][0]

        return predicted_label == topic