Harry Coultas Blum commited on
Commit
88afac1
·
1 Parent(s): a2e6acf
app.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import gradio as gr
4
+ import torch
5
+
6
+ from vui.inference import render
7
+ from vui.model import Vui
8
+
9
+
10
+ def get_available_models():
11
+ """Extract all CAPs static variables from Vui class that end with .pt"""
12
+ models = {}
13
+ for attr_name in dir(Vui):
14
+ if attr_name.isupper():
15
+ attr_value = getattr(Vui, attr_name)
16
+ if isinstance(attr_value, str) and attr_value.endswith(".pt"):
17
+ models[attr_name] = attr_value
18
+ return models
19
+
20
+
21
+ AVAILABLE_MODELS = get_available_models()
22
+ print(f"Available models: {list(AVAILABLE_MODELS.keys())}")
23
+
24
+ current_model = None
25
+ current_model_name = None
26
+
27
+
28
+ def load_and_warm_model(model_name):
29
+ """Load and warm up a specific model"""
30
+ global current_model, current_model_name
31
+
32
+ if current_model_name == model_name and current_model is not None:
33
+ print(f"Model {model_name} already loaded and warmed up!")
34
+ return current_model
35
+
36
+ print(f"Loading model {model_name}...")
37
+ model_path = AVAILABLE_MODELS[model_name]
38
+ model = Vui.from_pretrained_inf(model_path).cuda()
39
+
40
+ print(f"Compiling model {model_name}...")
41
+ model.decoder = torch.compile(model.decoder, fullgraph=True)
42
+
43
+ print(f"Warming up model {model_name}...")
44
+ warmup_text = "Hello, this is a test. Let's say some random shizz"
45
+ render(
46
+ model,
47
+ warmup_text,
48
+ max_secs=10,
49
+ )
50
+
51
+ current_model = model
52
+ current_model_name = model_name
53
+ print(f"Model {model_name} loaded and warmed up successfully!")
54
+ return model
55
+
56
+
57
+ # Load default model (COHOST)
58
+ default_model = (
59
+ "COHOST" if "COHOST" in AVAILABLE_MODELS else list(AVAILABLE_MODELS.keys())[0]
60
+ )
61
+ model = load_and_warm_model(default_model)
62
+
63
+ # Preload sample 1 (index 0) with current model
64
+ print("Preloading sample 1...")
65
+ sample_1_text = """Welcome to Fluxions, the podcast where... we uh explore how technology is shaping the world around us. I'm your host, Alex.
66
+ [breath] And I'm Jamie um [laugh] today, we're diving into a [hesitate] topic that's transforming customer service uh voice technology for agents.
67
+ That's right. We're [hesitate] talking about the AI-driven tools that are making those long, frustrating customer service calls a little more bearable, for both the customer and the agents."""
68
+
69
+ sample_1_audio = render(
70
+ current_model,
71
+ sample_1_text,
72
+ )
73
+ sample_1_audio = sample_1_audio.cpu()
74
+ sample_1_audio = sample_1_audio[..., :-2000] # Trim end artifacts
75
+ preloaded_sample_1 = (model.codec.config.sample_rate, sample_1_audio.flatten().numpy())
76
+ print("Sample 1 preloaded successfully!")
77
+ print("Models ready for inference!")
78
+
79
+ # Sample texts for quick testing - keeping original examples intact
80
+ SAMPLE_TEXTS = [
81
+ """Welcome to Fluxions, the podcast where... we uh explore how technology is shaping the world around us. I'm your host, Alex.
82
+ [breath] And I'm Jamie um [laugh] today, we're diving into a [hesitate] topic that's transforming customer service uh voice technology for agents.
83
+ That's right. We're [hesitate] talking about the AI-driven tools that are making those long, frustrating customer service calls a little more bearable, for both the customer and the agents.""",
84
+ """Um, hey Sarah, so I just left the meeting with the, uh, rabbit focus group and they are absolutely loving the new heritage carrots! Like, I've never seen such enthusiastic thumping in my life! The purple ones are testing through the roof - apparently the flavor profile is just amazing - and they're willing to pay a premium for them! We need to, like, triple production on those immediately and maybe consider a subscription model? Anyway, gotta go, but let's touch base tomorrow about scaling this before the Easter rush hits!""",
85
+ """What an absolute joke, like I'm really not enjoying this situation where I'm just forced to say things.""",
86
+ """ So [breath] I don't know if you've been there [breath] but I'm really pissed off.
87
+ Oh no! Why, what happened?
88
+ Well I went to this cafe hearth, and they gave me the worst toastie I've ever had, it didn't come with salad it was just raw.
89
+ Well that's awful what kind of toastie was it?
90
+ It was supposed to be a chicken bacon lettuce tomatoe, but it was fucking shite, like really bad and I honestly would have preferred to eat my own shit.
91
+ [laugh] well, it must have been awful for you, I'm sorry to hear that, why don't we move on to brighter topics, like the good old weather?""",
92
+ ]
93
+
94
+
95
+ def text_to_speech(text, temperature=0.5, top_k=100, top_p=None, max_duration=60):
96
+ """
97
+ Convert text to speech using the current Vui model
98
+
99
+ Args:
100
+ text (str): Input text to convert to speech
101
+ temperature (float): Sampling temperature (0.1-1.0)
102
+ top_k (int): Top-k sampling parameter
103
+ top_p (float): Top-p sampling parameter (None to disable)
104
+ max_duration (int): Maximum audio duration in seconds
105
+
106
+ Returns:
107
+ tuple: (sample_rate, audio_array) for Gradio audio output
108
+ """
109
+ if not text.strip():
110
+ return None, "Please enter some text to convert to speech."
111
+
112
+ if current_model is None:
113
+ return None, "No model loaded. Please select a model first."
114
+
115
+ print(f"Generating speech for: {text[:50]}... using model {current_model_name}")
116
+
117
+ # Generate speech using render
118
+ t1 = time.perf_counter()
119
+ result = render(
120
+ current_model,
121
+ text.strip(),
122
+ temperature=temperature,
123
+ top_k=top_k,
124
+ top_p=top_p,
125
+ max_secs=max_duration,
126
+ )
127
+
128
+ # Long text: render returns (codes, text, audio) tuple
129
+ waveform = result
130
+
131
+ # waveform is already decoded audio from generate_infinite
132
+ waveform = waveform.cpu()
133
+ sr = current_model.codec.config.sample_rate
134
+
135
+ # Calculate generation speed
136
+ generation_time = time.perf_counter() - t1
137
+ audio_duration = waveform.shape[-1] / sr
138
+ speed_factor = audio_duration / generation_time
139
+
140
+ # Trim end artifacts if needed
141
+ if waveform.shape[-1] > 2000:
142
+ waveform = waveform[..., :-2000]
143
+
144
+ # Convert to numpy array for Gradio
145
+ audio_array = waveform.flatten().numpy()
146
+
147
+ info = f"Generated {audio_duration:.1f}s of audio in {generation_time:.1f}s ({speed_factor:.1f}x realtime) with {current_model_name}"
148
+ print(info)
149
+
150
+ return (sr, audio_array), info
151
+
152
+
153
+ def change_model(model_name):
154
+ """Change the active model and return status"""
155
+ try:
156
+ load_and_warm_model(model_name)
157
+ return f"Successfully loaded and warmed up model: {model_name}"
158
+ except Exception as e:
159
+ return f"Error loading model {model_name}: {str(e)}"
160
+
161
+
162
+ def load_sample_text(sample_index):
163
+ """Load a sample text for quick testing"""
164
+ if 0 <= sample_index < len(SAMPLE_TEXTS):
165
+ return SAMPLE_TEXTS[sample_index]
166
+ return ""
167
+
168
+
169
+ # Create Gradio interface
170
+ with gr.Blocks(
171
+ title="Vui",
172
+ theme=gr.themes.Soft(),
173
+ head="""
174
+ <script>
175
+ document.addEventListener('DOMContentLoaded', function() {
176
+ // Add keyboard shortcuts
177
+ document.addEventListener('keydown', function(e) {
178
+ // Ctrl/Cmd + Enter to generate (but not when Shift is pressed)
179
+ if ((e.ctrlKey) && e.key === 'Enter' && !e.shiftKey) {
180
+ e.preventDefault();
181
+ const generateBtn = document.querySelector('button[variant="primary"]');
182
+ if (generateBtn && !generateBtn.disabled) {
183
+ generateBtn.click();
184
+ }
185
+ }
186
+ else if ((e.ctrlKey) && e.code === 'Space') {
187
+ e.preventDefault();
188
+ const audioElement = document.querySelector('audio');
189
+ if (audioElement) {
190
+ if (audioElement.paused) {
191
+ audioElement.play();
192
+ } else {
193
+ audioElement.pause();
194
+ }
195
+ }
196
+ }
197
+ });
198
+
199
+ // Auto-play audio when it's updated
200
+ const observer = new MutationObserver(function(mutations) {
201
+ mutations.forEach(function(mutation) {
202
+ if (mutation.type === 'childList') {
203
+ const audioElements = document.querySelectorAll('audio');
204
+ audioElements.forEach(function(audio) {
205
+ if (audio.src && !audio.dataset.hasAutoplayListener) {
206
+ audio.dataset.hasAutoplayListener = 'true';
207
+ audio.addEventListener('loadeddata', function() {
208
+ // Small delay to ensure audio is ready
209
+ setTimeout(() => {
210
+ audio.play().catch(e => {
211
+ console.log('Autoplay prevented by browser:', e);
212
+ });
213
+ }, 100);
214
+ });
215
+ }
216
+ });
217
+ }
218
+ });
219
+ });
220
+
221
+ observer.observe(document.body, {
222
+ childList: true,
223
+ subtree: true
224
+ });
225
+
226
+ });
227
+ </script>
228
+ """,
229
+ ) as demo:
230
+
231
+ gr.Markdown(
232
+ "**Keyboard Shortcuts:** `Ctrl + Enter` to generate` or Ctrl + Space to pause"
233
+ )
234
+
235
+ with gr.Row():
236
+ with gr.Column(scale=2):
237
+ # Model selector
238
+ model_dropdown = gr.Dropdown(
239
+ choices=list(AVAILABLE_MODELS.keys()),
240
+ value=default_model,
241
+ label=None,
242
+ info="Select a voice model",
243
+ )
244
+
245
+ # Model status
246
+ model_status = gr.Textbox(
247
+ label=None,
248
+ value=f"Model {default_model} loaded and ready",
249
+ interactive=False,
250
+ lines=1,
251
+ )
252
+
253
+ # Text input
254
+ text_input = gr.Textbox(
255
+ label=None,
256
+ placeholder="Enter the text you want to convert to speech...",
257
+ lines=5,
258
+ max_lines=10,
259
+ )
260
+
261
+ with gr.Column(scale=1):
262
+ # Audio output with autoplay
263
+ audio_output = gr.Audio(
264
+ label="Generated Speech", type="numpy", autoplay=True # Enable autoplay
265
+ )
266
+
267
+ # Info output
268
+ info_output = gr.Textbox(
269
+ label="Generation Info", lines=3, interactive=False
270
+ )
271
+
272
+ with gr.Row():
273
+ with gr.Column(scale=2):
274
+
275
+ # Sample text buttons
276
+ gr.Markdown("**Quick samples:**")
277
+ with gr.Row():
278
+ sample_btns = []
279
+ for i, sample in enumerate(SAMPLE_TEXTS):
280
+ btn = gr.Button(f"Sample {i+1}", size="sm")
281
+ if i == 0: # Sample 1 (index 0) - use preloaded audio
282
+
283
+ def load_preloaded_sample_1():
284
+ return (
285
+ SAMPLE_TEXTS[0],
286
+ preloaded_sample_1,
287
+ "Preloaded sample 1 audio",
288
+ )
289
+
290
+ btn.click(
291
+ fn=load_preloaded_sample_1,
292
+ outputs=[text_input, audio_output, info_output],
293
+ )
294
+ else:
295
+ btn.click(
296
+ fn=lambda idx=i: SAMPLE_TEXTS[idx], outputs=text_input
297
+ )
298
+
299
+ # Generation parameters
300
+ with gr.Accordion("Advanced Settings", open=False):
301
+ temperature = gr.Slider(
302
+ minimum=0.1,
303
+ maximum=1.0,
304
+ value=0.5,
305
+ step=0.1,
306
+ label="Temperature",
307
+ info="Higher values = more varied speech",
308
+ )
309
+
310
+ top_k = gr.Slider(
311
+ minimum=1,
312
+ maximum=200,
313
+ value=100,
314
+ step=1,
315
+ label="Top-K",
316
+ info="Number of top tokens to consider",
317
+ )
318
+
319
+ use_top_p = gr.Checkbox(label="Use Top-P sampling", value=False)
320
+ top_p = gr.Slider(
321
+ minimum=0.1,
322
+ maximum=1.0,
323
+ value=0.9,
324
+ step=0.05,
325
+ label="Top-P",
326
+ info="Cumulative probability threshold",
327
+ visible=False,
328
+ )
329
+
330
+ max_duration = gr.Slider(
331
+ minimum=5,
332
+ maximum=120,
333
+ value=60,
334
+ step=5,
335
+ label="Max Duration (seconds)",
336
+ info="Maximum length of generated audio",
337
+ )
338
+
339
+ # Show/hide top_p based on checkbox
340
+ use_top_p.change(
341
+ fn=lambda x: gr.update(visible=x), inputs=use_top_p, outputs=top_p
342
+ )
343
+
344
+ # Generate button
345
+ generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
346
+
347
+ # Examples section
348
+ gr.Markdown("## 📝 Example Texts")
349
+ with gr.Accordion("View example texts", open=False):
350
+ for i, sample in enumerate(SAMPLE_TEXTS):
351
+ gr.Markdown(f"**Sample {i+1}:** {sample}")
352
+
353
+ # Connect the model change function
354
+ model_dropdown.change(fn=change_model, inputs=model_dropdown, outputs=model_status)
355
+
356
+ # Connect the generate function
357
+ def generate_wrapper(text, temp, k, use_p, p, duration):
358
+ top_p_val = p if use_p else None
359
+ return text_to_speech(text, temp, k, top_p_val, duration)
360
+
361
+ generate_btn.click(
362
+ fn=generate_wrapper,
363
+ inputs=[text_input, temperature, top_k, use_top_p, top_p, max_duration],
364
+ outputs=[audio_output, info_output],
365
+ )
366
+
367
+ # Also allow Enter key to generate
368
+ text_input.submit(
369
+ fn=generate_wrapper,
370
+ inputs=[text_input, temperature, top_k, use_top_p, top_p, max_duration],
371
+ outputs=[audio_output, info_output],
372
+ )
373
+
374
+ # Auto-load sample 1 on startup
375
+ demo.load(
376
+ fn=lambda: (
377
+ SAMPLE_TEXTS[0],
378
+ preloaded_sample_1,
379
+ "Sample 1 preloaded and ready!",
380
+ ),
381
+ outputs=[text_input, audio_output, info_output],
382
+ )
383
+
384
+ if __name__ == "__main__":
385
+ demo.launch(server_name="0.0.0.0", share=True)
inference.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+
3
+ from vui.inference import render
4
+ from vui.model import Vui
5
+
6
+ model = Vui.from_pretrained().cuda()
7
+ waveform = render(
8
+ model,
9
+ "Hey, here is some random stuff, usually something quite long as the shorter the text the less likely the model can cope!",
10
+ )
11
+ print(waveform.shape)
12
+ torchaudio.save("out.opus", waveform[0], 22050)
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops
2
+ huggingface_hub[hf_transfer]
3
+ inflect
4
+ gradio
5
+ numba
6
+ numpy
7
+ openai-whisper
8
+ feedparser
9
+ pydantic
10
+ pyannote.audio
11
+ soundfile
12
+ sphn
13
+ tiktoken
14
+ torch
15
+ torchaudio
16
+ tqdm
17
+ transformers
src/vui/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.0"
src/vui/config.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class VuiConfig(BaseModel):
7
+ max_text_tokens: int = 100
8
+ text_size: int = -1
9
+ max_audio_tokens: int = 100
10
+
11
+ n_quantizers: int = 9
12
+ codebook_size: int = 1000
13
+ special_token_id: int = 1000
14
+ audio_eos_id: int = 1000 + 1
15
+ audio_pad_id: int = 1000 + 1 + 1
16
+ d_model: int = 512
17
+ n_layers: int = 6
18
+ n_heads: int = 8
19
+ bias: bool = False
20
+ dropout: float = 0.0
21
+ use_rotary_emb: bool = True
22
+ rope_dim: int | None = None
23
+ rope_theta: float = 10_000.0
24
+ rope_theta_rescale_factor: float = 1.0
25
+
26
+
27
+ class Config(BaseModel):
28
+ name: str = "base"
29
+
30
+ checkpoint: str | dict | None = None
31
+
32
+ model: VuiConfig = VuiConfig()
33
+
34
+
35
+ ALL = []
36
+ current_module = sys.modules[__name__]
37
+ for name in dir(current_module):
38
+ if name.isupper() and isinstance(getattr(current_module, name), Config):
39
+ ALL.append(getattr(current_module, name))
40
+
41
+ CONFIGS = {v.name: v for v in ALL}
src/vui/fluac.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from contextlib import nullcontext
3
+ from functools import partial, wraps
4
+ from os import path
5
+ from typing import List, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import pack, rearrange, unpack
11
+ from einops.layers.torch import Rearrange
12
+ from pydantic import BaseModel
13
+ from torch import Tensor, int32
14
+ from torch.amp import autocast
15
+ from torch.nn import Module
16
+ from torch.nn.utils.parametrizations import weight_norm
17
+
18
+ from vui.utils import decompile_state_dict
19
+
20
+
21
+ def exists(v):
22
+ return v is not None
23
+
24
+
25
+ def default(*args):
26
+ for arg in args:
27
+ if exists(arg):
28
+ return arg
29
+ return None
30
+
31
+
32
+ def maybe(fn):
33
+ @wraps(fn)
34
+ def inner(x, *args, **kwargs):
35
+ if not exists(x):
36
+ return x
37
+ return fn(x, *args, **kwargs)
38
+
39
+ return inner
40
+
41
+
42
+ def pack_one(t, pattern):
43
+ return pack([t], pattern)
44
+
45
+
46
+ def unpack_one(t, ps, pattern):
47
+ return unpack(t, ps, pattern)[0]
48
+
49
+
50
+ def round_ste(z: Tensor) -> Tensor:
51
+ """Round with straight through gradients."""
52
+ zhat = z.round()
53
+ return z + (zhat - z).detach()
54
+
55
+
56
+ class FSQ(Module):
57
+ def __init__(
58
+ self,
59
+ levels: List[int],
60
+ dim: int | None = None,
61
+ num_codebooks: int = 1,
62
+ keep_num_codebooks_dim: bool | None = None,
63
+ allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
64
+ channel_first: bool = True,
65
+ projection_has_bias: bool = True,
66
+ return_indices=True,
67
+ force_quantization_f32: bool = True,
68
+ ):
69
+ super().__init__()
70
+
71
+ _levels = torch.tensor(levels, dtype=int32)
72
+ self.register_buffer("_levels", _levels, persistent=False)
73
+
74
+ _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
75
+ self.register_buffer("_basis", _basis, persistent=False)
76
+
77
+ codebook_dim = len(levels)
78
+ self.codebook_dim = codebook_dim
79
+
80
+ effective_codebook_dim = codebook_dim * num_codebooks
81
+ self.num_codebooks = num_codebooks
82
+ self.effective_codebook_dim = effective_codebook_dim
83
+
84
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
85
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
86
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
87
+
88
+ self.dim = default(dim, len(_levels) * num_codebooks)
89
+
90
+ self.channel_first = channel_first
91
+
92
+ has_projections = self.dim != effective_codebook_dim
93
+ self.project_in = (
94
+ nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias)
95
+ if has_projections
96
+ else nn.Identity()
97
+ )
98
+ self.project_out = (
99
+ nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias)
100
+ if has_projections
101
+ else nn.Identity()
102
+ )
103
+
104
+ self.has_projections = has_projections
105
+
106
+ self.return_indices = return_indices
107
+ if return_indices:
108
+ self.codebook_size = self._levels.prod().item()
109
+ implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
110
+ self.register_buffer(
111
+ "implicit_codebook", implicit_codebook, persistent=False
112
+ )
113
+
114
+ self.allowed_dtypes = allowed_dtypes
115
+ self.force_quantization_f32 = force_quantization_f32
116
+
117
+ def bound(self, z, eps: float = 1e-3):
118
+ """Bound `z`, an array of shape (..., d)."""
119
+ half_l = (self._levels - 1) * (1 + eps) / 2
120
+ offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
121
+ shift = (offset / half_l).atanh()
122
+ return (z + shift).tanh() * half_l - offset
123
+
124
+ def quantize(self, z):
125
+ """Quantizes z, returns quantized zhat, same shape as z."""
126
+ quantized = round_ste(self.bound(z))
127
+ half_width = self._levels // 2 # Renormalize to [-1, 1].
128
+ return quantized / half_width
129
+
130
+ def _scale_and_shift(self, zhat_normalized):
131
+ half_width = self._levels // 2
132
+ return (zhat_normalized * half_width) + half_width
133
+
134
+ def _scale_and_shift_inverse(self, zhat):
135
+ half_width = self._levels // 2
136
+ return (zhat - half_width) / half_width
137
+
138
+ def _indices_to_codes(self, indices):
139
+ level_indices = self.indices_to_level_indices(indices)
140
+ codes = self._scale_and_shift_inverse(level_indices)
141
+ return codes
142
+
143
+ def codes_to_indices(self, zhat):
144
+ """Converts a `code` to an index in the codebook."""
145
+ assert zhat.shape[-1] == self.codebook_dim
146
+ zhat = self._scale_and_shift(zhat)
147
+ return (zhat * self._basis).sum(dim=-1).to(int32)
148
+
149
+ def indices_to_level_indices(self, indices):
150
+ """Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings"""
151
+ indices = rearrange(indices, "... -> ... 1")
152
+ codes_non_centered = (indices // self._basis) % self._levels
153
+ return codes_non_centered
154
+
155
+ def indices_to_codes(self, indices):
156
+ """Inverse of `codes_to_indices`."""
157
+ assert exists(indices)
158
+
159
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
160
+
161
+ codes = self._indices_to_codes(indices)
162
+
163
+ if self.keep_num_codebooks_dim:
164
+ codes = rearrange(codes, "... c d -> ... (c d)")
165
+
166
+ codes = self.project_out(codes)
167
+
168
+ if is_img_or_video or self.channel_first:
169
+ codes = rearrange(codes, "b ... d -> b d ...")
170
+
171
+ return codes
172
+
173
+ def forward(self, z: Tensor):
174
+ """
175
+ einstein notation
176
+ b - batch
177
+ n - sequence (or flattened spatial dimensions)
178
+ d - feature dimension
179
+ c - number of codebook dim
180
+ """
181
+ device_type = z.device.type
182
+
183
+ with torch.autocast(device_type=device_type, enabled=False):
184
+ if self.channel_first:
185
+ z = rearrange(z, "b d ... -> b ... d")
186
+ z, ps = pack_one(z, "b * d")
187
+
188
+ assert (
189
+ z.shape[-1] == self.dim
190
+ ), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
191
+
192
+ z = self.project_in(z)
193
+
194
+ z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
195
+
196
+ # whether to force quantization step to be full precision or not
197
+
198
+ force_f32 = self.force_quantization_f32
199
+ quantization_context = (
200
+ partial(autocast, device_type=device_type, enabled=False)
201
+ if force_f32
202
+ else nullcontext
203
+ )
204
+
205
+ with quantization_context():
206
+ orig_dtype = z.dtype
207
+
208
+ if force_f32 and orig_dtype not in self.allowed_dtypes:
209
+ z = z.float()
210
+
211
+ codes = self.quantize(z)
212
+
213
+ # returning indices could be optional
214
+
215
+ indices = None
216
+
217
+ if self.return_indices:
218
+ indices = self.codes_to_indices(codes)
219
+
220
+ codes = rearrange(codes, "b n c d -> b n (c d)")
221
+
222
+ codes = codes.type(orig_dtype)
223
+
224
+ # project out
225
+
226
+ out = self.project_out(codes)
227
+
228
+ # reconstitute image or video dimensions
229
+
230
+ if self.channel_first:
231
+ out = unpack_one(out, ps, "b * d")
232
+ out = rearrange(out, "b ... d -> b d ...")
233
+
234
+ indices = maybe(unpack_one)(indices, ps, "b * c")
235
+
236
+ if not self.keep_num_codebooks_dim and self.return_indices:
237
+ indices = maybe(rearrange)(indices, "... 1 -> ...")
238
+
239
+ # return quantized output and indices
240
+
241
+ return out, indices
242
+
243
+
244
+ def WNConv1d(*args, **kwargs):
245
+ return weight_norm(nn.Conv1d(*args, **kwargs))
246
+
247
+
248
+ def WNConvTranspose1d(*args, **kwargs):
249
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
250
+
251
+
252
+ # Scripting this brings model speed up 1.4x
253
+ @torch.jit.script
254
+ def snake(x, alpha):
255
+ shape = x.shape
256
+ x = x.reshape(shape[0], shape[1], -1)
257
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
258
+ x = x.reshape(shape)
259
+ return x
260
+
261
+
262
+ class Snake1d(nn.Module):
263
+ def __init__(self, channels):
264
+ super().__init__()
265
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
266
+
267
+ def forward(self, x):
268
+ return snake(x, self.alpha)
269
+
270
+
271
+ def init_weights(m):
272
+ if isinstance(m, nn.Conv1d):
273
+ nn.init.trunc_normal_(m.weight, std=0.02)
274
+ nn.init.constant_(m.bias, 0)
275
+
276
+
277
+ class ResidualUnit(nn.Module):
278
+ def __init__(self, dim: int = 16, dilation: int = 1):
279
+ super().__init__()
280
+ pad = ((7 - 1) * dilation) // 2
281
+ self.block = nn.Sequential(
282
+ Snake1d(dim),
283
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
284
+ Snake1d(dim),
285
+ WNConv1d(dim, dim, kernel_size=1),
286
+ )
287
+
288
+ def forward(self, x):
289
+ y = self.block(x)
290
+ pad = (x.shape[-1] - y.shape[-1]) // 2
291
+ if pad > 0:
292
+ x = x[..., pad:-pad]
293
+ return x + y
294
+
295
+
296
+ class EncoderBlock(nn.Module):
297
+ def __init__(self, dim: int = 16, stride: int = 1):
298
+ super().__init__()
299
+ self.block = nn.Sequential(
300
+ ResidualUnit(dim // 2, dilation=1),
301
+ ResidualUnit(dim // 2, dilation=3),
302
+ ResidualUnit(dim // 2, dilation=9),
303
+ Snake1d(dim // 2),
304
+ WNConv1d(
305
+ dim // 2,
306
+ dim,
307
+ kernel_size=2 * stride,
308
+ stride=stride,
309
+ padding=math.ceil(stride / 2),
310
+ ),
311
+ )
312
+
313
+ def forward(self, x):
314
+ return self.block(x)
315
+
316
+
317
+ class Encoder(nn.Module):
318
+ def __init__(
319
+ self,
320
+ d_model: int = 64,
321
+ strides: list = [2, 4, 8, 8],
322
+ d_latent: int = 64,
323
+ ):
324
+ super().__init__()
325
+ # Create first convolution
326
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
327
+
328
+ # Create EncoderBlocks that double channels as they downsample by `stride`
329
+ for stride in strides:
330
+ d_model *= 2
331
+ self.block += [EncoderBlock(d_model, stride=stride)]
332
+
333
+ # Create last convolution
334
+ self.block += [
335
+ Snake1d(d_model),
336
+ WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
337
+ ]
338
+
339
+ # Wrap black into nn.Sequential
340
+ self.block = nn.Sequential(*self.block)
341
+ self.enc_dim = d_model
342
+
343
+ def forward(self, x):
344
+ return self.block(x)
345
+
346
+
347
+ class DecoderBlock(nn.Module):
348
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
349
+ super().__init__()
350
+ self.block = nn.Sequential(
351
+ Snake1d(input_dim),
352
+ WNConvTranspose1d(
353
+ input_dim,
354
+ output_dim,
355
+ kernel_size=2 * stride,
356
+ stride=stride,
357
+ padding=math.ceil(stride / 2),
358
+ ),
359
+ ResidualUnit(output_dim, dilation=1),
360
+ ResidualUnit(output_dim, dilation=3),
361
+ ResidualUnit(output_dim, dilation=9),
362
+ )
363
+
364
+ def forward(self, x):
365
+ return self.block(x)
366
+
367
+
368
+ class Decoder(nn.Module):
369
+ def __init__(
370
+ self,
371
+ input_channel: int,
372
+ channels: int,
373
+ rates: list[int],
374
+ d_out: int = 1,
375
+ ):
376
+ super().__init__()
377
+
378
+ # Add first conv layer
379
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
380
+
381
+ # Add upsampling + MRF blocks
382
+ for i, stride in enumerate(rates):
383
+ input_dim = channels // 2**i
384
+ output_dim = channels // 2 ** (i + 1)
385
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
386
+
387
+ # Add final conv layer
388
+ layers += [
389
+ Snake1d(output_dim),
390
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
391
+ nn.Tanh(),
392
+ ]
393
+
394
+ self.model = nn.Sequential(*layers)
395
+
396
+ # @torch.compile(dynamic=True)
397
+ def forward(self, z: Tensor):
398
+ return self.model(z)
399
+
400
+
401
+ class FiniteScalarQuantize(nn.Module):
402
+ def __init__(
403
+ self, latent_dim: int, levels: list[int], *, stride: int = 1, mlp: bool = False
404
+ ):
405
+ super().__init__()
406
+
407
+ self.stride = stride
408
+
409
+ codebook_dim = len(levels)
410
+
411
+ self.in_proj = WNConv1d(latent_dim, codebook_dim, kernel_size=1)
412
+ self.quantize = FSQ(levels=levels, channel_first=True)
413
+ self.out_proj = WNConv1d(codebook_dim, latent_dim, kernel_size=1)
414
+
415
+ if mlp:
416
+ self.mlp = nn.Sequential(
417
+ Rearrange("B C T -> B T C"),
418
+ nn.Linear(latent_dim, 4 * latent_dim),
419
+ nn.GELU(),
420
+ nn.Linear(4 * latent_dim, latent_dim),
421
+ Rearrange("B T C -> B C T"),
422
+ )
423
+ else:
424
+ self.mlp = None
425
+
426
+ def from_indices(self, indices: Tensor):
427
+ B, T = indices.size()
428
+ z_q = self.quantize.indices_to_codes(indices)
429
+ z_q = self.out_proj(z_q)
430
+ return z_q
431
+
432
+ def forward(self, z: Tensor, *args):
433
+ if self.stride > 1:
434
+ z = F.avg_pool1d(z, self.stride, stride=self.stride)
435
+
436
+ z_e = self.in_proj(z) # z_e : (B x D x T)
437
+
438
+ # we're channels first
439
+ # scale = scale.unsqueeze(-1)
440
+
441
+ # z_e = z_e / scale
442
+ z_q, indices = self.quantize(z_e)
443
+ # z_q = z_q * scale
444
+
445
+ z_q = self.out_proj(z_q)
446
+
447
+ if self.stride > 1:
448
+ z_e = z_e.repeat_interleave(self.stride, dim=-1)
449
+ z_q = z_q.repeat_interleave(self.stride, dim=-1)
450
+ indices = indices.repeat_interleave(self.stride, dim=-1)
451
+
452
+ if self.mlp is not None:
453
+ z_q = self.mlp(z_q)
454
+
455
+ return z_q, indices, z_e
456
+
457
+
458
+ class ResidualFiniteScalarQuantize(nn.Module):
459
+ def __init__(
460
+ self,
461
+ *,
462
+ latent_dim: int,
463
+ n_quantizers: int,
464
+ levels: list[int],
465
+ strides: list[int] | None = None,
466
+ quantizer_dropout: float = 0.0,
467
+ mlp: bool = False,
468
+ ):
469
+ super().__init__()
470
+
471
+ self.n_quantizers = n_quantizers
472
+ self.quantizer_dropout = quantizer_dropout
473
+
474
+ strides = [1] * n_quantizers if strides is None else strides
475
+
476
+ assert (
477
+ len(strides) == n_quantizers
478
+ ), "Strides must be provided for each codebook"
479
+
480
+ scales = []
481
+ quantizers = []
482
+ levels_tensor = torch.tensor(levels, dtype=torch.float32)
483
+
484
+ for i in range(n_quantizers):
485
+ scales.append((levels_tensor - 1) ** -i)
486
+ quantizers.append(
487
+ FiniteScalarQuantize(
488
+ latent_dim=latent_dim, levels=levels, stride=strides[i], mlp=mlp
489
+ )
490
+ )
491
+
492
+ self.quantizers = nn.ModuleList(quantizers)
493
+
494
+ self.register_buffer("scales", torch.stack(scales), persistent=False)
495
+
496
+ codebooks = [
497
+ quantizer.quantize.implicit_codebook for quantizer in self.quantizers
498
+ ]
499
+ self.codebooks = torch.stack(codebooks, dim=0)
500
+
501
+ def from_indices(self, indices: Tensor):
502
+ B, Q, T = indices.size()
503
+
504
+ z_q = 0.0
505
+
506
+ for i, quantizer in enumerate(self.quantizers):
507
+ z_q_i = quantizer.from_indices(indices[:, i])
508
+ z_q = z_q + z_q_i
509
+
510
+ return z_q
511
+
512
+ def forward(self, z: Tensor, n_quantizers: int | None = None):
513
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
514
+ the corresponding codebook vectors
515
+ Parameters
516
+ ----------
517
+ z : Tensor[B x D x T]
518
+ n_quantizers : int, optional
519
+ No. of quantizers to use
520
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
521
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
522
+ when in training mode, and a random number of quantizers is used.
523
+ Returns
524
+ -------
525
+ dict
526
+ A dictionary with the following keys:
527
+
528
+ "z" : Tensor[B x D x T]
529
+ Quantized continuous representation of input
530
+ "codes" : Tensor[B x N x T]
531
+ Codebook indices for each codebook
532
+ (quantized discrete representation of input)
533
+ "latents" : Tensor[B x N*D x T]
534
+ Projected latents (continuous representation of input before quantization)
535
+ """
536
+ B = z.shape[0]
537
+ z_q = 0
538
+ residual = z
539
+
540
+ indices = []
541
+ latents = []
542
+
543
+ if n_quantizers is None:
544
+ n_quantizers = self.n_quantizers
545
+
546
+ if self.training:
547
+ n_quantizers = torch.ones((B,)) * self.n_quantizers + 1
548
+ dropout = torch.randint(1, self.n_quantizers + 1, (B,))
549
+ n_dropout = int(B * self.quantizer_dropout)
550
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
551
+ n_quantizers = n_quantizers.to(z.device)
552
+
553
+ for i, quantizer in enumerate(self.quantizers):
554
+ if not self.training and i >= n_quantizers:
555
+ break
556
+
557
+ z_q_i, indices_i, z_e_i = quantizer(residual)
558
+
559
+ residual = residual - z_q_i.detach()
560
+
561
+ mask = torch.full((B,), fill_value=i, device=z.device) < n_quantizers
562
+ z_q = z_q + z_q_i * mask[:, None, None]
563
+
564
+ indices.append(indices_i)
565
+ latents.append(z_e_i)
566
+
567
+ indices = torch.stack(indices, dim=1)
568
+ latents = torch.cat(latents, dim=1)
569
+
570
+ return z_q, indices, latents
571
+
572
+
573
+ class FluacConfig(BaseModel):
574
+ sample_rate: int = 44100
575
+
576
+ codebook_size: int | None = None
577
+
578
+ encoder_dim: int = 64
579
+ encoder_rates: list[int] = [2, 4, 8, 8]
580
+
581
+ quantizer_strides: list[int] | None = None # SNAC style strides
582
+ n_quantizers: int = 1
583
+ fsq_levels: list[int] | None = [8, 5, 5, 5] # 1000
584
+ decoder_dim: int = 1536
585
+ decoder_rates: list[int] = [8, 8, 4, 2]
586
+
587
+ @property
588
+ def hop_length(self) -> int:
589
+ return math.prod(self.encoder_rates)
590
+
591
+ @property
592
+ def latent_dim(self) -> int:
593
+ return self.encoder_dim * (2 ** len(self.encoder_rates))
594
+
595
+ @property
596
+ def effective_codebook_size(self) -> int:
597
+ return math.prod(self.fsq_levels)
598
+
599
+
600
+ class Fluac(nn.Module):
601
+ Q9_22KHZ = "fluac-22hz-22khz.pt"
602
+
603
+ def __init__(self, config: FluacConfig):
604
+ super().__init__()
605
+
606
+ self.config = config
607
+
608
+ self.encoder = Encoder(
609
+ config.encoder_dim, config.encoder_rates, config.latent_dim
610
+ )
611
+
612
+ self.quantizer = ResidualFiniteScalarQuantize(
613
+ latent_dim=config.latent_dim,
614
+ n_quantizers=config.n_quantizers,
615
+ levels=config.fsq_levels,
616
+ strides=config.quantizer_strides,
617
+ )
618
+
619
+ self.decoder = Decoder(
620
+ config.latent_dim,
621
+ config.decoder_dim,
622
+ config.decoder_rates,
623
+ )
624
+
625
+ self.apply(init_weights)
626
+
627
+ @staticmethod
628
+ def from_pretrained(name: str = Q9_22KHZ):
629
+ if path.exists(name):
630
+ checkpoint_path = name
631
+ else:
632
+ from huggingface_hub import hf_hub_download
633
+
634
+ checkpoint_path = hf_hub_download(
635
+ "fluxions/vui",
636
+ name,
637
+ )
638
+
639
+ checkpoint = torch.load(checkpoint_path, weights_only=True, map_location="cpu")
640
+ config = checkpoint["config"]
641
+ if "model" in config:
642
+ model_config = FluacConfig(**config["model"])
643
+ else:
644
+ model_config = FluacConfig(**config)
645
+
646
+ generator = Fluac(model_config).eval()
647
+ ckpt = decompile_state_dict(checkpoint["generator"])
648
+ generator.load_state_dict(ckpt)
649
+ return generator
650
+
651
+ def pad(self, waveform: Tensor):
652
+ T = waveform.size(-1)
653
+ right_pad = math.ceil(T / self.config.hop_length) * self.config.hop_length - T
654
+ waveform = F.pad(waveform, (0, right_pad))
655
+ return waveform
656
+
657
+ @torch.inference_mode()
658
+ def from_indices(self, indices: Tensor):
659
+ z_q = self.quantizer.from_indices(indices)
660
+ waveform = self.decoder(z_q)
661
+ return waveform
662
+
663
+ @torch.inference_mode()
664
+ def encode(self, waveforms: Tensor, n_quantizers: int | None = None):
665
+ # Ensure that waveforms is 3 dima
666
+ waveforms = waveforms.flatten()[None][None]
667
+ waveforms = self.pad(waveforms)
668
+ B, C, T = waveforms.size()
669
+ z = self.encoder(waveforms)
670
+ z_q, codes, latents = self.quantizer(z, n_quantizers=n_quantizers)
671
+ return codes
672
+
673
+ def forward(self, waveforms: Tensor, n_quantizers: int | None = None):
674
+ B, C, T = waveforms.size()
675
+ waveforms = self.pad(waveforms)
676
+ z = self.encoder(waveforms)
677
+ z_q, codes, latents = self.quantizer(z, n_quantizers=n_quantizers)
678
+
679
+ recons = self.decoder(z_q)
680
+ recons = recons[..., :T]
681
+ return {
682
+ "recons": recons,
683
+ "codes": codes,
684
+ }
685
+
686
+ @property
687
+ def device(self):
688
+ return next(self.parameters()).device
689
+
690
+ @property
691
+ def dtype(self):
692
+ return next(self.parameters()).dtype
693
+
694
+ @property
695
+ def hz(self):
696
+ import numpy as np
697
+
698
+ return self.config.sample_rate / np.prod(self.config.encoder_rates).item()
699
+
700
+
701
+ if __name__ == "__main__":
702
+ codec = Fluac.from_pretrained(Fluac.Q9_22KHZ)
703
+ print(codec.config)
704
+ wav = torch.rand(1, 1, 22050)
705
+ wav = codec.pad(wav)
706
+ codes = codec.encode(wav)
707
+ breakpoint()
src/vui/inference.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import time
3
+
4
+ import inflect
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ from torch import Tensor
9
+ from torch.nn.attention import SDPBackend, sdpa_kernel
10
+
11
+ from vui.model import Vui
12
+ from vui.sampling import multinomial, sample_top_k, sample_top_p, sample_top_p_top_k
13
+ from vui.utils import timer
14
+ from vui.vad import detect_voice_activity as vad
15
+
16
+
17
+ def ensure_spaces_around_tags(text: str):
18
+ # Add space before '[' if not preceded by space, '<', or '['
19
+ text = re.sub(
20
+ r"(?<![<\[\s])(\[)",
21
+ lambda m: (
22
+ f"\n{m.group(1)}"
23
+ if m.start() > 0 and text[m.start() - 1] == "\n"
24
+ else f" {m.group(1)}"
25
+ ),
26
+ text,
27
+ )
28
+ # Add space after ']' if not preceded by digit+']' and not followed by space, '>', or ']'
29
+ text = re.sub(
30
+ r"(?<!\d\])(\])(?![>\]\s])",
31
+ lambda m: (
32
+ f"{m.group(1)}\n"
33
+ if m.end() < len(text) and text[m.end()] == "\n"
34
+ else f"{m.group(1)} "
35
+ ),
36
+ text,
37
+ )
38
+ text = text.strip()
39
+ return text
40
+
41
+
42
+ REPLACE = [
43
+ ("—", ","),
44
+ ("'", "'"),
45
+ (":", ","),
46
+ (";", ","),
47
+ ]
48
+
49
+ engine = None
50
+ wm = None
51
+
52
+
53
+ def asr(chunk, model=None, prefix=None):
54
+ import whisper
55
+
56
+ global wm
57
+ if model is not None:
58
+ wm = model
59
+ elif wm is None:
60
+ wm = whisper.load_model("turbo", "cuda")
61
+
62
+ """Process audio with VAD and transcribe"""
63
+ chunk = whisper.pad_or_trim(chunk)
64
+ mel = whisper.log_mel_spectrogram(chunk, n_mels=wm.dims.n_mels).to(wm.device)
65
+ options = whisper.DecodingOptions(
66
+ language="en", without_timestamps=True, prefix=prefix
67
+ )
68
+ result = whisper.decode(wm, mel[None], options)
69
+ return result[0].text
70
+
71
+
72
+ def replace_numbers_with_words(text):
73
+ global engine
74
+
75
+ if engine is None:
76
+ engine = inflect.engine()
77
+
78
+ # Function to convert a number match to words
79
+ def number_to_words(match):
80
+ number = match.group()
81
+ return engine.number_to_words(number) + " "
82
+
83
+ # Replace digits with their word equivalents
84
+ return re.sub(r"\d+", number_to_words, text)
85
+
86
+
87
+ valid_non_speech = ["breath", "sigh", "laugh", "tut", "hesitate"]
88
+ valid_non_speech = [f"[{v}]" for v in valid_non_speech]
89
+
90
+
91
+ def remove_all_invalid_non_speech(txt):
92
+ """
93
+ Remove all non-speech markers that are not in the valid_non_speech list.
94
+ Only keeps valid non-speech markers like [breath], [sigh], etc.
95
+ """
96
+ # Find all text within square brackets
97
+ bracket_pattern = r"\[([^\]]+)\]"
98
+ brackets = re.findall(bracket_pattern, txt)
99
+
100
+ # For each bracketed text, check if it's in our valid list
101
+ for bracket in brackets:
102
+ bracket_with_brackets = f"[{bracket}]"
103
+ if bracket_with_brackets not in valid_non_speech and bracket != "pause":
104
+ # If not valid, remove it from the text
105
+ txt = txt.replace(bracket_with_brackets, "")
106
+
107
+ return txt
108
+
109
+
110
+ def simple_clean(text):
111
+ text = re.sub(r"(\d+)am", r"\1 AM", text)
112
+ text = re.sub(r"(\d+)pm", r"\1 PM", text)
113
+ text = replace_numbers_with_words(text)
114
+ text = ensure_spaces_around_tags(text)
115
+ text = remove_all_invalid_non_speech(text)
116
+
117
+ text = text.replace('"', "")
118
+ text = text.replace("”", "")
119
+ text = text.replace("“", "")
120
+ text = text.replace("’", "'")
121
+ text = text.replace("%", " percent")
122
+ text = text.replace("*", "")
123
+ text = text.replace("(", "")
124
+ text = text.replace(")", "")
125
+ text = text.replace(";", "")
126
+ text = text.replace("–", " ")
127
+ text = text.replace("—", "")
128
+ text = text.replace(":", "")
129
+ text = text.replace("…", "...")
130
+ text = text.replace("s...", "s")
131
+
132
+ # replace repeating \n with just one \n
133
+ text = re.sub(r"\n+", "\n", text)
134
+ ntxt = re.sub(r" +", " ", text)
135
+
136
+ # Ensure that ntxt ends with . or ?
137
+ ntxt = ntxt.strip()
138
+ if not ntxt.endswith(".") or ntxt.endswith("?"):
139
+ ntxt += "."
140
+ ntxt += " [pause]"
141
+ return ntxt
142
+
143
+
144
+ @torch.inference_mode()
145
+ def generate(
146
+ self: Vui,
147
+ text: str,
148
+ prompt_codes: Tensor | None = None,
149
+ temperature: float = 0.5,
150
+ top_k: int | None = 150,
151
+ top_p: float | None = None,
152
+ max_gen_len: int = int(120 * 21.53),
153
+ ):
154
+ text = simple_clean(text)
155
+ with (
156
+ torch.autocast("cuda", torch.bfloat16, True),
157
+ sdpa_kernel([SDPBackend.MATH]),
158
+ timer("generate"),
159
+ ):
160
+ t1 = time.perf_counter()
161
+ batch_size = 1
162
+ device = self.device
163
+ self.dtype
164
+ self.decoder.allocate_inference_cache(batch_size, device, torch.bfloat16)
165
+
166
+ texts = [text]
167
+
168
+ encoded = self.tokenizer(
169
+ texts,
170
+ padding="longest",
171
+ return_tensors="pt",
172
+ )
173
+
174
+ input_ids = encoded.input_ids.to(device)
175
+ text_embeddings = self.token_emb(input_ids)
176
+
177
+ B = batch_size
178
+ Q = self.config.model.n_quantizers
179
+
180
+ if prompt_codes is None:
181
+ prompt_codes = torch.zeros(
182
+ (batch_size, Q, 0), dtype=torch.int64, device=device
183
+ )
184
+ else:
185
+ prompt_codes = prompt_codes[:, :Q].repeat(batch_size, 1, 1)
186
+
187
+ start_offset = prompt_codes.size(-1)
188
+
189
+ pattern = self.pattern_provider.get_pattern(max_gen_len)
190
+ # this token is used as default value for codes that are not generated yet
191
+ unknown_token = -1
192
+ special_token_id = self.config.model.special_token_id
193
+
194
+ # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
195
+ codes = torch.full(
196
+ (B, Q, max_gen_len), unknown_token, dtype=torch.int64, device=device
197
+ )
198
+
199
+ codes[:, :, :start_offset] = prompt_codes
200
+
201
+ sequence, indexes, mask = pattern.build_pattern_sequence(
202
+ codes, special_token_id
203
+ )
204
+ # retrieve the start_offset in the sequence:
205
+ # it is the first sequence step that contains the `start_offset` timestep
206
+ start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
207
+ assert start_offset_sequence is not None
208
+
209
+ prev_offset = 0
210
+ S = sequence.size(-1)
211
+
212
+ do_prefill = True
213
+ eos = self.config.model.audio_eos_id
214
+
215
+ for offset in range(start_offset_sequence, S):
216
+ # print(f"{prev_offset}:{offset}")
217
+ curr_sequence = sequence[..., prev_offset:offset]
218
+ audio_embeddings = (
219
+ sum([self.audio_embeddings[q](curr_sequence[:, q]) for q in range(Q)])
220
+ / Q
221
+ )
222
+
223
+ if do_prefill:
224
+ embeddings = torch.cat((text_embeddings, audio_embeddings), dim=1)
225
+ T = embeddings.size(1)
226
+ input_pos = torch.arange(0, T, device=device)
227
+ do_prefill = False
228
+ else:
229
+ embeddings = audio_embeddings
230
+ input_pos = torch.tensor([T], device=device)
231
+ T += 1
232
+
233
+ out = self.decoder(embeddings, input_pos)
234
+
235
+ if offset == 15:
236
+ print("TTFB", time.perf_counter() - t1)
237
+
238
+ logits = torch.stack(
239
+ [self.audio_heads[q](out[:, -1]) for q in range(Q)], dim=1
240
+ )
241
+
242
+ repetition_penalty = 1.4
243
+ history_window = 12
244
+
245
+ # Get the history of generated tokens for each quantizer
246
+ for q in range(Q):
247
+ # Extract the history window for this quantizer
248
+ history_start = max(0, offset - history_window)
249
+ token_history = sequence[0, q, history_start:offset]
250
+
251
+ # Only apply penalty to tokens that appear in the history
252
+ unique_tokens = torch.unique(token_history)
253
+ unique_tokens = unique_tokens[unique_tokens != special_token_id]
254
+ unique_tokens = unique_tokens[unique_tokens != eos]
255
+ unique_tokens = unique_tokens[unique_tokens != unknown_token]
256
+
257
+ if len(unique_tokens) > 0:
258
+ # Apply penalty by dividing the logits for tokens that have appeared recently
259
+ logits[0, q, unique_tokens] = (
260
+ logits[0, q, unique_tokens] / repetition_penalty
261
+ )
262
+
263
+ if offset < 24.53 * 4:
264
+ logits[..., eos] = -float("inf")
265
+
266
+ probs = F.softmax(logits / temperature, dim=-1)
267
+
268
+ # print(probs.shape)
269
+ if top_p is not None and top_k is not None:
270
+ next_codes = sample_top_p_top_k(probs, top_p, top_k)
271
+ elif top_p is not None and top_p > 0:
272
+ next_codes = sample_top_p(probs, top_p)
273
+ elif top_k is not None and top_k > 0:
274
+ next_codes = sample_top_k(probs, top_k)
275
+ else:
276
+ next_codes = multinomial(probs, num_samples=1)
277
+
278
+ next_codes = next_codes.repeat(batch_size, 1, 1)
279
+
280
+ if (probs[..., eos] > 0.95).any():
281
+ print("breaking at", offset)
282
+ break
283
+
284
+ valid_mask = mask[..., offset : offset + 1].expand(B, -1, -1)
285
+ next_codes[~valid_mask] = special_token_id
286
+
287
+ sequence[..., offset : offset + 1] = torch.where(
288
+ sequence[..., offset : offset + 1] == unknown_token,
289
+ next_codes,
290
+ sequence[..., offset : offset + 1],
291
+ )
292
+
293
+ prev_offset = offset
294
+
295
+ # print(sequence.shape)
296
+ out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(
297
+ sequence, special_token=unknown_token
298
+ )
299
+
300
+ # sanity checks over the returned codes and corresponding masks
301
+ # assert (out_codes[..., :max_gen_len] != unknown_token).all()
302
+ # assert (out_mask[..., :max_gen_len] == 1).all()
303
+ out_codes = out_codes[..., prompt_codes.shape[-1] : offset]
304
+ return out_codes[[0]]
305
+
306
+
307
+ @torch.inference_mode()
308
+ def render(
309
+ self: Vui,
310
+ text: str,
311
+ prompt_codes: Tensor | None = None,
312
+ temperature: float = 0.5,
313
+ top_k: int | None = 100,
314
+ top_p: float | None = None,
315
+ max_secs: int = 100,
316
+ ):
317
+ """
318
+ Render audio from text. Uses generate for text < 1000 characters,
319
+ otherwise breaks text into sections and uses chunking with context.
320
+ """
321
+ text = remove_all_invalid_non_speech(text)
322
+ text = simple_clean(text)
323
+ SR = self.codec.config.sample_rate
324
+ HZ = self.codec.hz
325
+ max_gen_len = int(HZ * max_secs)
326
+
327
+ if len(text) < 1000:
328
+ codes = generate(
329
+ self, text, prompt_codes, temperature, top_k, top_p, max_gen_len
330
+ )
331
+ codes = codes[..., :-10]
332
+ audio = self.codec.from_indices(codes)
333
+ paudio = torchaudio.functional.resample(audio[0], 22050, 16000)
334
+ results = vad(paudio)
335
+
336
+ if len(results):
337
+ # Cut the audio based on VAD results, add 200ms silence at end
338
+ s, e = results[0][0], results[-1][1]
339
+ return audio[..., int(s * SR) : int((e + 0.2) * SR)].cpu()
340
+
341
+ raise Exception("Failed to render")
342
+
343
+ # Otherwise we have to do some clever chaining!
344
+
345
+ orig_codes = prompt_codes
346
+
347
+ lines = text.split("\n")
348
+ audios = []
349
+ prev_codes = prompt_codes
350
+ prev_text = ""
351
+
352
+ for i, line in enumerate(lines):
353
+ run = True
354
+ while run:
355
+ current_text = prev_text + "\n" + line if prev_text else line
356
+ current_text = current_text.strip()
357
+ current_text = current_text.replace("...", "")
358
+ current_text = current_text + " [pause]"
359
+
360
+ # Calculate max length based on text length
361
+ maxlen = int(HZ * int(60 * len(current_text) / 500))
362
+
363
+ try:
364
+ print("rendering", current_text)
365
+ with (
366
+ torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH),
367
+ torch.autocast("cuda", dtype=torch.bfloat16, enabled=True),
368
+ ):
369
+ codes = generate(
370
+ self,
371
+ current_text,
372
+ prompt_codes=prev_codes,
373
+ temperature=temperature,
374
+ top_k=top_k,
375
+ top_p=top_p,
376
+ max_gen_len=maxlen,
377
+ )
378
+
379
+ codes = codes[..., :-10]
380
+ audio = self.codec.from_indices(codes)
381
+ # Resample for VAD
382
+ paudio = torchaudio.functional.resample(audio[0], 22050, 16000)
383
+
384
+ results = vad(paudio)
385
+ run = len(results) == 0
386
+
387
+ if len(results):
388
+ prev_text = line
389
+ # Cut the audio based on VAD results, add 200ms silence at end
390
+ s, e = results[0][0], results[0][1]
391
+ codes = codes[..., int(s * HZ) : int(e * HZ)]
392
+ prev_codes = codes
393
+ audio = audio[..., int(s * SR) : int((e + 0.2) * SR)].cpu()
394
+ audios.append(audio)
395
+ else:
396
+ prev_codes = orig_codes
397
+ prev_text = ""
398
+ except KeyboardInterrupt:
399
+ break
400
+ except RuntimeError as e:
401
+ prev_codes = orig_codes
402
+ prev_text = ""
403
+ print(e)
404
+
405
+ return torch.cat(audios, dim=-1)
src/vui/model.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch import Tensor
9
+ from transformers import AutoTokenizer
10
+
11
+ from vui.fluac import Fluac
12
+ from vui.utils import load_what_you_can
13
+
14
+ from .config import Config
15
+ from .patterns import DelayedPatternProvider
16
+ from .rope import apply_rotary_emb, precompute_freqs_cis
17
+
18
+
19
+ class KVCache(nn.Module):
20
+ def __init__(
21
+ self,
22
+ batch_size: int,
23
+ max_seqlen: int,
24
+ n_kv_heads: int,
25
+ head_dim: int,
26
+ dtype: torch.dtype = torch.bfloat16,
27
+ ):
28
+ super().__init__()
29
+
30
+ cache_shape = (batch_size, n_kv_heads, max_seqlen, head_dim)
31
+
32
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
33
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
34
+
35
+ def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
36
+ # input_pos: (T,), k_val: (B, nh, T, d)
37
+ assert input_pos.size(0) == k_val.size(-2)
38
+
39
+ k_out = self.k_cache
40
+ v_out = self.v_cache
41
+ k_out[:, :, input_pos] = k_val
42
+ v_out[:, :, input_pos] = v_val
43
+
44
+ return k_out, v_out
45
+
46
+
47
+ def repeat_kv(x: torch.Tensor, n_reps: int) -> torch.Tensor:
48
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
49
+ bs, n_kv_heads, T, head_dim = x.shape
50
+
51
+ return (
52
+ x[:, :, :, None, :]
53
+ .expand(bs, n_kv_heads, n_reps, T, head_dim)
54
+ .reshape(bs, n_kv_heads * n_reps, T, head_dim)
55
+ )
56
+
57
+
58
+ class MHA(nn.Module):
59
+ def __init__(
60
+ self,
61
+ dim: int,
62
+ n_heads: int,
63
+ n_kv_heads: int,
64
+ *,
65
+ block_idx: int,
66
+ bias: bool = False,
67
+ dropout: float = 0.0,
68
+ causal: bool = False,
69
+ use_rotary_emb: bool = True,
70
+ ):
71
+ super().__init__()
72
+
73
+ head_dim = dim // n_heads
74
+
75
+ self.use_rotary_emb = use_rotary_emb
76
+ self.block_idx = block_idx
77
+ self.dim = dim
78
+ self.n_heads = n_heads
79
+ self.n_kv_heads = n_kv_heads
80
+ self.head_dim = head_dim
81
+ self.dropout = dropout
82
+ self.causal = causal
83
+ self.n_reps = n_kv_heads // n_heads
84
+ qkv_dim = (n_heads + 2 * n_kv_heads) * head_dim
85
+ self.Wqkv = nn.Linear(dim, qkv_dim, bias=bias)
86
+ self.out_proj = nn.Linear(dim, dim, bias=bias)
87
+ self.kv_cache = None
88
+
89
+ def forward(
90
+ self,
91
+ x: Tensor,
92
+ freqs_cis: Tensor | None = None,
93
+ input_pos: Tensor | None = None,
94
+ attn_mask: Tensor | None = None,
95
+ ):
96
+ B, T, d = x.size()
97
+ x.dtype
98
+
99
+ dropout_p = self.dropout if self.training else 0.0
100
+
101
+ qkv = self.Wqkv(x)
102
+ if self.n_heads == self.n_kv_heads:
103
+ qkv = rearrange(
104
+ qkv, "B T (three h d) -> B three h T d", three=3, h=self.n_heads
105
+ )
106
+ q, k, v = qkv.unbind(dim=1) # (B, h, T, d)
107
+ else:
108
+ q, k, v = torch.split(
109
+ qkv,
110
+ [
111
+ self.head_dim * self.n_heads,
112
+ self.head_dim * self.n_kv_heads,
113
+ self.head_dim * self.n_kv_heads,
114
+ ],
115
+ dim=1,
116
+ )
117
+ q, k, v = map(lambda t: rearrange(t, "B T (h d) -> B h T d"), (q, k, v))
118
+
119
+ if self.use_rotary_emb:
120
+ q = apply_rotary_emb(freqs_cis, q)
121
+ k = apply_rotary_emb(freqs_cis, k)
122
+
123
+ if self.kv_cache is not None:
124
+ k, v = self.kv_cache.update(input_pos, k, v)
125
+
126
+ if self.n_reps > 1:
127
+ k = repeat_kv(k, self.n_reps)
128
+ v = repeat_kv(v, self.n_reps)
129
+
130
+ is_causal = self.causal and self.kv_cache is None
131
+
132
+ out = F.scaled_dot_product_attention(
133
+ q,
134
+ k,
135
+ v,
136
+ dropout_p=dropout_p,
137
+ is_causal=is_causal,
138
+ attn_mask=attn_mask,
139
+ )
140
+
141
+ out = self.out_proj(rearrange(out, "B h T d -> B T (h d)"))
142
+
143
+ return out
144
+
145
+
146
+ class MLP(nn.Module):
147
+ def __init__(
148
+ self, *, d_model: int, bias: bool, dropout: float, act=nn.GELU, **kwargs
149
+ ):
150
+ super().__init__()
151
+ self.fc1 = nn.Linear(d_model, 4 * d_model, bias=bias)
152
+ self.act = act()
153
+ self.fc2 = nn.Linear(4 * d_model, d_model, bias=bias)
154
+ self.dropout = nn.Dropout(dropout)
155
+
156
+ def forward(self, x):
157
+ return self.dropout(self.fc2(self.act(self.fc1(x))))
158
+
159
+
160
+ class LlamaMLP(nn.Module):
161
+ def __init__(
162
+ self, *, d_model: int, multiple_of: int = 256, bias: bool = False, **kwargs
163
+ ) -> None:
164
+ super().__init__()
165
+ hidden_dim = 4 * d_model
166
+ hidden_dim = int(2 * hidden_dim / 3)
167
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
168
+ self.w1 = nn.Linear(d_model, hidden_dim, bias=bias)
169
+ self.w3 = nn.Linear(d_model, hidden_dim, bias=bias)
170
+ self.w2 = nn.Linear(hidden_dim, d_model, bias=bias)
171
+
172
+ def forward(self, x: Tensor) -> Tensor:
173
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
174
+
175
+
176
+ class RMSNorm(nn.Module):
177
+ def __init__(self, dim: int, eps: float = 1e-6):
178
+ super().__init__()
179
+ self.eps = eps
180
+ self.weight = nn.Parameter(torch.ones(dim))
181
+
182
+ def _norm(self, x):
183
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
184
+
185
+ def forward(self, x: Tensor):
186
+ output = self._norm(x.float()).type_as(x)
187
+ return output * self.weight
188
+
189
+
190
+ class Block(nn.Module):
191
+ def __init__(
192
+ self,
193
+ *,
194
+ d_model: int,
195
+ n_heads: int,
196
+ n_kv_heads: int,
197
+ block_idx: int,
198
+ bias: bool,
199
+ dropout: float,
200
+ norm_eps: float = 1e-5, # use 1e-6 for rms
201
+ use_rotary_emb: bool = True,
202
+ ):
203
+ super().__init__()
204
+
205
+ self.block_idx = block_idx
206
+ self.n_heads = n_heads
207
+ self.n_kv_heads = n_kv_heads
208
+ self.head_dim = d_model // n_heads
209
+
210
+ self.attn_norm = RMSNorm(d_model, eps=norm_eps)
211
+ self.attn = MHA(
212
+ d_model,
213
+ n_heads,
214
+ n_kv_heads,
215
+ block_idx=block_idx,
216
+ bias=bias,
217
+ dropout=dropout,
218
+ causal=True,
219
+ use_rotary_emb=use_rotary_emb,
220
+ )
221
+ self.mlp_norm = RMSNorm(d_model, eps=norm_eps)
222
+ self.mlp = LlamaMLP(d_model=d_model, bias=bias, dropout=dropout)
223
+
224
+ def forward(
225
+ self,
226
+ x: Tensor,
227
+ freqs_cis: Tensor | None = None,
228
+ input_pos: Tensor | None = None,
229
+ attn_mask: Tensor | None = None,
230
+ ):
231
+ x = x + self.attn(
232
+ self.attn_norm(x),
233
+ freqs_cis=freqs_cis,
234
+ input_pos=input_pos,
235
+ attn_mask=attn_mask,
236
+ )
237
+ x = x + self.mlp(self.mlp_norm(x))
238
+
239
+ return x
240
+
241
+
242
+ class Decoder(nn.Module):
243
+ def __init__(
244
+ self,
245
+ *,
246
+ n_layers: int,
247
+ d_model: int,
248
+ n_heads: int,
249
+ n_kv_heads: int,
250
+ bias: bool,
251
+ dropout: float,
252
+ max_seqlen: int = 4096,
253
+ rope_theta: float = 10000.0,
254
+ rope_theta_rescale_factor: float = 1.0,
255
+ norm_eps: float = 1e-5,
256
+ use_rotary_emb: bool = True,
257
+ rope_dim: int | None = None,
258
+ ):
259
+ super().__init__()
260
+ assert d_model % n_heads == 0
261
+
262
+ self.use_rotary_emb = use_rotary_emb
263
+
264
+ self.max_seqlen = max_seqlen
265
+ self.blocks = nn.ModuleList(
266
+ [
267
+ Block(
268
+ d_model=d_model,
269
+ n_heads=n_heads,
270
+ n_kv_heads=n_kv_heads,
271
+ block_idx=block_idx,
272
+ bias=bias,
273
+ dropout=dropout,
274
+ norm_eps=norm_eps,
275
+ use_rotary_emb=use_rotary_emb,
276
+ )
277
+ for block_idx in range(n_layers)
278
+ ]
279
+ )
280
+ self.norm = RMSNorm(d_model, eps=norm_eps)
281
+
282
+ self.attn_mask = None
283
+
284
+ head_dim = d_model // n_heads
285
+
286
+ rope_dim = rope_dim or head_dim
287
+
288
+ assert rope_dim <= head_dim # apply RoPE to a fraction of embeddings
289
+
290
+ freqs_cis = precompute_freqs_cis(
291
+ rope_dim,
292
+ max_seqlen,
293
+ theta=rope_theta,
294
+ theta_rescale_factor=rope_theta_rescale_factor,
295
+ )
296
+ self.register_buffer("freqs_cis", freqs_cis, persistent=False)
297
+
298
+ def allocate_inference_cache(
299
+ self, batch_size: int, device: str, dtype=torch.bfloat16
300
+ ):
301
+ for block in self.blocks:
302
+ block.attn.kv_cache = KVCache(
303
+ batch_size, self.max_seqlen, block.n_kv_heads, block.head_dim, dtype
304
+ ).to(device)
305
+
306
+ # I don't understand why this is needed
307
+ self.attn_mask = torch.tril(
308
+ torch.ones(
309
+ self.max_seqlen, self.max_seqlen, dtype=torch.bool, device=device
310
+ )
311
+ )
312
+
313
+ def deallocate_kv_cache(self):
314
+ for block in self.blocks:
315
+ block.attn.kv_cache = None
316
+
317
+ self.attn_mask = None
318
+
319
+ def forward(self, x: Tensor, input_pos: Tensor):
320
+ if self.use_rotary_emb:
321
+ freqs_cis = self.freqs_cis[input_pos]
322
+ else:
323
+ freqs_cis = None
324
+
325
+ attn_mask = (
326
+ self.attn_mask[None, None, input_pos]
327
+ if self.attn_mask is not None
328
+ else None
329
+ )
330
+
331
+ for block in self.blocks:
332
+ x = block(x, freqs_cis=freqs_cis, input_pos=input_pos, attn_mask=attn_mask)
333
+
334
+ x = self.norm(x)
335
+
336
+ return x
337
+
338
+
339
+ class Vui(nn.Module):
340
+ BASE = "vui-100m-base.pt"
341
+ COHOST = "vui-cohost-100m.pt"
342
+ ABRAHAM = "vui-abraham-100m.pt"
343
+
344
+ def __init__(self, config: Config = Config()):
345
+ super().__init__()
346
+ self.codec = Fluac.from_pretrained()
347
+ self.config = config
348
+ cfg = config.model
349
+ self.tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
350
+ self.use_rotary_emb = cfg.use_rotary_emb
351
+ self.token_emb = nn.Embedding(self.tokenizer.vocab_size, cfg.d_model)
352
+
353
+ self.pattern_provider = DelayedPatternProvider(n_q=cfg.n_quantizers)
354
+
355
+ self.audio_embeddings = nn.ModuleList(
356
+ [
357
+ nn.Embedding(cfg.codebook_size + 8, cfg.d_model)
358
+ for _ in range(cfg.n_quantizers)
359
+ ]
360
+ )
361
+
362
+ n_kv_heads = cfg.n_heads
363
+
364
+ max_seqlen = cfg.max_text_tokens + cfg.max_audio_tokens
365
+ self.decoder = Decoder(
366
+ n_layers=cfg.n_layers,
367
+ d_model=cfg.d_model,
368
+ n_heads=cfg.n_heads,
369
+ n_kv_heads=n_kv_heads,
370
+ bias=cfg.bias,
371
+ dropout=cfg.dropout,
372
+ max_seqlen=max_seqlen + cfg.n_quantizers,
373
+ rope_dim=cfg.rope_dim,
374
+ rope_theta=cfg.rope_theta,
375
+ rope_theta_rescale_factor=cfg.rope_theta_rescale_factor,
376
+ )
377
+
378
+ self.audio_heads = nn.ModuleList(
379
+ [
380
+ nn.Linear(cfg.d_model, cfg.codebook_size + 8, bias=cfg.bias)
381
+ for _ in range(cfg.n_quantizers)
382
+ ]
383
+ )
384
+
385
+ self.apply(self._init_weights)
386
+
387
+ for pn, p in self.named_parameters():
388
+ if pn.endswith("out_proj.weight"):
389
+ torch.nn.init.normal_(
390
+ p, mean=0.0, std=0.02 / math.sqrt(2 * cfg.n_layers)
391
+ )
392
+
393
+ def _init_weights(self, module):
394
+ if isinstance(module, nn.Linear):
395
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
396
+ if module.bias is not None:
397
+ torch.nn.init.zeros_(module.bias)
398
+ elif isinstance(module, nn.Embedding):
399
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
400
+
401
+ @staticmethod
402
+ def from_pretrained(
403
+ checkpoint_path: str | dict = ABRAHAM,
404
+ **config_kwargs,
405
+ ):
406
+ if isinstance(checkpoint_path, dict):
407
+ checkpoint = checkpoint_path
408
+ else:
409
+ if not os.path.exists(checkpoint_path):
410
+ from huggingface_hub import hf_hub_download
411
+
412
+ checkpoint_path = hf_hub_download(
413
+ "fluxions/vui",
414
+ checkpoint_path,
415
+ )
416
+ checkpoint = torch.load(
417
+ checkpoint_path, map_location="cpu", weights_only=True
418
+ )
419
+
420
+ config = {**checkpoint["config"], **config_kwargs}
421
+ config = Config(**config)
422
+ state_dict = checkpoint["model"]
423
+
424
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
425
+ state_dict = {
426
+ k.replace("text_embedding.", "token_emb."): v for k, v in state_dict.items()
427
+ }
428
+ model = Vui(config)
429
+ load_what_you_can(state_dict, model)
430
+ return model
431
+
432
+ @staticmethod
433
+ def from_pretrained_inf(
434
+ checkpoint_path: str | dict,
435
+ **config_kwargs,
436
+ ):
437
+ return Vui.from_pretrained(checkpoint_path, **config_kwargs).eval()
438
+
439
+ @property
440
+ def device(self):
441
+ return next(self.parameters()).device
442
+
443
+ @property
444
+ def dtype(self):
445
+ return next(self.parameters()).dtype
src/vui/notebook.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def play(audio: torch.Tensor | np.ndarray | str, sr=16000, autoplay=True):
6
+ import torchaudio
7
+ from IPython.display import Audio, display
8
+
9
+ if isinstance(audio, str):
10
+ audio = torchaudio.load(audio)
11
+ if isinstance(audio, np.ndarray):
12
+ audio = torch.from_numpy(audio)
13
+
14
+ assert audio.numel() > 100, "play() needs a non empty audio array"
15
+
16
+ audio = audio.flatten()
17
+ if audio.dim() < 2:
18
+ audio = audio[None]
19
+
20
+ # Sum Channels
21
+ if audio.shape[0] > 1:
22
+ audio = audio.sum(dim=0)
23
+
24
+ display(Audio(audio.cpu().detach(), rate=sr, autoplay=autoplay, normalize=True))
25
+
26
+
27
+ def plot_mel_spec(mel_spec: torch.Tensor | np.ndarray, title: str = None):
28
+ import matplotlib.pyplot as plt
29
+
30
+ mel_spec = mel_spec.squeeze()
31
+ if isinstance(mel_spec, torch.Tensor):
32
+ mel_spec = mel_spec.cpu().numpy()
33
+
34
+ fig, ax = plt.subplots(figsize=(16, 4))
35
+ im = ax.imshow(mel_spec, aspect="auto", origin="lower", interpolation="none")
36
+ fig.colorbar(im, ax=ax)
37
+ ax.set_xlabel("frames")
38
+ ax.set_ylabel("channels")
39
+
40
+ if title is not None:
41
+ ax.set_title(title)
src/vui/patterns.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from collections import namedtuple
4
+ from dataclasses import dataclass
5
+ from functools import lru_cache
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+
11
+ def apply_delay_pattern(codes: torch.Tensor, mask_token: int):
12
+ codes = F.pad(codes, (0, codes.shape[1] + 1), value=mask_token)
13
+ return torch.stack([codes[:, k].roll(k + 1) for k in range(codes.shape[1])], dim=1)
14
+
15
+
16
+ def revert_delay_pattern(codes: torch.Tensor):
17
+ _, n_q, seq_len = codes.shape
18
+ return torch.stack(
19
+ [codes[:, k, k + 1 : seq_len - n_q + k] for k in range(n_q)], dim=1
20
+ )
21
+
22
+
23
+ LayoutCoord = namedtuple("LayoutCoord", ["t", "q"]) # (timestep, codebook index)
24
+ PatternLayout = list[list[LayoutCoord]] # Sequence of coordinates
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @dataclass
29
+ class Pattern:
30
+ """Base implementation of a pattern over a sequence with multiple codebooks.
31
+
32
+ The codebook pattern consists in a layout, defining for each sequence step
33
+ the list of coordinates of each codebook timestep in the resulting interleaved sequence.
34
+ The first item of the pattern is always an empty list in order to properly insert a special token
35
+ to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
36
+ and ``timesteps`` the number of timesteps corresponding to the original sequence.
37
+
38
+ The pattern provides convenient methods to build and revert interleaved sequences from it:
39
+ ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
40
+ to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size,
41
+ K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
42
+ for the output sequence. The unfilled positions are replaced with a special token and the built sequence
43
+ is returned along with a mask indicating valid tokens.
44
+ ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
45
+ of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
46
+ to fill and specify invalid positions if needed.
47
+ See the dedicated methods for more details.
48
+ """
49
+
50
+ # Pattern layout, for each sequence step, we have a list of coordinates
51
+ # corresponding to the original codebook timestep and position.
52
+ # The first list is always an empty list in order to properly insert
53
+ # a special token to start with.
54
+ layout: PatternLayout
55
+ timesteps: int
56
+ n_q: int
57
+
58
+ def __post_init__(self):
59
+ assert len(self.layout) > 0
60
+ self._validate_layout()
61
+ self._build_reverted_sequence_scatter_indexes = lru_cache(100)(
62
+ self._build_reverted_sequence_scatter_indexes
63
+ )
64
+ self._build_pattern_sequence_scatter_indexes = lru_cache(100)(
65
+ self._build_pattern_sequence_scatter_indexes
66
+ )
67
+ logger.info(
68
+ "New pattern, time steps: %d, sequence steps: %d",
69
+ self.timesteps,
70
+ len(self.layout),
71
+ )
72
+
73
+ def _validate_layout(self):
74
+ """Runs checks on the layout to ensure a valid pattern is defined.
75
+ A pattern is considered invalid if:
76
+ - Multiple timesteps for a same codebook are defined in the same sequence step
77
+ - The timesteps for a given codebook are not in ascending order as we advance in the sequence
78
+ (this would mean that we have future timesteps before past timesteps).
79
+ """
80
+ q_timesteps = {q: 0 for q in range(self.n_q)}
81
+ for s, seq_coords in enumerate(self.layout):
82
+ if len(seq_coords) > 0:
83
+ qs = set()
84
+ for coord in seq_coords:
85
+ qs.add(coord.q)
86
+ last_q_timestep = q_timesteps[coord.q]
87
+ assert (
88
+ coord.t >= last_q_timestep
89
+ ), f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
90
+ q_timesteps[coord.q] = coord.t
91
+ # each sequence step contains at max 1 coordinate per codebook
92
+ assert len(qs) == len(
93
+ seq_coords
94
+ ), f"Multiple entries for a same codebook are found at step {s}"
95
+
96
+ @property
97
+ def num_sequence_steps(self):
98
+ return len(self.layout) - 1
99
+
100
+ @property
101
+ def max_delay(self):
102
+ max_t_in_seq_coords = 0
103
+ for seq_coords in self.layout[1:]:
104
+ for coords in seq_coords:
105
+ max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
106
+ return max_t_in_seq_coords - self.timesteps
107
+
108
+ @property
109
+ def valid_layout(self):
110
+ valid_step = len(self.layout) - self.max_delay
111
+ return self.layout[:valid_step]
112
+
113
+ def starts_with_special_token(self):
114
+ return self.layout[0] == []
115
+
116
+ def get_sequence_coords_with_timestep(self, t: int, q: int | None = None):
117
+ """Get codebook coordinates in the layout that corresponds to the specified timestep t
118
+ and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
119
+ and the actual codebook coordinates.
120
+ """
121
+ assert (
122
+ t <= self.timesteps
123
+ ), "provided timesteps is greater than the pattern's number of timesteps"
124
+ if q is not None:
125
+ assert (
126
+ q <= self.n_q
127
+ ), "provided number of codebooks is greater than the pattern's number of codebooks"
128
+ coords = []
129
+ for s, seq_codes in enumerate(self.layout):
130
+ for code in seq_codes:
131
+ if code.t == t and (q is None or code.q == q):
132
+ coords.append((s, code))
133
+ return coords
134
+
135
+ def get_steps_with_timestep(self, t: int, q: int | None = None) -> list[int]:
136
+ return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
137
+
138
+ def get_first_step_with_timesteps(self, t: int, q: int | None = None) -> int | None:
139
+ steps_with_timesteps = self.get_steps_with_timestep(t, q)
140
+ return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
141
+
142
+ def _build_pattern_sequence_scatter_indexes(
143
+ self,
144
+ timesteps: int,
145
+ n_q: int,
146
+ keep_only_valid_steps: bool,
147
+ device: torch.device | str = "cpu",
148
+ ):
149
+ """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
150
+
151
+ Args:
152
+ timesteps (int): Maximum number of timesteps steps to consider.
153
+ keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
154
+ device (torch.device or str): Device for created tensors.
155
+ Returns:
156
+ indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
157
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
158
+ """
159
+ assert (
160
+ n_q == self.n_q
161
+ ), f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
162
+ assert (
163
+ timesteps <= self.timesteps
164
+ ), "invalid number of timesteps used to build the sequence from the pattern"
165
+ # use the proper layout based on whether we limit ourselves to valid steps only or not,
166
+ # note that using the valid_layout will result in a truncated sequence up to the valid steps
167
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
168
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
169
+ indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
170
+ mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
171
+ # fill indexes with last sequence step value that will correspond to our special token
172
+ # the last value is n_q * timesteps as we have flattened z and append special token as the last token
173
+ # which will correspond to the index: n_q * timesteps
174
+ indexes[:] = n_q * timesteps
175
+ # iterate over the pattern and fill scattered indexes and mask
176
+ for s, sequence_coords in enumerate(ref_layout):
177
+ for coords in sequence_coords:
178
+ if coords.t < timesteps:
179
+ indexes[coords.q, s] = coords.t + coords.q * timesteps
180
+ mask[coords.q, s] = 1
181
+ indexes = torch.from_numpy(indexes).to(device)
182
+ mask = torch.from_numpy(mask).to(device)
183
+ return indexes, mask
184
+
185
+ def build_pattern_sequence(
186
+ self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False
187
+ ):
188
+ """Build sequence corresponding to the pattern from the input tensor z.
189
+ The sequence is built using up to sequence_steps if specified, and non-pattern
190
+ coordinates are filled with the special token.
191
+
192
+ Args:
193
+ z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
194
+ special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
195
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
196
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
197
+ Returns:
198
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
199
+ corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
200
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
201
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
202
+ """
203
+ B, K, T = z.shape
204
+ indexes, mask = self._build_pattern_sequence_scatter_indexes(
205
+ T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
206
+ )
207
+ z = z.reshape(B, -1)
208
+ # we append the special token as the last index of our flattened z tensor
209
+ z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
210
+ values = z[:, indexes.view(-1)]
211
+ values = values.view(B, K, indexes.shape[-1])
212
+ return values, indexes, mask
213
+
214
+ def _build_reverted_sequence_scatter_indexes(
215
+ self,
216
+ sequence_steps: int,
217
+ n_q: int,
218
+ keep_only_valid_steps: bool = False,
219
+ is_model_output: bool = False,
220
+ device: torch.device | str = "cpu",
221
+ ):
222
+ """Builds scatter indexes required to retrieve the original multi-codebook sequence
223
+ from interleaving pattern.
224
+
225
+ Args:
226
+ sequence_steps (int): Sequence steps.
227
+ n_q (int): Number of codebooks.
228
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
229
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
230
+ is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
231
+ device (torch.device or str): Device for created tensors.
232
+ Returns:
233
+ indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
234
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
235
+ """
236
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
237
+ timesteps = self.timesteps
238
+ assert (
239
+ n_q == self.n_q
240
+ ), f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
241
+ assert sequence_steps <= len(
242
+ ref_layout
243
+ ), f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
244
+
245
+ # ensure we take the appropriate indexes to keep the model output from the first special token as well
246
+ if is_model_output and self.starts_with_special_token():
247
+ ref_layout = ref_layout[1:]
248
+
249
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
250
+ indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
251
+ mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
252
+ # fill indexes with last sequence step value that will correspond to our special token
253
+ indexes[:] = n_q * sequence_steps
254
+ for s, sequence_codes in enumerate(ref_layout):
255
+ if s < sequence_steps:
256
+ for code in sequence_codes:
257
+ if code.t < timesteps:
258
+ indexes[code.q, code.t] = s + code.q * sequence_steps
259
+ mask[code.q, code.t] = 1
260
+ indexes = torch.from_numpy(indexes).to(device)
261
+ mask = torch.from_numpy(mask).to(device)
262
+ return indexes, mask
263
+
264
+ def revert_pattern_sequence(
265
+ self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False
266
+ ):
267
+ """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
268
+ The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
269
+ are filled with the special token.
270
+
271
+ Args:
272
+ s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
273
+ special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
274
+ Returns:
275
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
276
+ corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
277
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
278
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
279
+ """
280
+ B, K, S = s.shape
281
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
282
+ S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
283
+ )
284
+ s = s.view(B, -1)
285
+ # we append the special token as the last index of our flattened z tensor
286
+ s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
287
+ values = s[:, indexes.view(-1)]
288
+ values = values.view(B, K, indexes.shape[-1])
289
+ return values, indexes, mask
290
+
291
+ def revert_pattern_logits(
292
+ self,
293
+ logits: torch.Tensor,
294
+ special_token: float,
295
+ keep_only_valid_steps: bool = False,
296
+ ):
297
+ """Revert model logits obtained on a sequence built from the pattern
298
+ back to a tensor matching the original sequence.
299
+
300
+ This method is similar to ``revert_pattern_sequence`` with the following specificities:
301
+ 1. It is designed to work with the extra cardinality dimension
302
+ 2. We return the logits for the first sequence item that matches the special_token and
303
+ which matching target in the original sequence is the first item of the sequence,
304
+ while we skip the last logits as there is no matching target
305
+ """
306
+ B, n, Q, S = logits.shape
307
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
308
+ S, Q, keep_only_valid_steps, is_model_output=True, device=logits.device
309
+ )
310
+ logits = logits.reshape(B, n, -1)
311
+ # we append the special token as the last index of our flattened z tensor
312
+ logits = torch.cat(
313
+ [logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1
314
+ ) # [B, card, K x S]
315
+ values = logits[:, :, indexes.view(-1)]
316
+ values = values.view(B, n, Q, indexes.shape[-1])
317
+ return values, indexes, mask
318
+
319
+
320
+ class CodebooksPatternProvider(ABC):
321
+ """Abstraction around providing pattern for interleaving codebooks.
322
+
323
+ The CodebooksPatternProvider abstraction allows to implement various strategies to
324
+ define interleaving pattern of sequences composed of multiple codebooks. For a given
325
+ number of codebooks `n_q`, the pattern provider can generate a specified pattern
326
+ corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
327
+ can be used to construct a new sequence from the original codes respecting the specified
328
+ pattern. The pattern is defined as a list of list of code coordinates, code coordinate
329
+ being a tuple with the original timestep and codebook to build the new sequence.
330
+ Note that all patterns must start with an empty list that is then used to insert a first
331
+ sequence step of special tokens in the newly generated sequence.
332
+
333
+ Args:
334
+ n_q (int): number of codebooks.
335
+ cached (bool): if True, patterns for a given length are cached. In general
336
+ that should be true for efficiency reason to avoid synchronization points.
337
+ """
338
+
339
+ def __init__(self, n_q: int, cached: bool = True):
340
+ assert n_q > 0
341
+ self.n_q = n_q
342
+ self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
343
+
344
+ @abstractmethod
345
+ def get_pattern(self, timesteps: int) -> Pattern:
346
+ """Builds pattern with specific interleaving between codebooks.
347
+
348
+ Args:
349
+ timesteps (int): Total number of timesteps.
350
+ """
351
+ raise NotImplementedError()
352
+
353
+
354
+ class DelayedPatternProvider(CodebooksPatternProvider):
355
+ """Provider for delayed pattern across delayed codebooks.
356
+ Codebooks are delayed in the sequence and sequence steps will contain codebooks
357
+ from different timesteps.
358
+
359
+ Example:
360
+ Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
361
+ [[1, 2, 3, 4],
362
+ [1, 2, 3, 4],
363
+ [1, 2, 3, 4]]
364
+ The resulting sequence obtained from the returned pattern is:
365
+ [[S, 1, 2, 3, 4],
366
+ [S, S, 1, 2, 3],
367
+ [S, S, S, 1, 2]]
368
+ (with S being a special token)
369
+
370
+ Args:
371
+ n_q (int): Number of codebooks.
372
+ delays (list of int, optional): Delay for each of the codebooks.
373
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
374
+ flatten_first (int): Flatten the first N timesteps.
375
+ empty_initial (int): Prepend with N empty list of coordinates.
376
+ """
377
+
378
+ def __init__(
379
+ self,
380
+ n_q: int,
381
+ delays: list[int] | None = None,
382
+ flatten_first: int = 0,
383
+ empty_initial: int = 0,
384
+ ):
385
+ super().__init__(n_q)
386
+ if delays is None:
387
+ delays = list(range(n_q))
388
+ self.delays = delays
389
+ self.flatten_first = flatten_first
390
+ self.empty_initial = empty_initial
391
+ assert len(self.delays) == self.n_q
392
+ assert sorted(self.delays) == self.delays
393
+
394
+ def get_pattern(self, timesteps: int) -> Pattern:
395
+ omit_special_token = self.empty_initial < 0
396
+ out: PatternLayout = [] if omit_special_token else [[]]
397
+ max_delay = max(self.delays)
398
+ if self.empty_initial:
399
+ out += [[] for _ in range(self.empty_initial)]
400
+ if self.flatten_first:
401
+ for t in range(min(timesteps, self.flatten_first)):
402
+ for q in range(self.n_q):
403
+ out.append([LayoutCoord(t, q)])
404
+ for t in range(self.flatten_first, timesteps + max_delay):
405
+ v = []
406
+ for q, delay in enumerate(self.delays):
407
+ t_for_q = t - delay
408
+ if t_for_q >= self.flatten_first:
409
+ v.append(LayoutCoord(t_for_q, q))
410
+ out.append(v)
411
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
412
+
413
+
414
+ if __name__ == "__main__":
415
+ # Tried to use the simple patterns to train and something very odd happened.
416
+ Q = 4
417
+
418
+ codes = torch.randint(0, 1000, (1, Q, 100))
419
+ pcodes = apply_delay_pattern(codes, 1001)
420
+ provider = DelayedPatternProvider(Q)
421
+ provider = provider.get_pattern(100)
422
+ pcodes2 = provider.build_pattern_sequence(codes, 1001)
423
+ breakpoint()
src/vui/rope.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange, repeat
3
+ from torch import Tensor
4
+ from torch.amp import autocast
5
+
6
+
7
+ def rotate_half(x):
8
+ """Also known as "interleaved" style or GPT-J style."""
9
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
10
+ x1, x2 = x.unbind(dim=-1)
11
+ x = torch.stack((-x2, x1), dim=-1)
12
+ return rearrange(x, "... d r -> ... (d r)")
13
+
14
+
15
+ @autocast("cuda", enabled=False)
16
+ def apply_rotary_emb(
17
+ freqs: Tensor, t: Tensor, start_index: int = 0, scale: float = 1.0, seq_dim=-2
18
+ ):
19
+ dtype = t.dtype
20
+
21
+ if t.ndim == 3:
22
+ seq_len = t.shape[seq_dim]
23
+ freqs = freqs[-seq_len:]
24
+
25
+ rot_dim = freqs.shape[-1]
26
+ end_index = start_index + rot_dim
27
+
28
+ assert (
29
+ rot_dim <= t.shape[-1]
30
+ ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
31
+
32
+ t_left, t, t_right = (
33
+ t[..., :start_index],
34
+ t[..., start_index:end_index],
35
+ t[..., end_index:],
36
+ )
37
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
38
+ out = torch.cat((t_left, t, t_right), dim=-1)
39
+ return out.to(dtype)
40
+
41
+
42
+ def precompute_freqs_cis(
43
+ dim: int,
44
+ max_seqlen: int,
45
+ theta: float = 10_000.0,
46
+ theta_rescale_factor: float = 1.0,
47
+ dtype: torch.dtype = torch.float32,
48
+ ):
49
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
50
+ pos = torch.arange(max_seqlen, dtype=dtype)
51
+ inv_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype) / dim))
52
+ freqs = torch.einsum("..., f -> ... f", pos.to(inv_freqs.dtype), inv_freqs)
53
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
54
+ return freqs
src/vui/sampling.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+
4
+
5
+ def multinomial(input: Tensor, num_samples: int, replacement=False, *, generator=None):
6
+ input_ = input.reshape(-1, input.shape[-1])
7
+ output_ = torch.multinomial(
8
+ input_, num_samples=num_samples, replacement=replacement, generator=generator
9
+ )
10
+ output = output_.reshape(*list(input.shape[:-1]), -1)
11
+ return output
12
+
13
+
14
+ def sample_top_k(probs: Tensor, k: int) -> Tensor:
15
+ top_k_value, _ = torch.topk(probs, k, dim=-1)
16
+ min_value_top_k = top_k_value[..., [-1]]
17
+ probs *= (probs >= min_value_top_k).float()
18
+ probs.div_(probs.sum(dim=-1, keepdim=True))
19
+ next_token = multinomial(probs, num_samples=1)
20
+ return next_token
21
+
22
+
23
+ def sample_top_p(probs: Tensor, p: float) -> Tensor:
24
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
25
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
26
+ mask = probs_sum - probs_sort > p
27
+ probs_sort *= (~mask).float()
28
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
29
+ next_token = multinomial(probs_sort, num_samples=1)
30
+ next_token = torch.gather(probs_idx, -1, next_token)
31
+
32
+ return next_token
33
+
34
+
35
+ def sample_top_p_top_k(probs: Tensor, p: float, top_k: int):
36
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
37
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
38
+ mask = probs_sum - probs_sort > p
39
+ probs_sort *= (~mask).float()
40
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
41
+ next_token = sample_top_k(probs_sort, top_k)
42
+ next_token = torch.gather(probs_idx, -1, next_token)
43
+ return next_token
src/vui/tok.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import ByT5Tokenizer
3
+
4
+
5
+ class CustomByT5Tokenizer(ByT5Tokenizer):
6
+ def encode(self, text, add_special_tokens=False, **kwargs):
7
+ """
8
+ Override the encode method.
9
+
10
+ Args:
11
+ text (str): Input text
12
+ add_special_tokens (bool): Whether to add BOS/EOS tokens
13
+ """
14
+ # Use the parent class's encode method
15
+ tokens = super().encode(text, add_special_tokens=add_special_tokens, **kwargs)
16
+ return torch.tensor(tokens)
17
+
18
+
19
+ tok = CustomByT5Tokenizer.from_pretrained("google/byt5-small")
src/vui/utils.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import time
3
+ from functools import partial
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+
9
+
10
+ def load_what_you_can(checkpoint: dict, model: torch.nn.Module):
11
+ """
12
+ This method takes a checkpoint and loads as many weights from it as possible:
13
+
14
+ If they are the same shape, there's nothing to do
15
+
16
+ Will load the smallest shape otherwise.
17
+ """
18
+ import torch
19
+
20
+ model_state_dict = model.state_dict()
21
+ checkpoint_state_dict = checkpoint
22
+
23
+ for name, param in checkpoint_state_dict.items():
24
+ if name not in model_state_dict:
25
+ print(f"Ignoring parameter '{name}' because it is not found in the model")
26
+ continue
27
+
28
+ model_state = model_state_dict[name]
29
+ mshape = model_state.shape
30
+ pshape = param.shape
31
+
32
+ if pshape == mshape:
33
+ model_state.copy_(param)
34
+ continue
35
+
36
+ if len(pshape) != len(mshape):
37
+ # Completely different shapes so probably unwise to merge
38
+ continue
39
+
40
+ min_shape = [
41
+ min(param.shape[i], model_state.shape[i]) for i in range(len(param.shape))
42
+ ]
43
+ print(name, "model:", mshape, "chkpt:", pshape, "loading:", min_shape)
44
+ idxs = torch.meshgrid(*[torch.arange(s) for s in min_shape])
45
+ model_state[tuple(idxs)].copy_(param[tuple(idxs)])
46
+
47
+ return model.load_state_dict(model_state_dict)
48
+
49
+
50
+ def multimap(
51
+ items: list, func: callable, workers=4, desc=None, thread=False, chunk_size=128
52
+ ) -> list:
53
+ """
54
+ Quick and dirty multiprocessing that will return the result of func if it returns None
55
+ """
56
+ from tqdm.contrib.concurrent import process_map, thread_map
57
+
58
+ m = thread_map if thread else process_map
59
+ length = None
60
+ try:
61
+ length = len(items)
62
+ except Exception as e:
63
+ print(e, "getting length")
64
+
65
+ results = m(
66
+ func,
67
+ items,
68
+ leave=False,
69
+ desc=desc,
70
+ max_workers=workers,
71
+ total=length,
72
+ chunksize=chunk_size,
73
+ )
74
+ return list(filter(lambda x: x is not None, results))
75
+
76
+
77
+ def round_up(num: float, factor: int):
78
+ return factor * math.ceil(num / factor)
79
+
80
+
81
+ def left_padding_mask(lengths, max_len, device=None, dtype=None):
82
+ masks = []
83
+ if not max_len:
84
+ max_len = max(lengths)
85
+ for l in lengths:
86
+ mask = torch.empty(l, l, device=device, dtype=dtype).fill_(-torch.inf).triu_(1)
87
+ diff = max_len - l
88
+ mask = F.pad(mask, (diff, 0, diff, 0), value=-torch.inf)
89
+ masks.append(mask)
90
+
91
+ masks = torch.stack(masks)
92
+ return masks[:, None]
93
+
94
+
95
+ def seed_all(seed: int):
96
+ import random
97
+
98
+ import numpy as np
99
+ import torch
100
+
101
+ torch.manual_seed(seed)
102
+ np.random.seed(seed)
103
+ random.seed(seed)
104
+
105
+
106
+ def split_bucket_path(url: str) -> tuple[str, str]:
107
+ url = url.replace("s3://", "")
108
+ url = url.replace("sj://", "")
109
+ url = url.replace("r2://", "")
110
+ bucket = url.split("/")[0]
111
+ path = "/".join(url.split("/")[1:])
112
+ return bucket, path
113
+
114
+
115
+ def prob_mask_like(shape, prob: float, device):
116
+ import torch
117
+
118
+ if prob == 1:
119
+ return torch.ones(shape, device=device, dtype=torch.bool)
120
+ elif prob == 0:
121
+ return torch.zeros(shape, device=device, dtype=torch.bool)
122
+ else:
123
+ return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
124
+
125
+
126
+ def round_up_to_multiple(n: int, multiple: int) -> int:
127
+ if n % multiple != 0:
128
+ n += multiple - (n % multiple)
129
+
130
+ return n
131
+
132
+
133
+ def warmup_then_cosine_decay(
134
+ step: int, *, warmup_steps: int, steps: int, min_lr: float, max_lr: float
135
+ ):
136
+ eps = 1e-9
137
+ cooldown_steps = warmup_steps
138
+ if step < warmup_steps:
139
+ return min_lr + step * (max_lr - min_lr) / (warmup_steps)
140
+ elif step > steps:
141
+ return min_lr
142
+ elif step < steps - cooldown_steps:
143
+ decay_ratio = (step - warmup_steps) / (steps - warmup_steps - cooldown_steps)
144
+ # assert 0 <= decay_ratio <= 1
145
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
146
+ return min_lr + coeff * (max_lr - min_lr)
147
+ else:
148
+ # decay from min_lr to 0
149
+ return min_lr * (steps - step) / cooldown_steps + eps
150
+
151
+
152
+ def decay_to_zero(step: int, *, decay_steps: int, steps: int, max_lr: float):
153
+ if step > steps:
154
+ return 0.0
155
+ else:
156
+ gradient = -max_lr / decay_steps
157
+
158
+ return max_lr + gradient * step
159
+
160
+
161
+ def cross_entropy_loss(logits, mask, targets):
162
+ import torch
163
+ import torch.nn.functional as F
164
+
165
+ B, Q, T, _ = logits.size()
166
+ assert logits.shape[:-1] == targets.shape
167
+ assert mask.shape == targets.shape
168
+ loss = torch.zeros([], device=targets.device)
169
+ codebook_losses = []
170
+ for q in range(Q):
171
+ logits_q = (
172
+ logits[:, q, ...].contiguous().view(-1, logits.size(-1))
173
+ ) # [B x T, card]
174
+ targets_q = targets[:, q, ...].contiguous().view(-1) # [B x T]
175
+ mask_q = mask[:, q, ...].contiguous().view(-1) # [B x T]
176
+ ce_targets = targets_q[mask_q]
177
+ ce_logits = logits_q[mask_q]
178
+ q_ce = F.cross_entropy(ce_logits, ce_targets)
179
+ loss += q_ce
180
+ codebook_losses.append(q_ce.detach())
181
+ # average cross entropy across codebooks
182
+ loss = loss / Q
183
+ return loss, codebook_losses
184
+
185
+
186
+ def build_optimizer(
187
+ module, *, weight_decay: float, lr: float, betas: tuple[float, float]
188
+ ):
189
+ import torch
190
+
191
+ param_dict = {pn: p for pn, p in module.named_parameters() if p.requires_grad}
192
+
193
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
194
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
195
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
196
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
197
+ optim_groups = [
198
+ {"params": decay_params, "weight_decay": weight_decay},
199
+ {"params": nodecay_params, "weight_decay": 0.0},
200
+ ]
201
+ # num_decay_params = sum(p.numel() for p in decay_params)
202
+ # num_nodecay_params = sum(p.numel() for p in nodecay_params)
203
+ # print(
204
+ # f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters"
205
+ # )
206
+ # print(
207
+ # f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters"
208
+ # )
209
+ optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=betas, fused=True)
210
+
211
+ return optimizer
212
+
213
+
214
+ def pad_or_cut_right(t: Tensor, padlen: int, value=0) -> Tensor:
215
+ current_len = t.shape[-1]
216
+
217
+ if current_len == padlen:
218
+ return t
219
+
220
+ if current_len < padlen:
221
+ # Need to pad
222
+ pad_size = (0, padlen - current_len)
223
+ return F.pad(t, pad_size, value=value)
224
+ # Need to cut
225
+ return t[:padlen]
226
+
227
+
228
+ def pad_or_cut_left(t: Tensor, value: int) -> Tensor:
229
+ dims = t.ndim
230
+ current_len = t.shape[0]
231
+
232
+ if current_len == value:
233
+ return t
234
+
235
+ if current_len < value:
236
+ # Need to pad
237
+ pad_size = (0,) * (2 * (dims - 1)) + (value - current_len, 0)
238
+ return F.pad(t, pad_size)
239
+ # Need to cut
240
+ return t[-value:]
241
+
242
+
243
+ def dl_pt(orig: str):
244
+ from os.path import exists
245
+
246
+ import torch
247
+
248
+ from vui.storage import s3, split_bucket_path
249
+
250
+ if not orig.endswith(".pt"):
251
+ orig = orig + ".pt"
252
+
253
+ load = partial(torch.load, weights_only=True)
254
+ if exists(orig):
255
+ return load(orig)
256
+ url = "/data/" + orig
257
+
258
+ if exists(url):
259
+ return load(url)
260
+ url = "s3://fluxions/" + orig
261
+
262
+ bucket, key = split_bucket_path(url)
263
+ response = s3.get_object(Bucket=bucket, Key=key)
264
+ return load(response["Body"])
265
+
266
+
267
+ def dl_ogg(url: str, start=0, end=-1, sr=None):
268
+ import re
269
+ from os.path import exists
270
+
271
+ import soundfile as sf
272
+ import torch
273
+
274
+ search_sr = re.search(r"(\d+)/", url)
275
+ if search_sr:
276
+ sr = int(search_sr.group(1))
277
+
278
+ local_file = exists(url)
279
+
280
+ if exists("/data/audio/" + url):
281
+ local_file = True
282
+ url = "/data/audio/" + url
283
+
284
+ if not local_file:
285
+ from vui.storage import s3
286
+
287
+ url = "s3://fluxions/" + url
288
+ b, p = split_bucket_path(url)
289
+ url = s3.get_object(Bucket=b, Key=p)["Body"]
290
+
291
+ if sr is None:
292
+ if local_file:
293
+ sr = sf.info(url).samplerate
294
+ else:
295
+ sr = sf.info(url.read()).samplerate
296
+
297
+ start_frame = int(start * sr)
298
+ num_frames = int(end * sr) - start_frame
299
+ wav, _ = sf.read(url, frames=num_frames, start=start_frame, always_2d=True)
300
+ wav = torch.from_numpy(wav).float()
301
+ wav = wav.T.mean(0, keepdim=True)
302
+ return wav, sr
303
+
304
+
305
+ class timer:
306
+ def __init__(self, name=""):
307
+ self.name = name
308
+
309
+ def __enter__(self):
310
+ self.t = time.perf_counter()
311
+ return self
312
+
313
+ def __exit__(self, exc_type, exc_val, exc_tb):
314
+ elapsed = time.perf_counter() - self.t
315
+ print(f"{self.name} {elapsed:.4f}")
316
+
317
+
318
+ @torch.inference_mode()
319
+ def decode_audio_from_indices(model, indices, chunk_size=64):
320
+ """
321
+ Decodes audio from indices in batches to avoid memory issues.
322
+
323
+ Args:
324
+ model: Codec
325
+ indices: Tensor of shape (1, n_quantizers, sequence_length)
326
+ chunk_size: Number of samples to process at once
327
+
328
+ Returns:
329
+ Tensor of reconstructed audio
330
+ """
331
+ device = model.device
332
+ indices = indices.to(device)
333
+ _, _, seq_len = indices.shape
334
+ chunks = seq_len // chunk_size + (1 if seq_len % chunk_size != 0 else 0)
335
+
336
+ audio_chunks = []
337
+ for i in range(chunks):
338
+ start_idx = i * chunk_size
339
+ end_idx = min(start_idx + chunk_size, seq_len)
340
+ chunk_indices = indices[:, :, start_idx:end_idx]
341
+ chunk_audio = model.from_indices(chunk_indices)
342
+ audio_chunks.append(chunk_audio.cpu())
343
+
344
+ full_audio = torch.cat(audio_chunks, dim=-1)
345
+ return full_audio.flatten()
346
+
347
+
348
+ def normalize_loudness(waveform, sample_rate: int, lufs: float = -12.0):
349
+ """
350
+ Normalize the loudness of an audio tensor using torchaudio.transforms.Loudness.
351
+
352
+ Args:
353
+ audio_tensor (torch.Tensor): Input audio tensor of shape (channels, samples)
354
+ sample_rate (int): Sampling rate of the audio
355
+ target_loudness (float): Target loudness in LUFS (default: -16.0 LUFS)
356
+
357
+ Returns:
358
+ torch.Tensor: Loudness-normalized audio tensor
359
+ """
360
+ import torchaudio
361
+
362
+ # Ensure the input tensor is 2D (add channel dimension if it's 1D)
363
+ if waveform.ndim == 1:
364
+ waveform = waveform.unsqueeze(0)
365
+
366
+ # Create a Loudness transform
367
+ loudness_transform = torchaudio.transforms.Loudness(sample_rate)
368
+
369
+ # Measure the current loudness
370
+ current_loudness = loudness_transform(waveform)
371
+
372
+ # Calculate the required gain
373
+ gain_db = lufs - current_loudness
374
+
375
+ # Convert gain from dB to linear scale
376
+ gain_linear = torch.pow(10, gain_db / 20)
377
+
378
+ # Apply the gain to normalize loudness
379
+ normalized_audio = waveform * gain_linear
380
+
381
+ return normalized_audio
382
+
383
+
384
+ def get_basename_without_extension(file_path):
385
+ from pathlib import Path
386
+
387
+ p = Path(file_path)
388
+ return p.stem
389
+
390
+
391
+ def ollama(prompt, MODEL=None):
392
+ import os
393
+
394
+ import requests
395
+
396
+ OLLAMA_HOST = "http://localhost:11434"
397
+ API = f"{OLLAMA_HOST}/api/generate"
398
+
399
+ if MODEL is None:
400
+ MODEL = os.environ.get("OLLAMA_MODEL", "gemma:1b")
401
+
402
+ payload = {
403
+ "model": MODEL,
404
+ "prompt": prompt,
405
+ "stream": False,
406
+ "options": {"temperature": 0.9, "top_p": 0.9, "max_tokens": 1000},
407
+ }
408
+
409
+ try:
410
+ response = requests.post(API, json=payload)
411
+ response.raise_for_status() # Raise exception for HTTP errors
412
+ result = response.json()
413
+ return result.get("response", "")
414
+ except requests.exceptions.RequestException as e:
415
+ print(f"Error calling Ollama API: {e}")
416
+ return ""
417
+
418
+
419
+ def decompile_state_dict(state_dict):
420
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
421
+ # state_dict = convert_old_weight_norm_to_new(state_dict)
422
+ return {k.replace("module.", ""): v for k, v in state_dict.items()}
src/vui/vad.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ from collections.abc import Callable
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from pyannote.audio import Model, Pipeline
10
+ from pyannote.audio.core.io import AudioFile
11
+ from pyannote.audio.pipelines import VoiceActivityDetection
12
+ from pyannote.audio.pipelines.utils import PipelineModel
13
+ from pyannote.core import Annotation, Segment, SlidingWindowFeature
14
+ from tqdm import tqdm
15
+
16
+ VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
17
+
18
+
19
+ pipeline = None
20
+ pipeline_name = "pyannote/voice-activity-detection"
21
+
22
+
23
+ @torch.autocast("cuda", enabled=False)
24
+ def detect_voice_activity(waveform, pipe=None):
25
+ """16khz"""
26
+ waveform = waveform.flatten().float()[None]
27
+ global pipeline
28
+
29
+ if pipe is not None:
30
+ pipeline = pipe
31
+ elif pipeline is None:
32
+ pipeline = Pipeline.from_pretrained(pipeline_name)
33
+ initial_params = {
34
+ "onset": 0.8,
35
+ "offset": 0.5,
36
+ "min_duration_on": 0,
37
+ "min_duration_off": 0.0,
38
+ }
39
+ pipeline.instantiate(initial_params)
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ pipeline = pipeline.to(device)
42
+
43
+ vad = pipeline({"waveform": waveform, "sample_rate": 16000})
44
+ segments = [
45
+ (segment.start, segment.end) for segment in vad.get_timeline().support()
46
+ ]
47
+
48
+ return segments
49
+
50
+
51
+ def load_vad_model(
52
+ device,
53
+ vad_onset=0.500,
54
+ vad_offset=0.363,
55
+ use_auth_token=None,
56
+ model_fp=None,
57
+ batch_size=32,
58
+ ):
59
+ model_dir = torch.hub._get_torch_home()
60
+ os.makedirs(model_dir, exist_ok=True)
61
+ if model_fp is None:
62
+ model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
63
+ if os.path.exists(model_fp) and not os.path.isfile(model_fp):
64
+ raise RuntimeError(f"{model_fp} exists and is not a regular file")
65
+
66
+ if not os.path.isfile(model_fp):
67
+ with (
68
+ urllib.request.urlopen(VAD_SEGMENTATION_URL) as source,
69
+ open(model_fp, "wb") as output,
70
+ ):
71
+ with tqdm(
72
+ total=int(source.info().get("Content-Length")),
73
+ ncols=80,
74
+ unit="iB",
75
+ unit_scale=True,
76
+ unit_divisor=1024,
77
+ ) as loop:
78
+ while True:
79
+ buffer = source.read(8192)
80
+ if not buffer:
81
+ break
82
+
83
+ output.write(buffer)
84
+ loop.update(len(buffer))
85
+
86
+ model_bytes = open(model_fp, "rb").read()
87
+ if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split("/")[-2]:
88
+ raise RuntimeError(
89
+ "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
90
+ )
91
+
92
+ vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
93
+ hyperparameters = {
94
+ "onset": vad_onset,
95
+ "offset": vad_offset,
96
+ "min_duration_on": 0.1,
97
+ "min_duration_off": 0.1,
98
+ }
99
+ vad_pipeline = VoiceActivitySegmentation(
100
+ segmentation=vad_model, device=torch.device(device), batch_size=batch_size
101
+ )
102
+ vad_pipeline.instantiate(hyperparameters)
103
+
104
+ return vad_pipeline
105
+
106
+
107
+ class Binarize:
108
+ """Binarize detection scores using hysteresis thresholding, with min-cut operation
109
+ to ensure not segments are longer than max_duration.
110
+
111
+ Parameters
112
+ ----------
113
+ onset : float, optional
114
+ Onset threshold. Defaults to 0.5.
115
+ offset : float, optional
116
+ Offset threshold. Defaults to `onset`.
117
+ min_duration_on : float, optional
118
+ Remove active regions shorter than that many seconds. Defaults to 0s.
119
+ min_duration_off : float, optional
120
+ Fill inactive regions shorter than that many seconds. Defaults to 0s.
121
+ pad_onset : float, optional
122
+ Extend active regions by moving their start time by that many seconds.
123
+ Defaults to 0s.
124
+ pad_offset : float, optional
125
+ Extend active regions by moving their end time by that many seconds.
126
+ Defaults to 0s.
127
+ max_duration: float
128
+ The maximum length of an active segment, divides segment at timestamp with lowest score.
129
+ Reference
130
+ ---------
131
+ Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
132
+ RNN-based Voice Activity Detection", InterSpeech 2015.
133
+
134
+ Modified by Max Bain to include WhisperX's min-cut operation
135
+ https://arxiv.org/abs/2303.00747
136
+
137
+ Pyannote-audio
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ onset: float = 0.5,
143
+ offset: float | None = None,
144
+ min_duration_on: float = 0.0,
145
+ min_duration_off: float = 0.0,
146
+ pad_onset: float = 0.0,
147
+ pad_offset: float = 0.0,
148
+ max_duration: float = float("inf"),
149
+ ):
150
+ super().__init__()
151
+
152
+ self.onset = onset
153
+ self.offset = offset or onset
154
+
155
+ self.pad_onset = pad_onset
156
+ self.pad_offset = pad_offset
157
+
158
+ self.min_duration_on = min_duration_on
159
+ self.min_duration_off = min_duration_off
160
+
161
+ self.max_duration = max_duration
162
+
163
+ def __call__(self, scores: SlidingWindowFeature) -> Annotation:
164
+ """Binarize detection scores
165
+ Parameters
166
+ ----------
167
+ scores : SlidingWindowFeature
168
+ Detection scores.
169
+ Returns
170
+ -------
171
+ active : Annotation
172
+ Binarized scores.
173
+ """
174
+
175
+ num_frames, num_classes = scores.data.shape
176
+ frames = scores.sliding_window
177
+ timestamps = [frames[i].middle for i in range(num_frames)]
178
+
179
+ # annotation meant to store 'active' regions
180
+ active = Annotation()
181
+ for k, k_scores in enumerate(scores.data.T):
182
+ label = k if scores.labels is None else scores.labels[k]
183
+
184
+ # initial state
185
+ start = timestamps[0]
186
+ is_active = k_scores[0] > self.onset
187
+ curr_scores = [k_scores[0]]
188
+ curr_timestamps = [start]
189
+ t = start
190
+ for t, y in zip(timestamps[1:], k_scores[1:], strict=False):
191
+ # currently active
192
+ if is_active:
193
+ curr_duration = t - start
194
+ if curr_duration > self.max_duration:
195
+ search_after = len(curr_scores) // 2
196
+ # divide segment
197
+ min_score_div_idx = search_after + np.argmin(
198
+ curr_scores[search_after:]
199
+ )
200
+ min_score_t = curr_timestamps[min_score_div_idx]
201
+ region = Segment(
202
+ start - self.pad_onset, min_score_t + self.pad_offset
203
+ )
204
+ active[region, k] = label
205
+ start = curr_timestamps[min_score_div_idx]
206
+ curr_scores = curr_scores[min_score_div_idx + 1 :]
207
+ curr_timestamps = curr_timestamps[min_score_div_idx + 1 :]
208
+ # switching from active to inactive
209
+ elif y < self.offset:
210
+ region = Segment(start - self.pad_onset, t + self.pad_offset)
211
+ active[region, k] = label
212
+ start = t
213
+ is_active = False
214
+ curr_scores = []
215
+ curr_timestamps = []
216
+ curr_scores.append(y)
217
+ curr_timestamps.append(t)
218
+ # currently inactive
219
+ else:
220
+ # switching from inactive to active
221
+ if y > self.onset:
222
+ start = t
223
+ is_active = True
224
+
225
+ # if active at the end, add final region
226
+ if is_active:
227
+ region = Segment(start - self.pad_onset, t + self.pad_offset)
228
+ active[region, k] = label
229
+
230
+ # because of padding, some active regions might be overlapping: merge them.
231
+ # also: fill same speaker gaps shorter than min_duration_off
232
+ if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
233
+ if self.max_duration < float("inf"):
234
+ raise NotImplementedError("This would break current max_duration param")
235
+ active = active.support(collar=self.min_duration_off)
236
+
237
+ # remove tracks shorter than min_duration_on
238
+ if self.min_duration_on > 0:
239
+ for segment, track in list(active.itertracks()):
240
+ if segment.duration < self.min_duration_on:
241
+ del active[segment, track]
242
+
243
+ return active
244
+
245
+
246
+ class VoiceActivitySegmentation(VoiceActivityDetection):
247
+ def __init__(
248
+ self,
249
+ segmentation: PipelineModel = "pyannote/segmentation",
250
+ fscore: bool = False,
251
+ use_auth_token: str | None = None,
252
+ **inference_kwargs,
253
+ ):
254
+ super().__init__(
255
+ segmentation=segmentation,
256
+ fscore=fscore,
257
+ use_auth_token=use_auth_token,
258
+ **inference_kwargs,
259
+ )
260
+
261
+ def apply(self, file: AudioFile, hook: Callable | None = None) -> Annotation:
262
+ """Apply voice activity detection
263
+
264
+ Parameters
265
+ ----------
266
+ file : AudioFile
267
+ Processed file.
268
+ hook : callable, optional
269
+ Hook called after each major step of the pipeline with the following
270
+ signature: hook("step_name", step_artefact, file=file)
271
+
272
+ Returns
273
+ -------
274
+ speech : Annotation
275
+ Speech regions.
276
+ """
277
+
278
+ # setup hook (e.g. for debugging purposes)
279
+ hook = self.setup_hook(file, hook=hook)
280
+
281
+ # apply segmentation model (only if needed)
282
+ # output shape is (num_chunks, num_frames, 1)
283
+ if self.training:
284
+ if self.CACHED_SEGMENTATION in file:
285
+ segmentations = file[self.CACHED_SEGMENTATION]
286
+ else:
287
+ segmentations = self._segmentation(file)
288
+ file[self.CACHED_SEGMENTATION] = segmentations
289
+ else:
290
+ segmentations: SlidingWindowFeature = self._segmentation(file)
291
+
292
+ return segmentations
293
+
294
+
295
+ def merge_vad(
296
+ vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0
297
+ ):
298
+ active = Annotation()
299
+ for k, vad_t in enumerate(vad_arr):
300
+ region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
301
+ active[region, k] = 1
302
+
303
+ if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
304
+ active = active.support(collar=min_duration_off)
305
+
306
+ # remove tracks shorter than min_duration_on
307
+ if min_duration_on > 0:
308
+ for segment, track in list(active.itertracks()):
309
+ if segment.duration < min_duration_on:
310
+ del active[segment, track]
311
+
312
+ active = active.for_json()
313
+ active_segs = pd.DataFrame([x["segment"] for x in active["content"]])
314
+ return active_segs
315
+
316
+
317
+ def merge_chunks(
318
+ segments,
319
+ chunk_size,
320
+ onset: float = 0.5,
321
+ offset: float | None = None,
322
+ ):
323
+ """
324
+ Merge operation described in paper
325
+ """
326
+ curr_end = 0
327
+ merged_segments = []
328
+ seg_idxs = []
329
+
330
+ assert chunk_size > 0
331
+ binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
332
+ segments = binarize(segments)
333
+ segments_list = []
334
+ for speech_turn in segments.get_timeline():
335
+ segments_list.append(Segment(speech_turn.start, speech_turn.end))
336
+
337
+ if len(segments_list) == 0:
338
+ print("No active speech found in audio")
339
+ return []
340
+ # assert segments_list, "segments_list is empty."
341
+ # Make sur the starting point is the start of the segment.
342
+ curr_start = segments_list[0].start
343
+
344
+ for seg in segments_list:
345
+ if seg.end - curr_start > chunk_size and curr_end - curr_start > 0:
346
+ merged_segments.append(
347
+ {
348
+ "start": curr_start,
349
+ "end": curr_end,
350
+ }
351
+ )
352
+ curr_start = seg.start
353
+ seg_idxs = []
354
+ curr_end = seg.end
355
+ seg_idxs.append((seg.start, seg.end))
356
+
357
+ merged_segments.append(
358
+ {
359
+ "start": curr_start,
360
+ "end": curr_end,
361
+ }
362
+ )
363
+ return merged_segments