Delete yarngpt_utils.py
Browse files- yarngpt_utils.py +0 -52
yarngpt_utils.py
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
# yarngpt_utils.py
|
2 |
-
import torch
|
3 |
-
import torchaudio
|
4 |
-
import re
|
5 |
-
from outetts.wav_tokenizer.decoder import WavTokenizer
|
6 |
-
from transformers import AutoTokenizer
|
7 |
-
|
8 |
-
class AudioTokenizer:
|
9 |
-
def __init__(self, hf_path, wav_tokenizer_model_path, wav_tokenizer_config_path):
|
10 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
-
self.tokenizer = AutoTokenizer.from_pretrained(hf_path)
|
12 |
-
|
13 |
-
# Fix: Use the correct parameter names for WavTokenizer
|
14 |
-
self.wav_tokenizer = WavTokenizer(
|
15 |
-
checkpoint_path=wav_tokenizer_model_path,
|
16 |
-
config_path=wav_tokenizer_config_path,
|
17 |
-
device=self.device
|
18 |
-
)
|
19 |
-
|
20 |
-
self.speakers = ["idera", "emma", "jude", "osagie", "tayo", "zainab",
|
21 |
-
"joke", "regina", "remi", "umar", "chinenye"]
|
22 |
-
|
23 |
-
def create_prompt(self, text, speaker_name=None):
|
24 |
-
if speaker_name is None or speaker_name not in self.speakers:
|
25 |
-
speaker_name = self.speakers[torch.randint(0, len(self.speakers), (1,)).item()]
|
26 |
-
|
27 |
-
# Create a prompt similar to the original YarnGPT
|
28 |
-
prompt = f"<|system|>\nYou are a helpful assistant that speaks in {speaker_name}'s voice.\n<|user|>\nSpeak this text: {text}\n<|assistant|>"
|
29 |
-
return prompt
|
30 |
-
|
31 |
-
def tokenize_prompt(self, prompt):
|
32 |
-
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
33 |
-
return input_ids
|
34 |
-
|
35 |
-
def get_codes(self, output):
|
36 |
-
# Decode the sequence
|
37 |
-
decoded_str = self.tokenizer.decode(output[0])
|
38 |
-
|
39 |
-
# Extract the part after <|assistant|>
|
40 |
-
speech_part = decoded_str.split("<|assistant|>")[-1].strip()
|
41 |
-
|
42 |
-
# Extract code tokens - assuming format like "<audio_001>"
|
43 |
-
audio_codes = []
|
44 |
-
for match in re.finditer(r"<audio_(\d+)>", speech_part):
|
45 |
-
code = int(match.group(1))
|
46 |
-
audio_codes.append(code)
|
47 |
-
|
48 |
-
return audio_codes
|
49 |
-
|
50 |
-
def get_audio(self, codes):
|
51 |
-
audio = self.wav_tokenizer.decode(torch.tensor(codes, device=self.device))
|
52 |
-
return audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|