File size: 574 Bytes
88afac1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
from transformers import ByT5Tokenizer
class CustomByT5Tokenizer(ByT5Tokenizer):
def encode(self, text, add_special_tokens=False, **kwargs):
"""
Override the encode method.
Args:
text (str): Input text
add_special_tokens (bool): Whether to add BOS/EOS tokens
"""
# Use the parent class's encode method
tokens = super().encode(text, add_special_tokens=add_special_tokens, **kwargs)
return torch.tensor(tokens)
tok = CustomByT5Tokenizer.from_pretrained("google/byt5-small")
|