zachzzc commited on
Commit
07f1f64
·
1 Parent(s): 91a3092

Upload tts playground and serving engine

Browse files
Files changed (37) hide show
  1. .gitignore +10 -0
  2. README.md +1 -1
  3. app.py +528 -4
  4. higgs_audio/__init__.py +1 -0
  5. higgs_audio/audio_processing/LICENSE +51 -0
  6. higgs_audio/audio_processing/descriptaudiocodec/__init__.py +0 -0
  7. higgs_audio/audio_processing/descriptaudiocodec/dac/model/base.py +286 -0
  8. higgs_audio/audio_processing/descriptaudiocodec/dac/model/dac.py +365 -0
  9. higgs_audio/audio_processing/descriptaudiocodec/dac/nn/layers.py +33 -0
  10. higgs_audio/audio_processing/descriptaudiocodec/dac/nn/quantize.py +251 -0
  11. higgs_audio/audio_processing/higgs_audio_tokenizer.py +341 -0
  12. higgs_audio/audio_processing/quantization/__init__.py +8 -0
  13. higgs_audio/audio_processing/quantization/ac.py +301 -0
  14. higgs_audio/audio_processing/quantization/core_vq.py +360 -0
  15. higgs_audio/audio_processing/quantization/core_vq_lsx_version.py +431 -0
  16. higgs_audio/audio_processing/quantization/ddp_utils.py +197 -0
  17. higgs_audio/audio_processing/quantization/distrib.py +123 -0
  18. higgs_audio/audio_processing/quantization/vq.py +116 -0
  19. higgs_audio/audio_processing/semantic_module.py +310 -0
  20. higgs_audio/constants.py +3 -0
  21. higgs_audio/data_collator/__init__.py +0 -0
  22. higgs_audio/data_collator/higgs_audio_collator.py +583 -0
  23. higgs_audio/data_types.py +38 -0
  24. higgs_audio/dataset/__init__.py +0 -0
  25. higgs_audio/dataset/chatml_dataset.py +554 -0
  26. higgs_audio/model/__init__.py +9 -0
  27. higgs_audio/model/audio_head.py +139 -0
  28. higgs_audio/model/common.py +27 -0
  29. higgs_audio/model/configuration_higgs_audio.py +235 -0
  30. higgs_audio/model/cuda_graph_runner.py +129 -0
  31. higgs_audio/model/custom_modules.py +155 -0
  32. higgs_audio/model/modeling_higgs_audio.py +0 -0
  33. higgs_audio/model/utils.py +778 -0
  34. higgs_audio/serve/serve_engine.py +424 -0
  35. higgs_audio/serve/utils.py +254 -0
  36. pyproject.toml +100 -0
  37. requirements.txt +17 -0
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ *.pyw
6
+ *.pyz
7
+ *.pywz
8
+ *.pyzw
9
+ *.pyzwz
10
+ .ruff_cache/
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Higgs Audio Demo
3
- emoji: 🏢
4
  colorFrom: green
5
  colorTo: purple
6
  sdk: gradio
 
1
  ---
2
  title: Higgs Audio Demo
3
+ emoji: 🎤
4
  colorFrom: green
5
  colorTo: purple
6
  sdk: gradio
app.py CHANGED
@@ -1,7 +1,531 @@
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ """
2
+ Gradio UI for Text-to-Speech using HiggsAudioServeEngine
3
+ """
4
+
5
+ import argparse
6
+ import base64
7
+ import os
8
+ import uuid
9
+ import json
10
+ from typing import Optional
11
  import gradio as gr
12
+ from loguru import logger
13
+ import numpy as np
14
+ import time
15
+ from functools import lru_cache
16
+ import re
17
+ import spaces
18
+
19
+
20
+ # Import HiggsAudio components
21
+ from higgs_audio.serve.serve_engine import HiggsAudioServeEngine
22
+ from higgs_audio.data_types import ChatMLSample, AudioContent, Message
23
+
24
+ # Global engine instance
25
+ engine = None
26
+
27
+ # Set up default paths and resources
28
+ EXAMPLES_DIR = os.path.join(os.path.dirname(__file__), "examples")
29
+ os.makedirs(EXAMPLES_DIR, exist_ok=True)
30
+
31
+ # Default model configuration
32
+ DEFAULT_MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-staging"
33
+ DEFAULT_AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer-staging"
34
+ SAMPLE_RATE = 24000
35
+
36
+ DEFAULT_SYSTEM_PROMPT = (
37
+ "Generate audio following instruction.\n\n"
38
+ "<|scene_desc_start|>\n"
39
+ "Audio is recorded from a quiet room.\n"
40
+ "<|scene_desc_end|>"
41
+ )
42
+
43
+ DEFAULT_STOP_STRINGS = ["<|end_of_text|>", "<|eot_id|>"]
44
+
45
+ # Predefined examples for system and input messages
46
+ PREDEFINED_EXAMPLES = {
47
+ "None": {"system_prompt": "", "input_text": "", "description": "Default example"},
48
+ "multispeaker-interleave": {
49
+ "system_prompt": "Generate audio following instruction.\n\n"
50
+ "<|scene_desc_start|>\n"
51
+ "SPEAKER0: vocal fry;feminism;slightly fast\n"
52
+ "SPEAKER1: masculine;moderate;moderate pitch;monotone;mature\n"
53
+ "In this scene, a group of adventurers is debating whether to investigate a potentially dangerous situation.\n"
54
+ "<|scene_desc_end|>",
55
+ "input_text": "<|generation_instruction_start|>\nGenerate interleaved transcript and audio that lasts for around 10 seconds.\n<|generation_instruction_end|>",
56
+ "description": "Multispeaker interleave example",
57
+ },
58
+ "single-speaker": {
59
+ "system_prompt": "Generate audio following instruction.\n\n"
60
+ "<|scene_desc_start|>\n"
61
+ "SPEAKER0: british accent\n"
62
+ "<|scene_desc_end|>",
63
+ "input_text": "Hey, everyone! Welcome back to Tech Talk Tuesdays.\n"
64
+ "It's your host, Alex, and today, we're diving into a topic that's become absolutely crucial in the tech world — deep learning.\n"
65
+ "And let's be honest, if you've been even remotely connected to tech, AI, or machine learning lately, you know that deep learning is everywhere.\n"
66
+ "\n"
67
+ "So here's the big question: Do you want to understand how deep learning works?\n"
68
+ "How to use it to build powerful models that can predict, automate, and transform industries?\n"
69
+ "Well, today, I've got some exciting news for you.\n"
70
+ "\n"
71
+ "We're going to talk about a course that I highly recommend: Dive into Deep Learning.\n"
72
+ "It's not just another course; it's an entire experience that will take you from a beginner to someone who is well-versed in deep learning techniques.",
73
+ "description": "Single speaker example",
74
+ },
75
+ "single-speaker-zh": {
76
+ "system_prompt": "Generate audio following instruction.\n\n"
77
+ "<|scene_desc_start|>\n"
78
+ "\nAudio is recorded from a quiet room.\n"
79
+ "\nSPEAKER0: feminine\n"
80
+ "<|scene_desc_end|>",
81
+ "input_text": "大家好, 欢迎收听本期的跟李沐学AI. 今天沐哥在忙着洗数据, 所以由我, 希格斯主播代替他讲这期视频.\n"
82
+ "今天我们要聊的是一个你绝对不能忽视的话题: 多模态学习.\n"
83
+ "无论你是开发者, 数据科学爱好者, 还是只是对人工智能感兴趣的人都一定听说过这个词. 它已经成为AI时代的一个研究热点.\n"
84
+ "那么, 问题来了, 你真的了解多模态吗? 你知道如何自己动手构建多模态大模型吗.\n"
85
+ "或者说, 你能察觉到我其实是个机器人吗?",
86
+ "description": "Single speaker with Chinese text",
87
+ },
88
+ }
89
+
90
+
91
+ @lru_cache(maxsize=20)
92
+ def encode_audio_file(file_path):
93
+ """Encode an audio file to base64."""
94
+ with open(file_path, "rb") as audio_file:
95
+ return base64.b64encode(audio_file.read()).decode("utf-8")
96
+
97
+
98
+ def load_voice_presets():
99
+ """Load the voice presets from the voice_examples directory."""
100
+ try:
101
+ with open(
102
+ os.path.join(os.path.dirname(__file__), "voice_examples", "config.json"),
103
+ "r",
104
+ ) as f:
105
+ voice_dict = json.load(f)
106
+ voice_presets = {k: v["transcript"] for k, v in voice_dict.items()}
107
+ voice_presets["EMPTY"] = "No reference voice"
108
+ logger.info(f"Loaded voice presets: {list(voice_presets.keys())}")
109
+ return voice_presets
110
+ except FileNotFoundError:
111
+ logger.warning("Voice examples config file not found. Using empty voice presets.")
112
+ return {"EMPTY": "No reference voice"}
113
+ except Exception as e:
114
+ logger.error(f"Error loading voice presets: {e}")
115
+ return {"EMPTY": "No reference voice"}
116
+
117
+
118
+ def get_voice_present(voice_preset):
119
+ """Get the voice path and text for a given voice preset."""
120
+ voice_path = os.path.join(os.path.dirname(__file__), "voice_examples", f"{voice_preset}.wav")
121
+ if not os.path.exists(voice_path):
122
+ logger.warning(f"Voice preset file not found: {voice_path}")
123
+ return None, "Voice preset not found"
124
+
125
+ text = VOICE_PRESETS.get(voice_preset, "No transcript available")
126
+ return voice_path, text
127
+
128
+
129
+ @spaces.GPU
130
+ def initialize_engine(model_path, audio_tokenizer_path, device="cuda") -> bool:
131
+ """Initialize the HiggsAudioServeEngine."""
132
+ global engine
133
+ try:
134
+ engine = HiggsAudioServeEngine(
135
+ model_name_or_path=model_path,
136
+ audio_tokenizer_name_or_path=audio_tokenizer_path,
137
+ device=device,
138
+ )
139
+ logger.info(f"Successfully initialized HiggsAudioServeEngine with model: {model_path}")
140
+ return True
141
+ except Exception as e:
142
+ logger.error(f"Failed to initialize engine: {e}")
143
+ return False
144
+
145
+
146
+ def check_return_audio(audio_wv: np.ndarray):
147
+ # check if the audio returned is all silent
148
+ if np.all(audio_wv == 0):
149
+ logger.warning("Audio is silent, returning None")
150
+
151
+
152
+ def process_text_output(text_output: str):
153
+ # remove all the continuous <|AUDIO_OUT|> tokens with a single <|AUDIO_OUT|>
154
+ text_output = re.sub(r"(<\|AUDIO_OUT\|>)+", r"<|AUDIO_OUT|>", text_output)
155
+ return text_output
156
+
157
+
158
+ def prepare_chatml_sample(
159
+ voice_present: str,
160
+ text: str,
161
+ reference_audio: Optional[str] = None,
162
+ reference_text: Optional[str] = None,
163
+ system_prompt: str = DEFAULT_SYSTEM_PROMPT,
164
+ ):
165
+ """Prepare a ChatMLSample for the HiggsAudioServeEngine."""
166
+ messages = []
167
+
168
+ # Add system message if provided
169
+ if len(system_prompt) > 0:
170
+ messages.append(Message(role="system", content=system_prompt))
171
+
172
+ # Add reference audio if provided
173
+ audio_base64 = None
174
+ ref_text = ""
175
+
176
+ if reference_audio:
177
+ # Custom reference audio
178
+ audio_base64 = encode_audio_file(reference_audio)
179
+ ref_text = reference_text or ""
180
+ elif voice_present != "EMPTY":
181
+ # Voice preset
182
+ voice_path, ref_text = get_voice_present(voice_present)
183
+ if voice_path is None:
184
+ logger.warning(f"Voice preset {voice_present} not found, skipping reference audio")
185
+ else:
186
+ audio_base64 = encode_audio_file(voice_path)
187
+
188
+ # Only add reference audio if we have it
189
+ if audio_base64 is not None:
190
+ # Add user message with reference text
191
+ messages.append(Message(role="user", content=ref_text))
192
+
193
+ # Add assistant message with audio content
194
+ audio_content = AudioContent(raw_audio=audio_base64, audio_url="")
195
+ messages.append(Message(role="assistant", content=[audio_content]))
196
+
197
+ # Add the main user message
198
+ messages.append(Message(role="user", content=text))
199
+
200
+ return ChatMLSample(messages=messages)
201
+
202
+
203
+ @spaces.GPU(duration=500)
204
+ def text_to_speech(
205
+ text,
206
+ voice_preset,
207
+ reference_audio=None,
208
+ reference_text=None,
209
+ max_completion_tokens=1024,
210
+ temperature=1.0,
211
+ top_p=0.95,
212
+ top_k=50,
213
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
214
+ stop_strings=None,
215
+ ):
216
+ """Convert text to speech using HiggsAudioServeEngine."""
217
+ global engine
218
+
219
+ if engine is None:
220
+ error_msg = "Engine not initialized. Please load a model first."
221
+ logger.error(error_msg)
222
+ gr.Error(error_msg)
223
+ return f"❌ {error_msg}", None
224
+
225
+ try:
226
+ # Prepare ChatML sample
227
+ chatml_sample = prepare_chatml_sample(voice_preset, text, reference_audio, reference_text, system_prompt)
228
+
229
+ # Convert stop strings format
230
+ if stop_strings is None:
231
+ stop_list = DEFAULT_STOP_STRINGS
232
+ else:
233
+ stop_list = [s for s in stop_strings["stops"] if s.strip()]
234
+
235
+ request_id = f"tts-playground-{str(uuid.uuid4())}"
236
+ logger.info(
237
+ f"{request_id}: Generating speech for text: {text[:100]}..., \n"
238
+ f"with parameters: temperature={temperature}, top_p={top_p}, top_k={top_k}, stop_list={stop_list}"
239
+ )
240
+ start_time = time.time()
241
+
242
+ # Generate using the engine
243
+ response = engine.generate(
244
+ chat_ml_sample=chatml_sample,
245
+ max_new_tokens=max_completion_tokens,
246
+ temperature=temperature,
247
+ top_k=top_k if top_k > 0 else None,
248
+ top_p=top_p,
249
+ stop_strings=stop_list,
250
+ )
251
+
252
+ generation_time = time.time() - start_time
253
+ logger.info(f"{request_id}: Generated audio in {generation_time:.3f} seconds")
254
+ gr.Info(f"Generated audio in {generation_time:.3f} seconds")
255
+
256
+ # Process the response
257
+ text_output = process_text_output(response.generated_text)
258
+
259
+ if response.audio is not None:
260
+ # Convert to int16 for Gradio
261
+ audio_data = (response.audio * 32767).astype(np.int16)
262
+ check_return_audio(audio_data)
263
+ return text_output, (response.sampling_rate, audio_data)
264
+ else:
265
+ logger.warning("No audio generated")
266
+ return text_output, None
267
+
268
+ except Exception as e:
269
+ error_msg = f"Error generating speech: {e}"
270
+ logger.error(error_msg)
271
+ gr.Error(error_msg)
272
+ return f"❌ {error_msg}", None
273
+
274
+
275
+ def create_ui():
276
+ my_theme = "JohnSmith9982/small_and_pretty"
277
+
278
+ # Add custom CSS to disable focus highlighting on textboxes
279
+ custom_css = """
280
+ .gradio-container input:focus,
281
+ .gradio-container textarea:focus,
282
+ .gradio-container select:focus,
283
+ .gradio-container .gr-input:focus,
284
+ .gradio-container .gr-textarea:focus,
285
+ .gradio-container .gr-textbox:focus,
286
+ .gradio-container .gr-textbox:focus-within,
287
+ .gradio-container .gr-form:focus-within,
288
+ .gradio-container *:focus {
289
+ box-shadow: none !important;
290
+ border-color: var(--border-color-primary) !important;
291
+ outline: none !important;
292
+ background-color: var(--input-background-fill) !important;
293
+ }
294
+
295
+ /* Override any hover effects as well */
296
+ .gradio-container input:hover,
297
+ .gradio-container textarea:hover,
298
+ .gradio-container select:hover,
299
+ .gradio-container .gr-input:hover,
300
+ .gradio-container .gr-textarea:hover,
301
+ .gradio-container .gr-textbox:hover {
302
+ border-color: var(--border-color-primary) !important;
303
+ background-color: var(--input-background-fill) !important;
304
+ }
305
+
306
+ /* Style for checked checkbox */
307
+ .gradio-container input[type="checkbox"]:checked {
308
+ background-color: var(--primary-500) !important;
309
+ border-color: var(--primary-500) !important;
310
+ }
311
+ """
312
+
313
+ """Create the Gradio UI."""
314
+ with gr.Blocks(theme=my_theme, css=custom_css) as demo:
315
+ gr.Markdown("# Higgs Audio Text-to-Speech Playground")
316
+
317
+ # Main UI section
318
+ with gr.Row():
319
+ with gr.Column(scale=2):
320
+ # Template selection dropdown
321
+ template_dropdown = gr.Dropdown(
322
+ label="Message examples",
323
+ choices=list(PREDEFINED_EXAMPLES.keys()),
324
+ value="None",
325
+ info="Select a predefined example for system and input messages. Voice preset will be set to EMPTY when a example is selected.",
326
+ )
327
+
328
+ system_prompt = gr.TextArea(
329
+ label="System Prompt",
330
+ placeholder="Enter system prompt to guide the model...",
331
+ value=DEFAULT_SYSTEM_PROMPT,
332
+ lines=2,
333
+ )
334
+
335
+ input_text = gr.TextArea(
336
+ label="Input Text",
337
+ placeholder="Type the text you want to convert to speech...",
338
+ lines=5,
339
+ )
340
+
341
+ voice_preset = gr.Dropdown(
342
+ label="Voice Preset",
343
+ choices=list(VOICE_PRESETS.keys()),
344
+ value="EMPTY",
345
+ )
346
+
347
+ with gr.Accordion("Custom Reference (Optional)", open=False):
348
+ reference_audio = gr.Audio(label="Reference Audio", type="filepath")
349
+ reference_text = gr.TextArea(
350
+ label="Reference Text (transcript of the reference audio)",
351
+ placeholder="Enter the transcript of your reference audio...",
352
+ lines=3,
353
+ )
354
+
355
+ with gr.Accordion("Advanced Parameters", open=False):
356
+ max_completion_tokens = gr.Slider(
357
+ minimum=128,
358
+ maximum=4096,
359
+ value=1024,
360
+ step=10,
361
+ label="Max Completion Tokens",
362
+ )
363
+ temperature = gr.Slider(
364
+ minimum=0.0,
365
+ maximum=1.5,
366
+ value=1.0,
367
+ step=0.1,
368
+ label="Temperature",
369
+ )
370
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top P")
371
+ top_k = gr.Slider(minimum=-1, maximum=100, value=50, step=1, label="Top K")
372
+ # Add stop strings component
373
+ stop_strings = gr.Dataframe(
374
+ label="Stop Strings",
375
+ headers=["stops"],
376
+ datatype=["str"],
377
+ value=[[s] for s in DEFAULT_STOP_STRINGS],
378
+ interactive=True,
379
+ col_count=(1, "fixed"),
380
+ )
381
+
382
+ submit_btn = gr.Button("Generate Speech", variant="primary", scale=1)
383
+
384
+ with gr.Column(scale=2):
385
+ output_text = gr.TextArea(label="Model Response", lines=2)
386
+
387
+ # Audio output
388
+ output_audio = gr.Audio(label="Generated Audio", interactive=False, autoplay=True)
389
+
390
+ stop_btn = gr.Button("Stop Playback", variant="primary")
391
+
392
+ # Example voice
393
+ with gr.Row():
394
+ voice_samples_table = gr.Dataframe(
395
+ headers=["Voice Preset", "Sample Text"],
396
+ datatype=["str", "str"],
397
+ value=[[preset, text] for preset, text in VOICE_PRESETS.items() if preset != "EMPTY"],
398
+ interactive=False,
399
+ )
400
+ sample_audio = gr.Audio(label="Voice Sample", visible=True)
401
+
402
+ # Function to play voice sample when clicking on a row
403
+ def play_voice_sample(evt: gr.SelectData):
404
+ try:
405
+ # Get the preset name from the clicked row
406
+ preset_names = [preset for preset in VOICE_PRESETS.keys() if preset != "EMPTY"]
407
+ if evt.index[0] < len(preset_names):
408
+ preset = preset_names[evt.index[0]]
409
+ voice_path, _ = get_voice_present(preset)
410
+ if voice_path and os.path.exists(voice_path):
411
+ return voice_path
412
+ else:
413
+ gr.Warning(f"Voice sample file not found for preset: {preset}")
414
+ return None
415
+ else:
416
+ gr.Warning("Invalid voice preset selection")
417
+ return None
418
+ except Exception as e:
419
+ logger.error(f"Error playing voice sample: {e}")
420
+ gr.Error(f"Error playing voice sample: {e}")
421
+ return None
422
+
423
+ voice_samples_table.select(fn=play_voice_sample, outputs=[sample_audio])
424
+
425
+ # Function to handle template selection
426
+ def apply_template(template_name):
427
+ if template_name in PREDEFINED_EXAMPLES:
428
+ template = PREDEFINED_EXAMPLES[template_name]
429
+ return (
430
+ template["system_prompt"], # system_prompt
431
+ template["input_text"], # input_text
432
+ "EMPTY", # voice_preset (always set to EMPTY for examples)
433
+ )
434
+ else:
435
+ return (
436
+ gr.update(),
437
+ gr.update(),
438
+ gr.update(),
439
+ ) # No change if template not found
440
+
441
+ # Set up event handlers
442
+
443
+ # Connect template dropdown to handler
444
+ template_dropdown.change(
445
+ fn=apply_template,
446
+ inputs=[template_dropdown],
447
+ outputs=[system_prompt, input_text, voice_preset],
448
+ )
449
+
450
+ # Connect submit button to the TTS function
451
+ submit_btn.click(
452
+ fn=text_to_speech,
453
+ inputs=[
454
+ input_text,
455
+ voice_preset,
456
+ reference_audio,
457
+ reference_text,
458
+ max_completion_tokens,
459
+ temperature,
460
+ top_p,
461
+ top_k,
462
+ system_prompt,
463
+ stop_strings,
464
+ ],
465
+ outputs=[output_text, output_audio],
466
+ api_name="generate_speech",
467
+ )
468
+
469
+ # Stop button functionality
470
+ stop_btn.click(
471
+ fn=lambda: None,
472
+ inputs=[],
473
+ outputs=[output_audio],
474
+ js="() => {const audio = document.querySelector('audio'); if(audio) audio.pause(); return null;}",
475
+ )
476
+
477
+ return demo
478
+
479
+
480
+ def main():
481
+ """Main function to parse arguments and launch the UI."""
482
+ global DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH, VOICE_PRESETS
483
+
484
+ parser = argparse.ArgumentParser(description="Gradio UI for Text-to-Speech using HiggsAudioServeEngine")
485
+ parser.add_argument(
486
+ "--model-path",
487
+ type=str,
488
+ default=DEFAULT_MODEL_PATH,
489
+ help="Path to the Higgs Audio model.",
490
+ )
491
+ parser.add_argument(
492
+ "--audio-tokenizer-path",
493
+ type=str,
494
+ default=DEFAULT_AUDIO_TOKENIZER_PATH,
495
+ help="Path to the audio tokenizer.",
496
+ )
497
+ parser.add_argument(
498
+ "--device",
499
+ type=str,
500
+ default="cuda",
501
+ choices=["cuda", "cpu"],
502
+ help="Device to run the model on.",
503
+ )
504
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host for the Gradio interface.")
505
+ parser.add_argument("--port", type=int, default=7860, help="Port for the Gradio interface.")
506
+
507
+ args = parser.parse_args()
508
+
509
+ # Update default values if provided via command line
510
+ DEFAULT_MODEL_PATH = args.model_path
511
+ DEFAULT_AUDIO_TOKENIZER_PATH = args.audio_tokenizer_path
512
+ VOICE_PRESETS = load_voice_presets()
513
+
514
+ # Load model on startup
515
+ logger.info("Loading model...")
516
+ result = initialize_engine(args.model_path, args.audio_tokenizer_path, args.device)
517
+
518
+ # Exit if model loading failed
519
+ if not result:
520
+ logger.error("Failed to load model. Exiting.")
521
+ return
522
+
523
+ logger.info(f"Model loaded: {DEFAULT_MODEL_PATH}")
524
+
525
+ # Create and launch the UI
526
+ demo = create_ui()
527
+ demo.launch(server_name=args.host, server_port=args.port)
528
 
 
 
529
 
530
+ if __name__ == "__main__":
531
+ main()
higgs_audio/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import HiggsAudioConfig, HiggsAudioModel
higgs_audio/audio_processing/LICENSE ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Third-Party License Attribution for Audio Processing Module
2
+ ===========================================================
3
+
4
+ This directory contains code derived from multiple open-source projects.
5
+ The following sections detail the licenses and attributions for third-party code.
6
+
7
+ ## XCodec Repository
8
+ The code in this directory is derived from:
9
+ https://github.com/zhenye234/xcodec
10
+
11
+ ## Individual File Attributions
12
+
13
+ ### Quantization Module (quantization/)
14
+ - Several files contain code derived from Meta Platforms, Inc. and the vector-quantize-pytorch repository
15
+ - Individual files contain their own license headers where applicable
16
+ - The vector-quantize-pytorch portions are licensed under the MIT License
17
+
18
+ ## License Terms
19
+
20
+ ### MIT License (for applicable portions)
21
+ Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ of this software and associated documentation files (the "Software"), to deal
23
+ in the Software without restriction, including without limitation the rights
24
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ copies of the Software, and to permit persons to whom the Software is
26
+ furnished to do so, subject to the following conditions:
27
+
28
+ The above copyright notice and this permission notice shall be included in all
29
+ copies or substantial portions of the Software.
30
+
31
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ SOFTWARE.
38
+
39
+ ## Attribution Requirements
40
+ When using this code, please ensure proper attribution to:
41
+ 1. The original xcodec repository: https://github.com/zhenye234/xcodec
42
+ 2. Any other repositories mentioned in individual file headers
43
+ 3. This derivative work and its modifications
44
+
45
+ ## Disclaimer
46
+ This directory contains modified versions of the original code. Please refer to
47
+ the original repositories for the canonical implementations and their specific
48
+ license terms.
49
+
50
+ For any questions about licensing or attribution, please check the individual
51
+ file headers and the original source repositories.
higgs_audio/audio_processing/descriptaudiocodec/__init__.py ADDED
File without changes
higgs_audio/audio_processing/descriptaudiocodec/dac/model/base.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+ from audiotools import AudioSignal
10
+ from torch import nn
11
+
12
+ SUPPORTED_VERSIONS = ["1.0.0"]
13
+
14
+
15
+ @dataclass
16
+ class DACFile:
17
+ codes: torch.Tensor
18
+
19
+ # Metadata
20
+ chunk_length: int
21
+ original_length: int
22
+ input_db: float
23
+ channels: int
24
+ sample_rate: int
25
+ padding: bool
26
+ dac_version: str
27
+
28
+ def save(self, path):
29
+ artifacts = {
30
+ "codes": self.codes.numpy().astype(np.uint16),
31
+ "metadata": {
32
+ "input_db": self.input_db.numpy().astype(np.float32),
33
+ "original_length": self.original_length,
34
+ "sample_rate": self.sample_rate,
35
+ "chunk_length": self.chunk_length,
36
+ "channels": self.channels,
37
+ "padding": self.padding,
38
+ "dac_version": SUPPORTED_VERSIONS[-1],
39
+ },
40
+ }
41
+ path = Path(path).with_suffix(".dac")
42
+ with open(path, "wb") as f:
43
+ np.save(f, artifacts)
44
+ return path
45
+
46
+ @classmethod
47
+ def load(cls, path):
48
+ artifacts = np.load(path, allow_pickle=True)[()]
49
+ codes = torch.from_numpy(artifacts["codes"].astype(int))
50
+ if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51
+ raise RuntimeError(f"Given file {path} can't be loaded with this version of descript-audio-codec.")
52
+ return cls(codes=codes, **artifacts["metadata"])
53
+
54
+
55
+ class CodecMixin:
56
+ @property
57
+ def padding(self):
58
+ if not hasattr(self, "_padding"):
59
+ self._padding = True
60
+ return self._padding
61
+
62
+ @padding.setter
63
+ def padding(self, value):
64
+ assert isinstance(value, bool)
65
+
66
+ layers = [l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))]
67
+
68
+ for layer in layers:
69
+ if value:
70
+ if hasattr(layer, "original_padding"):
71
+ layer.padding = layer.original_padding
72
+ else:
73
+ layer.original_padding = layer.padding
74
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
75
+
76
+ self._padding = value
77
+
78
+ def get_delay(self):
79
+ # Any number works here, delay is invariant to input length
80
+ l_out = self.get_output_length(0)
81
+ L = l_out
82
+
83
+ layers = []
84
+ for layer in self.modules():
85
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
86
+ layers.append(layer)
87
+
88
+ for layer in reversed(layers):
89
+ d = layer.dilation[0]
90
+ k = layer.kernel_size[0]
91
+ s = layer.stride[0]
92
+
93
+ if isinstance(layer, nn.ConvTranspose1d):
94
+ L = ((L - d * (k - 1) - 1) / s) + 1
95
+ elif isinstance(layer, nn.Conv1d):
96
+ L = (L - 1) * s + d * (k - 1) + 1
97
+
98
+ L = math.ceil(L)
99
+
100
+ l_in = L
101
+
102
+ return (l_in - l_out) // 2
103
+
104
+ def get_output_length(self, input_length):
105
+ L = input_length
106
+ # Calculate output length
107
+ for layer in self.modules():
108
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
109
+ d = layer.dilation[0]
110
+ k = layer.kernel_size[0]
111
+ s = layer.stride[0]
112
+
113
+ if isinstance(layer, nn.Conv1d):
114
+ L = ((L - d * (k - 1) - 1) / s) + 1
115
+ elif isinstance(layer, nn.ConvTranspose1d):
116
+ L = (L - 1) * s + d * (k - 1) + 1
117
+
118
+ L = math.floor(L)
119
+ return L
120
+
121
+ @torch.no_grad()
122
+ def compress(
123
+ self,
124
+ audio_path_or_signal: Union[str, Path, AudioSignal],
125
+ win_duration: float = 1.0,
126
+ verbose: bool = False,
127
+ normalize_db: float = -16,
128
+ n_quantizers: int = None,
129
+ ) -> DACFile:
130
+ """Processes an audio signal from a file or AudioSignal object into
131
+ discrete codes. This function processes the signal in short windows,
132
+ using constant GPU memory.
133
+
134
+ Parameters
135
+ ----------
136
+ audio_path_or_signal : Union[str, Path, AudioSignal]
137
+ audio signal to reconstruct
138
+ win_duration : float, optional
139
+ window duration in seconds, by default 5.0
140
+ verbose : bool, optional
141
+ by default False
142
+ normalize_db : float, optional
143
+ normalize db, by default -16
144
+
145
+ Returns
146
+ -------
147
+ DACFile
148
+ Object containing compressed codes and metadata
149
+ required for decompression
150
+ """
151
+ audio_signal = audio_path_or_signal
152
+ if isinstance(audio_signal, (str, Path)):
153
+ audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
154
+
155
+ self.eval()
156
+ original_padding = self.padding
157
+ original_device = audio_signal.device
158
+
159
+ audio_signal = audio_signal.clone()
160
+ original_sr = audio_signal.sample_rate
161
+
162
+ resample_fn = audio_signal.resample
163
+ loudness_fn = audio_signal.loudness
164
+
165
+ # If audio is > 10 minutes long, use the ffmpeg versions
166
+ if audio_signal.signal_duration >= 10 * 60 * 60:
167
+ resample_fn = audio_signal.ffmpeg_resample
168
+ loudness_fn = audio_signal.ffmpeg_loudness
169
+
170
+ original_length = audio_signal.signal_length
171
+ resample_fn(self.sample_rate)
172
+ input_db = loudness_fn()
173
+
174
+ if normalize_db is not None:
175
+ audio_signal.normalize(normalize_db)
176
+ audio_signal.ensure_max_of_audio()
177
+
178
+ nb, nac, nt = audio_signal.audio_data.shape
179
+ audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
180
+ win_duration = audio_signal.signal_duration if win_duration is None else win_duration
181
+
182
+ if audio_signal.signal_duration <= win_duration:
183
+ # Unchunked compression (used if signal length < win duration)
184
+ self.padding = True
185
+ n_samples = nt
186
+ hop = nt
187
+ else:
188
+ # Chunked inference
189
+ self.padding = False
190
+ # Zero-pad signal on either side by the delay
191
+ audio_signal.zero_pad(self.delay, self.delay)
192
+ n_samples = int(win_duration * self.sample_rate)
193
+ # Round n_samples to nearest hop length multiple
194
+ n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
195
+ hop = self.get_output_length(n_samples)
196
+
197
+ codes = []
198
+ range_fn = range if not verbose else tqdm.trange
199
+
200
+ for i in range_fn(0, nt, hop):
201
+ x = audio_signal[..., i : i + n_samples]
202
+ x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
203
+
204
+ audio_data = x.audio_data.to(self.device)
205
+ audio_data = self.preprocess(audio_data, self.sample_rate)
206
+ _, c, _, _, _ = self.encode(audio_data, n_quantizers)
207
+ codes.append(c.to(original_device))
208
+ chunk_length = c.shape[-1]
209
+
210
+ codes = torch.cat(codes, dim=-1)
211
+
212
+ dac_file = DACFile(
213
+ codes=codes,
214
+ chunk_length=chunk_length,
215
+ original_length=original_length,
216
+ input_db=input_db,
217
+ channels=nac,
218
+ sample_rate=original_sr,
219
+ padding=self.padding,
220
+ dac_version=SUPPORTED_VERSIONS[-1],
221
+ )
222
+
223
+ if n_quantizers is not None:
224
+ codes = codes[:, :n_quantizers, :]
225
+
226
+ self.padding = original_padding
227
+ return dac_file
228
+
229
+ @torch.no_grad()
230
+ def decompress(
231
+ self,
232
+ obj: Union[str, Path, DACFile],
233
+ verbose: bool = False,
234
+ ) -> AudioSignal:
235
+ """Reconstruct audio from a given .dac file
236
+
237
+ Parameters
238
+ ----------
239
+ obj : Union[str, Path, DACFile]
240
+ .dac file location or corresponding DACFile object.
241
+ verbose : bool, optional
242
+ Prints progress if True, by default False
243
+
244
+ Returns
245
+ -------
246
+ AudioSignal
247
+ Object with the reconstructed audio
248
+ """
249
+ self.eval()
250
+ if isinstance(obj, (str, Path)):
251
+ obj = DACFile.load(obj)
252
+
253
+ original_padding = self.padding
254
+ self.padding = obj.padding
255
+
256
+ range_fn = range if not verbose else tqdm.trange
257
+ codes = obj.codes
258
+ original_device = codes.device
259
+ chunk_length = obj.chunk_length
260
+ recons = []
261
+
262
+ for i in range_fn(0, codes.shape[-1], chunk_length):
263
+ c = codes[..., i : i + chunk_length].to(self.device)
264
+ z = self.quantizer.from_codes(c)[0]
265
+ r = self.decode(z)
266
+ recons.append(r.to(original_device))
267
+
268
+ recons = torch.cat(recons, dim=-1)
269
+ recons = AudioSignal(recons, self.sample_rate)
270
+
271
+ resample_fn = recons.resample
272
+ loudness_fn = recons.loudness
273
+
274
+ # If audio is > 10 minutes long, use the ffmpeg versions
275
+ if recons.signal_duration >= 10 * 60 * 60:
276
+ resample_fn = recons.ffmpeg_resample
277
+ loudness_fn = recons.ffmpeg_loudness
278
+
279
+ recons.normalize(obj.input_db)
280
+ resample_fn(obj.sample_rate)
281
+ recons = recons[..., : obj.original_length]
282
+ loudness_fn()
283
+ recons.audio_data = recons.audio_data.reshape(-1, obj.channels, obj.original_length)
284
+
285
+ self.padding = original_padding
286
+ return recons
higgs_audio/audio_processing/descriptaudiocodec/dac/model/dac.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from audiotools.ml import BaseModel
9
+ from torch import nn
10
+
11
+ from .base import CodecMixin
12
+ from dac.nn.layers import Snake1d
13
+ from dac.nn.layers import WNConv1d
14
+ from dac.nn.layers import WNConvTranspose1d
15
+ from dac.nn.quantize import ResidualVectorQuantize
16
+
17
+
18
+ def init_weights(m):
19
+ if isinstance(m, nn.Conv1d):
20
+ nn.init.trunc_normal_(m.weight, std=0.02)
21
+ nn.init.constant_(m.bias, 0)
22
+
23
+
24
+ class ResidualUnit(nn.Module):
25
+ def __init__(self, dim: int = 16, dilation: int = 1):
26
+ super().__init__()
27
+ pad = ((7 - 1) * dilation) // 2
28
+ self.block = nn.Sequential(
29
+ Snake1d(dim),
30
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
31
+ Snake1d(dim),
32
+ WNConv1d(dim, dim, kernel_size=1),
33
+ )
34
+
35
+ def forward(self, x):
36
+ y = self.block(x)
37
+ pad = (x.shape[-1] - y.shape[-1]) // 2
38
+ if pad > 0:
39
+ x = x[..., pad:-pad]
40
+ return x + y
41
+
42
+
43
+ class EncoderBlock(nn.Module):
44
+ def __init__(self, dim: int = 16, stride: int = 1):
45
+ super().__init__()
46
+ self.block = nn.Sequential(
47
+ ResidualUnit(dim // 2, dilation=1),
48
+ ResidualUnit(dim // 2, dilation=3),
49
+ ResidualUnit(dim // 2, dilation=9),
50
+ Snake1d(dim // 2),
51
+ WNConv1d(
52
+ dim // 2,
53
+ dim,
54
+ kernel_size=2 * stride,
55
+ stride=stride,
56
+ padding=math.ceil(stride / 2),
57
+ ),
58
+ )
59
+
60
+ def forward(self, x):
61
+ return self.block(x)
62
+
63
+
64
+ class Encoder(nn.Module):
65
+ def __init__(
66
+ self,
67
+ d_model: int = 64,
68
+ strides: list = [2, 4, 8, 8],
69
+ d_latent: int = 256,
70
+ ):
71
+ super().__init__()
72
+ # Create first convolution
73
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
74
+
75
+ # Create EncoderBlocks that double channels as they downsample by `stride`
76
+ for stride in strides:
77
+ d_model *= 2
78
+ self.block += [EncoderBlock(d_model, stride=stride)]
79
+
80
+ # Create last convolution
81
+ self.block += [
82
+ Snake1d(d_model),
83
+ WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
84
+ ]
85
+
86
+ # Wrap black into nn.Sequential
87
+ self.block = nn.Sequential(*self.block)
88
+ self.enc_dim = d_model
89
+
90
+ def forward(self, x):
91
+ return self.block(x)
92
+
93
+
94
+ class DecoderBlock(nn.Module):
95
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, out_pad=0):
96
+ super().__init__()
97
+ self.block = nn.Sequential(
98
+ Snake1d(input_dim),
99
+ WNConvTranspose1d(
100
+ input_dim,
101
+ output_dim,
102
+ kernel_size=2 * stride,
103
+ stride=stride,
104
+ padding=math.ceil(stride / 2),
105
+ output_padding=stride % 2, # out_pad,
106
+ ),
107
+ ResidualUnit(output_dim, dilation=1),
108
+ ResidualUnit(output_dim, dilation=3),
109
+ ResidualUnit(output_dim, dilation=9),
110
+ )
111
+
112
+ def forward(self, x):
113
+ return self.block(x)
114
+
115
+
116
+ class Decoder(nn.Module):
117
+ def __init__(
118
+ self,
119
+ input_channel,
120
+ channels,
121
+ rates,
122
+ d_out: int = 1,
123
+ ):
124
+ super().__init__()
125
+
126
+ # Add first conv layer
127
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
128
+
129
+ # Add upsampling + MRF blocks
130
+ for i, stride in enumerate(rates):
131
+ input_dim = channels // 2**i
132
+ output_dim = channels // 2 ** (i + 1)
133
+ if i == 1:
134
+ out_pad = 1
135
+ else:
136
+ out_pad = 0
137
+ layers += [DecoderBlock(input_dim, output_dim, stride, out_pad)]
138
+
139
+ # Add final conv layer
140
+ layers += [
141
+ Snake1d(output_dim),
142
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
143
+ # nn.Tanh(),
144
+ ]
145
+
146
+ self.model = nn.Sequential(*layers)
147
+
148
+ def forward(self, x):
149
+ return self.model(x)
150
+
151
+
152
+ class DAC(BaseModel, CodecMixin):
153
+ def __init__(
154
+ self,
155
+ encoder_dim: int = 64,
156
+ encoder_rates: List[int] = [2, 4, 8, 8],
157
+ latent_dim: int = None,
158
+ decoder_dim: int = 1536,
159
+ decoder_rates: List[int] = [8, 8, 4, 2],
160
+ n_codebooks: int = 9,
161
+ codebook_size: int = 1024,
162
+ codebook_dim: Union[int, list] = 8,
163
+ quantizer_dropout: bool = False,
164
+ sample_rate: int = 44100,
165
+ ):
166
+ super().__init__()
167
+
168
+ self.encoder_dim = encoder_dim
169
+ self.encoder_rates = encoder_rates
170
+ self.decoder_dim = decoder_dim
171
+ self.decoder_rates = decoder_rates
172
+ self.sample_rate = sample_rate
173
+
174
+ if latent_dim is None:
175
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
176
+
177
+ self.latent_dim = latent_dim
178
+
179
+ self.hop_length = np.prod(encoder_rates)
180
+ self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
181
+
182
+ self.n_codebooks = n_codebooks
183
+ self.codebook_size = codebook_size
184
+ self.codebook_dim = codebook_dim
185
+ self.quantizer = ResidualVectorQuantize(
186
+ input_dim=latent_dim,
187
+ n_codebooks=n_codebooks,
188
+ codebook_size=codebook_size,
189
+ codebook_dim=codebook_dim,
190
+ quantizer_dropout=quantizer_dropout,
191
+ )
192
+
193
+ self.decoder = Decoder(
194
+ latent_dim,
195
+ decoder_dim,
196
+ decoder_rates,
197
+ )
198
+ self.sample_rate = sample_rate
199
+ self.apply(init_weights)
200
+
201
+ self.delay = self.get_delay()
202
+
203
+ def preprocess(self, audio_data, sample_rate):
204
+ if sample_rate is None:
205
+ sample_rate = self.sample_rate
206
+ assert sample_rate == self.sample_rate
207
+
208
+ length = audio_data.shape[-1]
209
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
210
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
211
+
212
+ return audio_data
213
+
214
+ def encode(
215
+ self,
216
+ audio_data: torch.Tensor,
217
+ n_quantizers: int = None,
218
+ ):
219
+ """Encode given audio data and return quantized latent codes
220
+
221
+ Parameters
222
+ ----------
223
+ audio_data : Tensor[B x 1 x T]
224
+ Audio data to encode
225
+ n_quantizers : int, optional
226
+ Number of quantizers to use, by default None
227
+ If None, all quantizers are used.
228
+
229
+ Returns
230
+ -------
231
+ dict
232
+ A dictionary with the following keys:
233
+ "z" : Tensor[B x D x T]
234
+ Quantized continuous representation of input
235
+ "codes" : Tensor[B x N x T]
236
+ Codebook indices for each codebook
237
+ (quantized discrete representation of input)
238
+ "latents" : Tensor[B x N*D x T]
239
+ Projected latents (continuous representation of input before quantization)
240
+ "vq/commitment_loss" : Tensor[1]
241
+ Commitment loss to train encoder to predict vectors closer to codebook
242
+ entries
243
+ "vq/codebook_loss" : Tensor[1]
244
+ Codebook loss to update the codebook
245
+ "length" : int
246
+ Number of samples in input audio
247
+ """
248
+ z = self.encoder(audio_data)
249
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
250
+ return z, codes, latents, commitment_loss, codebook_loss
251
+
252
+ def decode(self, z: torch.Tensor):
253
+ """Decode given latent codes and return audio data
254
+
255
+ Parameters
256
+ ----------
257
+ z : Tensor[B x D x T]
258
+ Quantized continuous representation of input
259
+ length : int, optional
260
+ Number of samples in output audio, by default None
261
+
262
+ Returns
263
+ -------
264
+ dict
265
+ A dictionary with the following keys:
266
+ "audio" : Tensor[B x 1 x length]
267
+ Decoded audio data.
268
+ """
269
+ return self.decoder(z)
270
+
271
+ def forward(
272
+ self,
273
+ audio_data: torch.Tensor,
274
+ sample_rate: int = None,
275
+ n_quantizers: int = None,
276
+ ):
277
+ """Model forward pass
278
+
279
+ Parameters
280
+ ----------
281
+ audio_data : Tensor[B x 1 x T]
282
+ Audio data to encode
283
+ sample_rate : int, optional
284
+ Sample rate of audio data in Hz, by default None
285
+ If None, defaults to `self.sample_rate`
286
+ n_quantizers : int, optional
287
+ Number of quantizers to use, by default None.
288
+ If None, all quantizers are used.
289
+
290
+ Returns
291
+ -------
292
+ dict
293
+ A dictionary with the following keys:
294
+ "z" : Tensor[B x D x T]
295
+ Quantized continuous representation of input
296
+ "codes" : Tensor[B x N x T]
297
+ Codebook indices for each codebook
298
+ (quantized discrete representation of input)
299
+ "latents" : Tensor[B x N*D x T]
300
+ Projected latents (continuous representation of input before quantization)
301
+ "vq/commitment_loss" : Tensor[1]
302
+ Commitment loss to train encoder to predict vectors closer to codebook
303
+ entries
304
+ "vq/codebook_loss" : Tensor[1]
305
+ Codebook loss to update the codebook
306
+ "length" : int
307
+ Number of samples in input audio
308
+ "audio" : Tensor[B x 1 x length]
309
+ Decoded audio data.
310
+ """
311
+ length = audio_data.shape[-1]
312
+ audio_data = self.preprocess(audio_data, sample_rate)
313
+ z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
314
+
315
+ x = self.decode(z)
316
+ return {
317
+ "audio": x[..., :length],
318
+ "z": z,
319
+ "codes": codes,
320
+ "latents": latents,
321
+ "vq/commitment_loss": commitment_loss,
322
+ "vq/codebook_loss": codebook_loss,
323
+ }
324
+
325
+
326
+ if __name__ == "__main__":
327
+ import numpy as np
328
+ from functools import partial
329
+
330
+ model = DAC().to("cpu")
331
+
332
+ for n, m in model.named_modules():
333
+ o = m.extra_repr()
334
+ p = sum([np.prod(p.size()) for p in m.parameters()])
335
+ fn = lambda o, p: o + f" {p / 1e6:<.3f}M params."
336
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
337
+ print(model)
338
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
339
+
340
+ length = 88200 * 2
341
+ x = torch.randn(1, 1, length).to(model.device)
342
+ x.requires_grad_(True)
343
+ x.retain_grad()
344
+
345
+ # Make a forward pass
346
+ out = model(x)["audio"]
347
+ print("Input shape:", x.shape)
348
+ print("Output shape:", out.shape)
349
+
350
+ # Create gradient variable
351
+ grad = torch.zeros_like(out)
352
+ grad[:, :, grad.shape[-1] // 2] = 1
353
+
354
+ # Make a backward pass
355
+ out.backward(grad)
356
+
357
+ # Check non-zero values
358
+ gradmap = x.grad.squeeze(0)
359
+ gradmap = (gradmap != 0).sum(0) # sum across features
360
+ rf = (gradmap != 0).sum()
361
+
362
+ print(f"Receptive field: {rf.item()}")
363
+
364
+ x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
365
+ model.decompress(model.compress(x, verbose=True), verbose=True)
higgs_audio/audio_processing/descriptaudiocodec/dac/nn/layers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ # Scripting this brings model speed up 1.4x
18
+ @torch.jit.script
19
+ def snake(x, alpha):
20
+ shape = x.shape
21
+ x = x.reshape(shape[0], shape[1], -1)
22
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
+ x = x.reshape(shape)
24
+ return x
25
+
26
+
27
+ class Snake1d(nn.Module):
28
+ def __init__(self, channels):
29
+ super().__init__()
30
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
+
32
+ def forward(self, x):
33
+ return snake(x, self.alpha)
higgs_audio/audio_processing/descriptaudiocodec/dac/nn/quantize.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ from dac.nn.layers import WNConv1d
11
+
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """
15
+ Implementation of VQ similar to Karpathy's repo:
16
+ https://github.com/karpathy/deep-vector-quantization
17
+ Additionally uses following tricks from Improved VQGAN
18
+ (https://arxiv.org/pdf/2110.04627.pdf):
19
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20
+ for improved codebook usage
21
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22
+ improves training stability
23
+ """
24
+
25
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
26
+ super().__init__()
27
+ self.codebook_size = codebook_size
28
+ self.codebook_dim = codebook_dim
29
+
30
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
33
+
34
+ def forward(self, z):
35
+ """Quantized the input tensor using a fixed codebook and returns
36
+ the corresponding codebook vectors
37
+
38
+ Parameters
39
+ ----------
40
+ z : Tensor[B x D x T]
41
+
42
+ Returns
43
+ -------
44
+ Tensor[B x D x T]
45
+ Quantized continuous representation of input
46
+ Tensor[1]
47
+ Commitment loss to train encoder to predict vectors closer to codebook
48
+ entries
49
+ Tensor[1]
50
+ Codebook loss to update the codebook
51
+ Tensor[B x T]
52
+ Codebook indices (quantized discrete representation of input)
53
+ Tensor[B x D x T]
54
+ Projected latents (continuous representation of input before quantization)
55
+ """
56
+
57
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
58
+ z_e = self.in_proj(z) # z_e : (B x D x T)
59
+ z_q, indices = self.decode_latents(z_e)
60
+
61
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
62
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
63
+
64
+ z_q = z_e + (z_q - z_e).detach() # noop in forward pass, straight-through gradient estimator in backward pass
65
+
66
+ z_q = self.out_proj(z_q)
67
+
68
+ return z_q, commitment_loss, codebook_loss, indices, z_e
69
+
70
+ def embed_code(self, embed_id):
71
+ return F.embedding(embed_id, self.codebook.weight)
72
+
73
+ def decode_code(self, embed_id):
74
+ return self.embed_code(embed_id).transpose(1, 2)
75
+
76
+ def decode_latents(self, latents):
77
+ encodings = rearrange(latents, "b d t -> (b t) d")
78
+ codebook = self.codebook.weight # codebook: (N x D)
79
+
80
+ # L2 normalize encodings and codebook (ViT-VQGAN)
81
+ encodings = F.normalize(encodings)
82
+ codebook = F.normalize(codebook)
83
+
84
+ # Compute euclidean distance with codebook
85
+ dist = (
86
+ encodings.pow(2).sum(1, keepdim=True)
87
+ - 2 * encodings @ codebook.t()
88
+ + codebook.pow(2).sum(1, keepdim=True).t()
89
+ )
90
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
91
+ z_q = self.decode_code(indices)
92
+ return z_q, indices
93
+
94
+
95
+ class ResidualVectorQuantize(nn.Module):
96
+ """
97
+ Introduced in SoundStream: An end2end neural audio codec
98
+ https://arxiv.org/abs/2107.03312
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ input_dim: int = 512,
104
+ n_codebooks: int = 9,
105
+ codebook_size: int = 1024,
106
+ codebook_dim: Union[int, list] = 8,
107
+ quantizer_dropout: float = 0.0,
108
+ ):
109
+ super().__init__()
110
+ if isinstance(codebook_dim, int):
111
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
112
+
113
+ self.n_codebooks = n_codebooks
114
+ self.codebook_dim = codebook_dim
115
+ self.codebook_size = codebook_size
116
+
117
+ self.quantizers = nn.ModuleList(
118
+ [VectorQuantize(input_dim, codebook_size, codebook_dim[i]) for i in range(n_codebooks)]
119
+ )
120
+ self.quantizer_dropout = quantizer_dropout
121
+
122
+ def forward(self, z, n_quantizers: int = None):
123
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
124
+ the corresponding codebook vectors
125
+ Parameters
126
+ ----------
127
+ z : Tensor[B x D x T]
128
+ n_quantizers : int, optional
129
+ No. of quantizers to use
130
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
131
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
132
+ when in training mode, and a random number of quantizers is used.
133
+ Returns
134
+ -------
135
+ dict
136
+ A dictionary with the following keys:
137
+
138
+ "z" : Tensor[B x D x T]
139
+ Quantized continuous representation of input
140
+ "codes" : Tensor[B x N x T]
141
+ Codebook indices for each codebook
142
+ (quantized discrete representation of input)
143
+ "latents" : Tensor[B x N*D x T]
144
+ Projected latents (continuous representation of input before quantization)
145
+ "vq/commitment_loss" : Tensor[1]
146
+ Commitment loss to train encoder to predict vectors closer to codebook
147
+ entries
148
+ "vq/codebook_loss" : Tensor[1]
149
+ Codebook loss to update the codebook
150
+ """
151
+ z_q = 0
152
+ residual = z
153
+ commitment_loss = 0
154
+ codebook_loss = 0
155
+
156
+ codebook_indices = []
157
+ latents = []
158
+
159
+ if n_quantizers is None:
160
+ n_quantizers = self.n_codebooks
161
+ if self.training:
162
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
163
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
164
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
165
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
166
+ n_quantizers = n_quantizers.to(z.device)
167
+
168
+ for i, quantizer in enumerate(self.quantizers):
169
+ if self.training is False and i >= n_quantizers:
170
+ break
171
+
172
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual)
173
+
174
+ # Create mask to apply quantizer dropout
175
+ mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
176
+ z_q = z_q + z_q_i * mask[:, None, None]
177
+ residual = residual - z_q_i
178
+
179
+ # Sum losses
180
+ commitment_loss += (commitment_loss_i * mask).mean()
181
+ codebook_loss += (codebook_loss_i * mask).mean()
182
+
183
+ codebook_indices.append(indices_i)
184
+ latents.append(z_e_i)
185
+
186
+ codes = torch.stack(codebook_indices, dim=1)
187
+ latents = torch.cat(latents, dim=1)
188
+
189
+ return z_q, codes, latents, commitment_loss, codebook_loss
190
+
191
+ def from_codes(self, codes: torch.Tensor):
192
+ """Given the quantized codes, reconstruct the continuous representation
193
+ Parameters
194
+ ----------
195
+ codes : Tensor[B x N x T]
196
+ Quantized discrete representation of input
197
+ Returns
198
+ -------
199
+ Tensor[B x D x T]
200
+ Quantized continuous representation of input
201
+ """
202
+ z_q = 0.0
203
+ z_p = []
204
+ n_codebooks = codes.shape[1]
205
+ for i in range(n_codebooks):
206
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
207
+ z_p.append(z_p_i)
208
+
209
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
210
+ z_q = z_q + z_q_i
211
+ return z_q, torch.cat(z_p, dim=1), codes
212
+
213
+ def from_latents(self, latents: torch.Tensor):
214
+ """Given the unquantized latents, reconstruct the
215
+ continuous representation after quantization.
216
+
217
+ Parameters
218
+ ----------
219
+ latents : Tensor[B x N x T]
220
+ Continuous representation of input after projection
221
+
222
+ Returns
223
+ -------
224
+ Tensor[B x D x T]
225
+ Quantized representation of full-projected space
226
+ Tensor[B x D x T]
227
+ Quantized representation of latent space
228
+ """
229
+ z_q = 0
230
+ z_p = []
231
+ codes = []
232
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
233
+
234
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
235
+ for i in range(n_codebooks):
236
+ j, k = dims[i], dims[i + 1]
237
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
238
+ z_p.append(z_p_i)
239
+ codes.append(codes_i)
240
+
241
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
242
+ z_q = z_q + z_q_i
243
+
244
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
245
+
246
+
247
+ if __name__ == "__main__":
248
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
249
+ x = torch.randn(16, 512, 80)
250
+ y = rvq(x)
251
+ print(y["latents"].shape)
higgs_audio/audio_processing/higgs_audio_tokenizer.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on code from: https://github.com/zhenye234/xcodec
2
+ # Licensed under MIT License
3
+ # Modifications by BosonAI
4
+
5
+ import math
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import Optional, Union, Sequence
11
+ import numpy as np
12
+ from transformers import AutoModel
13
+ import torchaudio
14
+ import json
15
+ import librosa
16
+ from huggingface_hub import snapshot_download
17
+
18
+ from vector_quantize_pytorch import ResidualFSQ
19
+ from .descriptaudiocodec.dac.model import dac as dac2
20
+ from .quantization.vq import ResidualVectorQuantizer
21
+ from .semantic_module import Encoder, Decoder
22
+
23
+
24
+ class EncodedResult:
25
+ def __init__(self, audio_codes):
26
+ self.audio_codes = audio_codes
27
+
28
+
29
+ class HiggsAudioFeatureExtractor(nn.Module):
30
+ def __init__(self, sampling_rate=16000):
31
+ super().__init__()
32
+ self.sampling_rate = sampling_rate
33
+
34
+ def forward(self, raw_audio, sampling_rate=16000, return_tensors="pt"):
35
+ # Convert from librosa to torch
36
+ audio_signal = torch.tensor(raw_audio)
37
+ audio_signal = audio_signal.unsqueeze(0)
38
+ if len(audio_signal.shape) < 3:
39
+ audio_signal = audio_signal.unsqueeze(0)
40
+ return {"input_values": audio_signal}
41
+
42
+
43
+ class HiggsAudioTokenizer(nn.Module):
44
+ def __init__(
45
+ self,
46
+ n_filters: int = 32,
47
+ D: int = 128,
48
+ target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6],
49
+ ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320
50
+ sample_rate: int = 16000,
51
+ bins: int = 1024,
52
+ n_q: int = 8,
53
+ codebook_dim: int = None,
54
+ normalize: bool = False,
55
+ causal: bool = False,
56
+ semantic_techer: str = "hubert_base_general",
57
+ last_layer_semantic: bool = True,
58
+ merge_mode: str = "concat",
59
+ downsample_mode: str = "step_down",
60
+ semantic_mode: str = "classic",
61
+ vq_scale: int = 1,
62
+ semantic_sample_rate: int = None,
63
+ device: str = "cuda",
64
+ ):
65
+ super().__init__()
66
+ self.hop_length = np.prod(ratios)
67
+ self.semantic_techer = semantic_techer
68
+
69
+ self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz
70
+
71
+ self.target_bandwidths = target_bandwidths
72
+ self.n_q = n_q
73
+ self.sample_rate = sample_rate
74
+ self.encoder = dac2.Encoder(64, ratios, D)
75
+
76
+ self.decoder_2 = dac2.Decoder(D, 1024, ratios)
77
+ self.last_layer_semantic = last_layer_semantic
78
+ self.device = device
79
+ if semantic_techer == "hubert_base":
80
+ self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960")
81
+ self.semantic_sample_rate = 16000
82
+ self.semantic_dim = 768
83
+ self.encoder_semantic_dim = 768
84
+
85
+ elif semantic_techer == "wavlm_base_plus":
86
+ self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus")
87
+ self.semantic_sample_rate = 16000
88
+ self.semantic_dim = 768
89
+ self.encoder_semantic_dim = 768
90
+
91
+ elif semantic_techer == "hubert_base_general":
92
+ self.semantic_model = AutoModel.from_pretrained("ZhenYe234/hubert_base_general_audio")
93
+ self.semantic_sample_rate = 16000
94
+ self.semantic_dim = 768
95
+ self.encoder_semantic_dim = 768
96
+
97
+ # Overwrite semantic model sr to ensure semantic_downsample_factor is an integer
98
+ if semantic_sample_rate is not None:
99
+ self.semantic_sample_rate = semantic_sample_rate
100
+
101
+ self.semantic_model.eval()
102
+
103
+ # make the semantic model parameters do not need gradient
104
+ for param in self.semantic_model.parameters():
105
+ param.requires_grad = False
106
+
107
+ self.semantic_downsample_factor = int(self.hop_length / (self.sample_rate / self.semantic_sample_rate) / 320)
108
+
109
+ self.quantizer_dim = int((D + self.encoder_semantic_dim) // vq_scale)
110
+ self.encoder_semantic = Encoder(input_channels=self.semantic_dim, encode_channels=self.encoder_semantic_dim)
111
+ self.decoder_semantic = Decoder(
112
+ code_dim=self.encoder_semantic_dim,
113
+ output_channels=self.semantic_dim,
114
+ decode_channels=self.semantic_dim,
115
+ )
116
+
117
+ # out_D=D+768
118
+ if isinstance(bins, int): # RVQ
119
+ self.quantizer = ResidualVectorQuantizer(
120
+ dimension=self.quantizer_dim,
121
+ codebook_dim=codebook_dim,
122
+ n_q=n_q,
123
+ bins=bins,
124
+ )
125
+ self.quantizer_type = "RVQ"
126
+ else: # RFSQ
127
+ self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q)
128
+ self.quantizer_type = "RFSQ"
129
+
130
+ self.fc_prior = nn.Linear(D + self.encoder_semantic_dim, self.quantizer_dim)
131
+ self.fc_post1 = nn.Linear(self.quantizer_dim, self.encoder_semantic_dim)
132
+ self.fc_post2 = nn.Linear(self.quantizer_dim, D)
133
+
134
+ self.downsample_mode = downsample_mode
135
+ if downsample_mode == "avg":
136
+ self.semantic_pooling = nn.AvgPool1d(
137
+ kernel_size=self.semantic_downsample_factor,
138
+ stride=self.semantic_downsample_factor,
139
+ )
140
+
141
+ self.audio_tokenizer_feature_extractor = HiggsAudioFeatureExtractor(sampling_rate=self.sample_rate)
142
+
143
+ @property
144
+ def tps(self):
145
+ return self.frame_rate
146
+
147
+ @property
148
+ def sampling_rate(self):
149
+ return self.sample_rate
150
+
151
+ @property
152
+ def num_codebooks(self):
153
+ return self.n_q
154
+
155
+ @property
156
+ def codebook_size(self):
157
+ return self.quantizer_dim
158
+
159
+ def get_last_layer(self):
160
+ return self.decoder.layers[-1].weight
161
+
162
+ def calculate_rec_loss(self, rec, target):
163
+ target = target / target.norm(dim=-1, keepdim=True)
164
+ rec = rec / rec.norm(dim=-1, keepdim=True)
165
+ rec_loss = (1 - (target * rec).sum(-1)).mean()
166
+
167
+ return rec_loss
168
+
169
+ @torch.no_grad()
170
+ def get_regress_target(self, x):
171
+ x = torchaudio.functional.resample(x, self.sample_rate, self.semantic_sample_rate)
172
+
173
+ if (
174
+ self.semantic_techer == "hubert_base"
175
+ or self.semantic_techer == "hubert_base_general"
176
+ or self.semantic_techer == "wavlm_base_plus"
177
+ ):
178
+ x = x[:, 0, :]
179
+ x = F.pad(x, (160, 160))
180
+ target = self.semantic_model(x, output_hidden_states=True).hidden_states
181
+ target = torch.stack(target, dim=1) # .transpose(-1, -2)#.flatten(start_dim=1, end_dim=2)
182
+
183
+ # average for all layers
184
+ target = target.mean(1)
185
+ # target = target[9]
186
+ # if self.hop_length > 320:
187
+ # target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
188
+
189
+ elif self.semantic_techer == "w2v_bert2":
190
+ target = self.semantic_model(x)
191
+
192
+ elif self.semantic_techer.startswith("whisper"):
193
+ if self.last_layer_semantic:
194
+ target = self.semantic_model(x, avg_layers=False)
195
+ else:
196
+ target = self.semantic_model(x, avg_layers=True)
197
+
198
+ elif self.semantic_techer.startswith("mert_music"):
199
+ if self.last_layer_semantic:
200
+ target = self.semantic_model(x, avg_layers=False)
201
+ else:
202
+ target = self.semantic_model(x, avg_layers=True)
203
+
204
+ elif self.semantic_techer.startswith("qwen_audio_omni"):
205
+ target = self.semantic_model(x)
206
+
207
+ if self.downsample_mode == "step_down":
208
+ if self.semantic_downsample_factor > 1:
209
+ target = target[:, :: self.semantic_downsample_factor, :]
210
+
211
+ elif self.downsample_mode == "avg":
212
+ target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
213
+ return target
214
+
215
+ def forward(self, x: torch.Tensor, bw: int):
216
+ e_semantic_input = self.get_regress_target(x).detach()
217
+
218
+ e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
219
+ e_acoustic = self.encoder(x)
220
+
221
+ e = torch.cat([e_acoustic, e_semantic], dim=1)
222
+
223
+ e = self.fc_prior(e.transpose(1, 2))
224
+
225
+ if self.quantizer_type == "RVQ":
226
+ e = e.transpose(1, 2)
227
+ quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
228
+ quantized = quantized.transpose(1, 2)
229
+ else:
230
+ quantized, codes = self.quantizer(e)
231
+ commit_loss = torch.tensor(0.0)
232
+
233
+ quantized_semantic = self.fc_post1(quantized).transpose(1, 2)
234
+ quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
235
+
236
+ o = self.decoder_2(quantized_acoustic)
237
+
238
+ o_semantic = self.decoder_semantic(quantized_semantic)
239
+ semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(), o_semantic)
240
+
241
+ return o, commit_loss, semantic_recon_loss, None
242
+
243
+ def encode(
244
+ self,
245
+ audio_path_or_wv,
246
+ sr=None,
247
+ loudness_normalize=False,
248
+ loudness_threshold=-23.0,
249
+ ):
250
+ if isinstance(audio_path_or_wv, str):
251
+ wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None)
252
+ else:
253
+ wv = audio_path_or_wv
254
+ assert sr is not None
255
+ if loudness_normalize:
256
+ import pyloudnorm as pyln
257
+
258
+ meter = pyln.Meter(sr)
259
+ l = meter.integrated_loudness(wv)
260
+ wv = pyln.normalize.loudness(wv, l, loudness_threshold)
261
+ if sr != self.sampling_rate:
262
+ wv = librosa.resample(wv, orig_sr=sr, target_sr=self.sampling_rate)
263
+ if self.audio_tokenizer_feature_extractor is not None:
264
+ inputs = self.audio_tokenizer_feature_extractor(
265
+ raw_audio=wv,
266
+ sampling_rate=self.audio_tokenizer_feature_extractor.sampling_rate,
267
+ return_tensors="pt",
268
+ )
269
+ input_values = inputs["input_values"].to(self.device)
270
+ else:
271
+ input_values = torch.from_numpy(wv).float().unsqueeze(0)
272
+ with torch.no_grad():
273
+ encoder_outputs = self._xcodec_encode(input_values)
274
+ vq_code = encoder_outputs.audio_codes[0]
275
+ return vq_code
276
+
277
+ def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor:
278
+ bw = target_bw
279
+
280
+ e_semantic_input = self.get_regress_target(x).detach()
281
+
282
+ e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
283
+ e_acoustic = self.encoder(x)
284
+
285
+ if e_acoustic.shape[2] != e_semantic.shape[2]:
286
+ pad_size = 160 * self.semantic_downsample_factor
287
+ e_acoustic = self.encoder(F.pad(x[:, 0, :], (pad_size, pad_size)).unsqueeze(0))
288
+
289
+ if e_acoustic.shape[2] != e_semantic.shape[2]:
290
+ if e_acoustic.shape[2] > e_semantic.shape[2]:
291
+ e_acoustic = e_acoustic[:, :, : e_semantic.shape[2]]
292
+ else:
293
+ e_semantic = e_semantic[:, :, : e_acoustic.shape[2]]
294
+
295
+ e = torch.cat([e_acoustic, e_semantic], dim=1)
296
+
297
+ e = self.fc_prior(e.transpose(1, 2))
298
+
299
+ if self.quantizer_type == "RVQ":
300
+ e = e.transpose(1, 2)
301
+ quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
302
+ codes = codes.permute(1, 0, 2)
303
+ else:
304
+ quantized, codes = self.quantizer(e)
305
+ codes = codes.permute(0, 2, 1)
306
+
307
+ # return codes
308
+ return EncodedResult(codes)
309
+
310
+ def decode(self, vq_code: torch.Tensor) -> torch.Tensor:
311
+ if self.quantizer_type == "RVQ":
312
+ vq_code = vq_code.permute(1, 0, 2)
313
+ quantized = self.quantizer.decode(vq_code)
314
+ quantized = quantized.transpose(1, 2)
315
+ else:
316
+ vq_code = vq_code.permute(0, 2, 1)
317
+ quantized = self.quantizer.get_output_from_indices(vq_code)
318
+ quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
319
+
320
+ o = self.decoder_2(quantized_acoustic)
321
+ return o.cpu().numpy()
322
+
323
+
324
+ def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"):
325
+ is_local = os.path.exists(tokenizer_name_or_path)
326
+ if not is_local:
327
+ tokenizer_path = snapshot_download(tokenizer_name_or_path)
328
+ else:
329
+ tokenizer_path = tokenizer_name_or_path
330
+ config_path = os.path.join(tokenizer_path, "config.json")
331
+ model_path = os.path.join(tokenizer_path, "model.pth")
332
+ config = json.load(open(config_path))
333
+ model = HiggsAudioTokenizer(
334
+ **config,
335
+ device=device,
336
+ )
337
+ parameter_dict = torch.load(model_path, map_location=device)
338
+ model.load_state_dict(parameter_dict, strict=False)
339
+ model.to(device)
340
+ model.eval()
341
+ return model
higgs_audio/audio_processing/quantization/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # flake8: noqa
8
+ from .vq import QuantizedResult, ResidualVectorQuantizer
higgs_audio/audio_processing/quantization/ac.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Arithmetic coder."""
8
+
9
+ import io
10
+ import math
11
+ import random
12
+ import typing as tp
13
+ import torch
14
+
15
+ from ..binary import BitPacker, BitUnpacker
16
+
17
+
18
+ def build_stable_quantized_cdf(
19
+ pdf: torch.Tensor,
20
+ total_range_bits: int,
21
+ roundoff: float = 1e-8,
22
+ min_range: int = 2,
23
+ check: bool = True,
24
+ ) -> torch.Tensor:
25
+ """Turn the given PDF into a quantized CDF that splits
26
+ [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
27
+ to the PDF.
28
+
29
+ Args:
30
+ pdf (torch.Tensor): probability distribution, shape should be `[N]`.
31
+ total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
32
+ during the coding process is `[0, 2 ** total_range_bits - 1]`.
33
+ roundoff (float): will round the pdf up to that level to remove difference coming
34
+ from e.g. evaluating the Language Model on different architectures.
35
+ min_range (int): minimum range width. Should always be at least 2 for numerical
36
+ stability. Use this to avoid pathological behavior is a value
37
+ that is expected to be rare actually happens in real life.
38
+ check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
39
+ """
40
+ pdf = pdf.detach()
41
+ if roundoff:
42
+ pdf = (pdf / roundoff).floor() * roundoff
43
+ # interpolate with uniform distribution to achieve desired minimum probability.
44
+ total_range = 2**total_range_bits
45
+ cardinality = len(pdf)
46
+ alpha = min_range * cardinality / total_range
47
+ assert alpha <= 1, "you must reduce min_range"
48
+ ranges = (((1 - alpha) * total_range) * pdf).floor().long()
49
+ ranges += min_range
50
+ quantized_cdf = torch.cumsum(ranges, dim=-1)
51
+ if min_range < 2:
52
+ raise ValueError("min_range must be at least 2.")
53
+ if check:
54
+ assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
55
+ if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
56
+ raise ValueError("You must increase your total_range_bits.")
57
+ return quantized_cdf
58
+
59
+
60
+ class ArithmeticCoder:
61
+ """ArithmeticCoder,
62
+ Let us take a distribution `p` over `N` symbols, and assume we have a stream
63
+ of random variables `s_t` sampled from `p`. Let us assume that we have a budget
64
+ of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
65
+ corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
66
+ sequence `(s_t)` by doing the following:
67
+
68
+ 1) Initialize the current range to` [0 ** 2 B - 1]`.
69
+ 2) For each time step t, split the current range into contiguous chunks,
70
+ one for each possible outcome, with size roughly proportional to `p`.
71
+ For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
72
+ would be `{[0, 2], [3, 3]}`.
73
+ 3) Select the chunk corresponding to `s_t`, and replace the current range with this.
74
+ 4) When done encoding all the values, just select any value remaining in the range.
75
+
76
+ You will notice that this procedure can fail: for instance if at any point in time
77
+ the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
78
+ possible outcome. Intuitively, the more likely a value is, the less the range width
79
+ will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
80
+ coding scheme, likely outcomes would take less bits, and more of them can be coded
81
+ with a fixed budget.
82
+
83
+ In practice, we do not know `B` ahead of time, but we have a way to inject new bits
84
+ when the current range decreases below a given limit (given by `total_range_bits`), without
85
+ having to redo all the computations. If we encode mostly likely values, we will seldom
86
+ need to inject new bits, but a single rare value can deplete our stock of entropy!
87
+
88
+ In this explanation, we assumed that the distribution `p` was constant. In fact, the present
89
+ code works for any sequence `(p_t)` possibly different for each timestep.
90
+ We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
91
+ the KL between the true distribution and `p_t`, the most efficient the coding will be.
92
+
93
+ Args:
94
+ fo (IO[bytes]): file-like object to which the bytes will be written to.
95
+ total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
96
+ Any time the current range width fall under this limit, new bits will
97
+ be injected to rescale the initial range.
98
+ """
99
+
100
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
101
+ assert total_range_bits <= 30
102
+ self.total_range_bits = total_range_bits
103
+ self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
104
+ self.low: int = 0
105
+ self.high: int = 0
106
+ self.max_bit: int = -1
107
+ self._dbg: tp.List[tp.Any] = []
108
+ self._dbg2: tp.List[tp.Any] = []
109
+
110
+ @property
111
+ def delta(self) -> int:
112
+ """Return the current range width."""
113
+ return self.high - self.low + 1
114
+
115
+ def _flush_common_prefix(self):
116
+ # If self.low and self.high start with the sames bits,
117
+ # those won't change anymore as we always just increase the range
118
+ # by powers of 2, and we can flush them out to the bit stream.
119
+ assert self.high >= self.low, (self.low, self.high)
120
+ assert self.high < 2 ** (self.max_bit + 1)
121
+ while self.max_bit >= 0:
122
+ b1 = self.low >> self.max_bit
123
+ b2 = self.high >> self.max_bit
124
+ if b1 == b2:
125
+ self.low -= b1 << self.max_bit
126
+ self.high -= b1 << self.max_bit
127
+ assert self.high >= self.low, (self.high, self.low, self.max_bit)
128
+ assert self.low >= 0
129
+ self.max_bit -= 1
130
+ self.packer.push(b1)
131
+ else:
132
+ break
133
+
134
+ def push(self, symbol: int, quantized_cdf: torch.Tensor):
135
+ """Push the given symbol on the stream, flushing out bits
136
+ if possible.
137
+
138
+ Args:
139
+ symbol (int): symbol to encode with the AC.
140
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
141
+ to build this from your pdf estimate.
142
+ """
143
+ while self.delta < 2**self.total_range_bits:
144
+ self.low *= 2
145
+ self.high = self.high * 2 + 1
146
+ self.max_bit += 1
147
+
148
+ range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
149
+ range_high = quantized_cdf[symbol].item() - 1
150
+ effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
151
+ effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
152
+ assert self.low <= self.high
153
+ self.high = self.low + effective_high
154
+ self.low = self.low + effective_low
155
+ assert self.low <= self.high, (
156
+ effective_low,
157
+ effective_high,
158
+ range_low,
159
+ range_high,
160
+ )
161
+ self._dbg.append((self.low, self.high))
162
+ self._dbg2.append((self.low, self.high))
163
+ outs = self._flush_common_prefix()
164
+ assert self.low <= self.high
165
+ assert self.max_bit >= -1
166
+ assert self.max_bit <= 61, self.max_bit
167
+ return outs
168
+
169
+ def flush(self):
170
+ """Flush the remaining information to the stream."""
171
+ while self.max_bit >= 0:
172
+ b1 = (self.low >> self.max_bit) & 1
173
+ self.packer.push(b1)
174
+ self.max_bit -= 1
175
+ self.packer.flush()
176
+
177
+
178
+ class ArithmeticDecoder:
179
+ """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
180
+
181
+ Note that this must be called with **exactly** the same parameters and sequence
182
+ of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
183
+
184
+ If the AC encoder current range is [L, H], with `L` and `H` having the some common
185
+ prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
186
+ For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
187
+ `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
188
+ for a specific sequence of symbols and a binary-search allows us to decode those symbols.
189
+ At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
190
+ and we will need to read new bits from the stream and repeat the process.
191
+
192
+ """
193
+
194
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
195
+ self.total_range_bits = total_range_bits
196
+ self.low: int = 0
197
+ self.high: int = 0
198
+ self.current: int = 0
199
+ self.max_bit: int = -1
200
+ self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
201
+ # Following is for debugging
202
+ self._dbg: tp.List[tp.Any] = []
203
+ self._dbg2: tp.List[tp.Any] = []
204
+ self._last: tp.Any = None
205
+
206
+ @property
207
+ def delta(self) -> int:
208
+ return self.high - self.low + 1
209
+
210
+ def _flush_common_prefix(self):
211
+ # Given the current range [L, H], if both have a common prefix,
212
+ # we know we can remove it from our representation to avoid handling large numbers.
213
+ while self.max_bit >= 0:
214
+ b1 = self.low >> self.max_bit
215
+ b2 = self.high >> self.max_bit
216
+ if b1 == b2:
217
+ self.low -= b1 << self.max_bit
218
+ self.high -= b1 << self.max_bit
219
+ self.current -= b1 << self.max_bit
220
+ assert self.high >= self.low
221
+ assert self.low >= 0
222
+ self.max_bit -= 1
223
+ else:
224
+ break
225
+
226
+ def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
227
+ """Pull a symbol, reading as many bits from the stream as required.
228
+ This returns `None` when the stream has been exhausted.
229
+
230
+ Args:
231
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
232
+ to build this from your pdf estimate. This must be **exatly**
233
+ the same cdf as the one used at encoding time.
234
+ """
235
+ while self.delta < 2**self.total_range_bits:
236
+ bit = self.unpacker.pull()
237
+ if bit is None:
238
+ return None
239
+ self.low *= 2
240
+ self.high = self.high * 2 + 1
241
+ self.current = self.current * 2 + bit
242
+ self.max_bit += 1
243
+
244
+ def bin_search(low_idx: int, high_idx: int):
245
+ # Binary search is not just for coding interviews :)
246
+ if high_idx < low_idx:
247
+ raise RuntimeError("Binary search failed")
248
+ mid = (low_idx + high_idx) // 2
249
+ range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
250
+ range_high = quantized_cdf[mid].item() - 1
251
+ effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
252
+ effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
253
+ low = effective_low + self.low
254
+ high = effective_high + self.low
255
+ if self.current >= low:
256
+ if self.current <= high:
257
+ return (mid, low, high, self.current)
258
+ else:
259
+ return bin_search(mid + 1, high_idx)
260
+ else:
261
+ return bin_search(low_idx, mid - 1)
262
+
263
+ self._last = (self.low, self.high, self.current, self.max_bit)
264
+ sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
265
+ self._dbg.append((self.low, self.high, self.current))
266
+ self._flush_common_prefix()
267
+ self._dbg2.append((self.low, self.high, self.current))
268
+
269
+ return sym
270
+
271
+
272
+ def test():
273
+ torch.manual_seed(1234)
274
+ random.seed(1234)
275
+ for _ in range(4):
276
+ pdfs = []
277
+ cardinality = random.randrange(4000)
278
+ steps = random.randrange(100, 500)
279
+ fo = io.BytesIO()
280
+ encoder = ArithmeticCoder(fo)
281
+ symbols = []
282
+ for step in range(steps):
283
+ pdf = torch.softmax(torch.randn(cardinality), dim=0)
284
+ pdfs.append(pdf)
285
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
286
+ symbol = torch.multinomial(pdf, 1).item()
287
+ symbols.append(symbol)
288
+ encoder.push(symbol, q_cdf)
289
+ encoder.flush()
290
+
291
+ fo.seek(0)
292
+ decoder = ArithmeticDecoder(fo)
293
+ for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
294
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
295
+ decoded_symbol = decoder.pull(q_cdf)
296
+ assert decoded_symbol == symbol, idx
297
+ assert decoder.pull(torch.zeros(1)) is None
298
+
299
+
300
+ if __name__ == "__main__":
301
+ test()
higgs_audio/audio_processing/quantization/core_vq.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # This implementation is inspired from
8
+ # https://github.com/lucidrains/vector-quantize-pytorch
9
+ # which is released under MIT License. Hereafter, the original license:
10
+ # MIT License
11
+ #
12
+ # Copyright (c) 2020 Phil Wang
13
+ #
14
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ # of this software and associated documentation files (the "Software"), to deal
16
+ # in the Software without restriction, including without limitation the rights
17
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ # copies of the Software, and to permit persons to whom the Software is
19
+ # furnished to do so, subject to the following conditions:
20
+ #
21
+ # The above copyright notice and this permission notice shall be included in all
22
+ # copies or substantial portions of the Software.
23
+ #
24
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ # SOFTWARE.
31
+
32
+ """Core vector quantization implementation."""
33
+
34
+ import typing as tp
35
+
36
+ from einops import rearrange, repeat
37
+ import torch
38
+ from torch import nn
39
+ import torch.nn.functional as F
40
+
41
+ from xcodec.quantization.distrib import broadcast_tensors, rank
42
+
43
+
44
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
45
+ return val if val is not None else d
46
+
47
+
48
+ def ema_inplace(moving_avg, new, decay: float):
49
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
50
+
51
+
52
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
53
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
54
+
55
+
56
+ def uniform_init(*shape: int):
57
+ t = torch.empty(shape)
58
+ nn.init.kaiming_uniform_(t)
59
+ return t
60
+
61
+
62
+ def sample_vectors(samples, num: int):
63
+ num_samples, device = samples.shape[0], samples.device
64
+
65
+ if num_samples >= num:
66
+ indices = torch.randperm(num_samples, device=device)[:num]
67
+ else:
68
+ indices = torch.randint(0, num_samples, (num,), device=device)
69
+
70
+ return samples[indices]
71
+
72
+
73
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
74
+ dim, dtype = samples.shape[-1], samples.dtype
75
+
76
+ means = sample_vectors(samples, num_clusters)
77
+
78
+ for _ in range(num_iters):
79
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
80
+ dists = -(diffs**2).sum(dim=-1)
81
+
82
+ buckets = dists.max(dim=-1).indices
83
+ bins = torch.bincount(buckets, minlength=num_clusters)
84
+ zero_mask = bins == 0
85
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
86
+
87
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
88
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
89
+ new_means = new_means / bins_min_clamped[..., None]
90
+
91
+ means = torch.where(zero_mask[..., None], means, new_means)
92
+
93
+ return means, bins
94
+
95
+
96
+ class EuclideanCodebook(nn.Module):
97
+ """Codebook with Euclidean distance.
98
+ Args:
99
+ dim (int): Dimension.
100
+ codebook_size (int): Codebook size.
101
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
102
+ If set to true, run the k-means algorithm on the first training batch and use
103
+ the learned centroids as initialization.
104
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
105
+ decay (float): Decay for exponential moving average over the codebooks.
106
+ epsilon (float): Epsilon value for numerical stability.
107
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
108
+ that have an exponential moving average cluster size less than the specified threshold with
109
+ randomly selected vector from the current batch.
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ dim: int,
115
+ codebook_size: int,
116
+ kmeans_init: int = False,
117
+ kmeans_iters: int = 10,
118
+ decay: float = 0.99,
119
+ epsilon: float = 1e-5,
120
+ threshold_ema_dead_code: int = 2,
121
+ ):
122
+ super().__init__()
123
+ self.decay = decay
124
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
125
+ embed = init_fn(codebook_size, dim)
126
+
127
+ self.codebook_size = codebook_size
128
+
129
+ self.kmeans_iters = kmeans_iters
130
+ self.epsilon = epsilon
131
+ self.threshold_ema_dead_code = threshold_ema_dead_code
132
+
133
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
134
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
135
+ self.register_buffer("embed", embed)
136
+ self.register_buffer("embed_avg", embed.clone())
137
+
138
+ @torch.jit.ignore
139
+ def init_embed_(self, data):
140
+ if self.inited:
141
+ return
142
+
143
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
144
+ self.embed.data.copy_(embed)
145
+ self.embed_avg.data.copy_(embed.clone())
146
+ self.cluster_size.data.copy_(cluster_size)
147
+ self.inited.data.copy_(torch.Tensor([True]))
148
+ # Make sure all buffers across workers are in sync after initialization
149
+ broadcast_tensors(self.buffers())
150
+
151
+ def replace_(self, samples, mask):
152
+ modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
153
+ self.embed.data.copy_(modified_codebook)
154
+
155
+ def expire_codes_(self, batch_samples):
156
+ if self.threshold_ema_dead_code == 0:
157
+ return
158
+
159
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
160
+ if not torch.any(expired_codes):
161
+ return
162
+
163
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
164
+ self.replace_(batch_samples, mask=expired_codes)
165
+ broadcast_tensors(self.buffers())
166
+
167
+ def preprocess(self, x):
168
+ x = rearrange(x, "... d -> (...) d")
169
+ return x
170
+
171
+ def quantize(self, x):
172
+ embed = self.embed.t()
173
+ dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
174
+ embed_ind = dist.max(dim=-1).indices
175
+ return embed_ind
176
+
177
+ def postprocess_emb(self, embed_ind, shape):
178
+ return embed_ind.view(*shape[:-1])
179
+
180
+ def dequantize(self, embed_ind):
181
+ quantize = F.embedding(embed_ind, self.embed) # get embedding based on index
182
+ return quantize
183
+
184
+ def encode(self, x):
185
+ shape = x.shape
186
+ # pre-process
187
+ x = self.preprocess(x)
188
+ # quantize
189
+ embed_ind = self.quantize(x) # get index based on Euclidean distance
190
+ # post-process
191
+ embed_ind = self.postprocess_emb(embed_ind, shape)
192
+ return embed_ind
193
+
194
+ def decode(self, embed_ind):
195
+ quantize = self.dequantize(embed_ind)
196
+ return quantize
197
+
198
+ def forward(self, x):
199
+ shape, dtype = x.shape, x.dtype
200
+ x = self.preprocess(x)
201
+
202
+ self.init_embed_(x)
203
+
204
+ embed_ind = self.quantize(x)
205
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
206
+ embed_ind = self.postprocess_emb(embed_ind, shape)
207
+ quantize = self.dequantize(embed_ind)
208
+
209
+ if self.training:
210
+ # We do the expiry of code at that point as buffers are in sync
211
+ # and all the workers will take the same decision.
212
+ self.expire_codes_(x)
213
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
214
+ embed_sum = x.t() @ embed_onehot
215
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
216
+ cluster_size = (
217
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
218
+ )
219
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
220
+ self.embed.data.copy_(embed_normalized)
221
+
222
+ return quantize, embed_ind
223
+
224
+
225
+ class VectorQuantization(nn.Module):
226
+ """Vector quantization implementation.
227
+ Currently supports only euclidean distance.
228
+ Args:
229
+ dim (int): Dimension
230
+ codebook_size (int): Codebook size
231
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
232
+ decay (float): Decay for exponential moving average over the codebooks.
233
+ epsilon (float): Epsilon value for numerical stability.
234
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
235
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
236
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
237
+ that have an exponential moving average cluster size less than the specified threshold with
238
+ randomly selected vector from the current batch.
239
+ commitment_weight (float): Weight for commitment loss.
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ dim: int,
245
+ codebook_size: int,
246
+ codebook_dim: tp.Optional[int] = None,
247
+ decay: float = 0.99,
248
+ epsilon: float = 1e-5,
249
+ kmeans_init: bool = True,
250
+ kmeans_iters: int = 50,
251
+ threshold_ema_dead_code: int = 2,
252
+ commitment_weight: float = 1.0,
253
+ ):
254
+ super().__init__()
255
+ _codebook_dim: int = default(codebook_dim, dim)
256
+
257
+ requires_projection = _codebook_dim != dim
258
+ self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
259
+ self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
260
+
261
+ self.epsilon = epsilon
262
+ self.commitment_weight = commitment_weight
263
+
264
+ self._codebook = EuclideanCodebook(
265
+ dim=_codebook_dim,
266
+ codebook_size=codebook_size,
267
+ kmeans_init=kmeans_init,
268
+ kmeans_iters=kmeans_iters,
269
+ decay=decay,
270
+ epsilon=epsilon,
271
+ threshold_ema_dead_code=threshold_ema_dead_code,
272
+ )
273
+ self.codebook_size = codebook_size
274
+
275
+ @property
276
+ def codebook(self):
277
+ return self._codebook.embed
278
+
279
+ def encode(self, x):
280
+ x = rearrange(x, "b d n -> b n d")
281
+ x = self.project_in(x)
282
+ embed_in = self._codebook.encode(x)
283
+ return embed_in
284
+
285
+ def decode(self, embed_ind):
286
+ quantize = self._codebook.decode(embed_ind)
287
+ quantize = self.project_out(quantize)
288
+ quantize = rearrange(quantize, "b n d -> b d n")
289
+ return quantize
290
+
291
+ def forward(self, x):
292
+ device = x.device
293
+ x = rearrange(x, "b d n -> b n d")
294
+ x = self.project_in(x)
295
+
296
+ quantize, embed_ind = self._codebook(x)
297
+
298
+ if self.training:
299
+ quantize = x + (quantize - x).detach()
300
+
301
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
302
+
303
+ if self.training:
304
+ if self.commitment_weight > 0:
305
+ commit_loss = F.mse_loss(quantize.detach(), x)
306
+ loss = loss + commit_loss * self.commitment_weight
307
+
308
+ quantize = self.project_out(quantize)
309
+ quantize = rearrange(quantize, "b n d -> b d n")
310
+ return quantize, embed_ind, loss
311
+
312
+
313
+ class ResidualVectorQuantization(nn.Module):
314
+ """Residual vector quantization implementation.
315
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
316
+ """
317
+
318
+ def __init__(self, *, num_quantizers, **kwargs):
319
+ super().__init__()
320
+ self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
321
+
322
+ def forward(self, x, n_q: tp.Optional[int] = None):
323
+ quantized_out = 0.0
324
+ residual = x
325
+
326
+ all_losses = []
327
+ all_indices = []
328
+
329
+ n_q = n_q or len(self.layers)
330
+
331
+ for layer in self.layers[:n_q]:
332
+ quantized, indices, loss = layer(residual)
333
+ residual = residual - quantized
334
+ quantized_out = quantized_out + quantized
335
+
336
+ all_indices.append(indices)
337
+ all_losses.append(loss)
338
+
339
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
340
+ return quantized_out, out_indices, out_losses
341
+
342
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
343
+ residual = x
344
+ all_indices = []
345
+ n_q = n_q or len(self.layers)
346
+ for layer in self.layers[:n_q]:
347
+ indices = layer.encode(residual)
348
+ quantized = layer.decode(indices)
349
+ residual = residual - quantized
350
+ all_indices.append(indices)
351
+ out_indices = torch.stack(all_indices)
352
+ return out_indices
353
+
354
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
355
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
356
+ for i, indices in enumerate(q_indices):
357
+ layer = self.layers[i]
358
+ quantized = layer.decode(indices)
359
+ quantized_out = quantized_out + quantized
360
+ return quantized_out
higgs_audio/audio_processing/quantization/core_vq_lsx_version.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c)
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # This implementation is inspired from
6
+ # https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py and
7
+ # https://github.com/clementchadebec/benchmark_VAE/blob/dfa0dcf6c79172df5d27769c09c860c42008baaa/src/pythae/models/vq_vae/vq_vae_utils.py#L81
8
+ #
9
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ # All rights reserved.
11
+ #
12
+ # This source code is licensed under the license found in the
13
+ # LICENSE file in the root directory of this source tree.
14
+ #
15
+ # This implementation is inspired from
16
+ # https://github.com/lucidrains/vector-quantize-pytorch
17
+ # which is released under MIT License. Hereafter, the original license:
18
+ # MIT License
19
+ #
20
+ # Copyright (c) 2020 Phil Wang
21
+ #
22
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
23
+ # of this software and associated documentation files (the "Software"), to deal
24
+ # in the Software without restriction, including without limitation the rights
25
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
26
+ # copies of the Software, and to permit persons to whom the Software is
27
+ # furnished to do so, subject to the following conditions:
28
+ #
29
+ # The above copyright notice and this permission notice shall be included in all
30
+ # copies or substantial portions of the Software.
31
+ #
32
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38
+ # SOFTWARE.
39
+
40
+ """Core vector quantization implementation."""
41
+
42
+ import typing as tp
43
+
44
+ from einops import rearrange
45
+ import torch
46
+ from torch import nn
47
+ import torch.nn.functional as F
48
+ import torch.distributed as dist
49
+
50
+ from .distrib import broadcast_tensors, is_distributed
51
+ from .ddp_utils import SyncFunction
52
+
53
+
54
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
55
+ return val if val is not None else d
56
+
57
+
58
+ def ema_inplace(moving_avg, new, decay: float):
59
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
60
+
61
+
62
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
63
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
64
+
65
+
66
+ def uniform_init(*shape: int):
67
+ t = torch.empty(shape)
68
+ nn.init.kaiming_uniform_(t)
69
+ return t
70
+
71
+
72
+ def sample_vectors(samples, num: int):
73
+ num_samples, device = samples.shape[0], samples.device
74
+
75
+ if num_samples >= num:
76
+ indices = torch.randperm(num_samples, device=device)[:num]
77
+ else:
78
+ indices = torch.randint(0, num_samples, (num,), device=device)
79
+
80
+ return samples[indices]
81
+
82
+
83
+ def kmeans(
84
+ samples,
85
+ num_clusters: int,
86
+ num_iters: int = 10,
87
+ frames_to_use: int = 10_000,
88
+ batch_size: int = 64,
89
+ ):
90
+ """
91
+ Memory-efficient K-means clustering.
92
+ Args:
93
+ samples (tensor): shape [N, D]
94
+ num_clusters (int): number of centroids.
95
+ num_iters (int): number of iterations.
96
+ frames_to_use (int): subsample size from total samples.
97
+ batch_size (int): batch size used in distance computation.
98
+ Returns:
99
+ means: [num_clusters, D]
100
+ bins: [num_clusters] (number of points per cluster)
101
+ """
102
+ N, D = samples.shape
103
+ dtype, device = samples.dtype, samples.device
104
+
105
+ if frames_to_use < N:
106
+ indices = torch.randperm(N, device=device)[:frames_to_use]
107
+ samples = samples[indices]
108
+
109
+ means = sample_vectors(samples, num_clusters)
110
+
111
+ for _ in range(num_iters):
112
+ # Store cluster assignments
113
+ all_assignments = []
114
+
115
+ for i in range(0, samples.shape[0], batch_size):
116
+ batch = samples[i : i + batch_size] # [B, D]
117
+ dists = torch.cdist(batch, means, p=2) # [B, C]
118
+ assignments = dists.argmin(dim=1) # [B]
119
+ all_assignments.append(assignments)
120
+
121
+ buckets = torch.cat(all_assignments, dim=0) # [N]
122
+ bins = torch.bincount(buckets, minlength=num_clusters)
123
+ zero_mask = bins == 0
124
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
125
+
126
+ # Compute new means
127
+ new_means = torch.zeros_like(means)
128
+ for i in range(num_clusters):
129
+ mask = buckets == i
130
+ if mask.any():
131
+ new_means[i] = samples[mask].mean(dim=0)
132
+
133
+ means = torch.where(zero_mask[:, None], means, new_means)
134
+
135
+ return means, bins
136
+
137
+
138
+ class EuclideanCodebook(nn.Module):
139
+ """Codebook with Euclidean distance.
140
+ Args:
141
+ dim (int): Dimension.
142
+ codebook_size (int): Codebook size.
143
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
144
+ If set to true, run the k-means algorithm on the first training batch and use
145
+ the learned centroids as initialization.
146
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
147
+ decay (float): Decay for exponential moving average over the codebooks.
148
+ epsilon (float): Epsilon value for numerical stability.
149
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
150
+ that have an exponential moving average cluster size less than the specified threshold with
151
+ randomly selected vector from the current batch.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ dim: int,
157
+ codebook_size: int,
158
+ kmeans_init: int = False,
159
+ kmeans_iters: int = 10,
160
+ decay: float = 0.99,
161
+ epsilon: float = 1e-5,
162
+ threshold_ema_dead_code: int = 2,
163
+ ):
164
+ super().__init__()
165
+ self.decay = decay
166
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
167
+ embed = init_fn(codebook_size, dim)
168
+
169
+ self.codebook_size = codebook_size
170
+
171
+ self.kmeans_iters = kmeans_iters
172
+ self.epsilon = epsilon
173
+ self.threshold_ema_dead_code = threshold_ema_dead_code
174
+
175
+ # Flag variable to indicate whether the codebook is initialized
176
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
177
+ # Runing EMA cluster size/count: N_i^t in eq. (6) in vqvae paper
178
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
179
+ # Codebook
180
+ self.register_buffer("embed", embed)
181
+ # EMA codebook: eq. (7) in vqvae paper
182
+ self.register_buffer("embed_avg", embed.clone())
183
+
184
+ @torch.jit.ignore
185
+ def init_embed_(self, data):
186
+ """Initialize codebook.
187
+ Args:
188
+ data (tensor): [B * T, D].
189
+ """
190
+ if self.inited:
191
+ return
192
+
193
+ ## NOTE (snippet added by Songxiang Liu): gather data from all gpus
194
+ if dist.is_available() and dist.is_initialized():
195
+ # [B * T * world_size, D]
196
+ data = SyncFunction.apply(data)
197
+
198
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
199
+ self.embed.data.copy_(embed)
200
+ self.embed_avg.data.copy_(embed.clone())
201
+ self.cluster_size.data.copy_(cluster_size)
202
+ self.inited.data.copy_(torch.Tensor([True]))
203
+ # Make sure all buffers across workers are in sync after initialization
204
+ broadcast_tensors(self.buffers())
205
+
206
+ def replace_(self, samples, mask):
207
+ modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
208
+ self.embed.data.copy_(modified_codebook)
209
+
210
+ def expire_codes_(self, batch_samples):
211
+ if self.threshold_ema_dead_code == 0:
212
+ return
213
+
214
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
215
+ if not torch.any(expired_codes):
216
+ return
217
+
218
+ ## NOTE (snippet added by Songxiang Liu): gather data from all gpus
219
+ if is_distributed():
220
+ # [B * T * world_size, D]
221
+ batch_samples = SyncFunction.apply(batch_samples)
222
+
223
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
224
+ self.replace_(batch_samples, mask=expired_codes)
225
+ broadcast_tensors(self.buffers())
226
+
227
+ def preprocess(self, x):
228
+ x = rearrange(x, "... d -> (...) d")
229
+ return x
230
+
231
+ def quantize(self, x):
232
+ embed = self.embed.t()
233
+ dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
234
+ embed_ind = dist.max(dim=-1).indices
235
+ return embed_ind
236
+
237
+ def postprocess_emb(self, embed_ind, shape):
238
+ return embed_ind.view(*shape[:-1])
239
+
240
+ def dequantize(self, embed_ind):
241
+ quantize = F.embedding(embed_ind, self.embed)
242
+ return quantize
243
+
244
+ def encode(self, x):
245
+ shape = x.shape
246
+ # pre-process
247
+ x = self.preprocess(x) # [B, T, D] -> [B*T, D]
248
+ # quantize
249
+ embed_ind = self.quantize(x)
250
+ # post-process
251
+ embed_ind = self.postprocess_emb(embed_ind, shape)
252
+ return embed_ind
253
+
254
+ def decode(self, embed_ind):
255
+ quantize = self.dequantize(embed_ind)
256
+ return quantize
257
+
258
+ def forward(self, x):
259
+ # shape: [B, T, D]
260
+ shape, dtype = x.shape, x.dtype
261
+ x = self.preprocess(x) # [B, T, D] -> [B*T, D]
262
+
263
+ # Initialize codebook
264
+ self.init_embed_(x)
265
+
266
+ embed_ind = self.quantize(x) # [B*T,]
267
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) # [B*T, cb-size]
268
+ embed_ind = self.postprocess_emb(embed_ind, shape) # [B, T]
269
+ quantize = self.dequantize(embed_ind) # [B, T, D]
270
+
271
+ if self.training:
272
+ ### Update codebook by EMA
273
+ embed_onehot_sum = embed_onehot.sum(0) # [cb-size,]
274
+ embed_sum = x.t() @ embed_onehot # [D, cb-size]
275
+ if is_distributed():
276
+ dist.all_reduce(embed_onehot_sum)
277
+ dist.all_reduce(embed_sum)
278
+ # Update ema cluster count N_i^t, eq. (6) in vqvae paper
279
+ self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
280
+ # Update ema embed: eq. (7) in vqvae paper
281
+ self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay)
282
+ # apply laplace smoothing
283
+ n = self.cluster_size.sum()
284
+ cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n
285
+ # Update ema embed: eq. (8) in vqvae paper
286
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
287
+ self.embed.data.copy_(embed_normalized)
288
+
289
+ # We do the expiry of code at that point as buffers are in sync
290
+ # and all the workers will take the same decision.
291
+ self.expire_codes_(x)
292
+
293
+ return quantize, embed_ind
294
+
295
+
296
+ class VectorQuantization(nn.Module):
297
+ """Vector quantization implementation.
298
+ Currently supports only euclidean distance.
299
+ Args:
300
+ dim (int): Dimension
301
+ codebook_size (int): Codebook size
302
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
303
+ decay (float): Decay for exponential moving average over the codebooks.
304
+ epsilon (float): Epsilon value for numerical stability.
305
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
306
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
307
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
308
+ that have an exponential moving average cluster size less than the specified threshold with
309
+ randomly selected vector from the current batch.
310
+ commitment_weight (float): Weight for commitment loss.
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ dim: int,
316
+ codebook_size: int,
317
+ codebook_dim: tp.Optional[int] = None,
318
+ decay: float = 0.99,
319
+ epsilon: float = 1e-5,
320
+ kmeans_init: bool = True,
321
+ kmeans_iters: int = 50,
322
+ threshold_ema_dead_code: int = 2,
323
+ commitment_weight: float = 1.0,
324
+ ):
325
+ super().__init__()
326
+ _codebook_dim: int = default(codebook_dim, dim)
327
+
328
+ requires_projection = _codebook_dim != dim
329
+ self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
330
+ self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
331
+
332
+ self.epsilon = epsilon
333
+ self.commitment_weight = commitment_weight
334
+
335
+ self._codebook = EuclideanCodebook(
336
+ dim=_codebook_dim,
337
+ codebook_size=codebook_size,
338
+ kmeans_init=kmeans_init,
339
+ kmeans_iters=kmeans_iters,
340
+ decay=decay,
341
+ epsilon=epsilon,
342
+ threshold_ema_dead_code=threshold_ema_dead_code,
343
+ )
344
+ self.codebook_size = codebook_size
345
+
346
+ @property
347
+ def codebook(self):
348
+ return self._codebook.embed
349
+
350
+ def encode(self, x):
351
+ x = rearrange(x, "b d n -> b n d")
352
+ x = self.project_in(x)
353
+ embed_in = self._codebook.encode(x)
354
+ return embed_in
355
+
356
+ def decode(self, embed_ind):
357
+ quantize = self._codebook.decode(embed_ind)
358
+ quantize = self.project_out(quantize)
359
+ quantize = rearrange(quantize, "b n d -> b d n")
360
+ return quantize
361
+
362
+ def forward(self, x):
363
+ device = x.device
364
+ x = x.transpose(1, 2).contiguous() # [b d n] -> [b n d]
365
+ x = self.project_in(x)
366
+
367
+ quantize, embed_ind = self._codebook(x)
368
+
369
+ if self.training:
370
+ quantize = x + (quantize - x).detach()
371
+
372
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
373
+
374
+ if self.training:
375
+ if self.commitment_weight > 0:
376
+ commit_loss = F.mse_loss(quantize.detach(), x)
377
+ loss = loss + commit_loss * self.commitment_weight
378
+
379
+ quantize = self.project_out(quantize)
380
+ quantize = quantize.transpose(1, 2).contiguous() # [b n d] -> [b d n]
381
+ return quantize, embed_ind, loss
382
+
383
+
384
+ class ResidualVectorQuantization(nn.Module):
385
+ """Residual vector quantization implementation.
386
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
387
+ """
388
+
389
+ def __init__(self, *, num_quantizers, **kwargs):
390
+ super().__init__()
391
+ self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
392
+
393
+ def forward(self, x, n_q: tp.Optional[int] = None):
394
+ quantized_out = 0.0
395
+ residual = x
396
+
397
+ all_losses = []
398
+ all_indices = []
399
+
400
+ n_q = n_q or len(self.layers)
401
+
402
+ for layer in self.layers[:n_q]:
403
+ quantized, indices, loss = layer(residual)
404
+ residual = residual - quantized
405
+ quantized_out = quantized_out + quantized
406
+
407
+ all_indices.append(indices)
408
+ all_losses.append(loss)
409
+
410
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
411
+ return quantized_out, out_indices, out_losses
412
+
413
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
414
+ residual = x
415
+ all_indices = []
416
+ n_q = n_q or len(self.layers)
417
+ for layer in self.layers[:n_q]:
418
+ indices = layer.encode(residual)
419
+ quantized = layer.decode(indices)
420
+ residual = residual - quantized
421
+ all_indices.append(indices)
422
+ out_indices = torch.stack(all_indices)
423
+ return out_indices
424
+
425
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
426
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
427
+ for i, indices in enumerate(q_indices):
428
+ layer = self.layers[i]
429
+ quantized = layer.decode(indices)
430
+ quantized_out = quantized_out + quantized
431
+ return quantized_out
higgs_audio/audio_processing/quantization/ddp_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ import subprocess
4
+ from datetime import datetime
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch.nn.parallel import DistributedDataParallel
10
+ from torch.nn.parallel.distributed import _find_tensors
11
+ import torch.optim
12
+ import torch.utils.data
13
+ from packaging import version
14
+ from omegaconf import OmegaConf
15
+
16
+
17
+ def set_random_seed(seed):
18
+ random.seed(seed)
19
+ np.random.seed(seed)
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed_all(seed)
22
+
23
+
24
+ def is_logging_process():
25
+ return not dist.is_initialized() or dist.get_rank() == 0
26
+
27
+
28
+ def get_logger(cfg, name=None):
29
+ # log_file_path is used when unit testing
30
+ if is_logging_process():
31
+ logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_config, resolve=True))
32
+ return logging.getLogger(name)
33
+
34
+
35
+ # from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
36
+ class SyncFunction(torch.autograd.Function):
37
+ @staticmethod
38
+ # @torch.no_grad()
39
+ def forward(ctx, tensor):
40
+ ctx.batch_size = tensor.shape[0]
41
+
42
+ gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
43
+
44
+ torch.distributed.all_gather(gathered_tensor, tensor)
45
+ gathered_tensor = torch.cat(gathered_tensor, 0)
46
+
47
+ return gathered_tensor
48
+
49
+ @staticmethod
50
+ def backward(ctx, grad_output):
51
+ grad_input = grad_output.clone()
52
+ torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
53
+
54
+ idx_from = torch.distributed.get_rank() * ctx.batch_size
55
+ idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
56
+ return grad_input[idx_from:idx_to]
57
+
58
+
59
+ def get_timestamp():
60
+ return datetime.now().strftime("%y%m%d-%H%M%S")
61
+
62
+
63
+ def get_commit_hash():
64
+ message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
65
+ return message.strip().decode("utf-8")
66
+
67
+
68
+ class DDP(DistributedDataParallel):
69
+ """
70
+ Override the forward call in lightning so it goes to training and validation step respectively
71
+ """
72
+
73
+ def forward(self, *inputs, **kwargs): # pragma: no cover
74
+ if version.parse(torch.__version__[:6]) < version.parse("1.11"):
75
+ self._sync_params()
76
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
77
+ assert len(self.device_ids) == 1
78
+ if self.module.training:
79
+ output = self.module.training_step(*inputs[0], **kwargs[0])
80
+ elif self.module.testing:
81
+ output = self.module.test_step(*inputs[0], **kwargs[0])
82
+ else:
83
+ output = self.module.validation_step(*inputs[0], **kwargs[0])
84
+ if torch.is_grad_enabled():
85
+ # We'll return the output object verbatim since it is a freeform
86
+ # object. We need to find any tensors in this object, though,
87
+ # because we need to figure out which parameters were used during
88
+ # this forward pass, to ensure we short circuit reduction for any
89
+ # unused parameters. Only if `find_unused_parameters` is set.
90
+ if self.find_unused_parameters:
91
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
92
+ else:
93
+ self.reducer.prepare_for_backward([])
94
+ else:
95
+ from torch.nn.parallel.distributed import (
96
+ logging,
97
+ Join,
98
+ _DDPSink,
99
+ _tree_flatten_with_rref,
100
+ _tree_unflatten_with_rref,
101
+ )
102
+
103
+ with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
104
+ if torch.is_grad_enabled() and self.require_backward_grad_sync:
105
+ self.logger.set_runtime_stats_and_log()
106
+ self.num_iterations += 1
107
+ self.reducer.prepare_for_forward()
108
+
109
+ # Notify the join context that this process has not joined, if
110
+ # needed
111
+ work = Join.notify_join_context(self)
112
+ if work:
113
+ self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size)
114
+
115
+ # Calling _rebuild_buckets before forward compuation,
116
+ # It may allocate new buckets before deallocating old buckets
117
+ # inside _rebuild_buckets. To save peak memory usage,
118
+ # call _rebuild_buckets before the peak memory usage increases
119
+ # during forward computation.
120
+ # This should be called only once during whole training period.
121
+ if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
122
+ logging.info("Reducer buckets have been rebuilt in this iteration.")
123
+ self._has_rebuilt_buckets = True
124
+
125
+ # sync params according to location (before/after forward) user
126
+ # specified as part of hook, if hook was specified.
127
+ buffer_hook_registered = hasattr(self, "buffer_hook")
128
+ if self._check_sync_bufs_pre_fwd():
129
+ self._sync_buffers()
130
+
131
+ if self._join_config.enable:
132
+ # Notify joined ranks whether they should sync in backwards pass or not.
133
+ self._check_global_requires_backward_grad_sync(is_joined_rank=False)
134
+
135
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
136
+ if self.module.training:
137
+ output = self.module.training_step(*inputs[0], **kwargs[0])
138
+ elif self.module.testing:
139
+ output = self.module.test_step(*inputs[0], **kwargs[0])
140
+ else:
141
+ output = self.module.validation_step(*inputs[0], **kwargs[0])
142
+
143
+ # sync params according to location (before/after forward) user
144
+ # specified as part of hook, if hook was specified.
145
+ if self._check_sync_bufs_post_fwd():
146
+ self._sync_buffers()
147
+
148
+ if torch.is_grad_enabled() and self.require_backward_grad_sync:
149
+ self.require_forward_param_sync = True
150
+ # We'll return the output object verbatim since it is a freeform
151
+ # object. We need to find any tensors in this object, though,
152
+ # because we need to figure out which parameters were used during
153
+ # this forward pass, to ensure we short circuit reduction for any
154
+ # unused parameters. Only if `find_unused_parameters` is set.
155
+ if self.find_unused_parameters and not self.static_graph:
156
+ # Do not need to populate this for static graph.
157
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
158
+ else:
159
+ self.reducer.prepare_for_backward([])
160
+ else:
161
+ self.require_forward_param_sync = False
162
+
163
+ # TODO: DDPSink is currently enabled for unused parameter detection and
164
+ # static graph training for first iteration.
165
+ if (self.find_unused_parameters and not self.static_graph) or (
166
+ self.static_graph and self.num_iterations == 1
167
+ ):
168
+ state_dict = {
169
+ "static_graph": self.static_graph,
170
+ "num_iterations": self.num_iterations,
171
+ }
172
+
173
+ output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output)
174
+ output_placeholders = [None for _ in range(len(output_tensor_list))]
175
+ # Do not touch tensors that have no grad_fn, which can cause issues
176
+ # such as https://github.com/pytorch/pytorch/issues/60733
177
+ for i, output in enumerate(output_tensor_list):
178
+ if torch.is_tensor(output) and output.grad_fn is None:
179
+ output_placeholders[i] = output
180
+
181
+ # When find_unused_parameters=True, makes tensors which require grad
182
+ # run through the DDPSink backward pass. When not all outputs are
183
+ # used in loss, this makes those corresponding tensors receive
184
+ # undefined gradient which the reducer then handles to ensure
185
+ # param.grad field is not touched and we don't error out.
186
+ passthrough_tensor_list = _DDPSink.apply(
187
+ self.reducer,
188
+ state_dict,
189
+ *output_tensor_list,
190
+ )
191
+ for i in range(len(output_placeholders)):
192
+ if output_placeholders[i] is None:
193
+ output_placeholders[i] = passthrough_tensor_list[i]
194
+
195
+ # Reconstruct output data structure.
196
+ output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref)
197
+ return output
higgs_audio/audio_processing/quantization/distrib.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Torch distributed utilities."""
8
+
9
+ import typing as tp
10
+
11
+ import torch
12
+
13
+
14
+ def rank():
15
+ if torch.distributed.is_initialized():
16
+ return torch.distributed.get_rank()
17
+ else:
18
+ return 0
19
+
20
+
21
+ def world_size():
22
+ if torch.distributed.is_initialized():
23
+ return torch.distributed.get_world_size()
24
+ else:
25
+ return 1
26
+
27
+
28
+ def is_distributed():
29
+ return world_size() > 1
30
+
31
+
32
+ def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
33
+ if is_distributed():
34
+ return torch.distributed.all_reduce(tensor, op)
35
+
36
+
37
+ def _is_complex_or_float(tensor):
38
+ return torch.is_floating_point(tensor) or torch.is_complex(tensor)
39
+
40
+
41
+ def _check_number_of_params(params: tp.List[torch.Tensor]):
42
+ # utility function to check that the number of params in all workers is the same,
43
+ # and thus avoid a deadlock with distributed all reduce.
44
+ if not is_distributed() or not params:
45
+ return
46
+ # print('params[0].device ', params[0].device)
47
+ tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
48
+ all_reduce(tensor)
49
+ if tensor.item() != len(params) * world_size():
50
+ # If not all the workers have the same number, for at least one of them,
51
+ # this inequality will be verified.
52
+ raise RuntimeError(
53
+ f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one."
54
+ )
55
+
56
+
57
+ def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
58
+ """Broadcast the tensors from the given parameters to all workers.
59
+ This can be used to ensure that all workers have the same model to start with.
60
+ """
61
+ if not is_distributed():
62
+ return
63
+ tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
64
+ _check_number_of_params(tensors)
65
+ handles = []
66
+ for tensor in tensors:
67
+ handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
68
+ handles.append(handle)
69
+ for handle in handles:
70
+ handle.wait()
71
+
72
+
73
+ def sync_buffer(buffers, average=True):
74
+ """
75
+ Sync grad for buffers. If average is False, broadcast instead of averaging.
76
+ """
77
+ if not is_distributed():
78
+ return
79
+ handles = []
80
+ for buffer in buffers:
81
+ if torch.is_floating_point(buffer.data):
82
+ if average:
83
+ handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
84
+ else:
85
+ handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
86
+ handles.append((buffer, handle))
87
+ for buffer, handle in handles:
88
+ handle.wait()
89
+ if average:
90
+ buffer.data /= world_size
91
+
92
+
93
+ def sync_grad(params):
94
+ """
95
+ Simpler alternative to DistributedDataParallel, that doesn't rely
96
+ on any black magic. For simple models it can also be as fast.
97
+ Just call this on your model parameters after the call to backward!
98
+ """
99
+ if not is_distributed():
100
+ return
101
+ handles = []
102
+ for p in params:
103
+ if p.grad is not None:
104
+ handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
105
+ handles.append((p, handle))
106
+ for p, handle in handles:
107
+ handle.wait()
108
+ p.grad.data /= world_size()
109
+
110
+
111
+ def average_metrics(metrics: tp.Dict[str, float], count=1.0):
112
+ """Average a dictionary of metrics across all workers, using the optional
113
+ `count` as unormalized weight.
114
+ """
115
+ if not is_distributed():
116
+ return metrics
117
+ keys, values = zip(*metrics.items())
118
+ device = "cuda" if torch.cuda.is_available() else "cpu"
119
+ tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
120
+ tensor *= count
121
+ all_reduce(tensor)
122
+ averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
123
+ return dict(zip(keys, averaged))
higgs_audio/audio_processing/quantization/vq.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Residual vector quantizer implementation."""
8
+
9
+ from dataclasses import dataclass, field
10
+ import math
11
+ import typing as tp
12
+
13
+ import torch
14
+ from torch import nn
15
+
16
+ # from .core_vq import ResidualVectorQuantization
17
+ from .core_vq_lsx_version import ResidualVectorQuantization
18
+
19
+
20
+ @dataclass
21
+ class QuantizedResult:
22
+ quantized: torch.Tensor
23
+ codes: torch.Tensor
24
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
25
+ penalty: tp.Optional[torch.Tensor] = None
26
+ metrics: dict = field(default_factory=dict)
27
+
28
+
29
+ class ResidualVectorQuantizer(nn.Module):
30
+ """Residual Vector Quantizer.
31
+ Args:
32
+ dimension (int): Dimension of the codebooks.
33
+ n_q (int): Number of residual vector quantizers used.
34
+ bins (int): Codebook size.
35
+ decay (float): Decay for exponential moving average over the codebooks.
36
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
37
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
38
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
39
+ that have an exponential moving average cluster size less than the specified threshold with
40
+ randomly selected vector from the current batch.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dimension: int = 256,
46
+ codebook_dim: int = None,
47
+ n_q: int = 8,
48
+ bins: int = 1024,
49
+ decay: float = 0.99,
50
+ kmeans_init: bool = True,
51
+ kmeans_iters: int = 50,
52
+ threshold_ema_dead_code: int = 2,
53
+ ):
54
+ super().__init__()
55
+ self.n_q = n_q
56
+ self.dimension = dimension
57
+ self.codebook_dim = codebook_dim
58
+ self.bins = bins
59
+ self.decay = decay
60
+ self.kmeans_init = kmeans_init
61
+ self.kmeans_iters = kmeans_iters
62
+ self.threshold_ema_dead_code = threshold_ema_dead_code
63
+ self.vq = ResidualVectorQuantization(
64
+ dim=self.dimension,
65
+ codebook_dim=self.codebook_dim,
66
+ codebook_size=self.bins,
67
+ num_quantizers=self.n_q,
68
+ decay=self.decay,
69
+ kmeans_init=self.kmeans_init,
70
+ kmeans_iters=self.kmeans_iters,
71
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
72
+ )
73
+
74
+ def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None): # -> QuantizedResult:
75
+ """Residual vector quantization on the given input tensor.
76
+ Args:
77
+ x (torch.Tensor): Input tensor.
78
+ sample_rate (int): Sample rate of the input tensor.
79
+ bandwidth (float): Target bandwidth.
80
+ Returns:
81
+ QuantizedResult:
82
+ The quantized (or approximately quantized) representation with
83
+ the associated bandwidth and any penalty term for the loss.
84
+ """
85
+ bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
86
+ n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
87
+ quantized, codes, commit_loss = self.vq(x, n_q=n_q)
88
+ bw = torch.tensor(n_q * bw_per_q).to(x)
89
+ return quantized, codes, bw, torch.mean(commit_loss)
90
+ # return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
91
+
92
+ def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int:
93
+ """Return n_q based on specified target bandwidth."""
94
+ bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
95
+ n_q = self.n_q
96
+ if bandwidth and bandwidth > 0.0:
97
+ n_q = int(max(1, math.floor(bandwidth / bw_per_q)))
98
+ return n_q
99
+
100
+ def get_bandwidth_per_quantizer(self, sample_rate: int):
101
+ """Return bandwidth per quantizer for a given input sample rate."""
102
+ return math.log2(self.bins) * sample_rate / 1000
103
+
104
+ def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
105
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
106
+ The RVQ encode method sets the appropriate number of quantizer to use
107
+ and returns indices for each quantizer.
108
+ """
109
+ n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
110
+ codes = self.vq.encode(x, n_q=n_q)
111
+ return codes
112
+
113
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
114
+ """Decode the given codes to the quantized representation."""
115
+ quantized = self.vq.decode(codes)
116
+ return quantized
higgs_audio/audio_processing/semantic_module.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on code from: https://github.com/zhenye234/xcodec
2
+ # Licensed under MIT License
3
+ # Modifications by BosonAI
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class Conv1d1x1(nn.Conv1d):
10
+ """1x1 Conv1d."""
11
+
12
+ def __init__(self, in_channels, out_channels, bias=True):
13
+ super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias)
14
+
15
+
16
+ class Conv1d(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_channels: int,
20
+ out_channels: int,
21
+ kernel_size: int,
22
+ stride: int = 1,
23
+ padding: int = -1,
24
+ dilation: int = 1,
25
+ groups: int = 1,
26
+ bias: bool = True,
27
+ ):
28
+ super().__init__()
29
+ self.in_channels = in_channels
30
+ self.out_channels = out_channels
31
+ self.kernel_size = kernel_size
32
+ if padding < 0:
33
+ padding = (kernel_size - 1) // 2 * dilation
34
+ self.dilation = dilation
35
+ self.conv = nn.Conv1d(
36
+ in_channels=in_channels,
37
+ out_channels=out_channels,
38
+ kernel_size=kernel_size,
39
+ stride=stride,
40
+ padding=padding,
41
+ dilation=dilation,
42
+ groups=groups,
43
+ bias=bias,
44
+ )
45
+
46
+ def forward(self, x):
47
+ """
48
+ Args:
49
+ x (Tensor): Float tensor variable with the shape (B, C, T).
50
+ Returns:
51
+ Tensor: Float tensor variable with the shape (B, C, T).
52
+ """
53
+ x = self.conv(x)
54
+ return x
55
+
56
+
57
+ class ResidualUnit(nn.Module):
58
+ def __init__(
59
+ self,
60
+ in_channels: int,
61
+ out_channels: int,
62
+ kernel_size=3,
63
+ dilation=1,
64
+ bias=False,
65
+ nonlinear_activation="ELU",
66
+ nonlinear_activation_params={},
67
+ ):
68
+ super().__init__()
69
+ self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
70
+ self.conv1 = Conv1d(
71
+ in_channels=in_channels,
72
+ out_channels=out_channels,
73
+ kernel_size=kernel_size,
74
+ stride=1,
75
+ dilation=dilation,
76
+ bias=bias,
77
+ )
78
+ self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
79
+
80
+ def forward(self, x):
81
+ y = self.conv1(self.activation(x))
82
+ y = self.conv2(self.activation(y))
83
+ return x + y
84
+
85
+
86
+ class ConvTranspose1d(nn.Module):
87
+ def __init__(
88
+ self,
89
+ in_channels: int,
90
+ out_channels: int,
91
+ kernel_size: int,
92
+ stride: int,
93
+ padding=-1,
94
+ output_padding=-1,
95
+ groups=1,
96
+ bias=True,
97
+ ):
98
+ super().__init__()
99
+ if padding < 0:
100
+ padding = (stride + 1) // 2
101
+ if output_padding < 0:
102
+ output_padding = 1 if stride % 2 else 0
103
+ self.deconv = nn.ConvTranspose1d(
104
+ in_channels=in_channels,
105
+ out_channels=out_channels,
106
+ kernel_size=kernel_size,
107
+ stride=stride,
108
+ padding=padding,
109
+ output_padding=output_padding,
110
+ groups=groups,
111
+ bias=bias,
112
+ )
113
+
114
+ def forward(self, x):
115
+ """
116
+ Args:
117
+ x (Tensor): Float tensor variable with the shape (B, C, T).
118
+ Returns:
119
+ Tensor: Float tensor variable with the shape (B, C', T').
120
+ """
121
+ x = self.deconv(x)
122
+ return x
123
+
124
+
125
+ class EncoderBlock(nn.Module):
126
+ def __init__(
127
+ self,
128
+ in_channels: int,
129
+ out_channels: int,
130
+ stride: int,
131
+ dilations=(1, 1),
132
+ unit_kernel_size=3,
133
+ bias=True,
134
+ ):
135
+ super().__init__()
136
+ self.res_units = torch.nn.ModuleList()
137
+ for dilation in dilations:
138
+ self.res_units += [
139
+ ResidualUnit(
140
+ in_channels,
141
+ in_channels,
142
+ kernel_size=unit_kernel_size,
143
+ dilation=dilation,
144
+ )
145
+ ]
146
+ self.num_res = len(self.res_units)
147
+
148
+ self.conv = Conv1d(
149
+ in_channels=in_channels,
150
+ out_channels=out_channels,
151
+ kernel_size=3 if stride == 1 else (2 * stride), # special case: stride=1, do not use kernel=2
152
+ stride=stride,
153
+ bias=bias,
154
+ )
155
+
156
+ def forward(self, x):
157
+ for idx in range(self.num_res):
158
+ x = self.res_units[idx](x)
159
+ x = self.conv(x)
160
+ return x
161
+
162
+
163
+ class Encoder(nn.Module):
164
+ def __init__(
165
+ self,
166
+ input_channels: int,
167
+ encode_channels: int,
168
+ channel_ratios=(1, 1),
169
+ strides=(1, 1),
170
+ kernel_size=3,
171
+ bias=True,
172
+ block_dilations=(1, 1),
173
+ unit_kernel_size=3,
174
+ ):
175
+ super().__init__()
176
+ assert len(channel_ratios) == len(strides)
177
+
178
+ self.conv = Conv1d(
179
+ in_channels=input_channels,
180
+ out_channels=encode_channels,
181
+ kernel_size=kernel_size,
182
+ stride=1,
183
+ bias=False,
184
+ )
185
+ self.conv_blocks = torch.nn.ModuleList()
186
+ in_channels = encode_channels
187
+ for idx, stride in enumerate(strides):
188
+ out_channels = int(encode_channels * channel_ratios[idx]) # could be float
189
+ self.conv_blocks += [
190
+ EncoderBlock(
191
+ in_channels,
192
+ out_channels,
193
+ stride,
194
+ dilations=block_dilations,
195
+ unit_kernel_size=unit_kernel_size,
196
+ bias=bias,
197
+ )
198
+ ]
199
+ in_channels = out_channels
200
+ self.num_blocks = len(self.conv_blocks)
201
+ self.out_channels = out_channels
202
+
203
+ def forward(self, x):
204
+ x = self.conv(x)
205
+ for i in range(self.num_blocks):
206
+ x = self.conv_blocks[i](x)
207
+ return x
208
+
209
+
210
+ class DecoderBlock(nn.Module):
211
+ """Decoder block (no up-sampling)"""
212
+
213
+ def __init__(
214
+ self,
215
+ in_channels: int,
216
+ out_channels: int,
217
+ stride: int,
218
+ dilations=(1, 1),
219
+ unit_kernel_size=3,
220
+ bias=True,
221
+ ):
222
+ super().__init__()
223
+
224
+ if stride == 1:
225
+ self.conv = Conv1d(
226
+ in_channels=in_channels,
227
+ out_channels=out_channels,
228
+ kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape
229
+ stride=stride,
230
+ bias=bias,
231
+ )
232
+ else:
233
+ self.conv = ConvTranspose1d(
234
+ in_channels=in_channels,
235
+ out_channels=out_channels,
236
+ kernel_size=(2 * stride),
237
+ stride=stride,
238
+ bias=bias,
239
+ )
240
+
241
+ self.res_units = torch.nn.ModuleList()
242
+ for idx, dilation in enumerate(dilations):
243
+ self.res_units += [
244
+ ResidualUnit(
245
+ out_channels,
246
+ out_channels,
247
+ kernel_size=unit_kernel_size,
248
+ dilation=dilation,
249
+ )
250
+ ]
251
+ self.num_res = len(self.res_units)
252
+
253
+ def forward(self, x):
254
+ x = self.conv(x)
255
+ for idx in range(self.num_res):
256
+ x = self.res_units[idx](x)
257
+ return x
258
+
259
+
260
+ class Decoder(nn.Module):
261
+ def __init__(
262
+ self,
263
+ code_dim: int,
264
+ output_channels: int,
265
+ decode_channels: int,
266
+ channel_ratios=(1, 1),
267
+ strides=(1, 1),
268
+ kernel_size=3,
269
+ bias=True,
270
+ block_dilations=(1, 1),
271
+ unit_kernel_size=3,
272
+ ):
273
+ super().__init__()
274
+ assert len(channel_ratios) == len(strides)
275
+
276
+ self.conv1 = Conv1d(
277
+ in_channels=code_dim,
278
+ out_channels=int(decode_channels * channel_ratios[0]),
279
+ kernel_size=kernel_size,
280
+ stride=1,
281
+ bias=False,
282
+ )
283
+
284
+ self.conv_blocks = torch.nn.ModuleList()
285
+ for idx, stride in enumerate(strides):
286
+ in_channels = int(decode_channels * channel_ratios[idx])
287
+ if idx < (len(channel_ratios) - 1):
288
+ out_channels = int(decode_channels * channel_ratios[idx + 1])
289
+ else:
290
+ out_channels = decode_channels
291
+ self.conv_blocks += [
292
+ DecoderBlock(
293
+ in_channels,
294
+ out_channels,
295
+ stride,
296
+ dilations=block_dilations,
297
+ unit_kernel_size=unit_kernel_size,
298
+ bias=bias,
299
+ )
300
+ ]
301
+ self.num_blocks = len(self.conv_blocks)
302
+
303
+ self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False)
304
+
305
+ def forward(self, z):
306
+ x = self.conv1(z)
307
+ for i in range(self.num_blocks):
308
+ x = self.conv_blocks[i](x)
309
+ x = self.conv2(x)
310
+ return x
higgs_audio/constants.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ AUDIO_IN_TOKEN = "<|AUDIO|>"
2
+ AUDIO_OUT_TOKEN = "<|AUDIO_OUT|>"
3
+ EOS_TOKEN = "<|end_of_text|>"
higgs_audio/data_collator/__init__.py ADDED
File without changes
higgs_audio/data_collator/higgs_audio_collator.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import math
5
+ import numpy as np
6
+ from typing import List, Tuple, Dict
7
+
8
+ from dataclasses import dataclass
9
+ from typing import List, Optional
10
+ from transformers.models.whisper.processing_whisper import WhisperProcessor
11
+
12
+ from ..dataset.chatml_dataset import ChatMLDatasetSample, RankedChatMLDatasetSampleTuple
13
+ from ..model.utils import build_delay_pattern_mask
14
+
15
+
16
+ def _ceil_to_nearest(n, round_to):
17
+ return (n + round_to - 1) // round_to * round_to
18
+
19
+
20
+ @dataclass
21
+ class HiggsAudioBatchInput:
22
+ input_ids: torch.LongTensor # shape (bsz, seq_len).
23
+ attention_mask: torch.Tensor # shape (bsz, seq_len).
24
+ audio_features: Optional[torch.Tensor] # shape (num_audio_in, feature_dim, max_mel_seq_len).
25
+ audio_feature_attention_mask: Optional[torch.Tensor] # shape (num_audio_in, max_mel_seq_len).
26
+ audio_out_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
27
+ audio_out_ids_start: Optional[torch.LongTensor] # shape (num_audio_out,)
28
+ # The audio_out_ids_start_group_loc has the same length as audio_out_ids_start. It is used to recover group location in a batch for an audio segment
29
+ # Currently, we concatenante audio segments along dim 0 to handle variadic audio segment length. However, in the alignment stage, we need the location information
30
+ # For example,
31
+ # audio_out_ids_start = [0, 2, 4, 8]; and the first two audio segments come from the same sample in a batch, and other two come from different samples.
32
+ # This is a batch of 3 samples, then we will have the group location as:
33
+ # audio_out_ids_start_group_loc = [0, 0, 1, 2]
34
+ audio_out_ids_start_group_loc: Optional[
35
+ torch.LongTensor
36
+ ] # shape (num_audio_out,), specify which a sample's group location in the batch
37
+ audio_in_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_in_total_length)
38
+ audio_in_ids_start: Optional[torch.LongTensor] # shape (num_audio_in,)
39
+ label_ids: Optional[torch.LongTensor] # shape (bsz, seq_len)
40
+ label_audio_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
41
+ reward: Optional[float] = None
42
+
43
+
44
+ class HiggsAudioSampleCollator:
45
+ """Sample collator for Higgs-Audio model.
46
+
47
+ Args:
48
+ whisper_processor (WhisperProcessor): The whisper processor.
49
+ audio_in_token_id (int): The token id for audio-in.
50
+ audio_out_token_id (int): The token id for audio-out.
51
+ pad_token_id (int): The token id for padding.
52
+ audio_stream_bos_id (int): The token id for audio-stream beginning of sentence.
53
+ audio_stream_eos_id (int): The token id for audio-stream end of sentence.
54
+ round_to (int): The round-to value.
55
+ pad_left (bool): Whether to pad left.
56
+ return_audio_in_tokens (bool): Whether to return audio-in tokens.
57
+ use_delay_pattern (bool): Whether to use delay pattern.
58
+ disable_audio_codes_transform (bool): Whether to add bos and eos tokens to audio codes.
59
+ chunk_size_seconds (int): The chunk size in seconds.
60
+ add_new_bos_eos_for_long_chunk (bool): Whether to add new bos and eos tokens for long chunks.
61
+ mask_audio_out_token_label (bool): Whether to always mask the label associated with <|AUDIO_OUT|> token. Since we will always have `<|AUDIO_OUT|>` after `<|audio_bos|>`, we can safely mask <|AUDIO_OUT|>.
62
+
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ whisper_processor: WhisperProcessor,
68
+ audio_in_token_id,
69
+ audio_out_token_id,
70
+ pad_token_id,
71
+ audio_stream_bos_id,
72
+ audio_stream_eos_id,
73
+ round_to=8,
74
+ pad_left=False,
75
+ encode_whisper_embed=True,
76
+ return_audio_in_tokens=True,
77
+ audio_num_codebooks=None,
78
+ use_delay_pattern=False,
79
+ disable_audio_codes_transform=False,
80
+ chunk_size_seconds=30, # Maximum duration for each chunk
81
+ add_new_bos_eos_for_long_chunk=True,
82
+ mask_audio_out_token_label=True,
83
+ ):
84
+ self.whisper_processor = whisper_processor
85
+ self.round_to = round_to
86
+ self.pad_left = pad_left
87
+ self.audio_in_token_id = audio_in_token_id
88
+ self.audio_out_token_id = audio_out_token_id
89
+ self.audio_stream_bos_id = audio_stream_bos_id
90
+ self.audio_stream_eos_id = audio_stream_eos_id
91
+ self.pad_token_id = pad_token_id
92
+ self.encode_whisper_embed = encode_whisper_embed
93
+ self.return_audio_in_tokens = return_audio_in_tokens
94
+ self.audio_num_codebooks = audio_num_codebooks
95
+ self.use_delay_pattern = use_delay_pattern
96
+ if encode_whisper_embed:
97
+ self.chunk_size_seconds = chunk_size_seconds
98
+ self.chunk_size_samples = int(chunk_size_seconds * whisper_processor.feature_extractor.sampling_rate)
99
+ else:
100
+ self.chunk_size_seconds = None
101
+ self.chunk_size_samples = None
102
+ self.disable_audio_codes_transform = disable_audio_codes_transform
103
+ self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk
104
+ self.mask_audio_out_token_label = mask_audio_out_token_label
105
+
106
+ def _process_and_duplicate_audio_tokens(
107
+ self,
108
+ input_ids: torch.Tensor,
109
+ audio_idx: int,
110
+ wv: torch.Tensor,
111
+ sr: int,
112
+ labels: Optional[torch.Tensor] = None,
113
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
114
+ """Process long audio and duplicate corresponding audio tokens.
115
+
116
+ Args:
117
+ input_ids: Input token ids
118
+ audio_idx: Index of the audio token in the sequence
119
+ wv: Audio waveform
120
+ sr: Sample rate
121
+ labels: Optional label ids to be duplicated alongside input ids
122
+
123
+ Returns:
124
+ Tuple of:
125
+ - New input ids with duplicated audio tokens
126
+ - New label ids (if labels were provided) or None
127
+ - Number of chunks created
128
+ """
129
+ # Calculate number of chunks needed
130
+ total_samples = len(wv)
131
+ num_chunks = math.ceil(total_samples / self.chunk_size_samples)
132
+
133
+ if num_chunks <= 1:
134
+ return input_ids, labels, 1
135
+
136
+ # Get the three tokens: <|audio_bos|><|AUDIO|><|audio_eos|>
137
+ audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2]
138
+ # Duplicate sequence for each chunk
139
+ duplicated_sequence = audio_token_seq.repeat(num_chunks)
140
+
141
+ # Create new input_ids with duplicated tokens
142
+ new_input_ids = torch.cat(
143
+ [
144
+ input_ids[: audio_idx - 1],
145
+ duplicated_sequence,
146
+ input_ids[audio_idx + 2 :],
147
+ ]
148
+ )
149
+
150
+ # If labels are provided, duplicate them as well
151
+ new_labels = None
152
+ if labels is not None:
153
+ label_seq = labels[audio_idx - 1 : audio_idx + 2]
154
+ duplicated_labels = label_seq.repeat(num_chunks)
155
+ new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]])
156
+
157
+ return new_input_ids, new_labels, num_chunks
158
+
159
+ def __call__(self, batch: List[ChatMLDatasetSample]):
160
+ """Collate the input data with support for long audio processing."""
161
+
162
+ label_ids = None
163
+ label_audio_ids = None
164
+ if all([ele.label_ids is None for ele in batch]):
165
+ return_labels = False
166
+ else:
167
+ return_labels = True
168
+
169
+ if self.encode_whisper_embed:
170
+ # Process each sample in the batch to handle long audio
171
+ # TODO(?) The implementation here can be optimized.
172
+ processed_batch = []
173
+ for i in range(len(batch)):
174
+ sample = batch[i]
175
+ audio_in_mask = sample.input_ids == self.audio_in_token_id
176
+ audio_in_indices = torch.where(audio_in_mask)[0]
177
+ audio_out_mask = sample.input_ids == self.audio_out_token_id
178
+
179
+ # Process each audio token and duplicate if needed
180
+ modified_input_ids = sample.input_ids
181
+ modified_labels = sample.label_ids if return_labels else None
182
+ modified_waveforms_concat = []
183
+ modified_waveforms_start = []
184
+ modified_sample_rate = []
185
+ offset = 0 # Track position changes from duplicating tokens
186
+ curr_wv_offset = 0
187
+
188
+ # Process input audio tokens
189
+ for idx, audio_idx in enumerate(audio_in_indices):
190
+ # Get the audio for this token
191
+ wv, sr = sample.get_wv(idx) # Use idx since we want the original audio index
192
+ if sr != self.whisper_processor.feature_extractor.sampling_rate:
193
+ resampled_wv = librosa.resample(
194
+ wv.cpu().numpy(),
195
+ orig_sr=sr,
196
+ target_sr=self.whisper_processor.feature_extractor.sampling_rate,
197
+ )
198
+ else:
199
+ resampled_wv = wv.cpu().numpy()
200
+ wv = torch.tensor(resampled_wv, device=wv.device)
201
+ sr = self.whisper_processor.feature_extractor.sampling_rate
202
+
203
+ # Process and duplicate tokens if necessary
204
+ token_pos = audio_idx + offset
205
+ modified_input_ids, modified_labels, num_chunks = self._process_and_duplicate_audio_tokens(
206
+ modified_input_ids, token_pos, wv, sr, modified_labels
207
+ )
208
+
209
+ # Update audio data
210
+ for chunk_idx in range(num_chunks):
211
+ chunk_start = chunk_idx * self.chunk_size_samples
212
+ chunk_end = min((chunk_idx + 1) * self.chunk_size_samples, len(wv))
213
+ chunk_wv = wv[chunk_start:chunk_end]
214
+ modified_waveforms_concat.append(chunk_wv)
215
+ modified_waveforms_start.append(curr_wv_offset)
216
+ curr_wv_offset += len(chunk_wv)
217
+ modified_sample_rate.append(sr)
218
+
219
+ # Update offset for next iteration
220
+ offset += (num_chunks - 1) * 3 # Each new chunk adds 3 more tokens
221
+
222
+ # Create new sample with modified tokens and audio data
223
+ processed_sample = ChatMLDatasetSample(
224
+ input_ids=modified_input_ids,
225
+ label_ids=modified_labels if return_labels else sample.label_ids,
226
+ audio_ids_concat=sample.audio_ids_concat,
227
+ audio_ids_start=sample.audio_ids_start,
228
+ audio_waveforms_concat=torch.cat(modified_waveforms_concat)
229
+ if modified_waveforms_concat
230
+ else sample.audio_waveforms_concat,
231
+ audio_waveforms_start=torch.tensor(modified_waveforms_start, dtype=torch.long)
232
+ if modified_waveforms_start
233
+ else sample.audio_waveforms_start,
234
+ audio_sample_rate=torch.tensor(modified_sample_rate)
235
+ if modified_sample_rate
236
+ else sample.audio_sample_rate,
237
+ audio_speaker_indices=torch.tensor([]),
238
+ # FIXME(sxjscience): The logic here is not correct for audio_label_ids_concat.
239
+ audio_label_ids_concat=sample.audio_label_ids_concat,
240
+ )
241
+ # audio_in_chunk_len = len(torch.where(modified_input_ids == self.audio_in_token_id)[0])
242
+ # assert audio_in_chunk_len == processed_sample.num_audios(), f"Mismatch: audio_in_chunk_len={audio_in_chunk_len}, processed_sample.num_audios()={processed_sample.num_audios()}"
243
+ processed_batch.append(processed_sample)
244
+ else:
245
+ processed_batch = batch
246
+
247
+ # Get the max sequence length based on processed batch
248
+ max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to)
249
+
250
+ # Get the ids for audio-in and audio-out for each batch
251
+ audio_in_wv_l = []
252
+ audio_in_ids_l = []
253
+ audio_out_ids_l = []
254
+ audio_out_ids_group_loc_l = []
255
+ audio_in_label_ids_l = None
256
+ audio_out_label_ids_l = None
257
+ reward_l = []
258
+
259
+ if return_labels:
260
+ audio_out_no_train_flag = [] # Whether the audio-out data should be trained on or not.
261
+
262
+ # Process the audio inputs and outputs
263
+ for i in range(len(processed_batch)):
264
+ audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id
265
+ audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id
266
+ audio_ids = torch.ones_like(processed_batch[i].input_ids)
267
+ audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1
268
+ audio_in_ids = audio_ids[audio_in_mask]
269
+ audio_out_ids = audio_ids[audio_out_mask]
270
+
271
+ if return_labels:
272
+ audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0)
273
+ if self.mask_audio_out_token_label:
274
+ processed_batch[i].label_ids[audio_out_mask] = -100
275
+
276
+ # Process audio inputs
277
+ if self.return_audio_in_tokens:
278
+ audio_in_ids_l.extend(
279
+ [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids]
280
+ )
281
+ if processed_batch[i].audio_label_ids_concat is not None:
282
+ if audio_in_label_ids_l is None:
283
+ audio_in_label_ids_l = []
284
+ audio_in_label_ids_l.extend(
285
+ [
286
+ processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
287
+ for idx in audio_in_ids
288
+ ]
289
+ )
290
+
291
+ audio_out_ids_l.extend(
292
+ [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids]
293
+ )
294
+ audio_out_ids_group_loc_l.append(i)
295
+ if processed_batch[i].reward is not None:
296
+ reward_l.append(processed_batch[i].reward)
297
+
298
+ if processed_batch[i].audio_label_ids_concat is not None:
299
+ if audio_out_label_ids_l is None:
300
+ audio_out_label_ids_l = []
301
+ audio_out_label_ids_l.extend(
302
+ [
303
+ processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
304
+ for idx in audio_out_ids
305
+ ]
306
+ )
307
+
308
+ if self.encode_whisper_embed:
309
+ for idx in audio_in_ids:
310
+ wv, sr = processed_batch[i].get_wv(idx)
311
+ resampled_wv = wv.cpu().numpy()
312
+ # Split long audio into chunks
313
+ total_samples = len(resampled_wv)
314
+ for chunk_start in range(0, total_samples, self.chunk_size_samples):
315
+ chunk_end = min(chunk_start + self.chunk_size_samples, total_samples)
316
+ chunk = resampled_wv[chunk_start:chunk_end]
317
+ audio_in_wv_l.append(chunk)
318
+ # assert len(audio_in_wv_l) == processed_batch[i].num_audios(), \
319
+ # f"Assertion failed: Mismatch in number of audios. " \
320
+ # f"Expected {processed_batch[i].num_audios()}, but got {len(audio_in_wv_l)} at index {i}."
321
+
322
+ if return_labels:
323
+ audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)
324
+
325
+ # Process all audio features
326
+ if len(audio_in_wv_l) > 0:
327
+ feature_ret = self.whisper_processor.feature_extractor(
328
+ audio_in_wv_l,
329
+ sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
330
+ return_attention_mask=True,
331
+ padding="max_length",
332
+ )
333
+ audio_features = torch.from_numpy(feature_ret["input_features"])
334
+ audio_feature_attention_mask = torch.from_numpy(feature_ret["attention_mask"])
335
+ else:
336
+ if self.encode_whisper_embed:
337
+ audio_features = torch.zeros(
338
+ (
339
+ 0,
340
+ self.whisper_processor.feature_extractor.feature_size,
341
+ self.whisper_processor.feature_extractor.nb_max_frames,
342
+ ),
343
+ dtype=torch.float32,
344
+ )
345
+ audio_feature_attention_mask = torch.zeros(
346
+ (0, self.whisper_processor.feature_extractor.nb_max_frames),
347
+ dtype=torch.int32,
348
+ )
349
+ else:
350
+ audio_features = None
351
+ audio_feature_attention_mask = None
352
+
353
+ # Process audio input tokens
354
+ if len(audio_in_ids_l) > 0:
355
+ # Append audio-stream-bos and eos tokens
356
+ new_audio_in_ids_l = []
357
+ for ele in audio_in_ids_l:
358
+ if self.disable_audio_codes_transform:
359
+ # Do not add audio-stream-bos or eos tokens.
360
+ # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
361
+ audio_codes = ele
362
+ else:
363
+ audio_codes = torch.cat(
364
+ [
365
+ torch.full(
366
+ (ele.shape[0], 1),
367
+ self.audio_stream_bos_id,
368
+ dtype=torch.long,
369
+ ),
370
+ ele,
371
+ torch.full(
372
+ (ele.shape[0], 1),
373
+ self.audio_stream_eos_id,
374
+ dtype=torch.long,
375
+ ),
376
+ ],
377
+ dim=1,
378
+ )
379
+ if self.use_delay_pattern:
380
+ audio_codes = build_delay_pattern_mask(
381
+ audio_codes.unsqueeze(0),
382
+ bos_token_id=self.audio_stream_bos_id,
383
+ pad_token_id=self.audio_stream_eos_id,
384
+ )[0].squeeze(0)
385
+ new_audio_in_ids_l.append(audio_codes)
386
+ audio_in_ids = torch.cat(new_audio_in_ids_l, dim=1).long()
387
+ audio_in_ids_start = torch.cumsum(
388
+ torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_in_ids_l[:-1]]),
389
+ dim=0,
390
+ )
391
+ else:
392
+ audio_in_ids = torch.zeros((0, 0), dtype=torch.long)
393
+ audio_in_ids_start = torch.zeros(0, dtype=torch.long)
394
+
395
+ # Process audio output tokens
396
+ audio_out_ids_start_group_loc = None
397
+ if len(audio_out_ids_l) > 0:
398
+ new_audio_out_ids_l = []
399
+ label_audio_ids_l = []
400
+ for idx, ele in enumerate(audio_out_ids_l):
401
+ if self.disable_audio_codes_transform:
402
+ # Do not add audio-stream-bos or eos tokens.
403
+ # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
404
+ audio_codes = ele
405
+ if return_labels:
406
+ label_audio_ids = audio_out_label_ids_l[idx]
407
+ else:
408
+ audio_codes = torch.cat(
409
+ [
410
+ torch.full(
411
+ (ele.shape[0], 1),
412
+ self.audio_stream_bos_id,
413
+ dtype=torch.long,
414
+ ),
415
+ ele,
416
+ torch.full(
417
+ (ele.shape[0], 1),
418
+ self.audio_stream_eos_id,
419
+ dtype=torch.long,
420
+ ),
421
+ ],
422
+ dim=1,
423
+ )
424
+ if return_labels:
425
+ label_audio_ids = torch.cat(
426
+ [
427
+ torch.full((ele.shape[0], 1), -100, dtype=torch.long),
428
+ ele,
429
+ torch.full(
430
+ (ele.shape[0], 1),
431
+ self.audio_stream_eos_id,
432
+ dtype=torch.long,
433
+ ),
434
+ ],
435
+ dim=1,
436
+ )
437
+ if self.use_delay_pattern:
438
+ audio_codes = build_delay_pattern_mask(
439
+ audio_codes.unsqueeze(0),
440
+ bos_token_id=self.audio_stream_bos_id,
441
+ pad_token_id=self.audio_stream_eos_id,
442
+ )[0].squeeze(0)
443
+ if return_labels:
444
+ label_audio_ids = build_delay_pattern_mask(
445
+ label_audio_ids.unsqueeze(0),
446
+ bos_token_id=-100,
447
+ pad_token_id=-100,
448
+ )[0].squeeze(0)
449
+ new_audio_out_ids_l.append(audio_codes)
450
+
451
+ if return_labels:
452
+ if audio_out_no_train_flag[idx]:
453
+ label_audio_ids[:] = -100
454
+ label_audio_ids_l.append(label_audio_ids)
455
+
456
+ audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long()
457
+ if return_labels:
458
+ label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long()
459
+ audio_out_ids_start = torch.cumsum(
460
+ torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]),
461
+ dim=0,
462
+ )
463
+ audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long)
464
+ else:
465
+ audio_out_ids = torch.zeros((0, 0), dtype=torch.long)
466
+ audio_out_ids_start = torch.zeros(0, dtype=torch.long)
467
+ if return_labels:
468
+ label_audio_ids = torch.zeros((0, 0), dtype=torch.long)
469
+
470
+ reward = torch.tensor(reward_l, dtype=torch.float32)
471
+
472
+ # Handle padding for input ids and attention mask
473
+ if self.pad_left:
474
+ input_ids = torch.stack(
475
+ [
476
+ F.pad(
477
+ ele.input_ids,
478
+ (max_seq_length - len(ele.input_ids), 0),
479
+ value=self.pad_token_id,
480
+ )
481
+ for ele in processed_batch
482
+ ]
483
+ )
484
+ if return_labels:
485
+ label_ids = torch.stack(
486
+ [
487
+ F.pad(
488
+ ele.label_ids,
489
+ (max_seq_length - len(ele.label_ids), 0),
490
+ value=-100,
491
+ )
492
+ for ele in processed_batch
493
+ ]
494
+ )
495
+ attention_mask = torch.stack(
496
+ [
497
+ F.pad(
498
+ torch.ones_like(ele.input_ids),
499
+ (max_seq_length - len(ele.input_ids), 0),
500
+ value=0,
501
+ )
502
+ for ele in processed_batch
503
+ ]
504
+ )
505
+ else:
506
+ input_ids = torch.stack(
507
+ [
508
+ F.pad(
509
+ ele.input_ids,
510
+ (0, max_seq_length - len(ele.input_ids)),
511
+ value=self.pad_token_id,
512
+ )
513
+ for ele in processed_batch
514
+ ]
515
+ )
516
+ if return_labels:
517
+ label_ids = torch.stack(
518
+ [
519
+ F.pad(
520
+ ele.label_ids,
521
+ (0, max_seq_length - len(ele.label_ids)),
522
+ value=-100,
523
+ )
524
+ for ele in processed_batch
525
+ ]
526
+ )
527
+ attention_mask = torch.stack(
528
+ [
529
+ F.pad(
530
+ torch.ones_like(ele.input_ids),
531
+ (0, max_seq_length - len(ele.input_ids)),
532
+ value=0,
533
+ )
534
+ for ele in processed_batch
535
+ ]
536
+ )
537
+
538
+ if not self.return_audio_in_tokens:
539
+ audio_in_ids = None
540
+ audio_in_ids_start = None
541
+
542
+ # Apply audio_num_codebooks limit if specified
543
+ if self.audio_num_codebooks is not None:
544
+ if audio_in_ids is not None:
545
+ audio_in_ids = audio_in_ids[: self.audio_num_codebooks]
546
+ if audio_out_ids is not None:
547
+ audio_out_ids = audio_out_ids[: self.audio_num_codebooks]
548
+ if label_audio_ids is not None:
549
+ label_audio_ids = label_audio_ids[: self.audio_num_codebooks]
550
+
551
+ return HiggsAudioBatchInput(
552
+ input_ids=input_ids,
553
+ attention_mask=attention_mask,
554
+ audio_features=audio_features,
555
+ audio_feature_attention_mask=audio_feature_attention_mask,
556
+ audio_out_ids=audio_out_ids,
557
+ audio_out_ids_start=audio_out_ids_start,
558
+ audio_out_ids_start_group_loc=audio_out_ids_start_group_loc,
559
+ audio_in_ids=audio_in_ids,
560
+ audio_in_ids_start=audio_in_ids_start,
561
+ label_ids=label_ids,
562
+ label_audio_ids=label_audio_ids,
563
+ reward=reward,
564
+ )
565
+
566
+
567
+ class HiggsAudioDPOSamplesCollator(HiggsAudioSampleCollator):
568
+ def __init__(self, *args, **kwargs):
569
+ super().__init__(*args, **kwargs)
570
+
571
+ def __call__(self, batch: List[RankedChatMLDatasetSampleTuple]) -> HiggsAudioBatchInput:
572
+ # flatten ranked chatml samples
573
+ chosen = []
574
+ rejected = []
575
+
576
+ for sample in batch:
577
+ chosen.append(sample.max_score_sample())
578
+ rejected.append(sample.min_score_sample())
579
+
580
+ merged = chosen
581
+ merged.extend(rejected)
582
+
583
+ return super().__call__(batch=merged)
higgs_audio/data_types.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Basic data types for multimodal ChatML format."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Optional, Union
5
+
6
+
7
+ @dataclass
8
+ class AudioContent:
9
+ audio_url: str
10
+ # Base64 encoded audio bytes
11
+ raw_audio: Optional[str] = None
12
+ offset: Optional[float] = None
13
+ duration: Optional[float] = None
14
+ row_id: Optional[int] = None
15
+ type: str = "audio"
16
+
17
+
18
+ @dataclass
19
+ class TextContent:
20
+ text: str
21
+ type: str = "text"
22
+
23
+
24
+ @dataclass
25
+ class Message:
26
+ role: str
27
+ content: Union[str, AudioContent, TextContent, List[Union[str, AudioContent, TextContent]]]
28
+ recipient: Optional[str] = None
29
+
30
+
31
+ @dataclass
32
+ class ChatMLSample:
33
+ """Dataclass to hold multimodal ChatML data."""
34
+
35
+ messages: List[Message]
36
+ start_index: Optional[int] = None # We will mask the messages[:start_index] when finetuning the LLM.
37
+ misc: Optional[Dict] = None
38
+ speaker: Optional[str] = None
higgs_audio/dataset/__init__.py ADDED
File without changes
higgs_audio/dataset/chatml_dataset.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dacite
2
+ import pandas as pd
3
+ import torch
4
+ import json
5
+
6
+ import numpy as np
7
+ import multiprocessing as mp
8
+
9
+ from dataclasses import dataclass, fields
10
+ from abc import ABC, abstractmethod
11
+ from typing import Union, List, Dict, Optional
12
+
13
+ from ..data_types import ChatMLSample, TextContent, AudioContent
14
+ from ..constants import AUDIO_IN_TOKEN, AUDIO_OUT_TOKEN
15
+
16
+ from loguru import logger
17
+
18
+ # Whisper processor, 30 sec -> 3000 features
19
+ # Then we divide 4 in the audio towker, we decrease 3000 features to 750, which gives 25 Hz
20
+ WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC = 25
21
+
22
+
23
+ @dataclass
24
+ class ChatMLDatasetSample:
25
+ input_ids: torch.LongTensor # Shape (seq_len,): The input text tokens.
26
+ label_ids: torch.LongTensor # Shape (seq_len,): The label ids.
27
+ audio_ids_concat: torch.LongTensor # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
28
+ # Here `audio_seq_len` is the length of the concatenated audio tokens.`
29
+ audio_ids_start: (
30
+ torch.LongTensor
31
+ ) # Shape (num_audios,): The start index of each audio token in the concatenated audio tokens.
32
+ audio_waveforms_concat: (
33
+ torch.Tensor
34
+ ) # Shape (total_wv_length,): The concatenated audio waveforms for audio-in features.
35
+ audio_waveforms_start: (
36
+ torch.LongTensor
37
+ ) # Shape (num_audios,): The start index of each audio waveform in the concatenated audio waveforms.
38
+ audio_sample_rate: torch.Tensor # Shape (num_audios,): The sampling rate of the audio waveforms.
39
+ audio_speaker_indices: (
40
+ torch.LongTensor
41
+ ) # Shape (num_audios,) -1 means unknown speaker: The speaker indices for each audio.
42
+ audio_label_ids_concat: Optional[torch.LongTensor] = (
43
+ None # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
44
+ )
45
+ # Here `audio_seq_len` is the length of the concatenated audio tokens.`
46
+ reward: Optional[float] = None
47
+
48
+ def num_audios(self):
49
+ return max(len(self.audio_waveforms_start), len(self.audio_ids_start))
50
+
51
+ def get_audio_codes(self, idx):
52
+ code_start = self.audio_ids_start[idx]
53
+ if idx < len(self.audio_ids_start) - 1:
54
+ code_end = self.audio_ids_start[idx + 1]
55
+ else:
56
+ code_end = self.audio_ids_concat.shape[-1]
57
+
58
+ return self.audio_ids_concat[:, code_start:code_end]
59
+
60
+ def get_audio_codes_labels(self, idx):
61
+ if self.audio_label_ids_concat is None:
62
+ return None
63
+ code_start = self.audio_ids_start[idx]
64
+ if idx < len(self.audio_ids_start) - 1:
65
+ code_end = self.audio_ids_start[idx + 1]
66
+ else:
67
+ code_end = self.audio_ids_concat.shape[-1]
68
+
69
+ return self.audio_label_ids_concat[:, code_start:code_end]
70
+
71
+ def get_wv(self, idx):
72
+ wv_start = self.audio_waveforms_start[idx]
73
+ sr = self.audio_sample_rate[idx]
74
+ if idx < len(self.audio_waveforms_start) - 1:
75
+ wv_end = self.audio_waveforms_start[idx + 1]
76
+ else:
77
+ wv_end = self.audio_waveforms_concat.shape[-1]
78
+ return self.audio_waveforms_concat[wv_start:wv_end], sr
79
+
80
+ def cal_num_tokens(
81
+ self,
82
+ encode_whisper_embed: bool = True,
83
+ encode_audio_in_tokens: bool = False,
84
+ encode_audio_out_tokens: bool = True,
85
+ audio_in_token_id: int = 128015,
86
+ audio_out_token_id: int = 128016,
87
+ ) -> int:
88
+ # we firstly exclude <|AUDIO|> and <|AUDIO_OUT|> because we do late merging and replace those position with actual audio features and audio token ids
89
+ # It's assumed that we always have audio_ids when audio_waveforms are there (but not vice-versa)
90
+ num_tokens = len(self.input_ids) - len(self.audio_ids_start)
91
+
92
+ if encode_whisper_embed and len(self.audio_waveforms_concat) > 0:
93
+ audio_lengths = torch.diff(self.audio_waveforms_start)
94
+ if len(audio_lengths):
95
+ # Sum before calling .item()
96
+ num_tokens += (
97
+ (
98
+ np.ceil(WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC * audio_lengths / self.audio_sample_rate[:-1])
99
+ ).sum()
100
+ ).item()
101
+ # add the last audio's token estimation
102
+ num_tokens += (
103
+ np.ceil(
104
+ WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC
105
+ * (self.audio_waveforms_concat.shape[0] - self.audio_waveforms_start[-1])
106
+ / self.audio_sample_rate[-1]
107
+ )
108
+ ).item()
109
+
110
+ if self.audio_ids_concat.size(1) > 0:
111
+ audio_io_ids = self.input_ids[
112
+ (self.input_ids == audio_in_token_id) | (self.input_ids == audio_out_token_id)
113
+ ]
114
+ audio_io_id_lengths = torch.concat(
115
+ [
116
+ torch.diff(self.audio_ids_start),
117
+ torch.tensor([self.audio_ids_concat.shape[-1] - self.audio_ids_start[-1]]),
118
+ ]
119
+ )
120
+ if encode_audio_in_tokens:
121
+ num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_in_token_id]).item()
122
+
123
+ if encode_audio_out_tokens:
124
+ num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_out_token_id]).item()
125
+
126
+ return int(num_tokens)
127
+
128
+ @classmethod
129
+ def merge(
130
+ cls,
131
+ samples: List["ChatMLDatasetSample"],
132
+ eos_token_id: int,
133
+ ignore_index: int,
134
+ padding_size: Optional[int] = None,
135
+ ) -> "ChatMLDatasetSample":
136
+ """Merges a list of ChatMLDatasetSample instances, inserting eos_token_id and ignore_index between them, and adjusting offsets for audio_ids_start and audio_waveforms_start.
137
+
138
+ Args:
139
+ samples (List[ChatMLDatasetSample]): List of samples to merge.
140
+ eos_token_id (int): Tokens to be inserted into input_ids between samples.
141
+ ignore_index (int): Default label for padding.
142
+ padding_size (Optional[int]): If provided, pad the sequence to with this length.
143
+
144
+ Returns:
145
+ ChatMLDatasetSample: Merged and potentially padded sample.
146
+ """
147
+ if not samples:
148
+ logger.fatal("The samples list is empty and cannot be merged.")
149
+ raise ValueError("The samples list is empty and cannot be merged.")
150
+
151
+ # Initialize empty lists for concatenation
152
+ input_ids_list = []
153
+ label_ids_list = []
154
+ audio_ids_concat_list = []
155
+ audio_ids_start_list = []
156
+ audio_waveforms_concat_list = []
157
+ audio_waveforms_start_list = []
158
+ audio_sample_rate_list = []
159
+ audio_speaker_indices_list = []
160
+
161
+ # Track offsets
162
+ audio_ids_offset = 0
163
+ audio_waveforms_offset = 0
164
+
165
+ for sample in samples:
166
+ # Add input_ids and label_ids with padding
167
+ if input_ids_list:
168
+ input_ids_list.append(torch.tensor([eos_token_id], dtype=torch.long))
169
+ label_ids_list.append(torch.tensor([ignore_index], dtype=torch.long))
170
+ input_ids_list.append(sample.input_ids)
171
+ label_ids_list.append(sample.label_ids)
172
+
173
+ # Add audio_ids_concat and handle empty audio ids
174
+ if sample.audio_ids_concat.size(1) > 0:
175
+ audio_ids_concat_list.append(sample.audio_ids_concat)
176
+
177
+ # Offset and add audio_ids_start
178
+ audio_ids_start_list.append(sample.audio_ids_start + audio_ids_offset)
179
+ audio_ids_offset += sample.audio_ids_concat.size(
180
+ 1
181
+ ) # (num_codebooks, seq_len): Update offset by audio_seq_len
182
+
183
+ # Add audio_waveforms_concat
184
+ if sample.audio_waveforms_concat.size(0) > 0:
185
+ # Check dimensions of the audio waveform to ensure consistency
186
+ if (
187
+ audio_waveforms_concat_list
188
+ and sample.audio_waveforms_concat.dim() != audio_waveforms_concat_list[0].dim()
189
+ ):
190
+ logger.warning(
191
+ f"Skipping audio waveform with inconsistent dimensions: expected {audio_waveforms_concat_list[0].dim()}D, got {sample.audio_waveforms_concat.dim()}D"
192
+ )
193
+ continue
194
+
195
+ audio_waveforms_concat_list.append(sample.audio_waveforms_concat)
196
+ audio_waveforms_start_list.append(sample.audio_waveforms_start + audio_waveforms_offset)
197
+ audio_waveforms_offset += sample.audio_waveforms_concat.size(0)
198
+
199
+ # Add audio_sample_rate and audio_speaker_indices
200
+ audio_sample_rate_list.append(sample.audio_sample_rate)
201
+
202
+ audio_speaker_indices_list.append(sample.audio_speaker_indices)
203
+
204
+ # Concatenate all tensors
205
+ input_ids = torch.cat(input_ids_list, dim=0)
206
+ label_ids = torch.cat(label_ids_list, dim=0)
207
+
208
+ # Apply padding if padding_size is specified
209
+ if padding_size is not None and padding_size > 0:
210
+ input_ids = torch.cat(
211
+ [
212
+ input_ids,
213
+ torch.full((padding_size,), eos_token_id, dtype=torch.long),
214
+ ],
215
+ dim=0,
216
+ )
217
+ label_ids = torch.cat(
218
+ [
219
+ label_ids,
220
+ torch.full((padding_size,), ignore_index, dtype=torch.long),
221
+ ],
222
+ dim=0,
223
+ )
224
+
225
+ # Safely concatenate audio tensors with proper error handling
226
+ try:
227
+ audio_ids_concat = torch.cat(audio_ids_concat_list, dim=1) if audio_ids_concat_list else torch.tensor([[]])
228
+ audio_ids_start = torch.cat(audio_ids_start_list, dim=0) if audio_ids_start_list else torch.tensor([])
229
+
230
+ # Check for dimensional consistency in audio waveforms
231
+ if audio_waveforms_concat_list:
232
+ dims = [t.dim() for t in audio_waveforms_concat_list]
233
+ if not all(d == dims[0] for d in dims):
234
+ # If dimensions don't match, log warning and filter out the problematic tensors
235
+ logger.warning(
236
+ f"Inconsistent dimensions in audio waveforms: {dims}. Filtering to keep only consistent ones."
237
+ )
238
+ expected_dim = max(set(dims), key=dims.count) # Most common dimension
239
+ audio_waveforms_concat_list = [t for t in audio_waveforms_concat_list if t.dim() == expected_dim]
240
+
241
+ # Recalculate audio_waveforms_start with the filtered list
242
+ if audio_waveforms_concat_list:
243
+ audio_waveforms_offset = 0
244
+ audio_waveforms_start_list = []
245
+ for waveform in audio_waveforms_concat_list:
246
+ audio_waveforms_start_list.append(torch.tensor([audio_waveforms_offset]))
247
+ audio_waveforms_offset += waveform.size(0)
248
+
249
+ audio_waveforms_concat = (
250
+ torch.cat(audio_waveforms_concat_list, dim=0) if audio_waveforms_concat_list else torch.tensor([])
251
+ )
252
+ audio_waveforms_start = (
253
+ torch.cat(audio_waveforms_start_list, dim=0) if audio_waveforms_start_list else torch.tensor([])
254
+ )
255
+ audio_sample_rate = (
256
+ torch.cat(audio_sample_rate_list, dim=0) if audio_sample_rate_list else torch.tensor([])
257
+ )
258
+ audio_speaker_indices = (
259
+ torch.cat(audio_speaker_indices_list, dim=0) if audio_speaker_indices_list else torch.tensor([])
260
+ )
261
+
262
+ except RuntimeError as e:
263
+ logger.error(f"Error during tensor concatenation: {str(e)}")
264
+ logger.warning("Falling back to empty audio tensors")
265
+ # Fall back to empty tensors
266
+ audio_ids_concat = torch.tensor([[]])
267
+ audio_ids_start = torch.tensor([])
268
+ audio_waveforms_concat = torch.tensor([])
269
+ audio_waveforms_start = torch.tensor([])
270
+ audio_sample_rate = torch.tensor([])
271
+ audio_speaker_indices = torch.tensor([])
272
+
273
+ # Create the merged sample
274
+ merged_sample = cls(
275
+ input_ids=input_ids,
276
+ label_ids=label_ids,
277
+ audio_ids_concat=audio_ids_concat,
278
+ audio_ids_start=audio_ids_start,
279
+ audio_waveforms_concat=audio_waveforms_concat,
280
+ audio_waveforms_start=audio_waveforms_start,
281
+ audio_sample_rate=audio_sample_rate,
282
+ audio_speaker_indices=audio_speaker_indices,
283
+ )
284
+
285
+ return merged_sample
286
+
287
+
288
+ @dataclass
289
+ class RankedChatMLDatasetSampleTuple:
290
+ samples: List[ChatMLDatasetSample]
291
+ scores: List[float]
292
+
293
+ def max_score_sample(self) -> ChatMLDatasetSample:
294
+ idx = self.scores.index(max(self.scores))
295
+ self.samples[idx].reward = self.scores[idx]
296
+ return self.samples[idx]
297
+
298
+ def min_score_sample(self) -> ChatMLDatasetSample:
299
+ idx = self.scores.index(min(self.scores))
300
+ self.samples[idx].reward = self.scores[idx]
301
+ return self.samples[idx]
302
+
303
+
304
+ @dataclass
305
+ class ChatMLDatasetStorageSample:
306
+ input_tokens: torch.LongTensor
307
+ label_tokens: torch.LongTensor
308
+ audio_bytes_cache_dir_index: int
309
+ audio_codes_cache_dir_index: int
310
+ audio_bytes_indices: torch.LongTensor
311
+ audio_codes_indices: torch.LongTensor
312
+ speaker_indices: torch.LongTensor
313
+ file_index: int
314
+ original_sample_index: int
315
+
316
+
317
+ # TODO(sxjscience): We need to revist the logic about parsing speaker ids.
318
+ # Currently, we assume that the speaker id is stored at the "misc" field in ChatMLSample.
319
+ def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
320
+ """Preprocess the ChatML sample to get the tokens for the text part.
321
+
322
+ Args:
323
+ sample (ChatMLSample): The ChatML sample to preprocess.
324
+ tokenizer: The tokenizer to use for encoding the text.
325
+
326
+ """
327
+
328
+ try:
329
+ if not isinstance(sample, ChatMLSample):
330
+ # Handle all fields that could be NaN
331
+ if "speaker" in sample and pd.isna(sample["speaker"]):
332
+ sample["speaker"] = None
333
+ if "start_index" in sample and pd.isna(sample["start_index"]):
334
+ sample["start_index"] = None
335
+ if "content" in sample and pd.isna(sample["content"]):
336
+ sample["content"] = ""
337
+
338
+ # Convert any other potential NaN values in nested structures
339
+ def convert_nan_to_none(obj):
340
+ import numpy as np
341
+
342
+ if isinstance(obj, (pd.Series, np.ndarray)):
343
+ return obj.tolist()
344
+ elif pd.api.types.is_scalar(obj) and pd.isna(obj):
345
+ return None
346
+ elif isinstance(obj, dict):
347
+ return {k: convert_nan_to_none(v) for k, v in obj.items()}
348
+ elif isinstance(obj, (list, tuple)): # Fixed: Handle both list and tuple
349
+ return [convert_nan_to_none(item) for item in obj]
350
+ return obj
351
+
352
+ # Clean the sample data
353
+ clean_sample = convert_nan_to_none(sample)
354
+
355
+ val_keys = []
356
+ for field in fields(ChatMLSample):
357
+ if field.name in clean_sample:
358
+ val_keys.append(field.name)
359
+ clean_sample = {k: clean_sample[k] for k in val_keys}
360
+
361
+ try:
362
+ sample = dacite.from_dict(
363
+ data_class=ChatMLSample,
364
+ data=clean_sample,
365
+ config=dacite.Config(strict=True, check_types=True),
366
+ )
367
+ except Exception as e:
368
+ print(f"Failed to convert to ChatMLSample: {e}")
369
+ print(f"Clean sample: {json.dumps(clean_sample, indent=2)}")
370
+ return None, None, None, None
371
+
372
+ input_tokens = []
373
+ label_tokens = []
374
+ audio_contents = []
375
+ speaker_id = None
376
+ if sample.speaker is not None:
377
+ speaker_id = sample.speaker
378
+ elif sample.misc is not None:
379
+ if "speaker" in sample.misc:
380
+ speaker_id = sample.misc["speaker"]
381
+
382
+ total_m = len(sample.messages)
383
+ for turn_id, message in enumerate(sample.messages):
384
+ role = message.role
385
+ recipient = message.recipient
386
+ content = message.content
387
+ content_l = []
388
+
389
+ if isinstance(content, str):
390
+ content_l.append(TextContent(text=content))
391
+ elif isinstance(content, TextContent):
392
+ content_l.append(content)
393
+ elif isinstance(content, AudioContent):
394
+ content_l.append(content)
395
+ elif isinstance(content, list):
396
+ for ele in content:
397
+ if isinstance(ele, str):
398
+ content_l.append(TextContent(text=ele))
399
+ else:
400
+ content_l.append(ele)
401
+ if turn_id == 0:
402
+ prefix = f"<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n"
403
+ else:
404
+ prefix = f"<|start_header_id|>{role}<|end_header_id|>\n\n"
405
+ eot_postfix = "<|eot_id|>"
406
+ eom_postfix = "<|eom_id|>"
407
+
408
+ prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
409
+ input_tokens.extend(prefix_tokens)
410
+ label_tokens.extend([-100 for _ in prefix_tokens])
411
+
412
+ if recipient:
413
+ assert role == "assistant", "Recipient is only available for assistant role."
414
+ recipient_tokens = tokenizer.encode(f"{recipient}<|recipient|>", add_special_tokens=False)
415
+ input_tokens.extend(recipient_tokens)
416
+ label_tokens.extend(recipient_tokens)
417
+
418
+ for content in content_l:
419
+ if content.type == "text":
420
+ text_tokens = tokenizer.encode(content.text, add_special_tokens=False)
421
+ input_tokens.extend(text_tokens)
422
+ if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
423
+ label_tokens.extend(text_tokens)
424
+ else:
425
+ label_tokens.extend([-100 for _ in text_tokens])
426
+
427
+ elif content.type == "audio":
428
+ # Generate the text-part of the audio tokens
429
+ audio_contents.append(content)
430
+ if role == "user" or role == "system":
431
+ # Add the text tokens
432
+ text_tokens = tokenizer.encode(
433
+ f"<|audio_bos|><|AUDIO|><|audio_eos|>",
434
+ add_special_tokens=False,
435
+ )
436
+ input_tokens.extend(text_tokens)
437
+ label_tokens.extend([-100 for _ in text_tokens])
438
+ elif role == "assistant":
439
+ # Add the text tokens for audio-out part.
440
+ text_tokens = tokenizer.encode(
441
+ f"<|audio_out_bos|><|AUDIO_OUT|><|audio_eos|>",
442
+ add_special_tokens=False,
443
+ )
444
+ input_tokens.extend(text_tokens)
445
+ if sample.start_index is None or turn_id >= sample.start_index:
446
+ label_tokens.extend(text_tokens)
447
+ else:
448
+ label_tokens.extend([-100 for _ in text_tokens])
449
+ next_id = turn_id + 1
450
+ if role == "assistant" and next_id != total_m and sample.messages[next_id].role == "assistant":
451
+ postfix_tokens = tokenizer.encode(eom_postfix, add_special_tokens=False)
452
+ input_tokens.extend(postfix_tokens)
453
+ else:
454
+ postfix_tokens = tokenizer.encode(eot_postfix, add_special_tokens=False)
455
+ input_tokens.extend(postfix_tokens)
456
+ if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
457
+ label_tokens.extend(postfix_tokens)
458
+ else:
459
+ label_tokens.extend([-100 for _ in postfix_tokens])
460
+
461
+ return input_tokens, label_tokens, audio_contents, speaker_id
462
+
463
+ except Exception as e:
464
+ print(f"Error in prepare_chatml_sample: {str(e)}")
465
+ print(f"Sample data: {json.dumps(sample, indent=2)}")
466
+ return None, None, None, None
467
+
468
+
469
+ def extract_generation_prompt_from_input_tokens(input_tokens, tokenizer):
470
+ """Extract the generation prompt and reference answer from the input tokens.
471
+
472
+ For example:
473
+
474
+ Input Text = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
475
+ What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
476
+ <|start_header_id|>assistant<|end_header_id|>\n\nAt first they went by quick, too quick to even get.<|eot_id|>'
477
+
478
+ -->
479
+
480
+ Prompt = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
481
+ What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
482
+ <|start_header_id|>assistant<|end_header_id|>\n\n',
483
+ Reference = 'At first they went by quick, too quick to even get.'
484
+
485
+ Args:
486
+ input_tokens: The input tokens.
487
+ audio_contents: The audio contents.
488
+ tokenizer: The tokenizer to use for decoding the text.
489
+
490
+ Returns:
491
+ prompt_tokens: The tokens for the prompt.
492
+ reference_answer: The reference answer.
493
+ num_audios_in_reference: The number of audios in the reference answer.
494
+
495
+ """
496
+ input_text = tokenizer.decode(input_tokens)
497
+ generation_prefix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
498
+ postfix = "<|eot_id|>"
499
+ assert generation_prefix in input_text
500
+ generation_prompt_end_loc = input_text.rfind(generation_prefix) + len(generation_prefix)
501
+ generation_prompt = input_text[:generation_prompt_end_loc]
502
+ reference_answer = input_text[generation_prompt_end_loc : input_text.find(postfix, generation_prompt_end_loc)]
503
+ num_audios_in_reference = reference_answer.count(AUDIO_IN_TOKEN) + reference_answer.count(AUDIO_OUT_TOKEN)
504
+ return (
505
+ tokenizer.encode(generation_prompt, add_special_tokens=False),
506
+ reference_answer,
507
+ num_audios_in_reference,
508
+ )
509
+
510
+
511
+ def prepare_chatml_dataframe_single_process(df, tokenizer):
512
+ """Prepare the ChatML DataFrame."""
513
+ ret = []
514
+ for _, row in df.iterrows():
515
+ input_tokens, label_tokens, audio_contents, speaker_id = prepare_chatml_sample(row.to_dict(), tokenizer)
516
+ ret.append((input_tokens, label_tokens, audio_contents, speaker_id))
517
+ return ret
518
+
519
+
520
+ def prepare_chatml_dataframe(df, tokenizer, num_process=16):
521
+ if num_process is None:
522
+ return prepare_chatml_dataframe_single_process(df, tokenizer)
523
+ else:
524
+ num_process = max(min(len(df) // 1000, num_process), 1)
525
+ workloads = np.array_split(df, num_process)
526
+ with mp.Pool(num_process) as pool:
527
+ ret = pool.starmap(
528
+ prepare_chatml_dataframe_single_process,
529
+ [(workload, tokenizer) for workload in workloads],
530
+ )
531
+ return sum(ret, [])
532
+
533
+
534
+ class DatasetInterface(ABC):
535
+ @abstractmethod
536
+ def __getitem__(self, idx) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
537
+ """Retrieve a dataset sample by index."""
538
+ raise NotImplementedError
539
+
540
+
541
+ class IterableDatasetInterface(ABC):
542
+ @abstractmethod
543
+ def __iter__(
544
+ self,
545
+ ) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
546
+ """Retrieve a sample by iterating through the dataset."""
547
+ raise NotImplementedError
548
+
549
+
550
+ @dataclass
551
+ class DatasetInfo:
552
+ dataset_type: str
553
+ group_type: Optional[str] = None
554
+ mask_text: Optional[bool] = None # Whether to mask the text tokens for pretraining samples.
higgs_audio/model/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel
2
+
3
+ from .configuration_higgs_audio import HiggsAudioConfig, HiggsAudioEncoderConfig
4
+ from .modeling_higgs_audio import HiggsAudioModel
5
+
6
+
7
+ AutoConfig.register("higgs_audio_encoder", HiggsAudioEncoderConfig)
8
+ AutoConfig.register("higgs_audio", HiggsAudioConfig)
9
+ AutoModel.register(HiggsAudioConfig, HiggsAudioModel)
higgs_audio/model/audio_head.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Projector that maps hidden states from the LLM component to multimodal logits."""
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple
8
+
9
+ from .common import HiggsAudioPreTrainedModel
10
+ from .configuration_higgs_audio import HiggsAudioConfig
11
+
12
+
13
+ @dataclass
14
+ class HiggsAudioDecoderLayerOutput:
15
+ logits: torch.FloatTensor
16
+ audio_logits: torch.FloatTensor
17
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
18
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
19
+
20
+
21
+ class HiggsAudioDecoderProjector(HiggsAudioPreTrainedModel):
22
+ """Projection layers that map hidden states from the LLM component to audio / text logits.
23
+
24
+ We support two type of audio head:
25
+ - Basic Audio Head:
26
+ Directly map the hidden states to audio logits for all the codebooks.
27
+ """
28
+
29
+ def __init__(self, config: HiggsAudioConfig, layer_idx: Optional[int] = None):
30
+ super().__init__(config)
31
+ self.text_lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
32
+ self.audio_lm_head = nn.Linear(
33
+ config.text_config.hidden_size,
34
+ config.audio_num_codebooks * (config.audio_codebook_size + 2),
35
+ bias=False,
36
+ )
37
+
38
+ # Initialize weights and apply final processing
39
+ self.post_init()
40
+
41
+ def forward(
42
+ self,
43
+ hidden_states,
44
+ audio_out_mask,
45
+ label_audio_ids=None,
46
+ attention_mask=None,
47
+ position_ids=None,
48
+ past_key_values=None,
49
+ use_cache=None,
50
+ output_attentions=None,
51
+ output_hidden_states=None,
52
+ output_audio_hidden_states=False,
53
+ cache_position=None,
54
+ ):
55
+ """
56
+ Args:
57
+ hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
58
+ Hidden states from the LLM component
59
+ audio_out_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
60
+ Mask for identifying the audio out tokens.
61
+ label_audio_ids (`torch.Tensor` of shape `(num_codebooks, num_audio_out_tokens)`):
62
+ Label tokens for the audio-out part. This is used for calculating the logits if RQ-Transformer is used.
63
+ attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
64
+ Mask to avoid performing attention on padding token indices
65
+ position_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
66
+ Position ids for the input tokens
67
+
68
+ Returns:
69
+ logits (`torch.Tensor` of shape `(batch_size, seq_len, vocab_size)`):
70
+ Logits for text tokens
71
+ audio_logits (`torch.Tensor` of shape `(num_audio_out_tokens, audio_num_codebooks * audio_codebook_size)`):
72
+ Logits for audio tokens. We ensure `num_text_tokens + num_audio_tokens == batch_size * seq_len`
73
+ """
74
+ logits = self.text_lm_head(hidden_states)
75
+
76
+ all_hidden_states = () if output_hidden_states else None
77
+ all_self_attns = () if output_attentions else None
78
+ next_decoder_cache = None
79
+
80
+ # TODO(sxjscience) Need to check if DeepSpeed Zero3 supports zero-shape input.
81
+ if self.config.audio_decoder_proj_num_layers > 0:
82
+ # create position embeddings to be shared across the decoder layers
83
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
84
+ for decoder_layer in self.transformer_layers:
85
+ if output_hidden_states:
86
+ all_hidden_states += (hidden_states,)
87
+
88
+ if self.gradient_checkpointing and self.training:
89
+ layer_outputs = self._gradient_checkpointing_func(
90
+ decoder_layer.__call__,
91
+ hidden_states,
92
+ attention_mask,
93
+ position_ids,
94
+ past_key_values,
95
+ output_attentions,
96
+ use_cache,
97
+ cache_position,
98
+ position_embeddings,
99
+ )
100
+ else:
101
+ layer_outputs = decoder_layer(
102
+ hidden_states,
103
+ attention_mask=attention_mask,
104
+ position_ids=position_ids,
105
+ past_key_value=past_key_values,
106
+ output_attentions=output_attentions,
107
+ use_cache=use_cache,
108
+ cache_position=cache_position,
109
+ position_embeddings=position_embeddings,
110
+ )
111
+ hidden_states = layer_outputs[0]
112
+ hidden_states = self.norm(hidden_states)
113
+
114
+ if output_hidden_states:
115
+ all_hidden_states += (hidden_states,)
116
+
117
+ if output_attentions:
118
+ all_self_attns += (layer_outputs[1],)
119
+
120
+ if use_cache:
121
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
122
+
123
+ next_cache = next_decoder_cache if use_cache else None
124
+
125
+ audio_logits = self.audio_lm_head(hidden_states[audio_out_mask])
126
+
127
+ if output_audio_hidden_states:
128
+ audio_hidden_states = hidden_states[audio_out_mask]
129
+ else:
130
+ audio_hidden_states = None
131
+
132
+ return (
133
+ logits,
134
+ audio_logits,
135
+ all_self_attns,
136
+ all_hidden_states,
137
+ audio_hidden_states,
138
+ next_cache,
139
+ )
higgs_audio/model/common.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from transformers.modeling_utils import PreTrainedModel
4
+
5
+ from .configuration_higgs_audio import HiggsAudioConfig
6
+
7
+
8
+ class HiggsAudioPreTrainedModel(PreTrainedModel):
9
+ config_class = HiggsAudioConfig
10
+ base_model_prefix = "model"
11
+ supports_gradient_checkpointing = True
12
+ _no_split_modules = []
13
+ _skip_keys_device_placement = "past_key_values"
14
+ _supports_flash_attn_2 = True
15
+ _supports_sdpa = True
16
+
17
+ def _init_weights(self, module):
18
+ std = self.config.init_std if hasattr(self.config, "init_std") else self.config.audio_encoder_config.init_std
19
+
20
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
21
+ module.weight.data.normal_(mean=0.0, std=std)
22
+ if module.bias is not None:
23
+ module.bias.data.zero_()
24
+ elif isinstance(module, nn.Embedding):
25
+ module.weight.data.normal_(mean=0.0, std=std)
26
+ if module.padding_idx is not None:
27
+ module.weight.data[module.padding_idx].zero_()
higgs_audio/model/configuration_higgs_audio.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.models.auto import CONFIG_MAPPING
3
+
4
+
5
+ class HiggsAudioEncoderConfig(PretrainedConfig):
6
+ """Configuration of the Audio encoder in Higgs-Audio."""
7
+
8
+ model_type = "higgs_audio_encoder"
9
+
10
+ def __init__(
11
+ self,
12
+ num_mel_bins=128,
13
+ encoder_layers=32,
14
+ encoder_attention_heads=20,
15
+ encoder_ffn_dim=5120,
16
+ encoder_layerdrop=0.0,
17
+ d_model=1280,
18
+ dropout=0.0,
19
+ attention_dropout=0.0,
20
+ activation_function="gelu",
21
+ activation_dropout=0.0,
22
+ scale_embedding=False,
23
+ init_std=0.02,
24
+ max_source_positions=1500,
25
+ pad_token_id=128001,
26
+ **kwargs,
27
+ ):
28
+ super().__init__(**kwargs)
29
+
30
+ self.num_mel_bins = num_mel_bins
31
+ self.d_model = d_model
32
+ self.encoder_layers = encoder_layers
33
+ self.encoder_attention_heads = encoder_attention_heads
34
+ self.encoder_ffn_dim = encoder_ffn_dim
35
+ self.dropout = dropout
36
+ self.attention_dropout = attention_dropout
37
+ self.activation_function = activation_function
38
+ self.activation_dropout = activation_dropout
39
+ self.encoder_layerdrop = encoder_layerdrop
40
+ self.num_hidden_layers = encoder_layers
41
+ self.init_std = init_std
42
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
43
+ self.max_source_positions = max_source_positions
44
+ self.pad_token_id = pad_token_id
45
+
46
+
47
+ class HiggsAudioConfig(PretrainedConfig):
48
+ r"""
49
+ This is the configuration class for the HiggsAudioModel.
50
+
51
+ Args:
52
+ text_config (`Union[AutoConfig, dict]`):
53
+ The config object or dictionary of the text backbone.
54
+ audio_encoder_config (`Union[AutoConfig, dict]`):
55
+ The config object or dictionary of the whisper encoder.
56
+ The audio encoder will be bidirectional and will be only available for audio understanding.
57
+ audio_tokenizer_config
58
+ The config object or dictionary of the audio tokenizer.
59
+ audio_adapter_type
60
+ The type of audio adapter to use. We support two types of adapter:
61
+ - stack:
62
+ We stack additional Transformer layers after the main LLM backbone for audio generation.
63
+ - dual_ffn:
64
+ For selected part of the LLM backbone, we replace the text FFN with a dual FFN architecture
65
+ that contains an additional audio FFN. The audio FFN will be triggered when the location is marked for audio tokens.
66
+ - dual_ffn_fast_forward:
67
+ We pick a few layers in the LLM backbone to plug-in the audio FFN. For the remaining layers,
68
+ the audio hidden states will be directly fast-forward to the next layer.
69
+ This reduces the computational cost for audio generation.
70
+ audio_embed_avg (`bool`, *optional*, defaults to False):
71
+ Whether to average the audio embeddings before sending them to the text attention layer.
72
+ audio_ffn_hidden_size
73
+ The hidden size of the audio feedforward network in dual-path FFN
74
+ audio_ffn_intermediate_size
75
+ The intermediate size of the audio feedforward network in dual-path FFN
76
+ audio_dual_ffn_layers
77
+ The layers in the LLM backbone to plug-in the dual FFN layer (mixture of audio FFN and text FFN).
78
+ audio_decoder_proj_num_attention (`int`, *optional*, defaults to 0):
79
+ The number of attention heads in the audio decoder projection layer.
80
+ use_delay_pattern (`bool`, *optional*, defaults to False):
81
+ Whether to use delay pattern in the audio decoder.
82
+ skip_audio_tower (`bool`, *optional*, defaults to False):
83
+ Whether to skip the audio tower in the audio encoder.
84
+ use_audio_out_embed_projector (`bool`, *optional*, defaults to False):
85
+ Whether to use an embedding projector to map audio out embeddings.
86
+ use_audio_out_self_attention (`bool`, *optional*, defaults to False):
87
+ Whether to use self-attention to aggregate information from audio-tokens before sending to the text attention layer.
88
+ audio_num_codebooks (`int`, *optional*, defaults to 12):
89
+ The number of codebooks in RVQGAN.
90
+ audio_codebook_size (`int`, *optional*, defaults to 1024):
91
+ The size of each codebook in RVQGAN.
92
+ audio_stream_bos_id
93
+ The id of the bos in the audio stream
94
+ audio_stream_eos_id
95
+ The id of the eos in the audio stream
96
+ audio_bos_token (`str`, *optional*, defaults to "<|audio_bos|>"):
97
+ The special `<|audio_bos|>` token. In Higgs-Audio, it is mapped to 128011,
98
+ which is the index of `<|reserved_special_token_3|>` in Llama-3.1-8B-Instruct's tokenizer.
99
+ audio_eos_token (`str`, *optional*, defaults to "<|audio_eos|>"):
100
+ The special `<|audio_eos|>` token. We use 128012 as the default value,
101
+ which is the index of `<|reserved_special_token_4|>` in Llama-3.1-8B-Instruct's tokenizer.
102
+ audio_out_bos_token (`str`, *optional*, defaults to "<|audio_out_bos|>"):
103
+ The special `<|audio_out_bos|>` token. We use 128013 as the default value,
104
+ which is the index of `<|reserved_special_token_5|>` in Llama-3.1-8B-Instruct's tokenizer.
105
+ audio_token (`str`, *optional*, defaults to "<|AUDIO|>"):
106
+ The special `<|AUDIO|>` token. We use 128015 as the default value,
107
+ which is the index of `<|reserved_special_token_7|>` in Llama-3.1-8B-Instruct's tokenizer.
108
+ This token indicates that the location should be filled in with whisper features.
109
+ audio_out_token (`str`, *optional*, defaults to "<|AUDIO_OUT|>"):
110
+ The special `<|AUDIO_OUT|>` token. We use 128016 as the default value,
111
+ which is the index of `<|reserved_special_token_8|>` in Llama-3.1-8B-Instruct's tokenizer.
112
+ This token indicates that the location should be filled in with audio tokens extracted via audio tokenizer.
113
+ """
114
+
115
+ model_type = "higgs_audio"
116
+ is_composition = True
117
+
118
+ def __init__(
119
+ self,
120
+ text_config=None,
121
+ audio_encoder_config=None,
122
+ audio_tokenizer_config=None,
123
+ audio_adapter_type="stack",
124
+ audio_embed_avg=False,
125
+ audio_ffn_hidden_size=4096,
126
+ audio_ffn_intermediate_size=14336,
127
+ audio_dual_ffn_layers=None,
128
+ audio_decoder_proj_num_layers=0,
129
+ encode_whisper_embed=True,
130
+ encode_audio_in_tokens=False,
131
+ use_delay_pattern=False,
132
+ skip_audio_tower=False,
133
+ use_audio_out_embed_projector=False,
134
+ use_audio_out_self_attention=False,
135
+ use_rq_transformer=False,
136
+ rq_transformer_hidden_size=None,
137
+ rq_transformer_intermediate_size=None,
138
+ rq_transformer_num_attention_heads=None,
139
+ rq_transformer_num_key_value_heads=None,
140
+ rq_transformer_num_hidden_layers=3,
141
+ audio_num_codebooks=12,
142
+ audio_codebook_size=1024,
143
+ audio_stream_bos_id=1024,
144
+ audio_stream_eos_id=1025,
145
+ audio_bos_token="<|audio_bos|>",
146
+ audio_eos_token="<|audio_eos|>",
147
+ audio_out_bos_token="<|audio_out_bos|>",
148
+ audio_in_token="<|AUDIO|>",
149
+ audio_out_token="<|AUDIO_OUT|>",
150
+ audio_in_token_idx=128015,
151
+ audio_out_token_idx=128016,
152
+ pad_token_id=128001,
153
+ audio_out_bos_token_id=128013,
154
+ audio_eos_token_id=128012,
155
+ **kwargs,
156
+ ):
157
+ if isinstance(audio_encoder_config, dict):
158
+ audio_encoder_config["model_type"] = (
159
+ audio_encoder_config["model_type"] if "model_type" in audio_encoder_config else "higgs_audio_encoder"
160
+ )
161
+ audio_encoder_config = CONFIG_MAPPING[audio_encoder_config["model_type"]](**audio_encoder_config)
162
+ elif audio_encoder_config is None:
163
+ audio_encoder_config = HiggsAudioEncoderConfig()
164
+
165
+ if isinstance(text_config, dict):
166
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
167
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
168
+ elif text_config is None:
169
+ text_config = CONFIG_MAPPING["llama"]()
170
+
171
+ assert audio_adapter_type in [
172
+ "stack",
173
+ "dual_ffn",
174
+ "dual_ffn_fast_forward",
175
+ ], f"Invalid audio adapter type: {audio_adapter_type}"
176
+ if audio_adapter_type.startswith("dual_ffn"):
177
+ assert audio_dual_ffn_layers is not None, (
178
+ "audio_dual_ffn_layers must be specified when using dual_ffn adapter."
179
+ )
180
+ self.text_config = text_config
181
+ self.audio_encoder_config = audio_encoder_config
182
+ self.audio_tokenizer_config = audio_tokenizer_config
183
+ self.audio_adapter_type = audio_adapter_type
184
+ self.audio_embed_avg = audio_embed_avg
185
+ self.audio_ffn_hidden_size = audio_ffn_hidden_size
186
+ self.audio_ffn_intermediate_size = audio_ffn_intermediate_size
187
+ self.audio_dual_ffn_layers = audio_dual_ffn_layers
188
+ self.audio_decoder_proj_num_layers = audio_decoder_proj_num_layers
189
+ self.encode_whisper_embed = encode_whisper_embed
190
+ self.encode_audio_in_tokens = encode_audio_in_tokens
191
+ self.use_delay_pattern = use_delay_pattern
192
+ self.skip_audio_tower = skip_audio_tower
193
+ self.use_audio_out_embed_projector = use_audio_out_embed_projector
194
+ self.use_audio_out_self_attention = use_audio_out_self_attention
195
+
196
+ self.use_rq_transformer = use_rq_transformer
197
+
198
+ if self.use_rq_transformer:
199
+ assert not self.use_delay_pattern, "Delay pattern is not supported if you turned on RQ-Transformer!"
200
+ self.rq_transformer_hidden_size = rq_transformer_hidden_size
201
+ self.rq_transformer_intermediate_size = rq_transformer_intermediate_size
202
+ self.rq_transformer_num_attention_heads = rq_transformer_num_attention_heads
203
+ self.rq_transformer_num_key_value_heads = rq_transformer_num_key_value_heads
204
+ self.rq_transformer_num_hidden_layers = rq_transformer_num_hidden_layers
205
+
206
+ if use_rq_transformer:
207
+ # For RQ-Transformer, we set the hidden_size to the same as the text model's hidden size if it is not specified.
208
+ if self.rq_transformer_hidden_size is None:
209
+ self.rq_transformer_hidden_size = text_config.hidden_size
210
+ assert self.rq_transformer_hidden_size % 128 == 0
211
+ if self.rq_transformer_intermediate_size is None:
212
+ self.rq_transformer_intermediate_size = text_config.intermediate_size
213
+ if self.rq_transformer_num_attention_heads is None:
214
+ self.rq_transformer_num_attention_heads = self.rq_transformer_hidden_size // 128
215
+ if self.rq_transformer_num_key_value_heads is None:
216
+ self.rq_transformer_num_key_value_heads = self.rq_transformer_hidden_size // 128 // 4
217
+ assert self.rq_transformer_hidden_size % self.rq_transformer_num_attention_heads == 0
218
+ assert self.rq_transformer_hidden_size % self.rq_transformer_num_key_value_heads == 0
219
+
220
+ self.audio_num_codebooks = audio_num_codebooks
221
+ self.audio_codebook_size = audio_codebook_size
222
+ self.audio_bos_token = audio_bos_token
223
+ self.audio_eos_token = audio_eos_token
224
+ self.audio_out_bos_token = audio_out_bos_token
225
+ self.audio_in_token = audio_in_token
226
+ self.audio_out_token = audio_out_token
227
+ self.audio_in_token_idx = audio_in_token_idx
228
+ self.audio_out_token_idx = audio_out_token_idx
229
+ self.audio_stream_bos_id = audio_stream_bos_id
230
+ self.audio_stream_eos_id = audio_stream_eos_id
231
+ self.audio_out_bos_token_id = audio_out_bos_token_id
232
+ self.audio_eos_token_id = audio_eos_token_id
233
+
234
+ super().__init__(**kwargs)
235
+ self.pad_token_id = pad_token_id
higgs_audio/model/cuda_graph_runner.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, List, Dict, Tuple, Union
4
+ import gc
5
+
6
+ from transformers.cache_utils import Cache
7
+
8
+
9
+ _NUM_WARMUP_ITERS = 2
10
+
11
+
12
+ class CUDAGraphRunner(nn.Module):
13
+ def __init__(self, model):
14
+ super().__init__()
15
+ self.model = model
16
+
17
+ self.input_buffers: Dict[str, torch.Tensor] = {}
18
+ self.output_buffers: Dict[str, torch.Tensor] = {}
19
+
20
+ self._graph: Optional[torch.cuda.CUDAGraph] = None
21
+
22
+ @property
23
+ def graph(self):
24
+ assert self._graph is not None
25
+ return self._graph
26
+
27
+ def capture(
28
+ self,
29
+ hidden_states: torch.Tensor,
30
+ causal_mask: torch.Tensor,
31
+ position_ids: torch.Tensor,
32
+ audio_discrete_codes_mask: torch.Tensor,
33
+ cache_position: torch.Tensor,
34
+ past_key_values: Union[Cache, List[torch.FloatTensor]],
35
+ use_cache: bool,
36
+ audio_attention_mask: torch.Tensor,
37
+ fast_forward_attention_mask: torch.Tensor,
38
+ output_attentions: bool,
39
+ output_hidden_states: bool,
40
+ is_decoding_audio_token: Optional[bool] = None,
41
+ is_using_cuda_graph: Optional[bool] = False,
42
+ stream: torch.cuda.Stream = None,
43
+ memory_pool: Optional[Tuple[int, int]] = None,
44
+ ):
45
+ assert self._graph is None
46
+ # Run warmup iterations
47
+ for _ in range(_NUM_WARMUP_ITERS):
48
+ self.model(
49
+ hidden_states=hidden_states,
50
+ causal_mask=causal_mask,
51
+ position_ids=position_ids,
52
+ audio_discrete_codes_mask=audio_discrete_codes_mask,
53
+ cache_position=cache_position,
54
+ past_key_values=past_key_values,
55
+ use_cache=use_cache,
56
+ audio_attention_mask=audio_attention_mask,
57
+ fast_forward_attention_mask=fast_forward_attention_mask,
58
+ output_attentions=output_attentions,
59
+ output_hidden_states=output_hidden_states,
60
+ is_decoding_audio_token=is_decoding_audio_token,
61
+ is_using_cuda_graph=is_using_cuda_graph,
62
+ )
63
+
64
+ torch.cuda.synchronize()
65
+
66
+ # Capture the graph
67
+ self._graph = torch.cuda.CUDAGraph()
68
+ with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
69
+ out_hidden_states, all_hidden_states, all_self_attns = self.model(
70
+ hidden_states=hidden_states,
71
+ causal_mask=causal_mask,
72
+ position_ids=position_ids,
73
+ audio_discrete_codes_mask=audio_discrete_codes_mask,
74
+ cache_position=cache_position,
75
+ past_key_values=past_key_values,
76
+ use_cache=use_cache,
77
+ audio_attention_mask=audio_attention_mask,
78
+ fast_forward_attention_mask=fast_forward_attention_mask,
79
+ output_attentions=output_attentions,
80
+ output_hidden_states=output_hidden_states,
81
+ is_decoding_audio_token=is_decoding_audio_token,
82
+ is_using_cuda_graph=is_using_cuda_graph,
83
+ )
84
+ # hidden_states_out = torch.ops._C.weak_ref_tensor(outputs[0])
85
+ # del outputs
86
+ gc.collect()
87
+ torch.cuda.synchronize()
88
+
89
+ # Save input and output buffers
90
+ self.input_buffers = {
91
+ "hidden_states": hidden_states,
92
+ "causal_mask": causal_mask,
93
+ "position_ids": position_ids,
94
+ "audio_discrete_codes_mask": audio_discrete_codes_mask,
95
+ "cache_position": cache_position,
96
+ "past_key_values": past_key_values,
97
+ "audio_attention_mask": audio_attention_mask,
98
+ "fast_forward_attention_mask": fast_forward_attention_mask,
99
+ }
100
+ self.output_buffers = {
101
+ "hidden_states": out_hidden_states,
102
+ "all_hidden_states": all_hidden_states,
103
+ "all_self_attns": all_self_attns,
104
+ }
105
+
106
+ def forward(
107
+ self,
108
+ hidden_states: torch.Tensor,
109
+ causal_mask: torch.Tensor,
110
+ position_ids: torch.Tensor,
111
+ audio_discrete_codes_mask: torch.Tensor,
112
+ cache_position: torch.Tensor,
113
+ audio_attention_mask: torch.Tensor,
114
+ fast_forward_attention_mask: torch.Tensor,
115
+ **kwargs,
116
+ ) -> torch.Tensor:
117
+ # Copy input tensors to buffers
118
+ self.input_buffers["hidden_states"].copy_(hidden_states, non_blocking=True)
119
+ self.input_buffers["causal_mask"].copy_(causal_mask, non_blocking=True)
120
+ self.input_buffers["position_ids"].copy_(position_ids, non_blocking=True)
121
+ self.input_buffers["audio_discrete_codes_mask"].copy_(audio_discrete_codes_mask, non_blocking=True)
122
+ self.input_buffers["cache_position"].copy_(cache_position, non_blocking=True)
123
+ self.input_buffers["audio_attention_mask"].copy_(audio_attention_mask, non_blocking=True)
124
+ self.input_buffers["fast_forward_attention_mask"].copy_(fast_forward_attention_mask, non_blocking=True)
125
+
126
+ # Run the captured graph
127
+ self.graph.replay()
128
+
129
+ return self.output_buffers["hidden_states"], None, None
higgs_audio/model/custom_modules.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class PartiallyFrozenEmbedding(nn.Module):
6
+ """Split an existing `nn.Embedding` module that splits the embedding into:
7
+
8
+ - A frozen embedding for indices [0..freeze_until_idx].
9
+ - A trainable embedding for indices [freeze_until_idx+1..vocab_size-1].
10
+
11
+ This should work with both Zero-2 and Zero-3 seamlessly
12
+ """
13
+
14
+ def __init__(self, original_embedding: nn.Embedding, freeze_until_idx: int):
15
+ """
16
+ :param original_embedding: An instance of nn.Embedding (the original embedding layer).
17
+ :param freeze_until_idx: The index up to which the embedding is frozen (excluding). The freeze_until_idx is not frozen.
18
+ """
19
+ super().__init__()
20
+ self.freeze_until_idx = freeze_until_idx
21
+ self.original_vocab_size = original_embedding.num_embeddings
22
+ self.embedding_dim = original_embedding.embedding_dim
23
+
24
+ # Split the original embedding into frozen and trainable parts
25
+ self.embedding_frozen = nn.Embedding(
26
+ freeze_until_idx,
27
+ self.embedding_dim,
28
+ dtype=original_embedding.weight.dtype,
29
+ device=original_embedding.weight.device,
30
+ )
31
+ self.embedding_trainable = nn.Embedding(
32
+ self.original_vocab_size - freeze_until_idx,
33
+ self.embedding_dim,
34
+ dtype=original_embedding.weight.dtype,
35
+ device=original_embedding.weight.device,
36
+ )
37
+
38
+ # Copy weights from the original embedding into the frozen and trainable parts
39
+ with torch.no_grad():
40
+ self.embedding_frozen.weight.copy_(original_embedding.weight[:freeze_until_idx])
41
+ self.embedding_trainable.weight.copy_(original_embedding.weight[freeze_until_idx:])
42
+
43
+ # Freeze the frozen embedding
44
+ self.embedding_frozen.weight.requires_grad = False
45
+
46
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ Forward pass for the split embedding wrapper.
49
+ :param input_ids: Tensor of shape [batch_size, seq_len] with indices in [0..original_vocab_size-1].
50
+ """
51
+ # Masks to separate frozen and trainable indices
52
+ # (bsz, seq_len)
53
+ mask_frozen = input_ids < self.freeze_until_idx
54
+ mask_trainable = ~mask_frozen
55
+
56
+ # Output tensor for embedding results
57
+ batch_size, seq_len = input_ids.shape
58
+ embeddings = torch.zeros(
59
+ batch_size,
60
+ seq_len,
61
+ self.embedding_dim,
62
+ device=input_ids.device,
63
+ dtype=self.embedding_frozen.weight.dtype,
64
+ )
65
+
66
+ # Handle frozen embedding
67
+ if mask_frozen.any():
68
+ frozen_ids = input_ids[mask_frozen]
69
+ frozen_emb = self.embedding_frozen(frozen_ids)
70
+ embeddings[mask_frozen] = frozen_emb
71
+
72
+ # Handle trainable embedding
73
+ if mask_trainable.any():
74
+ # Adjust trainable IDs to the local index space of the trainable embedding
75
+ trainable_ids = input_ids[mask_trainable] - (self.freeze_until_idx)
76
+ trainable_emb = self.embedding_trainable(trainable_ids)
77
+ embeddings[mask_trainable] = trainable_emb
78
+
79
+ return embeddings
80
+
81
+ def to_unsplit(self) -> nn.Embedding:
82
+ unsplit_embedding = nn.Embedding(
83
+ self.original_vocab_size,
84
+ self.embedding_dim,
85
+ dtype=self.embedding_frozen.weight.dtype,
86
+ device=self.embedding_frozen.weight.device,
87
+ )
88
+
89
+ with torch.no_grad():
90
+ unsplit_embedding.weight[: self.freeze_until_idx].copy_(self.embedding_frozen.weight)
91
+ unsplit_embedding.weight[self.freeze_until_idx :].copy_(self.embedding_trainable.weight)
92
+
93
+ return unsplit_embedding
94
+
95
+
96
+ class PartiallyFrozenLinear(nn.Module):
97
+ """A wrapper around nn.Linear to partially freeze part of the weight matrix."""
98
+
99
+ def __init__(self, original_linear: nn.Linear, freeze_until_idx: int):
100
+ """
101
+ :param original_linear: The original nn.Linear layer.
102
+ :param freeze_until_idx: The index up to which the rows of the weight matrix are frozen.
103
+ """
104
+ super().__init__()
105
+ assert original_linear.bias is None, "Currently only support linear module without bias"
106
+
107
+ self.freeze_until_idx = freeze_until_idx
108
+ self.input_dim = original_linear.in_features
109
+ self.output_dim = original_linear.out_features
110
+
111
+ # Create frozen and trainable linear layers
112
+ self.linear_frozen = nn.Linear(
113
+ self.input_dim,
114
+ freeze_until_idx,
115
+ bias=False,
116
+ dtype=original_linear.weight.dtype,
117
+ device=original_linear.weight.device,
118
+ )
119
+ self.linear_trainable = nn.Linear(
120
+ self.input_dim,
121
+ self.output_dim - freeze_until_idx,
122
+ bias=False,
123
+ dtype=original_linear.weight.dtype,
124
+ device=original_linear.weight.device,
125
+ )
126
+
127
+ # Copy weights from the original linear layer
128
+ with torch.no_grad():
129
+ self.linear_frozen.weight.copy_(original_linear.weight[:freeze_until_idx])
130
+ self.linear_trainable.weight.copy_(original_linear.weight[freeze_until_idx:])
131
+
132
+ # Freeze the frozen linear layer
133
+ self.linear_frozen.weight.requires_grad = False
134
+
135
+ def forward(self, input_tensor):
136
+ # input_tensor: (bsz, seq_len, hidden_state_dim)
137
+ frozen_output = self.linear_frozen(input_tensor)
138
+ trainable_output = self.linear_trainable(input_tensor)
139
+ return torch.cat((frozen_output, trainable_output), dim=-1)
140
+
141
+ def to_unsplit(self) -> nn.Linear:
142
+ unsplit_linear = nn.Linear(
143
+ self.input_dim,
144
+ self.output_dim,
145
+ bias=False,
146
+ dtype=self.linear_frozen.weight.dtype,
147
+ device=self.linear_frozen.weight.device,
148
+ )
149
+
150
+ # Copy weights from the frozen and trainable layers into the unsplit linear layer
151
+ with torch.no_grad():
152
+ unsplit_linear.weight[: self.freeze_until_idx].copy_(self.linear_frozen.weight)
153
+ unsplit_linear.weight[self.freeze_until_idx :].copy_(self.linear_trainable.weight)
154
+
155
+ return unsplit_linear
higgs_audio/model/modeling_higgs_audio.py ADDED
The diff for this file is too large to render. See raw diff
 
higgs_audio/model/utils.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ from contextlib import contextmanager
3
+ from functools import wraps
4
+ import torch
5
+ from transformers.integrations import is_deepspeed_available
6
+
7
+ if is_deepspeed_available():
8
+ from deepspeed.utils import groups as deepspeed_groups
9
+ from deepspeed.sequence.layer import _SeqAllToAll
10
+ else:
11
+ deepspeed_groups = None
12
+ _SeqAllToAll = None
13
+
14
+
15
+ def _ceil_to_nearest(n, round_to):
16
+ return (n + round_to - 1) // round_to * round_to
17
+
18
+
19
+ def count_parameters(model, trainable_only=True):
20
+ if trainable_only:
21
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
22
+ else:
23
+ return sum(p.numel() for p in model.parameters())
24
+
25
+
26
+ # TODO(sxjscience) Consider to move the function to audio_processing/utils.py
27
+ def build_delay_pattern_mask(
28
+ input_ids: torch.LongTensor,
29
+ bos_token_id: int,
30
+ pad_token_id: int,
31
+ ):
32
+ """Implement the delay pattern proposed in "Simple and Controllable Music Generation", https://arxiv.org/pdf/2306.05284
33
+
34
+ In the delay pattern, each codebook is offset by the previous codebook by
35
+ one. We insert a special delay token at the start of the sequence if its delayed, and append pad token once the sequence finishes.
36
+
37
+ Take the example where there are 4 codebooks and audio sequence length=5. After shifting, the output should have length seq_len + num_codebooks - 1
38
+
39
+ - [ *, *, *, *, *, P, P, P]
40
+ - [ B, *, *, *, *, *, P, P]
41
+ - [ B, B, *, *, *, *, *, P]
42
+ - [ B, B, B, *, *, *, *, *]
43
+
44
+ where B indicates the delay token id, P is the special padding token id and `*` indicates that the original audio token.
45
+
46
+ Now let's consider the case where we have a sequence of audio tokens to condition on.
47
+ The audio tokens were originally in the following non-delayed form:
48
+
49
+ - [a, b]
50
+ - [c, d]
51
+ - [e, f]
52
+ - [g, h]
53
+
54
+ After conversion, we get the following delayed form:
55
+ - [a, b, -1, -1, -1]
56
+ - [B, c, d, -1, -1]
57
+ - [B, B, e, f, -1]
58
+ - [B, B, B, g, h]
59
+
60
+ Note that we have a special token `-1` that indicates it should be replaced by a new token we see in the generation phase.
61
+ In that case, we should override the `-1` tokens in auto-regressive generation.
62
+
63
+ Args:
64
+ input_ids (:obj:`torch.LongTensor`):
65
+ The input ids of the prompt. It will have shape (bsz, num_codebooks, seq_len).
66
+ bos_token_id (:obj:`int`):
67
+ The id of the special delay token
68
+ pad_token_id (:obj:`int`):
69
+ The id of the padding token. Should be the same as eos_token_id.
70
+
71
+ Returns:
72
+ input_ids (:obj:`torch.LongTensor`):
73
+ The transformed input ids with delay pattern applied. It will have shape (bsz, num_codebooks, seq_len + num_codebooks - 1).
74
+ input_ids_with_gen_mask (:obj:`torch.LongTensor`):
75
+ The transformed input ids with delay pattern applied. The -1 in the output indicates new tokens that should be generated.
76
+
77
+ """
78
+ bsz, num_codebooks, seq_len = input_ids.shape
79
+
80
+ new_seq_len = seq_len + num_codebooks - 1
81
+ input_ids_with_gen_mask = torch.ones((bsz, num_codebooks, new_seq_len), dtype=torch.long, device=input_ids.device)
82
+ bos_mask = torch.tril(input_ids_with_gen_mask, -1) > 0
83
+ eos_mask = torch.triu(input_ids_with_gen_mask, seq_len) > 0
84
+ input_ids_with_gen_mask[bos_mask] = bos_token_id
85
+ input_ids_with_gen_mask[(~bos_mask) & (~eos_mask)] = input_ids.reshape(-1)
86
+ input_ids = input_ids_with_gen_mask.clone()
87
+ input_ids[eos_mask] = pad_token_id
88
+ input_ids_with_gen_mask[eos_mask] = -1
89
+ return input_ids, input_ids_with_gen_mask
90
+
91
+
92
+ def revert_delay_pattern(data):
93
+ """Convert samples encoded with delay pattern back to the original form.
94
+
95
+ Args:
96
+ data (:obj:`torch.Tensor`):
97
+ The data with delay pattern applied. It will have shape (num_codebooks, seq_len + num_codebooks - 1).
98
+
99
+ Returns:
100
+ ret (:obj:`torch.Tensor`):
101
+ Recovered data with delay pattern removed. It will have shape (num_codebooks, seq_len).
102
+ """
103
+ assert len(data.shape) == 2
104
+ out_l = []
105
+ num_codebooks = data.shape[0]
106
+ for i in range(num_codebooks):
107
+ out_l.append(data[i : (i + 1), i : (data.shape[1] - num_codebooks + 1 + i)])
108
+ return torch.cat(out_l, dim=0)
109
+
110
+
111
+ def merge_input_ids_with_audio_features(
112
+ audio_features_embed,
113
+ audio_features_length,
114
+ audio_in_embed,
115
+ audio_in_ids_start,
116
+ audio_out_embed,
117
+ audio_out_ids_start,
118
+ audio_in_token_idx,
119
+ audio_out_token_idx,
120
+ inputs_embeds,
121
+ input_ids,
122
+ attention_mask,
123
+ label_ids,
124
+ pad_token_id,
125
+ ignore_index=-100,
126
+ round_to=8,
127
+ left_padding=True,
128
+ ):
129
+ """
130
+ Merge input_ids with audio features into final embeddings.
131
+
132
+ Args:
133
+ audio_features_embed (`torch.Tensor` of shape `(num_audios, max_audio_tokens, embed_dim)`):
134
+ Encoded vectors of all audios in the batch (obtained from the semantic encoder)
135
+ audio_features_length (`torch.LongTensor` of shape `(num_audios,)`):
136
+ The length of audio embeddings of each audio as stacked in `audio_features_embed`
137
+ audio_in_embed (`torch.Tensor` of shape `(total_num_audio_in_tokens, embed_dim)`):
138
+ The embeddings of audio-in tokens
139
+ audio_in_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
140
+ The start index of the audio-in tokens for each audio
141
+ audio_out_embed (`torch.Tensor` of shape `(total_num_audio_out_tokens, embed_dim)`):
142
+ The embeddings of audio-out tokens
143
+ audio_out_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
144
+ The start index of the audio-out tokens for each audio
145
+ audio_in_token_idx
146
+ The index of the audio-in token in the vocabulary
147
+ audio_out_token_idx
148
+ The index of the audio-out token in the vocabulary
149
+ inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
150
+ Token embeddings before merging with audio embeddings
151
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
152
+ Input_ids of tokens, possibly filled with audio token
153
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
154
+ Mask to avoid performing attention on padding token indices.
155
+ label_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
156
+ labels need to be recalculated to support training (if provided)
157
+ pad_token_id (`int`):
158
+ The index of the pad token in the vocabulary
159
+ ignore_index
160
+ The index to ignore in the loss calculation
161
+ round_to
162
+ The number to round to for padding
163
+ left_padding
164
+ Whether to apply left padding
165
+
166
+ Returns:
167
+ final_embedding
168
+ The final embeddings after merging audio embeddings with text embeddings.
169
+ final_attention_mask
170
+ The final attention mask after merging audio embeddings with text embeddings.
171
+ final_labels
172
+ The labels for the text stream
173
+ position_ids
174
+ Positional ids for the merged data
175
+ final_input_ids
176
+ The final input_ids after merging audio embeddings with text embeddings.
177
+ final_audio_in_mask
178
+ Mask for audio-in embeddings
179
+ final_audio_in_discrete_codes_mask
180
+ Mask for audio-in discrete tokens
181
+ final_audio_out_mask
182
+ Mask for audio-out embeddings
183
+
184
+ Explanation:
185
+ each audio has variable length embeddings, with length specified by
186
+ - audio_features_length
187
+ - audio_in_ids_start
188
+ - audio_out_ids_start
189
+
190
+ Task:
191
+ - fill each <|AUDIO|> with audio embeddings (it can be the combination of embeddings extracted by WhisperEncoder and embeddings from audio codebooks)
192
+ - fill each <|AUDIO_OUT|> with the audio-out embeddings
193
+
194
+ Example:
195
+ <|AUDIO_OUT|>: X (5 tokens), Y (3 tokens)
196
+ <|AUDIO|>: Z (8 tokens)
197
+
198
+ X, Y are in the same sequence (in-context voice-clone). Z is in a different sequence (audio understanding).
199
+ if right padding
200
+ input_ids: [
201
+ a b c d e f X g h i j k Y l m
202
+ o p q r Z s t u v _ _ _ _ _ _
203
+ ]
204
+ input_ids should be: [
205
+ a b c d e f X X X X X g h i j k Y Y Y l m
206
+ o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
207
+ ]
208
+ labels should be: [
209
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
210
+ o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
211
+ ]
212
+ elif left padding
213
+ input_ids: [
214
+ a b c d e f X g h i j k Y l m
215
+ _ _ _ _ _ _ o p q r Z s t u v
216
+ ]
217
+ input_ids should be: [
218
+ a b c d e f X X X X X g h i j k Y Y Y l m
219
+ _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
220
+ ]
221
+ labels should be: [
222
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
223
+ _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
224
+ ]
225
+
226
+ """
227
+ if label_ids is None:
228
+ skip_labels = True
229
+ else:
230
+ skip_labels = False
231
+ if audio_features_embed is not None and audio_features_embed.shape[0] == 0:
232
+ audio_features_embed = None
233
+ if audio_in_embed is not None and audio_in_embed.shape[0] == 0:
234
+ audio_in_embed = None
235
+ if audio_out_embed is not None and audio_out_embed.shape[0] == 0:
236
+ audio_out_embed = None
237
+
238
+ batch_size, sequence_length, embed_dim = inputs_embeds.shape
239
+
240
+ target_device = inputs_embeds.device
241
+ if left_padding is None:
242
+ left_padding = torch.any(attention_mask[:, 0] == 0)
243
+
244
+ audio_in_token_mask = input_ids == audio_in_token_idx
245
+ audio_out_token_mask = input_ids == audio_out_token_idx
246
+ text_token_mask = (input_ids != audio_in_token_idx) & (input_ids != audio_out_token_idx)
247
+
248
+ # 1. Calculate the number of tokens for each placeholder (like [<|AUDIO|>, <|AUDIO_OUT|>]).
249
+ token_placeholder_num = torch.ones_like(input_ids)
250
+
251
+ if audio_features_embed is not None:
252
+ num_audios, max_audio_tokens, _ = audio_features_embed.shape
253
+ audio_in_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to(
254
+ audio_features_length.device
255
+ ) < audio_features_length.unsqueeze(1)
256
+ masked_audio_in_features = audio_features_embed[audio_in_features_mask].view(-1, embed_dim)
257
+ token_placeholder_num[audio_in_token_mask] = audio_features_length.long()
258
+
259
+ if audio_in_embed is not None:
260
+ audio_in_codes_length = torch.concat(
261
+ [
262
+ audio_in_ids_start[1:] - audio_in_ids_start[:-1],
263
+ torch.tensor(
264
+ [audio_in_embed.shape[0] - audio_in_ids_start[-1]],
265
+ device=audio_in_ids_start.device,
266
+ dtype=torch.long,
267
+ ),
268
+ ],
269
+ dim=0,
270
+ )
271
+ if audio_features_embed is not None:
272
+ token_placeholder_num[audio_in_token_mask] += audio_in_codes_length.long()
273
+ else:
274
+ token_placeholder_num[audio_in_token_mask] = audio_in_codes_length.long()
275
+
276
+ if audio_out_embed is not None:
277
+ audio_out_codes_length = torch.concat(
278
+ [
279
+ audio_out_ids_start[1:] - audio_out_ids_start[:-1],
280
+ torch.tensor(
281
+ [audio_out_embed.shape[0] - audio_out_ids_start[-1]],
282
+ device=audio_out_ids_start.device,
283
+ dtype=torch.long,
284
+ ),
285
+ ],
286
+ dim=0,
287
+ )
288
+ token_placeholder_num[audio_out_token_mask] = audio_out_codes_length.long()
289
+
290
+ new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1
291
+ max_token_num = _ceil_to_nearest(token_placeholder_num.sum(-1).max(), round_to)
292
+ nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1]
293
+
294
+ if left_padding:
295
+ new_token_positions += nb_audio_pad[:, None] # offset for left padding
296
+
297
+ # 2. Create the full embedding, already padded to the maximum position
298
+ final_embedding = torch.zeros(
299
+ (batch_size, max_token_num, embed_dim),
300
+ dtype=inputs_embeds.dtype,
301
+ device=inputs_embeds.device,
302
+ )
303
+ final_attention_mask = torch.zeros(
304
+ (batch_size, max_token_num),
305
+ dtype=attention_mask.dtype,
306
+ device=inputs_embeds.device,
307
+ )
308
+ final_input_ids = torch.full(
309
+ (batch_size, max_token_num),
310
+ pad_token_id,
311
+ dtype=input_ids.dtype,
312
+ device=inputs_embeds.device,
313
+ )
314
+ if skip_labels:
315
+ final_labels = None
316
+ else:
317
+ final_labels = torch.full(
318
+ (batch_size, max_token_num),
319
+ ignore_index,
320
+ dtype=label_ids.dtype,
321
+ device=inputs_embeds.device,
322
+ )
323
+
324
+ final_audio_in_mask = torch.full(
325
+ (batch_size, max_token_num),
326
+ False,
327
+ dtype=torch.bool,
328
+ device=inputs_embeds.device,
329
+ )
330
+ final_audio_in_discrete_codes_mask = torch.full(
331
+ (batch_size, max_token_num),
332
+ False,
333
+ dtype=torch.bool,
334
+ device=inputs_embeds.device,
335
+ )
336
+ final_audio_out_mask = torch.full(
337
+ (batch_size, max_token_num),
338
+ False,
339
+ dtype=torch.bool,
340
+ device=inputs_embeds.device,
341
+ )
342
+ # 3. Get the audio-in token positions and audio-out token positions
343
+ batch_id = torch.arange(batch_size, device=target_device).unsqueeze(1).expand(batch_size, sequence_length)
344
+ audio_in_batch_id = batch_id[audio_in_token_mask] # Shape (num_audio_in,)
345
+ audio_out_batch_id = batch_id[audio_out_token_mask] # Shape (num_audio_out,)
346
+ audio_features_token_ends = new_token_positions[audio_in_token_mask] # Shape (num_audio_in,)
347
+ audio_out_embed_ends = new_token_positions[audio_out_token_mask] # Shape (num_audio_out,)
348
+
349
+ if audio_in_embed is not None:
350
+ # Fill in the audio-in embeddings
351
+ seq_indices = (
352
+ torch.arange(max_token_num, device=target_device)
353
+ .unsqueeze(0)
354
+ .expand(audio_in_ids_start.shape[0], max_token_num)
355
+ )
356
+ audio_in_embed_token_starts = audio_features_token_ends - audio_in_codes_length + 1
357
+ batch_indices, col_indices = torch.where(
358
+ (seq_indices >= audio_in_embed_token_starts.unsqueeze(1))
359
+ & (seq_indices <= audio_features_token_ends.unsqueeze(1))
360
+ )
361
+ batch_indices = audio_in_batch_id[batch_indices]
362
+ final_embedding[batch_indices, col_indices] = audio_in_embed
363
+ final_input_ids[batch_indices, col_indices] = audio_in_token_idx
364
+ if not skip_labels:
365
+ final_labels[batch_indices, col_indices] = ignore_index
366
+ final_audio_in_mask[batch_indices, col_indices] = True
367
+ final_audio_in_discrete_codes_mask[batch_indices, col_indices] = True
368
+ audio_features_token_ends = audio_features_token_ends - audio_in_codes_length
369
+
370
+ if audio_features_embed is not None:
371
+ # Fill in the audio features
372
+ seq_indices = (
373
+ torch.arange(max_token_num, device=target_device)
374
+ .unsqueeze(0)
375
+ .expand(audio_features_embed.shape[0], max_token_num)
376
+ )
377
+ audio_features_token_starts = audio_features_token_ends - audio_features_length + 1
378
+ batch_indices, col_indices = torch.where(
379
+ (seq_indices >= audio_features_token_starts.unsqueeze(1))
380
+ & (seq_indices <= audio_features_token_ends.unsqueeze(1))
381
+ )
382
+ batch_indices = audio_in_batch_id[batch_indices]
383
+ final_embedding[batch_indices, col_indices] = masked_audio_in_features
384
+ final_input_ids[batch_indices, col_indices] = audio_in_token_idx
385
+ if not skip_labels:
386
+ final_labels[batch_indices, col_indices] = ignore_index
387
+ final_audio_in_mask[batch_indices, col_indices] = True
388
+
389
+ if audio_out_embed is not None:
390
+ # Fill in the audio-out embeddings
391
+ seq_indices = (
392
+ torch.arange(max_token_num, device=target_device)
393
+ .unsqueeze(0)
394
+ .expand(audio_out_ids_start.shape[0], max_token_num)
395
+ )
396
+ audio_out_embed_token_starts = audio_out_embed_ends - audio_out_codes_length + 1
397
+ batch_indices, col_indices = torch.where(
398
+ (seq_indices >= audio_out_embed_token_starts.unsqueeze(1))
399
+ & (seq_indices <= audio_out_embed_ends.unsqueeze(1))
400
+ )
401
+ batch_indices = audio_out_batch_id[batch_indices]
402
+ final_embedding[batch_indices, col_indices] = audio_out_embed
403
+ final_input_ids[batch_indices, col_indices] = audio_out_token_idx
404
+ if not skip_labels:
405
+ final_labels[batch_indices, col_indices] = ignore_index
406
+ final_audio_out_mask[batch_indices, col_indices] = True
407
+
408
+ # Fill in the original text embeddings and labels
409
+ batch_indices, non_audio_indices = torch.where(text_token_mask)
410
+ text_to_overwrite = new_token_positions[batch_indices, non_audio_indices]
411
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_audio_indices]
412
+ if not skip_labels:
413
+ final_labels[batch_indices, text_to_overwrite] = label_ids[batch_indices, non_audio_indices]
414
+ final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_audio_indices]
415
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_audio_indices]
416
+ final_attention_mask = final_attention_mask | final_audio_in_mask | final_audio_out_mask
417
+
418
+ # Trim the tensor if there are redundant padding tokens
419
+ if left_padding:
420
+ first_non_zero_loc = final_attention_mask.sum(0).nonzero()[0]
421
+ first_non_zero_loc = (first_non_zero_loc // round_to) * round_to
422
+ if first_non_zero_loc > 0:
423
+ final_attention_mask = final_attention_mask[:, first_non_zero_loc:]
424
+ final_embedding = final_embedding[:, first_non_zero_loc:]
425
+ if not skip_labels:
426
+ final_labels = final_labels[:, first_non_zero_loc:]
427
+ final_input_ids = final_input_ids[:, first_non_zero_loc:]
428
+ final_audio_in_mask = final_audio_in_mask[:, first_non_zero_loc:]
429
+ final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, first_non_zero_loc:]
430
+ final_audio_out_mask = final_audio_out_mask[:, first_non_zero_loc:]
431
+ else:
432
+ # We have done right padding, so we need to trim the mask
433
+ last_non_zero_loc = final_attention_mask.sum(0).nonzero()[-1] + 1
434
+ last_non_zero_loc = ((last_non_zero_loc + round_to - 1) // round_to) * round_to
435
+ if last_non_zero_loc < max_token_num:
436
+ final_attention_mask = final_attention_mask[:, :last_non_zero_loc]
437
+ final_embedding = final_embedding[:, :last_non_zero_loc]
438
+ if not skip_labels:
439
+ final_labels = final_labels[:, :last_non_zero_loc]
440
+ final_input_ids = final_input_ids[:, :last_non_zero_loc]
441
+ final_audio_in_mask = final_audio_in_mask[:, :last_non_zero_loc]
442
+ final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, :last_non_zero_loc]
443
+ final_audio_out_mask = final_audio_out_mask[:, :last_non_zero_loc]
444
+
445
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
446
+ return (
447
+ final_embedding,
448
+ final_attention_mask,
449
+ final_labels,
450
+ position_ids,
451
+ final_input_ids,
452
+ final_audio_in_mask,
453
+ final_audio_in_discrete_codes_mask,
454
+ final_audio_out_mask,
455
+ )
456
+
457
+
458
+ def is_deepspeed_ulysses_enabled():
459
+ if deepspeed_groups is None:
460
+ return False
461
+
462
+ """Check if sequence parallelism is enabled."""
463
+ return deepspeed_groups._get_sequence_parallel_world_size() > 1
464
+
465
+
466
+ def support_deepspeed_ulysses(module):
467
+ """A decorator around Pytorch module. It is needed for the module that needs access to sequence parallel info."""
468
+ module._sp_size = None
469
+ module._sp_rank = None
470
+ module._sp_group = None
471
+
472
+ @property
473
+ def sp_size(self):
474
+ if self._sp_size is None:
475
+ self._sp_size = 1
476
+ if is_deepspeed_ulysses_enabled():
477
+ self._sp_size = deepspeed_groups._get_sequence_parallel_group().size()
478
+ return self._sp_size
479
+
480
+ @property
481
+ def sp_rank(self):
482
+ if self._sp_rank is None:
483
+ self._sp_rank = 0
484
+ if is_deepspeed_ulysses_enabled():
485
+ self._sp_rank = deepspeed_groups._get_sequence_parallel_rank()
486
+ return self._sp_rank
487
+
488
+ @property
489
+ def sp_group(self):
490
+ if self._sp_group is None and is_deepspeed_ulysses_enabled():
491
+ self._sp_group = deepspeed_groups._get_sequence_parallel_group()
492
+ return self._sp_group
493
+
494
+ module.sp_size = sp_size
495
+ module.sp_rank = sp_rank
496
+ module.sp_group = sp_group
497
+
498
+ return module
499
+
500
+
501
+ def deepspeed_ulysses_attention(seq_dim=1, head_dim=2):
502
+ """Perform all-to-all before and after the attention function."""
503
+
504
+ def attention_decorator(attn_func=None):
505
+ def wrapped(*args, **kwargs):
506
+ if is_deepspeed_ulysses_enabled():
507
+ sp_group = deepspeed_groups._get_sequence_parallel_group()
508
+ scatter_idx = head_dim # Scatter on num_heads dimension
509
+ gather_idx = seq_dim # Gather on seq_len dimension
510
+ batch_dim_idx = 0
511
+ args = list(args)
512
+ args[0] = _SeqAllToAll.apply(sp_group, args[0], scatter_idx, gather_idx, batch_dim_idx)
513
+ args[1] = _SeqAllToAll.apply(sp_group, args[1], scatter_idx, gather_idx, batch_dim_idx)
514
+ args[2] = _SeqAllToAll.apply(sp_group, args[2], scatter_idx, gather_idx, batch_dim_idx)
515
+ args = tuple(args)
516
+
517
+ attn_output = attn_func(*args, **kwargs)
518
+
519
+ if is_deepspeed_ulysses_enabled():
520
+ scatter_idx = seq_dim # Scatter back on seq_len dimension
521
+ gather_idx = head_dim # Gather on num_heads dimension
522
+ batch_dim_idx = 0
523
+ attn_output = _SeqAllToAll.apply(sp_group, attn_output, scatter_idx, gather_idx, batch_dim_idx)
524
+
525
+ return attn_output
526
+
527
+ return wrapped
528
+
529
+ return attention_decorator
530
+
531
+
532
+ def deepspeed_ulysses_rope(state_seq_dim=2, trig_seq_dim=1):
533
+ """Slice the corresponding cos and sin chunks for rope."""
534
+
535
+ def rope_decorator(rope_func=None):
536
+ def wrapped(*args, **kwargs):
537
+ if is_deepspeed_ulysses_enabled():
538
+ sp_rank = deepspeed_groups._get_sequence_parallel_rank()
539
+ args = list(args)
540
+ seq_chunk_size = args[0].size(state_seq_dim)
541
+ args[2] = torch.narrow(args[2], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
542
+ args[3] = torch.narrow(args[3], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
543
+ args = tuple(args)
544
+
545
+ return rope_func(*args, **kwargs)
546
+
547
+ return wrapped
548
+
549
+ return rope_decorator
550
+
551
+
552
+ def _gather_tensors(input_, group=None):
553
+ """Gather tensors and concatenate them along a dimension."""
554
+ input_ = input_.contiguous()
555
+ world_size = torch.distributed.get_world_size(group)
556
+ if world_size == 1:
557
+ return input_
558
+ tensor_shapes = [
559
+ torch.empty(len(input_.size()), dtype=torch.int64, device=input_.device) for _ in range(world_size)
560
+ ]
561
+ input_size = torch.tensor(input_.size(), dtype=torch.int64, device=input_.device)
562
+ torch.distributed.all_gather(tensor_shapes, input_size, group=group)
563
+ gathered_buffers = [
564
+ torch.empty(tensor_shapes[i].tolist(), dtype=input_.dtype, device=input_.device) for i in range(world_size)
565
+ ]
566
+ torch.distributed.all_gather(gathered_buffers, input_, group=group)
567
+ return gathered_buffers
568
+
569
+
570
+ def _scatter_tensors(input_, group=None):
571
+ """Scatter tensors."""
572
+ world_size = torch.distributed.get_world_size(group)
573
+ if world_size == 1:
574
+ return input_
575
+ rank = torch.distributed.get_rank(group)
576
+ return input_[rank]
577
+
578
+
579
+ class _GatherTensors(torch.autograd.Function):
580
+ """All gather tensors among the ranks."""
581
+
582
+ @staticmethod
583
+ def symbolic(graph, input_, group):
584
+ return _gather_tensors(input_, group)
585
+
586
+ @staticmethod
587
+ def forward(ctx, input_, group):
588
+ ctx.group = group
589
+ return torch.nested.as_nested_tensor(_gather_tensors(input_, group), layout=torch.jagged)
590
+
591
+ @staticmethod
592
+ def backward(ctx, grad_output):
593
+ return _scatter_tensors(grad_output, ctx.group), None
594
+
595
+
596
+ def all_gather_tensors(input_, size=None, dim=0, group=None):
597
+ if torch.distributed.get_world_size(group) == 1:
598
+ # no sequence parallelism
599
+ return input_
600
+ gathered_tensors = _GatherTensors.apply(input_, group)
601
+
602
+ if size:
603
+ split_gathered_tensors = []
604
+ for s, gathered_tensor in zip(size, gathered_tensors):
605
+ split_gathered_tensor = torch.split(gathered_tensor, s.tolist())
606
+ split_gathered_tensors.append(split_gathered_tensor)
607
+
608
+ gathered_tensors = [y for x in zip(*split_gathered_tensors) for y in x]
609
+
610
+ return torch.cat(gathered_tensors, dim).contiguous()
611
+
612
+
613
+ def get_sequence_data_parallel_world_size():
614
+ return torch.distributed.get_world_size()
615
+
616
+
617
+ def get_sequence_data_parallel_rank():
618
+ return torch.distributed.get_rank()
619
+
620
+
621
+ def get_sequence_data_parallel_group():
622
+ return torch.distributed.group.WORLD
623
+
624
+
625
+ if is_deepspeed_available():
626
+ deepspeed_groups._get_sequence_data_parallel_world_size = get_sequence_data_parallel_world_size
627
+ deepspeed_groups._get_sequence_data_parallel_rank = get_sequence_data_parallel_rank
628
+ deepspeed_groups._get_sequence_data_parallel_group = get_sequence_data_parallel_group
629
+
630
+
631
+ def _gather_tokens(input_, dim=0, group=None):
632
+ """Gather tensors and concatenate them along a dimension"""
633
+ input_ = input_.contiguous()
634
+ world_size = torch.distributed.get_world_size(group)
635
+ if world_size == 1:
636
+ return input_
637
+
638
+ gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device)
639
+ torch.distributed.all_gather_into_tensor(gather_buffer, input_, group=group)
640
+ if dim == 0:
641
+ shape = list(input_.size())
642
+ shape[0] = shape[0] * world_size
643
+ output = gather_buffer.view(shape)
644
+ else:
645
+ tensor_list = [
646
+ gather_buffer.narrow(0, input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size)
647
+ ]
648
+ # Note: torch.cat already creates a contiguous tensor.
649
+ output = torch.cat(tensor_list, dim=dim).contiguous()
650
+
651
+ return output
652
+
653
+
654
+ def _drop_tokens(input_, dim=0, group=None):
655
+ """Divide a tensor among the sequence parallel ranks"""
656
+ world_size = torch.distributed.get_world_size(group)
657
+ if world_size == 1:
658
+ return input_
659
+ this_rank = torch.distributed.get_rank(group)
660
+ assert input_.shape[dim] % world_size == 0, (
661
+ f"input dimension {dim} ({input_.shape[dim]}) is not divisible by sequence parallel world size ({world_size})"
662
+ )
663
+ chunk_size = input_.shape[dim] // world_size
664
+
665
+ return torch.narrow(input_, dim, this_rank * chunk_size, chunk_size)
666
+
667
+
668
+ class _DropTokens(torch.autograd.Function):
669
+ "Divide tokens equally among the sequence parallel ranks"
670
+
671
+ @staticmethod
672
+ def symbolic(graph, input_, dim, group, grad_scale):
673
+ return _drop_tokens(input_, dim, group)
674
+
675
+ @staticmethod
676
+ def forward(ctx, input_, dim, group, grad_scale):
677
+ ctx.dim = dim
678
+ ctx.group = group
679
+ ctx.grad_scale = grad_scale
680
+ return _drop_tokens(input_, dim, group)
681
+
682
+ @staticmethod
683
+ def backward(ctx, grad_output):
684
+ grad_input = _gather_tokens(grad_output, ctx.dim, ctx.group)
685
+ if ctx.grad_scale != 1:
686
+ grad_input /= ctx.grad_scale
687
+ return grad_input, None, None, None
688
+
689
+
690
+ class _GatherTokens(torch.autograd.Function):
691
+ "Gather tokens among the sequence parallel ranks"
692
+
693
+ @staticmethod
694
+ def symbolic(graph, input_, dim, group, grad_scale):
695
+ return _gather_tokens(input_, dim, group)
696
+
697
+ @staticmethod
698
+ def forward(ctx, input_, dim, group, grad_scale):
699
+ ctx.dim = dim
700
+ ctx.group = group
701
+ ctx.grad_scale = grad_scale
702
+ return _gather_tokens(input_, dim, group)
703
+
704
+ @staticmethod
705
+ def backward(ctx, grad_output):
706
+ grad_input = _drop_tokens(grad_output, ctx.dim, ctx.group)
707
+ if ctx.grad_scale != 1:
708
+ grad_input *= ctx.grad_scale
709
+ return grad_input, None, None, None
710
+
711
+
712
+ def drop_tokens(input_, dim=0, group=None, grad_scale=1):
713
+ if torch.distributed.get_world_size(group) == 1:
714
+ # no sequence parallelism
715
+ return input_
716
+ return _DropTokens.apply(input_, dim, group, grad_scale)
717
+
718
+
719
+ def gather_tokens(input_, dim=0, group=None, grad_scale=1):
720
+ if torch.distributed.get_world_size(group) == 1:
721
+ # no sequence parallelism
722
+ return input_
723
+ return _GatherTokens.apply(input_, dim, group, grad_scale)
724
+
725
+
726
+ def sequence_chunking_per_rank(sp_size, sp_rank, *args, dim=1):
727
+ """
728
+ Slice the inputs to create chuncks per the sequence parallel rank. This is used for the context parallel training.
729
+
730
+ Args:
731
+ sp_size (`int`):
732
+ Sequence parallel size.
733
+ sp_rank (`int`):
734
+ Sequence parallel rank for the current process.
735
+ dim (`int`):
736
+ The dimension to slice
737
+ """
738
+ if sp_size == 1:
739
+ return args[0] if len(args) == 1 else args
740
+
741
+ seq_length = args[0].size(dim)
742
+ for arg in args[1:]:
743
+ assert arg.size(dim) == seq_length, (
744
+ f"arg={arg} ({arg.shape[dim]}) does not have the same size as args[0] ({seq_length}) in dimension {dim}"
745
+ )
746
+ assert seq_length % sp_size == 0, (
747
+ f"dimension {dim} ({args[0].shape[dim]}) is not divisible by sequence parallel world size ({sp_size})"
748
+ )
749
+
750
+ sub_seq_length = seq_length // sp_size
751
+ sub_seq_start = sp_rank * sub_seq_length
752
+
753
+ output = []
754
+ for ind in args:
755
+ ind = torch.narrow(ind, dim, sub_seq_start, sub_seq_length)
756
+ output.append(ind)
757
+
758
+ return tuple(output) if len(output) > 1 else output[0]
759
+
760
+
761
+ @contextmanager
762
+ def disable_deepspeed_ulysses():
763
+ """Disable deepspeed ulysses (sequence parallelism) if it is enabled"""
764
+ if is_deepspeed_ulysses_enabled():
765
+ _old_get_sequence_parallel_world_size = deepspeed_groups._get_sequence_parallel_world_size
766
+
767
+ def _get_sequence_parallel_world_size():
768
+ return 1
769
+
770
+ deepspeed_groups._get_sequence_parallel_world_size = _get_sequence_parallel_world_size
771
+ try:
772
+ yield
773
+ finally:
774
+ deepspeed_groups._get_sequence_parallel_world_size = _old_get_sequence_parallel_world_size
775
+ else:
776
+ context = contextlib.nullcontext
777
+ with context():
778
+ yield
higgs_audio/serve/serve_engine.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import torch
4
+ import numpy as np
5
+ from io import BytesIO
6
+ from dataclasses import dataclass
7
+ from typing import List, Optional, Union
8
+ from copy import deepcopy
9
+ from transformers import AutoTokenizer, AutoProcessor
10
+ from transformers.cache_utils import StaticCache
11
+ from transformers.generation.streamers import BaseStreamer
12
+ from transformers.generation.stopping_criteria import StoppingCriteria
13
+ from dataclasses import asdict
14
+ from loguru import logger
15
+ import threading
16
+ import librosa
17
+
18
+
19
+ from ..dataset.chatml_dataset import (
20
+ ChatMLSample,
21
+ ChatMLDatasetSample,
22
+ prepare_chatml_sample,
23
+ )
24
+ from ..model import HiggsAudioModel
25
+ from ..model.utils import revert_delay_pattern
26
+ from ..data_collator.higgs_audio_collator import HiggsAudioSampleCollator
27
+ from ..audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer
28
+
29
+
30
+ @dataclass
31
+ class HiggsAudioStreamerDelta:
32
+ """Represents a chunk of generated content, either text or audio tokens."""
33
+
34
+ text: Optional[str] = None
35
+ text_tokens: Optional[torch.Tensor] = None
36
+ audio_tokens: Optional[torch.Tensor] = None
37
+ finish_reason: Optional[str] = None
38
+
39
+
40
+ class AsyncHiggsAudioStreamer(BaseStreamer):
41
+ """
42
+ Async streamer that handles both text and audio token generation from Higgs-Audio model.
43
+ Stores chunks in a queue to be consumed by downstream applications.
44
+
45
+ Parameters:
46
+ tokenizer (`AutoTokenizer`):
47
+ The tokenizer used to decode text tokens.
48
+ skip_prompt (`bool`, *optional*, defaults to `False`):
49
+ Whether to skip the prompt tokens in generation.
50
+ timeout (`float`, *optional*):
51
+ The timeout for the queue. If `None`, the queue will block indefinitely.
52
+ decode_kwargs (`dict`, *optional*):
53
+ Additional keyword arguments to pass to the tokenizer's `decode` method.
54
+
55
+ Examples:
56
+ ```python
57
+ >>> from transformers import AutoTokenizer
58
+ >>> from threading import Thread
59
+ >>> import asyncio
60
+
61
+ >>> tokenizer = AutoTokenizer.from_pretrained("path/to/higgs/tokenizer")
62
+ >>> model = HiggsAudioModel.from_pretrained("path/to/higgs/model")
63
+ >>> inputs = tokenizer(["Generate some text and audio:"], return_tensors="pt")
64
+
65
+ >>> async def main():
66
+ ... streamer = AsyncHiggsAudioStreamer(tokenizer)
67
+ ... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
68
+ ... thread = Thread(target=model.generate, kwargs=generation_kwargs)
69
+ ... thread.start()
70
+ ...
71
+ ... async for delta in streamer:
72
+ ... if delta.text is not None:
73
+ ... print("Text:", delta.text)
74
+ ... if delta.audio_tokens is not None:
75
+ ... print("Audio tokens shape:", delta.audio_tokens.shape)
76
+ >>> asyncio.run(main())
77
+ ```
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ tokenizer: "AutoTokenizer",
83
+ skip_prompt: bool = False,
84
+ timeout: Optional[float] = None,
85
+ audio_num_codebooks: int = 1,
86
+ **decode_kwargs,
87
+ ):
88
+ self.tokenizer = tokenizer
89
+ self.skip_prompt = skip_prompt
90
+ self.timeout = timeout
91
+ self.decode_kwargs = decode_kwargs
92
+ self.audio_num_codebooks = audio_num_codebooks
93
+
94
+ # Queue to store generated chunks
95
+ self.queue = asyncio.Queue()
96
+ self.stop_signal = None
97
+
98
+ # Get running event loop
99
+ self.loop = asyncio.get_running_loop()
100
+ self.has_asyncio_timeout = hasattr(asyncio, "timeout")
101
+
102
+ # State tracking
103
+ self.next_tokens_are_prompt = True
104
+
105
+ def put(self, value: torch.Tensor):
106
+ """
107
+ Receives tokens and processes them as either text or audio tokens.
108
+ For text tokens, decodes and caches them until complete words are formed.
109
+ For audio tokens, directly queues them.
110
+ """
111
+ if value.shape[0] > 1 and not self.next_tokens_are_prompt:
112
+ # This is likely audio tokens (shape: [audio_num_codebooks])
113
+ assert value.shape[0] == self.audio_num_codebooks, "Number of codebooks mismatch"
114
+ delta = HiggsAudioStreamerDelta(audio_tokens=value)
115
+ self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
116
+ return
117
+
118
+ # Skip prompt tokens if configured
119
+ if self.skip_prompt and self.next_tokens_are_prompt:
120
+ self.next_tokens_are_prompt = False
121
+ return
122
+
123
+ # Process as text tokens
124
+ if len(value.shape) > 1:
125
+ value = value[0]
126
+
127
+ text = self.tokenizer.decode(value, **self.decode_kwargs)
128
+ delta = HiggsAudioStreamerDelta(text=text, text_tokens=value)
129
+ self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
130
+
131
+ def end(self):
132
+ """Flushes any remaining text tokens and signals the end of generation."""
133
+ self.next_tokens_are_prompt = True
134
+ self.loop.call_soon_threadsafe(self.queue.put_nowait, self.stop_signal)
135
+
136
+ def __aiter__(self):
137
+ return self
138
+
139
+ async def __anext__(self):
140
+ try:
141
+ if self.has_asyncio_timeout:
142
+ async with asyncio.timeout(self.timeout):
143
+ value = await self.queue.get()
144
+ else:
145
+ value = await asyncio.wait_for(self.queue.get(), timeout=self.timeout)
146
+ except asyncio.TimeoutError:
147
+ raise TimeoutError()
148
+ else:
149
+ if value == self.stop_signal:
150
+ raise StopAsyncIteration()
151
+ else:
152
+ return value
153
+
154
+
155
+ class AsyncStoppingCriteria(StoppingCriteria):
156
+ """
157
+ Stopping criteria that checks for stop signal from a threading event.
158
+
159
+ Args:
160
+ stop_signal (threading.Event): Event that will receive stop signals
161
+ """
162
+
163
+ def __init__(self, stop_signal: threading.Event):
164
+ self.stop_signal = stop_signal
165
+
166
+ def __call__(self, input_ids, scores, **kwargs) -> bool:
167
+ if self.stop_signal.is_set():
168
+ logger.info(f"Stop signal received. Can be caused by client disconnection.")
169
+ return True
170
+ return False
171
+
172
+
173
+ @dataclass
174
+ class HiggsAudioResponse:
175
+ audio: Optional[np.ndarray] = None
176
+ generated_audio_tokens: Optional[np.ndarray] = None
177
+ sampling_rate: Optional[int] = None
178
+ generated_text: str = ""
179
+ generated_text_tokens: np.ndarray = np.array([])
180
+ usage: Optional[dict] = None
181
+
182
+
183
+ class HiggsAudioServeEngine:
184
+ def __init__(
185
+ self,
186
+ model_name_or_path: str,
187
+ audio_tokenizer_name_or_path: str,
188
+ tokenizer_name_or_path: Optional[str] = None,
189
+ device: str = "cuda",
190
+ torch_dtype: Union[torch.dtype, str] = "auto",
191
+ kv_cache_lengths: List[int] = [1024, 4096, 8192], # Multiple KV cache sizes
192
+ ):
193
+ """
194
+ Initialize the HiggsAudioServeEngine, a serving wrapper for the HiggsAudioModel.
195
+ The model, tokenizer, and audio tokenizer will be downloaded from the Hugging Face Hub if they are not local.
196
+
197
+ Args:
198
+ model_name_or_path (str):
199
+ The name or path of the model to load.
200
+ audio_tokenizer_name_or_path (str):
201
+ The name or path of the audio tokenizer to load.
202
+ tokenizer_name_or_path (str):
203
+ The name or path of the tokenizer to load.
204
+ device (str):
205
+ The device to use for the model.
206
+ kv_cache_lengths (List[int]):
207
+ The lengths of the KV caches to use for the model. Used for cuda graph capture when device is cuda.
208
+ torch_dtype (Union[torch.dtype, str]):
209
+ The dtype to use for the model.
210
+ """
211
+ self.device = device
212
+ self.model_name_or_path = model_name_or_path
213
+ self.torch_dtype = torch_dtype
214
+
215
+ # Initialize model and tokenizer
216
+ self.model = HiggsAudioModel.from_pretrained(model_name_or_path, torch_dtype=torch_dtype).to(device)
217
+ logger.info(f"Loaded model from {model_name_or_path}, dtype: {self.model.dtype}")
218
+
219
+ if tokenizer_name_or_path is None:
220
+ tokenizer_name_or_path = model_name_or_path
221
+ logger.info(f"Loading tokenizer from {tokenizer_name_or_path}")
222
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
223
+
224
+ logger.info(f"Initializing Higgs Audio Tokenizer")
225
+ self.audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer_name_or_path, device=device)
226
+
227
+ self.audio_num_codebooks = self.model.config.audio_num_codebooks
228
+ self.audio_codebook_size = self.model.config.audio_codebook_size
229
+ self.audio_tokenizer_tps = self.audio_tokenizer.tps
230
+ self.samples_per_token = int(self.audio_tokenizer.sampling_rate // self.audio_tokenizer_tps)
231
+ self.hamming_window_len = 2 * self.audio_num_codebooks * self.samples_per_token
232
+ # Set the audio special tokens
233
+ self.model.set_audio_special_tokens(self.tokenizer)
234
+
235
+ # Prepare KV caches for different lengths
236
+ cache_config = deepcopy(self.model.config.text_config)
237
+ cache_config.num_hidden_layers = self.model.config.text_config.num_hidden_layers
238
+ if self.model.config.audio_dual_ffn_layers:
239
+ cache_config.num_hidden_layers += len(self.model.config.audio_dual_ffn_layers)
240
+ # A list of KV caches for different lengths
241
+ self.kv_caches = {
242
+ length: StaticCache(
243
+ config=cache_config,
244
+ max_batch_size=1,
245
+ max_cache_len=length,
246
+ device=self.model.device,
247
+ dtype=self.model.dtype,
248
+ )
249
+ for length in sorted(kv_cache_lengths)
250
+ }
251
+
252
+ if self.model.config.encode_whisper_embed:
253
+ logger.info(f"Loading whisper processor")
254
+ whisper_processor = AutoProcessor.from_pretrained(
255
+ "openai/whisper-large-v3-turbo",
256
+ trust_remote=True,
257
+ device=self.device,
258
+ )
259
+ else:
260
+ whisper_processor = None
261
+
262
+ # Reuse collator to prepare inference samples
263
+ self.collator = HiggsAudioSampleCollator(
264
+ whisper_processor=whisper_processor,
265
+ encode_whisper_embed=self.model.config.encode_whisper_embed,
266
+ audio_in_token_id=self.model.config.audio_in_token_idx,
267
+ audio_out_token_id=self.model.config.audio_out_token_idx,
268
+ audio_stream_bos_id=self.model.config.audio_stream_bos_id,
269
+ audio_stream_eos_id=self.model.config.audio_stream_eos_id,
270
+ pad_token_id=self.model.config.pad_token_id,
271
+ return_audio_in_tokens=False,
272
+ use_delay_pattern=self.model.config.use_delay_pattern,
273
+ audio_num_codebooks=self.model.config.audio_num_codebooks,
274
+ round_to=1,
275
+ )
276
+
277
+ # Lock to prevent multiple generations from happening at the same time
278
+ self.generate_lock = threading.Lock()
279
+
280
+ # Capture CUDA graphs for each KV cache length
281
+ if device == "cuda":
282
+ logger.info(f"Capturing CUDA graphs for each KV cache length")
283
+ self.model.capture_model(self.kv_caches.values())
284
+
285
+ def _prepare_inputs(self, chat_ml_sample: ChatMLSample, force_audio_gen: bool = False):
286
+ input_tokens, _, audio_contents, _ = prepare_chatml_sample(
287
+ chat_ml_sample,
288
+ self.tokenizer,
289
+ )
290
+
291
+ postfix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
292
+ if force_audio_gen:
293
+ postfix += "<|audio_out_bos|>"
294
+ postfix = self.tokenizer.encode(postfix, add_special_tokens=False)
295
+ input_tokens.extend(postfix)
296
+
297
+ # Configure the audio inputs
298
+ audio_ids_l = []
299
+ for audio_content in audio_contents:
300
+ if audio_content.audio_url not in ["placeholder", ""]:
301
+ raw_audio, _ = librosa.load(audio_content.audio_url, sr=self.audio_tokenizer.sampling_rate)
302
+ elif audio_content.raw_audio is not None:
303
+ raw_audio, _ = librosa.load(
304
+ BytesIO(base64.b64decode(audio_content.raw_audio)),
305
+ sr=self.audio_tokenizer.sampling_rate,
306
+ )
307
+ else:
308
+ raw_audio = None
309
+
310
+ if raw_audio is not None:
311
+ audio_ids = self.audio_tokenizer.encode(raw_audio, self.audio_tokenizer.sampling_rate)
312
+ audio_ids_l.append(audio_ids.squeeze(0).cpu())
313
+
314
+ if len(audio_ids_l) > 0:
315
+ audio_ids_start = torch.tensor(
316
+ np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
317
+ dtype=torch.long,
318
+ device=self.device,
319
+ )[0:-1]
320
+ audio_ids_concat = torch.cat(audio_ids_l, dim=1)
321
+ else:
322
+ audio_ids_start = None
323
+ audio_ids_concat = None
324
+
325
+ sample = ChatMLDatasetSample(
326
+ input_ids=torch.LongTensor(input_tokens),
327
+ label_ids=None,
328
+ audio_ids_concat=audio_ids_concat,
329
+ audio_ids_start=audio_ids_start,
330
+ audio_waveforms_concat=None,
331
+ audio_waveforms_start=None,
332
+ audio_sample_rate=None,
333
+ audio_speaker_indices=None,
334
+ )
335
+ data = self.collator([sample])
336
+ inputs = asdict(data)
337
+ for k, v in inputs.items():
338
+ if isinstance(v, torch.Tensor):
339
+ inputs[k] = v.to(self.model.device)
340
+
341
+ return inputs
342
+
343
+ def _prepare_kv_caches(self):
344
+ for kv_cache in self.kv_caches.values():
345
+ kv_cache.reset()
346
+
347
+ def generate(
348
+ self,
349
+ chat_ml_sample: ChatMLSample,
350
+ max_new_tokens: int,
351
+ temperature: float = 0.7,
352
+ top_k: Optional[int] = None,
353
+ top_p: float = 0.95,
354
+ stop_strings: Optional[List[str]] = None,
355
+ force_audio_gen: bool = False,
356
+ ras_win_len: Optional[int] = None,
357
+ ras_win_max_num_repeat: int = 2,
358
+ ):
359
+ """
360
+ Generate audio from a chatml sample.
361
+ Args:
362
+ chat_ml_sample: A chatml sample.
363
+ max_new_tokens: The maximum number of new tokens to generate.
364
+ temperature: The temperature to use for the generation.
365
+ top_p: The top p to use for the generation.
366
+ Returns:
367
+ A dictionary with the following keys:
368
+ audio: The generated audio.
369
+ sampling_rate: The sampling rate of the generated audio.
370
+ """
371
+ # Default stop strings
372
+ if stop_strings is None:
373
+ stop_strings = ["<|end_of_text|>", "<|eot_id|>"]
374
+
375
+ with torch.no_grad(), self.generate_lock:
376
+ inputs = self._prepare_inputs(chat_ml_sample, force_audio_gen=force_audio_gen)
377
+ prompt_token_ids = inputs["input_ids"][0].cpu().numpy()
378
+
379
+ self._prepare_kv_caches()
380
+
381
+ outputs = self.model.generate(
382
+ **inputs,
383
+ max_new_tokens=max_new_tokens,
384
+ use_cache=True,
385
+ stop_strings=stop_strings,
386
+ tokenizer=self.tokenizer,
387
+ do_sample=False if temperature == 0.0 else True,
388
+ temperature=temperature,
389
+ top_k=top_k,
390
+ top_p=top_p,
391
+ past_key_values_buckets=self.kv_caches,
392
+ ras_win_len=ras_win_len,
393
+ ras_win_max_num_repeat=ras_win_max_num_repeat,
394
+ )
395
+
396
+ if len(outputs[1]) > 0:
397
+ wv_list = []
398
+ for output_audio in outputs[1]:
399
+ vq_code = revert_delay_pattern(output_audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
400
+ wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
401
+ wv_list.append(wv_numpy)
402
+ wv_numpy = np.concatenate(wv_list)
403
+ else:
404
+ wv_numpy = None
405
+
406
+ # We only support one request at a time now
407
+ generated_text_tokens = outputs[0][0].cpu().numpy()[len(prompt_token_ids) :]
408
+ generated_text = self.tokenizer.decode(generated_text_tokens)
409
+ generated_audio_tokens = outputs[1][0].cpu().numpy()
410
+ return HiggsAudioResponse(
411
+ audio=wv_numpy,
412
+ generated_audio_tokens=generated_audio_tokens,
413
+ sampling_rate=self.audio_tokenizer.sampling_rate,
414
+ generated_text=generated_text,
415
+ generated_text_tokens=generated_text_tokens,
416
+ usage={
417
+ "prompt_tokens": prompt_token_ids.shape[0],
418
+ "completion_tokens": generated_text_tokens.shape[0] + generated_audio_tokens.shape[1],
419
+ "total_tokens": (
420
+ prompt_token_ids.shape[0] + generated_text_tokens.shape[0] + generated_audio_tokens.shape[1]
421
+ ),
422
+ "cached_tokens": 0,
423
+ },
424
+ )
higgs_audio/serve/utils.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import base64
3
+ import re
4
+ import regex
5
+ from typing import AsyncGenerator, Union
6
+ import io
7
+ from pydub import AudioSegment
8
+ import torch
9
+ import numpy as np
10
+ from functools import lru_cache
11
+
12
+ from ..audio_processing.higgs_audio_tokenizer import HiggsAudioTokenizer
13
+
14
+
15
+ def random_uuid() -> str:
16
+ return str(uuid.uuid4().hex)
17
+
18
+
19
+ async def async_generator_wrap(first_element, gen: AsyncGenerator):
20
+ """Wrap an async generator with the first element."""
21
+ yield first_element
22
+ async for item in gen:
23
+ yield item
24
+
25
+
26
+ @lru_cache(maxsize=50)
27
+ def encode_base64_content_from_file(file_path: str) -> str:
28
+ """Encode a content from a local file to base64 format."""
29
+ # Read the MP3 file as binary and encode it directly to Base64
30
+ with open(file_path, "rb") as audio_file:
31
+ audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8")
32
+ return audio_base64
33
+
34
+
35
+ def pcm16_to_target_format(
36
+ np_audio: np.ndarray,
37
+ sample_rate: int,
38
+ bit_depth: int,
39
+ channels: int,
40
+ format: str,
41
+ target_rate: int,
42
+ ):
43
+ wav_audio = AudioSegment(
44
+ np_audio.tobytes(),
45
+ frame_rate=sample_rate,
46
+ sample_width=bit_depth // 8,
47
+ channels=channels,
48
+ )
49
+ if target_rate is not None and target_rate != sample_rate:
50
+ wav_audio = wav_audio.set_frame_rate(target_rate)
51
+
52
+ # Convert WAV to MP3
53
+ target_io = io.BytesIO()
54
+ wav_audio.export(target_io, format=format)
55
+ target_io.seek(0)
56
+
57
+ return target_io
58
+
59
+
60
+ chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
61
+
62
+
63
+ def contains_chinese(text: str):
64
+ return bool(chinese_char_pattern.search(text))
65
+
66
+
67
+ # remove blank between chinese character
68
+ def replace_blank(text: str):
69
+ out_str = []
70
+ for i, c in enumerate(text):
71
+ if c == " ":
72
+ if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
73
+ out_str.append(c)
74
+ else:
75
+ out_str.append(c)
76
+ return "".join(out_str)
77
+
78
+
79
+ def replace_corner_mark(text: str):
80
+ text = text.replace("²", "平方")
81
+ text = text.replace("³", "立方")
82
+ return text
83
+
84
+
85
+ # remove meaningless symbol
86
+ def remove_bracket(text: str):
87
+ text = text.replace("(", "").replace(")", "")
88
+ text = text.replace("【", "").replace("】", "")
89
+ text = text.replace("`", "").replace("`", "")
90
+ text = text.replace("——", " ")
91
+ return text
92
+
93
+
94
+ # split paragrah logic:
95
+ # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
96
+ # 2. cal sentence len according to lang
97
+ # 3. split sentence according to puncatation
98
+ def split_paragraph(
99
+ text: str,
100
+ tokenize,
101
+ lang="zh",
102
+ token_max_n=80,
103
+ token_min_n=60,
104
+ merge_len=20,
105
+ comma_split=False,
106
+ ):
107
+ def calc_utt_length(_text: str):
108
+ if lang == "zh":
109
+ return len(_text)
110
+ else:
111
+ return len(tokenize(_text))
112
+
113
+ def should_merge(_text: str):
114
+ if lang == "zh":
115
+ return len(_text) < merge_len
116
+ else:
117
+ return len(tokenize(_text)) < merge_len
118
+
119
+ if lang == "zh":
120
+ pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"]
121
+ else:
122
+ pounc = [".", "?", "!", ";", ":"]
123
+ if comma_split:
124
+ pounc.extend([",", ","])
125
+
126
+ if text[-1] not in pounc:
127
+ if lang == "zh":
128
+ text += "。"
129
+ else:
130
+ text += "."
131
+
132
+ st = 0
133
+ utts = []
134
+ for i, c in enumerate(text):
135
+ if c in pounc:
136
+ if len(text[st:i]) > 0:
137
+ utts.append(text[st:i] + c)
138
+ if i + 1 < len(text) and text[i + 1] in ['"', "”"]:
139
+ tmp = utts.pop(-1)
140
+ utts.append(tmp + text[i + 1])
141
+ st = i + 2
142
+ else:
143
+ st = i + 1
144
+
145
+ final_utts = []
146
+ cur_utt = ""
147
+ for utt in utts:
148
+ if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
149
+ final_utts.append(cur_utt)
150
+ cur_utt = ""
151
+ cur_utt = cur_utt + utt
152
+ if len(cur_utt) > 0:
153
+ if should_merge(cur_utt) and len(final_utts) != 0:
154
+ final_utts[-1] = final_utts[-1] + cur_utt
155
+ else:
156
+ final_utts.append(cur_utt)
157
+
158
+ return final_utts
159
+
160
+
161
+ def is_only_punctuation(text: str):
162
+ # Regular expression: Match strings that consist only of punctuation marks or are empty.
163
+ punctuation_pattern = r"^[\p{P}\p{S}]*$"
164
+ return bool(regex.fullmatch(punctuation_pattern, text))
165
+
166
+
167
+ # spell Arabic numerals
168
+ def spell_out_number(text: str, inflect_parser):
169
+ new_text = []
170
+ st = None
171
+ for i, c in enumerate(text):
172
+ if not c.isdigit():
173
+ if st is not None:
174
+ num_str = inflect_parser.number_to_words(text[st:i])
175
+ new_text.append(num_str)
176
+ st = None
177
+ new_text.append(c)
178
+ else:
179
+ if st is None:
180
+ st = i
181
+ if st is not None and st < len(text):
182
+ num_str = inflect_parser.number_to_words(text[st:])
183
+ new_text.append(num_str)
184
+ return "".join(new_text)
185
+
186
+
187
+ def remove_emoji(text: str):
188
+ # Pattern to match emojis and their modifiers
189
+ # - Standard emoji range
190
+ # - Zero-width joiners (U+200D)
191
+ # - Variation selectors (U+FE0F, U+FE0E)
192
+ # - Skin tone modifiers (U+1F3FB to U+1F3FF)
193
+ emoji_pattern = re.compile(
194
+ r"["
195
+ r"\U00010000-\U0010FFFF" # Standard emoji range
196
+ r"\u200D" # Zero-width joiner
197
+ r"\uFE0F\uFE0E" # Variation selectors
198
+ r"\U0001F3FB-\U0001F3FF" # Skin tone modifiers
199
+ r"]+",
200
+ flags=re.UNICODE,
201
+ )
202
+ return emoji_pattern.sub(r"", text)
203
+
204
+
205
+ def remove_repeated_punctuations(text, punctuations):
206
+ if len(punctuations) == 0:
207
+ return text
208
+ pattern = f"[{re.escape(''.join(punctuations))}]" # Create regex pattern for given punctuations
209
+ return re.sub(rf"({pattern})\1+", r"\1", text)
210
+
211
+
212
+ def full_to_half_width(text: str) -> str:
213
+ """Convert full-width punctuation to half-width in a given string."""
214
+ full_width = "!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"
215
+ half_width = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
216
+ trans_table = str.maketrans(full_width, half_width)
217
+ return text.translate(trans_table)
218
+
219
+
220
+ def split_interleaved_delayed_audios(
221
+ audio_data: Union[list[list[int]], torch.Tensor],
222
+ audio_tokenizer: HiggsAudioTokenizer,
223
+ audio_stream_eos_id: int,
224
+ ) -> list[tuple[list[list[int]], torch.Tensor]]:
225
+ separator = [audio_stream_eos_id] * audio_tokenizer.num_codebooks
226
+
227
+ # Convert separator to numpy array if audio_data is numpy array
228
+ if isinstance(audio_data, torch.Tensor):
229
+ audio_data = audio_data.transpose(1, 0)
230
+ separator = torch.tensor(separator)
231
+ # Find the indices where the rows equal the separator
232
+ split_indices = torch.where(torch.all(audio_data == separator, dim=1))[0]
233
+ start = 0
234
+ groups = []
235
+ for idx in split_indices:
236
+ groups.append(audio_data[start:idx].transpose(1, 0))
237
+ start = idx + 1
238
+ if start < len(audio_data):
239
+ groups.append(audio_data[start:].transpose(1, 0))
240
+ else:
241
+ groups = []
242
+ current = []
243
+ for row in audio_data:
244
+ current.append(row)
245
+
246
+ if row == separator:
247
+ groups.append(current)
248
+ current = []
249
+
250
+ # Don't forget the last group if there's no trailing separator
251
+ if current:
252
+ groups.append(current)
253
+
254
+ return groups
pyproject.toml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [tool.ruff]
6
+ line-length = 119
7
+ target-version = "py310"
8
+ indent-width = 4
9
+ exclude = [
10
+ ".bzr",
11
+ ".direnv",
12
+ ".eggs",
13
+ ".git",
14
+ ".git-rewrite",
15
+ ".hg",
16
+ ".ipynb_checkpoints",
17
+ ".mypy_cache",
18
+ ".nox",
19
+ ".pants.d",
20
+ ".pyenv",
21
+ ".pytest_cache",
22
+ ".pytype",
23
+ ".ruff_cache",
24
+ ".svn",
25
+ ".tox",
26
+ ".venv",
27
+ ".vscode",
28
+ "__pypackages__",
29
+ "_build",
30
+ "buck-out",
31
+ "build",
32
+ "dist",
33
+ "node_modules",
34
+ "site-packages",
35
+ "venv",
36
+ "external",
37
+ "third_party",
38
+ ]
39
+
40
+ [tool.ruff.lint]
41
+ preview = true
42
+ ignore-init-module-imports = true
43
+ extend-select = [
44
+ "B009", # static getattr
45
+ "B010", # static setattr
46
+ "CPY", # Copyright
47
+ "E", # PEP8 errors
48
+ "F", # PEP8 formatting
49
+ "I", # Import sorting
50
+ "TID251", # Banned API
51
+ "UP", # Pyupgrade
52
+ "W", # PEP8 warnings
53
+ ]
54
+ ignore = [
55
+ "E501", # Line length (handled by ruff-format)
56
+ "E741", # Ambiguous variable name
57
+ "W605", # Invalid escape sequence
58
+ "UP007", # X | Y type annotations
59
+ ]
60
+
61
+ [tool.ruff.lint.per-file-ignores]
62
+ "__init__.py" = [
63
+ "F401", # Ignore seemingly unused imports (they're meant for re-export)
64
+ ]
65
+
66
+ [tool.ruff.lint.isort]
67
+ lines-after-imports = 2
68
+ known-first-party = ["character_tuning"]
69
+
70
+ [tool.ruff.format]
71
+ # Like Black, use double quotes for strings.
72
+ quote-style = "double"
73
+
74
+ # Like Black, indent with spaces, rather than tabs.
75
+ indent-style = "space"
76
+
77
+ # Like Black, respect magic trailing commas.
78
+ skip-magic-trailing-comma = false
79
+
80
+ # Like Black, automatically detect the appropriate line ending.
81
+ line-ending = "auto"
82
+
83
+ # Enable auto-formatting of code examples in docstrings. Markdown,
84
+ # reStructuredText code/literal blocks and doctests are all supported.
85
+ #
86
+ # This is currently disabled by default, but it is planned for this
87
+ # to be opt-out in the future.
88
+ docstring-code-format = false
89
+
90
+ # Set the line length limit used when formatting code snippets in
91
+ # docstrings.
92
+ #
93
+ # This only has an effect when the `docstring-code-format` setting is
94
+ # enabled.
95
+ docstring-code-line-length = "dynamic"
96
+
97
+ [tool.ruff.lint.flake8-tidy-imports.banned-api]
98
+ "os.getenv".msg = "Use os.environ instead"
99
+ "os.putenv".msg = "Use os.environ instead"
100
+ "os.unsetenv".msg = "Use os.environ instead"
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ descript-audio-codec
2
+ torch==2.5.1
3
+ torchaudio==2.5.1
4
+ transformers>=4.45.1,<4.47.0
5
+ librosa
6
+ dacite
7
+ boto3==1.35.36
8
+ s3fs
9
+ json_repair
10
+ pandas
11
+ pydantic
12
+ vector_quantize_pytorch
13
+ loguru
14
+ pydub
15
+ ruff==0.12.2
16
+ omegaconf
17
+ click