File size: 12,198 Bytes
88afac1 c64babc 88afac1 c64babc 88afac1 3d69f83 88afac1 1a21598 88afac1 3d69f83 88afac1 3d69f83 88afac1 3d69f83 88afac1 3d69f83 88afac1 c976192 88afac1 1a21598 88afac1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 |
import re
import time
import inflect
import torch
import torch.nn.functional as F
from torchaudio.transforms import Resample
from torch import Tensor
from torch.nn.attention import SDPBackend, sdpa_kernel
from vui.model import Vui
from vui.sampling import multinomial, sample_top_k, sample_top_p, sample_top_p_top_k
resample = Resample(22050, 16000).cuda()
def ensure_spaces_around_tags(text: str):
# Add space before '[' if not preceded by space, '<', or '['
text = re.sub(
r"(?<![<\[\s])(\[)",
lambda m: (
f"\n{m.group(1)}"
if m.start() > 0 and text[m.start() - 1] == "\n"
else f" {m.group(1)}"
),
text,
)
# Add space after ']' if not preceded by digit+']' and not followed by space, '>', or ']'
text = re.sub(
r"(?<!\d\])(\])(?![>\]\s])",
lambda m: (
f"{m.group(1)}\n"
if m.end() < len(text) and text[m.end()] == "\n"
else f"{m.group(1)} "
),
text,
)
text = text.strip()
return text
REPLACE = [
("β", ","),
("'", "'"),
(":", ","),
(";", ","),
]
engine = None
wm = None
def asr(chunk, model=None, prefix=None):
import whisper
global wm
if model is not None:
wm = model
elif wm is None:
wm = whisper.load_model("turbo", "cuda")
"""Process audio with VAD and transcribe"""
chunk = whisper.pad_or_trim(chunk)
mel = whisper.log_mel_spectrogram(chunk, n_mels=wm.dims.n_mels).to(wm.device)
options = whisper.DecodingOptions(
language="en", without_timestamps=True, prefix=prefix
)
result = whisper.decode(wm, mel[None], options)
return result[0].text
def replace_numbers_with_words(text):
global engine
if engine is None:
engine = inflect.engine()
# Function to convert a number match to words
def number_to_words(match):
number = match.group()
return engine.number_to_words(number) + " "
# Replace digits with their word equivalents
return re.sub(r"\d+", number_to_words, text)
valid_non_speech = ["breath", "sigh", "laugh", "tut", "hesitate", "clearthroat"]
valid_non_speech = [f"[{v}]" for v in valid_non_speech]
def remove_all_invalid_non_speech(txt):
"""
Remove all non-speech markers that are not in the valid_non_speech list.
Only keeps valid non-speech markers like [breath], [sigh], etc.
"""
# Find all text within square brackets
bracket_pattern = r"\[([^\]]+)\]"
brackets = re.findall(bracket_pattern, txt)
# For each bracketed text, check if it's in our valid list
for bracket in brackets:
bracket_with_brackets = f"[{bracket}]"
if bracket_with_brackets not in valid_non_speech and bracket != "pause":
# If not valid, remove it from the text
txt = txt.replace(bracket_with_brackets, "")
return txt
def simple_clean(text):
text = re.sub(r"(\d+)am", r"\1 AM", text)
text = re.sub(r"(\d+)pm", r"\1 PM", text)
text = replace_numbers_with_words(text)
text = ensure_spaces_around_tags(text)
text = remove_all_invalid_non_speech(text)
text = text.replace('"', "")
text = text.replace("β", "")
text = text.replace("β", "")
text = text.replace("β", "'")
text = text.replace("%", " percent")
text = text.replace("*", "")
text = text.replace("(", "")
text = text.replace(")", "")
text = text.replace(";", "")
text = text.replace("β", " ")
text = text.replace("β", "")
text = text.replace(":", "")
text = text.replace("β¦", "...")
text = text.replace("s...", "s")
# replace repeating \n with just one \n
text = re.sub(r"\n+", "\n", text)
ntxt = re.sub(r" +", " ", text)
# Ensure that ntxt ends with . or ?
ntxt = ntxt.strip()
if not ntxt.endswith(".") or ntxt.endswith("?"):
ntxt += "."
ntxt += " [pause]"
return ntxt
@torch.inference_mode()
def generate(
self: Vui,
text: str,
prompt_codes: Tensor | None = None,
temperature: float = 0.5,
top_k: int | None = 150,
top_p: float | None = None,
max_gen_len: int = int(120 * 21.53),
):
text = simple_clean(text)
with (
torch.autocast("cuda", torch.bfloat16, True),
sdpa_kernel([SDPBackend.MATH]),
):
t1 = time.perf_counter()
batch_size = 1
device = self.device
self.dtype
self.decoder.allocate_inference_cache(batch_size, device, torch.bfloat16)
texts = [text]
encoded = self.tokenizer(
texts,
padding="longest",
return_tensors="pt",
)
input_ids = encoded.input_ids.to(device)
text_embeddings = self.token_emb(input_ids)
B = batch_size
Q = self.config.model.n_quantizers
if prompt_codes is None:
prompt_codes = torch.zeros(
(batch_size, Q, 0), dtype=torch.int64, device=device
)
else:
prompt_codes = prompt_codes[:, :Q].repeat(batch_size, 1, 1)
start_offset = prompt_codes.size(-1)
pattern = self.pattern_provider.get_pattern(max_gen_len)
# this token is used as default value for codes that are not generated yet
unknown_token = -1
special_token_id = self.config.model.special_token_id
# we generate codes up to the max_gen_len that will be mapped to the pattern sequence
codes = torch.full(
(B, Q, max_gen_len), unknown_token, dtype=torch.int64, device=device
)
codes[:, :, :start_offset] = prompt_codes
sequence, indexes, mask = pattern.build_pattern_sequence(
codes, special_token_id
)
# retrieve the start_offset in the sequence:
# it is the first sequence step that contains the `start_offset` timestep
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
assert start_offset_sequence is not None
prev_offset = 0
S = sequence.size(-1)
do_prefill = True
eos = self.config.model.audio_eos_id
for offset in range(start_offset_sequence, S):
# print(f"{prev_offset}:{offset}")
curr_sequence = sequence[..., prev_offset:offset]
audio_embeddings = (
sum([self.audio_embeddings[q](curr_sequence[:, q]) for q in range(Q)])
/ Q
)
if do_prefill:
embeddings = torch.cat((text_embeddings, audio_embeddings), dim=1)
T = embeddings.size(1)
input_pos = torch.arange(0, T, device=device)
do_prefill = False
else:
embeddings = audio_embeddings
input_pos = torch.tensor([T], device=device)
T += 1
out = self.decoder(embeddings, input_pos)
if offset == 15:
print("TTFB", time.perf_counter() - t1)
logits = torch.stack(
[self.audio_heads[q](out[:, -1]) for q in range(Q)], dim=1
)
repetition_penalty = 1.4
history_window = 12
# Get the history of generated tokens for each quantizer
for q in range(Q):
# Extract the history window for this quantizer
history_start = max(0, offset - history_window)
token_history = sequence[0, q, history_start:offset]
# Only apply penalty to tokens that appear in the history
unique_tokens = torch.unique(token_history)
unique_tokens = unique_tokens[unique_tokens != special_token_id]
unique_tokens = unique_tokens[unique_tokens != eos]
unique_tokens = unique_tokens[unique_tokens != unknown_token]
if len(unique_tokens) > 0:
# Apply penalty by dividing the logits for tokens that have appeared recently
logits[0, q, unique_tokens] = (
logits[0, q, unique_tokens] / repetition_penalty
)
if offset < 24.53 * 4:
logits[..., eos] = -float("inf")
probs = F.softmax(logits / temperature, dim=-1)
# print(probs.shape)
if top_p is not None and top_k is not None:
next_codes = sample_top_p_top_k(probs, top_p, top_k)
elif top_p is not None and top_p > 0:
next_codes = sample_top_p(probs, top_p)
elif top_k is not None and top_k > 0:
next_codes = sample_top_k(probs, top_k)
else:
next_codes = multinomial(probs, num_samples=1)
next_codes = next_codes.repeat(batch_size, 1, 1)
if (probs[..., eos] > 0.95).any():
print("breaking at", offset)
break
valid_mask = mask[..., offset : offset + 1].expand(B, -1, -1)
next_codes[~valid_mask] = special_token_id
sequence[..., offset : offset + 1] = torch.where(
sequence[..., offset : offset + 1] == unknown_token,
next_codes,
sequence[..., offset : offset + 1],
)
prev_offset = offset
# print(sequence.shape)
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(
sequence, special_token=unknown_token
)
# sanity checks over the returned codes and corresponding masks
# assert (out_codes[..., :max_gen_len] != unknown_token).all()
# assert (out_mask[..., :max_gen_len] == 1).all()
out_codes = out_codes[..., prompt_codes.shape[-1] : offset]
return out_codes[[0]]
@torch.inference_mode()
def render(
self: Vui,
text: str,
prompt_codes: Tensor | None = None,
temperature: float = 0.5,
top_k: int | None = 100,
top_p: float | None = None,
max_secs: int = 100,
):
"""
Render audio from text. Uses generate for text < 1000 characters,
otherwise breaks text into sections and uses chunking with context.
"""
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
text = remove_all_invalid_non_speech(text)
text = simple_clean(text)
SR = self.codec.config.sample_rate
HZ = self.codec.hz
max_gen_len = int(HZ * max_secs)
t1 = time.perf_counter()
if len(text) < 1400:
codes = generate(
self, text, prompt_codes, temperature, top_k, top_p, max_gen_len
)
codes = codes[..., :-10]
audio = self.codec.from_indices(codes)
print("RTF", (audio.numel()/SR)/(time.perf_counter() - t1))
return audio.cpu()
# Otherwise we have to do some clever chaining!
orig_codes = prompt_codes
lines = text.split("\n")
audios = []
prev_codes = prompt_codes
prev_text = ""
for i, line in enumerate(lines):
run = True
while run:
current_text = prev_text + "\n" + line if prev_text else line
current_text = current_text.strip()
current_text = current_text.replace("...", "")
current_text = current_text + " [pause]"
# Calculate max length based on text length
maxlen = int(HZ * int(60 * len(current_text) / 500))
try:
print("rendering", current_text)
codes = generate(
self,
current_text,
prompt_codes=prev_codes,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_gen_len=maxlen,
)
codes = codes[..., :-10]
paudio = self.codec.from_indices(codes)
prev_text = line
prev_codes = codes
audios.append(paudio)
except KeyboardInterrupt:
break
except RuntimeError as e:
prev_codes = orig_codes
prev_text = ""
print(e)
return torch.cat(audios, dim=-1)
|