import torch class SimpleTokenizer: def __init__(self, vocab_path): self.char_to_idx = torch.load(vocab_pth) # Add if not in vocab if '' not in self.char_to_idx: self.char_to_idx[''] = max(self.char_to_idx.values()) + 1 self.idx_to_char = {i: c for c, i in self.char_to_idx.items()} def encode(self, text): return [self.char_to_idx.get(c, self.char_to_idx.get('', 0)) for c in text] def decode(self, indices): return ''.join([self.idx_to_char.get(i, '') for i in indices]) # Example usage vocab_path = 'vocab.pth' # Replace with the actual path to your vocab file tokenizer = SimpleTokenizer(vocab_path) text = "Hello, world!" tokens = tokenizer.encode(text) # Use the encode method here print(tokens) decoded_text = tokenizer.decode(tokens) print(decoded_text)