Spaces:
Running
on
Zero
Running
on
Zero
Harry Coultas Blum
commited on
Commit
·
88afac1
1
Parent(s):
a2e6acf
INIT
Browse files- app.py +385 -0
- inference.py +12 -0
- requirements.txt +17 -0
- src/vui/__init__.py +1 -0
- src/vui/config.py +41 -0
- src/vui/fluac.py +707 -0
- src/vui/inference.py +405 -0
- src/vui/model.py +445 -0
- src/vui/notebook.py +41 -0
- src/vui/patterns.py +423 -0
- src/vui/rope.py +54 -0
- src/vui/sampling.py +43 -0
- src/vui/tok.py +19 -0
- src/vui/utils.py +422 -0
- src/vui/vad.py +363 -0
app.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from vui.inference import render
|
7 |
+
from vui.model import Vui
|
8 |
+
|
9 |
+
|
10 |
+
def get_available_models():
|
11 |
+
"""Extract all CAPs static variables from Vui class that end with .pt"""
|
12 |
+
models = {}
|
13 |
+
for attr_name in dir(Vui):
|
14 |
+
if attr_name.isupper():
|
15 |
+
attr_value = getattr(Vui, attr_name)
|
16 |
+
if isinstance(attr_value, str) and attr_value.endswith(".pt"):
|
17 |
+
models[attr_name] = attr_value
|
18 |
+
return models
|
19 |
+
|
20 |
+
|
21 |
+
AVAILABLE_MODELS = get_available_models()
|
22 |
+
print(f"Available models: {list(AVAILABLE_MODELS.keys())}")
|
23 |
+
|
24 |
+
current_model = None
|
25 |
+
current_model_name = None
|
26 |
+
|
27 |
+
|
28 |
+
def load_and_warm_model(model_name):
|
29 |
+
"""Load and warm up a specific model"""
|
30 |
+
global current_model, current_model_name
|
31 |
+
|
32 |
+
if current_model_name == model_name and current_model is not None:
|
33 |
+
print(f"Model {model_name} already loaded and warmed up!")
|
34 |
+
return current_model
|
35 |
+
|
36 |
+
print(f"Loading model {model_name}...")
|
37 |
+
model_path = AVAILABLE_MODELS[model_name]
|
38 |
+
model = Vui.from_pretrained_inf(model_path).cuda()
|
39 |
+
|
40 |
+
print(f"Compiling model {model_name}...")
|
41 |
+
model.decoder = torch.compile(model.decoder, fullgraph=True)
|
42 |
+
|
43 |
+
print(f"Warming up model {model_name}...")
|
44 |
+
warmup_text = "Hello, this is a test. Let's say some random shizz"
|
45 |
+
render(
|
46 |
+
model,
|
47 |
+
warmup_text,
|
48 |
+
max_secs=10,
|
49 |
+
)
|
50 |
+
|
51 |
+
current_model = model
|
52 |
+
current_model_name = model_name
|
53 |
+
print(f"Model {model_name} loaded and warmed up successfully!")
|
54 |
+
return model
|
55 |
+
|
56 |
+
|
57 |
+
# Load default model (COHOST)
|
58 |
+
default_model = (
|
59 |
+
"COHOST" if "COHOST" in AVAILABLE_MODELS else list(AVAILABLE_MODELS.keys())[0]
|
60 |
+
)
|
61 |
+
model = load_and_warm_model(default_model)
|
62 |
+
|
63 |
+
# Preload sample 1 (index 0) with current model
|
64 |
+
print("Preloading sample 1...")
|
65 |
+
sample_1_text = """Welcome to Fluxions, the podcast where... we uh explore how technology is shaping the world around us. I'm your host, Alex.
|
66 |
+
[breath] And I'm Jamie um [laugh] today, we're diving into a [hesitate] topic that's transforming customer service uh voice technology for agents.
|
67 |
+
That's right. We're [hesitate] talking about the AI-driven tools that are making those long, frustrating customer service calls a little more bearable, for both the customer and the agents."""
|
68 |
+
|
69 |
+
sample_1_audio = render(
|
70 |
+
current_model,
|
71 |
+
sample_1_text,
|
72 |
+
)
|
73 |
+
sample_1_audio = sample_1_audio.cpu()
|
74 |
+
sample_1_audio = sample_1_audio[..., :-2000] # Trim end artifacts
|
75 |
+
preloaded_sample_1 = (model.codec.config.sample_rate, sample_1_audio.flatten().numpy())
|
76 |
+
print("Sample 1 preloaded successfully!")
|
77 |
+
print("Models ready for inference!")
|
78 |
+
|
79 |
+
# Sample texts for quick testing - keeping original examples intact
|
80 |
+
SAMPLE_TEXTS = [
|
81 |
+
"""Welcome to Fluxions, the podcast where... we uh explore how technology is shaping the world around us. I'm your host, Alex.
|
82 |
+
[breath] And I'm Jamie um [laugh] today, we're diving into a [hesitate] topic that's transforming customer service uh voice technology for agents.
|
83 |
+
That's right. We're [hesitate] talking about the AI-driven tools that are making those long, frustrating customer service calls a little more bearable, for both the customer and the agents.""",
|
84 |
+
"""Um, hey Sarah, so I just left the meeting with the, uh, rabbit focus group and they are absolutely loving the new heritage carrots! Like, I've never seen such enthusiastic thumping in my life! The purple ones are testing through the roof - apparently the flavor profile is just amazing - and they're willing to pay a premium for them! We need to, like, triple production on those immediately and maybe consider a subscription model? Anyway, gotta go, but let's touch base tomorrow about scaling this before the Easter rush hits!""",
|
85 |
+
"""What an absolute joke, like I'm really not enjoying this situation where I'm just forced to say things.""",
|
86 |
+
""" So [breath] I don't know if you've been there [breath] but I'm really pissed off.
|
87 |
+
Oh no! Why, what happened?
|
88 |
+
Well I went to this cafe hearth, and they gave me the worst toastie I've ever had, it didn't come with salad it was just raw.
|
89 |
+
Well that's awful what kind of toastie was it?
|
90 |
+
It was supposed to be a chicken bacon lettuce tomatoe, but it was fucking shite, like really bad and I honestly would have preferred to eat my own shit.
|
91 |
+
[laugh] well, it must have been awful for you, I'm sorry to hear that, why don't we move on to brighter topics, like the good old weather?""",
|
92 |
+
]
|
93 |
+
|
94 |
+
|
95 |
+
def text_to_speech(text, temperature=0.5, top_k=100, top_p=None, max_duration=60):
|
96 |
+
"""
|
97 |
+
Convert text to speech using the current Vui model
|
98 |
+
|
99 |
+
Args:
|
100 |
+
text (str): Input text to convert to speech
|
101 |
+
temperature (float): Sampling temperature (0.1-1.0)
|
102 |
+
top_k (int): Top-k sampling parameter
|
103 |
+
top_p (float): Top-p sampling parameter (None to disable)
|
104 |
+
max_duration (int): Maximum audio duration in seconds
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
tuple: (sample_rate, audio_array) for Gradio audio output
|
108 |
+
"""
|
109 |
+
if not text.strip():
|
110 |
+
return None, "Please enter some text to convert to speech."
|
111 |
+
|
112 |
+
if current_model is None:
|
113 |
+
return None, "No model loaded. Please select a model first."
|
114 |
+
|
115 |
+
print(f"Generating speech for: {text[:50]}... using model {current_model_name}")
|
116 |
+
|
117 |
+
# Generate speech using render
|
118 |
+
t1 = time.perf_counter()
|
119 |
+
result = render(
|
120 |
+
current_model,
|
121 |
+
text.strip(),
|
122 |
+
temperature=temperature,
|
123 |
+
top_k=top_k,
|
124 |
+
top_p=top_p,
|
125 |
+
max_secs=max_duration,
|
126 |
+
)
|
127 |
+
|
128 |
+
# Long text: render returns (codes, text, audio) tuple
|
129 |
+
waveform = result
|
130 |
+
|
131 |
+
# waveform is already decoded audio from generate_infinite
|
132 |
+
waveform = waveform.cpu()
|
133 |
+
sr = current_model.codec.config.sample_rate
|
134 |
+
|
135 |
+
# Calculate generation speed
|
136 |
+
generation_time = time.perf_counter() - t1
|
137 |
+
audio_duration = waveform.shape[-1] / sr
|
138 |
+
speed_factor = audio_duration / generation_time
|
139 |
+
|
140 |
+
# Trim end artifacts if needed
|
141 |
+
if waveform.shape[-1] > 2000:
|
142 |
+
waveform = waveform[..., :-2000]
|
143 |
+
|
144 |
+
# Convert to numpy array for Gradio
|
145 |
+
audio_array = waveform.flatten().numpy()
|
146 |
+
|
147 |
+
info = f"Generated {audio_duration:.1f}s of audio in {generation_time:.1f}s ({speed_factor:.1f}x realtime) with {current_model_name}"
|
148 |
+
print(info)
|
149 |
+
|
150 |
+
return (sr, audio_array), info
|
151 |
+
|
152 |
+
|
153 |
+
def change_model(model_name):
|
154 |
+
"""Change the active model and return status"""
|
155 |
+
try:
|
156 |
+
load_and_warm_model(model_name)
|
157 |
+
return f"Successfully loaded and warmed up model: {model_name}"
|
158 |
+
except Exception as e:
|
159 |
+
return f"Error loading model {model_name}: {str(e)}"
|
160 |
+
|
161 |
+
|
162 |
+
def load_sample_text(sample_index):
|
163 |
+
"""Load a sample text for quick testing"""
|
164 |
+
if 0 <= sample_index < len(SAMPLE_TEXTS):
|
165 |
+
return SAMPLE_TEXTS[sample_index]
|
166 |
+
return ""
|
167 |
+
|
168 |
+
|
169 |
+
# Create Gradio interface
|
170 |
+
with gr.Blocks(
|
171 |
+
title="Vui",
|
172 |
+
theme=gr.themes.Soft(),
|
173 |
+
head="""
|
174 |
+
<script>
|
175 |
+
document.addEventListener('DOMContentLoaded', function() {
|
176 |
+
// Add keyboard shortcuts
|
177 |
+
document.addEventListener('keydown', function(e) {
|
178 |
+
// Ctrl/Cmd + Enter to generate (but not when Shift is pressed)
|
179 |
+
if ((e.ctrlKey) && e.key === 'Enter' && !e.shiftKey) {
|
180 |
+
e.preventDefault();
|
181 |
+
const generateBtn = document.querySelector('button[variant="primary"]');
|
182 |
+
if (generateBtn && !generateBtn.disabled) {
|
183 |
+
generateBtn.click();
|
184 |
+
}
|
185 |
+
}
|
186 |
+
else if ((e.ctrlKey) && e.code === 'Space') {
|
187 |
+
e.preventDefault();
|
188 |
+
const audioElement = document.querySelector('audio');
|
189 |
+
if (audioElement) {
|
190 |
+
if (audioElement.paused) {
|
191 |
+
audioElement.play();
|
192 |
+
} else {
|
193 |
+
audioElement.pause();
|
194 |
+
}
|
195 |
+
}
|
196 |
+
}
|
197 |
+
});
|
198 |
+
|
199 |
+
// Auto-play audio when it's updated
|
200 |
+
const observer = new MutationObserver(function(mutations) {
|
201 |
+
mutations.forEach(function(mutation) {
|
202 |
+
if (mutation.type === 'childList') {
|
203 |
+
const audioElements = document.querySelectorAll('audio');
|
204 |
+
audioElements.forEach(function(audio) {
|
205 |
+
if (audio.src && !audio.dataset.hasAutoplayListener) {
|
206 |
+
audio.dataset.hasAutoplayListener = 'true';
|
207 |
+
audio.addEventListener('loadeddata', function() {
|
208 |
+
// Small delay to ensure audio is ready
|
209 |
+
setTimeout(() => {
|
210 |
+
audio.play().catch(e => {
|
211 |
+
console.log('Autoplay prevented by browser:', e);
|
212 |
+
});
|
213 |
+
}, 100);
|
214 |
+
});
|
215 |
+
}
|
216 |
+
});
|
217 |
+
}
|
218 |
+
});
|
219 |
+
});
|
220 |
+
|
221 |
+
observer.observe(document.body, {
|
222 |
+
childList: true,
|
223 |
+
subtree: true
|
224 |
+
});
|
225 |
+
|
226 |
+
});
|
227 |
+
</script>
|
228 |
+
""",
|
229 |
+
) as demo:
|
230 |
+
|
231 |
+
gr.Markdown(
|
232 |
+
"**Keyboard Shortcuts:** `Ctrl + Enter` to generate` or Ctrl + Space to pause"
|
233 |
+
)
|
234 |
+
|
235 |
+
with gr.Row():
|
236 |
+
with gr.Column(scale=2):
|
237 |
+
# Model selector
|
238 |
+
model_dropdown = gr.Dropdown(
|
239 |
+
choices=list(AVAILABLE_MODELS.keys()),
|
240 |
+
value=default_model,
|
241 |
+
label=None,
|
242 |
+
info="Select a voice model",
|
243 |
+
)
|
244 |
+
|
245 |
+
# Model status
|
246 |
+
model_status = gr.Textbox(
|
247 |
+
label=None,
|
248 |
+
value=f"Model {default_model} loaded and ready",
|
249 |
+
interactive=False,
|
250 |
+
lines=1,
|
251 |
+
)
|
252 |
+
|
253 |
+
# Text input
|
254 |
+
text_input = gr.Textbox(
|
255 |
+
label=None,
|
256 |
+
placeholder="Enter the text you want to convert to speech...",
|
257 |
+
lines=5,
|
258 |
+
max_lines=10,
|
259 |
+
)
|
260 |
+
|
261 |
+
with gr.Column(scale=1):
|
262 |
+
# Audio output with autoplay
|
263 |
+
audio_output = gr.Audio(
|
264 |
+
label="Generated Speech", type="numpy", autoplay=True # Enable autoplay
|
265 |
+
)
|
266 |
+
|
267 |
+
# Info output
|
268 |
+
info_output = gr.Textbox(
|
269 |
+
label="Generation Info", lines=3, interactive=False
|
270 |
+
)
|
271 |
+
|
272 |
+
with gr.Row():
|
273 |
+
with gr.Column(scale=2):
|
274 |
+
|
275 |
+
# Sample text buttons
|
276 |
+
gr.Markdown("**Quick samples:**")
|
277 |
+
with gr.Row():
|
278 |
+
sample_btns = []
|
279 |
+
for i, sample in enumerate(SAMPLE_TEXTS):
|
280 |
+
btn = gr.Button(f"Sample {i+1}", size="sm")
|
281 |
+
if i == 0: # Sample 1 (index 0) - use preloaded audio
|
282 |
+
|
283 |
+
def load_preloaded_sample_1():
|
284 |
+
return (
|
285 |
+
SAMPLE_TEXTS[0],
|
286 |
+
preloaded_sample_1,
|
287 |
+
"Preloaded sample 1 audio",
|
288 |
+
)
|
289 |
+
|
290 |
+
btn.click(
|
291 |
+
fn=load_preloaded_sample_1,
|
292 |
+
outputs=[text_input, audio_output, info_output],
|
293 |
+
)
|
294 |
+
else:
|
295 |
+
btn.click(
|
296 |
+
fn=lambda idx=i: SAMPLE_TEXTS[idx], outputs=text_input
|
297 |
+
)
|
298 |
+
|
299 |
+
# Generation parameters
|
300 |
+
with gr.Accordion("Advanced Settings", open=False):
|
301 |
+
temperature = gr.Slider(
|
302 |
+
minimum=0.1,
|
303 |
+
maximum=1.0,
|
304 |
+
value=0.5,
|
305 |
+
step=0.1,
|
306 |
+
label="Temperature",
|
307 |
+
info="Higher values = more varied speech",
|
308 |
+
)
|
309 |
+
|
310 |
+
top_k = gr.Slider(
|
311 |
+
minimum=1,
|
312 |
+
maximum=200,
|
313 |
+
value=100,
|
314 |
+
step=1,
|
315 |
+
label="Top-K",
|
316 |
+
info="Number of top tokens to consider",
|
317 |
+
)
|
318 |
+
|
319 |
+
use_top_p = gr.Checkbox(label="Use Top-P sampling", value=False)
|
320 |
+
top_p = gr.Slider(
|
321 |
+
minimum=0.1,
|
322 |
+
maximum=1.0,
|
323 |
+
value=0.9,
|
324 |
+
step=0.05,
|
325 |
+
label="Top-P",
|
326 |
+
info="Cumulative probability threshold",
|
327 |
+
visible=False,
|
328 |
+
)
|
329 |
+
|
330 |
+
max_duration = gr.Slider(
|
331 |
+
minimum=5,
|
332 |
+
maximum=120,
|
333 |
+
value=60,
|
334 |
+
step=5,
|
335 |
+
label="Max Duration (seconds)",
|
336 |
+
info="Maximum length of generated audio",
|
337 |
+
)
|
338 |
+
|
339 |
+
# Show/hide top_p based on checkbox
|
340 |
+
use_top_p.change(
|
341 |
+
fn=lambda x: gr.update(visible=x), inputs=use_top_p, outputs=top_p
|
342 |
+
)
|
343 |
+
|
344 |
+
# Generate button
|
345 |
+
generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
|
346 |
+
|
347 |
+
# Examples section
|
348 |
+
gr.Markdown("## 📝 Example Texts")
|
349 |
+
with gr.Accordion("View example texts", open=False):
|
350 |
+
for i, sample in enumerate(SAMPLE_TEXTS):
|
351 |
+
gr.Markdown(f"**Sample {i+1}:** {sample}")
|
352 |
+
|
353 |
+
# Connect the model change function
|
354 |
+
model_dropdown.change(fn=change_model, inputs=model_dropdown, outputs=model_status)
|
355 |
+
|
356 |
+
# Connect the generate function
|
357 |
+
def generate_wrapper(text, temp, k, use_p, p, duration):
|
358 |
+
top_p_val = p if use_p else None
|
359 |
+
return text_to_speech(text, temp, k, top_p_val, duration)
|
360 |
+
|
361 |
+
generate_btn.click(
|
362 |
+
fn=generate_wrapper,
|
363 |
+
inputs=[text_input, temperature, top_k, use_top_p, top_p, max_duration],
|
364 |
+
outputs=[audio_output, info_output],
|
365 |
+
)
|
366 |
+
|
367 |
+
# Also allow Enter key to generate
|
368 |
+
text_input.submit(
|
369 |
+
fn=generate_wrapper,
|
370 |
+
inputs=[text_input, temperature, top_k, use_top_p, top_p, max_duration],
|
371 |
+
outputs=[audio_output, info_output],
|
372 |
+
)
|
373 |
+
|
374 |
+
# Auto-load sample 1 on startup
|
375 |
+
demo.load(
|
376 |
+
fn=lambda: (
|
377 |
+
SAMPLE_TEXTS[0],
|
378 |
+
preloaded_sample_1,
|
379 |
+
"Sample 1 preloaded and ready!",
|
380 |
+
),
|
381 |
+
outputs=[text_input, audio_output, info_output],
|
382 |
+
)
|
383 |
+
|
384 |
+
if __name__ == "__main__":
|
385 |
+
demo.launch(server_name="0.0.0.0", share=True)
|
inference.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchaudio
|
2 |
+
|
3 |
+
from vui.inference import render
|
4 |
+
from vui.model import Vui
|
5 |
+
|
6 |
+
model = Vui.from_pretrained().cuda()
|
7 |
+
waveform = render(
|
8 |
+
model,
|
9 |
+
"Hey, here is some random stuff, usually something quite long as the shorter the text the less likely the model can cope!",
|
10 |
+
)
|
11 |
+
print(waveform.shape)
|
12 |
+
torchaudio.save("out.opus", waveform[0], 22050)
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops
|
2 |
+
huggingface_hub[hf_transfer]
|
3 |
+
inflect
|
4 |
+
gradio
|
5 |
+
numba
|
6 |
+
numpy
|
7 |
+
openai-whisper
|
8 |
+
feedparser
|
9 |
+
pydantic
|
10 |
+
pyannote.audio
|
11 |
+
soundfile
|
12 |
+
sphn
|
13 |
+
tiktoken
|
14 |
+
torch
|
15 |
+
torchaudio
|
16 |
+
tqdm
|
17 |
+
transformers
|
src/vui/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = "0.1.0"
|
src/vui/config.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class VuiConfig(BaseModel):
|
7 |
+
max_text_tokens: int = 100
|
8 |
+
text_size: int = -1
|
9 |
+
max_audio_tokens: int = 100
|
10 |
+
|
11 |
+
n_quantizers: int = 9
|
12 |
+
codebook_size: int = 1000
|
13 |
+
special_token_id: int = 1000
|
14 |
+
audio_eos_id: int = 1000 + 1
|
15 |
+
audio_pad_id: int = 1000 + 1 + 1
|
16 |
+
d_model: int = 512
|
17 |
+
n_layers: int = 6
|
18 |
+
n_heads: int = 8
|
19 |
+
bias: bool = False
|
20 |
+
dropout: float = 0.0
|
21 |
+
use_rotary_emb: bool = True
|
22 |
+
rope_dim: int | None = None
|
23 |
+
rope_theta: float = 10_000.0
|
24 |
+
rope_theta_rescale_factor: float = 1.0
|
25 |
+
|
26 |
+
|
27 |
+
class Config(BaseModel):
|
28 |
+
name: str = "base"
|
29 |
+
|
30 |
+
checkpoint: str | dict | None = None
|
31 |
+
|
32 |
+
model: VuiConfig = VuiConfig()
|
33 |
+
|
34 |
+
|
35 |
+
ALL = []
|
36 |
+
current_module = sys.modules[__name__]
|
37 |
+
for name in dir(current_module):
|
38 |
+
if name.isupper() and isinstance(getattr(current_module, name), Config):
|
39 |
+
ALL.append(getattr(current_module, name))
|
40 |
+
|
41 |
+
CONFIGS = {v.name: v for v in ALL}
|
src/vui/fluac.py
ADDED
@@ -0,0 +1,707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from contextlib import nullcontext
|
3 |
+
from functools import partial, wraps
|
4 |
+
from os import path
|
5 |
+
from typing import List, Tuple
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from einops import pack, rearrange, unpack
|
11 |
+
from einops.layers.torch import Rearrange
|
12 |
+
from pydantic import BaseModel
|
13 |
+
from torch import Tensor, int32
|
14 |
+
from torch.amp import autocast
|
15 |
+
from torch.nn import Module
|
16 |
+
from torch.nn.utils.parametrizations import weight_norm
|
17 |
+
|
18 |
+
from vui.utils import decompile_state_dict
|
19 |
+
|
20 |
+
|
21 |
+
def exists(v):
|
22 |
+
return v is not None
|
23 |
+
|
24 |
+
|
25 |
+
def default(*args):
|
26 |
+
for arg in args:
|
27 |
+
if exists(arg):
|
28 |
+
return arg
|
29 |
+
return None
|
30 |
+
|
31 |
+
|
32 |
+
def maybe(fn):
|
33 |
+
@wraps(fn)
|
34 |
+
def inner(x, *args, **kwargs):
|
35 |
+
if not exists(x):
|
36 |
+
return x
|
37 |
+
return fn(x, *args, **kwargs)
|
38 |
+
|
39 |
+
return inner
|
40 |
+
|
41 |
+
|
42 |
+
def pack_one(t, pattern):
|
43 |
+
return pack([t], pattern)
|
44 |
+
|
45 |
+
|
46 |
+
def unpack_one(t, ps, pattern):
|
47 |
+
return unpack(t, ps, pattern)[0]
|
48 |
+
|
49 |
+
|
50 |
+
def round_ste(z: Tensor) -> Tensor:
|
51 |
+
"""Round with straight through gradients."""
|
52 |
+
zhat = z.round()
|
53 |
+
return z + (zhat - z).detach()
|
54 |
+
|
55 |
+
|
56 |
+
class FSQ(Module):
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
levels: List[int],
|
60 |
+
dim: int | None = None,
|
61 |
+
num_codebooks: int = 1,
|
62 |
+
keep_num_codebooks_dim: bool | None = None,
|
63 |
+
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
|
64 |
+
channel_first: bool = True,
|
65 |
+
projection_has_bias: bool = True,
|
66 |
+
return_indices=True,
|
67 |
+
force_quantization_f32: bool = True,
|
68 |
+
):
|
69 |
+
super().__init__()
|
70 |
+
|
71 |
+
_levels = torch.tensor(levels, dtype=int32)
|
72 |
+
self.register_buffer("_levels", _levels, persistent=False)
|
73 |
+
|
74 |
+
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
|
75 |
+
self.register_buffer("_basis", _basis, persistent=False)
|
76 |
+
|
77 |
+
codebook_dim = len(levels)
|
78 |
+
self.codebook_dim = codebook_dim
|
79 |
+
|
80 |
+
effective_codebook_dim = codebook_dim * num_codebooks
|
81 |
+
self.num_codebooks = num_codebooks
|
82 |
+
self.effective_codebook_dim = effective_codebook_dim
|
83 |
+
|
84 |
+
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
85 |
+
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
86 |
+
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
87 |
+
|
88 |
+
self.dim = default(dim, len(_levels) * num_codebooks)
|
89 |
+
|
90 |
+
self.channel_first = channel_first
|
91 |
+
|
92 |
+
has_projections = self.dim != effective_codebook_dim
|
93 |
+
self.project_in = (
|
94 |
+
nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias)
|
95 |
+
if has_projections
|
96 |
+
else nn.Identity()
|
97 |
+
)
|
98 |
+
self.project_out = (
|
99 |
+
nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias)
|
100 |
+
if has_projections
|
101 |
+
else nn.Identity()
|
102 |
+
)
|
103 |
+
|
104 |
+
self.has_projections = has_projections
|
105 |
+
|
106 |
+
self.return_indices = return_indices
|
107 |
+
if return_indices:
|
108 |
+
self.codebook_size = self._levels.prod().item()
|
109 |
+
implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
|
110 |
+
self.register_buffer(
|
111 |
+
"implicit_codebook", implicit_codebook, persistent=False
|
112 |
+
)
|
113 |
+
|
114 |
+
self.allowed_dtypes = allowed_dtypes
|
115 |
+
self.force_quantization_f32 = force_quantization_f32
|
116 |
+
|
117 |
+
def bound(self, z, eps: float = 1e-3):
|
118 |
+
"""Bound `z`, an array of shape (..., d)."""
|
119 |
+
half_l = (self._levels - 1) * (1 + eps) / 2
|
120 |
+
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
|
121 |
+
shift = (offset / half_l).atanh()
|
122 |
+
return (z + shift).tanh() * half_l - offset
|
123 |
+
|
124 |
+
def quantize(self, z):
|
125 |
+
"""Quantizes z, returns quantized zhat, same shape as z."""
|
126 |
+
quantized = round_ste(self.bound(z))
|
127 |
+
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
128 |
+
return quantized / half_width
|
129 |
+
|
130 |
+
def _scale_and_shift(self, zhat_normalized):
|
131 |
+
half_width = self._levels // 2
|
132 |
+
return (zhat_normalized * half_width) + half_width
|
133 |
+
|
134 |
+
def _scale_and_shift_inverse(self, zhat):
|
135 |
+
half_width = self._levels // 2
|
136 |
+
return (zhat - half_width) / half_width
|
137 |
+
|
138 |
+
def _indices_to_codes(self, indices):
|
139 |
+
level_indices = self.indices_to_level_indices(indices)
|
140 |
+
codes = self._scale_and_shift_inverse(level_indices)
|
141 |
+
return codes
|
142 |
+
|
143 |
+
def codes_to_indices(self, zhat):
|
144 |
+
"""Converts a `code` to an index in the codebook."""
|
145 |
+
assert zhat.shape[-1] == self.codebook_dim
|
146 |
+
zhat = self._scale_and_shift(zhat)
|
147 |
+
return (zhat * self._basis).sum(dim=-1).to(int32)
|
148 |
+
|
149 |
+
def indices_to_level_indices(self, indices):
|
150 |
+
"""Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings"""
|
151 |
+
indices = rearrange(indices, "... -> ... 1")
|
152 |
+
codes_non_centered = (indices // self._basis) % self._levels
|
153 |
+
return codes_non_centered
|
154 |
+
|
155 |
+
def indices_to_codes(self, indices):
|
156 |
+
"""Inverse of `codes_to_indices`."""
|
157 |
+
assert exists(indices)
|
158 |
+
|
159 |
+
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
160 |
+
|
161 |
+
codes = self._indices_to_codes(indices)
|
162 |
+
|
163 |
+
if self.keep_num_codebooks_dim:
|
164 |
+
codes = rearrange(codes, "... c d -> ... (c d)")
|
165 |
+
|
166 |
+
codes = self.project_out(codes)
|
167 |
+
|
168 |
+
if is_img_or_video or self.channel_first:
|
169 |
+
codes = rearrange(codes, "b ... d -> b d ...")
|
170 |
+
|
171 |
+
return codes
|
172 |
+
|
173 |
+
def forward(self, z: Tensor):
|
174 |
+
"""
|
175 |
+
einstein notation
|
176 |
+
b - batch
|
177 |
+
n - sequence (or flattened spatial dimensions)
|
178 |
+
d - feature dimension
|
179 |
+
c - number of codebook dim
|
180 |
+
"""
|
181 |
+
device_type = z.device.type
|
182 |
+
|
183 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
184 |
+
if self.channel_first:
|
185 |
+
z = rearrange(z, "b d ... -> b ... d")
|
186 |
+
z, ps = pack_one(z, "b * d")
|
187 |
+
|
188 |
+
assert (
|
189 |
+
z.shape[-1] == self.dim
|
190 |
+
), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
|
191 |
+
|
192 |
+
z = self.project_in(z)
|
193 |
+
|
194 |
+
z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
|
195 |
+
|
196 |
+
# whether to force quantization step to be full precision or not
|
197 |
+
|
198 |
+
force_f32 = self.force_quantization_f32
|
199 |
+
quantization_context = (
|
200 |
+
partial(autocast, device_type=device_type, enabled=False)
|
201 |
+
if force_f32
|
202 |
+
else nullcontext
|
203 |
+
)
|
204 |
+
|
205 |
+
with quantization_context():
|
206 |
+
orig_dtype = z.dtype
|
207 |
+
|
208 |
+
if force_f32 and orig_dtype not in self.allowed_dtypes:
|
209 |
+
z = z.float()
|
210 |
+
|
211 |
+
codes = self.quantize(z)
|
212 |
+
|
213 |
+
# returning indices could be optional
|
214 |
+
|
215 |
+
indices = None
|
216 |
+
|
217 |
+
if self.return_indices:
|
218 |
+
indices = self.codes_to_indices(codes)
|
219 |
+
|
220 |
+
codes = rearrange(codes, "b n c d -> b n (c d)")
|
221 |
+
|
222 |
+
codes = codes.type(orig_dtype)
|
223 |
+
|
224 |
+
# project out
|
225 |
+
|
226 |
+
out = self.project_out(codes)
|
227 |
+
|
228 |
+
# reconstitute image or video dimensions
|
229 |
+
|
230 |
+
if self.channel_first:
|
231 |
+
out = unpack_one(out, ps, "b * d")
|
232 |
+
out = rearrange(out, "b ... d -> b d ...")
|
233 |
+
|
234 |
+
indices = maybe(unpack_one)(indices, ps, "b * c")
|
235 |
+
|
236 |
+
if not self.keep_num_codebooks_dim and self.return_indices:
|
237 |
+
indices = maybe(rearrange)(indices, "... 1 -> ...")
|
238 |
+
|
239 |
+
# return quantized output and indices
|
240 |
+
|
241 |
+
return out, indices
|
242 |
+
|
243 |
+
|
244 |
+
def WNConv1d(*args, **kwargs):
|
245 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
246 |
+
|
247 |
+
|
248 |
+
def WNConvTranspose1d(*args, **kwargs):
|
249 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
250 |
+
|
251 |
+
|
252 |
+
# Scripting this brings model speed up 1.4x
|
253 |
+
@torch.jit.script
|
254 |
+
def snake(x, alpha):
|
255 |
+
shape = x.shape
|
256 |
+
x = x.reshape(shape[0], shape[1], -1)
|
257 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
258 |
+
x = x.reshape(shape)
|
259 |
+
return x
|
260 |
+
|
261 |
+
|
262 |
+
class Snake1d(nn.Module):
|
263 |
+
def __init__(self, channels):
|
264 |
+
super().__init__()
|
265 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
266 |
+
|
267 |
+
def forward(self, x):
|
268 |
+
return snake(x, self.alpha)
|
269 |
+
|
270 |
+
|
271 |
+
def init_weights(m):
|
272 |
+
if isinstance(m, nn.Conv1d):
|
273 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
274 |
+
nn.init.constant_(m.bias, 0)
|
275 |
+
|
276 |
+
|
277 |
+
class ResidualUnit(nn.Module):
|
278 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
279 |
+
super().__init__()
|
280 |
+
pad = ((7 - 1) * dilation) // 2
|
281 |
+
self.block = nn.Sequential(
|
282 |
+
Snake1d(dim),
|
283 |
+
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
284 |
+
Snake1d(dim),
|
285 |
+
WNConv1d(dim, dim, kernel_size=1),
|
286 |
+
)
|
287 |
+
|
288 |
+
def forward(self, x):
|
289 |
+
y = self.block(x)
|
290 |
+
pad = (x.shape[-1] - y.shape[-1]) // 2
|
291 |
+
if pad > 0:
|
292 |
+
x = x[..., pad:-pad]
|
293 |
+
return x + y
|
294 |
+
|
295 |
+
|
296 |
+
class EncoderBlock(nn.Module):
|
297 |
+
def __init__(self, dim: int = 16, stride: int = 1):
|
298 |
+
super().__init__()
|
299 |
+
self.block = nn.Sequential(
|
300 |
+
ResidualUnit(dim // 2, dilation=1),
|
301 |
+
ResidualUnit(dim // 2, dilation=3),
|
302 |
+
ResidualUnit(dim // 2, dilation=9),
|
303 |
+
Snake1d(dim // 2),
|
304 |
+
WNConv1d(
|
305 |
+
dim // 2,
|
306 |
+
dim,
|
307 |
+
kernel_size=2 * stride,
|
308 |
+
stride=stride,
|
309 |
+
padding=math.ceil(stride / 2),
|
310 |
+
),
|
311 |
+
)
|
312 |
+
|
313 |
+
def forward(self, x):
|
314 |
+
return self.block(x)
|
315 |
+
|
316 |
+
|
317 |
+
class Encoder(nn.Module):
|
318 |
+
def __init__(
|
319 |
+
self,
|
320 |
+
d_model: int = 64,
|
321 |
+
strides: list = [2, 4, 8, 8],
|
322 |
+
d_latent: int = 64,
|
323 |
+
):
|
324 |
+
super().__init__()
|
325 |
+
# Create first convolution
|
326 |
+
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
327 |
+
|
328 |
+
# Create EncoderBlocks that double channels as they downsample by `stride`
|
329 |
+
for stride in strides:
|
330 |
+
d_model *= 2
|
331 |
+
self.block += [EncoderBlock(d_model, stride=stride)]
|
332 |
+
|
333 |
+
# Create last convolution
|
334 |
+
self.block += [
|
335 |
+
Snake1d(d_model),
|
336 |
+
WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
|
337 |
+
]
|
338 |
+
|
339 |
+
# Wrap black into nn.Sequential
|
340 |
+
self.block = nn.Sequential(*self.block)
|
341 |
+
self.enc_dim = d_model
|
342 |
+
|
343 |
+
def forward(self, x):
|
344 |
+
return self.block(x)
|
345 |
+
|
346 |
+
|
347 |
+
class DecoderBlock(nn.Module):
|
348 |
+
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
|
349 |
+
super().__init__()
|
350 |
+
self.block = nn.Sequential(
|
351 |
+
Snake1d(input_dim),
|
352 |
+
WNConvTranspose1d(
|
353 |
+
input_dim,
|
354 |
+
output_dim,
|
355 |
+
kernel_size=2 * stride,
|
356 |
+
stride=stride,
|
357 |
+
padding=math.ceil(stride / 2),
|
358 |
+
),
|
359 |
+
ResidualUnit(output_dim, dilation=1),
|
360 |
+
ResidualUnit(output_dim, dilation=3),
|
361 |
+
ResidualUnit(output_dim, dilation=9),
|
362 |
+
)
|
363 |
+
|
364 |
+
def forward(self, x):
|
365 |
+
return self.block(x)
|
366 |
+
|
367 |
+
|
368 |
+
class Decoder(nn.Module):
|
369 |
+
def __init__(
|
370 |
+
self,
|
371 |
+
input_channel: int,
|
372 |
+
channels: int,
|
373 |
+
rates: list[int],
|
374 |
+
d_out: int = 1,
|
375 |
+
):
|
376 |
+
super().__init__()
|
377 |
+
|
378 |
+
# Add first conv layer
|
379 |
+
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
380 |
+
|
381 |
+
# Add upsampling + MRF blocks
|
382 |
+
for i, stride in enumerate(rates):
|
383 |
+
input_dim = channels // 2**i
|
384 |
+
output_dim = channels // 2 ** (i + 1)
|
385 |
+
layers += [DecoderBlock(input_dim, output_dim, stride)]
|
386 |
+
|
387 |
+
# Add final conv layer
|
388 |
+
layers += [
|
389 |
+
Snake1d(output_dim),
|
390 |
+
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
391 |
+
nn.Tanh(),
|
392 |
+
]
|
393 |
+
|
394 |
+
self.model = nn.Sequential(*layers)
|
395 |
+
|
396 |
+
# @torch.compile(dynamic=True)
|
397 |
+
def forward(self, z: Tensor):
|
398 |
+
return self.model(z)
|
399 |
+
|
400 |
+
|
401 |
+
class FiniteScalarQuantize(nn.Module):
|
402 |
+
def __init__(
|
403 |
+
self, latent_dim: int, levels: list[int], *, stride: int = 1, mlp: bool = False
|
404 |
+
):
|
405 |
+
super().__init__()
|
406 |
+
|
407 |
+
self.stride = stride
|
408 |
+
|
409 |
+
codebook_dim = len(levels)
|
410 |
+
|
411 |
+
self.in_proj = WNConv1d(latent_dim, codebook_dim, kernel_size=1)
|
412 |
+
self.quantize = FSQ(levels=levels, channel_first=True)
|
413 |
+
self.out_proj = WNConv1d(codebook_dim, latent_dim, kernel_size=1)
|
414 |
+
|
415 |
+
if mlp:
|
416 |
+
self.mlp = nn.Sequential(
|
417 |
+
Rearrange("B C T -> B T C"),
|
418 |
+
nn.Linear(latent_dim, 4 * latent_dim),
|
419 |
+
nn.GELU(),
|
420 |
+
nn.Linear(4 * latent_dim, latent_dim),
|
421 |
+
Rearrange("B T C -> B C T"),
|
422 |
+
)
|
423 |
+
else:
|
424 |
+
self.mlp = None
|
425 |
+
|
426 |
+
def from_indices(self, indices: Tensor):
|
427 |
+
B, T = indices.size()
|
428 |
+
z_q = self.quantize.indices_to_codes(indices)
|
429 |
+
z_q = self.out_proj(z_q)
|
430 |
+
return z_q
|
431 |
+
|
432 |
+
def forward(self, z: Tensor, *args):
|
433 |
+
if self.stride > 1:
|
434 |
+
z = F.avg_pool1d(z, self.stride, stride=self.stride)
|
435 |
+
|
436 |
+
z_e = self.in_proj(z) # z_e : (B x D x T)
|
437 |
+
|
438 |
+
# we're channels first
|
439 |
+
# scale = scale.unsqueeze(-1)
|
440 |
+
|
441 |
+
# z_e = z_e / scale
|
442 |
+
z_q, indices = self.quantize(z_e)
|
443 |
+
# z_q = z_q * scale
|
444 |
+
|
445 |
+
z_q = self.out_proj(z_q)
|
446 |
+
|
447 |
+
if self.stride > 1:
|
448 |
+
z_e = z_e.repeat_interleave(self.stride, dim=-1)
|
449 |
+
z_q = z_q.repeat_interleave(self.stride, dim=-1)
|
450 |
+
indices = indices.repeat_interleave(self.stride, dim=-1)
|
451 |
+
|
452 |
+
if self.mlp is not None:
|
453 |
+
z_q = self.mlp(z_q)
|
454 |
+
|
455 |
+
return z_q, indices, z_e
|
456 |
+
|
457 |
+
|
458 |
+
class ResidualFiniteScalarQuantize(nn.Module):
|
459 |
+
def __init__(
|
460 |
+
self,
|
461 |
+
*,
|
462 |
+
latent_dim: int,
|
463 |
+
n_quantizers: int,
|
464 |
+
levels: list[int],
|
465 |
+
strides: list[int] | None = None,
|
466 |
+
quantizer_dropout: float = 0.0,
|
467 |
+
mlp: bool = False,
|
468 |
+
):
|
469 |
+
super().__init__()
|
470 |
+
|
471 |
+
self.n_quantizers = n_quantizers
|
472 |
+
self.quantizer_dropout = quantizer_dropout
|
473 |
+
|
474 |
+
strides = [1] * n_quantizers if strides is None else strides
|
475 |
+
|
476 |
+
assert (
|
477 |
+
len(strides) == n_quantizers
|
478 |
+
), "Strides must be provided for each codebook"
|
479 |
+
|
480 |
+
scales = []
|
481 |
+
quantizers = []
|
482 |
+
levels_tensor = torch.tensor(levels, dtype=torch.float32)
|
483 |
+
|
484 |
+
for i in range(n_quantizers):
|
485 |
+
scales.append((levels_tensor - 1) ** -i)
|
486 |
+
quantizers.append(
|
487 |
+
FiniteScalarQuantize(
|
488 |
+
latent_dim=latent_dim, levels=levels, stride=strides[i], mlp=mlp
|
489 |
+
)
|
490 |
+
)
|
491 |
+
|
492 |
+
self.quantizers = nn.ModuleList(quantizers)
|
493 |
+
|
494 |
+
self.register_buffer("scales", torch.stack(scales), persistent=False)
|
495 |
+
|
496 |
+
codebooks = [
|
497 |
+
quantizer.quantize.implicit_codebook for quantizer in self.quantizers
|
498 |
+
]
|
499 |
+
self.codebooks = torch.stack(codebooks, dim=0)
|
500 |
+
|
501 |
+
def from_indices(self, indices: Tensor):
|
502 |
+
B, Q, T = indices.size()
|
503 |
+
|
504 |
+
z_q = 0.0
|
505 |
+
|
506 |
+
for i, quantizer in enumerate(self.quantizers):
|
507 |
+
z_q_i = quantizer.from_indices(indices[:, i])
|
508 |
+
z_q = z_q + z_q_i
|
509 |
+
|
510 |
+
return z_q
|
511 |
+
|
512 |
+
def forward(self, z: Tensor, n_quantizers: int | None = None):
|
513 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
514 |
+
the corresponding codebook vectors
|
515 |
+
Parameters
|
516 |
+
----------
|
517 |
+
z : Tensor[B x D x T]
|
518 |
+
n_quantizers : int, optional
|
519 |
+
No. of quantizers to use
|
520 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
521 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
522 |
+
when in training mode, and a random number of quantizers is used.
|
523 |
+
Returns
|
524 |
+
-------
|
525 |
+
dict
|
526 |
+
A dictionary with the following keys:
|
527 |
+
|
528 |
+
"z" : Tensor[B x D x T]
|
529 |
+
Quantized continuous representation of input
|
530 |
+
"codes" : Tensor[B x N x T]
|
531 |
+
Codebook indices for each codebook
|
532 |
+
(quantized discrete representation of input)
|
533 |
+
"latents" : Tensor[B x N*D x T]
|
534 |
+
Projected latents (continuous representation of input before quantization)
|
535 |
+
"""
|
536 |
+
B = z.shape[0]
|
537 |
+
z_q = 0
|
538 |
+
residual = z
|
539 |
+
|
540 |
+
indices = []
|
541 |
+
latents = []
|
542 |
+
|
543 |
+
if n_quantizers is None:
|
544 |
+
n_quantizers = self.n_quantizers
|
545 |
+
|
546 |
+
if self.training:
|
547 |
+
n_quantizers = torch.ones((B,)) * self.n_quantizers + 1
|
548 |
+
dropout = torch.randint(1, self.n_quantizers + 1, (B,))
|
549 |
+
n_dropout = int(B * self.quantizer_dropout)
|
550 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
551 |
+
n_quantizers = n_quantizers.to(z.device)
|
552 |
+
|
553 |
+
for i, quantizer in enumerate(self.quantizers):
|
554 |
+
if not self.training and i >= n_quantizers:
|
555 |
+
break
|
556 |
+
|
557 |
+
z_q_i, indices_i, z_e_i = quantizer(residual)
|
558 |
+
|
559 |
+
residual = residual - z_q_i.detach()
|
560 |
+
|
561 |
+
mask = torch.full((B,), fill_value=i, device=z.device) < n_quantizers
|
562 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
563 |
+
|
564 |
+
indices.append(indices_i)
|
565 |
+
latents.append(z_e_i)
|
566 |
+
|
567 |
+
indices = torch.stack(indices, dim=1)
|
568 |
+
latents = torch.cat(latents, dim=1)
|
569 |
+
|
570 |
+
return z_q, indices, latents
|
571 |
+
|
572 |
+
|
573 |
+
class FluacConfig(BaseModel):
|
574 |
+
sample_rate: int = 44100
|
575 |
+
|
576 |
+
codebook_size: int | None = None
|
577 |
+
|
578 |
+
encoder_dim: int = 64
|
579 |
+
encoder_rates: list[int] = [2, 4, 8, 8]
|
580 |
+
|
581 |
+
quantizer_strides: list[int] | None = None # SNAC style strides
|
582 |
+
n_quantizers: int = 1
|
583 |
+
fsq_levels: list[int] | None = [8, 5, 5, 5] # 1000
|
584 |
+
decoder_dim: int = 1536
|
585 |
+
decoder_rates: list[int] = [8, 8, 4, 2]
|
586 |
+
|
587 |
+
@property
|
588 |
+
def hop_length(self) -> int:
|
589 |
+
return math.prod(self.encoder_rates)
|
590 |
+
|
591 |
+
@property
|
592 |
+
def latent_dim(self) -> int:
|
593 |
+
return self.encoder_dim * (2 ** len(self.encoder_rates))
|
594 |
+
|
595 |
+
@property
|
596 |
+
def effective_codebook_size(self) -> int:
|
597 |
+
return math.prod(self.fsq_levels)
|
598 |
+
|
599 |
+
|
600 |
+
class Fluac(nn.Module):
|
601 |
+
Q9_22KHZ = "fluac-22hz-22khz.pt"
|
602 |
+
|
603 |
+
def __init__(self, config: FluacConfig):
|
604 |
+
super().__init__()
|
605 |
+
|
606 |
+
self.config = config
|
607 |
+
|
608 |
+
self.encoder = Encoder(
|
609 |
+
config.encoder_dim, config.encoder_rates, config.latent_dim
|
610 |
+
)
|
611 |
+
|
612 |
+
self.quantizer = ResidualFiniteScalarQuantize(
|
613 |
+
latent_dim=config.latent_dim,
|
614 |
+
n_quantizers=config.n_quantizers,
|
615 |
+
levels=config.fsq_levels,
|
616 |
+
strides=config.quantizer_strides,
|
617 |
+
)
|
618 |
+
|
619 |
+
self.decoder = Decoder(
|
620 |
+
config.latent_dim,
|
621 |
+
config.decoder_dim,
|
622 |
+
config.decoder_rates,
|
623 |
+
)
|
624 |
+
|
625 |
+
self.apply(init_weights)
|
626 |
+
|
627 |
+
@staticmethod
|
628 |
+
def from_pretrained(name: str = Q9_22KHZ):
|
629 |
+
if path.exists(name):
|
630 |
+
checkpoint_path = name
|
631 |
+
else:
|
632 |
+
from huggingface_hub import hf_hub_download
|
633 |
+
|
634 |
+
checkpoint_path = hf_hub_download(
|
635 |
+
"fluxions/vui",
|
636 |
+
name,
|
637 |
+
)
|
638 |
+
|
639 |
+
checkpoint = torch.load(checkpoint_path, weights_only=True, map_location="cpu")
|
640 |
+
config = checkpoint["config"]
|
641 |
+
if "model" in config:
|
642 |
+
model_config = FluacConfig(**config["model"])
|
643 |
+
else:
|
644 |
+
model_config = FluacConfig(**config)
|
645 |
+
|
646 |
+
generator = Fluac(model_config).eval()
|
647 |
+
ckpt = decompile_state_dict(checkpoint["generator"])
|
648 |
+
generator.load_state_dict(ckpt)
|
649 |
+
return generator
|
650 |
+
|
651 |
+
def pad(self, waveform: Tensor):
|
652 |
+
T = waveform.size(-1)
|
653 |
+
right_pad = math.ceil(T / self.config.hop_length) * self.config.hop_length - T
|
654 |
+
waveform = F.pad(waveform, (0, right_pad))
|
655 |
+
return waveform
|
656 |
+
|
657 |
+
@torch.inference_mode()
|
658 |
+
def from_indices(self, indices: Tensor):
|
659 |
+
z_q = self.quantizer.from_indices(indices)
|
660 |
+
waveform = self.decoder(z_q)
|
661 |
+
return waveform
|
662 |
+
|
663 |
+
@torch.inference_mode()
|
664 |
+
def encode(self, waveforms: Tensor, n_quantizers: int | None = None):
|
665 |
+
# Ensure that waveforms is 3 dima
|
666 |
+
waveforms = waveforms.flatten()[None][None]
|
667 |
+
waveforms = self.pad(waveforms)
|
668 |
+
B, C, T = waveforms.size()
|
669 |
+
z = self.encoder(waveforms)
|
670 |
+
z_q, codes, latents = self.quantizer(z, n_quantizers=n_quantizers)
|
671 |
+
return codes
|
672 |
+
|
673 |
+
def forward(self, waveforms: Tensor, n_quantizers: int | None = None):
|
674 |
+
B, C, T = waveforms.size()
|
675 |
+
waveforms = self.pad(waveforms)
|
676 |
+
z = self.encoder(waveforms)
|
677 |
+
z_q, codes, latents = self.quantizer(z, n_quantizers=n_quantizers)
|
678 |
+
|
679 |
+
recons = self.decoder(z_q)
|
680 |
+
recons = recons[..., :T]
|
681 |
+
return {
|
682 |
+
"recons": recons,
|
683 |
+
"codes": codes,
|
684 |
+
}
|
685 |
+
|
686 |
+
@property
|
687 |
+
def device(self):
|
688 |
+
return next(self.parameters()).device
|
689 |
+
|
690 |
+
@property
|
691 |
+
def dtype(self):
|
692 |
+
return next(self.parameters()).dtype
|
693 |
+
|
694 |
+
@property
|
695 |
+
def hz(self):
|
696 |
+
import numpy as np
|
697 |
+
|
698 |
+
return self.config.sample_rate / np.prod(self.config.encoder_rates).item()
|
699 |
+
|
700 |
+
|
701 |
+
if __name__ == "__main__":
|
702 |
+
codec = Fluac.from_pretrained(Fluac.Q9_22KHZ)
|
703 |
+
print(codec.config)
|
704 |
+
wav = torch.rand(1, 1, 22050)
|
705 |
+
wav = codec.pad(wav)
|
706 |
+
codes = codec.encode(wav)
|
707 |
+
breakpoint()
|
src/vui/inference.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import time
|
3 |
+
|
4 |
+
import inflect
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchaudio
|
8 |
+
from torch import Tensor
|
9 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
10 |
+
|
11 |
+
from vui.model import Vui
|
12 |
+
from vui.sampling import multinomial, sample_top_k, sample_top_p, sample_top_p_top_k
|
13 |
+
from vui.utils import timer
|
14 |
+
from vui.vad import detect_voice_activity as vad
|
15 |
+
|
16 |
+
|
17 |
+
def ensure_spaces_around_tags(text: str):
|
18 |
+
# Add space before '[' if not preceded by space, '<', or '['
|
19 |
+
text = re.sub(
|
20 |
+
r"(?<![<\[\s])(\[)",
|
21 |
+
lambda m: (
|
22 |
+
f"\n{m.group(1)}"
|
23 |
+
if m.start() > 0 and text[m.start() - 1] == "\n"
|
24 |
+
else f" {m.group(1)}"
|
25 |
+
),
|
26 |
+
text,
|
27 |
+
)
|
28 |
+
# Add space after ']' if not preceded by digit+']' and not followed by space, '>', or ']'
|
29 |
+
text = re.sub(
|
30 |
+
r"(?<!\d\])(\])(?![>\]\s])",
|
31 |
+
lambda m: (
|
32 |
+
f"{m.group(1)}\n"
|
33 |
+
if m.end() < len(text) and text[m.end()] == "\n"
|
34 |
+
else f"{m.group(1)} "
|
35 |
+
),
|
36 |
+
text,
|
37 |
+
)
|
38 |
+
text = text.strip()
|
39 |
+
return text
|
40 |
+
|
41 |
+
|
42 |
+
REPLACE = [
|
43 |
+
("—", ","),
|
44 |
+
("'", "'"),
|
45 |
+
(":", ","),
|
46 |
+
(";", ","),
|
47 |
+
]
|
48 |
+
|
49 |
+
engine = None
|
50 |
+
wm = None
|
51 |
+
|
52 |
+
|
53 |
+
def asr(chunk, model=None, prefix=None):
|
54 |
+
import whisper
|
55 |
+
|
56 |
+
global wm
|
57 |
+
if model is not None:
|
58 |
+
wm = model
|
59 |
+
elif wm is None:
|
60 |
+
wm = whisper.load_model("turbo", "cuda")
|
61 |
+
|
62 |
+
"""Process audio with VAD and transcribe"""
|
63 |
+
chunk = whisper.pad_or_trim(chunk)
|
64 |
+
mel = whisper.log_mel_spectrogram(chunk, n_mels=wm.dims.n_mels).to(wm.device)
|
65 |
+
options = whisper.DecodingOptions(
|
66 |
+
language="en", without_timestamps=True, prefix=prefix
|
67 |
+
)
|
68 |
+
result = whisper.decode(wm, mel[None], options)
|
69 |
+
return result[0].text
|
70 |
+
|
71 |
+
|
72 |
+
def replace_numbers_with_words(text):
|
73 |
+
global engine
|
74 |
+
|
75 |
+
if engine is None:
|
76 |
+
engine = inflect.engine()
|
77 |
+
|
78 |
+
# Function to convert a number match to words
|
79 |
+
def number_to_words(match):
|
80 |
+
number = match.group()
|
81 |
+
return engine.number_to_words(number) + " "
|
82 |
+
|
83 |
+
# Replace digits with their word equivalents
|
84 |
+
return re.sub(r"\d+", number_to_words, text)
|
85 |
+
|
86 |
+
|
87 |
+
valid_non_speech = ["breath", "sigh", "laugh", "tut", "hesitate"]
|
88 |
+
valid_non_speech = [f"[{v}]" for v in valid_non_speech]
|
89 |
+
|
90 |
+
|
91 |
+
def remove_all_invalid_non_speech(txt):
|
92 |
+
"""
|
93 |
+
Remove all non-speech markers that are not in the valid_non_speech list.
|
94 |
+
Only keeps valid non-speech markers like [breath], [sigh], etc.
|
95 |
+
"""
|
96 |
+
# Find all text within square brackets
|
97 |
+
bracket_pattern = r"\[([^\]]+)\]"
|
98 |
+
brackets = re.findall(bracket_pattern, txt)
|
99 |
+
|
100 |
+
# For each bracketed text, check if it's in our valid list
|
101 |
+
for bracket in brackets:
|
102 |
+
bracket_with_brackets = f"[{bracket}]"
|
103 |
+
if bracket_with_brackets not in valid_non_speech and bracket != "pause":
|
104 |
+
# If not valid, remove it from the text
|
105 |
+
txt = txt.replace(bracket_with_brackets, "")
|
106 |
+
|
107 |
+
return txt
|
108 |
+
|
109 |
+
|
110 |
+
def simple_clean(text):
|
111 |
+
text = re.sub(r"(\d+)am", r"\1 AM", text)
|
112 |
+
text = re.sub(r"(\d+)pm", r"\1 PM", text)
|
113 |
+
text = replace_numbers_with_words(text)
|
114 |
+
text = ensure_spaces_around_tags(text)
|
115 |
+
text = remove_all_invalid_non_speech(text)
|
116 |
+
|
117 |
+
text = text.replace('"', "")
|
118 |
+
text = text.replace("”", "")
|
119 |
+
text = text.replace("“", "")
|
120 |
+
text = text.replace("’", "'")
|
121 |
+
text = text.replace("%", " percent")
|
122 |
+
text = text.replace("*", "")
|
123 |
+
text = text.replace("(", "")
|
124 |
+
text = text.replace(")", "")
|
125 |
+
text = text.replace(";", "")
|
126 |
+
text = text.replace("–", " ")
|
127 |
+
text = text.replace("—", "")
|
128 |
+
text = text.replace(":", "")
|
129 |
+
text = text.replace("…", "...")
|
130 |
+
text = text.replace("s...", "s")
|
131 |
+
|
132 |
+
# replace repeating \n with just one \n
|
133 |
+
text = re.sub(r"\n+", "\n", text)
|
134 |
+
ntxt = re.sub(r" +", " ", text)
|
135 |
+
|
136 |
+
# Ensure that ntxt ends with . or ?
|
137 |
+
ntxt = ntxt.strip()
|
138 |
+
if not ntxt.endswith(".") or ntxt.endswith("?"):
|
139 |
+
ntxt += "."
|
140 |
+
ntxt += " [pause]"
|
141 |
+
return ntxt
|
142 |
+
|
143 |
+
|
144 |
+
@torch.inference_mode()
|
145 |
+
def generate(
|
146 |
+
self: Vui,
|
147 |
+
text: str,
|
148 |
+
prompt_codes: Tensor | None = None,
|
149 |
+
temperature: float = 0.5,
|
150 |
+
top_k: int | None = 150,
|
151 |
+
top_p: float | None = None,
|
152 |
+
max_gen_len: int = int(120 * 21.53),
|
153 |
+
):
|
154 |
+
text = simple_clean(text)
|
155 |
+
with (
|
156 |
+
torch.autocast("cuda", torch.bfloat16, True),
|
157 |
+
sdpa_kernel([SDPBackend.MATH]),
|
158 |
+
timer("generate"),
|
159 |
+
):
|
160 |
+
t1 = time.perf_counter()
|
161 |
+
batch_size = 1
|
162 |
+
device = self.device
|
163 |
+
self.dtype
|
164 |
+
self.decoder.allocate_inference_cache(batch_size, device, torch.bfloat16)
|
165 |
+
|
166 |
+
texts = [text]
|
167 |
+
|
168 |
+
encoded = self.tokenizer(
|
169 |
+
texts,
|
170 |
+
padding="longest",
|
171 |
+
return_tensors="pt",
|
172 |
+
)
|
173 |
+
|
174 |
+
input_ids = encoded.input_ids.to(device)
|
175 |
+
text_embeddings = self.token_emb(input_ids)
|
176 |
+
|
177 |
+
B = batch_size
|
178 |
+
Q = self.config.model.n_quantizers
|
179 |
+
|
180 |
+
if prompt_codes is None:
|
181 |
+
prompt_codes = torch.zeros(
|
182 |
+
(batch_size, Q, 0), dtype=torch.int64, device=device
|
183 |
+
)
|
184 |
+
else:
|
185 |
+
prompt_codes = prompt_codes[:, :Q].repeat(batch_size, 1, 1)
|
186 |
+
|
187 |
+
start_offset = prompt_codes.size(-1)
|
188 |
+
|
189 |
+
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
190 |
+
# this token is used as default value for codes that are not generated yet
|
191 |
+
unknown_token = -1
|
192 |
+
special_token_id = self.config.model.special_token_id
|
193 |
+
|
194 |
+
# we generate codes up to the max_gen_len that will be mapped to the pattern sequence
|
195 |
+
codes = torch.full(
|
196 |
+
(B, Q, max_gen_len), unknown_token, dtype=torch.int64, device=device
|
197 |
+
)
|
198 |
+
|
199 |
+
codes[:, :, :start_offset] = prompt_codes
|
200 |
+
|
201 |
+
sequence, indexes, mask = pattern.build_pattern_sequence(
|
202 |
+
codes, special_token_id
|
203 |
+
)
|
204 |
+
# retrieve the start_offset in the sequence:
|
205 |
+
# it is the first sequence step that contains the `start_offset` timestep
|
206 |
+
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
|
207 |
+
assert start_offset_sequence is not None
|
208 |
+
|
209 |
+
prev_offset = 0
|
210 |
+
S = sequence.size(-1)
|
211 |
+
|
212 |
+
do_prefill = True
|
213 |
+
eos = self.config.model.audio_eos_id
|
214 |
+
|
215 |
+
for offset in range(start_offset_sequence, S):
|
216 |
+
# print(f"{prev_offset}:{offset}")
|
217 |
+
curr_sequence = sequence[..., prev_offset:offset]
|
218 |
+
audio_embeddings = (
|
219 |
+
sum([self.audio_embeddings[q](curr_sequence[:, q]) for q in range(Q)])
|
220 |
+
/ Q
|
221 |
+
)
|
222 |
+
|
223 |
+
if do_prefill:
|
224 |
+
embeddings = torch.cat((text_embeddings, audio_embeddings), dim=1)
|
225 |
+
T = embeddings.size(1)
|
226 |
+
input_pos = torch.arange(0, T, device=device)
|
227 |
+
do_prefill = False
|
228 |
+
else:
|
229 |
+
embeddings = audio_embeddings
|
230 |
+
input_pos = torch.tensor([T], device=device)
|
231 |
+
T += 1
|
232 |
+
|
233 |
+
out = self.decoder(embeddings, input_pos)
|
234 |
+
|
235 |
+
if offset == 15:
|
236 |
+
print("TTFB", time.perf_counter() - t1)
|
237 |
+
|
238 |
+
logits = torch.stack(
|
239 |
+
[self.audio_heads[q](out[:, -1]) for q in range(Q)], dim=1
|
240 |
+
)
|
241 |
+
|
242 |
+
repetition_penalty = 1.4
|
243 |
+
history_window = 12
|
244 |
+
|
245 |
+
# Get the history of generated tokens for each quantizer
|
246 |
+
for q in range(Q):
|
247 |
+
# Extract the history window for this quantizer
|
248 |
+
history_start = max(0, offset - history_window)
|
249 |
+
token_history = sequence[0, q, history_start:offset]
|
250 |
+
|
251 |
+
# Only apply penalty to tokens that appear in the history
|
252 |
+
unique_tokens = torch.unique(token_history)
|
253 |
+
unique_tokens = unique_tokens[unique_tokens != special_token_id]
|
254 |
+
unique_tokens = unique_tokens[unique_tokens != eos]
|
255 |
+
unique_tokens = unique_tokens[unique_tokens != unknown_token]
|
256 |
+
|
257 |
+
if len(unique_tokens) > 0:
|
258 |
+
# Apply penalty by dividing the logits for tokens that have appeared recently
|
259 |
+
logits[0, q, unique_tokens] = (
|
260 |
+
logits[0, q, unique_tokens] / repetition_penalty
|
261 |
+
)
|
262 |
+
|
263 |
+
if offset < 24.53 * 4:
|
264 |
+
logits[..., eos] = -float("inf")
|
265 |
+
|
266 |
+
probs = F.softmax(logits / temperature, dim=-1)
|
267 |
+
|
268 |
+
# print(probs.shape)
|
269 |
+
if top_p is not None and top_k is not None:
|
270 |
+
next_codes = sample_top_p_top_k(probs, top_p, top_k)
|
271 |
+
elif top_p is not None and top_p > 0:
|
272 |
+
next_codes = sample_top_p(probs, top_p)
|
273 |
+
elif top_k is not None and top_k > 0:
|
274 |
+
next_codes = sample_top_k(probs, top_k)
|
275 |
+
else:
|
276 |
+
next_codes = multinomial(probs, num_samples=1)
|
277 |
+
|
278 |
+
next_codes = next_codes.repeat(batch_size, 1, 1)
|
279 |
+
|
280 |
+
if (probs[..., eos] > 0.95).any():
|
281 |
+
print("breaking at", offset)
|
282 |
+
break
|
283 |
+
|
284 |
+
valid_mask = mask[..., offset : offset + 1].expand(B, -1, -1)
|
285 |
+
next_codes[~valid_mask] = special_token_id
|
286 |
+
|
287 |
+
sequence[..., offset : offset + 1] = torch.where(
|
288 |
+
sequence[..., offset : offset + 1] == unknown_token,
|
289 |
+
next_codes,
|
290 |
+
sequence[..., offset : offset + 1],
|
291 |
+
)
|
292 |
+
|
293 |
+
prev_offset = offset
|
294 |
+
|
295 |
+
# print(sequence.shape)
|
296 |
+
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(
|
297 |
+
sequence, special_token=unknown_token
|
298 |
+
)
|
299 |
+
|
300 |
+
# sanity checks over the returned codes and corresponding masks
|
301 |
+
# assert (out_codes[..., :max_gen_len] != unknown_token).all()
|
302 |
+
# assert (out_mask[..., :max_gen_len] == 1).all()
|
303 |
+
out_codes = out_codes[..., prompt_codes.shape[-1] : offset]
|
304 |
+
return out_codes[[0]]
|
305 |
+
|
306 |
+
|
307 |
+
@torch.inference_mode()
|
308 |
+
def render(
|
309 |
+
self: Vui,
|
310 |
+
text: str,
|
311 |
+
prompt_codes: Tensor | None = None,
|
312 |
+
temperature: float = 0.5,
|
313 |
+
top_k: int | None = 100,
|
314 |
+
top_p: float | None = None,
|
315 |
+
max_secs: int = 100,
|
316 |
+
):
|
317 |
+
"""
|
318 |
+
Render audio from text. Uses generate for text < 1000 characters,
|
319 |
+
otherwise breaks text into sections and uses chunking with context.
|
320 |
+
"""
|
321 |
+
text = remove_all_invalid_non_speech(text)
|
322 |
+
text = simple_clean(text)
|
323 |
+
SR = self.codec.config.sample_rate
|
324 |
+
HZ = self.codec.hz
|
325 |
+
max_gen_len = int(HZ * max_secs)
|
326 |
+
|
327 |
+
if len(text) < 1000:
|
328 |
+
codes = generate(
|
329 |
+
self, text, prompt_codes, temperature, top_k, top_p, max_gen_len
|
330 |
+
)
|
331 |
+
codes = codes[..., :-10]
|
332 |
+
audio = self.codec.from_indices(codes)
|
333 |
+
paudio = torchaudio.functional.resample(audio[0], 22050, 16000)
|
334 |
+
results = vad(paudio)
|
335 |
+
|
336 |
+
if len(results):
|
337 |
+
# Cut the audio based on VAD results, add 200ms silence at end
|
338 |
+
s, e = results[0][0], results[-1][1]
|
339 |
+
return audio[..., int(s * SR) : int((e + 0.2) * SR)].cpu()
|
340 |
+
|
341 |
+
raise Exception("Failed to render")
|
342 |
+
|
343 |
+
# Otherwise we have to do some clever chaining!
|
344 |
+
|
345 |
+
orig_codes = prompt_codes
|
346 |
+
|
347 |
+
lines = text.split("\n")
|
348 |
+
audios = []
|
349 |
+
prev_codes = prompt_codes
|
350 |
+
prev_text = ""
|
351 |
+
|
352 |
+
for i, line in enumerate(lines):
|
353 |
+
run = True
|
354 |
+
while run:
|
355 |
+
current_text = prev_text + "\n" + line if prev_text else line
|
356 |
+
current_text = current_text.strip()
|
357 |
+
current_text = current_text.replace("...", "")
|
358 |
+
current_text = current_text + " [pause]"
|
359 |
+
|
360 |
+
# Calculate max length based on text length
|
361 |
+
maxlen = int(HZ * int(60 * len(current_text) / 500))
|
362 |
+
|
363 |
+
try:
|
364 |
+
print("rendering", current_text)
|
365 |
+
with (
|
366 |
+
torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH),
|
367 |
+
torch.autocast("cuda", dtype=torch.bfloat16, enabled=True),
|
368 |
+
):
|
369 |
+
codes = generate(
|
370 |
+
self,
|
371 |
+
current_text,
|
372 |
+
prompt_codes=prev_codes,
|
373 |
+
temperature=temperature,
|
374 |
+
top_k=top_k,
|
375 |
+
top_p=top_p,
|
376 |
+
max_gen_len=maxlen,
|
377 |
+
)
|
378 |
+
|
379 |
+
codes = codes[..., :-10]
|
380 |
+
audio = self.codec.from_indices(codes)
|
381 |
+
# Resample for VAD
|
382 |
+
paudio = torchaudio.functional.resample(audio[0], 22050, 16000)
|
383 |
+
|
384 |
+
results = vad(paudio)
|
385 |
+
run = len(results) == 0
|
386 |
+
|
387 |
+
if len(results):
|
388 |
+
prev_text = line
|
389 |
+
# Cut the audio based on VAD results, add 200ms silence at end
|
390 |
+
s, e = results[0][0], results[0][1]
|
391 |
+
codes = codes[..., int(s * HZ) : int(e * HZ)]
|
392 |
+
prev_codes = codes
|
393 |
+
audio = audio[..., int(s * SR) : int((e + 0.2) * SR)].cpu()
|
394 |
+
audios.append(audio)
|
395 |
+
else:
|
396 |
+
prev_codes = orig_codes
|
397 |
+
prev_text = ""
|
398 |
+
except KeyboardInterrupt:
|
399 |
+
break
|
400 |
+
except RuntimeError as e:
|
401 |
+
prev_codes = orig_codes
|
402 |
+
prev_text = ""
|
403 |
+
print(e)
|
404 |
+
|
405 |
+
return torch.cat(audios, dim=-1)
|
src/vui/model.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from torch import Tensor
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
|
11 |
+
from vui.fluac import Fluac
|
12 |
+
from vui.utils import load_what_you_can
|
13 |
+
|
14 |
+
from .config import Config
|
15 |
+
from .patterns import DelayedPatternProvider
|
16 |
+
from .rope import apply_rotary_emb, precompute_freqs_cis
|
17 |
+
|
18 |
+
|
19 |
+
class KVCache(nn.Module):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
batch_size: int,
|
23 |
+
max_seqlen: int,
|
24 |
+
n_kv_heads: int,
|
25 |
+
head_dim: int,
|
26 |
+
dtype: torch.dtype = torch.bfloat16,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
cache_shape = (batch_size, n_kv_heads, max_seqlen, head_dim)
|
31 |
+
|
32 |
+
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
|
33 |
+
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
|
34 |
+
|
35 |
+
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
36 |
+
# input_pos: (T,), k_val: (B, nh, T, d)
|
37 |
+
assert input_pos.size(0) == k_val.size(-2)
|
38 |
+
|
39 |
+
k_out = self.k_cache
|
40 |
+
v_out = self.v_cache
|
41 |
+
k_out[:, :, input_pos] = k_val
|
42 |
+
v_out[:, :, input_pos] = v_val
|
43 |
+
|
44 |
+
return k_out, v_out
|
45 |
+
|
46 |
+
|
47 |
+
def repeat_kv(x: torch.Tensor, n_reps: int) -> torch.Tensor:
|
48 |
+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
49 |
+
bs, n_kv_heads, T, head_dim = x.shape
|
50 |
+
|
51 |
+
return (
|
52 |
+
x[:, :, :, None, :]
|
53 |
+
.expand(bs, n_kv_heads, n_reps, T, head_dim)
|
54 |
+
.reshape(bs, n_kv_heads * n_reps, T, head_dim)
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
class MHA(nn.Module):
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
dim: int,
|
62 |
+
n_heads: int,
|
63 |
+
n_kv_heads: int,
|
64 |
+
*,
|
65 |
+
block_idx: int,
|
66 |
+
bias: bool = False,
|
67 |
+
dropout: float = 0.0,
|
68 |
+
causal: bool = False,
|
69 |
+
use_rotary_emb: bool = True,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
head_dim = dim // n_heads
|
74 |
+
|
75 |
+
self.use_rotary_emb = use_rotary_emb
|
76 |
+
self.block_idx = block_idx
|
77 |
+
self.dim = dim
|
78 |
+
self.n_heads = n_heads
|
79 |
+
self.n_kv_heads = n_kv_heads
|
80 |
+
self.head_dim = head_dim
|
81 |
+
self.dropout = dropout
|
82 |
+
self.causal = causal
|
83 |
+
self.n_reps = n_kv_heads // n_heads
|
84 |
+
qkv_dim = (n_heads + 2 * n_kv_heads) * head_dim
|
85 |
+
self.Wqkv = nn.Linear(dim, qkv_dim, bias=bias)
|
86 |
+
self.out_proj = nn.Linear(dim, dim, bias=bias)
|
87 |
+
self.kv_cache = None
|
88 |
+
|
89 |
+
def forward(
|
90 |
+
self,
|
91 |
+
x: Tensor,
|
92 |
+
freqs_cis: Tensor | None = None,
|
93 |
+
input_pos: Tensor | None = None,
|
94 |
+
attn_mask: Tensor | None = None,
|
95 |
+
):
|
96 |
+
B, T, d = x.size()
|
97 |
+
x.dtype
|
98 |
+
|
99 |
+
dropout_p = self.dropout if self.training else 0.0
|
100 |
+
|
101 |
+
qkv = self.Wqkv(x)
|
102 |
+
if self.n_heads == self.n_kv_heads:
|
103 |
+
qkv = rearrange(
|
104 |
+
qkv, "B T (three h d) -> B three h T d", three=3, h=self.n_heads
|
105 |
+
)
|
106 |
+
q, k, v = qkv.unbind(dim=1) # (B, h, T, d)
|
107 |
+
else:
|
108 |
+
q, k, v = torch.split(
|
109 |
+
qkv,
|
110 |
+
[
|
111 |
+
self.head_dim * self.n_heads,
|
112 |
+
self.head_dim * self.n_kv_heads,
|
113 |
+
self.head_dim * self.n_kv_heads,
|
114 |
+
],
|
115 |
+
dim=1,
|
116 |
+
)
|
117 |
+
q, k, v = map(lambda t: rearrange(t, "B T (h d) -> B h T d"), (q, k, v))
|
118 |
+
|
119 |
+
if self.use_rotary_emb:
|
120 |
+
q = apply_rotary_emb(freqs_cis, q)
|
121 |
+
k = apply_rotary_emb(freqs_cis, k)
|
122 |
+
|
123 |
+
if self.kv_cache is not None:
|
124 |
+
k, v = self.kv_cache.update(input_pos, k, v)
|
125 |
+
|
126 |
+
if self.n_reps > 1:
|
127 |
+
k = repeat_kv(k, self.n_reps)
|
128 |
+
v = repeat_kv(v, self.n_reps)
|
129 |
+
|
130 |
+
is_causal = self.causal and self.kv_cache is None
|
131 |
+
|
132 |
+
out = F.scaled_dot_product_attention(
|
133 |
+
q,
|
134 |
+
k,
|
135 |
+
v,
|
136 |
+
dropout_p=dropout_p,
|
137 |
+
is_causal=is_causal,
|
138 |
+
attn_mask=attn_mask,
|
139 |
+
)
|
140 |
+
|
141 |
+
out = self.out_proj(rearrange(out, "B h T d -> B T (h d)"))
|
142 |
+
|
143 |
+
return out
|
144 |
+
|
145 |
+
|
146 |
+
class MLP(nn.Module):
|
147 |
+
def __init__(
|
148 |
+
self, *, d_model: int, bias: bool, dropout: float, act=nn.GELU, **kwargs
|
149 |
+
):
|
150 |
+
super().__init__()
|
151 |
+
self.fc1 = nn.Linear(d_model, 4 * d_model, bias=bias)
|
152 |
+
self.act = act()
|
153 |
+
self.fc2 = nn.Linear(4 * d_model, d_model, bias=bias)
|
154 |
+
self.dropout = nn.Dropout(dropout)
|
155 |
+
|
156 |
+
def forward(self, x):
|
157 |
+
return self.dropout(self.fc2(self.act(self.fc1(x))))
|
158 |
+
|
159 |
+
|
160 |
+
class LlamaMLP(nn.Module):
|
161 |
+
def __init__(
|
162 |
+
self, *, d_model: int, multiple_of: int = 256, bias: bool = False, **kwargs
|
163 |
+
) -> None:
|
164 |
+
super().__init__()
|
165 |
+
hidden_dim = 4 * d_model
|
166 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
167 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
168 |
+
self.w1 = nn.Linear(d_model, hidden_dim, bias=bias)
|
169 |
+
self.w3 = nn.Linear(d_model, hidden_dim, bias=bias)
|
170 |
+
self.w2 = nn.Linear(hidden_dim, d_model, bias=bias)
|
171 |
+
|
172 |
+
def forward(self, x: Tensor) -> Tensor:
|
173 |
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
174 |
+
|
175 |
+
|
176 |
+
class RMSNorm(nn.Module):
|
177 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
178 |
+
super().__init__()
|
179 |
+
self.eps = eps
|
180 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
181 |
+
|
182 |
+
def _norm(self, x):
|
183 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
184 |
+
|
185 |
+
def forward(self, x: Tensor):
|
186 |
+
output = self._norm(x.float()).type_as(x)
|
187 |
+
return output * self.weight
|
188 |
+
|
189 |
+
|
190 |
+
class Block(nn.Module):
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
*,
|
194 |
+
d_model: int,
|
195 |
+
n_heads: int,
|
196 |
+
n_kv_heads: int,
|
197 |
+
block_idx: int,
|
198 |
+
bias: bool,
|
199 |
+
dropout: float,
|
200 |
+
norm_eps: float = 1e-5, # use 1e-6 for rms
|
201 |
+
use_rotary_emb: bool = True,
|
202 |
+
):
|
203 |
+
super().__init__()
|
204 |
+
|
205 |
+
self.block_idx = block_idx
|
206 |
+
self.n_heads = n_heads
|
207 |
+
self.n_kv_heads = n_kv_heads
|
208 |
+
self.head_dim = d_model // n_heads
|
209 |
+
|
210 |
+
self.attn_norm = RMSNorm(d_model, eps=norm_eps)
|
211 |
+
self.attn = MHA(
|
212 |
+
d_model,
|
213 |
+
n_heads,
|
214 |
+
n_kv_heads,
|
215 |
+
block_idx=block_idx,
|
216 |
+
bias=bias,
|
217 |
+
dropout=dropout,
|
218 |
+
causal=True,
|
219 |
+
use_rotary_emb=use_rotary_emb,
|
220 |
+
)
|
221 |
+
self.mlp_norm = RMSNorm(d_model, eps=norm_eps)
|
222 |
+
self.mlp = LlamaMLP(d_model=d_model, bias=bias, dropout=dropout)
|
223 |
+
|
224 |
+
def forward(
|
225 |
+
self,
|
226 |
+
x: Tensor,
|
227 |
+
freqs_cis: Tensor | None = None,
|
228 |
+
input_pos: Tensor | None = None,
|
229 |
+
attn_mask: Tensor | None = None,
|
230 |
+
):
|
231 |
+
x = x + self.attn(
|
232 |
+
self.attn_norm(x),
|
233 |
+
freqs_cis=freqs_cis,
|
234 |
+
input_pos=input_pos,
|
235 |
+
attn_mask=attn_mask,
|
236 |
+
)
|
237 |
+
x = x + self.mlp(self.mlp_norm(x))
|
238 |
+
|
239 |
+
return x
|
240 |
+
|
241 |
+
|
242 |
+
class Decoder(nn.Module):
|
243 |
+
def __init__(
|
244 |
+
self,
|
245 |
+
*,
|
246 |
+
n_layers: int,
|
247 |
+
d_model: int,
|
248 |
+
n_heads: int,
|
249 |
+
n_kv_heads: int,
|
250 |
+
bias: bool,
|
251 |
+
dropout: float,
|
252 |
+
max_seqlen: int = 4096,
|
253 |
+
rope_theta: float = 10000.0,
|
254 |
+
rope_theta_rescale_factor: float = 1.0,
|
255 |
+
norm_eps: float = 1e-5,
|
256 |
+
use_rotary_emb: bool = True,
|
257 |
+
rope_dim: int | None = None,
|
258 |
+
):
|
259 |
+
super().__init__()
|
260 |
+
assert d_model % n_heads == 0
|
261 |
+
|
262 |
+
self.use_rotary_emb = use_rotary_emb
|
263 |
+
|
264 |
+
self.max_seqlen = max_seqlen
|
265 |
+
self.blocks = nn.ModuleList(
|
266 |
+
[
|
267 |
+
Block(
|
268 |
+
d_model=d_model,
|
269 |
+
n_heads=n_heads,
|
270 |
+
n_kv_heads=n_kv_heads,
|
271 |
+
block_idx=block_idx,
|
272 |
+
bias=bias,
|
273 |
+
dropout=dropout,
|
274 |
+
norm_eps=norm_eps,
|
275 |
+
use_rotary_emb=use_rotary_emb,
|
276 |
+
)
|
277 |
+
for block_idx in range(n_layers)
|
278 |
+
]
|
279 |
+
)
|
280 |
+
self.norm = RMSNorm(d_model, eps=norm_eps)
|
281 |
+
|
282 |
+
self.attn_mask = None
|
283 |
+
|
284 |
+
head_dim = d_model // n_heads
|
285 |
+
|
286 |
+
rope_dim = rope_dim or head_dim
|
287 |
+
|
288 |
+
assert rope_dim <= head_dim # apply RoPE to a fraction of embeddings
|
289 |
+
|
290 |
+
freqs_cis = precompute_freqs_cis(
|
291 |
+
rope_dim,
|
292 |
+
max_seqlen,
|
293 |
+
theta=rope_theta,
|
294 |
+
theta_rescale_factor=rope_theta_rescale_factor,
|
295 |
+
)
|
296 |
+
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
297 |
+
|
298 |
+
def allocate_inference_cache(
|
299 |
+
self, batch_size: int, device: str, dtype=torch.bfloat16
|
300 |
+
):
|
301 |
+
for block in self.blocks:
|
302 |
+
block.attn.kv_cache = KVCache(
|
303 |
+
batch_size, self.max_seqlen, block.n_kv_heads, block.head_dim, dtype
|
304 |
+
).to(device)
|
305 |
+
|
306 |
+
# I don't understand why this is needed
|
307 |
+
self.attn_mask = torch.tril(
|
308 |
+
torch.ones(
|
309 |
+
self.max_seqlen, self.max_seqlen, dtype=torch.bool, device=device
|
310 |
+
)
|
311 |
+
)
|
312 |
+
|
313 |
+
def deallocate_kv_cache(self):
|
314 |
+
for block in self.blocks:
|
315 |
+
block.attn.kv_cache = None
|
316 |
+
|
317 |
+
self.attn_mask = None
|
318 |
+
|
319 |
+
def forward(self, x: Tensor, input_pos: Tensor):
|
320 |
+
if self.use_rotary_emb:
|
321 |
+
freqs_cis = self.freqs_cis[input_pos]
|
322 |
+
else:
|
323 |
+
freqs_cis = None
|
324 |
+
|
325 |
+
attn_mask = (
|
326 |
+
self.attn_mask[None, None, input_pos]
|
327 |
+
if self.attn_mask is not None
|
328 |
+
else None
|
329 |
+
)
|
330 |
+
|
331 |
+
for block in self.blocks:
|
332 |
+
x = block(x, freqs_cis=freqs_cis, input_pos=input_pos, attn_mask=attn_mask)
|
333 |
+
|
334 |
+
x = self.norm(x)
|
335 |
+
|
336 |
+
return x
|
337 |
+
|
338 |
+
|
339 |
+
class Vui(nn.Module):
|
340 |
+
BASE = "vui-100m-base.pt"
|
341 |
+
COHOST = "vui-cohost-100m.pt"
|
342 |
+
ABRAHAM = "vui-abraham-100m.pt"
|
343 |
+
|
344 |
+
def __init__(self, config: Config = Config()):
|
345 |
+
super().__init__()
|
346 |
+
self.codec = Fluac.from_pretrained()
|
347 |
+
self.config = config
|
348 |
+
cfg = config.model
|
349 |
+
self.tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
|
350 |
+
self.use_rotary_emb = cfg.use_rotary_emb
|
351 |
+
self.token_emb = nn.Embedding(self.tokenizer.vocab_size, cfg.d_model)
|
352 |
+
|
353 |
+
self.pattern_provider = DelayedPatternProvider(n_q=cfg.n_quantizers)
|
354 |
+
|
355 |
+
self.audio_embeddings = nn.ModuleList(
|
356 |
+
[
|
357 |
+
nn.Embedding(cfg.codebook_size + 8, cfg.d_model)
|
358 |
+
for _ in range(cfg.n_quantizers)
|
359 |
+
]
|
360 |
+
)
|
361 |
+
|
362 |
+
n_kv_heads = cfg.n_heads
|
363 |
+
|
364 |
+
max_seqlen = cfg.max_text_tokens + cfg.max_audio_tokens
|
365 |
+
self.decoder = Decoder(
|
366 |
+
n_layers=cfg.n_layers,
|
367 |
+
d_model=cfg.d_model,
|
368 |
+
n_heads=cfg.n_heads,
|
369 |
+
n_kv_heads=n_kv_heads,
|
370 |
+
bias=cfg.bias,
|
371 |
+
dropout=cfg.dropout,
|
372 |
+
max_seqlen=max_seqlen + cfg.n_quantizers,
|
373 |
+
rope_dim=cfg.rope_dim,
|
374 |
+
rope_theta=cfg.rope_theta,
|
375 |
+
rope_theta_rescale_factor=cfg.rope_theta_rescale_factor,
|
376 |
+
)
|
377 |
+
|
378 |
+
self.audio_heads = nn.ModuleList(
|
379 |
+
[
|
380 |
+
nn.Linear(cfg.d_model, cfg.codebook_size + 8, bias=cfg.bias)
|
381 |
+
for _ in range(cfg.n_quantizers)
|
382 |
+
]
|
383 |
+
)
|
384 |
+
|
385 |
+
self.apply(self._init_weights)
|
386 |
+
|
387 |
+
for pn, p in self.named_parameters():
|
388 |
+
if pn.endswith("out_proj.weight"):
|
389 |
+
torch.nn.init.normal_(
|
390 |
+
p, mean=0.0, std=0.02 / math.sqrt(2 * cfg.n_layers)
|
391 |
+
)
|
392 |
+
|
393 |
+
def _init_weights(self, module):
|
394 |
+
if isinstance(module, nn.Linear):
|
395 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
396 |
+
if module.bias is not None:
|
397 |
+
torch.nn.init.zeros_(module.bias)
|
398 |
+
elif isinstance(module, nn.Embedding):
|
399 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
400 |
+
|
401 |
+
@staticmethod
|
402 |
+
def from_pretrained(
|
403 |
+
checkpoint_path: str | dict = ABRAHAM,
|
404 |
+
**config_kwargs,
|
405 |
+
):
|
406 |
+
if isinstance(checkpoint_path, dict):
|
407 |
+
checkpoint = checkpoint_path
|
408 |
+
else:
|
409 |
+
if not os.path.exists(checkpoint_path):
|
410 |
+
from huggingface_hub import hf_hub_download
|
411 |
+
|
412 |
+
checkpoint_path = hf_hub_download(
|
413 |
+
"fluxions/vui",
|
414 |
+
checkpoint_path,
|
415 |
+
)
|
416 |
+
checkpoint = torch.load(
|
417 |
+
checkpoint_path, map_location="cpu", weights_only=True
|
418 |
+
)
|
419 |
+
|
420 |
+
config = {**checkpoint["config"], **config_kwargs}
|
421 |
+
config = Config(**config)
|
422 |
+
state_dict = checkpoint["model"]
|
423 |
+
|
424 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
425 |
+
state_dict = {
|
426 |
+
k.replace("text_embedding.", "token_emb."): v for k, v in state_dict.items()
|
427 |
+
}
|
428 |
+
model = Vui(config)
|
429 |
+
load_what_you_can(state_dict, model)
|
430 |
+
return model
|
431 |
+
|
432 |
+
@staticmethod
|
433 |
+
def from_pretrained_inf(
|
434 |
+
checkpoint_path: str | dict,
|
435 |
+
**config_kwargs,
|
436 |
+
):
|
437 |
+
return Vui.from_pretrained(checkpoint_path, **config_kwargs).eval()
|
438 |
+
|
439 |
+
@property
|
440 |
+
def device(self):
|
441 |
+
return next(self.parameters()).device
|
442 |
+
|
443 |
+
@property
|
444 |
+
def dtype(self):
|
445 |
+
return next(self.parameters()).dtype
|
src/vui/notebook.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def play(audio: torch.Tensor | np.ndarray | str, sr=16000, autoplay=True):
|
6 |
+
import torchaudio
|
7 |
+
from IPython.display import Audio, display
|
8 |
+
|
9 |
+
if isinstance(audio, str):
|
10 |
+
audio = torchaudio.load(audio)
|
11 |
+
if isinstance(audio, np.ndarray):
|
12 |
+
audio = torch.from_numpy(audio)
|
13 |
+
|
14 |
+
assert audio.numel() > 100, "play() needs a non empty audio array"
|
15 |
+
|
16 |
+
audio = audio.flatten()
|
17 |
+
if audio.dim() < 2:
|
18 |
+
audio = audio[None]
|
19 |
+
|
20 |
+
# Sum Channels
|
21 |
+
if audio.shape[0] > 1:
|
22 |
+
audio = audio.sum(dim=0)
|
23 |
+
|
24 |
+
display(Audio(audio.cpu().detach(), rate=sr, autoplay=autoplay, normalize=True))
|
25 |
+
|
26 |
+
|
27 |
+
def plot_mel_spec(mel_spec: torch.Tensor | np.ndarray, title: str = None):
|
28 |
+
import matplotlib.pyplot as plt
|
29 |
+
|
30 |
+
mel_spec = mel_spec.squeeze()
|
31 |
+
if isinstance(mel_spec, torch.Tensor):
|
32 |
+
mel_spec = mel_spec.cpu().numpy()
|
33 |
+
|
34 |
+
fig, ax = plt.subplots(figsize=(16, 4))
|
35 |
+
im = ax.imshow(mel_spec, aspect="auto", origin="lower", interpolation="none")
|
36 |
+
fig.colorbar(im, ax=ax)
|
37 |
+
ax.set_xlabel("frames")
|
38 |
+
ax.set_ylabel("channels")
|
39 |
+
|
40 |
+
if title is not None:
|
41 |
+
ax.set_title(title)
|
src/vui/patterns.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from collections import namedtuple
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from functools import lru_cache
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
def apply_delay_pattern(codes: torch.Tensor, mask_token: int):
|
12 |
+
codes = F.pad(codes, (0, codes.shape[1] + 1), value=mask_token)
|
13 |
+
return torch.stack([codes[:, k].roll(k + 1) for k in range(codes.shape[1])], dim=1)
|
14 |
+
|
15 |
+
|
16 |
+
def revert_delay_pattern(codes: torch.Tensor):
|
17 |
+
_, n_q, seq_len = codes.shape
|
18 |
+
return torch.stack(
|
19 |
+
[codes[:, k, k + 1 : seq_len - n_q + k] for k in range(n_q)], dim=1
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
LayoutCoord = namedtuple("LayoutCoord", ["t", "q"]) # (timestep, codebook index)
|
24 |
+
PatternLayout = list[list[LayoutCoord]] # Sequence of coordinates
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class Pattern:
|
30 |
+
"""Base implementation of a pattern over a sequence with multiple codebooks.
|
31 |
+
|
32 |
+
The codebook pattern consists in a layout, defining for each sequence step
|
33 |
+
the list of coordinates of each codebook timestep in the resulting interleaved sequence.
|
34 |
+
The first item of the pattern is always an empty list in order to properly insert a special token
|
35 |
+
to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
|
36 |
+
and ``timesteps`` the number of timesteps corresponding to the original sequence.
|
37 |
+
|
38 |
+
The pattern provides convenient methods to build and revert interleaved sequences from it:
|
39 |
+
``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
|
40 |
+
to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size,
|
41 |
+
K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
|
42 |
+
for the output sequence. The unfilled positions are replaced with a special token and the built sequence
|
43 |
+
is returned along with a mask indicating valid tokens.
|
44 |
+
``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
|
45 |
+
of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
|
46 |
+
to fill and specify invalid positions if needed.
|
47 |
+
See the dedicated methods for more details.
|
48 |
+
"""
|
49 |
+
|
50 |
+
# Pattern layout, for each sequence step, we have a list of coordinates
|
51 |
+
# corresponding to the original codebook timestep and position.
|
52 |
+
# The first list is always an empty list in order to properly insert
|
53 |
+
# a special token to start with.
|
54 |
+
layout: PatternLayout
|
55 |
+
timesteps: int
|
56 |
+
n_q: int
|
57 |
+
|
58 |
+
def __post_init__(self):
|
59 |
+
assert len(self.layout) > 0
|
60 |
+
self._validate_layout()
|
61 |
+
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(
|
62 |
+
self._build_reverted_sequence_scatter_indexes
|
63 |
+
)
|
64 |
+
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(
|
65 |
+
self._build_pattern_sequence_scatter_indexes
|
66 |
+
)
|
67 |
+
logger.info(
|
68 |
+
"New pattern, time steps: %d, sequence steps: %d",
|
69 |
+
self.timesteps,
|
70 |
+
len(self.layout),
|
71 |
+
)
|
72 |
+
|
73 |
+
def _validate_layout(self):
|
74 |
+
"""Runs checks on the layout to ensure a valid pattern is defined.
|
75 |
+
A pattern is considered invalid if:
|
76 |
+
- Multiple timesteps for a same codebook are defined in the same sequence step
|
77 |
+
- The timesteps for a given codebook are not in ascending order as we advance in the sequence
|
78 |
+
(this would mean that we have future timesteps before past timesteps).
|
79 |
+
"""
|
80 |
+
q_timesteps = {q: 0 for q in range(self.n_q)}
|
81 |
+
for s, seq_coords in enumerate(self.layout):
|
82 |
+
if len(seq_coords) > 0:
|
83 |
+
qs = set()
|
84 |
+
for coord in seq_coords:
|
85 |
+
qs.add(coord.q)
|
86 |
+
last_q_timestep = q_timesteps[coord.q]
|
87 |
+
assert (
|
88 |
+
coord.t >= last_q_timestep
|
89 |
+
), f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
|
90 |
+
q_timesteps[coord.q] = coord.t
|
91 |
+
# each sequence step contains at max 1 coordinate per codebook
|
92 |
+
assert len(qs) == len(
|
93 |
+
seq_coords
|
94 |
+
), f"Multiple entries for a same codebook are found at step {s}"
|
95 |
+
|
96 |
+
@property
|
97 |
+
def num_sequence_steps(self):
|
98 |
+
return len(self.layout) - 1
|
99 |
+
|
100 |
+
@property
|
101 |
+
def max_delay(self):
|
102 |
+
max_t_in_seq_coords = 0
|
103 |
+
for seq_coords in self.layout[1:]:
|
104 |
+
for coords in seq_coords:
|
105 |
+
max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
|
106 |
+
return max_t_in_seq_coords - self.timesteps
|
107 |
+
|
108 |
+
@property
|
109 |
+
def valid_layout(self):
|
110 |
+
valid_step = len(self.layout) - self.max_delay
|
111 |
+
return self.layout[:valid_step]
|
112 |
+
|
113 |
+
def starts_with_special_token(self):
|
114 |
+
return self.layout[0] == []
|
115 |
+
|
116 |
+
def get_sequence_coords_with_timestep(self, t: int, q: int | None = None):
|
117 |
+
"""Get codebook coordinates in the layout that corresponds to the specified timestep t
|
118 |
+
and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
|
119 |
+
and the actual codebook coordinates.
|
120 |
+
"""
|
121 |
+
assert (
|
122 |
+
t <= self.timesteps
|
123 |
+
), "provided timesteps is greater than the pattern's number of timesteps"
|
124 |
+
if q is not None:
|
125 |
+
assert (
|
126 |
+
q <= self.n_q
|
127 |
+
), "provided number of codebooks is greater than the pattern's number of codebooks"
|
128 |
+
coords = []
|
129 |
+
for s, seq_codes in enumerate(self.layout):
|
130 |
+
for code in seq_codes:
|
131 |
+
if code.t == t and (q is None or code.q == q):
|
132 |
+
coords.append((s, code))
|
133 |
+
return coords
|
134 |
+
|
135 |
+
def get_steps_with_timestep(self, t: int, q: int | None = None) -> list[int]:
|
136 |
+
return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
|
137 |
+
|
138 |
+
def get_first_step_with_timesteps(self, t: int, q: int | None = None) -> int | None:
|
139 |
+
steps_with_timesteps = self.get_steps_with_timestep(t, q)
|
140 |
+
return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
|
141 |
+
|
142 |
+
def _build_pattern_sequence_scatter_indexes(
|
143 |
+
self,
|
144 |
+
timesteps: int,
|
145 |
+
n_q: int,
|
146 |
+
keep_only_valid_steps: bool,
|
147 |
+
device: torch.device | str = "cpu",
|
148 |
+
):
|
149 |
+
"""Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
timesteps (int): Maximum number of timesteps steps to consider.
|
153 |
+
keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
|
154 |
+
device (torch.device or str): Device for created tensors.
|
155 |
+
Returns:
|
156 |
+
indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
|
157 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
|
158 |
+
"""
|
159 |
+
assert (
|
160 |
+
n_q == self.n_q
|
161 |
+
), f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
162 |
+
assert (
|
163 |
+
timesteps <= self.timesteps
|
164 |
+
), "invalid number of timesteps used to build the sequence from the pattern"
|
165 |
+
# use the proper layout based on whether we limit ourselves to valid steps only or not,
|
166 |
+
# note that using the valid_layout will result in a truncated sequence up to the valid steps
|
167 |
+
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
168 |
+
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
169 |
+
indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
|
170 |
+
mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
|
171 |
+
# fill indexes with last sequence step value that will correspond to our special token
|
172 |
+
# the last value is n_q * timesteps as we have flattened z and append special token as the last token
|
173 |
+
# which will correspond to the index: n_q * timesteps
|
174 |
+
indexes[:] = n_q * timesteps
|
175 |
+
# iterate over the pattern and fill scattered indexes and mask
|
176 |
+
for s, sequence_coords in enumerate(ref_layout):
|
177 |
+
for coords in sequence_coords:
|
178 |
+
if coords.t < timesteps:
|
179 |
+
indexes[coords.q, s] = coords.t + coords.q * timesteps
|
180 |
+
mask[coords.q, s] = 1
|
181 |
+
indexes = torch.from_numpy(indexes).to(device)
|
182 |
+
mask = torch.from_numpy(mask).to(device)
|
183 |
+
return indexes, mask
|
184 |
+
|
185 |
+
def build_pattern_sequence(
|
186 |
+
self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False
|
187 |
+
):
|
188 |
+
"""Build sequence corresponding to the pattern from the input tensor z.
|
189 |
+
The sequence is built using up to sequence_steps if specified, and non-pattern
|
190 |
+
coordinates are filled with the special token.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
|
194 |
+
special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
|
195 |
+
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
196 |
+
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
197 |
+
Returns:
|
198 |
+
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
|
199 |
+
corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
|
200 |
+
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
|
201 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
|
202 |
+
"""
|
203 |
+
B, K, T = z.shape
|
204 |
+
indexes, mask = self._build_pattern_sequence_scatter_indexes(
|
205 |
+
T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
|
206 |
+
)
|
207 |
+
z = z.reshape(B, -1)
|
208 |
+
# we append the special token as the last index of our flattened z tensor
|
209 |
+
z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
|
210 |
+
values = z[:, indexes.view(-1)]
|
211 |
+
values = values.view(B, K, indexes.shape[-1])
|
212 |
+
return values, indexes, mask
|
213 |
+
|
214 |
+
def _build_reverted_sequence_scatter_indexes(
|
215 |
+
self,
|
216 |
+
sequence_steps: int,
|
217 |
+
n_q: int,
|
218 |
+
keep_only_valid_steps: bool = False,
|
219 |
+
is_model_output: bool = False,
|
220 |
+
device: torch.device | str = "cpu",
|
221 |
+
):
|
222 |
+
"""Builds scatter indexes required to retrieve the original multi-codebook sequence
|
223 |
+
from interleaving pattern.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
sequence_steps (int): Sequence steps.
|
227 |
+
n_q (int): Number of codebooks.
|
228 |
+
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
229 |
+
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
230 |
+
is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
|
231 |
+
device (torch.device or str): Device for created tensors.
|
232 |
+
Returns:
|
233 |
+
indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
|
234 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
235 |
+
"""
|
236 |
+
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
237 |
+
timesteps = self.timesteps
|
238 |
+
assert (
|
239 |
+
n_q == self.n_q
|
240 |
+
), f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
241 |
+
assert sequence_steps <= len(
|
242 |
+
ref_layout
|
243 |
+
), f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
|
244 |
+
|
245 |
+
# ensure we take the appropriate indexes to keep the model output from the first special token as well
|
246 |
+
if is_model_output and self.starts_with_special_token():
|
247 |
+
ref_layout = ref_layout[1:]
|
248 |
+
|
249 |
+
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
250 |
+
indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
|
251 |
+
mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
|
252 |
+
# fill indexes with last sequence step value that will correspond to our special token
|
253 |
+
indexes[:] = n_q * sequence_steps
|
254 |
+
for s, sequence_codes in enumerate(ref_layout):
|
255 |
+
if s < sequence_steps:
|
256 |
+
for code in sequence_codes:
|
257 |
+
if code.t < timesteps:
|
258 |
+
indexes[code.q, code.t] = s + code.q * sequence_steps
|
259 |
+
mask[code.q, code.t] = 1
|
260 |
+
indexes = torch.from_numpy(indexes).to(device)
|
261 |
+
mask = torch.from_numpy(mask).to(device)
|
262 |
+
return indexes, mask
|
263 |
+
|
264 |
+
def revert_pattern_sequence(
|
265 |
+
self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False
|
266 |
+
):
|
267 |
+
"""Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
|
268 |
+
The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
|
269 |
+
are filled with the special token.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
|
273 |
+
special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
|
274 |
+
Returns:
|
275 |
+
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
|
276 |
+
corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
|
277 |
+
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
|
278 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
279 |
+
"""
|
280 |
+
B, K, S = s.shape
|
281 |
+
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
282 |
+
S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
|
283 |
+
)
|
284 |
+
s = s.view(B, -1)
|
285 |
+
# we append the special token as the last index of our flattened z tensor
|
286 |
+
s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
|
287 |
+
values = s[:, indexes.view(-1)]
|
288 |
+
values = values.view(B, K, indexes.shape[-1])
|
289 |
+
return values, indexes, mask
|
290 |
+
|
291 |
+
def revert_pattern_logits(
|
292 |
+
self,
|
293 |
+
logits: torch.Tensor,
|
294 |
+
special_token: float,
|
295 |
+
keep_only_valid_steps: bool = False,
|
296 |
+
):
|
297 |
+
"""Revert model logits obtained on a sequence built from the pattern
|
298 |
+
back to a tensor matching the original sequence.
|
299 |
+
|
300 |
+
This method is similar to ``revert_pattern_sequence`` with the following specificities:
|
301 |
+
1. It is designed to work with the extra cardinality dimension
|
302 |
+
2. We return the logits for the first sequence item that matches the special_token and
|
303 |
+
which matching target in the original sequence is the first item of the sequence,
|
304 |
+
while we skip the last logits as there is no matching target
|
305 |
+
"""
|
306 |
+
B, n, Q, S = logits.shape
|
307 |
+
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
308 |
+
S, Q, keep_only_valid_steps, is_model_output=True, device=logits.device
|
309 |
+
)
|
310 |
+
logits = logits.reshape(B, n, -1)
|
311 |
+
# we append the special token as the last index of our flattened z tensor
|
312 |
+
logits = torch.cat(
|
313 |
+
[logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1
|
314 |
+
) # [B, card, K x S]
|
315 |
+
values = logits[:, :, indexes.view(-1)]
|
316 |
+
values = values.view(B, n, Q, indexes.shape[-1])
|
317 |
+
return values, indexes, mask
|
318 |
+
|
319 |
+
|
320 |
+
class CodebooksPatternProvider(ABC):
|
321 |
+
"""Abstraction around providing pattern for interleaving codebooks.
|
322 |
+
|
323 |
+
The CodebooksPatternProvider abstraction allows to implement various strategies to
|
324 |
+
define interleaving pattern of sequences composed of multiple codebooks. For a given
|
325 |
+
number of codebooks `n_q`, the pattern provider can generate a specified pattern
|
326 |
+
corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
|
327 |
+
can be used to construct a new sequence from the original codes respecting the specified
|
328 |
+
pattern. The pattern is defined as a list of list of code coordinates, code coordinate
|
329 |
+
being a tuple with the original timestep and codebook to build the new sequence.
|
330 |
+
Note that all patterns must start with an empty list that is then used to insert a first
|
331 |
+
sequence step of special tokens in the newly generated sequence.
|
332 |
+
|
333 |
+
Args:
|
334 |
+
n_q (int): number of codebooks.
|
335 |
+
cached (bool): if True, patterns for a given length are cached. In general
|
336 |
+
that should be true for efficiency reason to avoid synchronization points.
|
337 |
+
"""
|
338 |
+
|
339 |
+
def __init__(self, n_q: int, cached: bool = True):
|
340 |
+
assert n_q > 0
|
341 |
+
self.n_q = n_q
|
342 |
+
self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
|
343 |
+
|
344 |
+
@abstractmethod
|
345 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
346 |
+
"""Builds pattern with specific interleaving between codebooks.
|
347 |
+
|
348 |
+
Args:
|
349 |
+
timesteps (int): Total number of timesteps.
|
350 |
+
"""
|
351 |
+
raise NotImplementedError()
|
352 |
+
|
353 |
+
|
354 |
+
class DelayedPatternProvider(CodebooksPatternProvider):
|
355 |
+
"""Provider for delayed pattern across delayed codebooks.
|
356 |
+
Codebooks are delayed in the sequence and sequence steps will contain codebooks
|
357 |
+
from different timesteps.
|
358 |
+
|
359 |
+
Example:
|
360 |
+
Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
|
361 |
+
[[1, 2, 3, 4],
|
362 |
+
[1, 2, 3, 4],
|
363 |
+
[1, 2, 3, 4]]
|
364 |
+
The resulting sequence obtained from the returned pattern is:
|
365 |
+
[[S, 1, 2, 3, 4],
|
366 |
+
[S, S, 1, 2, 3],
|
367 |
+
[S, S, S, 1, 2]]
|
368 |
+
(with S being a special token)
|
369 |
+
|
370 |
+
Args:
|
371 |
+
n_q (int): Number of codebooks.
|
372 |
+
delays (list of int, optional): Delay for each of the codebooks.
|
373 |
+
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
374 |
+
flatten_first (int): Flatten the first N timesteps.
|
375 |
+
empty_initial (int): Prepend with N empty list of coordinates.
|
376 |
+
"""
|
377 |
+
|
378 |
+
def __init__(
|
379 |
+
self,
|
380 |
+
n_q: int,
|
381 |
+
delays: list[int] | None = None,
|
382 |
+
flatten_first: int = 0,
|
383 |
+
empty_initial: int = 0,
|
384 |
+
):
|
385 |
+
super().__init__(n_q)
|
386 |
+
if delays is None:
|
387 |
+
delays = list(range(n_q))
|
388 |
+
self.delays = delays
|
389 |
+
self.flatten_first = flatten_first
|
390 |
+
self.empty_initial = empty_initial
|
391 |
+
assert len(self.delays) == self.n_q
|
392 |
+
assert sorted(self.delays) == self.delays
|
393 |
+
|
394 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
395 |
+
omit_special_token = self.empty_initial < 0
|
396 |
+
out: PatternLayout = [] if omit_special_token else [[]]
|
397 |
+
max_delay = max(self.delays)
|
398 |
+
if self.empty_initial:
|
399 |
+
out += [[] for _ in range(self.empty_initial)]
|
400 |
+
if self.flatten_first:
|
401 |
+
for t in range(min(timesteps, self.flatten_first)):
|
402 |
+
for q in range(self.n_q):
|
403 |
+
out.append([LayoutCoord(t, q)])
|
404 |
+
for t in range(self.flatten_first, timesteps + max_delay):
|
405 |
+
v = []
|
406 |
+
for q, delay in enumerate(self.delays):
|
407 |
+
t_for_q = t - delay
|
408 |
+
if t_for_q >= self.flatten_first:
|
409 |
+
v.append(LayoutCoord(t_for_q, q))
|
410 |
+
out.append(v)
|
411 |
+
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
412 |
+
|
413 |
+
|
414 |
+
if __name__ == "__main__":
|
415 |
+
# Tried to use the simple patterns to train and something very odd happened.
|
416 |
+
Q = 4
|
417 |
+
|
418 |
+
codes = torch.randint(0, 1000, (1, Q, 100))
|
419 |
+
pcodes = apply_delay_pattern(codes, 1001)
|
420 |
+
provider = DelayedPatternProvider(Q)
|
421 |
+
provider = provider.get_pattern(100)
|
422 |
+
pcodes2 = provider.build_pattern_sequence(codes, 1001)
|
423 |
+
breakpoint()
|
src/vui/rope.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange, repeat
|
3 |
+
from torch import Tensor
|
4 |
+
from torch.amp import autocast
|
5 |
+
|
6 |
+
|
7 |
+
def rotate_half(x):
|
8 |
+
"""Also known as "interleaved" style or GPT-J style."""
|
9 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
10 |
+
x1, x2 = x.unbind(dim=-1)
|
11 |
+
x = torch.stack((-x2, x1), dim=-1)
|
12 |
+
return rearrange(x, "... d r -> ... (d r)")
|
13 |
+
|
14 |
+
|
15 |
+
@autocast("cuda", enabled=False)
|
16 |
+
def apply_rotary_emb(
|
17 |
+
freqs: Tensor, t: Tensor, start_index: int = 0, scale: float = 1.0, seq_dim=-2
|
18 |
+
):
|
19 |
+
dtype = t.dtype
|
20 |
+
|
21 |
+
if t.ndim == 3:
|
22 |
+
seq_len = t.shape[seq_dim]
|
23 |
+
freqs = freqs[-seq_len:]
|
24 |
+
|
25 |
+
rot_dim = freqs.shape[-1]
|
26 |
+
end_index = start_index + rot_dim
|
27 |
+
|
28 |
+
assert (
|
29 |
+
rot_dim <= t.shape[-1]
|
30 |
+
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
31 |
+
|
32 |
+
t_left, t, t_right = (
|
33 |
+
t[..., :start_index],
|
34 |
+
t[..., start_index:end_index],
|
35 |
+
t[..., end_index:],
|
36 |
+
)
|
37 |
+
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
38 |
+
out = torch.cat((t_left, t, t_right), dim=-1)
|
39 |
+
return out.to(dtype)
|
40 |
+
|
41 |
+
|
42 |
+
def precompute_freqs_cis(
|
43 |
+
dim: int,
|
44 |
+
max_seqlen: int,
|
45 |
+
theta: float = 10_000.0,
|
46 |
+
theta_rescale_factor: float = 1.0,
|
47 |
+
dtype: torch.dtype = torch.float32,
|
48 |
+
):
|
49 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
50 |
+
pos = torch.arange(max_seqlen, dtype=dtype)
|
51 |
+
inv_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype) / dim))
|
52 |
+
freqs = torch.einsum("..., f -> ... f", pos.to(inv_freqs.dtype), inv_freqs)
|
53 |
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
54 |
+
return freqs
|
src/vui/sampling.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
|
4 |
+
|
5 |
+
def multinomial(input: Tensor, num_samples: int, replacement=False, *, generator=None):
|
6 |
+
input_ = input.reshape(-1, input.shape[-1])
|
7 |
+
output_ = torch.multinomial(
|
8 |
+
input_, num_samples=num_samples, replacement=replacement, generator=generator
|
9 |
+
)
|
10 |
+
output = output_.reshape(*list(input.shape[:-1]), -1)
|
11 |
+
return output
|
12 |
+
|
13 |
+
|
14 |
+
def sample_top_k(probs: Tensor, k: int) -> Tensor:
|
15 |
+
top_k_value, _ = torch.topk(probs, k, dim=-1)
|
16 |
+
min_value_top_k = top_k_value[..., [-1]]
|
17 |
+
probs *= (probs >= min_value_top_k).float()
|
18 |
+
probs.div_(probs.sum(dim=-1, keepdim=True))
|
19 |
+
next_token = multinomial(probs, num_samples=1)
|
20 |
+
return next_token
|
21 |
+
|
22 |
+
|
23 |
+
def sample_top_p(probs: Tensor, p: float) -> Tensor:
|
24 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
25 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
26 |
+
mask = probs_sum - probs_sort > p
|
27 |
+
probs_sort *= (~mask).float()
|
28 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
29 |
+
next_token = multinomial(probs_sort, num_samples=1)
|
30 |
+
next_token = torch.gather(probs_idx, -1, next_token)
|
31 |
+
|
32 |
+
return next_token
|
33 |
+
|
34 |
+
|
35 |
+
def sample_top_p_top_k(probs: Tensor, p: float, top_k: int):
|
36 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
37 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
38 |
+
mask = probs_sum - probs_sort > p
|
39 |
+
probs_sort *= (~mask).float()
|
40 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
41 |
+
next_token = sample_top_k(probs_sort, top_k)
|
42 |
+
next_token = torch.gather(probs_idx, -1, next_token)
|
43 |
+
return next_token
|
src/vui/tok.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import ByT5Tokenizer
|
3 |
+
|
4 |
+
|
5 |
+
class CustomByT5Tokenizer(ByT5Tokenizer):
|
6 |
+
def encode(self, text, add_special_tokens=False, **kwargs):
|
7 |
+
"""
|
8 |
+
Override the encode method.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
text (str): Input text
|
12 |
+
add_special_tokens (bool): Whether to add BOS/EOS tokens
|
13 |
+
"""
|
14 |
+
# Use the parent class's encode method
|
15 |
+
tokens = super().encode(text, add_special_tokens=add_special_tokens, **kwargs)
|
16 |
+
return torch.tensor(tokens)
|
17 |
+
|
18 |
+
|
19 |
+
tok = CustomByT5Tokenizer.from_pretrained("google/byt5-small")
|
src/vui/utils.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import time
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import Tensor
|
8 |
+
|
9 |
+
|
10 |
+
def load_what_you_can(checkpoint: dict, model: torch.nn.Module):
|
11 |
+
"""
|
12 |
+
This method takes a checkpoint and loads as many weights from it as possible:
|
13 |
+
|
14 |
+
If they are the same shape, there's nothing to do
|
15 |
+
|
16 |
+
Will load the smallest shape otherwise.
|
17 |
+
"""
|
18 |
+
import torch
|
19 |
+
|
20 |
+
model_state_dict = model.state_dict()
|
21 |
+
checkpoint_state_dict = checkpoint
|
22 |
+
|
23 |
+
for name, param in checkpoint_state_dict.items():
|
24 |
+
if name not in model_state_dict:
|
25 |
+
print(f"Ignoring parameter '{name}' because it is not found in the model")
|
26 |
+
continue
|
27 |
+
|
28 |
+
model_state = model_state_dict[name]
|
29 |
+
mshape = model_state.shape
|
30 |
+
pshape = param.shape
|
31 |
+
|
32 |
+
if pshape == mshape:
|
33 |
+
model_state.copy_(param)
|
34 |
+
continue
|
35 |
+
|
36 |
+
if len(pshape) != len(mshape):
|
37 |
+
# Completely different shapes so probably unwise to merge
|
38 |
+
continue
|
39 |
+
|
40 |
+
min_shape = [
|
41 |
+
min(param.shape[i], model_state.shape[i]) for i in range(len(param.shape))
|
42 |
+
]
|
43 |
+
print(name, "model:", mshape, "chkpt:", pshape, "loading:", min_shape)
|
44 |
+
idxs = torch.meshgrid(*[torch.arange(s) for s in min_shape])
|
45 |
+
model_state[tuple(idxs)].copy_(param[tuple(idxs)])
|
46 |
+
|
47 |
+
return model.load_state_dict(model_state_dict)
|
48 |
+
|
49 |
+
|
50 |
+
def multimap(
|
51 |
+
items: list, func: callable, workers=4, desc=None, thread=False, chunk_size=128
|
52 |
+
) -> list:
|
53 |
+
"""
|
54 |
+
Quick and dirty multiprocessing that will return the result of func if it returns None
|
55 |
+
"""
|
56 |
+
from tqdm.contrib.concurrent import process_map, thread_map
|
57 |
+
|
58 |
+
m = thread_map if thread else process_map
|
59 |
+
length = None
|
60 |
+
try:
|
61 |
+
length = len(items)
|
62 |
+
except Exception as e:
|
63 |
+
print(e, "getting length")
|
64 |
+
|
65 |
+
results = m(
|
66 |
+
func,
|
67 |
+
items,
|
68 |
+
leave=False,
|
69 |
+
desc=desc,
|
70 |
+
max_workers=workers,
|
71 |
+
total=length,
|
72 |
+
chunksize=chunk_size,
|
73 |
+
)
|
74 |
+
return list(filter(lambda x: x is not None, results))
|
75 |
+
|
76 |
+
|
77 |
+
def round_up(num: float, factor: int):
|
78 |
+
return factor * math.ceil(num / factor)
|
79 |
+
|
80 |
+
|
81 |
+
def left_padding_mask(lengths, max_len, device=None, dtype=None):
|
82 |
+
masks = []
|
83 |
+
if not max_len:
|
84 |
+
max_len = max(lengths)
|
85 |
+
for l in lengths:
|
86 |
+
mask = torch.empty(l, l, device=device, dtype=dtype).fill_(-torch.inf).triu_(1)
|
87 |
+
diff = max_len - l
|
88 |
+
mask = F.pad(mask, (diff, 0, diff, 0), value=-torch.inf)
|
89 |
+
masks.append(mask)
|
90 |
+
|
91 |
+
masks = torch.stack(masks)
|
92 |
+
return masks[:, None]
|
93 |
+
|
94 |
+
|
95 |
+
def seed_all(seed: int):
|
96 |
+
import random
|
97 |
+
|
98 |
+
import numpy as np
|
99 |
+
import torch
|
100 |
+
|
101 |
+
torch.manual_seed(seed)
|
102 |
+
np.random.seed(seed)
|
103 |
+
random.seed(seed)
|
104 |
+
|
105 |
+
|
106 |
+
def split_bucket_path(url: str) -> tuple[str, str]:
|
107 |
+
url = url.replace("s3://", "")
|
108 |
+
url = url.replace("sj://", "")
|
109 |
+
url = url.replace("r2://", "")
|
110 |
+
bucket = url.split("/")[0]
|
111 |
+
path = "/".join(url.split("/")[1:])
|
112 |
+
return bucket, path
|
113 |
+
|
114 |
+
|
115 |
+
def prob_mask_like(shape, prob: float, device):
|
116 |
+
import torch
|
117 |
+
|
118 |
+
if prob == 1:
|
119 |
+
return torch.ones(shape, device=device, dtype=torch.bool)
|
120 |
+
elif prob == 0:
|
121 |
+
return torch.zeros(shape, device=device, dtype=torch.bool)
|
122 |
+
else:
|
123 |
+
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
|
124 |
+
|
125 |
+
|
126 |
+
def round_up_to_multiple(n: int, multiple: int) -> int:
|
127 |
+
if n % multiple != 0:
|
128 |
+
n += multiple - (n % multiple)
|
129 |
+
|
130 |
+
return n
|
131 |
+
|
132 |
+
|
133 |
+
def warmup_then_cosine_decay(
|
134 |
+
step: int, *, warmup_steps: int, steps: int, min_lr: float, max_lr: float
|
135 |
+
):
|
136 |
+
eps = 1e-9
|
137 |
+
cooldown_steps = warmup_steps
|
138 |
+
if step < warmup_steps:
|
139 |
+
return min_lr + step * (max_lr - min_lr) / (warmup_steps)
|
140 |
+
elif step > steps:
|
141 |
+
return min_lr
|
142 |
+
elif step < steps - cooldown_steps:
|
143 |
+
decay_ratio = (step - warmup_steps) / (steps - warmup_steps - cooldown_steps)
|
144 |
+
# assert 0 <= decay_ratio <= 1
|
145 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
146 |
+
return min_lr + coeff * (max_lr - min_lr)
|
147 |
+
else:
|
148 |
+
# decay from min_lr to 0
|
149 |
+
return min_lr * (steps - step) / cooldown_steps + eps
|
150 |
+
|
151 |
+
|
152 |
+
def decay_to_zero(step: int, *, decay_steps: int, steps: int, max_lr: float):
|
153 |
+
if step > steps:
|
154 |
+
return 0.0
|
155 |
+
else:
|
156 |
+
gradient = -max_lr / decay_steps
|
157 |
+
|
158 |
+
return max_lr + gradient * step
|
159 |
+
|
160 |
+
|
161 |
+
def cross_entropy_loss(logits, mask, targets):
|
162 |
+
import torch
|
163 |
+
import torch.nn.functional as F
|
164 |
+
|
165 |
+
B, Q, T, _ = logits.size()
|
166 |
+
assert logits.shape[:-1] == targets.shape
|
167 |
+
assert mask.shape == targets.shape
|
168 |
+
loss = torch.zeros([], device=targets.device)
|
169 |
+
codebook_losses = []
|
170 |
+
for q in range(Q):
|
171 |
+
logits_q = (
|
172 |
+
logits[:, q, ...].contiguous().view(-1, logits.size(-1))
|
173 |
+
) # [B x T, card]
|
174 |
+
targets_q = targets[:, q, ...].contiguous().view(-1) # [B x T]
|
175 |
+
mask_q = mask[:, q, ...].contiguous().view(-1) # [B x T]
|
176 |
+
ce_targets = targets_q[mask_q]
|
177 |
+
ce_logits = logits_q[mask_q]
|
178 |
+
q_ce = F.cross_entropy(ce_logits, ce_targets)
|
179 |
+
loss += q_ce
|
180 |
+
codebook_losses.append(q_ce.detach())
|
181 |
+
# average cross entropy across codebooks
|
182 |
+
loss = loss / Q
|
183 |
+
return loss, codebook_losses
|
184 |
+
|
185 |
+
|
186 |
+
def build_optimizer(
|
187 |
+
module, *, weight_decay: float, lr: float, betas: tuple[float, float]
|
188 |
+
):
|
189 |
+
import torch
|
190 |
+
|
191 |
+
param_dict = {pn: p for pn, p in module.named_parameters() if p.requires_grad}
|
192 |
+
|
193 |
+
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
194 |
+
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
195 |
+
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
196 |
+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
197 |
+
optim_groups = [
|
198 |
+
{"params": decay_params, "weight_decay": weight_decay},
|
199 |
+
{"params": nodecay_params, "weight_decay": 0.0},
|
200 |
+
]
|
201 |
+
# num_decay_params = sum(p.numel() for p in decay_params)
|
202 |
+
# num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
203 |
+
# print(
|
204 |
+
# f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters"
|
205 |
+
# )
|
206 |
+
# print(
|
207 |
+
# f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters"
|
208 |
+
# )
|
209 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=betas, fused=True)
|
210 |
+
|
211 |
+
return optimizer
|
212 |
+
|
213 |
+
|
214 |
+
def pad_or_cut_right(t: Tensor, padlen: int, value=0) -> Tensor:
|
215 |
+
current_len = t.shape[-1]
|
216 |
+
|
217 |
+
if current_len == padlen:
|
218 |
+
return t
|
219 |
+
|
220 |
+
if current_len < padlen:
|
221 |
+
# Need to pad
|
222 |
+
pad_size = (0, padlen - current_len)
|
223 |
+
return F.pad(t, pad_size, value=value)
|
224 |
+
# Need to cut
|
225 |
+
return t[:padlen]
|
226 |
+
|
227 |
+
|
228 |
+
def pad_or_cut_left(t: Tensor, value: int) -> Tensor:
|
229 |
+
dims = t.ndim
|
230 |
+
current_len = t.shape[0]
|
231 |
+
|
232 |
+
if current_len == value:
|
233 |
+
return t
|
234 |
+
|
235 |
+
if current_len < value:
|
236 |
+
# Need to pad
|
237 |
+
pad_size = (0,) * (2 * (dims - 1)) + (value - current_len, 0)
|
238 |
+
return F.pad(t, pad_size)
|
239 |
+
# Need to cut
|
240 |
+
return t[-value:]
|
241 |
+
|
242 |
+
|
243 |
+
def dl_pt(orig: str):
|
244 |
+
from os.path import exists
|
245 |
+
|
246 |
+
import torch
|
247 |
+
|
248 |
+
from vui.storage import s3, split_bucket_path
|
249 |
+
|
250 |
+
if not orig.endswith(".pt"):
|
251 |
+
orig = orig + ".pt"
|
252 |
+
|
253 |
+
load = partial(torch.load, weights_only=True)
|
254 |
+
if exists(orig):
|
255 |
+
return load(orig)
|
256 |
+
url = "/data/" + orig
|
257 |
+
|
258 |
+
if exists(url):
|
259 |
+
return load(url)
|
260 |
+
url = "s3://fluxions/" + orig
|
261 |
+
|
262 |
+
bucket, key = split_bucket_path(url)
|
263 |
+
response = s3.get_object(Bucket=bucket, Key=key)
|
264 |
+
return load(response["Body"])
|
265 |
+
|
266 |
+
|
267 |
+
def dl_ogg(url: str, start=0, end=-1, sr=None):
|
268 |
+
import re
|
269 |
+
from os.path import exists
|
270 |
+
|
271 |
+
import soundfile as sf
|
272 |
+
import torch
|
273 |
+
|
274 |
+
search_sr = re.search(r"(\d+)/", url)
|
275 |
+
if search_sr:
|
276 |
+
sr = int(search_sr.group(1))
|
277 |
+
|
278 |
+
local_file = exists(url)
|
279 |
+
|
280 |
+
if exists("/data/audio/" + url):
|
281 |
+
local_file = True
|
282 |
+
url = "/data/audio/" + url
|
283 |
+
|
284 |
+
if not local_file:
|
285 |
+
from vui.storage import s3
|
286 |
+
|
287 |
+
url = "s3://fluxions/" + url
|
288 |
+
b, p = split_bucket_path(url)
|
289 |
+
url = s3.get_object(Bucket=b, Key=p)["Body"]
|
290 |
+
|
291 |
+
if sr is None:
|
292 |
+
if local_file:
|
293 |
+
sr = sf.info(url).samplerate
|
294 |
+
else:
|
295 |
+
sr = sf.info(url.read()).samplerate
|
296 |
+
|
297 |
+
start_frame = int(start * sr)
|
298 |
+
num_frames = int(end * sr) - start_frame
|
299 |
+
wav, _ = sf.read(url, frames=num_frames, start=start_frame, always_2d=True)
|
300 |
+
wav = torch.from_numpy(wav).float()
|
301 |
+
wav = wav.T.mean(0, keepdim=True)
|
302 |
+
return wav, sr
|
303 |
+
|
304 |
+
|
305 |
+
class timer:
|
306 |
+
def __init__(self, name=""):
|
307 |
+
self.name = name
|
308 |
+
|
309 |
+
def __enter__(self):
|
310 |
+
self.t = time.perf_counter()
|
311 |
+
return self
|
312 |
+
|
313 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
314 |
+
elapsed = time.perf_counter() - self.t
|
315 |
+
print(f"{self.name} {elapsed:.4f}")
|
316 |
+
|
317 |
+
|
318 |
+
@torch.inference_mode()
|
319 |
+
def decode_audio_from_indices(model, indices, chunk_size=64):
|
320 |
+
"""
|
321 |
+
Decodes audio from indices in batches to avoid memory issues.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
model: Codec
|
325 |
+
indices: Tensor of shape (1, n_quantizers, sequence_length)
|
326 |
+
chunk_size: Number of samples to process at once
|
327 |
+
|
328 |
+
Returns:
|
329 |
+
Tensor of reconstructed audio
|
330 |
+
"""
|
331 |
+
device = model.device
|
332 |
+
indices = indices.to(device)
|
333 |
+
_, _, seq_len = indices.shape
|
334 |
+
chunks = seq_len // chunk_size + (1 if seq_len % chunk_size != 0 else 0)
|
335 |
+
|
336 |
+
audio_chunks = []
|
337 |
+
for i in range(chunks):
|
338 |
+
start_idx = i * chunk_size
|
339 |
+
end_idx = min(start_idx + chunk_size, seq_len)
|
340 |
+
chunk_indices = indices[:, :, start_idx:end_idx]
|
341 |
+
chunk_audio = model.from_indices(chunk_indices)
|
342 |
+
audio_chunks.append(chunk_audio.cpu())
|
343 |
+
|
344 |
+
full_audio = torch.cat(audio_chunks, dim=-1)
|
345 |
+
return full_audio.flatten()
|
346 |
+
|
347 |
+
|
348 |
+
def normalize_loudness(waveform, sample_rate: int, lufs: float = -12.0):
|
349 |
+
"""
|
350 |
+
Normalize the loudness of an audio tensor using torchaudio.transforms.Loudness.
|
351 |
+
|
352 |
+
Args:
|
353 |
+
audio_tensor (torch.Tensor): Input audio tensor of shape (channels, samples)
|
354 |
+
sample_rate (int): Sampling rate of the audio
|
355 |
+
target_loudness (float): Target loudness in LUFS (default: -16.0 LUFS)
|
356 |
+
|
357 |
+
Returns:
|
358 |
+
torch.Tensor: Loudness-normalized audio tensor
|
359 |
+
"""
|
360 |
+
import torchaudio
|
361 |
+
|
362 |
+
# Ensure the input tensor is 2D (add channel dimension if it's 1D)
|
363 |
+
if waveform.ndim == 1:
|
364 |
+
waveform = waveform.unsqueeze(0)
|
365 |
+
|
366 |
+
# Create a Loudness transform
|
367 |
+
loudness_transform = torchaudio.transforms.Loudness(sample_rate)
|
368 |
+
|
369 |
+
# Measure the current loudness
|
370 |
+
current_loudness = loudness_transform(waveform)
|
371 |
+
|
372 |
+
# Calculate the required gain
|
373 |
+
gain_db = lufs - current_loudness
|
374 |
+
|
375 |
+
# Convert gain from dB to linear scale
|
376 |
+
gain_linear = torch.pow(10, gain_db / 20)
|
377 |
+
|
378 |
+
# Apply the gain to normalize loudness
|
379 |
+
normalized_audio = waveform * gain_linear
|
380 |
+
|
381 |
+
return normalized_audio
|
382 |
+
|
383 |
+
|
384 |
+
def get_basename_without_extension(file_path):
|
385 |
+
from pathlib import Path
|
386 |
+
|
387 |
+
p = Path(file_path)
|
388 |
+
return p.stem
|
389 |
+
|
390 |
+
|
391 |
+
def ollama(prompt, MODEL=None):
|
392 |
+
import os
|
393 |
+
|
394 |
+
import requests
|
395 |
+
|
396 |
+
OLLAMA_HOST = "http://localhost:11434"
|
397 |
+
API = f"{OLLAMA_HOST}/api/generate"
|
398 |
+
|
399 |
+
if MODEL is None:
|
400 |
+
MODEL = os.environ.get("OLLAMA_MODEL", "gemma:1b")
|
401 |
+
|
402 |
+
payload = {
|
403 |
+
"model": MODEL,
|
404 |
+
"prompt": prompt,
|
405 |
+
"stream": False,
|
406 |
+
"options": {"temperature": 0.9, "top_p": 0.9, "max_tokens": 1000},
|
407 |
+
}
|
408 |
+
|
409 |
+
try:
|
410 |
+
response = requests.post(API, json=payload)
|
411 |
+
response.raise_for_status() # Raise exception for HTTP errors
|
412 |
+
result = response.json()
|
413 |
+
return result.get("response", "")
|
414 |
+
except requests.exceptions.RequestException as e:
|
415 |
+
print(f"Error calling Ollama API: {e}")
|
416 |
+
return ""
|
417 |
+
|
418 |
+
|
419 |
+
def decompile_state_dict(state_dict):
|
420 |
+
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
|
421 |
+
# state_dict = convert_old_weight_norm_to_new(state_dict)
|
422 |
+
return {k.replace("module.", ""): v for k, v in state_dict.items()}
|
src/vui/vad.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
from collections.abc import Callable
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
from pyannote.audio import Model, Pipeline
|
10 |
+
from pyannote.audio.core.io import AudioFile
|
11 |
+
from pyannote.audio.pipelines import VoiceActivityDetection
|
12 |
+
from pyannote.audio.pipelines.utils import PipelineModel
|
13 |
+
from pyannote.core import Annotation, Segment, SlidingWindowFeature
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
17 |
+
|
18 |
+
|
19 |
+
pipeline = None
|
20 |
+
pipeline_name = "pyannote/voice-activity-detection"
|
21 |
+
|
22 |
+
|
23 |
+
@torch.autocast("cuda", enabled=False)
|
24 |
+
def detect_voice_activity(waveform, pipe=None):
|
25 |
+
"""16khz"""
|
26 |
+
waveform = waveform.flatten().float()[None]
|
27 |
+
global pipeline
|
28 |
+
|
29 |
+
if pipe is not None:
|
30 |
+
pipeline = pipe
|
31 |
+
elif pipeline is None:
|
32 |
+
pipeline = Pipeline.from_pretrained(pipeline_name)
|
33 |
+
initial_params = {
|
34 |
+
"onset": 0.8,
|
35 |
+
"offset": 0.5,
|
36 |
+
"min_duration_on": 0,
|
37 |
+
"min_duration_off": 0.0,
|
38 |
+
}
|
39 |
+
pipeline.instantiate(initial_params)
|
40 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
41 |
+
pipeline = pipeline.to(device)
|
42 |
+
|
43 |
+
vad = pipeline({"waveform": waveform, "sample_rate": 16000})
|
44 |
+
segments = [
|
45 |
+
(segment.start, segment.end) for segment in vad.get_timeline().support()
|
46 |
+
]
|
47 |
+
|
48 |
+
return segments
|
49 |
+
|
50 |
+
|
51 |
+
def load_vad_model(
|
52 |
+
device,
|
53 |
+
vad_onset=0.500,
|
54 |
+
vad_offset=0.363,
|
55 |
+
use_auth_token=None,
|
56 |
+
model_fp=None,
|
57 |
+
batch_size=32,
|
58 |
+
):
|
59 |
+
model_dir = torch.hub._get_torch_home()
|
60 |
+
os.makedirs(model_dir, exist_ok=True)
|
61 |
+
if model_fp is None:
|
62 |
+
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
|
63 |
+
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
64 |
+
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
65 |
+
|
66 |
+
if not os.path.isfile(model_fp):
|
67 |
+
with (
|
68 |
+
urllib.request.urlopen(VAD_SEGMENTATION_URL) as source,
|
69 |
+
open(model_fp, "wb") as output,
|
70 |
+
):
|
71 |
+
with tqdm(
|
72 |
+
total=int(source.info().get("Content-Length")),
|
73 |
+
ncols=80,
|
74 |
+
unit="iB",
|
75 |
+
unit_scale=True,
|
76 |
+
unit_divisor=1024,
|
77 |
+
) as loop:
|
78 |
+
while True:
|
79 |
+
buffer = source.read(8192)
|
80 |
+
if not buffer:
|
81 |
+
break
|
82 |
+
|
83 |
+
output.write(buffer)
|
84 |
+
loop.update(len(buffer))
|
85 |
+
|
86 |
+
model_bytes = open(model_fp, "rb").read()
|
87 |
+
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split("/")[-2]:
|
88 |
+
raise RuntimeError(
|
89 |
+
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
90 |
+
)
|
91 |
+
|
92 |
+
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
93 |
+
hyperparameters = {
|
94 |
+
"onset": vad_onset,
|
95 |
+
"offset": vad_offset,
|
96 |
+
"min_duration_on": 0.1,
|
97 |
+
"min_duration_off": 0.1,
|
98 |
+
}
|
99 |
+
vad_pipeline = VoiceActivitySegmentation(
|
100 |
+
segmentation=vad_model, device=torch.device(device), batch_size=batch_size
|
101 |
+
)
|
102 |
+
vad_pipeline.instantiate(hyperparameters)
|
103 |
+
|
104 |
+
return vad_pipeline
|
105 |
+
|
106 |
+
|
107 |
+
class Binarize:
|
108 |
+
"""Binarize detection scores using hysteresis thresholding, with min-cut operation
|
109 |
+
to ensure not segments are longer than max_duration.
|
110 |
+
|
111 |
+
Parameters
|
112 |
+
----------
|
113 |
+
onset : float, optional
|
114 |
+
Onset threshold. Defaults to 0.5.
|
115 |
+
offset : float, optional
|
116 |
+
Offset threshold. Defaults to `onset`.
|
117 |
+
min_duration_on : float, optional
|
118 |
+
Remove active regions shorter than that many seconds. Defaults to 0s.
|
119 |
+
min_duration_off : float, optional
|
120 |
+
Fill inactive regions shorter than that many seconds. Defaults to 0s.
|
121 |
+
pad_onset : float, optional
|
122 |
+
Extend active regions by moving their start time by that many seconds.
|
123 |
+
Defaults to 0s.
|
124 |
+
pad_offset : float, optional
|
125 |
+
Extend active regions by moving their end time by that many seconds.
|
126 |
+
Defaults to 0s.
|
127 |
+
max_duration: float
|
128 |
+
The maximum length of an active segment, divides segment at timestamp with lowest score.
|
129 |
+
Reference
|
130 |
+
---------
|
131 |
+
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
132 |
+
RNN-based Voice Activity Detection", InterSpeech 2015.
|
133 |
+
|
134 |
+
Modified by Max Bain to include WhisperX's min-cut operation
|
135 |
+
https://arxiv.org/abs/2303.00747
|
136 |
+
|
137 |
+
Pyannote-audio
|
138 |
+
"""
|
139 |
+
|
140 |
+
def __init__(
|
141 |
+
self,
|
142 |
+
onset: float = 0.5,
|
143 |
+
offset: float | None = None,
|
144 |
+
min_duration_on: float = 0.0,
|
145 |
+
min_duration_off: float = 0.0,
|
146 |
+
pad_onset: float = 0.0,
|
147 |
+
pad_offset: float = 0.0,
|
148 |
+
max_duration: float = float("inf"),
|
149 |
+
):
|
150 |
+
super().__init__()
|
151 |
+
|
152 |
+
self.onset = onset
|
153 |
+
self.offset = offset or onset
|
154 |
+
|
155 |
+
self.pad_onset = pad_onset
|
156 |
+
self.pad_offset = pad_offset
|
157 |
+
|
158 |
+
self.min_duration_on = min_duration_on
|
159 |
+
self.min_duration_off = min_duration_off
|
160 |
+
|
161 |
+
self.max_duration = max_duration
|
162 |
+
|
163 |
+
def __call__(self, scores: SlidingWindowFeature) -> Annotation:
|
164 |
+
"""Binarize detection scores
|
165 |
+
Parameters
|
166 |
+
----------
|
167 |
+
scores : SlidingWindowFeature
|
168 |
+
Detection scores.
|
169 |
+
Returns
|
170 |
+
-------
|
171 |
+
active : Annotation
|
172 |
+
Binarized scores.
|
173 |
+
"""
|
174 |
+
|
175 |
+
num_frames, num_classes = scores.data.shape
|
176 |
+
frames = scores.sliding_window
|
177 |
+
timestamps = [frames[i].middle for i in range(num_frames)]
|
178 |
+
|
179 |
+
# annotation meant to store 'active' regions
|
180 |
+
active = Annotation()
|
181 |
+
for k, k_scores in enumerate(scores.data.T):
|
182 |
+
label = k if scores.labels is None else scores.labels[k]
|
183 |
+
|
184 |
+
# initial state
|
185 |
+
start = timestamps[0]
|
186 |
+
is_active = k_scores[0] > self.onset
|
187 |
+
curr_scores = [k_scores[0]]
|
188 |
+
curr_timestamps = [start]
|
189 |
+
t = start
|
190 |
+
for t, y in zip(timestamps[1:], k_scores[1:], strict=False):
|
191 |
+
# currently active
|
192 |
+
if is_active:
|
193 |
+
curr_duration = t - start
|
194 |
+
if curr_duration > self.max_duration:
|
195 |
+
search_after = len(curr_scores) // 2
|
196 |
+
# divide segment
|
197 |
+
min_score_div_idx = search_after + np.argmin(
|
198 |
+
curr_scores[search_after:]
|
199 |
+
)
|
200 |
+
min_score_t = curr_timestamps[min_score_div_idx]
|
201 |
+
region = Segment(
|
202 |
+
start - self.pad_onset, min_score_t + self.pad_offset
|
203 |
+
)
|
204 |
+
active[region, k] = label
|
205 |
+
start = curr_timestamps[min_score_div_idx]
|
206 |
+
curr_scores = curr_scores[min_score_div_idx + 1 :]
|
207 |
+
curr_timestamps = curr_timestamps[min_score_div_idx + 1 :]
|
208 |
+
# switching from active to inactive
|
209 |
+
elif y < self.offset:
|
210 |
+
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
211 |
+
active[region, k] = label
|
212 |
+
start = t
|
213 |
+
is_active = False
|
214 |
+
curr_scores = []
|
215 |
+
curr_timestamps = []
|
216 |
+
curr_scores.append(y)
|
217 |
+
curr_timestamps.append(t)
|
218 |
+
# currently inactive
|
219 |
+
else:
|
220 |
+
# switching from inactive to active
|
221 |
+
if y > self.onset:
|
222 |
+
start = t
|
223 |
+
is_active = True
|
224 |
+
|
225 |
+
# if active at the end, add final region
|
226 |
+
if is_active:
|
227 |
+
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
228 |
+
active[region, k] = label
|
229 |
+
|
230 |
+
# because of padding, some active regions might be overlapping: merge them.
|
231 |
+
# also: fill same speaker gaps shorter than min_duration_off
|
232 |
+
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
|
233 |
+
if self.max_duration < float("inf"):
|
234 |
+
raise NotImplementedError("This would break current max_duration param")
|
235 |
+
active = active.support(collar=self.min_duration_off)
|
236 |
+
|
237 |
+
# remove tracks shorter than min_duration_on
|
238 |
+
if self.min_duration_on > 0:
|
239 |
+
for segment, track in list(active.itertracks()):
|
240 |
+
if segment.duration < self.min_duration_on:
|
241 |
+
del active[segment, track]
|
242 |
+
|
243 |
+
return active
|
244 |
+
|
245 |
+
|
246 |
+
class VoiceActivitySegmentation(VoiceActivityDetection):
|
247 |
+
def __init__(
|
248 |
+
self,
|
249 |
+
segmentation: PipelineModel = "pyannote/segmentation",
|
250 |
+
fscore: bool = False,
|
251 |
+
use_auth_token: str | None = None,
|
252 |
+
**inference_kwargs,
|
253 |
+
):
|
254 |
+
super().__init__(
|
255 |
+
segmentation=segmentation,
|
256 |
+
fscore=fscore,
|
257 |
+
use_auth_token=use_auth_token,
|
258 |
+
**inference_kwargs,
|
259 |
+
)
|
260 |
+
|
261 |
+
def apply(self, file: AudioFile, hook: Callable | None = None) -> Annotation:
|
262 |
+
"""Apply voice activity detection
|
263 |
+
|
264 |
+
Parameters
|
265 |
+
----------
|
266 |
+
file : AudioFile
|
267 |
+
Processed file.
|
268 |
+
hook : callable, optional
|
269 |
+
Hook called after each major step of the pipeline with the following
|
270 |
+
signature: hook("step_name", step_artefact, file=file)
|
271 |
+
|
272 |
+
Returns
|
273 |
+
-------
|
274 |
+
speech : Annotation
|
275 |
+
Speech regions.
|
276 |
+
"""
|
277 |
+
|
278 |
+
# setup hook (e.g. for debugging purposes)
|
279 |
+
hook = self.setup_hook(file, hook=hook)
|
280 |
+
|
281 |
+
# apply segmentation model (only if needed)
|
282 |
+
# output shape is (num_chunks, num_frames, 1)
|
283 |
+
if self.training:
|
284 |
+
if self.CACHED_SEGMENTATION in file:
|
285 |
+
segmentations = file[self.CACHED_SEGMENTATION]
|
286 |
+
else:
|
287 |
+
segmentations = self._segmentation(file)
|
288 |
+
file[self.CACHED_SEGMENTATION] = segmentations
|
289 |
+
else:
|
290 |
+
segmentations: SlidingWindowFeature = self._segmentation(file)
|
291 |
+
|
292 |
+
return segmentations
|
293 |
+
|
294 |
+
|
295 |
+
def merge_vad(
|
296 |
+
vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0
|
297 |
+
):
|
298 |
+
active = Annotation()
|
299 |
+
for k, vad_t in enumerate(vad_arr):
|
300 |
+
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
|
301 |
+
active[region, k] = 1
|
302 |
+
|
303 |
+
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
|
304 |
+
active = active.support(collar=min_duration_off)
|
305 |
+
|
306 |
+
# remove tracks shorter than min_duration_on
|
307 |
+
if min_duration_on > 0:
|
308 |
+
for segment, track in list(active.itertracks()):
|
309 |
+
if segment.duration < min_duration_on:
|
310 |
+
del active[segment, track]
|
311 |
+
|
312 |
+
active = active.for_json()
|
313 |
+
active_segs = pd.DataFrame([x["segment"] for x in active["content"]])
|
314 |
+
return active_segs
|
315 |
+
|
316 |
+
|
317 |
+
def merge_chunks(
|
318 |
+
segments,
|
319 |
+
chunk_size,
|
320 |
+
onset: float = 0.5,
|
321 |
+
offset: float | None = None,
|
322 |
+
):
|
323 |
+
"""
|
324 |
+
Merge operation described in paper
|
325 |
+
"""
|
326 |
+
curr_end = 0
|
327 |
+
merged_segments = []
|
328 |
+
seg_idxs = []
|
329 |
+
|
330 |
+
assert chunk_size > 0
|
331 |
+
binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
|
332 |
+
segments = binarize(segments)
|
333 |
+
segments_list = []
|
334 |
+
for speech_turn in segments.get_timeline():
|
335 |
+
segments_list.append(Segment(speech_turn.start, speech_turn.end))
|
336 |
+
|
337 |
+
if len(segments_list) == 0:
|
338 |
+
print("No active speech found in audio")
|
339 |
+
return []
|
340 |
+
# assert segments_list, "segments_list is empty."
|
341 |
+
# Make sur the starting point is the start of the segment.
|
342 |
+
curr_start = segments_list[0].start
|
343 |
+
|
344 |
+
for seg in segments_list:
|
345 |
+
if seg.end - curr_start > chunk_size and curr_end - curr_start > 0:
|
346 |
+
merged_segments.append(
|
347 |
+
{
|
348 |
+
"start": curr_start,
|
349 |
+
"end": curr_end,
|
350 |
+
}
|
351 |
+
)
|
352 |
+
curr_start = seg.start
|
353 |
+
seg_idxs = []
|
354 |
+
curr_end = seg.end
|
355 |
+
seg_idxs.append((seg.start, seg.end))
|
356 |
+
|
357 |
+
merged_segments.append(
|
358 |
+
{
|
359 |
+
"start": curr_start,
|
360 |
+
"end": curr_end,
|
361 |
+
}
|
362 |
+
)
|
363 |
+
return merged_segments
|