# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import html import string import ftfy import regex as re from transformers import AutoTokenizer __all__ = ["HuggingfaceTokenizer"] def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def whitespace_clean(text): text = re.sub(r"\s+", " ", text) text = text.strip() return text def canonicalize(text, keep_punctuation_exact_string=None): text = text.replace("_", " ") if keep_punctuation_exact_string: text = keep_punctuation_exact_string.join( part.translate(str.maketrans("", "", string.punctuation)) for part in text.split(keep_punctuation_exact_string) ) else: text = text.translate(str.maketrans("", "", string.punctuation)) text = text.lower() text = re.sub(r"\s+", " ", text) return text.strip() class HuggingfaceTokenizer: def __init__(self, name, seq_len=None, clean=None, **kwargs): assert clean in (None, "whitespace", "lower", "canonicalize") self.name = name self.seq_len = seq_len self.clean = clean # init tokenizer self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) self.vocab_size = self.tokenizer.vocab_size def __call__(self, sequence, **kwargs): return_mask = kwargs.pop("return_mask", False) # arguments _kwargs = {"return_tensors": "pt"} if self.seq_len is not None: _kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len}) _kwargs.update(**kwargs) # tokenization if isinstance(sequence, str): sequence = [sequence] if self.clean: sequence = [self._clean(u) for u in sequence] ids = self.tokenizer(sequence, **_kwargs) # output if return_mask: return ids.input_ids, ids.attention_mask else: return ids.input_ids def _clean(self, text): if self.clean == "whitespace": text = whitespace_clean(basic_clean(text)) elif self.clean == "lower": text = whitespace_clean(basic_clean(text)).lower() elif self.clean == "canonicalize": text = canonicalize(basic_clean(text)) return text