|
import re |
|
import time |
|
|
|
import inflect |
|
import torch |
|
import torch.nn.functional as F |
|
from torchaudio.transforms import Resample |
|
from torch import Tensor |
|
from torch.nn.attention import SDPBackend, sdpa_kernel |
|
|
|
from vui.model import Vui |
|
from vui.sampling import multinomial, sample_top_k, sample_top_p, sample_top_p_top_k |
|
|
|
resample = Resample(22050, 16000).cuda() |
|
|
|
def ensure_spaces_around_tags(text: str): |
|
|
|
text = re.sub( |
|
r"(?<![<\[\s])(\[)", |
|
lambda m: ( |
|
f"\n{m.group(1)}" |
|
if m.start() > 0 and text[m.start() - 1] == "\n" |
|
else f" {m.group(1)}" |
|
), |
|
text, |
|
) |
|
|
|
text = re.sub( |
|
r"(?<!\d\])(\])(?![>\]\s])", |
|
lambda m: ( |
|
f"{m.group(1)}\n" |
|
if m.end() < len(text) and text[m.end()] == "\n" |
|
else f"{m.group(1)} " |
|
), |
|
text, |
|
) |
|
text = text.strip() |
|
return text |
|
|
|
|
|
REPLACE = [ |
|
("β", ","), |
|
("'", "'"), |
|
(":", ","), |
|
(";", ","), |
|
] |
|
|
|
engine = None |
|
wm = None |
|
|
|
|
|
def asr(chunk, model=None, prefix=None): |
|
import whisper |
|
|
|
global wm |
|
if model is not None: |
|
wm = model |
|
elif wm is None: |
|
wm = whisper.load_model("turbo", "cuda") |
|
|
|
"""Process audio with VAD and transcribe""" |
|
chunk = whisper.pad_or_trim(chunk) |
|
mel = whisper.log_mel_spectrogram(chunk, n_mels=wm.dims.n_mels).to(wm.device) |
|
options = whisper.DecodingOptions( |
|
language="en", without_timestamps=True, prefix=prefix |
|
) |
|
result = whisper.decode(wm, mel[None], options) |
|
return result[0].text |
|
|
|
|
|
def replace_numbers_with_words(text): |
|
global engine |
|
|
|
if engine is None: |
|
engine = inflect.engine() |
|
|
|
|
|
def number_to_words(match): |
|
number = match.group() |
|
return engine.number_to_words(number) + " " |
|
|
|
|
|
return re.sub(r"\d+", number_to_words, text) |
|
|
|
|
|
valid_non_speech = ["breath", "sigh", "laugh", "tut", "hesitate", "clearthroat"] |
|
valid_non_speech = [f"[{v}]" for v in valid_non_speech] |
|
|
|
|
|
def remove_all_invalid_non_speech(txt): |
|
""" |
|
Remove all non-speech markers that are not in the valid_non_speech list. |
|
Only keeps valid non-speech markers like [breath], [sigh], etc. |
|
""" |
|
|
|
bracket_pattern = r"\[([^\]]+)\]" |
|
brackets = re.findall(bracket_pattern, txt) |
|
|
|
|
|
for bracket in brackets: |
|
bracket_with_brackets = f"[{bracket}]" |
|
if bracket_with_brackets not in valid_non_speech and bracket != "pause": |
|
|
|
txt = txt.replace(bracket_with_brackets, "") |
|
|
|
return txt |
|
|
|
|
|
def simple_clean(text): |
|
text = re.sub(r"(\d+)am", r"\1 AM", text) |
|
text = re.sub(r"(\d+)pm", r"\1 PM", text) |
|
text = replace_numbers_with_words(text) |
|
text = ensure_spaces_around_tags(text) |
|
text = remove_all_invalid_non_speech(text) |
|
|
|
text = text.replace('"', "") |
|
text = text.replace("β", "") |
|
text = text.replace("β", "") |
|
text = text.replace("β", "'") |
|
text = text.replace("%", " percent") |
|
text = text.replace("*", "") |
|
text = text.replace("(", "") |
|
text = text.replace(")", "") |
|
text = text.replace(";", "") |
|
text = text.replace("β", " ") |
|
text = text.replace("β", "") |
|
text = text.replace(":", "") |
|
text = text.replace("β¦", "...") |
|
text = text.replace("s...", "s") |
|
|
|
|
|
text = re.sub(r"\n+", "\n", text) |
|
ntxt = re.sub(r" +", " ", text) |
|
|
|
|
|
ntxt = ntxt.strip() |
|
if not ntxt.endswith(".") or ntxt.endswith("?"): |
|
ntxt += "." |
|
ntxt += " [pause]" |
|
return ntxt |
|
|
|
|
|
@torch.inference_mode() |
|
def generate( |
|
self: Vui, |
|
text: str, |
|
prompt_codes: Tensor | None = None, |
|
temperature: float = 0.5, |
|
top_k: int | None = 150, |
|
top_p: float | None = None, |
|
max_gen_len: int = int(120 * 21.53), |
|
): |
|
text = simple_clean(text) |
|
with ( |
|
torch.autocast("cuda", torch.bfloat16, True), |
|
sdpa_kernel([SDPBackend.MATH]), |
|
): |
|
t1 = time.perf_counter() |
|
batch_size = 1 |
|
device = self.device |
|
self.dtype |
|
self.decoder.allocate_inference_cache(batch_size, device, torch.bfloat16) |
|
|
|
texts = [text] |
|
|
|
encoded = self.tokenizer( |
|
texts, |
|
padding="longest", |
|
return_tensors="pt", |
|
) |
|
|
|
input_ids = encoded.input_ids.to(device) |
|
text_embeddings = self.token_emb(input_ids) |
|
|
|
B = batch_size |
|
Q = self.config.model.n_quantizers |
|
|
|
if prompt_codes is None: |
|
prompt_codes = torch.zeros( |
|
(batch_size, Q, 0), dtype=torch.int64, device=device |
|
) |
|
else: |
|
prompt_codes = prompt_codes[:, :Q].repeat(batch_size, 1, 1) |
|
|
|
start_offset = prompt_codes.size(-1) |
|
|
|
pattern = self.pattern_provider.get_pattern(max_gen_len) |
|
|
|
unknown_token = -1 |
|
special_token_id = self.config.model.special_token_id |
|
|
|
|
|
codes = torch.full( |
|
(B, Q, max_gen_len), unknown_token, dtype=torch.int64, device=device |
|
) |
|
|
|
codes[:, :, :start_offset] = prompt_codes |
|
|
|
sequence, indexes, mask = pattern.build_pattern_sequence( |
|
codes, special_token_id |
|
) |
|
|
|
|
|
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) |
|
assert start_offset_sequence is not None |
|
|
|
prev_offset = 0 |
|
S = sequence.size(-1) |
|
|
|
do_prefill = True |
|
eos = self.config.model.audio_eos_id |
|
|
|
for offset in range(start_offset_sequence, S): |
|
|
|
curr_sequence = sequence[..., prev_offset:offset] |
|
audio_embeddings = ( |
|
sum([self.audio_embeddings[q](curr_sequence[:, q]) for q in range(Q)]) |
|
/ Q |
|
) |
|
|
|
if do_prefill: |
|
embeddings = torch.cat((text_embeddings, audio_embeddings), dim=1) |
|
T = embeddings.size(1) |
|
input_pos = torch.arange(0, T, device=device) |
|
do_prefill = False |
|
else: |
|
embeddings = audio_embeddings |
|
input_pos = torch.tensor([T], device=device) |
|
T += 1 |
|
|
|
out = self.decoder(embeddings, input_pos) |
|
|
|
if offset == 15: |
|
print("TTFB", time.perf_counter() - t1) |
|
|
|
logits = torch.stack( |
|
[self.audio_heads[q](out[:, -1]) for q in range(Q)], dim=1 |
|
) |
|
|
|
repetition_penalty = 1.4 |
|
history_window = 12 |
|
|
|
|
|
for q in range(Q): |
|
|
|
history_start = max(0, offset - history_window) |
|
token_history = sequence[0, q, history_start:offset] |
|
|
|
|
|
unique_tokens = torch.unique(token_history) |
|
unique_tokens = unique_tokens[unique_tokens != special_token_id] |
|
unique_tokens = unique_tokens[unique_tokens != eos] |
|
unique_tokens = unique_tokens[unique_tokens != unknown_token] |
|
|
|
if len(unique_tokens) > 0: |
|
|
|
logits[0, q, unique_tokens] = ( |
|
logits[0, q, unique_tokens] / repetition_penalty |
|
) |
|
|
|
if offset < 24.53 * 4: |
|
logits[..., eos] = -float("inf") |
|
|
|
probs = F.softmax(logits / temperature, dim=-1) |
|
|
|
|
|
if top_p is not None and top_k is not None: |
|
next_codes = sample_top_p_top_k(probs, top_p, top_k) |
|
elif top_p is not None and top_p > 0: |
|
next_codes = sample_top_p(probs, top_p) |
|
elif top_k is not None and top_k > 0: |
|
next_codes = sample_top_k(probs, top_k) |
|
else: |
|
next_codes = multinomial(probs, num_samples=1) |
|
|
|
next_codes = next_codes.repeat(batch_size, 1, 1) |
|
|
|
if (probs[..., eos] > 0.95).any(): |
|
print("breaking at", offset) |
|
break |
|
|
|
valid_mask = mask[..., offset : offset + 1].expand(B, -1, -1) |
|
next_codes[~valid_mask] = special_token_id |
|
|
|
sequence[..., offset : offset + 1] = torch.where( |
|
sequence[..., offset : offset + 1] == unknown_token, |
|
next_codes, |
|
sequence[..., offset : offset + 1], |
|
) |
|
|
|
prev_offset = offset |
|
|
|
|
|
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence( |
|
sequence, special_token=unknown_token |
|
) |
|
|
|
|
|
|
|
|
|
out_codes = out_codes[..., prompt_codes.shape[-1] : offset] |
|
return out_codes[[0]] |
|
|
|
|
|
@torch.inference_mode() |
|
def render( |
|
self: Vui, |
|
text: str, |
|
prompt_codes: Tensor | None = None, |
|
temperature: float = 0.5, |
|
top_k: int | None = 100, |
|
top_p: float | None = None, |
|
max_secs: int = 100, |
|
): |
|
""" |
|
Render audio from text. Uses generate for text < 1000 characters, |
|
otherwise breaks text into sections and uses chunking with context. |
|
""" |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
text = remove_all_invalid_non_speech(text) |
|
text = simple_clean(text) |
|
SR = self.codec.config.sample_rate |
|
HZ = self.codec.hz |
|
max_gen_len = int(HZ * max_secs) |
|
t1 = time.perf_counter() |
|
|
|
if len(text) < 1400: |
|
codes = generate( |
|
self, text, prompt_codes, temperature, top_k, top_p, max_gen_len |
|
) |
|
codes = codes[..., :-10] |
|
audio = self.codec.from_indices(codes) |
|
print("RTF", (audio.numel()/SR)/(time.perf_counter() - t1)) |
|
return audio.cpu() |
|
|
|
|
|
orig_codes = prompt_codes |
|
|
|
lines = text.split("\n") |
|
audios = [] |
|
prev_codes = prompt_codes |
|
prev_text = "" |
|
|
|
for i, line in enumerate(lines): |
|
run = True |
|
while run: |
|
current_text = prev_text + "\n" + line if prev_text else line |
|
current_text = current_text.strip() |
|
current_text = current_text.replace("...", "") |
|
current_text = current_text + " [pause]" |
|
|
|
|
|
maxlen = int(HZ * int(60 * len(current_text) / 500)) |
|
|
|
try: |
|
print("rendering", current_text) |
|
codes = generate( |
|
self, |
|
current_text, |
|
prompt_codes=prev_codes, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
max_gen_len=maxlen, |
|
) |
|
|
|
codes = codes[..., :-10] |
|
paudio = self.codec.from_indices(codes) |
|
prev_text = line |
|
prev_codes = codes |
|
audios.append(paudio) |
|
except KeyboardInterrupt: |
|
break |
|
except RuntimeError as e: |
|
prev_codes = orig_codes |
|
prev_text = "" |
|
print(e) |
|
|
|
return torch.cat(audios, dim=-1) |
|
|