BiBiER / synthetic_utils /parler_tts_wrapper.py
farbverlauf's picture
gpu
960b1a0
# parler_tts_wrapper.py
import torch
import soundfile as sf
import time
import os
import logging
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
class ParlerTTS:
def __init__(self, model_name="parler-tts/parler-tts-mini-v1", device="cuda"):
self.device = device
logging.info(f"[ParlerTTS] Загрузка модели {model_name} на {device} ...")
self.model = ParlerTTSForConditionalGeneration.from_pretrained(model_name).to(device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.sr = self.model.config.sampling_rate
def generate_audio_from_text(self, text: str, description: str) -> torch.Tensor:
"""
Генерирует аудио (без сохранения на диск).
Возвращает PyTorch-тензор формы (1, num_samples).
"""
input_ids = self.tokenizer(description, return_tensors="pt").input_ids.to(self.device)
prompt_input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to(self.device)
with torch.no_grad():
generation = self.model.generate(
input_ids=input_ids,
prompt_input_ids=prompt_input_ids
)
audio_arr = generation.cpu().numpy().squeeze() # (samples,)
wf = torch.from_numpy(audio_arr).unsqueeze(0) # -> (1, samples)
return wf
def generate_and_save_audio(self, text: str, description: str, out_dir="tts_outputs", filename_prefix="tts") -> torch.Tensor:
"""
Генерирует аудио И сохраняет результат в WAV-файл (для отладки/проверки).
Возвращает PyTorch-тензор (1, num_samples).
"""
os.makedirs(out_dir, exist_ok=True)
wf = self.generate_audio_from_text(text, description)
np_wf = wf.squeeze().cpu().numpy()
# Формируем имя файла
timestr = time.strftime("%Y%m%d_%H%M%S")
filename = f"{filename_prefix}_{timestr}.wav"
out_path = os.path.join(out_dir, filename)
# Сохраняем
sf.write(out_path, np_wf, self.sr)
logging.info(f"[ParlerTTS] Сохранено аудио: {out_path}")
return wf
def get_sample_rate(self):
return self.sr