vui-space / vui /inference.py
Harry Coultas Blum
trying to cast
3d69f83
import re
import time
import inflect
import torch
import torch.nn.functional as F
from torchaudio.transforms import Resample
from torch import Tensor
from torch.nn.attention import SDPBackend, sdpa_kernel
from vui.model import Vui
from vui.sampling import multinomial, sample_top_k, sample_top_p, sample_top_p_top_k
resample = Resample(22050, 16000).cuda()
def ensure_spaces_around_tags(text: str):
# Add space before '[' if not preceded by space, '<', or '['
text = re.sub(
r"(?<![<\[\s])(\[)",
lambda m: (
f"\n{m.group(1)}"
if m.start() > 0 and text[m.start() - 1] == "\n"
else f" {m.group(1)}"
),
text,
)
# Add space after ']' if not preceded by digit+']' and not followed by space, '>', or ']'
text = re.sub(
r"(?<!\d\])(\])(?![>\]\s])",
lambda m: (
f"{m.group(1)}\n"
if m.end() < len(text) and text[m.end()] == "\n"
else f"{m.group(1)} "
),
text,
)
text = text.strip()
return text
REPLACE = [
("β€”", ","),
("'", "'"),
(":", ","),
(";", ","),
]
engine = None
wm = None
def asr(chunk, model=None, prefix=None):
import whisper
global wm
if model is not None:
wm = model
elif wm is None:
wm = whisper.load_model("turbo", "cuda")
"""Process audio with VAD and transcribe"""
chunk = whisper.pad_or_trim(chunk)
mel = whisper.log_mel_spectrogram(chunk, n_mels=wm.dims.n_mels).to(wm.device)
options = whisper.DecodingOptions(
language="en", without_timestamps=True, prefix=prefix
)
result = whisper.decode(wm, mel[None], options)
return result[0].text
def replace_numbers_with_words(text):
global engine
if engine is None:
engine = inflect.engine()
# Function to convert a number match to words
def number_to_words(match):
number = match.group()
return engine.number_to_words(number) + " "
# Replace digits with their word equivalents
return re.sub(r"\d+", number_to_words, text)
valid_non_speech = ["breath", "sigh", "laugh", "tut", "hesitate", "clearthroat"]
valid_non_speech = [f"[{v}]" for v in valid_non_speech]
def remove_all_invalid_non_speech(txt):
"""
Remove all non-speech markers that are not in the valid_non_speech list.
Only keeps valid non-speech markers like [breath], [sigh], etc.
"""
# Find all text within square brackets
bracket_pattern = r"\[([^\]]+)\]"
brackets = re.findall(bracket_pattern, txt)
# For each bracketed text, check if it's in our valid list
for bracket in brackets:
bracket_with_brackets = f"[{bracket}]"
if bracket_with_brackets not in valid_non_speech and bracket != "pause":
# If not valid, remove it from the text
txt = txt.replace(bracket_with_brackets, "")
return txt
def simple_clean(text):
text = re.sub(r"(\d+)am", r"\1 AM", text)
text = re.sub(r"(\d+)pm", r"\1 PM", text)
text = replace_numbers_with_words(text)
text = ensure_spaces_around_tags(text)
text = remove_all_invalid_non_speech(text)
text = text.replace('"', "")
text = text.replace("”", "")
text = text.replace("β€œ", "")
text = text.replace("’", "'")
text = text.replace("%", " percent")
text = text.replace("*", "")
text = text.replace("(", "")
text = text.replace(")", "")
text = text.replace(";", "")
text = text.replace("–", " ")
text = text.replace("β€”", "")
text = text.replace(":", "")
text = text.replace("…", "...")
text = text.replace("s...", "s")
# replace repeating \n with just one \n
text = re.sub(r"\n+", "\n", text)
ntxt = re.sub(r" +", " ", text)
# Ensure that ntxt ends with . or ?
ntxt = ntxt.strip()
if not ntxt.endswith(".") or ntxt.endswith("?"):
ntxt += "."
ntxt += " [pause]"
return ntxt
@torch.inference_mode()
def generate(
self: Vui,
text: str,
prompt_codes: Tensor | None = None,
temperature: float = 0.5,
top_k: int | None = 150,
top_p: float | None = None,
max_gen_len: int = int(120 * 21.53),
):
text = simple_clean(text)
with (
torch.autocast("cuda", torch.bfloat16, True),
sdpa_kernel([SDPBackend.MATH]),
):
t1 = time.perf_counter()
batch_size = 1
device = self.device
self.dtype
self.decoder.allocate_inference_cache(batch_size, device, torch.bfloat16)
texts = [text]
encoded = self.tokenizer(
texts,
padding="longest",
return_tensors="pt",
)
input_ids = encoded.input_ids.to(device)
text_embeddings = self.token_emb(input_ids)
B = batch_size
Q = self.config.model.n_quantizers
if prompt_codes is None:
prompt_codes = torch.zeros(
(batch_size, Q, 0), dtype=torch.int64, device=device
)
else:
prompt_codes = prompt_codes[:, :Q].repeat(batch_size, 1, 1)
start_offset = prompt_codes.size(-1)
pattern = self.pattern_provider.get_pattern(max_gen_len)
# this token is used as default value for codes that are not generated yet
unknown_token = -1
special_token_id = self.config.model.special_token_id
# we generate codes up to the max_gen_len that will be mapped to the pattern sequence
codes = torch.full(
(B, Q, max_gen_len), unknown_token, dtype=torch.int64, device=device
)
codes[:, :, :start_offset] = prompt_codes
sequence, indexes, mask = pattern.build_pattern_sequence(
codes, special_token_id
)
# retrieve the start_offset in the sequence:
# it is the first sequence step that contains the `start_offset` timestep
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
assert start_offset_sequence is not None
prev_offset = 0
S = sequence.size(-1)
do_prefill = True
eos = self.config.model.audio_eos_id
for offset in range(start_offset_sequence, S):
# print(f"{prev_offset}:{offset}")
curr_sequence = sequence[..., prev_offset:offset]
audio_embeddings = (
sum([self.audio_embeddings[q](curr_sequence[:, q]) for q in range(Q)])
/ Q
)
if do_prefill:
embeddings = torch.cat((text_embeddings, audio_embeddings), dim=1)
T = embeddings.size(1)
input_pos = torch.arange(0, T, device=device)
do_prefill = False
else:
embeddings = audio_embeddings
input_pos = torch.tensor([T], device=device)
T += 1
out = self.decoder(embeddings, input_pos)
if offset == 15:
print("TTFB", time.perf_counter() - t1)
logits = torch.stack(
[self.audio_heads[q](out[:, -1]) for q in range(Q)], dim=1
)
repetition_penalty = 1.4
history_window = 12
# Get the history of generated tokens for each quantizer
for q in range(Q):
# Extract the history window for this quantizer
history_start = max(0, offset - history_window)
token_history = sequence[0, q, history_start:offset]
# Only apply penalty to tokens that appear in the history
unique_tokens = torch.unique(token_history)
unique_tokens = unique_tokens[unique_tokens != special_token_id]
unique_tokens = unique_tokens[unique_tokens != eos]
unique_tokens = unique_tokens[unique_tokens != unknown_token]
if len(unique_tokens) > 0:
# Apply penalty by dividing the logits for tokens that have appeared recently
logits[0, q, unique_tokens] = (
logits[0, q, unique_tokens] / repetition_penalty
)
if offset < 24.53 * 4:
logits[..., eos] = -float("inf")
probs = F.softmax(logits / temperature, dim=-1)
# print(probs.shape)
if top_p is not None and top_k is not None:
next_codes = sample_top_p_top_k(probs, top_p, top_k)
elif top_p is not None and top_p > 0:
next_codes = sample_top_p(probs, top_p)
elif top_k is not None and top_k > 0:
next_codes = sample_top_k(probs, top_k)
else:
next_codes = multinomial(probs, num_samples=1)
next_codes = next_codes.repeat(batch_size, 1, 1)
if (probs[..., eos] > 0.95).any():
print("breaking at", offset)
break
valid_mask = mask[..., offset : offset + 1].expand(B, -1, -1)
next_codes[~valid_mask] = special_token_id
sequence[..., offset : offset + 1] = torch.where(
sequence[..., offset : offset + 1] == unknown_token,
next_codes,
sequence[..., offset : offset + 1],
)
prev_offset = offset
# print(sequence.shape)
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(
sequence, special_token=unknown_token
)
# sanity checks over the returned codes and corresponding masks
# assert (out_codes[..., :max_gen_len] != unknown_token).all()
# assert (out_mask[..., :max_gen_len] == 1).all()
out_codes = out_codes[..., prompt_codes.shape[-1] : offset]
return out_codes[[0]]
@torch.inference_mode()
def render(
self: Vui,
text: str,
prompt_codes: Tensor | None = None,
temperature: float = 0.5,
top_k: int | None = 100,
top_p: float | None = None,
max_secs: int = 100,
):
"""
Render audio from text. Uses generate for text < 1000 characters,
otherwise breaks text into sections and uses chunking with context.
"""
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
text = remove_all_invalid_non_speech(text)
text = simple_clean(text)
SR = self.codec.config.sample_rate
HZ = self.codec.hz
max_gen_len = int(HZ * max_secs)
t1 = time.perf_counter()
if len(text) < 1400:
codes = generate(
self, text, prompt_codes, temperature, top_k, top_p, max_gen_len
)
codes = codes[..., :-10]
audio = self.codec.from_indices(codes)
print("RTF", (audio.numel()/SR)/(time.perf_counter() - t1))
return audio.cpu()
# Otherwise we have to do some clever chaining!
orig_codes = prompt_codes
lines = text.split("\n")
audios = []
prev_codes = prompt_codes
prev_text = ""
for i, line in enumerate(lines):
run = True
while run:
current_text = prev_text + "\n" + line if prev_text else line
current_text = current_text.strip()
current_text = current_text.replace("...", "")
current_text = current_text + " [pause]"
# Calculate max length based on text length
maxlen = int(HZ * int(60 * len(current_text) / 500))
try:
print("rendering", current_text)
codes = generate(
self,
current_text,
prompt_codes=prev_codes,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_gen_len=maxlen,
)
codes = codes[..., :-10]
paudio = self.codec.from_indices(codes)
prev_text = line
prev_codes = codes
audios.append(paudio)
except KeyboardInterrupt:
break
except RuntimeError as e:
prev_codes = orig_codes
prev_text = ""
print(e)
return torch.cat(audios, dim=-1)