okewunmi commited on
Commit
6f34f59
·
verified ·
1 Parent(s): 1c22498

Delete yarngpt_utils.py

Browse files
Files changed (1) hide show
  1. yarngpt_utils.py +0 -52
yarngpt_utils.py DELETED
@@ -1,52 +0,0 @@
1
- # yarngpt_utils.py
2
- import torch
3
- import torchaudio
4
- import re
5
- from outetts.wav_tokenizer.decoder import WavTokenizer
6
- from transformers import AutoTokenizer
7
-
8
- class AudioTokenizer:
9
- def __init__(self, hf_path, wav_tokenizer_model_path, wav_tokenizer_config_path):
10
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
- self.tokenizer = AutoTokenizer.from_pretrained(hf_path)
12
-
13
- # Fix: Use the correct parameter names for WavTokenizer
14
- self.wav_tokenizer = WavTokenizer(
15
- checkpoint_path=wav_tokenizer_model_path,
16
- config_path=wav_tokenizer_config_path,
17
- device=self.device
18
- )
19
-
20
- self.speakers = ["idera", "emma", "jude", "osagie", "tayo", "zainab",
21
- "joke", "regina", "remi", "umar", "chinenye"]
22
-
23
- def create_prompt(self, text, speaker_name=None):
24
- if speaker_name is None or speaker_name not in self.speakers:
25
- speaker_name = self.speakers[torch.randint(0, len(self.speakers), (1,)).item()]
26
-
27
- # Create a prompt similar to the original YarnGPT
28
- prompt = f"<|system|>\nYou are a helpful assistant that speaks in {speaker_name}'s voice.\n<|user|>\nSpeak this text: {text}\n<|assistant|>"
29
- return prompt
30
-
31
- def tokenize_prompt(self, prompt):
32
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
33
- return input_ids
34
-
35
- def get_codes(self, output):
36
- # Decode the sequence
37
- decoded_str = self.tokenizer.decode(output[0])
38
-
39
- # Extract the part after <|assistant|>
40
- speech_part = decoded_str.split("<|assistant|>")[-1].strip()
41
-
42
- # Extract code tokens - assuming format like "<audio_001>"
43
- audio_codes = []
44
- for match in re.finditer(r"<audio_(\d+)>", speech_part):
45
- code = int(match.group(1))
46
- audio_codes.append(code)
47
-
48
- return audio_codes
49
-
50
- def get_audio(self, codes):
51
- audio = self.wav_tokenizer.decode(torch.tensor(codes, device=self.device))
52
- return audio