Spaces:
Running
on
Zero
Running
on
Zero
Upload tts playground and serving engine
Browse files- .gitignore +10 -0
- README.md +1 -1
- app.py +528 -4
- higgs_audio/__init__.py +1 -0
- higgs_audio/audio_processing/LICENSE +51 -0
- higgs_audio/audio_processing/descriptaudiocodec/__init__.py +0 -0
- higgs_audio/audio_processing/descriptaudiocodec/dac/model/base.py +286 -0
- higgs_audio/audio_processing/descriptaudiocodec/dac/model/dac.py +365 -0
- higgs_audio/audio_processing/descriptaudiocodec/dac/nn/layers.py +33 -0
- higgs_audio/audio_processing/descriptaudiocodec/dac/nn/quantize.py +251 -0
- higgs_audio/audio_processing/higgs_audio_tokenizer.py +341 -0
- higgs_audio/audio_processing/quantization/__init__.py +8 -0
- higgs_audio/audio_processing/quantization/ac.py +301 -0
- higgs_audio/audio_processing/quantization/core_vq.py +360 -0
- higgs_audio/audio_processing/quantization/core_vq_lsx_version.py +431 -0
- higgs_audio/audio_processing/quantization/ddp_utils.py +197 -0
- higgs_audio/audio_processing/quantization/distrib.py +123 -0
- higgs_audio/audio_processing/quantization/vq.py +116 -0
- higgs_audio/audio_processing/semantic_module.py +310 -0
- higgs_audio/constants.py +3 -0
- higgs_audio/data_collator/__init__.py +0 -0
- higgs_audio/data_collator/higgs_audio_collator.py +583 -0
- higgs_audio/data_types.py +38 -0
- higgs_audio/dataset/__init__.py +0 -0
- higgs_audio/dataset/chatml_dataset.py +554 -0
- higgs_audio/model/__init__.py +9 -0
- higgs_audio/model/audio_head.py +139 -0
- higgs_audio/model/common.py +27 -0
- higgs_audio/model/configuration_higgs_audio.py +235 -0
- higgs_audio/model/cuda_graph_runner.py +129 -0
- higgs_audio/model/custom_modules.py +155 -0
- higgs_audio/model/modeling_higgs_audio.py +0 -0
- higgs_audio/model/utils.py +778 -0
- higgs_audio/serve/serve_engine.py +424 -0
- higgs_audio/serve/utils.py +254 -0
- pyproject.toml +100 -0
- requirements.txt +17 -0
.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.pyc
|
3 |
+
*.pyo
|
4 |
+
*.pyd
|
5 |
+
*.pyw
|
6 |
+
*.pyz
|
7 |
+
*.pywz
|
8 |
+
*.pyzw
|
9 |
+
*.pyzwz
|
10 |
+
.ruff_cache/
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Higgs Audio Demo
|
3 |
-
emoji:
|
4 |
colorFrom: green
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: Higgs Audio Demo
|
3 |
+
emoji: 🎤
|
4 |
colorFrom: green
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
app.py
CHANGED
@@ -1,7 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
|
|
|
1 |
+
"""
|
2 |
+
Gradio UI for Text-to-Speech using HiggsAudioServeEngine
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import base64
|
7 |
+
import os
|
8 |
+
import uuid
|
9 |
+
import json
|
10 |
+
from typing import Optional
|
11 |
import gradio as gr
|
12 |
+
from loguru import logger
|
13 |
+
import numpy as np
|
14 |
+
import time
|
15 |
+
from functools import lru_cache
|
16 |
+
import re
|
17 |
+
import spaces
|
18 |
+
|
19 |
+
|
20 |
+
# Import HiggsAudio components
|
21 |
+
from higgs_audio.serve.serve_engine import HiggsAudioServeEngine
|
22 |
+
from higgs_audio.data_types import ChatMLSample, AudioContent, Message
|
23 |
+
|
24 |
+
# Global engine instance
|
25 |
+
engine = None
|
26 |
+
|
27 |
+
# Set up default paths and resources
|
28 |
+
EXAMPLES_DIR = os.path.join(os.path.dirname(__file__), "examples")
|
29 |
+
os.makedirs(EXAMPLES_DIR, exist_ok=True)
|
30 |
+
|
31 |
+
# Default model configuration
|
32 |
+
DEFAULT_MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-staging"
|
33 |
+
DEFAULT_AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer-staging"
|
34 |
+
SAMPLE_RATE = 24000
|
35 |
+
|
36 |
+
DEFAULT_SYSTEM_PROMPT = (
|
37 |
+
"Generate audio following instruction.\n\n"
|
38 |
+
"<|scene_desc_start|>\n"
|
39 |
+
"Audio is recorded from a quiet room.\n"
|
40 |
+
"<|scene_desc_end|>"
|
41 |
+
)
|
42 |
+
|
43 |
+
DEFAULT_STOP_STRINGS = ["<|end_of_text|>", "<|eot_id|>"]
|
44 |
+
|
45 |
+
# Predefined examples for system and input messages
|
46 |
+
PREDEFINED_EXAMPLES = {
|
47 |
+
"None": {"system_prompt": "", "input_text": "", "description": "Default example"},
|
48 |
+
"multispeaker-interleave": {
|
49 |
+
"system_prompt": "Generate audio following instruction.\n\n"
|
50 |
+
"<|scene_desc_start|>\n"
|
51 |
+
"SPEAKER0: vocal fry;feminism;slightly fast\n"
|
52 |
+
"SPEAKER1: masculine;moderate;moderate pitch;monotone;mature\n"
|
53 |
+
"In this scene, a group of adventurers is debating whether to investigate a potentially dangerous situation.\n"
|
54 |
+
"<|scene_desc_end|>",
|
55 |
+
"input_text": "<|generation_instruction_start|>\nGenerate interleaved transcript and audio that lasts for around 10 seconds.\n<|generation_instruction_end|>",
|
56 |
+
"description": "Multispeaker interleave example",
|
57 |
+
},
|
58 |
+
"single-speaker": {
|
59 |
+
"system_prompt": "Generate audio following instruction.\n\n"
|
60 |
+
"<|scene_desc_start|>\n"
|
61 |
+
"SPEAKER0: british accent\n"
|
62 |
+
"<|scene_desc_end|>",
|
63 |
+
"input_text": "Hey, everyone! Welcome back to Tech Talk Tuesdays.\n"
|
64 |
+
"It's your host, Alex, and today, we're diving into a topic that's become absolutely crucial in the tech world — deep learning.\n"
|
65 |
+
"And let's be honest, if you've been even remotely connected to tech, AI, or machine learning lately, you know that deep learning is everywhere.\n"
|
66 |
+
"\n"
|
67 |
+
"So here's the big question: Do you want to understand how deep learning works?\n"
|
68 |
+
"How to use it to build powerful models that can predict, automate, and transform industries?\n"
|
69 |
+
"Well, today, I've got some exciting news for you.\n"
|
70 |
+
"\n"
|
71 |
+
"We're going to talk about a course that I highly recommend: Dive into Deep Learning.\n"
|
72 |
+
"It's not just another course; it's an entire experience that will take you from a beginner to someone who is well-versed in deep learning techniques.",
|
73 |
+
"description": "Single speaker example",
|
74 |
+
},
|
75 |
+
"single-speaker-zh": {
|
76 |
+
"system_prompt": "Generate audio following instruction.\n\n"
|
77 |
+
"<|scene_desc_start|>\n"
|
78 |
+
"\nAudio is recorded from a quiet room.\n"
|
79 |
+
"\nSPEAKER0: feminine\n"
|
80 |
+
"<|scene_desc_end|>",
|
81 |
+
"input_text": "大家好, 欢迎收听本期的跟李沐学AI. 今天沐哥在忙着洗数据, 所以由我, 希格斯主播代替他讲这期视频.\n"
|
82 |
+
"今天我们要聊的是一个你绝对不能忽视的话题: 多模态学习.\n"
|
83 |
+
"无论你是开发者, 数据科学爱好者, 还是只是对人工智能感兴趣的人都一定听说过这个词. 它已经成为AI时代的一个研究热点.\n"
|
84 |
+
"那么, 问题来了, 你真的了解多模态吗? 你知道如何自己动手构建多模态大模型吗.\n"
|
85 |
+
"或者说, 你能察觉到我其实是个机器人吗?",
|
86 |
+
"description": "Single speaker with Chinese text",
|
87 |
+
},
|
88 |
+
}
|
89 |
+
|
90 |
+
|
91 |
+
@lru_cache(maxsize=20)
|
92 |
+
def encode_audio_file(file_path):
|
93 |
+
"""Encode an audio file to base64."""
|
94 |
+
with open(file_path, "rb") as audio_file:
|
95 |
+
return base64.b64encode(audio_file.read()).decode("utf-8")
|
96 |
+
|
97 |
+
|
98 |
+
def load_voice_presets():
|
99 |
+
"""Load the voice presets from the voice_examples directory."""
|
100 |
+
try:
|
101 |
+
with open(
|
102 |
+
os.path.join(os.path.dirname(__file__), "voice_examples", "config.json"),
|
103 |
+
"r",
|
104 |
+
) as f:
|
105 |
+
voice_dict = json.load(f)
|
106 |
+
voice_presets = {k: v["transcript"] for k, v in voice_dict.items()}
|
107 |
+
voice_presets["EMPTY"] = "No reference voice"
|
108 |
+
logger.info(f"Loaded voice presets: {list(voice_presets.keys())}")
|
109 |
+
return voice_presets
|
110 |
+
except FileNotFoundError:
|
111 |
+
logger.warning("Voice examples config file not found. Using empty voice presets.")
|
112 |
+
return {"EMPTY": "No reference voice"}
|
113 |
+
except Exception as e:
|
114 |
+
logger.error(f"Error loading voice presets: {e}")
|
115 |
+
return {"EMPTY": "No reference voice"}
|
116 |
+
|
117 |
+
|
118 |
+
def get_voice_present(voice_preset):
|
119 |
+
"""Get the voice path and text for a given voice preset."""
|
120 |
+
voice_path = os.path.join(os.path.dirname(__file__), "voice_examples", f"{voice_preset}.wav")
|
121 |
+
if not os.path.exists(voice_path):
|
122 |
+
logger.warning(f"Voice preset file not found: {voice_path}")
|
123 |
+
return None, "Voice preset not found"
|
124 |
+
|
125 |
+
text = VOICE_PRESETS.get(voice_preset, "No transcript available")
|
126 |
+
return voice_path, text
|
127 |
+
|
128 |
+
|
129 |
+
@spaces.GPU
|
130 |
+
def initialize_engine(model_path, audio_tokenizer_path, device="cuda") -> bool:
|
131 |
+
"""Initialize the HiggsAudioServeEngine."""
|
132 |
+
global engine
|
133 |
+
try:
|
134 |
+
engine = HiggsAudioServeEngine(
|
135 |
+
model_name_or_path=model_path,
|
136 |
+
audio_tokenizer_name_or_path=audio_tokenizer_path,
|
137 |
+
device=device,
|
138 |
+
)
|
139 |
+
logger.info(f"Successfully initialized HiggsAudioServeEngine with model: {model_path}")
|
140 |
+
return True
|
141 |
+
except Exception as e:
|
142 |
+
logger.error(f"Failed to initialize engine: {e}")
|
143 |
+
return False
|
144 |
+
|
145 |
+
|
146 |
+
def check_return_audio(audio_wv: np.ndarray):
|
147 |
+
# check if the audio returned is all silent
|
148 |
+
if np.all(audio_wv == 0):
|
149 |
+
logger.warning("Audio is silent, returning None")
|
150 |
+
|
151 |
+
|
152 |
+
def process_text_output(text_output: str):
|
153 |
+
# remove all the continuous <|AUDIO_OUT|> tokens with a single <|AUDIO_OUT|>
|
154 |
+
text_output = re.sub(r"(<\|AUDIO_OUT\|>)+", r"<|AUDIO_OUT|>", text_output)
|
155 |
+
return text_output
|
156 |
+
|
157 |
+
|
158 |
+
def prepare_chatml_sample(
|
159 |
+
voice_present: str,
|
160 |
+
text: str,
|
161 |
+
reference_audio: Optional[str] = None,
|
162 |
+
reference_text: Optional[str] = None,
|
163 |
+
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
164 |
+
):
|
165 |
+
"""Prepare a ChatMLSample for the HiggsAudioServeEngine."""
|
166 |
+
messages = []
|
167 |
+
|
168 |
+
# Add system message if provided
|
169 |
+
if len(system_prompt) > 0:
|
170 |
+
messages.append(Message(role="system", content=system_prompt))
|
171 |
+
|
172 |
+
# Add reference audio if provided
|
173 |
+
audio_base64 = None
|
174 |
+
ref_text = ""
|
175 |
+
|
176 |
+
if reference_audio:
|
177 |
+
# Custom reference audio
|
178 |
+
audio_base64 = encode_audio_file(reference_audio)
|
179 |
+
ref_text = reference_text or ""
|
180 |
+
elif voice_present != "EMPTY":
|
181 |
+
# Voice preset
|
182 |
+
voice_path, ref_text = get_voice_present(voice_present)
|
183 |
+
if voice_path is None:
|
184 |
+
logger.warning(f"Voice preset {voice_present} not found, skipping reference audio")
|
185 |
+
else:
|
186 |
+
audio_base64 = encode_audio_file(voice_path)
|
187 |
+
|
188 |
+
# Only add reference audio if we have it
|
189 |
+
if audio_base64 is not None:
|
190 |
+
# Add user message with reference text
|
191 |
+
messages.append(Message(role="user", content=ref_text))
|
192 |
+
|
193 |
+
# Add assistant message with audio content
|
194 |
+
audio_content = AudioContent(raw_audio=audio_base64, audio_url="")
|
195 |
+
messages.append(Message(role="assistant", content=[audio_content]))
|
196 |
+
|
197 |
+
# Add the main user message
|
198 |
+
messages.append(Message(role="user", content=text))
|
199 |
+
|
200 |
+
return ChatMLSample(messages=messages)
|
201 |
+
|
202 |
+
|
203 |
+
@spaces.GPU(duration=500)
|
204 |
+
def text_to_speech(
|
205 |
+
text,
|
206 |
+
voice_preset,
|
207 |
+
reference_audio=None,
|
208 |
+
reference_text=None,
|
209 |
+
max_completion_tokens=1024,
|
210 |
+
temperature=1.0,
|
211 |
+
top_p=0.95,
|
212 |
+
top_k=50,
|
213 |
+
system_prompt=DEFAULT_SYSTEM_PROMPT,
|
214 |
+
stop_strings=None,
|
215 |
+
):
|
216 |
+
"""Convert text to speech using HiggsAudioServeEngine."""
|
217 |
+
global engine
|
218 |
+
|
219 |
+
if engine is None:
|
220 |
+
error_msg = "Engine not initialized. Please load a model first."
|
221 |
+
logger.error(error_msg)
|
222 |
+
gr.Error(error_msg)
|
223 |
+
return f"❌ {error_msg}", None
|
224 |
+
|
225 |
+
try:
|
226 |
+
# Prepare ChatML sample
|
227 |
+
chatml_sample = prepare_chatml_sample(voice_preset, text, reference_audio, reference_text, system_prompt)
|
228 |
+
|
229 |
+
# Convert stop strings format
|
230 |
+
if stop_strings is None:
|
231 |
+
stop_list = DEFAULT_STOP_STRINGS
|
232 |
+
else:
|
233 |
+
stop_list = [s for s in stop_strings["stops"] if s.strip()]
|
234 |
+
|
235 |
+
request_id = f"tts-playground-{str(uuid.uuid4())}"
|
236 |
+
logger.info(
|
237 |
+
f"{request_id}: Generating speech for text: {text[:100]}..., \n"
|
238 |
+
f"with parameters: temperature={temperature}, top_p={top_p}, top_k={top_k}, stop_list={stop_list}"
|
239 |
+
)
|
240 |
+
start_time = time.time()
|
241 |
+
|
242 |
+
# Generate using the engine
|
243 |
+
response = engine.generate(
|
244 |
+
chat_ml_sample=chatml_sample,
|
245 |
+
max_new_tokens=max_completion_tokens,
|
246 |
+
temperature=temperature,
|
247 |
+
top_k=top_k if top_k > 0 else None,
|
248 |
+
top_p=top_p,
|
249 |
+
stop_strings=stop_list,
|
250 |
+
)
|
251 |
+
|
252 |
+
generation_time = time.time() - start_time
|
253 |
+
logger.info(f"{request_id}: Generated audio in {generation_time:.3f} seconds")
|
254 |
+
gr.Info(f"Generated audio in {generation_time:.3f} seconds")
|
255 |
+
|
256 |
+
# Process the response
|
257 |
+
text_output = process_text_output(response.generated_text)
|
258 |
+
|
259 |
+
if response.audio is not None:
|
260 |
+
# Convert to int16 for Gradio
|
261 |
+
audio_data = (response.audio * 32767).astype(np.int16)
|
262 |
+
check_return_audio(audio_data)
|
263 |
+
return text_output, (response.sampling_rate, audio_data)
|
264 |
+
else:
|
265 |
+
logger.warning("No audio generated")
|
266 |
+
return text_output, None
|
267 |
+
|
268 |
+
except Exception as e:
|
269 |
+
error_msg = f"Error generating speech: {e}"
|
270 |
+
logger.error(error_msg)
|
271 |
+
gr.Error(error_msg)
|
272 |
+
return f"❌ {error_msg}", None
|
273 |
+
|
274 |
+
|
275 |
+
def create_ui():
|
276 |
+
my_theme = "JohnSmith9982/small_and_pretty"
|
277 |
+
|
278 |
+
# Add custom CSS to disable focus highlighting on textboxes
|
279 |
+
custom_css = """
|
280 |
+
.gradio-container input:focus,
|
281 |
+
.gradio-container textarea:focus,
|
282 |
+
.gradio-container select:focus,
|
283 |
+
.gradio-container .gr-input:focus,
|
284 |
+
.gradio-container .gr-textarea:focus,
|
285 |
+
.gradio-container .gr-textbox:focus,
|
286 |
+
.gradio-container .gr-textbox:focus-within,
|
287 |
+
.gradio-container .gr-form:focus-within,
|
288 |
+
.gradio-container *:focus {
|
289 |
+
box-shadow: none !important;
|
290 |
+
border-color: var(--border-color-primary) !important;
|
291 |
+
outline: none !important;
|
292 |
+
background-color: var(--input-background-fill) !important;
|
293 |
+
}
|
294 |
+
|
295 |
+
/* Override any hover effects as well */
|
296 |
+
.gradio-container input:hover,
|
297 |
+
.gradio-container textarea:hover,
|
298 |
+
.gradio-container select:hover,
|
299 |
+
.gradio-container .gr-input:hover,
|
300 |
+
.gradio-container .gr-textarea:hover,
|
301 |
+
.gradio-container .gr-textbox:hover {
|
302 |
+
border-color: var(--border-color-primary) !important;
|
303 |
+
background-color: var(--input-background-fill) !important;
|
304 |
+
}
|
305 |
+
|
306 |
+
/* Style for checked checkbox */
|
307 |
+
.gradio-container input[type="checkbox"]:checked {
|
308 |
+
background-color: var(--primary-500) !important;
|
309 |
+
border-color: var(--primary-500) !important;
|
310 |
+
}
|
311 |
+
"""
|
312 |
+
|
313 |
+
"""Create the Gradio UI."""
|
314 |
+
with gr.Blocks(theme=my_theme, css=custom_css) as demo:
|
315 |
+
gr.Markdown("# Higgs Audio Text-to-Speech Playground")
|
316 |
+
|
317 |
+
# Main UI section
|
318 |
+
with gr.Row():
|
319 |
+
with gr.Column(scale=2):
|
320 |
+
# Template selection dropdown
|
321 |
+
template_dropdown = gr.Dropdown(
|
322 |
+
label="Message examples",
|
323 |
+
choices=list(PREDEFINED_EXAMPLES.keys()),
|
324 |
+
value="None",
|
325 |
+
info="Select a predefined example for system and input messages. Voice preset will be set to EMPTY when a example is selected.",
|
326 |
+
)
|
327 |
+
|
328 |
+
system_prompt = gr.TextArea(
|
329 |
+
label="System Prompt",
|
330 |
+
placeholder="Enter system prompt to guide the model...",
|
331 |
+
value=DEFAULT_SYSTEM_PROMPT,
|
332 |
+
lines=2,
|
333 |
+
)
|
334 |
+
|
335 |
+
input_text = gr.TextArea(
|
336 |
+
label="Input Text",
|
337 |
+
placeholder="Type the text you want to convert to speech...",
|
338 |
+
lines=5,
|
339 |
+
)
|
340 |
+
|
341 |
+
voice_preset = gr.Dropdown(
|
342 |
+
label="Voice Preset",
|
343 |
+
choices=list(VOICE_PRESETS.keys()),
|
344 |
+
value="EMPTY",
|
345 |
+
)
|
346 |
+
|
347 |
+
with gr.Accordion("Custom Reference (Optional)", open=False):
|
348 |
+
reference_audio = gr.Audio(label="Reference Audio", type="filepath")
|
349 |
+
reference_text = gr.TextArea(
|
350 |
+
label="Reference Text (transcript of the reference audio)",
|
351 |
+
placeholder="Enter the transcript of your reference audio...",
|
352 |
+
lines=3,
|
353 |
+
)
|
354 |
+
|
355 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
356 |
+
max_completion_tokens = gr.Slider(
|
357 |
+
minimum=128,
|
358 |
+
maximum=4096,
|
359 |
+
value=1024,
|
360 |
+
step=10,
|
361 |
+
label="Max Completion Tokens",
|
362 |
+
)
|
363 |
+
temperature = gr.Slider(
|
364 |
+
minimum=0.0,
|
365 |
+
maximum=1.5,
|
366 |
+
value=1.0,
|
367 |
+
step=0.1,
|
368 |
+
label="Temperature",
|
369 |
+
)
|
370 |
+
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top P")
|
371 |
+
top_k = gr.Slider(minimum=-1, maximum=100, value=50, step=1, label="Top K")
|
372 |
+
# Add stop strings component
|
373 |
+
stop_strings = gr.Dataframe(
|
374 |
+
label="Stop Strings",
|
375 |
+
headers=["stops"],
|
376 |
+
datatype=["str"],
|
377 |
+
value=[[s] for s in DEFAULT_STOP_STRINGS],
|
378 |
+
interactive=True,
|
379 |
+
col_count=(1, "fixed"),
|
380 |
+
)
|
381 |
+
|
382 |
+
submit_btn = gr.Button("Generate Speech", variant="primary", scale=1)
|
383 |
+
|
384 |
+
with gr.Column(scale=2):
|
385 |
+
output_text = gr.TextArea(label="Model Response", lines=2)
|
386 |
+
|
387 |
+
# Audio output
|
388 |
+
output_audio = gr.Audio(label="Generated Audio", interactive=False, autoplay=True)
|
389 |
+
|
390 |
+
stop_btn = gr.Button("Stop Playback", variant="primary")
|
391 |
+
|
392 |
+
# Example voice
|
393 |
+
with gr.Row():
|
394 |
+
voice_samples_table = gr.Dataframe(
|
395 |
+
headers=["Voice Preset", "Sample Text"],
|
396 |
+
datatype=["str", "str"],
|
397 |
+
value=[[preset, text] for preset, text in VOICE_PRESETS.items() if preset != "EMPTY"],
|
398 |
+
interactive=False,
|
399 |
+
)
|
400 |
+
sample_audio = gr.Audio(label="Voice Sample", visible=True)
|
401 |
+
|
402 |
+
# Function to play voice sample when clicking on a row
|
403 |
+
def play_voice_sample(evt: gr.SelectData):
|
404 |
+
try:
|
405 |
+
# Get the preset name from the clicked row
|
406 |
+
preset_names = [preset for preset in VOICE_PRESETS.keys() if preset != "EMPTY"]
|
407 |
+
if evt.index[0] < len(preset_names):
|
408 |
+
preset = preset_names[evt.index[0]]
|
409 |
+
voice_path, _ = get_voice_present(preset)
|
410 |
+
if voice_path and os.path.exists(voice_path):
|
411 |
+
return voice_path
|
412 |
+
else:
|
413 |
+
gr.Warning(f"Voice sample file not found for preset: {preset}")
|
414 |
+
return None
|
415 |
+
else:
|
416 |
+
gr.Warning("Invalid voice preset selection")
|
417 |
+
return None
|
418 |
+
except Exception as e:
|
419 |
+
logger.error(f"Error playing voice sample: {e}")
|
420 |
+
gr.Error(f"Error playing voice sample: {e}")
|
421 |
+
return None
|
422 |
+
|
423 |
+
voice_samples_table.select(fn=play_voice_sample, outputs=[sample_audio])
|
424 |
+
|
425 |
+
# Function to handle template selection
|
426 |
+
def apply_template(template_name):
|
427 |
+
if template_name in PREDEFINED_EXAMPLES:
|
428 |
+
template = PREDEFINED_EXAMPLES[template_name]
|
429 |
+
return (
|
430 |
+
template["system_prompt"], # system_prompt
|
431 |
+
template["input_text"], # input_text
|
432 |
+
"EMPTY", # voice_preset (always set to EMPTY for examples)
|
433 |
+
)
|
434 |
+
else:
|
435 |
+
return (
|
436 |
+
gr.update(),
|
437 |
+
gr.update(),
|
438 |
+
gr.update(),
|
439 |
+
) # No change if template not found
|
440 |
+
|
441 |
+
# Set up event handlers
|
442 |
+
|
443 |
+
# Connect template dropdown to handler
|
444 |
+
template_dropdown.change(
|
445 |
+
fn=apply_template,
|
446 |
+
inputs=[template_dropdown],
|
447 |
+
outputs=[system_prompt, input_text, voice_preset],
|
448 |
+
)
|
449 |
+
|
450 |
+
# Connect submit button to the TTS function
|
451 |
+
submit_btn.click(
|
452 |
+
fn=text_to_speech,
|
453 |
+
inputs=[
|
454 |
+
input_text,
|
455 |
+
voice_preset,
|
456 |
+
reference_audio,
|
457 |
+
reference_text,
|
458 |
+
max_completion_tokens,
|
459 |
+
temperature,
|
460 |
+
top_p,
|
461 |
+
top_k,
|
462 |
+
system_prompt,
|
463 |
+
stop_strings,
|
464 |
+
],
|
465 |
+
outputs=[output_text, output_audio],
|
466 |
+
api_name="generate_speech",
|
467 |
+
)
|
468 |
+
|
469 |
+
# Stop button functionality
|
470 |
+
stop_btn.click(
|
471 |
+
fn=lambda: None,
|
472 |
+
inputs=[],
|
473 |
+
outputs=[output_audio],
|
474 |
+
js="() => {const audio = document.querySelector('audio'); if(audio) audio.pause(); return null;}",
|
475 |
+
)
|
476 |
+
|
477 |
+
return demo
|
478 |
+
|
479 |
+
|
480 |
+
def main():
|
481 |
+
"""Main function to parse arguments and launch the UI."""
|
482 |
+
global DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH, VOICE_PRESETS
|
483 |
+
|
484 |
+
parser = argparse.ArgumentParser(description="Gradio UI for Text-to-Speech using HiggsAudioServeEngine")
|
485 |
+
parser.add_argument(
|
486 |
+
"--model-path",
|
487 |
+
type=str,
|
488 |
+
default=DEFAULT_MODEL_PATH,
|
489 |
+
help="Path to the Higgs Audio model.",
|
490 |
+
)
|
491 |
+
parser.add_argument(
|
492 |
+
"--audio-tokenizer-path",
|
493 |
+
type=str,
|
494 |
+
default=DEFAULT_AUDIO_TOKENIZER_PATH,
|
495 |
+
help="Path to the audio tokenizer.",
|
496 |
+
)
|
497 |
+
parser.add_argument(
|
498 |
+
"--device",
|
499 |
+
type=str,
|
500 |
+
default="cuda",
|
501 |
+
choices=["cuda", "cpu"],
|
502 |
+
help="Device to run the model on.",
|
503 |
+
)
|
504 |
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host for the Gradio interface.")
|
505 |
+
parser.add_argument("--port", type=int, default=7860, help="Port for the Gradio interface.")
|
506 |
+
|
507 |
+
args = parser.parse_args()
|
508 |
+
|
509 |
+
# Update default values if provided via command line
|
510 |
+
DEFAULT_MODEL_PATH = args.model_path
|
511 |
+
DEFAULT_AUDIO_TOKENIZER_PATH = args.audio_tokenizer_path
|
512 |
+
VOICE_PRESETS = load_voice_presets()
|
513 |
+
|
514 |
+
# Load model on startup
|
515 |
+
logger.info("Loading model...")
|
516 |
+
result = initialize_engine(args.model_path, args.audio_tokenizer_path, args.device)
|
517 |
+
|
518 |
+
# Exit if model loading failed
|
519 |
+
if not result:
|
520 |
+
logger.error("Failed to load model. Exiting.")
|
521 |
+
return
|
522 |
+
|
523 |
+
logger.info(f"Model loaded: {DEFAULT_MODEL_PATH}")
|
524 |
+
|
525 |
+
# Create and launch the UI
|
526 |
+
demo = create_ui()
|
527 |
+
demo.launch(server_name=args.host, server_port=args.port)
|
528 |
|
|
|
|
|
529 |
|
530 |
+
if __name__ == "__main__":
|
531 |
+
main()
|
higgs_audio/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import HiggsAudioConfig, HiggsAudioModel
|
higgs_audio/audio_processing/LICENSE
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Third-Party License Attribution for Audio Processing Module
|
2 |
+
===========================================================
|
3 |
+
|
4 |
+
This directory contains code derived from multiple open-source projects.
|
5 |
+
The following sections detail the licenses and attributions for third-party code.
|
6 |
+
|
7 |
+
## XCodec Repository
|
8 |
+
The code in this directory is derived from:
|
9 |
+
https://github.com/zhenye234/xcodec
|
10 |
+
|
11 |
+
## Individual File Attributions
|
12 |
+
|
13 |
+
### Quantization Module (quantization/)
|
14 |
+
- Several files contain code derived from Meta Platforms, Inc. and the vector-quantize-pytorch repository
|
15 |
+
- Individual files contain their own license headers where applicable
|
16 |
+
- The vector-quantize-pytorch portions are licensed under the MIT License
|
17 |
+
|
18 |
+
## License Terms
|
19 |
+
|
20 |
+
### MIT License (for applicable portions)
|
21 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
22 |
+
of this software and associated documentation files (the "Software"), to deal
|
23 |
+
in the Software without restriction, including without limitation the rights
|
24 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
25 |
+
copies of the Software, and to permit persons to whom the Software is
|
26 |
+
furnished to do so, subject to the following conditions:
|
27 |
+
|
28 |
+
The above copyright notice and this permission notice shall be included in all
|
29 |
+
copies or substantial portions of the Software.
|
30 |
+
|
31 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
32 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
33 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
34 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
35 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
36 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
37 |
+
SOFTWARE.
|
38 |
+
|
39 |
+
## Attribution Requirements
|
40 |
+
When using this code, please ensure proper attribution to:
|
41 |
+
1. The original xcodec repository: https://github.com/zhenye234/xcodec
|
42 |
+
2. Any other repositories mentioned in individual file headers
|
43 |
+
3. This derivative work and its modifications
|
44 |
+
|
45 |
+
## Disclaimer
|
46 |
+
This directory contains modified versions of the original code. Please refer to
|
47 |
+
the original repositories for the canonical implementations and their specific
|
48 |
+
license terms.
|
49 |
+
|
50 |
+
For any questions about licensing or attribution, please check the individual
|
51 |
+
file headers and the original source repositories.
|
higgs_audio/audio_processing/descriptaudiocodec/__init__.py
ADDED
File without changes
|
higgs_audio/audio_processing/descriptaudiocodec/dac/model/base.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import tqdm
|
9 |
+
from audiotools import AudioSignal
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
SUPPORTED_VERSIONS = ["1.0.0"]
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class DACFile:
|
17 |
+
codes: torch.Tensor
|
18 |
+
|
19 |
+
# Metadata
|
20 |
+
chunk_length: int
|
21 |
+
original_length: int
|
22 |
+
input_db: float
|
23 |
+
channels: int
|
24 |
+
sample_rate: int
|
25 |
+
padding: bool
|
26 |
+
dac_version: str
|
27 |
+
|
28 |
+
def save(self, path):
|
29 |
+
artifacts = {
|
30 |
+
"codes": self.codes.numpy().astype(np.uint16),
|
31 |
+
"metadata": {
|
32 |
+
"input_db": self.input_db.numpy().astype(np.float32),
|
33 |
+
"original_length": self.original_length,
|
34 |
+
"sample_rate": self.sample_rate,
|
35 |
+
"chunk_length": self.chunk_length,
|
36 |
+
"channels": self.channels,
|
37 |
+
"padding": self.padding,
|
38 |
+
"dac_version": SUPPORTED_VERSIONS[-1],
|
39 |
+
},
|
40 |
+
}
|
41 |
+
path = Path(path).with_suffix(".dac")
|
42 |
+
with open(path, "wb") as f:
|
43 |
+
np.save(f, artifacts)
|
44 |
+
return path
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def load(cls, path):
|
48 |
+
artifacts = np.load(path, allow_pickle=True)[()]
|
49 |
+
codes = torch.from_numpy(artifacts["codes"].astype(int))
|
50 |
+
if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
|
51 |
+
raise RuntimeError(f"Given file {path} can't be loaded with this version of descript-audio-codec.")
|
52 |
+
return cls(codes=codes, **artifacts["metadata"])
|
53 |
+
|
54 |
+
|
55 |
+
class CodecMixin:
|
56 |
+
@property
|
57 |
+
def padding(self):
|
58 |
+
if not hasattr(self, "_padding"):
|
59 |
+
self._padding = True
|
60 |
+
return self._padding
|
61 |
+
|
62 |
+
@padding.setter
|
63 |
+
def padding(self, value):
|
64 |
+
assert isinstance(value, bool)
|
65 |
+
|
66 |
+
layers = [l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))]
|
67 |
+
|
68 |
+
for layer in layers:
|
69 |
+
if value:
|
70 |
+
if hasattr(layer, "original_padding"):
|
71 |
+
layer.padding = layer.original_padding
|
72 |
+
else:
|
73 |
+
layer.original_padding = layer.padding
|
74 |
+
layer.padding = tuple(0 for _ in range(len(layer.padding)))
|
75 |
+
|
76 |
+
self._padding = value
|
77 |
+
|
78 |
+
def get_delay(self):
|
79 |
+
# Any number works here, delay is invariant to input length
|
80 |
+
l_out = self.get_output_length(0)
|
81 |
+
L = l_out
|
82 |
+
|
83 |
+
layers = []
|
84 |
+
for layer in self.modules():
|
85 |
+
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
86 |
+
layers.append(layer)
|
87 |
+
|
88 |
+
for layer in reversed(layers):
|
89 |
+
d = layer.dilation[0]
|
90 |
+
k = layer.kernel_size[0]
|
91 |
+
s = layer.stride[0]
|
92 |
+
|
93 |
+
if isinstance(layer, nn.ConvTranspose1d):
|
94 |
+
L = ((L - d * (k - 1) - 1) / s) + 1
|
95 |
+
elif isinstance(layer, nn.Conv1d):
|
96 |
+
L = (L - 1) * s + d * (k - 1) + 1
|
97 |
+
|
98 |
+
L = math.ceil(L)
|
99 |
+
|
100 |
+
l_in = L
|
101 |
+
|
102 |
+
return (l_in - l_out) // 2
|
103 |
+
|
104 |
+
def get_output_length(self, input_length):
|
105 |
+
L = input_length
|
106 |
+
# Calculate output length
|
107 |
+
for layer in self.modules():
|
108 |
+
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
109 |
+
d = layer.dilation[0]
|
110 |
+
k = layer.kernel_size[0]
|
111 |
+
s = layer.stride[0]
|
112 |
+
|
113 |
+
if isinstance(layer, nn.Conv1d):
|
114 |
+
L = ((L - d * (k - 1) - 1) / s) + 1
|
115 |
+
elif isinstance(layer, nn.ConvTranspose1d):
|
116 |
+
L = (L - 1) * s + d * (k - 1) + 1
|
117 |
+
|
118 |
+
L = math.floor(L)
|
119 |
+
return L
|
120 |
+
|
121 |
+
@torch.no_grad()
|
122 |
+
def compress(
|
123 |
+
self,
|
124 |
+
audio_path_or_signal: Union[str, Path, AudioSignal],
|
125 |
+
win_duration: float = 1.0,
|
126 |
+
verbose: bool = False,
|
127 |
+
normalize_db: float = -16,
|
128 |
+
n_quantizers: int = None,
|
129 |
+
) -> DACFile:
|
130 |
+
"""Processes an audio signal from a file or AudioSignal object into
|
131 |
+
discrete codes. This function processes the signal in short windows,
|
132 |
+
using constant GPU memory.
|
133 |
+
|
134 |
+
Parameters
|
135 |
+
----------
|
136 |
+
audio_path_or_signal : Union[str, Path, AudioSignal]
|
137 |
+
audio signal to reconstruct
|
138 |
+
win_duration : float, optional
|
139 |
+
window duration in seconds, by default 5.0
|
140 |
+
verbose : bool, optional
|
141 |
+
by default False
|
142 |
+
normalize_db : float, optional
|
143 |
+
normalize db, by default -16
|
144 |
+
|
145 |
+
Returns
|
146 |
+
-------
|
147 |
+
DACFile
|
148 |
+
Object containing compressed codes and metadata
|
149 |
+
required for decompression
|
150 |
+
"""
|
151 |
+
audio_signal = audio_path_or_signal
|
152 |
+
if isinstance(audio_signal, (str, Path)):
|
153 |
+
audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
|
154 |
+
|
155 |
+
self.eval()
|
156 |
+
original_padding = self.padding
|
157 |
+
original_device = audio_signal.device
|
158 |
+
|
159 |
+
audio_signal = audio_signal.clone()
|
160 |
+
original_sr = audio_signal.sample_rate
|
161 |
+
|
162 |
+
resample_fn = audio_signal.resample
|
163 |
+
loudness_fn = audio_signal.loudness
|
164 |
+
|
165 |
+
# If audio is > 10 minutes long, use the ffmpeg versions
|
166 |
+
if audio_signal.signal_duration >= 10 * 60 * 60:
|
167 |
+
resample_fn = audio_signal.ffmpeg_resample
|
168 |
+
loudness_fn = audio_signal.ffmpeg_loudness
|
169 |
+
|
170 |
+
original_length = audio_signal.signal_length
|
171 |
+
resample_fn(self.sample_rate)
|
172 |
+
input_db = loudness_fn()
|
173 |
+
|
174 |
+
if normalize_db is not None:
|
175 |
+
audio_signal.normalize(normalize_db)
|
176 |
+
audio_signal.ensure_max_of_audio()
|
177 |
+
|
178 |
+
nb, nac, nt = audio_signal.audio_data.shape
|
179 |
+
audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
|
180 |
+
win_duration = audio_signal.signal_duration if win_duration is None else win_duration
|
181 |
+
|
182 |
+
if audio_signal.signal_duration <= win_duration:
|
183 |
+
# Unchunked compression (used if signal length < win duration)
|
184 |
+
self.padding = True
|
185 |
+
n_samples = nt
|
186 |
+
hop = nt
|
187 |
+
else:
|
188 |
+
# Chunked inference
|
189 |
+
self.padding = False
|
190 |
+
# Zero-pad signal on either side by the delay
|
191 |
+
audio_signal.zero_pad(self.delay, self.delay)
|
192 |
+
n_samples = int(win_duration * self.sample_rate)
|
193 |
+
# Round n_samples to nearest hop length multiple
|
194 |
+
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
|
195 |
+
hop = self.get_output_length(n_samples)
|
196 |
+
|
197 |
+
codes = []
|
198 |
+
range_fn = range if not verbose else tqdm.trange
|
199 |
+
|
200 |
+
for i in range_fn(0, nt, hop):
|
201 |
+
x = audio_signal[..., i : i + n_samples]
|
202 |
+
x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
|
203 |
+
|
204 |
+
audio_data = x.audio_data.to(self.device)
|
205 |
+
audio_data = self.preprocess(audio_data, self.sample_rate)
|
206 |
+
_, c, _, _, _ = self.encode(audio_data, n_quantizers)
|
207 |
+
codes.append(c.to(original_device))
|
208 |
+
chunk_length = c.shape[-1]
|
209 |
+
|
210 |
+
codes = torch.cat(codes, dim=-1)
|
211 |
+
|
212 |
+
dac_file = DACFile(
|
213 |
+
codes=codes,
|
214 |
+
chunk_length=chunk_length,
|
215 |
+
original_length=original_length,
|
216 |
+
input_db=input_db,
|
217 |
+
channels=nac,
|
218 |
+
sample_rate=original_sr,
|
219 |
+
padding=self.padding,
|
220 |
+
dac_version=SUPPORTED_VERSIONS[-1],
|
221 |
+
)
|
222 |
+
|
223 |
+
if n_quantizers is not None:
|
224 |
+
codes = codes[:, :n_quantizers, :]
|
225 |
+
|
226 |
+
self.padding = original_padding
|
227 |
+
return dac_file
|
228 |
+
|
229 |
+
@torch.no_grad()
|
230 |
+
def decompress(
|
231 |
+
self,
|
232 |
+
obj: Union[str, Path, DACFile],
|
233 |
+
verbose: bool = False,
|
234 |
+
) -> AudioSignal:
|
235 |
+
"""Reconstruct audio from a given .dac file
|
236 |
+
|
237 |
+
Parameters
|
238 |
+
----------
|
239 |
+
obj : Union[str, Path, DACFile]
|
240 |
+
.dac file location or corresponding DACFile object.
|
241 |
+
verbose : bool, optional
|
242 |
+
Prints progress if True, by default False
|
243 |
+
|
244 |
+
Returns
|
245 |
+
-------
|
246 |
+
AudioSignal
|
247 |
+
Object with the reconstructed audio
|
248 |
+
"""
|
249 |
+
self.eval()
|
250 |
+
if isinstance(obj, (str, Path)):
|
251 |
+
obj = DACFile.load(obj)
|
252 |
+
|
253 |
+
original_padding = self.padding
|
254 |
+
self.padding = obj.padding
|
255 |
+
|
256 |
+
range_fn = range if not verbose else tqdm.trange
|
257 |
+
codes = obj.codes
|
258 |
+
original_device = codes.device
|
259 |
+
chunk_length = obj.chunk_length
|
260 |
+
recons = []
|
261 |
+
|
262 |
+
for i in range_fn(0, codes.shape[-1], chunk_length):
|
263 |
+
c = codes[..., i : i + chunk_length].to(self.device)
|
264 |
+
z = self.quantizer.from_codes(c)[0]
|
265 |
+
r = self.decode(z)
|
266 |
+
recons.append(r.to(original_device))
|
267 |
+
|
268 |
+
recons = torch.cat(recons, dim=-1)
|
269 |
+
recons = AudioSignal(recons, self.sample_rate)
|
270 |
+
|
271 |
+
resample_fn = recons.resample
|
272 |
+
loudness_fn = recons.loudness
|
273 |
+
|
274 |
+
# If audio is > 10 minutes long, use the ffmpeg versions
|
275 |
+
if recons.signal_duration >= 10 * 60 * 60:
|
276 |
+
resample_fn = recons.ffmpeg_resample
|
277 |
+
loudness_fn = recons.ffmpeg_loudness
|
278 |
+
|
279 |
+
recons.normalize(obj.input_db)
|
280 |
+
resample_fn(obj.sample_rate)
|
281 |
+
recons = recons[..., : obj.original_length]
|
282 |
+
loudness_fn()
|
283 |
+
recons.audio_data = recons.audio_data.reshape(-1, obj.channels, obj.original_length)
|
284 |
+
|
285 |
+
self.padding = original_padding
|
286 |
+
return recons
|
higgs_audio/audio_processing/descriptaudiocodec/dac/model/dac.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from audiotools import AudioSignal
|
8 |
+
from audiotools.ml import BaseModel
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from .base import CodecMixin
|
12 |
+
from dac.nn.layers import Snake1d
|
13 |
+
from dac.nn.layers import WNConv1d
|
14 |
+
from dac.nn.layers import WNConvTranspose1d
|
15 |
+
from dac.nn.quantize import ResidualVectorQuantize
|
16 |
+
|
17 |
+
|
18 |
+
def init_weights(m):
|
19 |
+
if isinstance(m, nn.Conv1d):
|
20 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
21 |
+
nn.init.constant_(m.bias, 0)
|
22 |
+
|
23 |
+
|
24 |
+
class ResidualUnit(nn.Module):
|
25 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
26 |
+
super().__init__()
|
27 |
+
pad = ((7 - 1) * dilation) // 2
|
28 |
+
self.block = nn.Sequential(
|
29 |
+
Snake1d(dim),
|
30 |
+
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
31 |
+
Snake1d(dim),
|
32 |
+
WNConv1d(dim, dim, kernel_size=1),
|
33 |
+
)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
y = self.block(x)
|
37 |
+
pad = (x.shape[-1] - y.shape[-1]) // 2
|
38 |
+
if pad > 0:
|
39 |
+
x = x[..., pad:-pad]
|
40 |
+
return x + y
|
41 |
+
|
42 |
+
|
43 |
+
class EncoderBlock(nn.Module):
|
44 |
+
def __init__(self, dim: int = 16, stride: int = 1):
|
45 |
+
super().__init__()
|
46 |
+
self.block = nn.Sequential(
|
47 |
+
ResidualUnit(dim // 2, dilation=1),
|
48 |
+
ResidualUnit(dim // 2, dilation=3),
|
49 |
+
ResidualUnit(dim // 2, dilation=9),
|
50 |
+
Snake1d(dim // 2),
|
51 |
+
WNConv1d(
|
52 |
+
dim // 2,
|
53 |
+
dim,
|
54 |
+
kernel_size=2 * stride,
|
55 |
+
stride=stride,
|
56 |
+
padding=math.ceil(stride / 2),
|
57 |
+
),
|
58 |
+
)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
return self.block(x)
|
62 |
+
|
63 |
+
|
64 |
+
class Encoder(nn.Module):
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
d_model: int = 64,
|
68 |
+
strides: list = [2, 4, 8, 8],
|
69 |
+
d_latent: int = 256,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
# Create first convolution
|
73 |
+
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
74 |
+
|
75 |
+
# Create EncoderBlocks that double channels as they downsample by `stride`
|
76 |
+
for stride in strides:
|
77 |
+
d_model *= 2
|
78 |
+
self.block += [EncoderBlock(d_model, stride=stride)]
|
79 |
+
|
80 |
+
# Create last convolution
|
81 |
+
self.block += [
|
82 |
+
Snake1d(d_model),
|
83 |
+
WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
|
84 |
+
]
|
85 |
+
|
86 |
+
# Wrap black into nn.Sequential
|
87 |
+
self.block = nn.Sequential(*self.block)
|
88 |
+
self.enc_dim = d_model
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
return self.block(x)
|
92 |
+
|
93 |
+
|
94 |
+
class DecoderBlock(nn.Module):
|
95 |
+
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, out_pad=0):
|
96 |
+
super().__init__()
|
97 |
+
self.block = nn.Sequential(
|
98 |
+
Snake1d(input_dim),
|
99 |
+
WNConvTranspose1d(
|
100 |
+
input_dim,
|
101 |
+
output_dim,
|
102 |
+
kernel_size=2 * stride,
|
103 |
+
stride=stride,
|
104 |
+
padding=math.ceil(stride / 2),
|
105 |
+
output_padding=stride % 2, # out_pad,
|
106 |
+
),
|
107 |
+
ResidualUnit(output_dim, dilation=1),
|
108 |
+
ResidualUnit(output_dim, dilation=3),
|
109 |
+
ResidualUnit(output_dim, dilation=9),
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
return self.block(x)
|
114 |
+
|
115 |
+
|
116 |
+
class Decoder(nn.Module):
|
117 |
+
def __init__(
|
118 |
+
self,
|
119 |
+
input_channel,
|
120 |
+
channels,
|
121 |
+
rates,
|
122 |
+
d_out: int = 1,
|
123 |
+
):
|
124 |
+
super().__init__()
|
125 |
+
|
126 |
+
# Add first conv layer
|
127 |
+
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
128 |
+
|
129 |
+
# Add upsampling + MRF blocks
|
130 |
+
for i, stride in enumerate(rates):
|
131 |
+
input_dim = channels // 2**i
|
132 |
+
output_dim = channels // 2 ** (i + 1)
|
133 |
+
if i == 1:
|
134 |
+
out_pad = 1
|
135 |
+
else:
|
136 |
+
out_pad = 0
|
137 |
+
layers += [DecoderBlock(input_dim, output_dim, stride, out_pad)]
|
138 |
+
|
139 |
+
# Add final conv layer
|
140 |
+
layers += [
|
141 |
+
Snake1d(output_dim),
|
142 |
+
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
143 |
+
# nn.Tanh(),
|
144 |
+
]
|
145 |
+
|
146 |
+
self.model = nn.Sequential(*layers)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
return self.model(x)
|
150 |
+
|
151 |
+
|
152 |
+
class DAC(BaseModel, CodecMixin):
|
153 |
+
def __init__(
|
154 |
+
self,
|
155 |
+
encoder_dim: int = 64,
|
156 |
+
encoder_rates: List[int] = [2, 4, 8, 8],
|
157 |
+
latent_dim: int = None,
|
158 |
+
decoder_dim: int = 1536,
|
159 |
+
decoder_rates: List[int] = [8, 8, 4, 2],
|
160 |
+
n_codebooks: int = 9,
|
161 |
+
codebook_size: int = 1024,
|
162 |
+
codebook_dim: Union[int, list] = 8,
|
163 |
+
quantizer_dropout: bool = False,
|
164 |
+
sample_rate: int = 44100,
|
165 |
+
):
|
166 |
+
super().__init__()
|
167 |
+
|
168 |
+
self.encoder_dim = encoder_dim
|
169 |
+
self.encoder_rates = encoder_rates
|
170 |
+
self.decoder_dim = decoder_dim
|
171 |
+
self.decoder_rates = decoder_rates
|
172 |
+
self.sample_rate = sample_rate
|
173 |
+
|
174 |
+
if latent_dim is None:
|
175 |
+
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
176 |
+
|
177 |
+
self.latent_dim = latent_dim
|
178 |
+
|
179 |
+
self.hop_length = np.prod(encoder_rates)
|
180 |
+
self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
|
181 |
+
|
182 |
+
self.n_codebooks = n_codebooks
|
183 |
+
self.codebook_size = codebook_size
|
184 |
+
self.codebook_dim = codebook_dim
|
185 |
+
self.quantizer = ResidualVectorQuantize(
|
186 |
+
input_dim=latent_dim,
|
187 |
+
n_codebooks=n_codebooks,
|
188 |
+
codebook_size=codebook_size,
|
189 |
+
codebook_dim=codebook_dim,
|
190 |
+
quantizer_dropout=quantizer_dropout,
|
191 |
+
)
|
192 |
+
|
193 |
+
self.decoder = Decoder(
|
194 |
+
latent_dim,
|
195 |
+
decoder_dim,
|
196 |
+
decoder_rates,
|
197 |
+
)
|
198 |
+
self.sample_rate = sample_rate
|
199 |
+
self.apply(init_weights)
|
200 |
+
|
201 |
+
self.delay = self.get_delay()
|
202 |
+
|
203 |
+
def preprocess(self, audio_data, sample_rate):
|
204 |
+
if sample_rate is None:
|
205 |
+
sample_rate = self.sample_rate
|
206 |
+
assert sample_rate == self.sample_rate
|
207 |
+
|
208 |
+
length = audio_data.shape[-1]
|
209 |
+
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
|
210 |
+
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
211 |
+
|
212 |
+
return audio_data
|
213 |
+
|
214 |
+
def encode(
|
215 |
+
self,
|
216 |
+
audio_data: torch.Tensor,
|
217 |
+
n_quantizers: int = None,
|
218 |
+
):
|
219 |
+
"""Encode given audio data and return quantized latent codes
|
220 |
+
|
221 |
+
Parameters
|
222 |
+
----------
|
223 |
+
audio_data : Tensor[B x 1 x T]
|
224 |
+
Audio data to encode
|
225 |
+
n_quantizers : int, optional
|
226 |
+
Number of quantizers to use, by default None
|
227 |
+
If None, all quantizers are used.
|
228 |
+
|
229 |
+
Returns
|
230 |
+
-------
|
231 |
+
dict
|
232 |
+
A dictionary with the following keys:
|
233 |
+
"z" : Tensor[B x D x T]
|
234 |
+
Quantized continuous representation of input
|
235 |
+
"codes" : Tensor[B x N x T]
|
236 |
+
Codebook indices for each codebook
|
237 |
+
(quantized discrete representation of input)
|
238 |
+
"latents" : Tensor[B x N*D x T]
|
239 |
+
Projected latents (continuous representation of input before quantization)
|
240 |
+
"vq/commitment_loss" : Tensor[1]
|
241 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
242 |
+
entries
|
243 |
+
"vq/codebook_loss" : Tensor[1]
|
244 |
+
Codebook loss to update the codebook
|
245 |
+
"length" : int
|
246 |
+
Number of samples in input audio
|
247 |
+
"""
|
248 |
+
z = self.encoder(audio_data)
|
249 |
+
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
|
250 |
+
return z, codes, latents, commitment_loss, codebook_loss
|
251 |
+
|
252 |
+
def decode(self, z: torch.Tensor):
|
253 |
+
"""Decode given latent codes and return audio data
|
254 |
+
|
255 |
+
Parameters
|
256 |
+
----------
|
257 |
+
z : Tensor[B x D x T]
|
258 |
+
Quantized continuous representation of input
|
259 |
+
length : int, optional
|
260 |
+
Number of samples in output audio, by default None
|
261 |
+
|
262 |
+
Returns
|
263 |
+
-------
|
264 |
+
dict
|
265 |
+
A dictionary with the following keys:
|
266 |
+
"audio" : Tensor[B x 1 x length]
|
267 |
+
Decoded audio data.
|
268 |
+
"""
|
269 |
+
return self.decoder(z)
|
270 |
+
|
271 |
+
def forward(
|
272 |
+
self,
|
273 |
+
audio_data: torch.Tensor,
|
274 |
+
sample_rate: int = None,
|
275 |
+
n_quantizers: int = None,
|
276 |
+
):
|
277 |
+
"""Model forward pass
|
278 |
+
|
279 |
+
Parameters
|
280 |
+
----------
|
281 |
+
audio_data : Tensor[B x 1 x T]
|
282 |
+
Audio data to encode
|
283 |
+
sample_rate : int, optional
|
284 |
+
Sample rate of audio data in Hz, by default None
|
285 |
+
If None, defaults to `self.sample_rate`
|
286 |
+
n_quantizers : int, optional
|
287 |
+
Number of quantizers to use, by default None.
|
288 |
+
If None, all quantizers are used.
|
289 |
+
|
290 |
+
Returns
|
291 |
+
-------
|
292 |
+
dict
|
293 |
+
A dictionary with the following keys:
|
294 |
+
"z" : Tensor[B x D x T]
|
295 |
+
Quantized continuous representation of input
|
296 |
+
"codes" : Tensor[B x N x T]
|
297 |
+
Codebook indices for each codebook
|
298 |
+
(quantized discrete representation of input)
|
299 |
+
"latents" : Tensor[B x N*D x T]
|
300 |
+
Projected latents (continuous representation of input before quantization)
|
301 |
+
"vq/commitment_loss" : Tensor[1]
|
302 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
303 |
+
entries
|
304 |
+
"vq/codebook_loss" : Tensor[1]
|
305 |
+
Codebook loss to update the codebook
|
306 |
+
"length" : int
|
307 |
+
Number of samples in input audio
|
308 |
+
"audio" : Tensor[B x 1 x length]
|
309 |
+
Decoded audio data.
|
310 |
+
"""
|
311 |
+
length = audio_data.shape[-1]
|
312 |
+
audio_data = self.preprocess(audio_data, sample_rate)
|
313 |
+
z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
|
314 |
+
|
315 |
+
x = self.decode(z)
|
316 |
+
return {
|
317 |
+
"audio": x[..., :length],
|
318 |
+
"z": z,
|
319 |
+
"codes": codes,
|
320 |
+
"latents": latents,
|
321 |
+
"vq/commitment_loss": commitment_loss,
|
322 |
+
"vq/codebook_loss": codebook_loss,
|
323 |
+
}
|
324 |
+
|
325 |
+
|
326 |
+
if __name__ == "__main__":
|
327 |
+
import numpy as np
|
328 |
+
from functools import partial
|
329 |
+
|
330 |
+
model = DAC().to("cpu")
|
331 |
+
|
332 |
+
for n, m in model.named_modules():
|
333 |
+
o = m.extra_repr()
|
334 |
+
p = sum([np.prod(p.size()) for p in m.parameters()])
|
335 |
+
fn = lambda o, p: o + f" {p / 1e6:<.3f}M params."
|
336 |
+
setattr(m, "extra_repr", partial(fn, o=o, p=p))
|
337 |
+
print(model)
|
338 |
+
print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
|
339 |
+
|
340 |
+
length = 88200 * 2
|
341 |
+
x = torch.randn(1, 1, length).to(model.device)
|
342 |
+
x.requires_grad_(True)
|
343 |
+
x.retain_grad()
|
344 |
+
|
345 |
+
# Make a forward pass
|
346 |
+
out = model(x)["audio"]
|
347 |
+
print("Input shape:", x.shape)
|
348 |
+
print("Output shape:", out.shape)
|
349 |
+
|
350 |
+
# Create gradient variable
|
351 |
+
grad = torch.zeros_like(out)
|
352 |
+
grad[:, :, grad.shape[-1] // 2] = 1
|
353 |
+
|
354 |
+
# Make a backward pass
|
355 |
+
out.backward(grad)
|
356 |
+
|
357 |
+
# Check non-zero values
|
358 |
+
gradmap = x.grad.squeeze(0)
|
359 |
+
gradmap = (gradmap != 0).sum(0) # sum across features
|
360 |
+
rf = (gradmap != 0).sum()
|
361 |
+
|
362 |
+
print(f"Receptive field: {rf.item()}")
|
363 |
+
|
364 |
+
x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
|
365 |
+
model.decompress(model.compress(x, verbose=True), verbose=True)
|
higgs_audio/audio_processing/descriptaudiocodec/dac/nn/layers.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange
|
6 |
+
from torch.nn.utils import weight_norm
|
7 |
+
|
8 |
+
|
9 |
+
def WNConv1d(*args, **kwargs):
|
10 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
11 |
+
|
12 |
+
|
13 |
+
def WNConvTranspose1d(*args, **kwargs):
|
14 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
15 |
+
|
16 |
+
|
17 |
+
# Scripting this brings model speed up 1.4x
|
18 |
+
@torch.jit.script
|
19 |
+
def snake(x, alpha):
|
20 |
+
shape = x.shape
|
21 |
+
x = x.reshape(shape[0], shape[1], -1)
|
22 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
23 |
+
x = x.reshape(shape)
|
24 |
+
return x
|
25 |
+
|
26 |
+
|
27 |
+
class Snake1d(nn.Module):
|
28 |
+
def __init__(self, channels):
|
29 |
+
super().__init__()
|
30 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
return snake(x, self.alpha)
|
higgs_audio/audio_processing/descriptaudiocodec/dac/nn/quantize.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from torch.nn.utils import weight_norm
|
9 |
+
|
10 |
+
from dac.nn.layers import WNConv1d
|
11 |
+
|
12 |
+
|
13 |
+
class VectorQuantize(nn.Module):
|
14 |
+
"""
|
15 |
+
Implementation of VQ similar to Karpathy's repo:
|
16 |
+
https://github.com/karpathy/deep-vector-quantization
|
17 |
+
Additionally uses following tricks from Improved VQGAN
|
18 |
+
(https://arxiv.org/pdf/2110.04627.pdf):
|
19 |
+
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
20 |
+
for improved codebook usage
|
21 |
+
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
22 |
+
improves training stability
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
|
26 |
+
super().__init__()
|
27 |
+
self.codebook_size = codebook_size
|
28 |
+
self.codebook_dim = codebook_dim
|
29 |
+
|
30 |
+
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
31 |
+
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
32 |
+
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
33 |
+
|
34 |
+
def forward(self, z):
|
35 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
36 |
+
the corresponding codebook vectors
|
37 |
+
|
38 |
+
Parameters
|
39 |
+
----------
|
40 |
+
z : Tensor[B x D x T]
|
41 |
+
|
42 |
+
Returns
|
43 |
+
-------
|
44 |
+
Tensor[B x D x T]
|
45 |
+
Quantized continuous representation of input
|
46 |
+
Tensor[1]
|
47 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
48 |
+
entries
|
49 |
+
Tensor[1]
|
50 |
+
Codebook loss to update the codebook
|
51 |
+
Tensor[B x T]
|
52 |
+
Codebook indices (quantized discrete representation of input)
|
53 |
+
Tensor[B x D x T]
|
54 |
+
Projected latents (continuous representation of input before quantization)
|
55 |
+
"""
|
56 |
+
|
57 |
+
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
58 |
+
z_e = self.in_proj(z) # z_e : (B x D x T)
|
59 |
+
z_q, indices = self.decode_latents(z_e)
|
60 |
+
|
61 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
62 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
63 |
+
|
64 |
+
z_q = z_e + (z_q - z_e).detach() # noop in forward pass, straight-through gradient estimator in backward pass
|
65 |
+
|
66 |
+
z_q = self.out_proj(z_q)
|
67 |
+
|
68 |
+
return z_q, commitment_loss, codebook_loss, indices, z_e
|
69 |
+
|
70 |
+
def embed_code(self, embed_id):
|
71 |
+
return F.embedding(embed_id, self.codebook.weight)
|
72 |
+
|
73 |
+
def decode_code(self, embed_id):
|
74 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
75 |
+
|
76 |
+
def decode_latents(self, latents):
|
77 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
78 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
79 |
+
|
80 |
+
# L2 normalize encodings and codebook (ViT-VQGAN)
|
81 |
+
encodings = F.normalize(encodings)
|
82 |
+
codebook = F.normalize(codebook)
|
83 |
+
|
84 |
+
# Compute euclidean distance with codebook
|
85 |
+
dist = (
|
86 |
+
encodings.pow(2).sum(1, keepdim=True)
|
87 |
+
- 2 * encodings @ codebook.t()
|
88 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
89 |
+
)
|
90 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
91 |
+
z_q = self.decode_code(indices)
|
92 |
+
return z_q, indices
|
93 |
+
|
94 |
+
|
95 |
+
class ResidualVectorQuantize(nn.Module):
|
96 |
+
"""
|
97 |
+
Introduced in SoundStream: An end2end neural audio codec
|
98 |
+
https://arxiv.org/abs/2107.03312
|
99 |
+
"""
|
100 |
+
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
input_dim: int = 512,
|
104 |
+
n_codebooks: int = 9,
|
105 |
+
codebook_size: int = 1024,
|
106 |
+
codebook_dim: Union[int, list] = 8,
|
107 |
+
quantizer_dropout: float = 0.0,
|
108 |
+
):
|
109 |
+
super().__init__()
|
110 |
+
if isinstance(codebook_dim, int):
|
111 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
112 |
+
|
113 |
+
self.n_codebooks = n_codebooks
|
114 |
+
self.codebook_dim = codebook_dim
|
115 |
+
self.codebook_size = codebook_size
|
116 |
+
|
117 |
+
self.quantizers = nn.ModuleList(
|
118 |
+
[VectorQuantize(input_dim, codebook_size, codebook_dim[i]) for i in range(n_codebooks)]
|
119 |
+
)
|
120 |
+
self.quantizer_dropout = quantizer_dropout
|
121 |
+
|
122 |
+
def forward(self, z, n_quantizers: int = None):
|
123 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
124 |
+
the corresponding codebook vectors
|
125 |
+
Parameters
|
126 |
+
----------
|
127 |
+
z : Tensor[B x D x T]
|
128 |
+
n_quantizers : int, optional
|
129 |
+
No. of quantizers to use
|
130 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
131 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
132 |
+
when in training mode, and a random number of quantizers is used.
|
133 |
+
Returns
|
134 |
+
-------
|
135 |
+
dict
|
136 |
+
A dictionary with the following keys:
|
137 |
+
|
138 |
+
"z" : Tensor[B x D x T]
|
139 |
+
Quantized continuous representation of input
|
140 |
+
"codes" : Tensor[B x N x T]
|
141 |
+
Codebook indices for each codebook
|
142 |
+
(quantized discrete representation of input)
|
143 |
+
"latents" : Tensor[B x N*D x T]
|
144 |
+
Projected latents (continuous representation of input before quantization)
|
145 |
+
"vq/commitment_loss" : Tensor[1]
|
146 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
147 |
+
entries
|
148 |
+
"vq/codebook_loss" : Tensor[1]
|
149 |
+
Codebook loss to update the codebook
|
150 |
+
"""
|
151 |
+
z_q = 0
|
152 |
+
residual = z
|
153 |
+
commitment_loss = 0
|
154 |
+
codebook_loss = 0
|
155 |
+
|
156 |
+
codebook_indices = []
|
157 |
+
latents = []
|
158 |
+
|
159 |
+
if n_quantizers is None:
|
160 |
+
n_quantizers = self.n_codebooks
|
161 |
+
if self.training:
|
162 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
163 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
164 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
165 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
166 |
+
n_quantizers = n_quantizers.to(z.device)
|
167 |
+
|
168 |
+
for i, quantizer in enumerate(self.quantizers):
|
169 |
+
if self.training is False and i >= n_quantizers:
|
170 |
+
break
|
171 |
+
|
172 |
+
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual)
|
173 |
+
|
174 |
+
# Create mask to apply quantizer dropout
|
175 |
+
mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
176 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
177 |
+
residual = residual - z_q_i
|
178 |
+
|
179 |
+
# Sum losses
|
180 |
+
commitment_loss += (commitment_loss_i * mask).mean()
|
181 |
+
codebook_loss += (codebook_loss_i * mask).mean()
|
182 |
+
|
183 |
+
codebook_indices.append(indices_i)
|
184 |
+
latents.append(z_e_i)
|
185 |
+
|
186 |
+
codes = torch.stack(codebook_indices, dim=1)
|
187 |
+
latents = torch.cat(latents, dim=1)
|
188 |
+
|
189 |
+
return z_q, codes, latents, commitment_loss, codebook_loss
|
190 |
+
|
191 |
+
def from_codes(self, codes: torch.Tensor):
|
192 |
+
"""Given the quantized codes, reconstruct the continuous representation
|
193 |
+
Parameters
|
194 |
+
----------
|
195 |
+
codes : Tensor[B x N x T]
|
196 |
+
Quantized discrete representation of input
|
197 |
+
Returns
|
198 |
+
-------
|
199 |
+
Tensor[B x D x T]
|
200 |
+
Quantized continuous representation of input
|
201 |
+
"""
|
202 |
+
z_q = 0.0
|
203 |
+
z_p = []
|
204 |
+
n_codebooks = codes.shape[1]
|
205 |
+
for i in range(n_codebooks):
|
206 |
+
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
207 |
+
z_p.append(z_p_i)
|
208 |
+
|
209 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
210 |
+
z_q = z_q + z_q_i
|
211 |
+
return z_q, torch.cat(z_p, dim=1), codes
|
212 |
+
|
213 |
+
def from_latents(self, latents: torch.Tensor):
|
214 |
+
"""Given the unquantized latents, reconstruct the
|
215 |
+
continuous representation after quantization.
|
216 |
+
|
217 |
+
Parameters
|
218 |
+
----------
|
219 |
+
latents : Tensor[B x N x T]
|
220 |
+
Continuous representation of input after projection
|
221 |
+
|
222 |
+
Returns
|
223 |
+
-------
|
224 |
+
Tensor[B x D x T]
|
225 |
+
Quantized representation of full-projected space
|
226 |
+
Tensor[B x D x T]
|
227 |
+
Quantized representation of latent space
|
228 |
+
"""
|
229 |
+
z_q = 0
|
230 |
+
z_p = []
|
231 |
+
codes = []
|
232 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
233 |
+
|
234 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
|
235 |
+
for i in range(n_codebooks):
|
236 |
+
j, k = dims[i], dims[i + 1]
|
237 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
238 |
+
z_p.append(z_p_i)
|
239 |
+
codes.append(codes_i)
|
240 |
+
|
241 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
242 |
+
z_q = z_q + z_q_i
|
243 |
+
|
244 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
245 |
+
|
246 |
+
|
247 |
+
if __name__ == "__main__":
|
248 |
+
rvq = ResidualVectorQuantize(quantizer_dropout=True)
|
249 |
+
x = torch.randn(16, 512, 80)
|
250 |
+
y = rvq(x)
|
251 |
+
print(y["latents"].shape)
|
higgs_audio/audio_processing/higgs_audio_tokenizer.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on code from: https://github.com/zhenye234/xcodec
|
2 |
+
# Licensed under MIT License
|
3 |
+
# Modifications by BosonAI
|
4 |
+
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from typing import Optional, Union, Sequence
|
11 |
+
import numpy as np
|
12 |
+
from transformers import AutoModel
|
13 |
+
import torchaudio
|
14 |
+
import json
|
15 |
+
import librosa
|
16 |
+
from huggingface_hub import snapshot_download
|
17 |
+
|
18 |
+
from vector_quantize_pytorch import ResidualFSQ
|
19 |
+
from .descriptaudiocodec.dac.model import dac as dac2
|
20 |
+
from .quantization.vq import ResidualVectorQuantizer
|
21 |
+
from .semantic_module import Encoder, Decoder
|
22 |
+
|
23 |
+
|
24 |
+
class EncodedResult:
|
25 |
+
def __init__(self, audio_codes):
|
26 |
+
self.audio_codes = audio_codes
|
27 |
+
|
28 |
+
|
29 |
+
class HiggsAudioFeatureExtractor(nn.Module):
|
30 |
+
def __init__(self, sampling_rate=16000):
|
31 |
+
super().__init__()
|
32 |
+
self.sampling_rate = sampling_rate
|
33 |
+
|
34 |
+
def forward(self, raw_audio, sampling_rate=16000, return_tensors="pt"):
|
35 |
+
# Convert from librosa to torch
|
36 |
+
audio_signal = torch.tensor(raw_audio)
|
37 |
+
audio_signal = audio_signal.unsqueeze(0)
|
38 |
+
if len(audio_signal.shape) < 3:
|
39 |
+
audio_signal = audio_signal.unsqueeze(0)
|
40 |
+
return {"input_values": audio_signal}
|
41 |
+
|
42 |
+
|
43 |
+
class HiggsAudioTokenizer(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
n_filters: int = 32,
|
47 |
+
D: int = 128,
|
48 |
+
target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6],
|
49 |
+
ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320
|
50 |
+
sample_rate: int = 16000,
|
51 |
+
bins: int = 1024,
|
52 |
+
n_q: int = 8,
|
53 |
+
codebook_dim: int = None,
|
54 |
+
normalize: bool = False,
|
55 |
+
causal: bool = False,
|
56 |
+
semantic_techer: str = "hubert_base_general",
|
57 |
+
last_layer_semantic: bool = True,
|
58 |
+
merge_mode: str = "concat",
|
59 |
+
downsample_mode: str = "step_down",
|
60 |
+
semantic_mode: str = "classic",
|
61 |
+
vq_scale: int = 1,
|
62 |
+
semantic_sample_rate: int = None,
|
63 |
+
device: str = "cuda",
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
self.hop_length = np.prod(ratios)
|
67 |
+
self.semantic_techer = semantic_techer
|
68 |
+
|
69 |
+
self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz
|
70 |
+
|
71 |
+
self.target_bandwidths = target_bandwidths
|
72 |
+
self.n_q = n_q
|
73 |
+
self.sample_rate = sample_rate
|
74 |
+
self.encoder = dac2.Encoder(64, ratios, D)
|
75 |
+
|
76 |
+
self.decoder_2 = dac2.Decoder(D, 1024, ratios)
|
77 |
+
self.last_layer_semantic = last_layer_semantic
|
78 |
+
self.device = device
|
79 |
+
if semantic_techer == "hubert_base":
|
80 |
+
self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960")
|
81 |
+
self.semantic_sample_rate = 16000
|
82 |
+
self.semantic_dim = 768
|
83 |
+
self.encoder_semantic_dim = 768
|
84 |
+
|
85 |
+
elif semantic_techer == "wavlm_base_plus":
|
86 |
+
self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus")
|
87 |
+
self.semantic_sample_rate = 16000
|
88 |
+
self.semantic_dim = 768
|
89 |
+
self.encoder_semantic_dim = 768
|
90 |
+
|
91 |
+
elif semantic_techer == "hubert_base_general":
|
92 |
+
self.semantic_model = AutoModel.from_pretrained("ZhenYe234/hubert_base_general_audio")
|
93 |
+
self.semantic_sample_rate = 16000
|
94 |
+
self.semantic_dim = 768
|
95 |
+
self.encoder_semantic_dim = 768
|
96 |
+
|
97 |
+
# Overwrite semantic model sr to ensure semantic_downsample_factor is an integer
|
98 |
+
if semantic_sample_rate is not None:
|
99 |
+
self.semantic_sample_rate = semantic_sample_rate
|
100 |
+
|
101 |
+
self.semantic_model.eval()
|
102 |
+
|
103 |
+
# make the semantic model parameters do not need gradient
|
104 |
+
for param in self.semantic_model.parameters():
|
105 |
+
param.requires_grad = False
|
106 |
+
|
107 |
+
self.semantic_downsample_factor = int(self.hop_length / (self.sample_rate / self.semantic_sample_rate) / 320)
|
108 |
+
|
109 |
+
self.quantizer_dim = int((D + self.encoder_semantic_dim) // vq_scale)
|
110 |
+
self.encoder_semantic = Encoder(input_channels=self.semantic_dim, encode_channels=self.encoder_semantic_dim)
|
111 |
+
self.decoder_semantic = Decoder(
|
112 |
+
code_dim=self.encoder_semantic_dim,
|
113 |
+
output_channels=self.semantic_dim,
|
114 |
+
decode_channels=self.semantic_dim,
|
115 |
+
)
|
116 |
+
|
117 |
+
# out_D=D+768
|
118 |
+
if isinstance(bins, int): # RVQ
|
119 |
+
self.quantizer = ResidualVectorQuantizer(
|
120 |
+
dimension=self.quantizer_dim,
|
121 |
+
codebook_dim=codebook_dim,
|
122 |
+
n_q=n_q,
|
123 |
+
bins=bins,
|
124 |
+
)
|
125 |
+
self.quantizer_type = "RVQ"
|
126 |
+
else: # RFSQ
|
127 |
+
self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q)
|
128 |
+
self.quantizer_type = "RFSQ"
|
129 |
+
|
130 |
+
self.fc_prior = nn.Linear(D + self.encoder_semantic_dim, self.quantizer_dim)
|
131 |
+
self.fc_post1 = nn.Linear(self.quantizer_dim, self.encoder_semantic_dim)
|
132 |
+
self.fc_post2 = nn.Linear(self.quantizer_dim, D)
|
133 |
+
|
134 |
+
self.downsample_mode = downsample_mode
|
135 |
+
if downsample_mode == "avg":
|
136 |
+
self.semantic_pooling = nn.AvgPool1d(
|
137 |
+
kernel_size=self.semantic_downsample_factor,
|
138 |
+
stride=self.semantic_downsample_factor,
|
139 |
+
)
|
140 |
+
|
141 |
+
self.audio_tokenizer_feature_extractor = HiggsAudioFeatureExtractor(sampling_rate=self.sample_rate)
|
142 |
+
|
143 |
+
@property
|
144 |
+
def tps(self):
|
145 |
+
return self.frame_rate
|
146 |
+
|
147 |
+
@property
|
148 |
+
def sampling_rate(self):
|
149 |
+
return self.sample_rate
|
150 |
+
|
151 |
+
@property
|
152 |
+
def num_codebooks(self):
|
153 |
+
return self.n_q
|
154 |
+
|
155 |
+
@property
|
156 |
+
def codebook_size(self):
|
157 |
+
return self.quantizer_dim
|
158 |
+
|
159 |
+
def get_last_layer(self):
|
160 |
+
return self.decoder.layers[-1].weight
|
161 |
+
|
162 |
+
def calculate_rec_loss(self, rec, target):
|
163 |
+
target = target / target.norm(dim=-1, keepdim=True)
|
164 |
+
rec = rec / rec.norm(dim=-1, keepdim=True)
|
165 |
+
rec_loss = (1 - (target * rec).sum(-1)).mean()
|
166 |
+
|
167 |
+
return rec_loss
|
168 |
+
|
169 |
+
@torch.no_grad()
|
170 |
+
def get_regress_target(self, x):
|
171 |
+
x = torchaudio.functional.resample(x, self.sample_rate, self.semantic_sample_rate)
|
172 |
+
|
173 |
+
if (
|
174 |
+
self.semantic_techer == "hubert_base"
|
175 |
+
or self.semantic_techer == "hubert_base_general"
|
176 |
+
or self.semantic_techer == "wavlm_base_plus"
|
177 |
+
):
|
178 |
+
x = x[:, 0, :]
|
179 |
+
x = F.pad(x, (160, 160))
|
180 |
+
target = self.semantic_model(x, output_hidden_states=True).hidden_states
|
181 |
+
target = torch.stack(target, dim=1) # .transpose(-1, -2)#.flatten(start_dim=1, end_dim=2)
|
182 |
+
|
183 |
+
# average for all layers
|
184 |
+
target = target.mean(1)
|
185 |
+
# target = target[9]
|
186 |
+
# if self.hop_length > 320:
|
187 |
+
# target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
|
188 |
+
|
189 |
+
elif self.semantic_techer == "w2v_bert2":
|
190 |
+
target = self.semantic_model(x)
|
191 |
+
|
192 |
+
elif self.semantic_techer.startswith("whisper"):
|
193 |
+
if self.last_layer_semantic:
|
194 |
+
target = self.semantic_model(x, avg_layers=False)
|
195 |
+
else:
|
196 |
+
target = self.semantic_model(x, avg_layers=True)
|
197 |
+
|
198 |
+
elif self.semantic_techer.startswith("mert_music"):
|
199 |
+
if self.last_layer_semantic:
|
200 |
+
target = self.semantic_model(x, avg_layers=False)
|
201 |
+
else:
|
202 |
+
target = self.semantic_model(x, avg_layers=True)
|
203 |
+
|
204 |
+
elif self.semantic_techer.startswith("qwen_audio_omni"):
|
205 |
+
target = self.semantic_model(x)
|
206 |
+
|
207 |
+
if self.downsample_mode == "step_down":
|
208 |
+
if self.semantic_downsample_factor > 1:
|
209 |
+
target = target[:, :: self.semantic_downsample_factor, :]
|
210 |
+
|
211 |
+
elif self.downsample_mode == "avg":
|
212 |
+
target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
|
213 |
+
return target
|
214 |
+
|
215 |
+
def forward(self, x: torch.Tensor, bw: int):
|
216 |
+
e_semantic_input = self.get_regress_target(x).detach()
|
217 |
+
|
218 |
+
e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
|
219 |
+
e_acoustic = self.encoder(x)
|
220 |
+
|
221 |
+
e = torch.cat([e_acoustic, e_semantic], dim=1)
|
222 |
+
|
223 |
+
e = self.fc_prior(e.transpose(1, 2))
|
224 |
+
|
225 |
+
if self.quantizer_type == "RVQ":
|
226 |
+
e = e.transpose(1, 2)
|
227 |
+
quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
|
228 |
+
quantized = quantized.transpose(1, 2)
|
229 |
+
else:
|
230 |
+
quantized, codes = self.quantizer(e)
|
231 |
+
commit_loss = torch.tensor(0.0)
|
232 |
+
|
233 |
+
quantized_semantic = self.fc_post1(quantized).transpose(1, 2)
|
234 |
+
quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
|
235 |
+
|
236 |
+
o = self.decoder_2(quantized_acoustic)
|
237 |
+
|
238 |
+
o_semantic = self.decoder_semantic(quantized_semantic)
|
239 |
+
semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(), o_semantic)
|
240 |
+
|
241 |
+
return o, commit_loss, semantic_recon_loss, None
|
242 |
+
|
243 |
+
def encode(
|
244 |
+
self,
|
245 |
+
audio_path_or_wv,
|
246 |
+
sr=None,
|
247 |
+
loudness_normalize=False,
|
248 |
+
loudness_threshold=-23.0,
|
249 |
+
):
|
250 |
+
if isinstance(audio_path_or_wv, str):
|
251 |
+
wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None)
|
252 |
+
else:
|
253 |
+
wv = audio_path_or_wv
|
254 |
+
assert sr is not None
|
255 |
+
if loudness_normalize:
|
256 |
+
import pyloudnorm as pyln
|
257 |
+
|
258 |
+
meter = pyln.Meter(sr)
|
259 |
+
l = meter.integrated_loudness(wv)
|
260 |
+
wv = pyln.normalize.loudness(wv, l, loudness_threshold)
|
261 |
+
if sr != self.sampling_rate:
|
262 |
+
wv = librosa.resample(wv, orig_sr=sr, target_sr=self.sampling_rate)
|
263 |
+
if self.audio_tokenizer_feature_extractor is not None:
|
264 |
+
inputs = self.audio_tokenizer_feature_extractor(
|
265 |
+
raw_audio=wv,
|
266 |
+
sampling_rate=self.audio_tokenizer_feature_extractor.sampling_rate,
|
267 |
+
return_tensors="pt",
|
268 |
+
)
|
269 |
+
input_values = inputs["input_values"].to(self.device)
|
270 |
+
else:
|
271 |
+
input_values = torch.from_numpy(wv).float().unsqueeze(0)
|
272 |
+
with torch.no_grad():
|
273 |
+
encoder_outputs = self._xcodec_encode(input_values)
|
274 |
+
vq_code = encoder_outputs.audio_codes[0]
|
275 |
+
return vq_code
|
276 |
+
|
277 |
+
def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor:
|
278 |
+
bw = target_bw
|
279 |
+
|
280 |
+
e_semantic_input = self.get_regress_target(x).detach()
|
281 |
+
|
282 |
+
e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
|
283 |
+
e_acoustic = self.encoder(x)
|
284 |
+
|
285 |
+
if e_acoustic.shape[2] != e_semantic.shape[2]:
|
286 |
+
pad_size = 160 * self.semantic_downsample_factor
|
287 |
+
e_acoustic = self.encoder(F.pad(x[:, 0, :], (pad_size, pad_size)).unsqueeze(0))
|
288 |
+
|
289 |
+
if e_acoustic.shape[2] != e_semantic.shape[2]:
|
290 |
+
if e_acoustic.shape[2] > e_semantic.shape[2]:
|
291 |
+
e_acoustic = e_acoustic[:, :, : e_semantic.shape[2]]
|
292 |
+
else:
|
293 |
+
e_semantic = e_semantic[:, :, : e_acoustic.shape[2]]
|
294 |
+
|
295 |
+
e = torch.cat([e_acoustic, e_semantic], dim=1)
|
296 |
+
|
297 |
+
e = self.fc_prior(e.transpose(1, 2))
|
298 |
+
|
299 |
+
if self.quantizer_type == "RVQ":
|
300 |
+
e = e.transpose(1, 2)
|
301 |
+
quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
|
302 |
+
codes = codes.permute(1, 0, 2)
|
303 |
+
else:
|
304 |
+
quantized, codes = self.quantizer(e)
|
305 |
+
codes = codes.permute(0, 2, 1)
|
306 |
+
|
307 |
+
# return codes
|
308 |
+
return EncodedResult(codes)
|
309 |
+
|
310 |
+
def decode(self, vq_code: torch.Tensor) -> torch.Tensor:
|
311 |
+
if self.quantizer_type == "RVQ":
|
312 |
+
vq_code = vq_code.permute(1, 0, 2)
|
313 |
+
quantized = self.quantizer.decode(vq_code)
|
314 |
+
quantized = quantized.transpose(1, 2)
|
315 |
+
else:
|
316 |
+
vq_code = vq_code.permute(0, 2, 1)
|
317 |
+
quantized = self.quantizer.get_output_from_indices(vq_code)
|
318 |
+
quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
|
319 |
+
|
320 |
+
o = self.decoder_2(quantized_acoustic)
|
321 |
+
return o.cpu().numpy()
|
322 |
+
|
323 |
+
|
324 |
+
def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"):
|
325 |
+
is_local = os.path.exists(tokenizer_name_or_path)
|
326 |
+
if not is_local:
|
327 |
+
tokenizer_path = snapshot_download(tokenizer_name_or_path)
|
328 |
+
else:
|
329 |
+
tokenizer_path = tokenizer_name_or_path
|
330 |
+
config_path = os.path.join(tokenizer_path, "config.json")
|
331 |
+
model_path = os.path.join(tokenizer_path, "model.pth")
|
332 |
+
config = json.load(open(config_path))
|
333 |
+
model = HiggsAudioTokenizer(
|
334 |
+
**config,
|
335 |
+
device=device,
|
336 |
+
)
|
337 |
+
parameter_dict = torch.load(model_path, map_location=device)
|
338 |
+
model.load_state_dict(parameter_dict, strict=False)
|
339 |
+
model.to(device)
|
340 |
+
model.eval()
|
341 |
+
return model
|
higgs_audio/audio_processing/quantization/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# flake8: noqa
|
8 |
+
from .vq import QuantizedResult, ResidualVectorQuantizer
|
higgs_audio/audio_processing/quantization/ac.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""Arithmetic coder."""
|
8 |
+
|
9 |
+
import io
|
10 |
+
import math
|
11 |
+
import random
|
12 |
+
import typing as tp
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from ..binary import BitPacker, BitUnpacker
|
16 |
+
|
17 |
+
|
18 |
+
def build_stable_quantized_cdf(
|
19 |
+
pdf: torch.Tensor,
|
20 |
+
total_range_bits: int,
|
21 |
+
roundoff: float = 1e-8,
|
22 |
+
min_range: int = 2,
|
23 |
+
check: bool = True,
|
24 |
+
) -> torch.Tensor:
|
25 |
+
"""Turn the given PDF into a quantized CDF that splits
|
26 |
+
[0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
|
27 |
+
to the PDF.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
pdf (torch.Tensor): probability distribution, shape should be `[N]`.
|
31 |
+
total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
|
32 |
+
during the coding process is `[0, 2 ** total_range_bits - 1]`.
|
33 |
+
roundoff (float): will round the pdf up to that level to remove difference coming
|
34 |
+
from e.g. evaluating the Language Model on different architectures.
|
35 |
+
min_range (int): minimum range width. Should always be at least 2 for numerical
|
36 |
+
stability. Use this to avoid pathological behavior is a value
|
37 |
+
that is expected to be rare actually happens in real life.
|
38 |
+
check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
|
39 |
+
"""
|
40 |
+
pdf = pdf.detach()
|
41 |
+
if roundoff:
|
42 |
+
pdf = (pdf / roundoff).floor() * roundoff
|
43 |
+
# interpolate with uniform distribution to achieve desired minimum probability.
|
44 |
+
total_range = 2**total_range_bits
|
45 |
+
cardinality = len(pdf)
|
46 |
+
alpha = min_range * cardinality / total_range
|
47 |
+
assert alpha <= 1, "you must reduce min_range"
|
48 |
+
ranges = (((1 - alpha) * total_range) * pdf).floor().long()
|
49 |
+
ranges += min_range
|
50 |
+
quantized_cdf = torch.cumsum(ranges, dim=-1)
|
51 |
+
if min_range < 2:
|
52 |
+
raise ValueError("min_range must be at least 2.")
|
53 |
+
if check:
|
54 |
+
assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
|
55 |
+
if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
|
56 |
+
raise ValueError("You must increase your total_range_bits.")
|
57 |
+
return quantized_cdf
|
58 |
+
|
59 |
+
|
60 |
+
class ArithmeticCoder:
|
61 |
+
"""ArithmeticCoder,
|
62 |
+
Let us take a distribution `p` over `N` symbols, and assume we have a stream
|
63 |
+
of random variables `s_t` sampled from `p`. Let us assume that we have a budget
|
64 |
+
of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
|
65 |
+
corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
|
66 |
+
sequence `(s_t)` by doing the following:
|
67 |
+
|
68 |
+
1) Initialize the current range to` [0 ** 2 B - 1]`.
|
69 |
+
2) For each time step t, split the current range into contiguous chunks,
|
70 |
+
one for each possible outcome, with size roughly proportional to `p`.
|
71 |
+
For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
|
72 |
+
would be `{[0, 2], [3, 3]}`.
|
73 |
+
3) Select the chunk corresponding to `s_t`, and replace the current range with this.
|
74 |
+
4) When done encoding all the values, just select any value remaining in the range.
|
75 |
+
|
76 |
+
You will notice that this procedure can fail: for instance if at any point in time
|
77 |
+
the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
|
78 |
+
possible outcome. Intuitively, the more likely a value is, the less the range width
|
79 |
+
will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
|
80 |
+
coding scheme, likely outcomes would take less bits, and more of them can be coded
|
81 |
+
with a fixed budget.
|
82 |
+
|
83 |
+
In practice, we do not know `B` ahead of time, but we have a way to inject new bits
|
84 |
+
when the current range decreases below a given limit (given by `total_range_bits`), without
|
85 |
+
having to redo all the computations. If we encode mostly likely values, we will seldom
|
86 |
+
need to inject new bits, but a single rare value can deplete our stock of entropy!
|
87 |
+
|
88 |
+
In this explanation, we assumed that the distribution `p` was constant. In fact, the present
|
89 |
+
code works for any sequence `(p_t)` possibly different for each timestep.
|
90 |
+
We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
|
91 |
+
the KL between the true distribution and `p_t`, the most efficient the coding will be.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
fo (IO[bytes]): file-like object to which the bytes will be written to.
|
95 |
+
total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
|
96 |
+
Any time the current range width fall under this limit, new bits will
|
97 |
+
be injected to rescale the initial range.
|
98 |
+
"""
|
99 |
+
|
100 |
+
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
|
101 |
+
assert total_range_bits <= 30
|
102 |
+
self.total_range_bits = total_range_bits
|
103 |
+
self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
|
104 |
+
self.low: int = 0
|
105 |
+
self.high: int = 0
|
106 |
+
self.max_bit: int = -1
|
107 |
+
self._dbg: tp.List[tp.Any] = []
|
108 |
+
self._dbg2: tp.List[tp.Any] = []
|
109 |
+
|
110 |
+
@property
|
111 |
+
def delta(self) -> int:
|
112 |
+
"""Return the current range width."""
|
113 |
+
return self.high - self.low + 1
|
114 |
+
|
115 |
+
def _flush_common_prefix(self):
|
116 |
+
# If self.low and self.high start with the sames bits,
|
117 |
+
# those won't change anymore as we always just increase the range
|
118 |
+
# by powers of 2, and we can flush them out to the bit stream.
|
119 |
+
assert self.high >= self.low, (self.low, self.high)
|
120 |
+
assert self.high < 2 ** (self.max_bit + 1)
|
121 |
+
while self.max_bit >= 0:
|
122 |
+
b1 = self.low >> self.max_bit
|
123 |
+
b2 = self.high >> self.max_bit
|
124 |
+
if b1 == b2:
|
125 |
+
self.low -= b1 << self.max_bit
|
126 |
+
self.high -= b1 << self.max_bit
|
127 |
+
assert self.high >= self.low, (self.high, self.low, self.max_bit)
|
128 |
+
assert self.low >= 0
|
129 |
+
self.max_bit -= 1
|
130 |
+
self.packer.push(b1)
|
131 |
+
else:
|
132 |
+
break
|
133 |
+
|
134 |
+
def push(self, symbol: int, quantized_cdf: torch.Tensor):
|
135 |
+
"""Push the given symbol on the stream, flushing out bits
|
136 |
+
if possible.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
symbol (int): symbol to encode with the AC.
|
140 |
+
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
|
141 |
+
to build this from your pdf estimate.
|
142 |
+
"""
|
143 |
+
while self.delta < 2**self.total_range_bits:
|
144 |
+
self.low *= 2
|
145 |
+
self.high = self.high * 2 + 1
|
146 |
+
self.max_bit += 1
|
147 |
+
|
148 |
+
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
|
149 |
+
range_high = quantized_cdf[symbol].item() - 1
|
150 |
+
effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
|
151 |
+
effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
|
152 |
+
assert self.low <= self.high
|
153 |
+
self.high = self.low + effective_high
|
154 |
+
self.low = self.low + effective_low
|
155 |
+
assert self.low <= self.high, (
|
156 |
+
effective_low,
|
157 |
+
effective_high,
|
158 |
+
range_low,
|
159 |
+
range_high,
|
160 |
+
)
|
161 |
+
self._dbg.append((self.low, self.high))
|
162 |
+
self._dbg2.append((self.low, self.high))
|
163 |
+
outs = self._flush_common_prefix()
|
164 |
+
assert self.low <= self.high
|
165 |
+
assert self.max_bit >= -1
|
166 |
+
assert self.max_bit <= 61, self.max_bit
|
167 |
+
return outs
|
168 |
+
|
169 |
+
def flush(self):
|
170 |
+
"""Flush the remaining information to the stream."""
|
171 |
+
while self.max_bit >= 0:
|
172 |
+
b1 = (self.low >> self.max_bit) & 1
|
173 |
+
self.packer.push(b1)
|
174 |
+
self.max_bit -= 1
|
175 |
+
self.packer.flush()
|
176 |
+
|
177 |
+
|
178 |
+
class ArithmeticDecoder:
|
179 |
+
"""ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
|
180 |
+
|
181 |
+
Note that this must be called with **exactly** the same parameters and sequence
|
182 |
+
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
|
183 |
+
|
184 |
+
If the AC encoder current range is [L, H], with `L` and `H` having the some common
|
185 |
+
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
|
186 |
+
For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
|
187 |
+
`[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
|
188 |
+
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
|
189 |
+
At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
|
190 |
+
and we will need to read new bits from the stream and repeat the process.
|
191 |
+
|
192 |
+
"""
|
193 |
+
|
194 |
+
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
|
195 |
+
self.total_range_bits = total_range_bits
|
196 |
+
self.low: int = 0
|
197 |
+
self.high: int = 0
|
198 |
+
self.current: int = 0
|
199 |
+
self.max_bit: int = -1
|
200 |
+
self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
|
201 |
+
# Following is for debugging
|
202 |
+
self._dbg: tp.List[tp.Any] = []
|
203 |
+
self._dbg2: tp.List[tp.Any] = []
|
204 |
+
self._last: tp.Any = None
|
205 |
+
|
206 |
+
@property
|
207 |
+
def delta(self) -> int:
|
208 |
+
return self.high - self.low + 1
|
209 |
+
|
210 |
+
def _flush_common_prefix(self):
|
211 |
+
# Given the current range [L, H], if both have a common prefix,
|
212 |
+
# we know we can remove it from our representation to avoid handling large numbers.
|
213 |
+
while self.max_bit >= 0:
|
214 |
+
b1 = self.low >> self.max_bit
|
215 |
+
b2 = self.high >> self.max_bit
|
216 |
+
if b1 == b2:
|
217 |
+
self.low -= b1 << self.max_bit
|
218 |
+
self.high -= b1 << self.max_bit
|
219 |
+
self.current -= b1 << self.max_bit
|
220 |
+
assert self.high >= self.low
|
221 |
+
assert self.low >= 0
|
222 |
+
self.max_bit -= 1
|
223 |
+
else:
|
224 |
+
break
|
225 |
+
|
226 |
+
def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
|
227 |
+
"""Pull a symbol, reading as many bits from the stream as required.
|
228 |
+
This returns `None` when the stream has been exhausted.
|
229 |
+
|
230 |
+
Args:
|
231 |
+
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
|
232 |
+
to build this from your pdf estimate. This must be **exatly**
|
233 |
+
the same cdf as the one used at encoding time.
|
234 |
+
"""
|
235 |
+
while self.delta < 2**self.total_range_bits:
|
236 |
+
bit = self.unpacker.pull()
|
237 |
+
if bit is None:
|
238 |
+
return None
|
239 |
+
self.low *= 2
|
240 |
+
self.high = self.high * 2 + 1
|
241 |
+
self.current = self.current * 2 + bit
|
242 |
+
self.max_bit += 1
|
243 |
+
|
244 |
+
def bin_search(low_idx: int, high_idx: int):
|
245 |
+
# Binary search is not just for coding interviews :)
|
246 |
+
if high_idx < low_idx:
|
247 |
+
raise RuntimeError("Binary search failed")
|
248 |
+
mid = (low_idx + high_idx) // 2
|
249 |
+
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
|
250 |
+
range_high = quantized_cdf[mid].item() - 1
|
251 |
+
effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
|
252 |
+
effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
|
253 |
+
low = effective_low + self.low
|
254 |
+
high = effective_high + self.low
|
255 |
+
if self.current >= low:
|
256 |
+
if self.current <= high:
|
257 |
+
return (mid, low, high, self.current)
|
258 |
+
else:
|
259 |
+
return bin_search(mid + 1, high_idx)
|
260 |
+
else:
|
261 |
+
return bin_search(low_idx, mid - 1)
|
262 |
+
|
263 |
+
self._last = (self.low, self.high, self.current, self.max_bit)
|
264 |
+
sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
|
265 |
+
self._dbg.append((self.low, self.high, self.current))
|
266 |
+
self._flush_common_prefix()
|
267 |
+
self._dbg2.append((self.low, self.high, self.current))
|
268 |
+
|
269 |
+
return sym
|
270 |
+
|
271 |
+
|
272 |
+
def test():
|
273 |
+
torch.manual_seed(1234)
|
274 |
+
random.seed(1234)
|
275 |
+
for _ in range(4):
|
276 |
+
pdfs = []
|
277 |
+
cardinality = random.randrange(4000)
|
278 |
+
steps = random.randrange(100, 500)
|
279 |
+
fo = io.BytesIO()
|
280 |
+
encoder = ArithmeticCoder(fo)
|
281 |
+
symbols = []
|
282 |
+
for step in range(steps):
|
283 |
+
pdf = torch.softmax(torch.randn(cardinality), dim=0)
|
284 |
+
pdfs.append(pdf)
|
285 |
+
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
|
286 |
+
symbol = torch.multinomial(pdf, 1).item()
|
287 |
+
symbols.append(symbol)
|
288 |
+
encoder.push(symbol, q_cdf)
|
289 |
+
encoder.flush()
|
290 |
+
|
291 |
+
fo.seek(0)
|
292 |
+
decoder = ArithmeticDecoder(fo)
|
293 |
+
for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
|
294 |
+
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
|
295 |
+
decoded_symbol = decoder.pull(q_cdf)
|
296 |
+
assert decoded_symbol == symbol, idx
|
297 |
+
assert decoder.pull(torch.zeros(1)) is None
|
298 |
+
|
299 |
+
|
300 |
+
if __name__ == "__main__":
|
301 |
+
test()
|
higgs_audio/audio_processing/quantization/core_vq.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# This implementation is inspired from
|
8 |
+
# https://github.com/lucidrains/vector-quantize-pytorch
|
9 |
+
# which is released under MIT License. Hereafter, the original license:
|
10 |
+
# MIT License
|
11 |
+
#
|
12 |
+
# Copyright (c) 2020 Phil Wang
|
13 |
+
#
|
14 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
15 |
+
# of this software and associated documentation files (the "Software"), to deal
|
16 |
+
# in the Software without restriction, including without limitation the rights
|
17 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
18 |
+
# copies of the Software, and to permit persons to whom the Software is
|
19 |
+
# furnished to do so, subject to the following conditions:
|
20 |
+
#
|
21 |
+
# The above copyright notice and this permission notice shall be included in all
|
22 |
+
# copies or substantial portions of the Software.
|
23 |
+
#
|
24 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
25 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
26 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
27 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
28 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
29 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
30 |
+
# SOFTWARE.
|
31 |
+
|
32 |
+
"""Core vector quantization implementation."""
|
33 |
+
|
34 |
+
import typing as tp
|
35 |
+
|
36 |
+
from einops import rearrange, repeat
|
37 |
+
import torch
|
38 |
+
from torch import nn
|
39 |
+
import torch.nn.functional as F
|
40 |
+
|
41 |
+
from xcodec.quantization.distrib import broadcast_tensors, rank
|
42 |
+
|
43 |
+
|
44 |
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
45 |
+
return val if val is not None else d
|
46 |
+
|
47 |
+
|
48 |
+
def ema_inplace(moving_avg, new, decay: float):
|
49 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
50 |
+
|
51 |
+
|
52 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
53 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
54 |
+
|
55 |
+
|
56 |
+
def uniform_init(*shape: int):
|
57 |
+
t = torch.empty(shape)
|
58 |
+
nn.init.kaiming_uniform_(t)
|
59 |
+
return t
|
60 |
+
|
61 |
+
|
62 |
+
def sample_vectors(samples, num: int):
|
63 |
+
num_samples, device = samples.shape[0], samples.device
|
64 |
+
|
65 |
+
if num_samples >= num:
|
66 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
67 |
+
else:
|
68 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
69 |
+
|
70 |
+
return samples[indices]
|
71 |
+
|
72 |
+
|
73 |
+
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
74 |
+
dim, dtype = samples.shape[-1], samples.dtype
|
75 |
+
|
76 |
+
means = sample_vectors(samples, num_clusters)
|
77 |
+
|
78 |
+
for _ in range(num_iters):
|
79 |
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
|
80 |
+
dists = -(diffs**2).sum(dim=-1)
|
81 |
+
|
82 |
+
buckets = dists.max(dim=-1).indices
|
83 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
84 |
+
zero_mask = bins == 0
|
85 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
86 |
+
|
87 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
88 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
89 |
+
new_means = new_means / bins_min_clamped[..., None]
|
90 |
+
|
91 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
92 |
+
|
93 |
+
return means, bins
|
94 |
+
|
95 |
+
|
96 |
+
class EuclideanCodebook(nn.Module):
|
97 |
+
"""Codebook with Euclidean distance.
|
98 |
+
Args:
|
99 |
+
dim (int): Dimension.
|
100 |
+
codebook_size (int): Codebook size.
|
101 |
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
102 |
+
If set to true, run the k-means algorithm on the first training batch and use
|
103 |
+
the learned centroids as initialization.
|
104 |
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
105 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
106 |
+
epsilon (float): Epsilon value for numerical stability.
|
107 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
108 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
109 |
+
randomly selected vector from the current batch.
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
dim: int,
|
115 |
+
codebook_size: int,
|
116 |
+
kmeans_init: int = False,
|
117 |
+
kmeans_iters: int = 10,
|
118 |
+
decay: float = 0.99,
|
119 |
+
epsilon: float = 1e-5,
|
120 |
+
threshold_ema_dead_code: int = 2,
|
121 |
+
):
|
122 |
+
super().__init__()
|
123 |
+
self.decay = decay
|
124 |
+
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
125 |
+
embed = init_fn(codebook_size, dim)
|
126 |
+
|
127 |
+
self.codebook_size = codebook_size
|
128 |
+
|
129 |
+
self.kmeans_iters = kmeans_iters
|
130 |
+
self.epsilon = epsilon
|
131 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
132 |
+
|
133 |
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
134 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
135 |
+
self.register_buffer("embed", embed)
|
136 |
+
self.register_buffer("embed_avg", embed.clone())
|
137 |
+
|
138 |
+
@torch.jit.ignore
|
139 |
+
def init_embed_(self, data):
|
140 |
+
if self.inited:
|
141 |
+
return
|
142 |
+
|
143 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
144 |
+
self.embed.data.copy_(embed)
|
145 |
+
self.embed_avg.data.copy_(embed.clone())
|
146 |
+
self.cluster_size.data.copy_(cluster_size)
|
147 |
+
self.inited.data.copy_(torch.Tensor([True]))
|
148 |
+
# Make sure all buffers across workers are in sync after initialization
|
149 |
+
broadcast_tensors(self.buffers())
|
150 |
+
|
151 |
+
def replace_(self, samples, mask):
|
152 |
+
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
|
153 |
+
self.embed.data.copy_(modified_codebook)
|
154 |
+
|
155 |
+
def expire_codes_(self, batch_samples):
|
156 |
+
if self.threshold_ema_dead_code == 0:
|
157 |
+
return
|
158 |
+
|
159 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
160 |
+
if not torch.any(expired_codes):
|
161 |
+
return
|
162 |
+
|
163 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
164 |
+
self.replace_(batch_samples, mask=expired_codes)
|
165 |
+
broadcast_tensors(self.buffers())
|
166 |
+
|
167 |
+
def preprocess(self, x):
|
168 |
+
x = rearrange(x, "... d -> (...) d")
|
169 |
+
return x
|
170 |
+
|
171 |
+
def quantize(self, x):
|
172 |
+
embed = self.embed.t()
|
173 |
+
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
|
174 |
+
embed_ind = dist.max(dim=-1).indices
|
175 |
+
return embed_ind
|
176 |
+
|
177 |
+
def postprocess_emb(self, embed_ind, shape):
|
178 |
+
return embed_ind.view(*shape[:-1])
|
179 |
+
|
180 |
+
def dequantize(self, embed_ind):
|
181 |
+
quantize = F.embedding(embed_ind, self.embed) # get embedding based on index
|
182 |
+
return quantize
|
183 |
+
|
184 |
+
def encode(self, x):
|
185 |
+
shape = x.shape
|
186 |
+
# pre-process
|
187 |
+
x = self.preprocess(x)
|
188 |
+
# quantize
|
189 |
+
embed_ind = self.quantize(x) # get index based on Euclidean distance
|
190 |
+
# post-process
|
191 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
192 |
+
return embed_ind
|
193 |
+
|
194 |
+
def decode(self, embed_ind):
|
195 |
+
quantize = self.dequantize(embed_ind)
|
196 |
+
return quantize
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
shape, dtype = x.shape, x.dtype
|
200 |
+
x = self.preprocess(x)
|
201 |
+
|
202 |
+
self.init_embed_(x)
|
203 |
+
|
204 |
+
embed_ind = self.quantize(x)
|
205 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
206 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
207 |
+
quantize = self.dequantize(embed_ind)
|
208 |
+
|
209 |
+
if self.training:
|
210 |
+
# We do the expiry of code at that point as buffers are in sync
|
211 |
+
# and all the workers will take the same decision.
|
212 |
+
self.expire_codes_(x)
|
213 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
214 |
+
embed_sum = x.t() @ embed_onehot
|
215 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
216 |
+
cluster_size = (
|
217 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
|
218 |
+
)
|
219 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
220 |
+
self.embed.data.copy_(embed_normalized)
|
221 |
+
|
222 |
+
return quantize, embed_ind
|
223 |
+
|
224 |
+
|
225 |
+
class VectorQuantization(nn.Module):
|
226 |
+
"""Vector quantization implementation.
|
227 |
+
Currently supports only euclidean distance.
|
228 |
+
Args:
|
229 |
+
dim (int): Dimension
|
230 |
+
codebook_size (int): Codebook size
|
231 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
232 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
233 |
+
epsilon (float): Epsilon value for numerical stability.
|
234 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
235 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
236 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
237 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
238 |
+
randomly selected vector from the current batch.
|
239 |
+
commitment_weight (float): Weight for commitment loss.
|
240 |
+
"""
|
241 |
+
|
242 |
+
def __init__(
|
243 |
+
self,
|
244 |
+
dim: int,
|
245 |
+
codebook_size: int,
|
246 |
+
codebook_dim: tp.Optional[int] = None,
|
247 |
+
decay: float = 0.99,
|
248 |
+
epsilon: float = 1e-5,
|
249 |
+
kmeans_init: bool = True,
|
250 |
+
kmeans_iters: int = 50,
|
251 |
+
threshold_ema_dead_code: int = 2,
|
252 |
+
commitment_weight: float = 1.0,
|
253 |
+
):
|
254 |
+
super().__init__()
|
255 |
+
_codebook_dim: int = default(codebook_dim, dim)
|
256 |
+
|
257 |
+
requires_projection = _codebook_dim != dim
|
258 |
+
self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
259 |
+
self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
260 |
+
|
261 |
+
self.epsilon = epsilon
|
262 |
+
self.commitment_weight = commitment_weight
|
263 |
+
|
264 |
+
self._codebook = EuclideanCodebook(
|
265 |
+
dim=_codebook_dim,
|
266 |
+
codebook_size=codebook_size,
|
267 |
+
kmeans_init=kmeans_init,
|
268 |
+
kmeans_iters=kmeans_iters,
|
269 |
+
decay=decay,
|
270 |
+
epsilon=epsilon,
|
271 |
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
272 |
+
)
|
273 |
+
self.codebook_size = codebook_size
|
274 |
+
|
275 |
+
@property
|
276 |
+
def codebook(self):
|
277 |
+
return self._codebook.embed
|
278 |
+
|
279 |
+
def encode(self, x):
|
280 |
+
x = rearrange(x, "b d n -> b n d")
|
281 |
+
x = self.project_in(x)
|
282 |
+
embed_in = self._codebook.encode(x)
|
283 |
+
return embed_in
|
284 |
+
|
285 |
+
def decode(self, embed_ind):
|
286 |
+
quantize = self._codebook.decode(embed_ind)
|
287 |
+
quantize = self.project_out(quantize)
|
288 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
289 |
+
return quantize
|
290 |
+
|
291 |
+
def forward(self, x):
|
292 |
+
device = x.device
|
293 |
+
x = rearrange(x, "b d n -> b n d")
|
294 |
+
x = self.project_in(x)
|
295 |
+
|
296 |
+
quantize, embed_ind = self._codebook(x)
|
297 |
+
|
298 |
+
if self.training:
|
299 |
+
quantize = x + (quantize - x).detach()
|
300 |
+
|
301 |
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
302 |
+
|
303 |
+
if self.training:
|
304 |
+
if self.commitment_weight > 0:
|
305 |
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
306 |
+
loss = loss + commit_loss * self.commitment_weight
|
307 |
+
|
308 |
+
quantize = self.project_out(quantize)
|
309 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
310 |
+
return quantize, embed_ind, loss
|
311 |
+
|
312 |
+
|
313 |
+
class ResidualVectorQuantization(nn.Module):
|
314 |
+
"""Residual vector quantization implementation.
|
315 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
316 |
+
"""
|
317 |
+
|
318 |
+
def __init__(self, *, num_quantizers, **kwargs):
|
319 |
+
super().__init__()
|
320 |
+
self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
|
321 |
+
|
322 |
+
def forward(self, x, n_q: tp.Optional[int] = None):
|
323 |
+
quantized_out = 0.0
|
324 |
+
residual = x
|
325 |
+
|
326 |
+
all_losses = []
|
327 |
+
all_indices = []
|
328 |
+
|
329 |
+
n_q = n_q or len(self.layers)
|
330 |
+
|
331 |
+
for layer in self.layers[:n_q]:
|
332 |
+
quantized, indices, loss = layer(residual)
|
333 |
+
residual = residual - quantized
|
334 |
+
quantized_out = quantized_out + quantized
|
335 |
+
|
336 |
+
all_indices.append(indices)
|
337 |
+
all_losses.append(loss)
|
338 |
+
|
339 |
+
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
340 |
+
return quantized_out, out_indices, out_losses
|
341 |
+
|
342 |
+
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
343 |
+
residual = x
|
344 |
+
all_indices = []
|
345 |
+
n_q = n_q or len(self.layers)
|
346 |
+
for layer in self.layers[:n_q]:
|
347 |
+
indices = layer.encode(residual)
|
348 |
+
quantized = layer.decode(indices)
|
349 |
+
residual = residual - quantized
|
350 |
+
all_indices.append(indices)
|
351 |
+
out_indices = torch.stack(all_indices)
|
352 |
+
return out_indices
|
353 |
+
|
354 |
+
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
355 |
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
356 |
+
for i, indices in enumerate(q_indices):
|
357 |
+
layer = self.layers[i]
|
358 |
+
quantized = layer.decode(indices)
|
359 |
+
quantized_out = quantized_out + quantized
|
360 |
+
return quantized_out
|
higgs_audio/audio_processing/quantization/core_vq_lsx_version.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c)
|
2 |
+
#
|
3 |
+
# This source code is licensed under the license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
# This implementation is inspired from
|
6 |
+
# https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py and
|
7 |
+
# https://github.com/clementchadebec/benchmark_VAE/blob/dfa0dcf6c79172df5d27769c09c860c42008baaa/src/pythae/models/vq_vae/vq_vae_utils.py#L81
|
8 |
+
#
|
9 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
10 |
+
# All rights reserved.
|
11 |
+
#
|
12 |
+
# This source code is licensed under the license found in the
|
13 |
+
# LICENSE file in the root directory of this source tree.
|
14 |
+
#
|
15 |
+
# This implementation is inspired from
|
16 |
+
# https://github.com/lucidrains/vector-quantize-pytorch
|
17 |
+
# which is released under MIT License. Hereafter, the original license:
|
18 |
+
# MIT License
|
19 |
+
#
|
20 |
+
# Copyright (c) 2020 Phil Wang
|
21 |
+
#
|
22 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
23 |
+
# of this software and associated documentation files (the "Software"), to deal
|
24 |
+
# in the Software without restriction, including without limitation the rights
|
25 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
26 |
+
# copies of the Software, and to permit persons to whom the Software is
|
27 |
+
# furnished to do so, subject to the following conditions:
|
28 |
+
#
|
29 |
+
# The above copyright notice and this permission notice shall be included in all
|
30 |
+
# copies or substantial portions of the Software.
|
31 |
+
#
|
32 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
33 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
34 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
35 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
36 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
37 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
38 |
+
# SOFTWARE.
|
39 |
+
|
40 |
+
"""Core vector quantization implementation."""
|
41 |
+
|
42 |
+
import typing as tp
|
43 |
+
|
44 |
+
from einops import rearrange
|
45 |
+
import torch
|
46 |
+
from torch import nn
|
47 |
+
import torch.nn.functional as F
|
48 |
+
import torch.distributed as dist
|
49 |
+
|
50 |
+
from .distrib import broadcast_tensors, is_distributed
|
51 |
+
from .ddp_utils import SyncFunction
|
52 |
+
|
53 |
+
|
54 |
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
55 |
+
return val if val is not None else d
|
56 |
+
|
57 |
+
|
58 |
+
def ema_inplace(moving_avg, new, decay: float):
|
59 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
60 |
+
|
61 |
+
|
62 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
63 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
64 |
+
|
65 |
+
|
66 |
+
def uniform_init(*shape: int):
|
67 |
+
t = torch.empty(shape)
|
68 |
+
nn.init.kaiming_uniform_(t)
|
69 |
+
return t
|
70 |
+
|
71 |
+
|
72 |
+
def sample_vectors(samples, num: int):
|
73 |
+
num_samples, device = samples.shape[0], samples.device
|
74 |
+
|
75 |
+
if num_samples >= num:
|
76 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
77 |
+
else:
|
78 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
79 |
+
|
80 |
+
return samples[indices]
|
81 |
+
|
82 |
+
|
83 |
+
def kmeans(
|
84 |
+
samples,
|
85 |
+
num_clusters: int,
|
86 |
+
num_iters: int = 10,
|
87 |
+
frames_to_use: int = 10_000,
|
88 |
+
batch_size: int = 64,
|
89 |
+
):
|
90 |
+
"""
|
91 |
+
Memory-efficient K-means clustering.
|
92 |
+
Args:
|
93 |
+
samples (tensor): shape [N, D]
|
94 |
+
num_clusters (int): number of centroids.
|
95 |
+
num_iters (int): number of iterations.
|
96 |
+
frames_to_use (int): subsample size from total samples.
|
97 |
+
batch_size (int): batch size used in distance computation.
|
98 |
+
Returns:
|
99 |
+
means: [num_clusters, D]
|
100 |
+
bins: [num_clusters] (number of points per cluster)
|
101 |
+
"""
|
102 |
+
N, D = samples.shape
|
103 |
+
dtype, device = samples.dtype, samples.device
|
104 |
+
|
105 |
+
if frames_to_use < N:
|
106 |
+
indices = torch.randperm(N, device=device)[:frames_to_use]
|
107 |
+
samples = samples[indices]
|
108 |
+
|
109 |
+
means = sample_vectors(samples, num_clusters)
|
110 |
+
|
111 |
+
for _ in range(num_iters):
|
112 |
+
# Store cluster assignments
|
113 |
+
all_assignments = []
|
114 |
+
|
115 |
+
for i in range(0, samples.shape[0], batch_size):
|
116 |
+
batch = samples[i : i + batch_size] # [B, D]
|
117 |
+
dists = torch.cdist(batch, means, p=2) # [B, C]
|
118 |
+
assignments = dists.argmin(dim=1) # [B]
|
119 |
+
all_assignments.append(assignments)
|
120 |
+
|
121 |
+
buckets = torch.cat(all_assignments, dim=0) # [N]
|
122 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
123 |
+
zero_mask = bins == 0
|
124 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
125 |
+
|
126 |
+
# Compute new means
|
127 |
+
new_means = torch.zeros_like(means)
|
128 |
+
for i in range(num_clusters):
|
129 |
+
mask = buckets == i
|
130 |
+
if mask.any():
|
131 |
+
new_means[i] = samples[mask].mean(dim=0)
|
132 |
+
|
133 |
+
means = torch.where(zero_mask[:, None], means, new_means)
|
134 |
+
|
135 |
+
return means, bins
|
136 |
+
|
137 |
+
|
138 |
+
class EuclideanCodebook(nn.Module):
|
139 |
+
"""Codebook with Euclidean distance.
|
140 |
+
Args:
|
141 |
+
dim (int): Dimension.
|
142 |
+
codebook_size (int): Codebook size.
|
143 |
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
144 |
+
If set to true, run the k-means algorithm on the first training batch and use
|
145 |
+
the learned centroids as initialization.
|
146 |
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
147 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
148 |
+
epsilon (float): Epsilon value for numerical stability.
|
149 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
150 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
151 |
+
randomly selected vector from the current batch.
|
152 |
+
"""
|
153 |
+
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
dim: int,
|
157 |
+
codebook_size: int,
|
158 |
+
kmeans_init: int = False,
|
159 |
+
kmeans_iters: int = 10,
|
160 |
+
decay: float = 0.99,
|
161 |
+
epsilon: float = 1e-5,
|
162 |
+
threshold_ema_dead_code: int = 2,
|
163 |
+
):
|
164 |
+
super().__init__()
|
165 |
+
self.decay = decay
|
166 |
+
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
167 |
+
embed = init_fn(codebook_size, dim)
|
168 |
+
|
169 |
+
self.codebook_size = codebook_size
|
170 |
+
|
171 |
+
self.kmeans_iters = kmeans_iters
|
172 |
+
self.epsilon = epsilon
|
173 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
174 |
+
|
175 |
+
# Flag variable to indicate whether the codebook is initialized
|
176 |
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
177 |
+
# Runing EMA cluster size/count: N_i^t in eq. (6) in vqvae paper
|
178 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
179 |
+
# Codebook
|
180 |
+
self.register_buffer("embed", embed)
|
181 |
+
# EMA codebook: eq. (7) in vqvae paper
|
182 |
+
self.register_buffer("embed_avg", embed.clone())
|
183 |
+
|
184 |
+
@torch.jit.ignore
|
185 |
+
def init_embed_(self, data):
|
186 |
+
"""Initialize codebook.
|
187 |
+
Args:
|
188 |
+
data (tensor): [B * T, D].
|
189 |
+
"""
|
190 |
+
if self.inited:
|
191 |
+
return
|
192 |
+
|
193 |
+
## NOTE (snippet added by Songxiang Liu): gather data from all gpus
|
194 |
+
if dist.is_available() and dist.is_initialized():
|
195 |
+
# [B * T * world_size, D]
|
196 |
+
data = SyncFunction.apply(data)
|
197 |
+
|
198 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
199 |
+
self.embed.data.copy_(embed)
|
200 |
+
self.embed_avg.data.copy_(embed.clone())
|
201 |
+
self.cluster_size.data.copy_(cluster_size)
|
202 |
+
self.inited.data.copy_(torch.Tensor([True]))
|
203 |
+
# Make sure all buffers across workers are in sync after initialization
|
204 |
+
broadcast_tensors(self.buffers())
|
205 |
+
|
206 |
+
def replace_(self, samples, mask):
|
207 |
+
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
|
208 |
+
self.embed.data.copy_(modified_codebook)
|
209 |
+
|
210 |
+
def expire_codes_(self, batch_samples):
|
211 |
+
if self.threshold_ema_dead_code == 0:
|
212 |
+
return
|
213 |
+
|
214 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
215 |
+
if not torch.any(expired_codes):
|
216 |
+
return
|
217 |
+
|
218 |
+
## NOTE (snippet added by Songxiang Liu): gather data from all gpus
|
219 |
+
if is_distributed():
|
220 |
+
# [B * T * world_size, D]
|
221 |
+
batch_samples = SyncFunction.apply(batch_samples)
|
222 |
+
|
223 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
224 |
+
self.replace_(batch_samples, mask=expired_codes)
|
225 |
+
broadcast_tensors(self.buffers())
|
226 |
+
|
227 |
+
def preprocess(self, x):
|
228 |
+
x = rearrange(x, "... d -> (...) d")
|
229 |
+
return x
|
230 |
+
|
231 |
+
def quantize(self, x):
|
232 |
+
embed = self.embed.t()
|
233 |
+
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
|
234 |
+
embed_ind = dist.max(dim=-1).indices
|
235 |
+
return embed_ind
|
236 |
+
|
237 |
+
def postprocess_emb(self, embed_ind, shape):
|
238 |
+
return embed_ind.view(*shape[:-1])
|
239 |
+
|
240 |
+
def dequantize(self, embed_ind):
|
241 |
+
quantize = F.embedding(embed_ind, self.embed)
|
242 |
+
return quantize
|
243 |
+
|
244 |
+
def encode(self, x):
|
245 |
+
shape = x.shape
|
246 |
+
# pre-process
|
247 |
+
x = self.preprocess(x) # [B, T, D] -> [B*T, D]
|
248 |
+
# quantize
|
249 |
+
embed_ind = self.quantize(x)
|
250 |
+
# post-process
|
251 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
252 |
+
return embed_ind
|
253 |
+
|
254 |
+
def decode(self, embed_ind):
|
255 |
+
quantize = self.dequantize(embed_ind)
|
256 |
+
return quantize
|
257 |
+
|
258 |
+
def forward(self, x):
|
259 |
+
# shape: [B, T, D]
|
260 |
+
shape, dtype = x.shape, x.dtype
|
261 |
+
x = self.preprocess(x) # [B, T, D] -> [B*T, D]
|
262 |
+
|
263 |
+
# Initialize codebook
|
264 |
+
self.init_embed_(x)
|
265 |
+
|
266 |
+
embed_ind = self.quantize(x) # [B*T,]
|
267 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) # [B*T, cb-size]
|
268 |
+
embed_ind = self.postprocess_emb(embed_ind, shape) # [B, T]
|
269 |
+
quantize = self.dequantize(embed_ind) # [B, T, D]
|
270 |
+
|
271 |
+
if self.training:
|
272 |
+
### Update codebook by EMA
|
273 |
+
embed_onehot_sum = embed_onehot.sum(0) # [cb-size,]
|
274 |
+
embed_sum = x.t() @ embed_onehot # [D, cb-size]
|
275 |
+
if is_distributed():
|
276 |
+
dist.all_reduce(embed_onehot_sum)
|
277 |
+
dist.all_reduce(embed_sum)
|
278 |
+
# Update ema cluster count N_i^t, eq. (6) in vqvae paper
|
279 |
+
self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
|
280 |
+
# Update ema embed: eq. (7) in vqvae paper
|
281 |
+
self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay)
|
282 |
+
# apply laplace smoothing
|
283 |
+
n = self.cluster_size.sum()
|
284 |
+
cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n
|
285 |
+
# Update ema embed: eq. (8) in vqvae paper
|
286 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
287 |
+
self.embed.data.copy_(embed_normalized)
|
288 |
+
|
289 |
+
# We do the expiry of code at that point as buffers are in sync
|
290 |
+
# and all the workers will take the same decision.
|
291 |
+
self.expire_codes_(x)
|
292 |
+
|
293 |
+
return quantize, embed_ind
|
294 |
+
|
295 |
+
|
296 |
+
class VectorQuantization(nn.Module):
|
297 |
+
"""Vector quantization implementation.
|
298 |
+
Currently supports only euclidean distance.
|
299 |
+
Args:
|
300 |
+
dim (int): Dimension
|
301 |
+
codebook_size (int): Codebook size
|
302 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
303 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
304 |
+
epsilon (float): Epsilon value for numerical stability.
|
305 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
306 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
307 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
308 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
309 |
+
randomly selected vector from the current batch.
|
310 |
+
commitment_weight (float): Weight for commitment loss.
|
311 |
+
"""
|
312 |
+
|
313 |
+
def __init__(
|
314 |
+
self,
|
315 |
+
dim: int,
|
316 |
+
codebook_size: int,
|
317 |
+
codebook_dim: tp.Optional[int] = None,
|
318 |
+
decay: float = 0.99,
|
319 |
+
epsilon: float = 1e-5,
|
320 |
+
kmeans_init: bool = True,
|
321 |
+
kmeans_iters: int = 50,
|
322 |
+
threshold_ema_dead_code: int = 2,
|
323 |
+
commitment_weight: float = 1.0,
|
324 |
+
):
|
325 |
+
super().__init__()
|
326 |
+
_codebook_dim: int = default(codebook_dim, dim)
|
327 |
+
|
328 |
+
requires_projection = _codebook_dim != dim
|
329 |
+
self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
330 |
+
self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
331 |
+
|
332 |
+
self.epsilon = epsilon
|
333 |
+
self.commitment_weight = commitment_weight
|
334 |
+
|
335 |
+
self._codebook = EuclideanCodebook(
|
336 |
+
dim=_codebook_dim,
|
337 |
+
codebook_size=codebook_size,
|
338 |
+
kmeans_init=kmeans_init,
|
339 |
+
kmeans_iters=kmeans_iters,
|
340 |
+
decay=decay,
|
341 |
+
epsilon=epsilon,
|
342 |
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
343 |
+
)
|
344 |
+
self.codebook_size = codebook_size
|
345 |
+
|
346 |
+
@property
|
347 |
+
def codebook(self):
|
348 |
+
return self._codebook.embed
|
349 |
+
|
350 |
+
def encode(self, x):
|
351 |
+
x = rearrange(x, "b d n -> b n d")
|
352 |
+
x = self.project_in(x)
|
353 |
+
embed_in = self._codebook.encode(x)
|
354 |
+
return embed_in
|
355 |
+
|
356 |
+
def decode(self, embed_ind):
|
357 |
+
quantize = self._codebook.decode(embed_ind)
|
358 |
+
quantize = self.project_out(quantize)
|
359 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
360 |
+
return quantize
|
361 |
+
|
362 |
+
def forward(self, x):
|
363 |
+
device = x.device
|
364 |
+
x = x.transpose(1, 2).contiguous() # [b d n] -> [b n d]
|
365 |
+
x = self.project_in(x)
|
366 |
+
|
367 |
+
quantize, embed_ind = self._codebook(x)
|
368 |
+
|
369 |
+
if self.training:
|
370 |
+
quantize = x + (quantize - x).detach()
|
371 |
+
|
372 |
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
373 |
+
|
374 |
+
if self.training:
|
375 |
+
if self.commitment_weight > 0:
|
376 |
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
377 |
+
loss = loss + commit_loss * self.commitment_weight
|
378 |
+
|
379 |
+
quantize = self.project_out(quantize)
|
380 |
+
quantize = quantize.transpose(1, 2).contiguous() # [b n d] -> [b d n]
|
381 |
+
return quantize, embed_ind, loss
|
382 |
+
|
383 |
+
|
384 |
+
class ResidualVectorQuantization(nn.Module):
|
385 |
+
"""Residual vector quantization implementation.
|
386 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
387 |
+
"""
|
388 |
+
|
389 |
+
def __init__(self, *, num_quantizers, **kwargs):
|
390 |
+
super().__init__()
|
391 |
+
self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
|
392 |
+
|
393 |
+
def forward(self, x, n_q: tp.Optional[int] = None):
|
394 |
+
quantized_out = 0.0
|
395 |
+
residual = x
|
396 |
+
|
397 |
+
all_losses = []
|
398 |
+
all_indices = []
|
399 |
+
|
400 |
+
n_q = n_q or len(self.layers)
|
401 |
+
|
402 |
+
for layer in self.layers[:n_q]:
|
403 |
+
quantized, indices, loss = layer(residual)
|
404 |
+
residual = residual - quantized
|
405 |
+
quantized_out = quantized_out + quantized
|
406 |
+
|
407 |
+
all_indices.append(indices)
|
408 |
+
all_losses.append(loss)
|
409 |
+
|
410 |
+
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
411 |
+
return quantized_out, out_indices, out_losses
|
412 |
+
|
413 |
+
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
414 |
+
residual = x
|
415 |
+
all_indices = []
|
416 |
+
n_q = n_q or len(self.layers)
|
417 |
+
for layer in self.layers[:n_q]:
|
418 |
+
indices = layer.encode(residual)
|
419 |
+
quantized = layer.decode(indices)
|
420 |
+
residual = residual - quantized
|
421 |
+
all_indices.append(indices)
|
422 |
+
out_indices = torch.stack(all_indices)
|
423 |
+
return out_indices
|
424 |
+
|
425 |
+
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
426 |
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
427 |
+
for i, indices in enumerate(q_indices):
|
428 |
+
layer = self.layers[i]
|
429 |
+
quantized = layer.decode(indices)
|
430 |
+
quantized_out = quantized_out + quantized
|
431 |
+
return quantized_out
|
higgs_audio/audio_processing/quantization/ddp_utils.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
import subprocess
|
4 |
+
from datetime import datetime
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
from torch.nn.parallel import DistributedDataParallel
|
10 |
+
from torch.nn.parallel.distributed import _find_tensors
|
11 |
+
import torch.optim
|
12 |
+
import torch.utils.data
|
13 |
+
from packaging import version
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
|
16 |
+
|
17 |
+
def set_random_seed(seed):
|
18 |
+
random.seed(seed)
|
19 |
+
np.random.seed(seed)
|
20 |
+
torch.manual_seed(seed)
|
21 |
+
torch.cuda.manual_seed_all(seed)
|
22 |
+
|
23 |
+
|
24 |
+
def is_logging_process():
|
25 |
+
return not dist.is_initialized() or dist.get_rank() == 0
|
26 |
+
|
27 |
+
|
28 |
+
def get_logger(cfg, name=None):
|
29 |
+
# log_file_path is used when unit testing
|
30 |
+
if is_logging_process():
|
31 |
+
logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_config, resolve=True))
|
32 |
+
return logging.getLogger(name)
|
33 |
+
|
34 |
+
|
35 |
+
# from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
|
36 |
+
class SyncFunction(torch.autograd.Function):
|
37 |
+
@staticmethod
|
38 |
+
# @torch.no_grad()
|
39 |
+
def forward(ctx, tensor):
|
40 |
+
ctx.batch_size = tensor.shape[0]
|
41 |
+
|
42 |
+
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
43 |
+
|
44 |
+
torch.distributed.all_gather(gathered_tensor, tensor)
|
45 |
+
gathered_tensor = torch.cat(gathered_tensor, 0)
|
46 |
+
|
47 |
+
return gathered_tensor
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def backward(ctx, grad_output):
|
51 |
+
grad_input = grad_output.clone()
|
52 |
+
torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
|
53 |
+
|
54 |
+
idx_from = torch.distributed.get_rank() * ctx.batch_size
|
55 |
+
idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
|
56 |
+
return grad_input[idx_from:idx_to]
|
57 |
+
|
58 |
+
|
59 |
+
def get_timestamp():
|
60 |
+
return datetime.now().strftime("%y%m%d-%H%M%S")
|
61 |
+
|
62 |
+
|
63 |
+
def get_commit_hash():
|
64 |
+
message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
|
65 |
+
return message.strip().decode("utf-8")
|
66 |
+
|
67 |
+
|
68 |
+
class DDP(DistributedDataParallel):
|
69 |
+
"""
|
70 |
+
Override the forward call in lightning so it goes to training and validation step respectively
|
71 |
+
"""
|
72 |
+
|
73 |
+
def forward(self, *inputs, **kwargs): # pragma: no cover
|
74 |
+
if version.parse(torch.__version__[:6]) < version.parse("1.11"):
|
75 |
+
self._sync_params()
|
76 |
+
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
77 |
+
assert len(self.device_ids) == 1
|
78 |
+
if self.module.training:
|
79 |
+
output = self.module.training_step(*inputs[0], **kwargs[0])
|
80 |
+
elif self.module.testing:
|
81 |
+
output = self.module.test_step(*inputs[0], **kwargs[0])
|
82 |
+
else:
|
83 |
+
output = self.module.validation_step(*inputs[0], **kwargs[0])
|
84 |
+
if torch.is_grad_enabled():
|
85 |
+
# We'll return the output object verbatim since it is a freeform
|
86 |
+
# object. We need to find any tensors in this object, though,
|
87 |
+
# because we need to figure out which parameters were used during
|
88 |
+
# this forward pass, to ensure we short circuit reduction for any
|
89 |
+
# unused parameters. Only if `find_unused_parameters` is set.
|
90 |
+
if self.find_unused_parameters:
|
91 |
+
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
92 |
+
else:
|
93 |
+
self.reducer.prepare_for_backward([])
|
94 |
+
else:
|
95 |
+
from torch.nn.parallel.distributed import (
|
96 |
+
logging,
|
97 |
+
Join,
|
98 |
+
_DDPSink,
|
99 |
+
_tree_flatten_with_rref,
|
100 |
+
_tree_unflatten_with_rref,
|
101 |
+
)
|
102 |
+
|
103 |
+
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
|
104 |
+
if torch.is_grad_enabled() and self.require_backward_grad_sync:
|
105 |
+
self.logger.set_runtime_stats_and_log()
|
106 |
+
self.num_iterations += 1
|
107 |
+
self.reducer.prepare_for_forward()
|
108 |
+
|
109 |
+
# Notify the join context that this process has not joined, if
|
110 |
+
# needed
|
111 |
+
work = Join.notify_join_context(self)
|
112 |
+
if work:
|
113 |
+
self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size)
|
114 |
+
|
115 |
+
# Calling _rebuild_buckets before forward compuation,
|
116 |
+
# It may allocate new buckets before deallocating old buckets
|
117 |
+
# inside _rebuild_buckets. To save peak memory usage,
|
118 |
+
# call _rebuild_buckets before the peak memory usage increases
|
119 |
+
# during forward computation.
|
120 |
+
# This should be called only once during whole training period.
|
121 |
+
if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
|
122 |
+
logging.info("Reducer buckets have been rebuilt in this iteration.")
|
123 |
+
self._has_rebuilt_buckets = True
|
124 |
+
|
125 |
+
# sync params according to location (before/after forward) user
|
126 |
+
# specified as part of hook, if hook was specified.
|
127 |
+
buffer_hook_registered = hasattr(self, "buffer_hook")
|
128 |
+
if self._check_sync_bufs_pre_fwd():
|
129 |
+
self._sync_buffers()
|
130 |
+
|
131 |
+
if self._join_config.enable:
|
132 |
+
# Notify joined ranks whether they should sync in backwards pass or not.
|
133 |
+
self._check_global_requires_backward_grad_sync(is_joined_rank=False)
|
134 |
+
|
135 |
+
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
136 |
+
if self.module.training:
|
137 |
+
output = self.module.training_step(*inputs[0], **kwargs[0])
|
138 |
+
elif self.module.testing:
|
139 |
+
output = self.module.test_step(*inputs[0], **kwargs[0])
|
140 |
+
else:
|
141 |
+
output = self.module.validation_step(*inputs[0], **kwargs[0])
|
142 |
+
|
143 |
+
# sync params according to location (before/after forward) user
|
144 |
+
# specified as part of hook, if hook was specified.
|
145 |
+
if self._check_sync_bufs_post_fwd():
|
146 |
+
self._sync_buffers()
|
147 |
+
|
148 |
+
if torch.is_grad_enabled() and self.require_backward_grad_sync:
|
149 |
+
self.require_forward_param_sync = True
|
150 |
+
# We'll return the output object verbatim since it is a freeform
|
151 |
+
# object. We need to find any tensors in this object, though,
|
152 |
+
# because we need to figure out which parameters were used during
|
153 |
+
# this forward pass, to ensure we short circuit reduction for any
|
154 |
+
# unused parameters. Only if `find_unused_parameters` is set.
|
155 |
+
if self.find_unused_parameters and not self.static_graph:
|
156 |
+
# Do not need to populate this for static graph.
|
157 |
+
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
158 |
+
else:
|
159 |
+
self.reducer.prepare_for_backward([])
|
160 |
+
else:
|
161 |
+
self.require_forward_param_sync = False
|
162 |
+
|
163 |
+
# TODO: DDPSink is currently enabled for unused parameter detection and
|
164 |
+
# static graph training for first iteration.
|
165 |
+
if (self.find_unused_parameters and not self.static_graph) or (
|
166 |
+
self.static_graph and self.num_iterations == 1
|
167 |
+
):
|
168 |
+
state_dict = {
|
169 |
+
"static_graph": self.static_graph,
|
170 |
+
"num_iterations": self.num_iterations,
|
171 |
+
}
|
172 |
+
|
173 |
+
output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output)
|
174 |
+
output_placeholders = [None for _ in range(len(output_tensor_list))]
|
175 |
+
# Do not touch tensors that have no grad_fn, which can cause issues
|
176 |
+
# such as https://github.com/pytorch/pytorch/issues/60733
|
177 |
+
for i, output in enumerate(output_tensor_list):
|
178 |
+
if torch.is_tensor(output) and output.grad_fn is None:
|
179 |
+
output_placeholders[i] = output
|
180 |
+
|
181 |
+
# When find_unused_parameters=True, makes tensors which require grad
|
182 |
+
# run through the DDPSink backward pass. When not all outputs are
|
183 |
+
# used in loss, this makes those corresponding tensors receive
|
184 |
+
# undefined gradient which the reducer then handles to ensure
|
185 |
+
# param.grad field is not touched and we don't error out.
|
186 |
+
passthrough_tensor_list = _DDPSink.apply(
|
187 |
+
self.reducer,
|
188 |
+
state_dict,
|
189 |
+
*output_tensor_list,
|
190 |
+
)
|
191 |
+
for i in range(len(output_placeholders)):
|
192 |
+
if output_placeholders[i] is None:
|
193 |
+
output_placeholders[i] = passthrough_tensor_list[i]
|
194 |
+
|
195 |
+
# Reconstruct output data structure.
|
196 |
+
output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref)
|
197 |
+
return output
|
higgs_audio/audio_processing/quantization/distrib.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""Torch distributed utilities."""
|
8 |
+
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
def rank():
|
15 |
+
if torch.distributed.is_initialized():
|
16 |
+
return torch.distributed.get_rank()
|
17 |
+
else:
|
18 |
+
return 0
|
19 |
+
|
20 |
+
|
21 |
+
def world_size():
|
22 |
+
if torch.distributed.is_initialized():
|
23 |
+
return torch.distributed.get_world_size()
|
24 |
+
else:
|
25 |
+
return 1
|
26 |
+
|
27 |
+
|
28 |
+
def is_distributed():
|
29 |
+
return world_size() > 1
|
30 |
+
|
31 |
+
|
32 |
+
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
|
33 |
+
if is_distributed():
|
34 |
+
return torch.distributed.all_reduce(tensor, op)
|
35 |
+
|
36 |
+
|
37 |
+
def _is_complex_or_float(tensor):
|
38 |
+
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
39 |
+
|
40 |
+
|
41 |
+
def _check_number_of_params(params: tp.List[torch.Tensor]):
|
42 |
+
# utility function to check that the number of params in all workers is the same,
|
43 |
+
# and thus avoid a deadlock with distributed all reduce.
|
44 |
+
if not is_distributed() or not params:
|
45 |
+
return
|
46 |
+
# print('params[0].device ', params[0].device)
|
47 |
+
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
48 |
+
all_reduce(tensor)
|
49 |
+
if tensor.item() != len(params) * world_size():
|
50 |
+
# If not all the workers have the same number, for at least one of them,
|
51 |
+
# this inequality will be verified.
|
52 |
+
raise RuntimeError(
|
53 |
+
f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one."
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
|
58 |
+
"""Broadcast the tensors from the given parameters to all workers.
|
59 |
+
This can be used to ensure that all workers have the same model to start with.
|
60 |
+
"""
|
61 |
+
if not is_distributed():
|
62 |
+
return
|
63 |
+
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
64 |
+
_check_number_of_params(tensors)
|
65 |
+
handles = []
|
66 |
+
for tensor in tensors:
|
67 |
+
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
|
68 |
+
handles.append(handle)
|
69 |
+
for handle in handles:
|
70 |
+
handle.wait()
|
71 |
+
|
72 |
+
|
73 |
+
def sync_buffer(buffers, average=True):
|
74 |
+
"""
|
75 |
+
Sync grad for buffers. If average is False, broadcast instead of averaging.
|
76 |
+
"""
|
77 |
+
if not is_distributed():
|
78 |
+
return
|
79 |
+
handles = []
|
80 |
+
for buffer in buffers:
|
81 |
+
if torch.is_floating_point(buffer.data):
|
82 |
+
if average:
|
83 |
+
handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
84 |
+
else:
|
85 |
+
handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
|
86 |
+
handles.append((buffer, handle))
|
87 |
+
for buffer, handle in handles:
|
88 |
+
handle.wait()
|
89 |
+
if average:
|
90 |
+
buffer.data /= world_size
|
91 |
+
|
92 |
+
|
93 |
+
def sync_grad(params):
|
94 |
+
"""
|
95 |
+
Simpler alternative to DistributedDataParallel, that doesn't rely
|
96 |
+
on any black magic. For simple models it can also be as fast.
|
97 |
+
Just call this on your model parameters after the call to backward!
|
98 |
+
"""
|
99 |
+
if not is_distributed():
|
100 |
+
return
|
101 |
+
handles = []
|
102 |
+
for p in params:
|
103 |
+
if p.grad is not None:
|
104 |
+
handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
105 |
+
handles.append((p, handle))
|
106 |
+
for p, handle in handles:
|
107 |
+
handle.wait()
|
108 |
+
p.grad.data /= world_size()
|
109 |
+
|
110 |
+
|
111 |
+
def average_metrics(metrics: tp.Dict[str, float], count=1.0):
|
112 |
+
"""Average a dictionary of metrics across all workers, using the optional
|
113 |
+
`count` as unormalized weight.
|
114 |
+
"""
|
115 |
+
if not is_distributed():
|
116 |
+
return metrics
|
117 |
+
keys, values = zip(*metrics.items())
|
118 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
119 |
+
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
|
120 |
+
tensor *= count
|
121 |
+
all_reduce(tensor)
|
122 |
+
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
|
123 |
+
return dict(zip(keys, averaged))
|
higgs_audio/audio_processing/quantization/vq.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""Residual vector quantizer implementation."""
|
8 |
+
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
import math
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
# from .core_vq import ResidualVectorQuantization
|
17 |
+
from .core_vq_lsx_version import ResidualVectorQuantization
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class QuantizedResult:
|
22 |
+
quantized: torch.Tensor
|
23 |
+
codes: torch.Tensor
|
24 |
+
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
|
25 |
+
penalty: tp.Optional[torch.Tensor] = None
|
26 |
+
metrics: dict = field(default_factory=dict)
|
27 |
+
|
28 |
+
|
29 |
+
class ResidualVectorQuantizer(nn.Module):
|
30 |
+
"""Residual Vector Quantizer.
|
31 |
+
Args:
|
32 |
+
dimension (int): Dimension of the codebooks.
|
33 |
+
n_q (int): Number of residual vector quantizers used.
|
34 |
+
bins (int): Codebook size.
|
35 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
36 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
37 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
38 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
39 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
40 |
+
randomly selected vector from the current batch.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
dimension: int = 256,
|
46 |
+
codebook_dim: int = None,
|
47 |
+
n_q: int = 8,
|
48 |
+
bins: int = 1024,
|
49 |
+
decay: float = 0.99,
|
50 |
+
kmeans_init: bool = True,
|
51 |
+
kmeans_iters: int = 50,
|
52 |
+
threshold_ema_dead_code: int = 2,
|
53 |
+
):
|
54 |
+
super().__init__()
|
55 |
+
self.n_q = n_q
|
56 |
+
self.dimension = dimension
|
57 |
+
self.codebook_dim = codebook_dim
|
58 |
+
self.bins = bins
|
59 |
+
self.decay = decay
|
60 |
+
self.kmeans_init = kmeans_init
|
61 |
+
self.kmeans_iters = kmeans_iters
|
62 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
63 |
+
self.vq = ResidualVectorQuantization(
|
64 |
+
dim=self.dimension,
|
65 |
+
codebook_dim=self.codebook_dim,
|
66 |
+
codebook_size=self.bins,
|
67 |
+
num_quantizers=self.n_q,
|
68 |
+
decay=self.decay,
|
69 |
+
kmeans_init=self.kmeans_init,
|
70 |
+
kmeans_iters=self.kmeans_iters,
|
71 |
+
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
72 |
+
)
|
73 |
+
|
74 |
+
def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None): # -> QuantizedResult:
|
75 |
+
"""Residual vector quantization on the given input tensor.
|
76 |
+
Args:
|
77 |
+
x (torch.Tensor): Input tensor.
|
78 |
+
sample_rate (int): Sample rate of the input tensor.
|
79 |
+
bandwidth (float): Target bandwidth.
|
80 |
+
Returns:
|
81 |
+
QuantizedResult:
|
82 |
+
The quantized (or approximately quantized) representation with
|
83 |
+
the associated bandwidth and any penalty term for the loss.
|
84 |
+
"""
|
85 |
+
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
|
86 |
+
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
|
87 |
+
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
|
88 |
+
bw = torch.tensor(n_q * bw_per_q).to(x)
|
89 |
+
return quantized, codes, bw, torch.mean(commit_loss)
|
90 |
+
# return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
|
91 |
+
|
92 |
+
def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int:
|
93 |
+
"""Return n_q based on specified target bandwidth."""
|
94 |
+
bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
|
95 |
+
n_q = self.n_q
|
96 |
+
if bandwidth and bandwidth > 0.0:
|
97 |
+
n_q = int(max(1, math.floor(bandwidth / bw_per_q)))
|
98 |
+
return n_q
|
99 |
+
|
100 |
+
def get_bandwidth_per_quantizer(self, sample_rate: int):
|
101 |
+
"""Return bandwidth per quantizer for a given input sample rate."""
|
102 |
+
return math.log2(self.bins) * sample_rate / 1000
|
103 |
+
|
104 |
+
def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
|
105 |
+
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
106 |
+
The RVQ encode method sets the appropriate number of quantizer to use
|
107 |
+
and returns indices for each quantizer.
|
108 |
+
"""
|
109 |
+
n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
|
110 |
+
codes = self.vq.encode(x, n_q=n_q)
|
111 |
+
return codes
|
112 |
+
|
113 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
114 |
+
"""Decode the given codes to the quantized representation."""
|
115 |
+
quantized = self.vq.decode(codes)
|
116 |
+
return quantized
|
higgs_audio/audio_processing/semantic_module.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on code from: https://github.com/zhenye234/xcodec
|
2 |
+
# Licensed under MIT License
|
3 |
+
# Modifications by BosonAI
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
class Conv1d1x1(nn.Conv1d):
|
10 |
+
"""1x1 Conv1d."""
|
11 |
+
|
12 |
+
def __init__(self, in_channels, out_channels, bias=True):
|
13 |
+
super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias)
|
14 |
+
|
15 |
+
|
16 |
+
class Conv1d(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
in_channels: int,
|
20 |
+
out_channels: int,
|
21 |
+
kernel_size: int,
|
22 |
+
stride: int = 1,
|
23 |
+
padding: int = -1,
|
24 |
+
dilation: int = 1,
|
25 |
+
groups: int = 1,
|
26 |
+
bias: bool = True,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
self.in_channels = in_channels
|
30 |
+
self.out_channels = out_channels
|
31 |
+
self.kernel_size = kernel_size
|
32 |
+
if padding < 0:
|
33 |
+
padding = (kernel_size - 1) // 2 * dilation
|
34 |
+
self.dilation = dilation
|
35 |
+
self.conv = nn.Conv1d(
|
36 |
+
in_channels=in_channels,
|
37 |
+
out_channels=out_channels,
|
38 |
+
kernel_size=kernel_size,
|
39 |
+
stride=stride,
|
40 |
+
padding=padding,
|
41 |
+
dilation=dilation,
|
42 |
+
groups=groups,
|
43 |
+
bias=bias,
|
44 |
+
)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
"""
|
48 |
+
Args:
|
49 |
+
x (Tensor): Float tensor variable with the shape (B, C, T).
|
50 |
+
Returns:
|
51 |
+
Tensor: Float tensor variable with the shape (B, C, T).
|
52 |
+
"""
|
53 |
+
x = self.conv(x)
|
54 |
+
return x
|
55 |
+
|
56 |
+
|
57 |
+
class ResidualUnit(nn.Module):
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
in_channels: int,
|
61 |
+
out_channels: int,
|
62 |
+
kernel_size=3,
|
63 |
+
dilation=1,
|
64 |
+
bias=False,
|
65 |
+
nonlinear_activation="ELU",
|
66 |
+
nonlinear_activation_params={},
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
|
70 |
+
self.conv1 = Conv1d(
|
71 |
+
in_channels=in_channels,
|
72 |
+
out_channels=out_channels,
|
73 |
+
kernel_size=kernel_size,
|
74 |
+
stride=1,
|
75 |
+
dilation=dilation,
|
76 |
+
bias=bias,
|
77 |
+
)
|
78 |
+
self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
y = self.conv1(self.activation(x))
|
82 |
+
y = self.conv2(self.activation(y))
|
83 |
+
return x + y
|
84 |
+
|
85 |
+
|
86 |
+
class ConvTranspose1d(nn.Module):
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
in_channels: int,
|
90 |
+
out_channels: int,
|
91 |
+
kernel_size: int,
|
92 |
+
stride: int,
|
93 |
+
padding=-1,
|
94 |
+
output_padding=-1,
|
95 |
+
groups=1,
|
96 |
+
bias=True,
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
if padding < 0:
|
100 |
+
padding = (stride + 1) // 2
|
101 |
+
if output_padding < 0:
|
102 |
+
output_padding = 1 if stride % 2 else 0
|
103 |
+
self.deconv = nn.ConvTranspose1d(
|
104 |
+
in_channels=in_channels,
|
105 |
+
out_channels=out_channels,
|
106 |
+
kernel_size=kernel_size,
|
107 |
+
stride=stride,
|
108 |
+
padding=padding,
|
109 |
+
output_padding=output_padding,
|
110 |
+
groups=groups,
|
111 |
+
bias=bias,
|
112 |
+
)
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
"""
|
116 |
+
Args:
|
117 |
+
x (Tensor): Float tensor variable with the shape (B, C, T).
|
118 |
+
Returns:
|
119 |
+
Tensor: Float tensor variable with the shape (B, C', T').
|
120 |
+
"""
|
121 |
+
x = self.deconv(x)
|
122 |
+
return x
|
123 |
+
|
124 |
+
|
125 |
+
class EncoderBlock(nn.Module):
|
126 |
+
def __init__(
|
127 |
+
self,
|
128 |
+
in_channels: int,
|
129 |
+
out_channels: int,
|
130 |
+
stride: int,
|
131 |
+
dilations=(1, 1),
|
132 |
+
unit_kernel_size=3,
|
133 |
+
bias=True,
|
134 |
+
):
|
135 |
+
super().__init__()
|
136 |
+
self.res_units = torch.nn.ModuleList()
|
137 |
+
for dilation in dilations:
|
138 |
+
self.res_units += [
|
139 |
+
ResidualUnit(
|
140 |
+
in_channels,
|
141 |
+
in_channels,
|
142 |
+
kernel_size=unit_kernel_size,
|
143 |
+
dilation=dilation,
|
144 |
+
)
|
145 |
+
]
|
146 |
+
self.num_res = len(self.res_units)
|
147 |
+
|
148 |
+
self.conv = Conv1d(
|
149 |
+
in_channels=in_channels,
|
150 |
+
out_channels=out_channels,
|
151 |
+
kernel_size=3 if stride == 1 else (2 * stride), # special case: stride=1, do not use kernel=2
|
152 |
+
stride=stride,
|
153 |
+
bias=bias,
|
154 |
+
)
|
155 |
+
|
156 |
+
def forward(self, x):
|
157 |
+
for idx in range(self.num_res):
|
158 |
+
x = self.res_units[idx](x)
|
159 |
+
x = self.conv(x)
|
160 |
+
return x
|
161 |
+
|
162 |
+
|
163 |
+
class Encoder(nn.Module):
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
input_channels: int,
|
167 |
+
encode_channels: int,
|
168 |
+
channel_ratios=(1, 1),
|
169 |
+
strides=(1, 1),
|
170 |
+
kernel_size=3,
|
171 |
+
bias=True,
|
172 |
+
block_dilations=(1, 1),
|
173 |
+
unit_kernel_size=3,
|
174 |
+
):
|
175 |
+
super().__init__()
|
176 |
+
assert len(channel_ratios) == len(strides)
|
177 |
+
|
178 |
+
self.conv = Conv1d(
|
179 |
+
in_channels=input_channels,
|
180 |
+
out_channels=encode_channels,
|
181 |
+
kernel_size=kernel_size,
|
182 |
+
stride=1,
|
183 |
+
bias=False,
|
184 |
+
)
|
185 |
+
self.conv_blocks = torch.nn.ModuleList()
|
186 |
+
in_channels = encode_channels
|
187 |
+
for idx, stride in enumerate(strides):
|
188 |
+
out_channels = int(encode_channels * channel_ratios[idx]) # could be float
|
189 |
+
self.conv_blocks += [
|
190 |
+
EncoderBlock(
|
191 |
+
in_channels,
|
192 |
+
out_channels,
|
193 |
+
stride,
|
194 |
+
dilations=block_dilations,
|
195 |
+
unit_kernel_size=unit_kernel_size,
|
196 |
+
bias=bias,
|
197 |
+
)
|
198 |
+
]
|
199 |
+
in_channels = out_channels
|
200 |
+
self.num_blocks = len(self.conv_blocks)
|
201 |
+
self.out_channels = out_channels
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
x = self.conv(x)
|
205 |
+
for i in range(self.num_blocks):
|
206 |
+
x = self.conv_blocks[i](x)
|
207 |
+
return x
|
208 |
+
|
209 |
+
|
210 |
+
class DecoderBlock(nn.Module):
|
211 |
+
"""Decoder block (no up-sampling)"""
|
212 |
+
|
213 |
+
def __init__(
|
214 |
+
self,
|
215 |
+
in_channels: int,
|
216 |
+
out_channels: int,
|
217 |
+
stride: int,
|
218 |
+
dilations=(1, 1),
|
219 |
+
unit_kernel_size=3,
|
220 |
+
bias=True,
|
221 |
+
):
|
222 |
+
super().__init__()
|
223 |
+
|
224 |
+
if stride == 1:
|
225 |
+
self.conv = Conv1d(
|
226 |
+
in_channels=in_channels,
|
227 |
+
out_channels=out_channels,
|
228 |
+
kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape
|
229 |
+
stride=stride,
|
230 |
+
bias=bias,
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
self.conv = ConvTranspose1d(
|
234 |
+
in_channels=in_channels,
|
235 |
+
out_channels=out_channels,
|
236 |
+
kernel_size=(2 * stride),
|
237 |
+
stride=stride,
|
238 |
+
bias=bias,
|
239 |
+
)
|
240 |
+
|
241 |
+
self.res_units = torch.nn.ModuleList()
|
242 |
+
for idx, dilation in enumerate(dilations):
|
243 |
+
self.res_units += [
|
244 |
+
ResidualUnit(
|
245 |
+
out_channels,
|
246 |
+
out_channels,
|
247 |
+
kernel_size=unit_kernel_size,
|
248 |
+
dilation=dilation,
|
249 |
+
)
|
250 |
+
]
|
251 |
+
self.num_res = len(self.res_units)
|
252 |
+
|
253 |
+
def forward(self, x):
|
254 |
+
x = self.conv(x)
|
255 |
+
for idx in range(self.num_res):
|
256 |
+
x = self.res_units[idx](x)
|
257 |
+
return x
|
258 |
+
|
259 |
+
|
260 |
+
class Decoder(nn.Module):
|
261 |
+
def __init__(
|
262 |
+
self,
|
263 |
+
code_dim: int,
|
264 |
+
output_channels: int,
|
265 |
+
decode_channels: int,
|
266 |
+
channel_ratios=(1, 1),
|
267 |
+
strides=(1, 1),
|
268 |
+
kernel_size=3,
|
269 |
+
bias=True,
|
270 |
+
block_dilations=(1, 1),
|
271 |
+
unit_kernel_size=3,
|
272 |
+
):
|
273 |
+
super().__init__()
|
274 |
+
assert len(channel_ratios) == len(strides)
|
275 |
+
|
276 |
+
self.conv1 = Conv1d(
|
277 |
+
in_channels=code_dim,
|
278 |
+
out_channels=int(decode_channels * channel_ratios[0]),
|
279 |
+
kernel_size=kernel_size,
|
280 |
+
stride=1,
|
281 |
+
bias=False,
|
282 |
+
)
|
283 |
+
|
284 |
+
self.conv_blocks = torch.nn.ModuleList()
|
285 |
+
for idx, stride in enumerate(strides):
|
286 |
+
in_channels = int(decode_channels * channel_ratios[idx])
|
287 |
+
if idx < (len(channel_ratios) - 1):
|
288 |
+
out_channels = int(decode_channels * channel_ratios[idx + 1])
|
289 |
+
else:
|
290 |
+
out_channels = decode_channels
|
291 |
+
self.conv_blocks += [
|
292 |
+
DecoderBlock(
|
293 |
+
in_channels,
|
294 |
+
out_channels,
|
295 |
+
stride,
|
296 |
+
dilations=block_dilations,
|
297 |
+
unit_kernel_size=unit_kernel_size,
|
298 |
+
bias=bias,
|
299 |
+
)
|
300 |
+
]
|
301 |
+
self.num_blocks = len(self.conv_blocks)
|
302 |
+
|
303 |
+
self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False)
|
304 |
+
|
305 |
+
def forward(self, z):
|
306 |
+
x = self.conv1(z)
|
307 |
+
for i in range(self.num_blocks):
|
308 |
+
x = self.conv_blocks[i](x)
|
309 |
+
x = self.conv2(x)
|
310 |
+
return x
|
higgs_audio/constants.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
AUDIO_IN_TOKEN = "<|AUDIO|>"
|
2 |
+
AUDIO_OUT_TOKEN = "<|AUDIO_OUT|>"
|
3 |
+
EOS_TOKEN = "<|end_of_text|>"
|
higgs_audio/data_collator/__init__.py
ADDED
File without changes
|
higgs_audio/data_collator/higgs_audio_collator.py
ADDED
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
from typing import List, Tuple, Dict
|
7 |
+
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import List, Optional
|
10 |
+
from transformers.models.whisper.processing_whisper import WhisperProcessor
|
11 |
+
|
12 |
+
from ..dataset.chatml_dataset import ChatMLDatasetSample, RankedChatMLDatasetSampleTuple
|
13 |
+
from ..model.utils import build_delay_pattern_mask
|
14 |
+
|
15 |
+
|
16 |
+
def _ceil_to_nearest(n, round_to):
|
17 |
+
return (n + round_to - 1) // round_to * round_to
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class HiggsAudioBatchInput:
|
22 |
+
input_ids: torch.LongTensor # shape (bsz, seq_len).
|
23 |
+
attention_mask: torch.Tensor # shape (bsz, seq_len).
|
24 |
+
audio_features: Optional[torch.Tensor] # shape (num_audio_in, feature_dim, max_mel_seq_len).
|
25 |
+
audio_feature_attention_mask: Optional[torch.Tensor] # shape (num_audio_in, max_mel_seq_len).
|
26 |
+
audio_out_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
|
27 |
+
audio_out_ids_start: Optional[torch.LongTensor] # shape (num_audio_out,)
|
28 |
+
# The audio_out_ids_start_group_loc has the same length as audio_out_ids_start. It is used to recover group location in a batch for an audio segment
|
29 |
+
# Currently, we concatenante audio segments along dim 0 to handle variadic audio segment length. However, in the alignment stage, we need the location information
|
30 |
+
# For example,
|
31 |
+
# audio_out_ids_start = [0, 2, 4, 8]; and the first two audio segments come from the same sample in a batch, and other two come from different samples.
|
32 |
+
# This is a batch of 3 samples, then we will have the group location as:
|
33 |
+
# audio_out_ids_start_group_loc = [0, 0, 1, 2]
|
34 |
+
audio_out_ids_start_group_loc: Optional[
|
35 |
+
torch.LongTensor
|
36 |
+
] # shape (num_audio_out,), specify which a sample's group location in the batch
|
37 |
+
audio_in_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_in_total_length)
|
38 |
+
audio_in_ids_start: Optional[torch.LongTensor] # shape (num_audio_in,)
|
39 |
+
label_ids: Optional[torch.LongTensor] # shape (bsz, seq_len)
|
40 |
+
label_audio_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
|
41 |
+
reward: Optional[float] = None
|
42 |
+
|
43 |
+
|
44 |
+
class HiggsAudioSampleCollator:
|
45 |
+
"""Sample collator for Higgs-Audio model.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
whisper_processor (WhisperProcessor): The whisper processor.
|
49 |
+
audio_in_token_id (int): The token id for audio-in.
|
50 |
+
audio_out_token_id (int): The token id for audio-out.
|
51 |
+
pad_token_id (int): The token id for padding.
|
52 |
+
audio_stream_bos_id (int): The token id for audio-stream beginning of sentence.
|
53 |
+
audio_stream_eos_id (int): The token id for audio-stream end of sentence.
|
54 |
+
round_to (int): The round-to value.
|
55 |
+
pad_left (bool): Whether to pad left.
|
56 |
+
return_audio_in_tokens (bool): Whether to return audio-in tokens.
|
57 |
+
use_delay_pattern (bool): Whether to use delay pattern.
|
58 |
+
disable_audio_codes_transform (bool): Whether to add bos and eos tokens to audio codes.
|
59 |
+
chunk_size_seconds (int): The chunk size in seconds.
|
60 |
+
add_new_bos_eos_for_long_chunk (bool): Whether to add new bos and eos tokens for long chunks.
|
61 |
+
mask_audio_out_token_label (bool): Whether to always mask the label associated with <|AUDIO_OUT|> token. Since we will always have `<|AUDIO_OUT|>` after `<|audio_bos|>`, we can safely mask <|AUDIO_OUT|>.
|
62 |
+
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
whisper_processor: WhisperProcessor,
|
68 |
+
audio_in_token_id,
|
69 |
+
audio_out_token_id,
|
70 |
+
pad_token_id,
|
71 |
+
audio_stream_bos_id,
|
72 |
+
audio_stream_eos_id,
|
73 |
+
round_to=8,
|
74 |
+
pad_left=False,
|
75 |
+
encode_whisper_embed=True,
|
76 |
+
return_audio_in_tokens=True,
|
77 |
+
audio_num_codebooks=None,
|
78 |
+
use_delay_pattern=False,
|
79 |
+
disable_audio_codes_transform=False,
|
80 |
+
chunk_size_seconds=30, # Maximum duration for each chunk
|
81 |
+
add_new_bos_eos_for_long_chunk=True,
|
82 |
+
mask_audio_out_token_label=True,
|
83 |
+
):
|
84 |
+
self.whisper_processor = whisper_processor
|
85 |
+
self.round_to = round_to
|
86 |
+
self.pad_left = pad_left
|
87 |
+
self.audio_in_token_id = audio_in_token_id
|
88 |
+
self.audio_out_token_id = audio_out_token_id
|
89 |
+
self.audio_stream_bos_id = audio_stream_bos_id
|
90 |
+
self.audio_stream_eos_id = audio_stream_eos_id
|
91 |
+
self.pad_token_id = pad_token_id
|
92 |
+
self.encode_whisper_embed = encode_whisper_embed
|
93 |
+
self.return_audio_in_tokens = return_audio_in_tokens
|
94 |
+
self.audio_num_codebooks = audio_num_codebooks
|
95 |
+
self.use_delay_pattern = use_delay_pattern
|
96 |
+
if encode_whisper_embed:
|
97 |
+
self.chunk_size_seconds = chunk_size_seconds
|
98 |
+
self.chunk_size_samples = int(chunk_size_seconds * whisper_processor.feature_extractor.sampling_rate)
|
99 |
+
else:
|
100 |
+
self.chunk_size_seconds = None
|
101 |
+
self.chunk_size_samples = None
|
102 |
+
self.disable_audio_codes_transform = disable_audio_codes_transform
|
103 |
+
self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk
|
104 |
+
self.mask_audio_out_token_label = mask_audio_out_token_label
|
105 |
+
|
106 |
+
def _process_and_duplicate_audio_tokens(
|
107 |
+
self,
|
108 |
+
input_ids: torch.Tensor,
|
109 |
+
audio_idx: int,
|
110 |
+
wv: torch.Tensor,
|
111 |
+
sr: int,
|
112 |
+
labels: Optional[torch.Tensor] = None,
|
113 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
114 |
+
"""Process long audio and duplicate corresponding audio tokens.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
input_ids: Input token ids
|
118 |
+
audio_idx: Index of the audio token in the sequence
|
119 |
+
wv: Audio waveform
|
120 |
+
sr: Sample rate
|
121 |
+
labels: Optional label ids to be duplicated alongside input ids
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
Tuple of:
|
125 |
+
- New input ids with duplicated audio tokens
|
126 |
+
- New label ids (if labels were provided) or None
|
127 |
+
- Number of chunks created
|
128 |
+
"""
|
129 |
+
# Calculate number of chunks needed
|
130 |
+
total_samples = len(wv)
|
131 |
+
num_chunks = math.ceil(total_samples / self.chunk_size_samples)
|
132 |
+
|
133 |
+
if num_chunks <= 1:
|
134 |
+
return input_ids, labels, 1
|
135 |
+
|
136 |
+
# Get the three tokens: <|audio_bos|><|AUDIO|><|audio_eos|>
|
137 |
+
audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2]
|
138 |
+
# Duplicate sequence for each chunk
|
139 |
+
duplicated_sequence = audio_token_seq.repeat(num_chunks)
|
140 |
+
|
141 |
+
# Create new input_ids with duplicated tokens
|
142 |
+
new_input_ids = torch.cat(
|
143 |
+
[
|
144 |
+
input_ids[: audio_idx - 1],
|
145 |
+
duplicated_sequence,
|
146 |
+
input_ids[audio_idx + 2 :],
|
147 |
+
]
|
148 |
+
)
|
149 |
+
|
150 |
+
# If labels are provided, duplicate them as well
|
151 |
+
new_labels = None
|
152 |
+
if labels is not None:
|
153 |
+
label_seq = labels[audio_idx - 1 : audio_idx + 2]
|
154 |
+
duplicated_labels = label_seq.repeat(num_chunks)
|
155 |
+
new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]])
|
156 |
+
|
157 |
+
return new_input_ids, new_labels, num_chunks
|
158 |
+
|
159 |
+
def __call__(self, batch: List[ChatMLDatasetSample]):
|
160 |
+
"""Collate the input data with support for long audio processing."""
|
161 |
+
|
162 |
+
label_ids = None
|
163 |
+
label_audio_ids = None
|
164 |
+
if all([ele.label_ids is None for ele in batch]):
|
165 |
+
return_labels = False
|
166 |
+
else:
|
167 |
+
return_labels = True
|
168 |
+
|
169 |
+
if self.encode_whisper_embed:
|
170 |
+
# Process each sample in the batch to handle long audio
|
171 |
+
# TODO(?) The implementation here can be optimized.
|
172 |
+
processed_batch = []
|
173 |
+
for i in range(len(batch)):
|
174 |
+
sample = batch[i]
|
175 |
+
audio_in_mask = sample.input_ids == self.audio_in_token_id
|
176 |
+
audio_in_indices = torch.where(audio_in_mask)[0]
|
177 |
+
audio_out_mask = sample.input_ids == self.audio_out_token_id
|
178 |
+
|
179 |
+
# Process each audio token and duplicate if needed
|
180 |
+
modified_input_ids = sample.input_ids
|
181 |
+
modified_labels = sample.label_ids if return_labels else None
|
182 |
+
modified_waveforms_concat = []
|
183 |
+
modified_waveforms_start = []
|
184 |
+
modified_sample_rate = []
|
185 |
+
offset = 0 # Track position changes from duplicating tokens
|
186 |
+
curr_wv_offset = 0
|
187 |
+
|
188 |
+
# Process input audio tokens
|
189 |
+
for idx, audio_idx in enumerate(audio_in_indices):
|
190 |
+
# Get the audio for this token
|
191 |
+
wv, sr = sample.get_wv(idx) # Use idx since we want the original audio index
|
192 |
+
if sr != self.whisper_processor.feature_extractor.sampling_rate:
|
193 |
+
resampled_wv = librosa.resample(
|
194 |
+
wv.cpu().numpy(),
|
195 |
+
orig_sr=sr,
|
196 |
+
target_sr=self.whisper_processor.feature_extractor.sampling_rate,
|
197 |
+
)
|
198 |
+
else:
|
199 |
+
resampled_wv = wv.cpu().numpy()
|
200 |
+
wv = torch.tensor(resampled_wv, device=wv.device)
|
201 |
+
sr = self.whisper_processor.feature_extractor.sampling_rate
|
202 |
+
|
203 |
+
# Process and duplicate tokens if necessary
|
204 |
+
token_pos = audio_idx + offset
|
205 |
+
modified_input_ids, modified_labels, num_chunks = self._process_and_duplicate_audio_tokens(
|
206 |
+
modified_input_ids, token_pos, wv, sr, modified_labels
|
207 |
+
)
|
208 |
+
|
209 |
+
# Update audio data
|
210 |
+
for chunk_idx in range(num_chunks):
|
211 |
+
chunk_start = chunk_idx * self.chunk_size_samples
|
212 |
+
chunk_end = min((chunk_idx + 1) * self.chunk_size_samples, len(wv))
|
213 |
+
chunk_wv = wv[chunk_start:chunk_end]
|
214 |
+
modified_waveforms_concat.append(chunk_wv)
|
215 |
+
modified_waveforms_start.append(curr_wv_offset)
|
216 |
+
curr_wv_offset += len(chunk_wv)
|
217 |
+
modified_sample_rate.append(sr)
|
218 |
+
|
219 |
+
# Update offset for next iteration
|
220 |
+
offset += (num_chunks - 1) * 3 # Each new chunk adds 3 more tokens
|
221 |
+
|
222 |
+
# Create new sample with modified tokens and audio data
|
223 |
+
processed_sample = ChatMLDatasetSample(
|
224 |
+
input_ids=modified_input_ids,
|
225 |
+
label_ids=modified_labels if return_labels else sample.label_ids,
|
226 |
+
audio_ids_concat=sample.audio_ids_concat,
|
227 |
+
audio_ids_start=sample.audio_ids_start,
|
228 |
+
audio_waveforms_concat=torch.cat(modified_waveforms_concat)
|
229 |
+
if modified_waveforms_concat
|
230 |
+
else sample.audio_waveforms_concat,
|
231 |
+
audio_waveforms_start=torch.tensor(modified_waveforms_start, dtype=torch.long)
|
232 |
+
if modified_waveforms_start
|
233 |
+
else sample.audio_waveforms_start,
|
234 |
+
audio_sample_rate=torch.tensor(modified_sample_rate)
|
235 |
+
if modified_sample_rate
|
236 |
+
else sample.audio_sample_rate,
|
237 |
+
audio_speaker_indices=torch.tensor([]),
|
238 |
+
# FIXME(sxjscience): The logic here is not correct for audio_label_ids_concat.
|
239 |
+
audio_label_ids_concat=sample.audio_label_ids_concat,
|
240 |
+
)
|
241 |
+
# audio_in_chunk_len = len(torch.where(modified_input_ids == self.audio_in_token_id)[0])
|
242 |
+
# assert audio_in_chunk_len == processed_sample.num_audios(), f"Mismatch: audio_in_chunk_len={audio_in_chunk_len}, processed_sample.num_audios()={processed_sample.num_audios()}"
|
243 |
+
processed_batch.append(processed_sample)
|
244 |
+
else:
|
245 |
+
processed_batch = batch
|
246 |
+
|
247 |
+
# Get the max sequence length based on processed batch
|
248 |
+
max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to)
|
249 |
+
|
250 |
+
# Get the ids for audio-in and audio-out for each batch
|
251 |
+
audio_in_wv_l = []
|
252 |
+
audio_in_ids_l = []
|
253 |
+
audio_out_ids_l = []
|
254 |
+
audio_out_ids_group_loc_l = []
|
255 |
+
audio_in_label_ids_l = None
|
256 |
+
audio_out_label_ids_l = None
|
257 |
+
reward_l = []
|
258 |
+
|
259 |
+
if return_labels:
|
260 |
+
audio_out_no_train_flag = [] # Whether the audio-out data should be trained on or not.
|
261 |
+
|
262 |
+
# Process the audio inputs and outputs
|
263 |
+
for i in range(len(processed_batch)):
|
264 |
+
audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id
|
265 |
+
audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id
|
266 |
+
audio_ids = torch.ones_like(processed_batch[i].input_ids)
|
267 |
+
audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1
|
268 |
+
audio_in_ids = audio_ids[audio_in_mask]
|
269 |
+
audio_out_ids = audio_ids[audio_out_mask]
|
270 |
+
|
271 |
+
if return_labels:
|
272 |
+
audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0)
|
273 |
+
if self.mask_audio_out_token_label:
|
274 |
+
processed_batch[i].label_ids[audio_out_mask] = -100
|
275 |
+
|
276 |
+
# Process audio inputs
|
277 |
+
if self.return_audio_in_tokens:
|
278 |
+
audio_in_ids_l.extend(
|
279 |
+
[processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids]
|
280 |
+
)
|
281 |
+
if processed_batch[i].audio_label_ids_concat is not None:
|
282 |
+
if audio_in_label_ids_l is None:
|
283 |
+
audio_in_label_ids_l = []
|
284 |
+
audio_in_label_ids_l.extend(
|
285 |
+
[
|
286 |
+
processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
|
287 |
+
for idx in audio_in_ids
|
288 |
+
]
|
289 |
+
)
|
290 |
+
|
291 |
+
audio_out_ids_l.extend(
|
292 |
+
[processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids]
|
293 |
+
)
|
294 |
+
audio_out_ids_group_loc_l.append(i)
|
295 |
+
if processed_batch[i].reward is not None:
|
296 |
+
reward_l.append(processed_batch[i].reward)
|
297 |
+
|
298 |
+
if processed_batch[i].audio_label_ids_concat is not None:
|
299 |
+
if audio_out_label_ids_l is None:
|
300 |
+
audio_out_label_ids_l = []
|
301 |
+
audio_out_label_ids_l.extend(
|
302 |
+
[
|
303 |
+
processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
|
304 |
+
for idx in audio_out_ids
|
305 |
+
]
|
306 |
+
)
|
307 |
+
|
308 |
+
if self.encode_whisper_embed:
|
309 |
+
for idx in audio_in_ids:
|
310 |
+
wv, sr = processed_batch[i].get_wv(idx)
|
311 |
+
resampled_wv = wv.cpu().numpy()
|
312 |
+
# Split long audio into chunks
|
313 |
+
total_samples = len(resampled_wv)
|
314 |
+
for chunk_start in range(0, total_samples, self.chunk_size_samples):
|
315 |
+
chunk_end = min(chunk_start + self.chunk_size_samples, total_samples)
|
316 |
+
chunk = resampled_wv[chunk_start:chunk_end]
|
317 |
+
audio_in_wv_l.append(chunk)
|
318 |
+
# assert len(audio_in_wv_l) == processed_batch[i].num_audios(), \
|
319 |
+
# f"Assertion failed: Mismatch in number of audios. " \
|
320 |
+
# f"Expected {processed_batch[i].num_audios()}, but got {len(audio_in_wv_l)} at index {i}."
|
321 |
+
|
322 |
+
if return_labels:
|
323 |
+
audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)
|
324 |
+
|
325 |
+
# Process all audio features
|
326 |
+
if len(audio_in_wv_l) > 0:
|
327 |
+
feature_ret = self.whisper_processor.feature_extractor(
|
328 |
+
audio_in_wv_l,
|
329 |
+
sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
|
330 |
+
return_attention_mask=True,
|
331 |
+
padding="max_length",
|
332 |
+
)
|
333 |
+
audio_features = torch.from_numpy(feature_ret["input_features"])
|
334 |
+
audio_feature_attention_mask = torch.from_numpy(feature_ret["attention_mask"])
|
335 |
+
else:
|
336 |
+
if self.encode_whisper_embed:
|
337 |
+
audio_features = torch.zeros(
|
338 |
+
(
|
339 |
+
0,
|
340 |
+
self.whisper_processor.feature_extractor.feature_size,
|
341 |
+
self.whisper_processor.feature_extractor.nb_max_frames,
|
342 |
+
),
|
343 |
+
dtype=torch.float32,
|
344 |
+
)
|
345 |
+
audio_feature_attention_mask = torch.zeros(
|
346 |
+
(0, self.whisper_processor.feature_extractor.nb_max_frames),
|
347 |
+
dtype=torch.int32,
|
348 |
+
)
|
349 |
+
else:
|
350 |
+
audio_features = None
|
351 |
+
audio_feature_attention_mask = None
|
352 |
+
|
353 |
+
# Process audio input tokens
|
354 |
+
if len(audio_in_ids_l) > 0:
|
355 |
+
# Append audio-stream-bos and eos tokens
|
356 |
+
new_audio_in_ids_l = []
|
357 |
+
for ele in audio_in_ids_l:
|
358 |
+
if self.disable_audio_codes_transform:
|
359 |
+
# Do not add audio-stream-bos or eos tokens.
|
360 |
+
# This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
|
361 |
+
audio_codes = ele
|
362 |
+
else:
|
363 |
+
audio_codes = torch.cat(
|
364 |
+
[
|
365 |
+
torch.full(
|
366 |
+
(ele.shape[0], 1),
|
367 |
+
self.audio_stream_bos_id,
|
368 |
+
dtype=torch.long,
|
369 |
+
),
|
370 |
+
ele,
|
371 |
+
torch.full(
|
372 |
+
(ele.shape[0], 1),
|
373 |
+
self.audio_stream_eos_id,
|
374 |
+
dtype=torch.long,
|
375 |
+
),
|
376 |
+
],
|
377 |
+
dim=1,
|
378 |
+
)
|
379 |
+
if self.use_delay_pattern:
|
380 |
+
audio_codes = build_delay_pattern_mask(
|
381 |
+
audio_codes.unsqueeze(0),
|
382 |
+
bos_token_id=self.audio_stream_bos_id,
|
383 |
+
pad_token_id=self.audio_stream_eos_id,
|
384 |
+
)[0].squeeze(0)
|
385 |
+
new_audio_in_ids_l.append(audio_codes)
|
386 |
+
audio_in_ids = torch.cat(new_audio_in_ids_l, dim=1).long()
|
387 |
+
audio_in_ids_start = torch.cumsum(
|
388 |
+
torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_in_ids_l[:-1]]),
|
389 |
+
dim=0,
|
390 |
+
)
|
391 |
+
else:
|
392 |
+
audio_in_ids = torch.zeros((0, 0), dtype=torch.long)
|
393 |
+
audio_in_ids_start = torch.zeros(0, dtype=torch.long)
|
394 |
+
|
395 |
+
# Process audio output tokens
|
396 |
+
audio_out_ids_start_group_loc = None
|
397 |
+
if len(audio_out_ids_l) > 0:
|
398 |
+
new_audio_out_ids_l = []
|
399 |
+
label_audio_ids_l = []
|
400 |
+
for idx, ele in enumerate(audio_out_ids_l):
|
401 |
+
if self.disable_audio_codes_transform:
|
402 |
+
# Do not add audio-stream-bos or eos tokens.
|
403 |
+
# This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
|
404 |
+
audio_codes = ele
|
405 |
+
if return_labels:
|
406 |
+
label_audio_ids = audio_out_label_ids_l[idx]
|
407 |
+
else:
|
408 |
+
audio_codes = torch.cat(
|
409 |
+
[
|
410 |
+
torch.full(
|
411 |
+
(ele.shape[0], 1),
|
412 |
+
self.audio_stream_bos_id,
|
413 |
+
dtype=torch.long,
|
414 |
+
),
|
415 |
+
ele,
|
416 |
+
torch.full(
|
417 |
+
(ele.shape[0], 1),
|
418 |
+
self.audio_stream_eos_id,
|
419 |
+
dtype=torch.long,
|
420 |
+
),
|
421 |
+
],
|
422 |
+
dim=1,
|
423 |
+
)
|
424 |
+
if return_labels:
|
425 |
+
label_audio_ids = torch.cat(
|
426 |
+
[
|
427 |
+
torch.full((ele.shape[0], 1), -100, dtype=torch.long),
|
428 |
+
ele,
|
429 |
+
torch.full(
|
430 |
+
(ele.shape[0], 1),
|
431 |
+
self.audio_stream_eos_id,
|
432 |
+
dtype=torch.long,
|
433 |
+
),
|
434 |
+
],
|
435 |
+
dim=1,
|
436 |
+
)
|
437 |
+
if self.use_delay_pattern:
|
438 |
+
audio_codes = build_delay_pattern_mask(
|
439 |
+
audio_codes.unsqueeze(0),
|
440 |
+
bos_token_id=self.audio_stream_bos_id,
|
441 |
+
pad_token_id=self.audio_stream_eos_id,
|
442 |
+
)[0].squeeze(0)
|
443 |
+
if return_labels:
|
444 |
+
label_audio_ids = build_delay_pattern_mask(
|
445 |
+
label_audio_ids.unsqueeze(0),
|
446 |
+
bos_token_id=-100,
|
447 |
+
pad_token_id=-100,
|
448 |
+
)[0].squeeze(0)
|
449 |
+
new_audio_out_ids_l.append(audio_codes)
|
450 |
+
|
451 |
+
if return_labels:
|
452 |
+
if audio_out_no_train_flag[idx]:
|
453 |
+
label_audio_ids[:] = -100
|
454 |
+
label_audio_ids_l.append(label_audio_ids)
|
455 |
+
|
456 |
+
audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long()
|
457 |
+
if return_labels:
|
458 |
+
label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long()
|
459 |
+
audio_out_ids_start = torch.cumsum(
|
460 |
+
torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]),
|
461 |
+
dim=0,
|
462 |
+
)
|
463 |
+
audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long)
|
464 |
+
else:
|
465 |
+
audio_out_ids = torch.zeros((0, 0), dtype=torch.long)
|
466 |
+
audio_out_ids_start = torch.zeros(0, dtype=torch.long)
|
467 |
+
if return_labels:
|
468 |
+
label_audio_ids = torch.zeros((0, 0), dtype=torch.long)
|
469 |
+
|
470 |
+
reward = torch.tensor(reward_l, dtype=torch.float32)
|
471 |
+
|
472 |
+
# Handle padding for input ids and attention mask
|
473 |
+
if self.pad_left:
|
474 |
+
input_ids = torch.stack(
|
475 |
+
[
|
476 |
+
F.pad(
|
477 |
+
ele.input_ids,
|
478 |
+
(max_seq_length - len(ele.input_ids), 0),
|
479 |
+
value=self.pad_token_id,
|
480 |
+
)
|
481 |
+
for ele in processed_batch
|
482 |
+
]
|
483 |
+
)
|
484 |
+
if return_labels:
|
485 |
+
label_ids = torch.stack(
|
486 |
+
[
|
487 |
+
F.pad(
|
488 |
+
ele.label_ids,
|
489 |
+
(max_seq_length - len(ele.label_ids), 0),
|
490 |
+
value=-100,
|
491 |
+
)
|
492 |
+
for ele in processed_batch
|
493 |
+
]
|
494 |
+
)
|
495 |
+
attention_mask = torch.stack(
|
496 |
+
[
|
497 |
+
F.pad(
|
498 |
+
torch.ones_like(ele.input_ids),
|
499 |
+
(max_seq_length - len(ele.input_ids), 0),
|
500 |
+
value=0,
|
501 |
+
)
|
502 |
+
for ele in processed_batch
|
503 |
+
]
|
504 |
+
)
|
505 |
+
else:
|
506 |
+
input_ids = torch.stack(
|
507 |
+
[
|
508 |
+
F.pad(
|
509 |
+
ele.input_ids,
|
510 |
+
(0, max_seq_length - len(ele.input_ids)),
|
511 |
+
value=self.pad_token_id,
|
512 |
+
)
|
513 |
+
for ele in processed_batch
|
514 |
+
]
|
515 |
+
)
|
516 |
+
if return_labels:
|
517 |
+
label_ids = torch.stack(
|
518 |
+
[
|
519 |
+
F.pad(
|
520 |
+
ele.label_ids,
|
521 |
+
(0, max_seq_length - len(ele.label_ids)),
|
522 |
+
value=-100,
|
523 |
+
)
|
524 |
+
for ele in processed_batch
|
525 |
+
]
|
526 |
+
)
|
527 |
+
attention_mask = torch.stack(
|
528 |
+
[
|
529 |
+
F.pad(
|
530 |
+
torch.ones_like(ele.input_ids),
|
531 |
+
(0, max_seq_length - len(ele.input_ids)),
|
532 |
+
value=0,
|
533 |
+
)
|
534 |
+
for ele in processed_batch
|
535 |
+
]
|
536 |
+
)
|
537 |
+
|
538 |
+
if not self.return_audio_in_tokens:
|
539 |
+
audio_in_ids = None
|
540 |
+
audio_in_ids_start = None
|
541 |
+
|
542 |
+
# Apply audio_num_codebooks limit if specified
|
543 |
+
if self.audio_num_codebooks is not None:
|
544 |
+
if audio_in_ids is not None:
|
545 |
+
audio_in_ids = audio_in_ids[: self.audio_num_codebooks]
|
546 |
+
if audio_out_ids is not None:
|
547 |
+
audio_out_ids = audio_out_ids[: self.audio_num_codebooks]
|
548 |
+
if label_audio_ids is not None:
|
549 |
+
label_audio_ids = label_audio_ids[: self.audio_num_codebooks]
|
550 |
+
|
551 |
+
return HiggsAudioBatchInput(
|
552 |
+
input_ids=input_ids,
|
553 |
+
attention_mask=attention_mask,
|
554 |
+
audio_features=audio_features,
|
555 |
+
audio_feature_attention_mask=audio_feature_attention_mask,
|
556 |
+
audio_out_ids=audio_out_ids,
|
557 |
+
audio_out_ids_start=audio_out_ids_start,
|
558 |
+
audio_out_ids_start_group_loc=audio_out_ids_start_group_loc,
|
559 |
+
audio_in_ids=audio_in_ids,
|
560 |
+
audio_in_ids_start=audio_in_ids_start,
|
561 |
+
label_ids=label_ids,
|
562 |
+
label_audio_ids=label_audio_ids,
|
563 |
+
reward=reward,
|
564 |
+
)
|
565 |
+
|
566 |
+
|
567 |
+
class HiggsAudioDPOSamplesCollator(HiggsAudioSampleCollator):
|
568 |
+
def __init__(self, *args, **kwargs):
|
569 |
+
super().__init__(*args, **kwargs)
|
570 |
+
|
571 |
+
def __call__(self, batch: List[RankedChatMLDatasetSampleTuple]) -> HiggsAudioBatchInput:
|
572 |
+
# flatten ranked chatml samples
|
573 |
+
chosen = []
|
574 |
+
rejected = []
|
575 |
+
|
576 |
+
for sample in batch:
|
577 |
+
chosen.append(sample.max_score_sample())
|
578 |
+
rejected.append(sample.min_score_sample())
|
579 |
+
|
580 |
+
merged = chosen
|
581 |
+
merged.extend(rejected)
|
582 |
+
|
583 |
+
return super().__call__(batch=merged)
|
higgs_audio/data_types.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Basic data types for multimodal ChatML format."""
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Dict, List, Optional, Union
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class AudioContent:
|
9 |
+
audio_url: str
|
10 |
+
# Base64 encoded audio bytes
|
11 |
+
raw_audio: Optional[str] = None
|
12 |
+
offset: Optional[float] = None
|
13 |
+
duration: Optional[float] = None
|
14 |
+
row_id: Optional[int] = None
|
15 |
+
type: str = "audio"
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class TextContent:
|
20 |
+
text: str
|
21 |
+
type: str = "text"
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class Message:
|
26 |
+
role: str
|
27 |
+
content: Union[str, AudioContent, TextContent, List[Union[str, AudioContent, TextContent]]]
|
28 |
+
recipient: Optional[str] = None
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class ChatMLSample:
|
33 |
+
"""Dataclass to hold multimodal ChatML data."""
|
34 |
+
|
35 |
+
messages: List[Message]
|
36 |
+
start_index: Optional[int] = None # We will mask the messages[:start_index] when finetuning the LLM.
|
37 |
+
misc: Optional[Dict] = None
|
38 |
+
speaker: Optional[str] = None
|
higgs_audio/dataset/__init__.py
ADDED
File without changes
|
higgs_audio/dataset/chatml_dataset.py
ADDED
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dacite
|
2 |
+
import pandas as pd
|
3 |
+
import torch
|
4 |
+
import json
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import multiprocessing as mp
|
8 |
+
|
9 |
+
from dataclasses import dataclass, fields
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
from typing import Union, List, Dict, Optional
|
12 |
+
|
13 |
+
from ..data_types import ChatMLSample, TextContent, AudioContent
|
14 |
+
from ..constants import AUDIO_IN_TOKEN, AUDIO_OUT_TOKEN
|
15 |
+
|
16 |
+
from loguru import logger
|
17 |
+
|
18 |
+
# Whisper processor, 30 sec -> 3000 features
|
19 |
+
# Then we divide 4 in the audio towker, we decrease 3000 features to 750, which gives 25 Hz
|
20 |
+
WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC = 25
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class ChatMLDatasetSample:
|
25 |
+
input_ids: torch.LongTensor # Shape (seq_len,): The input text tokens.
|
26 |
+
label_ids: torch.LongTensor # Shape (seq_len,): The label ids.
|
27 |
+
audio_ids_concat: torch.LongTensor # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
|
28 |
+
# Here `audio_seq_len` is the length of the concatenated audio tokens.`
|
29 |
+
audio_ids_start: (
|
30 |
+
torch.LongTensor
|
31 |
+
) # Shape (num_audios,): The start index of each audio token in the concatenated audio tokens.
|
32 |
+
audio_waveforms_concat: (
|
33 |
+
torch.Tensor
|
34 |
+
) # Shape (total_wv_length,): The concatenated audio waveforms for audio-in features.
|
35 |
+
audio_waveforms_start: (
|
36 |
+
torch.LongTensor
|
37 |
+
) # Shape (num_audios,): The start index of each audio waveform in the concatenated audio waveforms.
|
38 |
+
audio_sample_rate: torch.Tensor # Shape (num_audios,): The sampling rate of the audio waveforms.
|
39 |
+
audio_speaker_indices: (
|
40 |
+
torch.LongTensor
|
41 |
+
) # Shape (num_audios,) -1 means unknown speaker: The speaker indices for each audio.
|
42 |
+
audio_label_ids_concat: Optional[torch.LongTensor] = (
|
43 |
+
None # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
|
44 |
+
)
|
45 |
+
# Here `audio_seq_len` is the length of the concatenated audio tokens.`
|
46 |
+
reward: Optional[float] = None
|
47 |
+
|
48 |
+
def num_audios(self):
|
49 |
+
return max(len(self.audio_waveforms_start), len(self.audio_ids_start))
|
50 |
+
|
51 |
+
def get_audio_codes(self, idx):
|
52 |
+
code_start = self.audio_ids_start[idx]
|
53 |
+
if idx < len(self.audio_ids_start) - 1:
|
54 |
+
code_end = self.audio_ids_start[idx + 1]
|
55 |
+
else:
|
56 |
+
code_end = self.audio_ids_concat.shape[-1]
|
57 |
+
|
58 |
+
return self.audio_ids_concat[:, code_start:code_end]
|
59 |
+
|
60 |
+
def get_audio_codes_labels(self, idx):
|
61 |
+
if self.audio_label_ids_concat is None:
|
62 |
+
return None
|
63 |
+
code_start = self.audio_ids_start[idx]
|
64 |
+
if idx < len(self.audio_ids_start) - 1:
|
65 |
+
code_end = self.audio_ids_start[idx + 1]
|
66 |
+
else:
|
67 |
+
code_end = self.audio_ids_concat.shape[-1]
|
68 |
+
|
69 |
+
return self.audio_label_ids_concat[:, code_start:code_end]
|
70 |
+
|
71 |
+
def get_wv(self, idx):
|
72 |
+
wv_start = self.audio_waveforms_start[idx]
|
73 |
+
sr = self.audio_sample_rate[idx]
|
74 |
+
if idx < len(self.audio_waveforms_start) - 1:
|
75 |
+
wv_end = self.audio_waveforms_start[idx + 1]
|
76 |
+
else:
|
77 |
+
wv_end = self.audio_waveforms_concat.shape[-1]
|
78 |
+
return self.audio_waveforms_concat[wv_start:wv_end], sr
|
79 |
+
|
80 |
+
def cal_num_tokens(
|
81 |
+
self,
|
82 |
+
encode_whisper_embed: bool = True,
|
83 |
+
encode_audio_in_tokens: bool = False,
|
84 |
+
encode_audio_out_tokens: bool = True,
|
85 |
+
audio_in_token_id: int = 128015,
|
86 |
+
audio_out_token_id: int = 128016,
|
87 |
+
) -> int:
|
88 |
+
# we firstly exclude <|AUDIO|> and <|AUDIO_OUT|> because we do late merging and replace those position with actual audio features and audio token ids
|
89 |
+
# It's assumed that we always have audio_ids when audio_waveforms are there (but not vice-versa)
|
90 |
+
num_tokens = len(self.input_ids) - len(self.audio_ids_start)
|
91 |
+
|
92 |
+
if encode_whisper_embed and len(self.audio_waveforms_concat) > 0:
|
93 |
+
audio_lengths = torch.diff(self.audio_waveforms_start)
|
94 |
+
if len(audio_lengths):
|
95 |
+
# Sum before calling .item()
|
96 |
+
num_tokens += (
|
97 |
+
(
|
98 |
+
np.ceil(WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC * audio_lengths / self.audio_sample_rate[:-1])
|
99 |
+
).sum()
|
100 |
+
).item()
|
101 |
+
# add the last audio's token estimation
|
102 |
+
num_tokens += (
|
103 |
+
np.ceil(
|
104 |
+
WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC
|
105 |
+
* (self.audio_waveforms_concat.shape[0] - self.audio_waveforms_start[-1])
|
106 |
+
/ self.audio_sample_rate[-1]
|
107 |
+
)
|
108 |
+
).item()
|
109 |
+
|
110 |
+
if self.audio_ids_concat.size(1) > 0:
|
111 |
+
audio_io_ids = self.input_ids[
|
112 |
+
(self.input_ids == audio_in_token_id) | (self.input_ids == audio_out_token_id)
|
113 |
+
]
|
114 |
+
audio_io_id_lengths = torch.concat(
|
115 |
+
[
|
116 |
+
torch.diff(self.audio_ids_start),
|
117 |
+
torch.tensor([self.audio_ids_concat.shape[-1] - self.audio_ids_start[-1]]),
|
118 |
+
]
|
119 |
+
)
|
120 |
+
if encode_audio_in_tokens:
|
121 |
+
num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_in_token_id]).item()
|
122 |
+
|
123 |
+
if encode_audio_out_tokens:
|
124 |
+
num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_out_token_id]).item()
|
125 |
+
|
126 |
+
return int(num_tokens)
|
127 |
+
|
128 |
+
@classmethod
|
129 |
+
def merge(
|
130 |
+
cls,
|
131 |
+
samples: List["ChatMLDatasetSample"],
|
132 |
+
eos_token_id: int,
|
133 |
+
ignore_index: int,
|
134 |
+
padding_size: Optional[int] = None,
|
135 |
+
) -> "ChatMLDatasetSample":
|
136 |
+
"""Merges a list of ChatMLDatasetSample instances, inserting eos_token_id and ignore_index between them, and adjusting offsets for audio_ids_start and audio_waveforms_start.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
samples (List[ChatMLDatasetSample]): List of samples to merge.
|
140 |
+
eos_token_id (int): Tokens to be inserted into input_ids between samples.
|
141 |
+
ignore_index (int): Default label for padding.
|
142 |
+
padding_size (Optional[int]): If provided, pad the sequence to with this length.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
ChatMLDatasetSample: Merged and potentially padded sample.
|
146 |
+
"""
|
147 |
+
if not samples:
|
148 |
+
logger.fatal("The samples list is empty and cannot be merged.")
|
149 |
+
raise ValueError("The samples list is empty and cannot be merged.")
|
150 |
+
|
151 |
+
# Initialize empty lists for concatenation
|
152 |
+
input_ids_list = []
|
153 |
+
label_ids_list = []
|
154 |
+
audio_ids_concat_list = []
|
155 |
+
audio_ids_start_list = []
|
156 |
+
audio_waveforms_concat_list = []
|
157 |
+
audio_waveforms_start_list = []
|
158 |
+
audio_sample_rate_list = []
|
159 |
+
audio_speaker_indices_list = []
|
160 |
+
|
161 |
+
# Track offsets
|
162 |
+
audio_ids_offset = 0
|
163 |
+
audio_waveforms_offset = 0
|
164 |
+
|
165 |
+
for sample in samples:
|
166 |
+
# Add input_ids and label_ids with padding
|
167 |
+
if input_ids_list:
|
168 |
+
input_ids_list.append(torch.tensor([eos_token_id], dtype=torch.long))
|
169 |
+
label_ids_list.append(torch.tensor([ignore_index], dtype=torch.long))
|
170 |
+
input_ids_list.append(sample.input_ids)
|
171 |
+
label_ids_list.append(sample.label_ids)
|
172 |
+
|
173 |
+
# Add audio_ids_concat and handle empty audio ids
|
174 |
+
if sample.audio_ids_concat.size(1) > 0:
|
175 |
+
audio_ids_concat_list.append(sample.audio_ids_concat)
|
176 |
+
|
177 |
+
# Offset and add audio_ids_start
|
178 |
+
audio_ids_start_list.append(sample.audio_ids_start + audio_ids_offset)
|
179 |
+
audio_ids_offset += sample.audio_ids_concat.size(
|
180 |
+
1
|
181 |
+
) # (num_codebooks, seq_len): Update offset by audio_seq_len
|
182 |
+
|
183 |
+
# Add audio_waveforms_concat
|
184 |
+
if sample.audio_waveforms_concat.size(0) > 0:
|
185 |
+
# Check dimensions of the audio waveform to ensure consistency
|
186 |
+
if (
|
187 |
+
audio_waveforms_concat_list
|
188 |
+
and sample.audio_waveforms_concat.dim() != audio_waveforms_concat_list[0].dim()
|
189 |
+
):
|
190 |
+
logger.warning(
|
191 |
+
f"Skipping audio waveform with inconsistent dimensions: expected {audio_waveforms_concat_list[0].dim()}D, got {sample.audio_waveforms_concat.dim()}D"
|
192 |
+
)
|
193 |
+
continue
|
194 |
+
|
195 |
+
audio_waveforms_concat_list.append(sample.audio_waveforms_concat)
|
196 |
+
audio_waveforms_start_list.append(sample.audio_waveforms_start + audio_waveforms_offset)
|
197 |
+
audio_waveforms_offset += sample.audio_waveforms_concat.size(0)
|
198 |
+
|
199 |
+
# Add audio_sample_rate and audio_speaker_indices
|
200 |
+
audio_sample_rate_list.append(sample.audio_sample_rate)
|
201 |
+
|
202 |
+
audio_speaker_indices_list.append(sample.audio_speaker_indices)
|
203 |
+
|
204 |
+
# Concatenate all tensors
|
205 |
+
input_ids = torch.cat(input_ids_list, dim=0)
|
206 |
+
label_ids = torch.cat(label_ids_list, dim=0)
|
207 |
+
|
208 |
+
# Apply padding if padding_size is specified
|
209 |
+
if padding_size is not None and padding_size > 0:
|
210 |
+
input_ids = torch.cat(
|
211 |
+
[
|
212 |
+
input_ids,
|
213 |
+
torch.full((padding_size,), eos_token_id, dtype=torch.long),
|
214 |
+
],
|
215 |
+
dim=0,
|
216 |
+
)
|
217 |
+
label_ids = torch.cat(
|
218 |
+
[
|
219 |
+
label_ids,
|
220 |
+
torch.full((padding_size,), ignore_index, dtype=torch.long),
|
221 |
+
],
|
222 |
+
dim=0,
|
223 |
+
)
|
224 |
+
|
225 |
+
# Safely concatenate audio tensors with proper error handling
|
226 |
+
try:
|
227 |
+
audio_ids_concat = torch.cat(audio_ids_concat_list, dim=1) if audio_ids_concat_list else torch.tensor([[]])
|
228 |
+
audio_ids_start = torch.cat(audio_ids_start_list, dim=0) if audio_ids_start_list else torch.tensor([])
|
229 |
+
|
230 |
+
# Check for dimensional consistency in audio waveforms
|
231 |
+
if audio_waveforms_concat_list:
|
232 |
+
dims = [t.dim() for t in audio_waveforms_concat_list]
|
233 |
+
if not all(d == dims[0] for d in dims):
|
234 |
+
# If dimensions don't match, log warning and filter out the problematic tensors
|
235 |
+
logger.warning(
|
236 |
+
f"Inconsistent dimensions in audio waveforms: {dims}. Filtering to keep only consistent ones."
|
237 |
+
)
|
238 |
+
expected_dim = max(set(dims), key=dims.count) # Most common dimension
|
239 |
+
audio_waveforms_concat_list = [t for t in audio_waveforms_concat_list if t.dim() == expected_dim]
|
240 |
+
|
241 |
+
# Recalculate audio_waveforms_start with the filtered list
|
242 |
+
if audio_waveforms_concat_list:
|
243 |
+
audio_waveforms_offset = 0
|
244 |
+
audio_waveforms_start_list = []
|
245 |
+
for waveform in audio_waveforms_concat_list:
|
246 |
+
audio_waveforms_start_list.append(torch.tensor([audio_waveforms_offset]))
|
247 |
+
audio_waveforms_offset += waveform.size(0)
|
248 |
+
|
249 |
+
audio_waveforms_concat = (
|
250 |
+
torch.cat(audio_waveforms_concat_list, dim=0) if audio_waveforms_concat_list else torch.tensor([])
|
251 |
+
)
|
252 |
+
audio_waveforms_start = (
|
253 |
+
torch.cat(audio_waveforms_start_list, dim=0) if audio_waveforms_start_list else torch.tensor([])
|
254 |
+
)
|
255 |
+
audio_sample_rate = (
|
256 |
+
torch.cat(audio_sample_rate_list, dim=0) if audio_sample_rate_list else torch.tensor([])
|
257 |
+
)
|
258 |
+
audio_speaker_indices = (
|
259 |
+
torch.cat(audio_speaker_indices_list, dim=0) if audio_speaker_indices_list else torch.tensor([])
|
260 |
+
)
|
261 |
+
|
262 |
+
except RuntimeError as e:
|
263 |
+
logger.error(f"Error during tensor concatenation: {str(e)}")
|
264 |
+
logger.warning("Falling back to empty audio tensors")
|
265 |
+
# Fall back to empty tensors
|
266 |
+
audio_ids_concat = torch.tensor([[]])
|
267 |
+
audio_ids_start = torch.tensor([])
|
268 |
+
audio_waveforms_concat = torch.tensor([])
|
269 |
+
audio_waveforms_start = torch.tensor([])
|
270 |
+
audio_sample_rate = torch.tensor([])
|
271 |
+
audio_speaker_indices = torch.tensor([])
|
272 |
+
|
273 |
+
# Create the merged sample
|
274 |
+
merged_sample = cls(
|
275 |
+
input_ids=input_ids,
|
276 |
+
label_ids=label_ids,
|
277 |
+
audio_ids_concat=audio_ids_concat,
|
278 |
+
audio_ids_start=audio_ids_start,
|
279 |
+
audio_waveforms_concat=audio_waveforms_concat,
|
280 |
+
audio_waveforms_start=audio_waveforms_start,
|
281 |
+
audio_sample_rate=audio_sample_rate,
|
282 |
+
audio_speaker_indices=audio_speaker_indices,
|
283 |
+
)
|
284 |
+
|
285 |
+
return merged_sample
|
286 |
+
|
287 |
+
|
288 |
+
@dataclass
|
289 |
+
class RankedChatMLDatasetSampleTuple:
|
290 |
+
samples: List[ChatMLDatasetSample]
|
291 |
+
scores: List[float]
|
292 |
+
|
293 |
+
def max_score_sample(self) -> ChatMLDatasetSample:
|
294 |
+
idx = self.scores.index(max(self.scores))
|
295 |
+
self.samples[idx].reward = self.scores[idx]
|
296 |
+
return self.samples[idx]
|
297 |
+
|
298 |
+
def min_score_sample(self) -> ChatMLDatasetSample:
|
299 |
+
idx = self.scores.index(min(self.scores))
|
300 |
+
self.samples[idx].reward = self.scores[idx]
|
301 |
+
return self.samples[idx]
|
302 |
+
|
303 |
+
|
304 |
+
@dataclass
|
305 |
+
class ChatMLDatasetStorageSample:
|
306 |
+
input_tokens: torch.LongTensor
|
307 |
+
label_tokens: torch.LongTensor
|
308 |
+
audio_bytes_cache_dir_index: int
|
309 |
+
audio_codes_cache_dir_index: int
|
310 |
+
audio_bytes_indices: torch.LongTensor
|
311 |
+
audio_codes_indices: torch.LongTensor
|
312 |
+
speaker_indices: torch.LongTensor
|
313 |
+
file_index: int
|
314 |
+
original_sample_index: int
|
315 |
+
|
316 |
+
|
317 |
+
# TODO(sxjscience): We need to revist the logic about parsing speaker ids.
|
318 |
+
# Currently, we assume that the speaker id is stored at the "misc" field in ChatMLSample.
|
319 |
+
def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
|
320 |
+
"""Preprocess the ChatML sample to get the tokens for the text part.
|
321 |
+
|
322 |
+
Args:
|
323 |
+
sample (ChatMLSample): The ChatML sample to preprocess.
|
324 |
+
tokenizer: The tokenizer to use for encoding the text.
|
325 |
+
|
326 |
+
"""
|
327 |
+
|
328 |
+
try:
|
329 |
+
if not isinstance(sample, ChatMLSample):
|
330 |
+
# Handle all fields that could be NaN
|
331 |
+
if "speaker" in sample and pd.isna(sample["speaker"]):
|
332 |
+
sample["speaker"] = None
|
333 |
+
if "start_index" in sample and pd.isna(sample["start_index"]):
|
334 |
+
sample["start_index"] = None
|
335 |
+
if "content" in sample and pd.isna(sample["content"]):
|
336 |
+
sample["content"] = ""
|
337 |
+
|
338 |
+
# Convert any other potential NaN values in nested structures
|
339 |
+
def convert_nan_to_none(obj):
|
340 |
+
import numpy as np
|
341 |
+
|
342 |
+
if isinstance(obj, (pd.Series, np.ndarray)):
|
343 |
+
return obj.tolist()
|
344 |
+
elif pd.api.types.is_scalar(obj) and pd.isna(obj):
|
345 |
+
return None
|
346 |
+
elif isinstance(obj, dict):
|
347 |
+
return {k: convert_nan_to_none(v) for k, v in obj.items()}
|
348 |
+
elif isinstance(obj, (list, tuple)): # Fixed: Handle both list and tuple
|
349 |
+
return [convert_nan_to_none(item) for item in obj]
|
350 |
+
return obj
|
351 |
+
|
352 |
+
# Clean the sample data
|
353 |
+
clean_sample = convert_nan_to_none(sample)
|
354 |
+
|
355 |
+
val_keys = []
|
356 |
+
for field in fields(ChatMLSample):
|
357 |
+
if field.name in clean_sample:
|
358 |
+
val_keys.append(field.name)
|
359 |
+
clean_sample = {k: clean_sample[k] for k in val_keys}
|
360 |
+
|
361 |
+
try:
|
362 |
+
sample = dacite.from_dict(
|
363 |
+
data_class=ChatMLSample,
|
364 |
+
data=clean_sample,
|
365 |
+
config=dacite.Config(strict=True, check_types=True),
|
366 |
+
)
|
367 |
+
except Exception as e:
|
368 |
+
print(f"Failed to convert to ChatMLSample: {e}")
|
369 |
+
print(f"Clean sample: {json.dumps(clean_sample, indent=2)}")
|
370 |
+
return None, None, None, None
|
371 |
+
|
372 |
+
input_tokens = []
|
373 |
+
label_tokens = []
|
374 |
+
audio_contents = []
|
375 |
+
speaker_id = None
|
376 |
+
if sample.speaker is not None:
|
377 |
+
speaker_id = sample.speaker
|
378 |
+
elif sample.misc is not None:
|
379 |
+
if "speaker" in sample.misc:
|
380 |
+
speaker_id = sample.misc["speaker"]
|
381 |
+
|
382 |
+
total_m = len(sample.messages)
|
383 |
+
for turn_id, message in enumerate(sample.messages):
|
384 |
+
role = message.role
|
385 |
+
recipient = message.recipient
|
386 |
+
content = message.content
|
387 |
+
content_l = []
|
388 |
+
|
389 |
+
if isinstance(content, str):
|
390 |
+
content_l.append(TextContent(text=content))
|
391 |
+
elif isinstance(content, TextContent):
|
392 |
+
content_l.append(content)
|
393 |
+
elif isinstance(content, AudioContent):
|
394 |
+
content_l.append(content)
|
395 |
+
elif isinstance(content, list):
|
396 |
+
for ele in content:
|
397 |
+
if isinstance(ele, str):
|
398 |
+
content_l.append(TextContent(text=ele))
|
399 |
+
else:
|
400 |
+
content_l.append(ele)
|
401 |
+
if turn_id == 0:
|
402 |
+
prefix = f"<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n"
|
403 |
+
else:
|
404 |
+
prefix = f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
405 |
+
eot_postfix = "<|eot_id|>"
|
406 |
+
eom_postfix = "<|eom_id|>"
|
407 |
+
|
408 |
+
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
|
409 |
+
input_tokens.extend(prefix_tokens)
|
410 |
+
label_tokens.extend([-100 for _ in prefix_tokens])
|
411 |
+
|
412 |
+
if recipient:
|
413 |
+
assert role == "assistant", "Recipient is only available for assistant role."
|
414 |
+
recipient_tokens = tokenizer.encode(f"{recipient}<|recipient|>", add_special_tokens=False)
|
415 |
+
input_tokens.extend(recipient_tokens)
|
416 |
+
label_tokens.extend(recipient_tokens)
|
417 |
+
|
418 |
+
for content in content_l:
|
419 |
+
if content.type == "text":
|
420 |
+
text_tokens = tokenizer.encode(content.text, add_special_tokens=False)
|
421 |
+
input_tokens.extend(text_tokens)
|
422 |
+
if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
|
423 |
+
label_tokens.extend(text_tokens)
|
424 |
+
else:
|
425 |
+
label_tokens.extend([-100 for _ in text_tokens])
|
426 |
+
|
427 |
+
elif content.type == "audio":
|
428 |
+
# Generate the text-part of the audio tokens
|
429 |
+
audio_contents.append(content)
|
430 |
+
if role == "user" or role == "system":
|
431 |
+
# Add the text tokens
|
432 |
+
text_tokens = tokenizer.encode(
|
433 |
+
f"<|audio_bos|><|AUDIO|><|audio_eos|>",
|
434 |
+
add_special_tokens=False,
|
435 |
+
)
|
436 |
+
input_tokens.extend(text_tokens)
|
437 |
+
label_tokens.extend([-100 for _ in text_tokens])
|
438 |
+
elif role == "assistant":
|
439 |
+
# Add the text tokens for audio-out part.
|
440 |
+
text_tokens = tokenizer.encode(
|
441 |
+
f"<|audio_out_bos|><|AUDIO_OUT|><|audio_eos|>",
|
442 |
+
add_special_tokens=False,
|
443 |
+
)
|
444 |
+
input_tokens.extend(text_tokens)
|
445 |
+
if sample.start_index is None or turn_id >= sample.start_index:
|
446 |
+
label_tokens.extend(text_tokens)
|
447 |
+
else:
|
448 |
+
label_tokens.extend([-100 for _ in text_tokens])
|
449 |
+
next_id = turn_id + 1
|
450 |
+
if role == "assistant" and next_id != total_m and sample.messages[next_id].role == "assistant":
|
451 |
+
postfix_tokens = tokenizer.encode(eom_postfix, add_special_tokens=False)
|
452 |
+
input_tokens.extend(postfix_tokens)
|
453 |
+
else:
|
454 |
+
postfix_tokens = tokenizer.encode(eot_postfix, add_special_tokens=False)
|
455 |
+
input_tokens.extend(postfix_tokens)
|
456 |
+
if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
|
457 |
+
label_tokens.extend(postfix_tokens)
|
458 |
+
else:
|
459 |
+
label_tokens.extend([-100 for _ in postfix_tokens])
|
460 |
+
|
461 |
+
return input_tokens, label_tokens, audio_contents, speaker_id
|
462 |
+
|
463 |
+
except Exception as e:
|
464 |
+
print(f"Error in prepare_chatml_sample: {str(e)}")
|
465 |
+
print(f"Sample data: {json.dumps(sample, indent=2)}")
|
466 |
+
return None, None, None, None
|
467 |
+
|
468 |
+
|
469 |
+
def extract_generation_prompt_from_input_tokens(input_tokens, tokenizer):
|
470 |
+
"""Extract the generation prompt and reference answer from the input tokens.
|
471 |
+
|
472 |
+
For example:
|
473 |
+
|
474 |
+
Input Text = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
|
475 |
+
What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
|
476 |
+
<|start_header_id|>assistant<|end_header_id|>\n\nAt first they went by quick, too quick to even get.<|eot_id|>'
|
477 |
+
|
478 |
+
-->
|
479 |
+
|
480 |
+
Prompt = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
|
481 |
+
What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
|
482 |
+
<|start_header_id|>assistant<|end_header_id|>\n\n',
|
483 |
+
Reference = 'At first they went by quick, too quick to even get.'
|
484 |
+
|
485 |
+
Args:
|
486 |
+
input_tokens: The input tokens.
|
487 |
+
audio_contents: The audio contents.
|
488 |
+
tokenizer: The tokenizer to use for decoding the text.
|
489 |
+
|
490 |
+
Returns:
|
491 |
+
prompt_tokens: The tokens for the prompt.
|
492 |
+
reference_answer: The reference answer.
|
493 |
+
num_audios_in_reference: The number of audios in the reference answer.
|
494 |
+
|
495 |
+
"""
|
496 |
+
input_text = tokenizer.decode(input_tokens)
|
497 |
+
generation_prefix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
498 |
+
postfix = "<|eot_id|>"
|
499 |
+
assert generation_prefix in input_text
|
500 |
+
generation_prompt_end_loc = input_text.rfind(generation_prefix) + len(generation_prefix)
|
501 |
+
generation_prompt = input_text[:generation_prompt_end_loc]
|
502 |
+
reference_answer = input_text[generation_prompt_end_loc : input_text.find(postfix, generation_prompt_end_loc)]
|
503 |
+
num_audios_in_reference = reference_answer.count(AUDIO_IN_TOKEN) + reference_answer.count(AUDIO_OUT_TOKEN)
|
504 |
+
return (
|
505 |
+
tokenizer.encode(generation_prompt, add_special_tokens=False),
|
506 |
+
reference_answer,
|
507 |
+
num_audios_in_reference,
|
508 |
+
)
|
509 |
+
|
510 |
+
|
511 |
+
def prepare_chatml_dataframe_single_process(df, tokenizer):
|
512 |
+
"""Prepare the ChatML DataFrame."""
|
513 |
+
ret = []
|
514 |
+
for _, row in df.iterrows():
|
515 |
+
input_tokens, label_tokens, audio_contents, speaker_id = prepare_chatml_sample(row.to_dict(), tokenizer)
|
516 |
+
ret.append((input_tokens, label_tokens, audio_contents, speaker_id))
|
517 |
+
return ret
|
518 |
+
|
519 |
+
|
520 |
+
def prepare_chatml_dataframe(df, tokenizer, num_process=16):
|
521 |
+
if num_process is None:
|
522 |
+
return prepare_chatml_dataframe_single_process(df, tokenizer)
|
523 |
+
else:
|
524 |
+
num_process = max(min(len(df) // 1000, num_process), 1)
|
525 |
+
workloads = np.array_split(df, num_process)
|
526 |
+
with mp.Pool(num_process) as pool:
|
527 |
+
ret = pool.starmap(
|
528 |
+
prepare_chatml_dataframe_single_process,
|
529 |
+
[(workload, tokenizer) for workload in workloads],
|
530 |
+
)
|
531 |
+
return sum(ret, [])
|
532 |
+
|
533 |
+
|
534 |
+
class DatasetInterface(ABC):
|
535 |
+
@abstractmethod
|
536 |
+
def __getitem__(self, idx) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
|
537 |
+
"""Retrieve a dataset sample by index."""
|
538 |
+
raise NotImplementedError
|
539 |
+
|
540 |
+
|
541 |
+
class IterableDatasetInterface(ABC):
|
542 |
+
@abstractmethod
|
543 |
+
def __iter__(
|
544 |
+
self,
|
545 |
+
) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
|
546 |
+
"""Retrieve a sample by iterating through the dataset."""
|
547 |
+
raise NotImplementedError
|
548 |
+
|
549 |
+
|
550 |
+
@dataclass
|
551 |
+
class DatasetInfo:
|
552 |
+
dataset_type: str
|
553 |
+
group_type: Optional[str] = None
|
554 |
+
mask_text: Optional[bool] = None # Whether to mask the text tokens for pretraining samples.
|
higgs_audio/model/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoConfig, AutoModel
|
2 |
+
|
3 |
+
from .configuration_higgs_audio import HiggsAudioConfig, HiggsAudioEncoderConfig
|
4 |
+
from .modeling_higgs_audio import HiggsAudioModel
|
5 |
+
|
6 |
+
|
7 |
+
AutoConfig.register("higgs_audio_encoder", HiggsAudioEncoderConfig)
|
8 |
+
AutoConfig.register("higgs_audio", HiggsAudioConfig)
|
9 |
+
AutoModel.register(HiggsAudioConfig, HiggsAudioModel)
|
higgs_audio/model/audio_head.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Projector that maps hidden states from the LLM component to multimodal logits."""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import Optional, Tuple
|
8 |
+
|
9 |
+
from .common import HiggsAudioPreTrainedModel
|
10 |
+
from .configuration_higgs_audio import HiggsAudioConfig
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class HiggsAudioDecoderLayerOutput:
|
15 |
+
logits: torch.FloatTensor
|
16 |
+
audio_logits: torch.FloatTensor
|
17 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
18 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
19 |
+
|
20 |
+
|
21 |
+
class HiggsAudioDecoderProjector(HiggsAudioPreTrainedModel):
|
22 |
+
"""Projection layers that map hidden states from the LLM component to audio / text logits.
|
23 |
+
|
24 |
+
We support two type of audio head:
|
25 |
+
- Basic Audio Head:
|
26 |
+
Directly map the hidden states to audio logits for all the codebooks.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, config: HiggsAudioConfig, layer_idx: Optional[int] = None):
|
30 |
+
super().__init__(config)
|
31 |
+
self.text_lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
32 |
+
self.audio_lm_head = nn.Linear(
|
33 |
+
config.text_config.hidden_size,
|
34 |
+
config.audio_num_codebooks * (config.audio_codebook_size + 2),
|
35 |
+
bias=False,
|
36 |
+
)
|
37 |
+
|
38 |
+
# Initialize weights and apply final processing
|
39 |
+
self.post_init()
|
40 |
+
|
41 |
+
def forward(
|
42 |
+
self,
|
43 |
+
hidden_states,
|
44 |
+
audio_out_mask,
|
45 |
+
label_audio_ids=None,
|
46 |
+
attention_mask=None,
|
47 |
+
position_ids=None,
|
48 |
+
past_key_values=None,
|
49 |
+
use_cache=None,
|
50 |
+
output_attentions=None,
|
51 |
+
output_hidden_states=None,
|
52 |
+
output_audio_hidden_states=False,
|
53 |
+
cache_position=None,
|
54 |
+
):
|
55 |
+
"""
|
56 |
+
Args:
|
57 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
|
58 |
+
Hidden states from the LLM component
|
59 |
+
audio_out_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
60 |
+
Mask for identifying the audio out tokens.
|
61 |
+
label_audio_ids (`torch.Tensor` of shape `(num_codebooks, num_audio_out_tokens)`):
|
62 |
+
Label tokens for the audio-out part. This is used for calculating the logits if RQ-Transformer is used.
|
63 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
64 |
+
Mask to avoid performing attention on padding token indices
|
65 |
+
position_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
66 |
+
Position ids for the input tokens
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
logits (`torch.Tensor` of shape `(batch_size, seq_len, vocab_size)`):
|
70 |
+
Logits for text tokens
|
71 |
+
audio_logits (`torch.Tensor` of shape `(num_audio_out_tokens, audio_num_codebooks * audio_codebook_size)`):
|
72 |
+
Logits for audio tokens. We ensure `num_text_tokens + num_audio_tokens == batch_size * seq_len`
|
73 |
+
"""
|
74 |
+
logits = self.text_lm_head(hidden_states)
|
75 |
+
|
76 |
+
all_hidden_states = () if output_hidden_states else None
|
77 |
+
all_self_attns = () if output_attentions else None
|
78 |
+
next_decoder_cache = None
|
79 |
+
|
80 |
+
# TODO(sxjscience) Need to check if DeepSpeed Zero3 supports zero-shape input.
|
81 |
+
if self.config.audio_decoder_proj_num_layers > 0:
|
82 |
+
# create position embeddings to be shared across the decoder layers
|
83 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
84 |
+
for decoder_layer in self.transformer_layers:
|
85 |
+
if output_hidden_states:
|
86 |
+
all_hidden_states += (hidden_states,)
|
87 |
+
|
88 |
+
if self.gradient_checkpointing and self.training:
|
89 |
+
layer_outputs = self._gradient_checkpointing_func(
|
90 |
+
decoder_layer.__call__,
|
91 |
+
hidden_states,
|
92 |
+
attention_mask,
|
93 |
+
position_ids,
|
94 |
+
past_key_values,
|
95 |
+
output_attentions,
|
96 |
+
use_cache,
|
97 |
+
cache_position,
|
98 |
+
position_embeddings,
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
layer_outputs = decoder_layer(
|
102 |
+
hidden_states,
|
103 |
+
attention_mask=attention_mask,
|
104 |
+
position_ids=position_ids,
|
105 |
+
past_key_value=past_key_values,
|
106 |
+
output_attentions=output_attentions,
|
107 |
+
use_cache=use_cache,
|
108 |
+
cache_position=cache_position,
|
109 |
+
position_embeddings=position_embeddings,
|
110 |
+
)
|
111 |
+
hidden_states = layer_outputs[0]
|
112 |
+
hidden_states = self.norm(hidden_states)
|
113 |
+
|
114 |
+
if output_hidden_states:
|
115 |
+
all_hidden_states += (hidden_states,)
|
116 |
+
|
117 |
+
if output_attentions:
|
118 |
+
all_self_attns += (layer_outputs[1],)
|
119 |
+
|
120 |
+
if use_cache:
|
121 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
122 |
+
|
123 |
+
next_cache = next_decoder_cache if use_cache else None
|
124 |
+
|
125 |
+
audio_logits = self.audio_lm_head(hidden_states[audio_out_mask])
|
126 |
+
|
127 |
+
if output_audio_hidden_states:
|
128 |
+
audio_hidden_states = hidden_states[audio_out_mask]
|
129 |
+
else:
|
130 |
+
audio_hidden_states = None
|
131 |
+
|
132 |
+
return (
|
133 |
+
logits,
|
134 |
+
audio_logits,
|
135 |
+
all_self_attns,
|
136 |
+
all_hidden_states,
|
137 |
+
audio_hidden_states,
|
138 |
+
next_cache,
|
139 |
+
)
|
higgs_audio/model/common.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
from transformers.modeling_utils import PreTrainedModel
|
4 |
+
|
5 |
+
from .configuration_higgs_audio import HiggsAudioConfig
|
6 |
+
|
7 |
+
|
8 |
+
class HiggsAudioPreTrainedModel(PreTrainedModel):
|
9 |
+
config_class = HiggsAudioConfig
|
10 |
+
base_model_prefix = "model"
|
11 |
+
supports_gradient_checkpointing = True
|
12 |
+
_no_split_modules = []
|
13 |
+
_skip_keys_device_placement = "past_key_values"
|
14 |
+
_supports_flash_attn_2 = True
|
15 |
+
_supports_sdpa = True
|
16 |
+
|
17 |
+
def _init_weights(self, module):
|
18 |
+
std = self.config.init_std if hasattr(self.config, "init_std") else self.config.audio_encoder_config.init_std
|
19 |
+
|
20 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
21 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
22 |
+
if module.bias is not None:
|
23 |
+
module.bias.data.zero_()
|
24 |
+
elif isinstance(module, nn.Embedding):
|
25 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
26 |
+
if module.padding_idx is not None:
|
27 |
+
module.weight.data[module.padding_idx].zero_()
|
higgs_audio/model/configuration_higgs_audio.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.configuration_utils import PretrainedConfig
|
2 |
+
from transformers.models.auto import CONFIG_MAPPING
|
3 |
+
|
4 |
+
|
5 |
+
class HiggsAudioEncoderConfig(PretrainedConfig):
|
6 |
+
"""Configuration of the Audio encoder in Higgs-Audio."""
|
7 |
+
|
8 |
+
model_type = "higgs_audio_encoder"
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
num_mel_bins=128,
|
13 |
+
encoder_layers=32,
|
14 |
+
encoder_attention_heads=20,
|
15 |
+
encoder_ffn_dim=5120,
|
16 |
+
encoder_layerdrop=0.0,
|
17 |
+
d_model=1280,
|
18 |
+
dropout=0.0,
|
19 |
+
attention_dropout=0.0,
|
20 |
+
activation_function="gelu",
|
21 |
+
activation_dropout=0.0,
|
22 |
+
scale_embedding=False,
|
23 |
+
init_std=0.02,
|
24 |
+
max_source_positions=1500,
|
25 |
+
pad_token_id=128001,
|
26 |
+
**kwargs,
|
27 |
+
):
|
28 |
+
super().__init__(**kwargs)
|
29 |
+
|
30 |
+
self.num_mel_bins = num_mel_bins
|
31 |
+
self.d_model = d_model
|
32 |
+
self.encoder_layers = encoder_layers
|
33 |
+
self.encoder_attention_heads = encoder_attention_heads
|
34 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
35 |
+
self.dropout = dropout
|
36 |
+
self.attention_dropout = attention_dropout
|
37 |
+
self.activation_function = activation_function
|
38 |
+
self.activation_dropout = activation_dropout
|
39 |
+
self.encoder_layerdrop = encoder_layerdrop
|
40 |
+
self.num_hidden_layers = encoder_layers
|
41 |
+
self.init_std = init_std
|
42 |
+
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
43 |
+
self.max_source_positions = max_source_positions
|
44 |
+
self.pad_token_id = pad_token_id
|
45 |
+
|
46 |
+
|
47 |
+
class HiggsAudioConfig(PretrainedConfig):
|
48 |
+
r"""
|
49 |
+
This is the configuration class for the HiggsAudioModel.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
text_config (`Union[AutoConfig, dict]`):
|
53 |
+
The config object or dictionary of the text backbone.
|
54 |
+
audio_encoder_config (`Union[AutoConfig, dict]`):
|
55 |
+
The config object or dictionary of the whisper encoder.
|
56 |
+
The audio encoder will be bidirectional and will be only available for audio understanding.
|
57 |
+
audio_tokenizer_config
|
58 |
+
The config object or dictionary of the audio tokenizer.
|
59 |
+
audio_adapter_type
|
60 |
+
The type of audio adapter to use. We support two types of adapter:
|
61 |
+
- stack:
|
62 |
+
We stack additional Transformer layers after the main LLM backbone for audio generation.
|
63 |
+
- dual_ffn:
|
64 |
+
For selected part of the LLM backbone, we replace the text FFN with a dual FFN architecture
|
65 |
+
that contains an additional audio FFN. The audio FFN will be triggered when the location is marked for audio tokens.
|
66 |
+
- dual_ffn_fast_forward:
|
67 |
+
We pick a few layers in the LLM backbone to plug-in the audio FFN. For the remaining layers,
|
68 |
+
the audio hidden states will be directly fast-forward to the next layer.
|
69 |
+
This reduces the computational cost for audio generation.
|
70 |
+
audio_embed_avg (`bool`, *optional*, defaults to False):
|
71 |
+
Whether to average the audio embeddings before sending them to the text attention layer.
|
72 |
+
audio_ffn_hidden_size
|
73 |
+
The hidden size of the audio feedforward network in dual-path FFN
|
74 |
+
audio_ffn_intermediate_size
|
75 |
+
The intermediate size of the audio feedforward network in dual-path FFN
|
76 |
+
audio_dual_ffn_layers
|
77 |
+
The layers in the LLM backbone to plug-in the dual FFN layer (mixture of audio FFN and text FFN).
|
78 |
+
audio_decoder_proj_num_attention (`int`, *optional*, defaults to 0):
|
79 |
+
The number of attention heads in the audio decoder projection layer.
|
80 |
+
use_delay_pattern (`bool`, *optional*, defaults to False):
|
81 |
+
Whether to use delay pattern in the audio decoder.
|
82 |
+
skip_audio_tower (`bool`, *optional*, defaults to False):
|
83 |
+
Whether to skip the audio tower in the audio encoder.
|
84 |
+
use_audio_out_embed_projector (`bool`, *optional*, defaults to False):
|
85 |
+
Whether to use an embedding projector to map audio out embeddings.
|
86 |
+
use_audio_out_self_attention (`bool`, *optional*, defaults to False):
|
87 |
+
Whether to use self-attention to aggregate information from audio-tokens before sending to the text attention layer.
|
88 |
+
audio_num_codebooks (`int`, *optional*, defaults to 12):
|
89 |
+
The number of codebooks in RVQGAN.
|
90 |
+
audio_codebook_size (`int`, *optional*, defaults to 1024):
|
91 |
+
The size of each codebook in RVQGAN.
|
92 |
+
audio_stream_bos_id
|
93 |
+
The id of the bos in the audio stream
|
94 |
+
audio_stream_eos_id
|
95 |
+
The id of the eos in the audio stream
|
96 |
+
audio_bos_token (`str`, *optional*, defaults to "<|audio_bos|>"):
|
97 |
+
The special `<|audio_bos|>` token. In Higgs-Audio, it is mapped to 128011,
|
98 |
+
which is the index of `<|reserved_special_token_3|>` in Llama-3.1-8B-Instruct's tokenizer.
|
99 |
+
audio_eos_token (`str`, *optional*, defaults to "<|audio_eos|>"):
|
100 |
+
The special `<|audio_eos|>` token. We use 128012 as the default value,
|
101 |
+
which is the index of `<|reserved_special_token_4|>` in Llama-3.1-8B-Instruct's tokenizer.
|
102 |
+
audio_out_bos_token (`str`, *optional*, defaults to "<|audio_out_bos|>"):
|
103 |
+
The special `<|audio_out_bos|>` token. We use 128013 as the default value,
|
104 |
+
which is the index of `<|reserved_special_token_5|>` in Llama-3.1-8B-Instruct's tokenizer.
|
105 |
+
audio_token (`str`, *optional*, defaults to "<|AUDIO|>"):
|
106 |
+
The special `<|AUDIO|>` token. We use 128015 as the default value,
|
107 |
+
which is the index of `<|reserved_special_token_7|>` in Llama-3.1-8B-Instruct's tokenizer.
|
108 |
+
This token indicates that the location should be filled in with whisper features.
|
109 |
+
audio_out_token (`str`, *optional*, defaults to "<|AUDIO_OUT|>"):
|
110 |
+
The special `<|AUDIO_OUT|>` token. We use 128016 as the default value,
|
111 |
+
which is the index of `<|reserved_special_token_8|>` in Llama-3.1-8B-Instruct's tokenizer.
|
112 |
+
This token indicates that the location should be filled in with audio tokens extracted via audio tokenizer.
|
113 |
+
"""
|
114 |
+
|
115 |
+
model_type = "higgs_audio"
|
116 |
+
is_composition = True
|
117 |
+
|
118 |
+
def __init__(
|
119 |
+
self,
|
120 |
+
text_config=None,
|
121 |
+
audio_encoder_config=None,
|
122 |
+
audio_tokenizer_config=None,
|
123 |
+
audio_adapter_type="stack",
|
124 |
+
audio_embed_avg=False,
|
125 |
+
audio_ffn_hidden_size=4096,
|
126 |
+
audio_ffn_intermediate_size=14336,
|
127 |
+
audio_dual_ffn_layers=None,
|
128 |
+
audio_decoder_proj_num_layers=0,
|
129 |
+
encode_whisper_embed=True,
|
130 |
+
encode_audio_in_tokens=False,
|
131 |
+
use_delay_pattern=False,
|
132 |
+
skip_audio_tower=False,
|
133 |
+
use_audio_out_embed_projector=False,
|
134 |
+
use_audio_out_self_attention=False,
|
135 |
+
use_rq_transformer=False,
|
136 |
+
rq_transformer_hidden_size=None,
|
137 |
+
rq_transformer_intermediate_size=None,
|
138 |
+
rq_transformer_num_attention_heads=None,
|
139 |
+
rq_transformer_num_key_value_heads=None,
|
140 |
+
rq_transformer_num_hidden_layers=3,
|
141 |
+
audio_num_codebooks=12,
|
142 |
+
audio_codebook_size=1024,
|
143 |
+
audio_stream_bos_id=1024,
|
144 |
+
audio_stream_eos_id=1025,
|
145 |
+
audio_bos_token="<|audio_bos|>",
|
146 |
+
audio_eos_token="<|audio_eos|>",
|
147 |
+
audio_out_bos_token="<|audio_out_bos|>",
|
148 |
+
audio_in_token="<|AUDIO|>",
|
149 |
+
audio_out_token="<|AUDIO_OUT|>",
|
150 |
+
audio_in_token_idx=128015,
|
151 |
+
audio_out_token_idx=128016,
|
152 |
+
pad_token_id=128001,
|
153 |
+
audio_out_bos_token_id=128013,
|
154 |
+
audio_eos_token_id=128012,
|
155 |
+
**kwargs,
|
156 |
+
):
|
157 |
+
if isinstance(audio_encoder_config, dict):
|
158 |
+
audio_encoder_config["model_type"] = (
|
159 |
+
audio_encoder_config["model_type"] if "model_type" in audio_encoder_config else "higgs_audio_encoder"
|
160 |
+
)
|
161 |
+
audio_encoder_config = CONFIG_MAPPING[audio_encoder_config["model_type"]](**audio_encoder_config)
|
162 |
+
elif audio_encoder_config is None:
|
163 |
+
audio_encoder_config = HiggsAudioEncoderConfig()
|
164 |
+
|
165 |
+
if isinstance(text_config, dict):
|
166 |
+
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
|
167 |
+
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
168 |
+
elif text_config is None:
|
169 |
+
text_config = CONFIG_MAPPING["llama"]()
|
170 |
+
|
171 |
+
assert audio_adapter_type in [
|
172 |
+
"stack",
|
173 |
+
"dual_ffn",
|
174 |
+
"dual_ffn_fast_forward",
|
175 |
+
], f"Invalid audio adapter type: {audio_adapter_type}"
|
176 |
+
if audio_adapter_type.startswith("dual_ffn"):
|
177 |
+
assert audio_dual_ffn_layers is not None, (
|
178 |
+
"audio_dual_ffn_layers must be specified when using dual_ffn adapter."
|
179 |
+
)
|
180 |
+
self.text_config = text_config
|
181 |
+
self.audio_encoder_config = audio_encoder_config
|
182 |
+
self.audio_tokenizer_config = audio_tokenizer_config
|
183 |
+
self.audio_adapter_type = audio_adapter_type
|
184 |
+
self.audio_embed_avg = audio_embed_avg
|
185 |
+
self.audio_ffn_hidden_size = audio_ffn_hidden_size
|
186 |
+
self.audio_ffn_intermediate_size = audio_ffn_intermediate_size
|
187 |
+
self.audio_dual_ffn_layers = audio_dual_ffn_layers
|
188 |
+
self.audio_decoder_proj_num_layers = audio_decoder_proj_num_layers
|
189 |
+
self.encode_whisper_embed = encode_whisper_embed
|
190 |
+
self.encode_audio_in_tokens = encode_audio_in_tokens
|
191 |
+
self.use_delay_pattern = use_delay_pattern
|
192 |
+
self.skip_audio_tower = skip_audio_tower
|
193 |
+
self.use_audio_out_embed_projector = use_audio_out_embed_projector
|
194 |
+
self.use_audio_out_self_attention = use_audio_out_self_attention
|
195 |
+
|
196 |
+
self.use_rq_transformer = use_rq_transformer
|
197 |
+
|
198 |
+
if self.use_rq_transformer:
|
199 |
+
assert not self.use_delay_pattern, "Delay pattern is not supported if you turned on RQ-Transformer!"
|
200 |
+
self.rq_transformer_hidden_size = rq_transformer_hidden_size
|
201 |
+
self.rq_transformer_intermediate_size = rq_transformer_intermediate_size
|
202 |
+
self.rq_transformer_num_attention_heads = rq_transformer_num_attention_heads
|
203 |
+
self.rq_transformer_num_key_value_heads = rq_transformer_num_key_value_heads
|
204 |
+
self.rq_transformer_num_hidden_layers = rq_transformer_num_hidden_layers
|
205 |
+
|
206 |
+
if use_rq_transformer:
|
207 |
+
# For RQ-Transformer, we set the hidden_size to the same as the text model's hidden size if it is not specified.
|
208 |
+
if self.rq_transformer_hidden_size is None:
|
209 |
+
self.rq_transformer_hidden_size = text_config.hidden_size
|
210 |
+
assert self.rq_transformer_hidden_size % 128 == 0
|
211 |
+
if self.rq_transformer_intermediate_size is None:
|
212 |
+
self.rq_transformer_intermediate_size = text_config.intermediate_size
|
213 |
+
if self.rq_transformer_num_attention_heads is None:
|
214 |
+
self.rq_transformer_num_attention_heads = self.rq_transformer_hidden_size // 128
|
215 |
+
if self.rq_transformer_num_key_value_heads is None:
|
216 |
+
self.rq_transformer_num_key_value_heads = self.rq_transformer_hidden_size // 128 // 4
|
217 |
+
assert self.rq_transformer_hidden_size % self.rq_transformer_num_attention_heads == 0
|
218 |
+
assert self.rq_transformer_hidden_size % self.rq_transformer_num_key_value_heads == 0
|
219 |
+
|
220 |
+
self.audio_num_codebooks = audio_num_codebooks
|
221 |
+
self.audio_codebook_size = audio_codebook_size
|
222 |
+
self.audio_bos_token = audio_bos_token
|
223 |
+
self.audio_eos_token = audio_eos_token
|
224 |
+
self.audio_out_bos_token = audio_out_bos_token
|
225 |
+
self.audio_in_token = audio_in_token
|
226 |
+
self.audio_out_token = audio_out_token
|
227 |
+
self.audio_in_token_idx = audio_in_token_idx
|
228 |
+
self.audio_out_token_idx = audio_out_token_idx
|
229 |
+
self.audio_stream_bos_id = audio_stream_bos_id
|
230 |
+
self.audio_stream_eos_id = audio_stream_eos_id
|
231 |
+
self.audio_out_bos_token_id = audio_out_bos_token_id
|
232 |
+
self.audio_eos_token_id = audio_eos_token_id
|
233 |
+
|
234 |
+
super().__init__(**kwargs)
|
235 |
+
self.pad_token_id = pad_token_id
|
higgs_audio/model/cuda_graph_runner.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import Optional, List, Dict, Tuple, Union
|
4 |
+
import gc
|
5 |
+
|
6 |
+
from transformers.cache_utils import Cache
|
7 |
+
|
8 |
+
|
9 |
+
_NUM_WARMUP_ITERS = 2
|
10 |
+
|
11 |
+
|
12 |
+
class CUDAGraphRunner(nn.Module):
|
13 |
+
def __init__(self, model):
|
14 |
+
super().__init__()
|
15 |
+
self.model = model
|
16 |
+
|
17 |
+
self.input_buffers: Dict[str, torch.Tensor] = {}
|
18 |
+
self.output_buffers: Dict[str, torch.Tensor] = {}
|
19 |
+
|
20 |
+
self._graph: Optional[torch.cuda.CUDAGraph] = None
|
21 |
+
|
22 |
+
@property
|
23 |
+
def graph(self):
|
24 |
+
assert self._graph is not None
|
25 |
+
return self._graph
|
26 |
+
|
27 |
+
def capture(
|
28 |
+
self,
|
29 |
+
hidden_states: torch.Tensor,
|
30 |
+
causal_mask: torch.Tensor,
|
31 |
+
position_ids: torch.Tensor,
|
32 |
+
audio_discrete_codes_mask: torch.Tensor,
|
33 |
+
cache_position: torch.Tensor,
|
34 |
+
past_key_values: Union[Cache, List[torch.FloatTensor]],
|
35 |
+
use_cache: bool,
|
36 |
+
audio_attention_mask: torch.Tensor,
|
37 |
+
fast_forward_attention_mask: torch.Tensor,
|
38 |
+
output_attentions: bool,
|
39 |
+
output_hidden_states: bool,
|
40 |
+
is_decoding_audio_token: Optional[bool] = None,
|
41 |
+
is_using_cuda_graph: Optional[bool] = False,
|
42 |
+
stream: torch.cuda.Stream = None,
|
43 |
+
memory_pool: Optional[Tuple[int, int]] = None,
|
44 |
+
):
|
45 |
+
assert self._graph is None
|
46 |
+
# Run warmup iterations
|
47 |
+
for _ in range(_NUM_WARMUP_ITERS):
|
48 |
+
self.model(
|
49 |
+
hidden_states=hidden_states,
|
50 |
+
causal_mask=causal_mask,
|
51 |
+
position_ids=position_ids,
|
52 |
+
audio_discrete_codes_mask=audio_discrete_codes_mask,
|
53 |
+
cache_position=cache_position,
|
54 |
+
past_key_values=past_key_values,
|
55 |
+
use_cache=use_cache,
|
56 |
+
audio_attention_mask=audio_attention_mask,
|
57 |
+
fast_forward_attention_mask=fast_forward_attention_mask,
|
58 |
+
output_attentions=output_attentions,
|
59 |
+
output_hidden_states=output_hidden_states,
|
60 |
+
is_decoding_audio_token=is_decoding_audio_token,
|
61 |
+
is_using_cuda_graph=is_using_cuda_graph,
|
62 |
+
)
|
63 |
+
|
64 |
+
torch.cuda.synchronize()
|
65 |
+
|
66 |
+
# Capture the graph
|
67 |
+
self._graph = torch.cuda.CUDAGraph()
|
68 |
+
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
69 |
+
out_hidden_states, all_hidden_states, all_self_attns = self.model(
|
70 |
+
hidden_states=hidden_states,
|
71 |
+
causal_mask=causal_mask,
|
72 |
+
position_ids=position_ids,
|
73 |
+
audio_discrete_codes_mask=audio_discrete_codes_mask,
|
74 |
+
cache_position=cache_position,
|
75 |
+
past_key_values=past_key_values,
|
76 |
+
use_cache=use_cache,
|
77 |
+
audio_attention_mask=audio_attention_mask,
|
78 |
+
fast_forward_attention_mask=fast_forward_attention_mask,
|
79 |
+
output_attentions=output_attentions,
|
80 |
+
output_hidden_states=output_hidden_states,
|
81 |
+
is_decoding_audio_token=is_decoding_audio_token,
|
82 |
+
is_using_cuda_graph=is_using_cuda_graph,
|
83 |
+
)
|
84 |
+
# hidden_states_out = torch.ops._C.weak_ref_tensor(outputs[0])
|
85 |
+
# del outputs
|
86 |
+
gc.collect()
|
87 |
+
torch.cuda.synchronize()
|
88 |
+
|
89 |
+
# Save input and output buffers
|
90 |
+
self.input_buffers = {
|
91 |
+
"hidden_states": hidden_states,
|
92 |
+
"causal_mask": causal_mask,
|
93 |
+
"position_ids": position_ids,
|
94 |
+
"audio_discrete_codes_mask": audio_discrete_codes_mask,
|
95 |
+
"cache_position": cache_position,
|
96 |
+
"past_key_values": past_key_values,
|
97 |
+
"audio_attention_mask": audio_attention_mask,
|
98 |
+
"fast_forward_attention_mask": fast_forward_attention_mask,
|
99 |
+
}
|
100 |
+
self.output_buffers = {
|
101 |
+
"hidden_states": out_hidden_states,
|
102 |
+
"all_hidden_states": all_hidden_states,
|
103 |
+
"all_self_attns": all_self_attns,
|
104 |
+
}
|
105 |
+
|
106 |
+
def forward(
|
107 |
+
self,
|
108 |
+
hidden_states: torch.Tensor,
|
109 |
+
causal_mask: torch.Tensor,
|
110 |
+
position_ids: torch.Tensor,
|
111 |
+
audio_discrete_codes_mask: torch.Tensor,
|
112 |
+
cache_position: torch.Tensor,
|
113 |
+
audio_attention_mask: torch.Tensor,
|
114 |
+
fast_forward_attention_mask: torch.Tensor,
|
115 |
+
**kwargs,
|
116 |
+
) -> torch.Tensor:
|
117 |
+
# Copy input tensors to buffers
|
118 |
+
self.input_buffers["hidden_states"].copy_(hidden_states, non_blocking=True)
|
119 |
+
self.input_buffers["causal_mask"].copy_(causal_mask, non_blocking=True)
|
120 |
+
self.input_buffers["position_ids"].copy_(position_ids, non_blocking=True)
|
121 |
+
self.input_buffers["audio_discrete_codes_mask"].copy_(audio_discrete_codes_mask, non_blocking=True)
|
122 |
+
self.input_buffers["cache_position"].copy_(cache_position, non_blocking=True)
|
123 |
+
self.input_buffers["audio_attention_mask"].copy_(audio_attention_mask, non_blocking=True)
|
124 |
+
self.input_buffers["fast_forward_attention_mask"].copy_(fast_forward_attention_mask, non_blocking=True)
|
125 |
+
|
126 |
+
# Run the captured graph
|
127 |
+
self.graph.replay()
|
128 |
+
|
129 |
+
return self.output_buffers["hidden_states"], None, None
|
higgs_audio/model/custom_modules.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class PartiallyFrozenEmbedding(nn.Module):
|
6 |
+
"""Split an existing `nn.Embedding` module that splits the embedding into:
|
7 |
+
|
8 |
+
- A frozen embedding for indices [0..freeze_until_idx].
|
9 |
+
- A trainable embedding for indices [freeze_until_idx+1..vocab_size-1].
|
10 |
+
|
11 |
+
This should work with both Zero-2 and Zero-3 seamlessly
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, original_embedding: nn.Embedding, freeze_until_idx: int):
|
15 |
+
"""
|
16 |
+
:param original_embedding: An instance of nn.Embedding (the original embedding layer).
|
17 |
+
:param freeze_until_idx: The index up to which the embedding is frozen (excluding). The freeze_until_idx is not frozen.
|
18 |
+
"""
|
19 |
+
super().__init__()
|
20 |
+
self.freeze_until_idx = freeze_until_idx
|
21 |
+
self.original_vocab_size = original_embedding.num_embeddings
|
22 |
+
self.embedding_dim = original_embedding.embedding_dim
|
23 |
+
|
24 |
+
# Split the original embedding into frozen and trainable parts
|
25 |
+
self.embedding_frozen = nn.Embedding(
|
26 |
+
freeze_until_idx,
|
27 |
+
self.embedding_dim,
|
28 |
+
dtype=original_embedding.weight.dtype,
|
29 |
+
device=original_embedding.weight.device,
|
30 |
+
)
|
31 |
+
self.embedding_trainable = nn.Embedding(
|
32 |
+
self.original_vocab_size - freeze_until_idx,
|
33 |
+
self.embedding_dim,
|
34 |
+
dtype=original_embedding.weight.dtype,
|
35 |
+
device=original_embedding.weight.device,
|
36 |
+
)
|
37 |
+
|
38 |
+
# Copy weights from the original embedding into the frozen and trainable parts
|
39 |
+
with torch.no_grad():
|
40 |
+
self.embedding_frozen.weight.copy_(original_embedding.weight[:freeze_until_idx])
|
41 |
+
self.embedding_trainable.weight.copy_(original_embedding.weight[freeze_until_idx:])
|
42 |
+
|
43 |
+
# Freeze the frozen embedding
|
44 |
+
self.embedding_frozen.weight.requires_grad = False
|
45 |
+
|
46 |
+
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
47 |
+
"""
|
48 |
+
Forward pass for the split embedding wrapper.
|
49 |
+
:param input_ids: Tensor of shape [batch_size, seq_len] with indices in [0..original_vocab_size-1].
|
50 |
+
"""
|
51 |
+
# Masks to separate frozen and trainable indices
|
52 |
+
# (bsz, seq_len)
|
53 |
+
mask_frozen = input_ids < self.freeze_until_idx
|
54 |
+
mask_trainable = ~mask_frozen
|
55 |
+
|
56 |
+
# Output tensor for embedding results
|
57 |
+
batch_size, seq_len = input_ids.shape
|
58 |
+
embeddings = torch.zeros(
|
59 |
+
batch_size,
|
60 |
+
seq_len,
|
61 |
+
self.embedding_dim,
|
62 |
+
device=input_ids.device,
|
63 |
+
dtype=self.embedding_frozen.weight.dtype,
|
64 |
+
)
|
65 |
+
|
66 |
+
# Handle frozen embedding
|
67 |
+
if mask_frozen.any():
|
68 |
+
frozen_ids = input_ids[mask_frozen]
|
69 |
+
frozen_emb = self.embedding_frozen(frozen_ids)
|
70 |
+
embeddings[mask_frozen] = frozen_emb
|
71 |
+
|
72 |
+
# Handle trainable embedding
|
73 |
+
if mask_trainable.any():
|
74 |
+
# Adjust trainable IDs to the local index space of the trainable embedding
|
75 |
+
trainable_ids = input_ids[mask_trainable] - (self.freeze_until_idx)
|
76 |
+
trainable_emb = self.embedding_trainable(trainable_ids)
|
77 |
+
embeddings[mask_trainable] = trainable_emb
|
78 |
+
|
79 |
+
return embeddings
|
80 |
+
|
81 |
+
def to_unsplit(self) -> nn.Embedding:
|
82 |
+
unsplit_embedding = nn.Embedding(
|
83 |
+
self.original_vocab_size,
|
84 |
+
self.embedding_dim,
|
85 |
+
dtype=self.embedding_frozen.weight.dtype,
|
86 |
+
device=self.embedding_frozen.weight.device,
|
87 |
+
)
|
88 |
+
|
89 |
+
with torch.no_grad():
|
90 |
+
unsplit_embedding.weight[: self.freeze_until_idx].copy_(self.embedding_frozen.weight)
|
91 |
+
unsplit_embedding.weight[self.freeze_until_idx :].copy_(self.embedding_trainable.weight)
|
92 |
+
|
93 |
+
return unsplit_embedding
|
94 |
+
|
95 |
+
|
96 |
+
class PartiallyFrozenLinear(nn.Module):
|
97 |
+
"""A wrapper around nn.Linear to partially freeze part of the weight matrix."""
|
98 |
+
|
99 |
+
def __init__(self, original_linear: nn.Linear, freeze_until_idx: int):
|
100 |
+
"""
|
101 |
+
:param original_linear: The original nn.Linear layer.
|
102 |
+
:param freeze_until_idx: The index up to which the rows of the weight matrix are frozen.
|
103 |
+
"""
|
104 |
+
super().__init__()
|
105 |
+
assert original_linear.bias is None, "Currently only support linear module without bias"
|
106 |
+
|
107 |
+
self.freeze_until_idx = freeze_until_idx
|
108 |
+
self.input_dim = original_linear.in_features
|
109 |
+
self.output_dim = original_linear.out_features
|
110 |
+
|
111 |
+
# Create frozen and trainable linear layers
|
112 |
+
self.linear_frozen = nn.Linear(
|
113 |
+
self.input_dim,
|
114 |
+
freeze_until_idx,
|
115 |
+
bias=False,
|
116 |
+
dtype=original_linear.weight.dtype,
|
117 |
+
device=original_linear.weight.device,
|
118 |
+
)
|
119 |
+
self.linear_trainable = nn.Linear(
|
120 |
+
self.input_dim,
|
121 |
+
self.output_dim - freeze_until_idx,
|
122 |
+
bias=False,
|
123 |
+
dtype=original_linear.weight.dtype,
|
124 |
+
device=original_linear.weight.device,
|
125 |
+
)
|
126 |
+
|
127 |
+
# Copy weights from the original linear layer
|
128 |
+
with torch.no_grad():
|
129 |
+
self.linear_frozen.weight.copy_(original_linear.weight[:freeze_until_idx])
|
130 |
+
self.linear_trainable.weight.copy_(original_linear.weight[freeze_until_idx:])
|
131 |
+
|
132 |
+
# Freeze the frozen linear layer
|
133 |
+
self.linear_frozen.weight.requires_grad = False
|
134 |
+
|
135 |
+
def forward(self, input_tensor):
|
136 |
+
# input_tensor: (bsz, seq_len, hidden_state_dim)
|
137 |
+
frozen_output = self.linear_frozen(input_tensor)
|
138 |
+
trainable_output = self.linear_trainable(input_tensor)
|
139 |
+
return torch.cat((frozen_output, trainable_output), dim=-1)
|
140 |
+
|
141 |
+
def to_unsplit(self) -> nn.Linear:
|
142 |
+
unsplit_linear = nn.Linear(
|
143 |
+
self.input_dim,
|
144 |
+
self.output_dim,
|
145 |
+
bias=False,
|
146 |
+
dtype=self.linear_frozen.weight.dtype,
|
147 |
+
device=self.linear_frozen.weight.device,
|
148 |
+
)
|
149 |
+
|
150 |
+
# Copy weights from the frozen and trainable layers into the unsplit linear layer
|
151 |
+
with torch.no_grad():
|
152 |
+
unsplit_linear.weight[: self.freeze_until_idx].copy_(self.linear_frozen.weight)
|
153 |
+
unsplit_linear.weight[self.freeze_until_idx :].copy_(self.linear_trainable.weight)
|
154 |
+
|
155 |
+
return unsplit_linear
|
higgs_audio/model/modeling_higgs_audio.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
higgs_audio/model/utils.py
ADDED
@@ -0,0 +1,778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
from contextlib import contextmanager
|
3 |
+
from functools import wraps
|
4 |
+
import torch
|
5 |
+
from transformers.integrations import is_deepspeed_available
|
6 |
+
|
7 |
+
if is_deepspeed_available():
|
8 |
+
from deepspeed.utils import groups as deepspeed_groups
|
9 |
+
from deepspeed.sequence.layer import _SeqAllToAll
|
10 |
+
else:
|
11 |
+
deepspeed_groups = None
|
12 |
+
_SeqAllToAll = None
|
13 |
+
|
14 |
+
|
15 |
+
def _ceil_to_nearest(n, round_to):
|
16 |
+
return (n + round_to - 1) // round_to * round_to
|
17 |
+
|
18 |
+
|
19 |
+
def count_parameters(model, trainable_only=True):
|
20 |
+
if trainable_only:
|
21 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
22 |
+
else:
|
23 |
+
return sum(p.numel() for p in model.parameters())
|
24 |
+
|
25 |
+
|
26 |
+
# TODO(sxjscience) Consider to move the function to audio_processing/utils.py
|
27 |
+
def build_delay_pattern_mask(
|
28 |
+
input_ids: torch.LongTensor,
|
29 |
+
bos_token_id: int,
|
30 |
+
pad_token_id: int,
|
31 |
+
):
|
32 |
+
"""Implement the delay pattern proposed in "Simple and Controllable Music Generation", https://arxiv.org/pdf/2306.05284
|
33 |
+
|
34 |
+
In the delay pattern, each codebook is offset by the previous codebook by
|
35 |
+
one. We insert a special delay token at the start of the sequence if its delayed, and append pad token once the sequence finishes.
|
36 |
+
|
37 |
+
Take the example where there are 4 codebooks and audio sequence length=5. After shifting, the output should have length seq_len + num_codebooks - 1
|
38 |
+
|
39 |
+
- [ *, *, *, *, *, P, P, P]
|
40 |
+
- [ B, *, *, *, *, *, P, P]
|
41 |
+
- [ B, B, *, *, *, *, *, P]
|
42 |
+
- [ B, B, B, *, *, *, *, *]
|
43 |
+
|
44 |
+
where B indicates the delay token id, P is the special padding token id and `*` indicates that the original audio token.
|
45 |
+
|
46 |
+
Now let's consider the case where we have a sequence of audio tokens to condition on.
|
47 |
+
The audio tokens were originally in the following non-delayed form:
|
48 |
+
|
49 |
+
- [a, b]
|
50 |
+
- [c, d]
|
51 |
+
- [e, f]
|
52 |
+
- [g, h]
|
53 |
+
|
54 |
+
After conversion, we get the following delayed form:
|
55 |
+
- [a, b, -1, -1, -1]
|
56 |
+
- [B, c, d, -1, -1]
|
57 |
+
- [B, B, e, f, -1]
|
58 |
+
- [B, B, B, g, h]
|
59 |
+
|
60 |
+
Note that we have a special token `-1` that indicates it should be replaced by a new token we see in the generation phase.
|
61 |
+
In that case, we should override the `-1` tokens in auto-regressive generation.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
input_ids (:obj:`torch.LongTensor`):
|
65 |
+
The input ids of the prompt. It will have shape (bsz, num_codebooks, seq_len).
|
66 |
+
bos_token_id (:obj:`int`):
|
67 |
+
The id of the special delay token
|
68 |
+
pad_token_id (:obj:`int`):
|
69 |
+
The id of the padding token. Should be the same as eos_token_id.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
input_ids (:obj:`torch.LongTensor`):
|
73 |
+
The transformed input ids with delay pattern applied. It will have shape (bsz, num_codebooks, seq_len + num_codebooks - 1).
|
74 |
+
input_ids_with_gen_mask (:obj:`torch.LongTensor`):
|
75 |
+
The transformed input ids with delay pattern applied. The -1 in the output indicates new tokens that should be generated.
|
76 |
+
|
77 |
+
"""
|
78 |
+
bsz, num_codebooks, seq_len = input_ids.shape
|
79 |
+
|
80 |
+
new_seq_len = seq_len + num_codebooks - 1
|
81 |
+
input_ids_with_gen_mask = torch.ones((bsz, num_codebooks, new_seq_len), dtype=torch.long, device=input_ids.device)
|
82 |
+
bos_mask = torch.tril(input_ids_with_gen_mask, -1) > 0
|
83 |
+
eos_mask = torch.triu(input_ids_with_gen_mask, seq_len) > 0
|
84 |
+
input_ids_with_gen_mask[bos_mask] = bos_token_id
|
85 |
+
input_ids_with_gen_mask[(~bos_mask) & (~eos_mask)] = input_ids.reshape(-1)
|
86 |
+
input_ids = input_ids_with_gen_mask.clone()
|
87 |
+
input_ids[eos_mask] = pad_token_id
|
88 |
+
input_ids_with_gen_mask[eos_mask] = -1
|
89 |
+
return input_ids, input_ids_with_gen_mask
|
90 |
+
|
91 |
+
|
92 |
+
def revert_delay_pattern(data):
|
93 |
+
"""Convert samples encoded with delay pattern back to the original form.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
data (:obj:`torch.Tensor`):
|
97 |
+
The data with delay pattern applied. It will have shape (num_codebooks, seq_len + num_codebooks - 1).
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
ret (:obj:`torch.Tensor`):
|
101 |
+
Recovered data with delay pattern removed. It will have shape (num_codebooks, seq_len).
|
102 |
+
"""
|
103 |
+
assert len(data.shape) == 2
|
104 |
+
out_l = []
|
105 |
+
num_codebooks = data.shape[0]
|
106 |
+
for i in range(num_codebooks):
|
107 |
+
out_l.append(data[i : (i + 1), i : (data.shape[1] - num_codebooks + 1 + i)])
|
108 |
+
return torch.cat(out_l, dim=0)
|
109 |
+
|
110 |
+
|
111 |
+
def merge_input_ids_with_audio_features(
|
112 |
+
audio_features_embed,
|
113 |
+
audio_features_length,
|
114 |
+
audio_in_embed,
|
115 |
+
audio_in_ids_start,
|
116 |
+
audio_out_embed,
|
117 |
+
audio_out_ids_start,
|
118 |
+
audio_in_token_idx,
|
119 |
+
audio_out_token_idx,
|
120 |
+
inputs_embeds,
|
121 |
+
input_ids,
|
122 |
+
attention_mask,
|
123 |
+
label_ids,
|
124 |
+
pad_token_id,
|
125 |
+
ignore_index=-100,
|
126 |
+
round_to=8,
|
127 |
+
left_padding=True,
|
128 |
+
):
|
129 |
+
"""
|
130 |
+
Merge input_ids with audio features into final embeddings.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
audio_features_embed (`torch.Tensor` of shape `(num_audios, max_audio_tokens, embed_dim)`):
|
134 |
+
Encoded vectors of all audios in the batch (obtained from the semantic encoder)
|
135 |
+
audio_features_length (`torch.LongTensor` of shape `(num_audios,)`):
|
136 |
+
The length of audio embeddings of each audio as stacked in `audio_features_embed`
|
137 |
+
audio_in_embed (`torch.Tensor` of shape `(total_num_audio_in_tokens, embed_dim)`):
|
138 |
+
The embeddings of audio-in tokens
|
139 |
+
audio_in_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
|
140 |
+
The start index of the audio-in tokens for each audio
|
141 |
+
audio_out_embed (`torch.Tensor` of shape `(total_num_audio_out_tokens, embed_dim)`):
|
142 |
+
The embeddings of audio-out tokens
|
143 |
+
audio_out_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
|
144 |
+
The start index of the audio-out tokens for each audio
|
145 |
+
audio_in_token_idx
|
146 |
+
The index of the audio-in token in the vocabulary
|
147 |
+
audio_out_token_idx
|
148 |
+
The index of the audio-out token in the vocabulary
|
149 |
+
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
|
150 |
+
Token embeddings before merging with audio embeddings
|
151 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
152 |
+
Input_ids of tokens, possibly filled with audio token
|
153 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
154 |
+
Mask to avoid performing attention on padding token indices.
|
155 |
+
label_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
|
156 |
+
labels need to be recalculated to support training (if provided)
|
157 |
+
pad_token_id (`int`):
|
158 |
+
The index of the pad token in the vocabulary
|
159 |
+
ignore_index
|
160 |
+
The index to ignore in the loss calculation
|
161 |
+
round_to
|
162 |
+
The number to round to for padding
|
163 |
+
left_padding
|
164 |
+
Whether to apply left padding
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
final_embedding
|
168 |
+
The final embeddings after merging audio embeddings with text embeddings.
|
169 |
+
final_attention_mask
|
170 |
+
The final attention mask after merging audio embeddings with text embeddings.
|
171 |
+
final_labels
|
172 |
+
The labels for the text stream
|
173 |
+
position_ids
|
174 |
+
Positional ids for the merged data
|
175 |
+
final_input_ids
|
176 |
+
The final input_ids after merging audio embeddings with text embeddings.
|
177 |
+
final_audio_in_mask
|
178 |
+
Mask for audio-in embeddings
|
179 |
+
final_audio_in_discrete_codes_mask
|
180 |
+
Mask for audio-in discrete tokens
|
181 |
+
final_audio_out_mask
|
182 |
+
Mask for audio-out embeddings
|
183 |
+
|
184 |
+
Explanation:
|
185 |
+
each audio has variable length embeddings, with length specified by
|
186 |
+
- audio_features_length
|
187 |
+
- audio_in_ids_start
|
188 |
+
- audio_out_ids_start
|
189 |
+
|
190 |
+
Task:
|
191 |
+
- fill each <|AUDIO|> with audio embeddings (it can be the combination of embeddings extracted by WhisperEncoder and embeddings from audio codebooks)
|
192 |
+
- fill each <|AUDIO_OUT|> with the audio-out embeddings
|
193 |
+
|
194 |
+
Example:
|
195 |
+
<|AUDIO_OUT|>: X (5 tokens), Y (3 tokens)
|
196 |
+
<|AUDIO|>: Z (8 tokens)
|
197 |
+
|
198 |
+
X, Y are in the same sequence (in-context voice-clone). Z is in a different sequence (audio understanding).
|
199 |
+
if right padding
|
200 |
+
input_ids: [
|
201 |
+
a b c d e f X g h i j k Y l m
|
202 |
+
o p q r Z s t u v _ _ _ _ _ _
|
203 |
+
]
|
204 |
+
input_ids should be: [
|
205 |
+
a b c d e f X X X X X g h i j k Y Y Y l m
|
206 |
+
o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
|
207 |
+
]
|
208 |
+
labels should be: [
|
209 |
+
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
|
210 |
+
o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
|
211 |
+
]
|
212 |
+
elif left padding
|
213 |
+
input_ids: [
|
214 |
+
a b c d e f X g h i j k Y l m
|
215 |
+
_ _ _ _ _ _ o p q r Z s t u v
|
216 |
+
]
|
217 |
+
input_ids should be: [
|
218 |
+
a b c d e f X X X X X g h i j k Y Y Y l m
|
219 |
+
_ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
|
220 |
+
]
|
221 |
+
labels should be: [
|
222 |
+
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
|
223 |
+
_ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
|
224 |
+
]
|
225 |
+
|
226 |
+
"""
|
227 |
+
if label_ids is None:
|
228 |
+
skip_labels = True
|
229 |
+
else:
|
230 |
+
skip_labels = False
|
231 |
+
if audio_features_embed is not None and audio_features_embed.shape[0] == 0:
|
232 |
+
audio_features_embed = None
|
233 |
+
if audio_in_embed is not None and audio_in_embed.shape[0] == 0:
|
234 |
+
audio_in_embed = None
|
235 |
+
if audio_out_embed is not None and audio_out_embed.shape[0] == 0:
|
236 |
+
audio_out_embed = None
|
237 |
+
|
238 |
+
batch_size, sequence_length, embed_dim = inputs_embeds.shape
|
239 |
+
|
240 |
+
target_device = inputs_embeds.device
|
241 |
+
if left_padding is None:
|
242 |
+
left_padding = torch.any(attention_mask[:, 0] == 0)
|
243 |
+
|
244 |
+
audio_in_token_mask = input_ids == audio_in_token_idx
|
245 |
+
audio_out_token_mask = input_ids == audio_out_token_idx
|
246 |
+
text_token_mask = (input_ids != audio_in_token_idx) & (input_ids != audio_out_token_idx)
|
247 |
+
|
248 |
+
# 1. Calculate the number of tokens for each placeholder (like [<|AUDIO|>, <|AUDIO_OUT|>]).
|
249 |
+
token_placeholder_num = torch.ones_like(input_ids)
|
250 |
+
|
251 |
+
if audio_features_embed is not None:
|
252 |
+
num_audios, max_audio_tokens, _ = audio_features_embed.shape
|
253 |
+
audio_in_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to(
|
254 |
+
audio_features_length.device
|
255 |
+
) < audio_features_length.unsqueeze(1)
|
256 |
+
masked_audio_in_features = audio_features_embed[audio_in_features_mask].view(-1, embed_dim)
|
257 |
+
token_placeholder_num[audio_in_token_mask] = audio_features_length.long()
|
258 |
+
|
259 |
+
if audio_in_embed is not None:
|
260 |
+
audio_in_codes_length = torch.concat(
|
261 |
+
[
|
262 |
+
audio_in_ids_start[1:] - audio_in_ids_start[:-1],
|
263 |
+
torch.tensor(
|
264 |
+
[audio_in_embed.shape[0] - audio_in_ids_start[-1]],
|
265 |
+
device=audio_in_ids_start.device,
|
266 |
+
dtype=torch.long,
|
267 |
+
),
|
268 |
+
],
|
269 |
+
dim=0,
|
270 |
+
)
|
271 |
+
if audio_features_embed is not None:
|
272 |
+
token_placeholder_num[audio_in_token_mask] += audio_in_codes_length.long()
|
273 |
+
else:
|
274 |
+
token_placeholder_num[audio_in_token_mask] = audio_in_codes_length.long()
|
275 |
+
|
276 |
+
if audio_out_embed is not None:
|
277 |
+
audio_out_codes_length = torch.concat(
|
278 |
+
[
|
279 |
+
audio_out_ids_start[1:] - audio_out_ids_start[:-1],
|
280 |
+
torch.tensor(
|
281 |
+
[audio_out_embed.shape[0] - audio_out_ids_start[-1]],
|
282 |
+
device=audio_out_ids_start.device,
|
283 |
+
dtype=torch.long,
|
284 |
+
),
|
285 |
+
],
|
286 |
+
dim=0,
|
287 |
+
)
|
288 |
+
token_placeholder_num[audio_out_token_mask] = audio_out_codes_length.long()
|
289 |
+
|
290 |
+
new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1
|
291 |
+
max_token_num = _ceil_to_nearest(token_placeholder_num.sum(-1).max(), round_to)
|
292 |
+
nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1]
|
293 |
+
|
294 |
+
if left_padding:
|
295 |
+
new_token_positions += nb_audio_pad[:, None] # offset for left padding
|
296 |
+
|
297 |
+
# 2. Create the full embedding, already padded to the maximum position
|
298 |
+
final_embedding = torch.zeros(
|
299 |
+
(batch_size, max_token_num, embed_dim),
|
300 |
+
dtype=inputs_embeds.dtype,
|
301 |
+
device=inputs_embeds.device,
|
302 |
+
)
|
303 |
+
final_attention_mask = torch.zeros(
|
304 |
+
(batch_size, max_token_num),
|
305 |
+
dtype=attention_mask.dtype,
|
306 |
+
device=inputs_embeds.device,
|
307 |
+
)
|
308 |
+
final_input_ids = torch.full(
|
309 |
+
(batch_size, max_token_num),
|
310 |
+
pad_token_id,
|
311 |
+
dtype=input_ids.dtype,
|
312 |
+
device=inputs_embeds.device,
|
313 |
+
)
|
314 |
+
if skip_labels:
|
315 |
+
final_labels = None
|
316 |
+
else:
|
317 |
+
final_labels = torch.full(
|
318 |
+
(batch_size, max_token_num),
|
319 |
+
ignore_index,
|
320 |
+
dtype=label_ids.dtype,
|
321 |
+
device=inputs_embeds.device,
|
322 |
+
)
|
323 |
+
|
324 |
+
final_audio_in_mask = torch.full(
|
325 |
+
(batch_size, max_token_num),
|
326 |
+
False,
|
327 |
+
dtype=torch.bool,
|
328 |
+
device=inputs_embeds.device,
|
329 |
+
)
|
330 |
+
final_audio_in_discrete_codes_mask = torch.full(
|
331 |
+
(batch_size, max_token_num),
|
332 |
+
False,
|
333 |
+
dtype=torch.bool,
|
334 |
+
device=inputs_embeds.device,
|
335 |
+
)
|
336 |
+
final_audio_out_mask = torch.full(
|
337 |
+
(batch_size, max_token_num),
|
338 |
+
False,
|
339 |
+
dtype=torch.bool,
|
340 |
+
device=inputs_embeds.device,
|
341 |
+
)
|
342 |
+
# 3. Get the audio-in token positions and audio-out token positions
|
343 |
+
batch_id = torch.arange(batch_size, device=target_device).unsqueeze(1).expand(batch_size, sequence_length)
|
344 |
+
audio_in_batch_id = batch_id[audio_in_token_mask] # Shape (num_audio_in,)
|
345 |
+
audio_out_batch_id = batch_id[audio_out_token_mask] # Shape (num_audio_out,)
|
346 |
+
audio_features_token_ends = new_token_positions[audio_in_token_mask] # Shape (num_audio_in,)
|
347 |
+
audio_out_embed_ends = new_token_positions[audio_out_token_mask] # Shape (num_audio_out,)
|
348 |
+
|
349 |
+
if audio_in_embed is not None:
|
350 |
+
# Fill in the audio-in embeddings
|
351 |
+
seq_indices = (
|
352 |
+
torch.arange(max_token_num, device=target_device)
|
353 |
+
.unsqueeze(0)
|
354 |
+
.expand(audio_in_ids_start.shape[0], max_token_num)
|
355 |
+
)
|
356 |
+
audio_in_embed_token_starts = audio_features_token_ends - audio_in_codes_length + 1
|
357 |
+
batch_indices, col_indices = torch.where(
|
358 |
+
(seq_indices >= audio_in_embed_token_starts.unsqueeze(1))
|
359 |
+
& (seq_indices <= audio_features_token_ends.unsqueeze(1))
|
360 |
+
)
|
361 |
+
batch_indices = audio_in_batch_id[batch_indices]
|
362 |
+
final_embedding[batch_indices, col_indices] = audio_in_embed
|
363 |
+
final_input_ids[batch_indices, col_indices] = audio_in_token_idx
|
364 |
+
if not skip_labels:
|
365 |
+
final_labels[batch_indices, col_indices] = ignore_index
|
366 |
+
final_audio_in_mask[batch_indices, col_indices] = True
|
367 |
+
final_audio_in_discrete_codes_mask[batch_indices, col_indices] = True
|
368 |
+
audio_features_token_ends = audio_features_token_ends - audio_in_codes_length
|
369 |
+
|
370 |
+
if audio_features_embed is not None:
|
371 |
+
# Fill in the audio features
|
372 |
+
seq_indices = (
|
373 |
+
torch.arange(max_token_num, device=target_device)
|
374 |
+
.unsqueeze(0)
|
375 |
+
.expand(audio_features_embed.shape[0], max_token_num)
|
376 |
+
)
|
377 |
+
audio_features_token_starts = audio_features_token_ends - audio_features_length + 1
|
378 |
+
batch_indices, col_indices = torch.where(
|
379 |
+
(seq_indices >= audio_features_token_starts.unsqueeze(1))
|
380 |
+
& (seq_indices <= audio_features_token_ends.unsqueeze(1))
|
381 |
+
)
|
382 |
+
batch_indices = audio_in_batch_id[batch_indices]
|
383 |
+
final_embedding[batch_indices, col_indices] = masked_audio_in_features
|
384 |
+
final_input_ids[batch_indices, col_indices] = audio_in_token_idx
|
385 |
+
if not skip_labels:
|
386 |
+
final_labels[batch_indices, col_indices] = ignore_index
|
387 |
+
final_audio_in_mask[batch_indices, col_indices] = True
|
388 |
+
|
389 |
+
if audio_out_embed is not None:
|
390 |
+
# Fill in the audio-out embeddings
|
391 |
+
seq_indices = (
|
392 |
+
torch.arange(max_token_num, device=target_device)
|
393 |
+
.unsqueeze(0)
|
394 |
+
.expand(audio_out_ids_start.shape[0], max_token_num)
|
395 |
+
)
|
396 |
+
audio_out_embed_token_starts = audio_out_embed_ends - audio_out_codes_length + 1
|
397 |
+
batch_indices, col_indices = torch.where(
|
398 |
+
(seq_indices >= audio_out_embed_token_starts.unsqueeze(1))
|
399 |
+
& (seq_indices <= audio_out_embed_ends.unsqueeze(1))
|
400 |
+
)
|
401 |
+
batch_indices = audio_out_batch_id[batch_indices]
|
402 |
+
final_embedding[batch_indices, col_indices] = audio_out_embed
|
403 |
+
final_input_ids[batch_indices, col_indices] = audio_out_token_idx
|
404 |
+
if not skip_labels:
|
405 |
+
final_labels[batch_indices, col_indices] = ignore_index
|
406 |
+
final_audio_out_mask[batch_indices, col_indices] = True
|
407 |
+
|
408 |
+
# Fill in the original text embeddings and labels
|
409 |
+
batch_indices, non_audio_indices = torch.where(text_token_mask)
|
410 |
+
text_to_overwrite = new_token_positions[batch_indices, non_audio_indices]
|
411 |
+
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_audio_indices]
|
412 |
+
if not skip_labels:
|
413 |
+
final_labels[batch_indices, text_to_overwrite] = label_ids[batch_indices, non_audio_indices]
|
414 |
+
final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_audio_indices]
|
415 |
+
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_audio_indices]
|
416 |
+
final_attention_mask = final_attention_mask | final_audio_in_mask | final_audio_out_mask
|
417 |
+
|
418 |
+
# Trim the tensor if there are redundant padding tokens
|
419 |
+
if left_padding:
|
420 |
+
first_non_zero_loc = final_attention_mask.sum(0).nonzero()[0]
|
421 |
+
first_non_zero_loc = (first_non_zero_loc // round_to) * round_to
|
422 |
+
if first_non_zero_loc > 0:
|
423 |
+
final_attention_mask = final_attention_mask[:, first_non_zero_loc:]
|
424 |
+
final_embedding = final_embedding[:, first_non_zero_loc:]
|
425 |
+
if not skip_labels:
|
426 |
+
final_labels = final_labels[:, first_non_zero_loc:]
|
427 |
+
final_input_ids = final_input_ids[:, first_non_zero_loc:]
|
428 |
+
final_audio_in_mask = final_audio_in_mask[:, first_non_zero_loc:]
|
429 |
+
final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, first_non_zero_loc:]
|
430 |
+
final_audio_out_mask = final_audio_out_mask[:, first_non_zero_loc:]
|
431 |
+
else:
|
432 |
+
# We have done right padding, so we need to trim the mask
|
433 |
+
last_non_zero_loc = final_attention_mask.sum(0).nonzero()[-1] + 1
|
434 |
+
last_non_zero_loc = ((last_non_zero_loc + round_to - 1) // round_to) * round_to
|
435 |
+
if last_non_zero_loc < max_token_num:
|
436 |
+
final_attention_mask = final_attention_mask[:, :last_non_zero_loc]
|
437 |
+
final_embedding = final_embedding[:, :last_non_zero_loc]
|
438 |
+
if not skip_labels:
|
439 |
+
final_labels = final_labels[:, :last_non_zero_loc]
|
440 |
+
final_input_ids = final_input_ids[:, :last_non_zero_loc]
|
441 |
+
final_audio_in_mask = final_audio_in_mask[:, :last_non_zero_loc]
|
442 |
+
final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, :last_non_zero_loc]
|
443 |
+
final_audio_out_mask = final_audio_out_mask[:, :last_non_zero_loc]
|
444 |
+
|
445 |
+
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
446 |
+
return (
|
447 |
+
final_embedding,
|
448 |
+
final_attention_mask,
|
449 |
+
final_labels,
|
450 |
+
position_ids,
|
451 |
+
final_input_ids,
|
452 |
+
final_audio_in_mask,
|
453 |
+
final_audio_in_discrete_codes_mask,
|
454 |
+
final_audio_out_mask,
|
455 |
+
)
|
456 |
+
|
457 |
+
|
458 |
+
def is_deepspeed_ulysses_enabled():
|
459 |
+
if deepspeed_groups is None:
|
460 |
+
return False
|
461 |
+
|
462 |
+
"""Check if sequence parallelism is enabled."""
|
463 |
+
return deepspeed_groups._get_sequence_parallel_world_size() > 1
|
464 |
+
|
465 |
+
|
466 |
+
def support_deepspeed_ulysses(module):
|
467 |
+
"""A decorator around Pytorch module. It is needed for the module that needs access to sequence parallel info."""
|
468 |
+
module._sp_size = None
|
469 |
+
module._sp_rank = None
|
470 |
+
module._sp_group = None
|
471 |
+
|
472 |
+
@property
|
473 |
+
def sp_size(self):
|
474 |
+
if self._sp_size is None:
|
475 |
+
self._sp_size = 1
|
476 |
+
if is_deepspeed_ulysses_enabled():
|
477 |
+
self._sp_size = deepspeed_groups._get_sequence_parallel_group().size()
|
478 |
+
return self._sp_size
|
479 |
+
|
480 |
+
@property
|
481 |
+
def sp_rank(self):
|
482 |
+
if self._sp_rank is None:
|
483 |
+
self._sp_rank = 0
|
484 |
+
if is_deepspeed_ulysses_enabled():
|
485 |
+
self._sp_rank = deepspeed_groups._get_sequence_parallel_rank()
|
486 |
+
return self._sp_rank
|
487 |
+
|
488 |
+
@property
|
489 |
+
def sp_group(self):
|
490 |
+
if self._sp_group is None and is_deepspeed_ulysses_enabled():
|
491 |
+
self._sp_group = deepspeed_groups._get_sequence_parallel_group()
|
492 |
+
return self._sp_group
|
493 |
+
|
494 |
+
module.sp_size = sp_size
|
495 |
+
module.sp_rank = sp_rank
|
496 |
+
module.sp_group = sp_group
|
497 |
+
|
498 |
+
return module
|
499 |
+
|
500 |
+
|
501 |
+
def deepspeed_ulysses_attention(seq_dim=1, head_dim=2):
|
502 |
+
"""Perform all-to-all before and after the attention function."""
|
503 |
+
|
504 |
+
def attention_decorator(attn_func=None):
|
505 |
+
def wrapped(*args, **kwargs):
|
506 |
+
if is_deepspeed_ulysses_enabled():
|
507 |
+
sp_group = deepspeed_groups._get_sequence_parallel_group()
|
508 |
+
scatter_idx = head_dim # Scatter on num_heads dimension
|
509 |
+
gather_idx = seq_dim # Gather on seq_len dimension
|
510 |
+
batch_dim_idx = 0
|
511 |
+
args = list(args)
|
512 |
+
args[0] = _SeqAllToAll.apply(sp_group, args[0], scatter_idx, gather_idx, batch_dim_idx)
|
513 |
+
args[1] = _SeqAllToAll.apply(sp_group, args[1], scatter_idx, gather_idx, batch_dim_idx)
|
514 |
+
args[2] = _SeqAllToAll.apply(sp_group, args[2], scatter_idx, gather_idx, batch_dim_idx)
|
515 |
+
args = tuple(args)
|
516 |
+
|
517 |
+
attn_output = attn_func(*args, **kwargs)
|
518 |
+
|
519 |
+
if is_deepspeed_ulysses_enabled():
|
520 |
+
scatter_idx = seq_dim # Scatter back on seq_len dimension
|
521 |
+
gather_idx = head_dim # Gather on num_heads dimension
|
522 |
+
batch_dim_idx = 0
|
523 |
+
attn_output = _SeqAllToAll.apply(sp_group, attn_output, scatter_idx, gather_idx, batch_dim_idx)
|
524 |
+
|
525 |
+
return attn_output
|
526 |
+
|
527 |
+
return wrapped
|
528 |
+
|
529 |
+
return attention_decorator
|
530 |
+
|
531 |
+
|
532 |
+
def deepspeed_ulysses_rope(state_seq_dim=2, trig_seq_dim=1):
|
533 |
+
"""Slice the corresponding cos and sin chunks for rope."""
|
534 |
+
|
535 |
+
def rope_decorator(rope_func=None):
|
536 |
+
def wrapped(*args, **kwargs):
|
537 |
+
if is_deepspeed_ulysses_enabled():
|
538 |
+
sp_rank = deepspeed_groups._get_sequence_parallel_rank()
|
539 |
+
args = list(args)
|
540 |
+
seq_chunk_size = args[0].size(state_seq_dim)
|
541 |
+
args[2] = torch.narrow(args[2], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
|
542 |
+
args[3] = torch.narrow(args[3], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
|
543 |
+
args = tuple(args)
|
544 |
+
|
545 |
+
return rope_func(*args, **kwargs)
|
546 |
+
|
547 |
+
return wrapped
|
548 |
+
|
549 |
+
return rope_decorator
|
550 |
+
|
551 |
+
|
552 |
+
def _gather_tensors(input_, group=None):
|
553 |
+
"""Gather tensors and concatenate them along a dimension."""
|
554 |
+
input_ = input_.contiguous()
|
555 |
+
world_size = torch.distributed.get_world_size(group)
|
556 |
+
if world_size == 1:
|
557 |
+
return input_
|
558 |
+
tensor_shapes = [
|
559 |
+
torch.empty(len(input_.size()), dtype=torch.int64, device=input_.device) for _ in range(world_size)
|
560 |
+
]
|
561 |
+
input_size = torch.tensor(input_.size(), dtype=torch.int64, device=input_.device)
|
562 |
+
torch.distributed.all_gather(tensor_shapes, input_size, group=group)
|
563 |
+
gathered_buffers = [
|
564 |
+
torch.empty(tensor_shapes[i].tolist(), dtype=input_.dtype, device=input_.device) for i in range(world_size)
|
565 |
+
]
|
566 |
+
torch.distributed.all_gather(gathered_buffers, input_, group=group)
|
567 |
+
return gathered_buffers
|
568 |
+
|
569 |
+
|
570 |
+
def _scatter_tensors(input_, group=None):
|
571 |
+
"""Scatter tensors."""
|
572 |
+
world_size = torch.distributed.get_world_size(group)
|
573 |
+
if world_size == 1:
|
574 |
+
return input_
|
575 |
+
rank = torch.distributed.get_rank(group)
|
576 |
+
return input_[rank]
|
577 |
+
|
578 |
+
|
579 |
+
class _GatherTensors(torch.autograd.Function):
|
580 |
+
"""All gather tensors among the ranks."""
|
581 |
+
|
582 |
+
@staticmethod
|
583 |
+
def symbolic(graph, input_, group):
|
584 |
+
return _gather_tensors(input_, group)
|
585 |
+
|
586 |
+
@staticmethod
|
587 |
+
def forward(ctx, input_, group):
|
588 |
+
ctx.group = group
|
589 |
+
return torch.nested.as_nested_tensor(_gather_tensors(input_, group), layout=torch.jagged)
|
590 |
+
|
591 |
+
@staticmethod
|
592 |
+
def backward(ctx, grad_output):
|
593 |
+
return _scatter_tensors(grad_output, ctx.group), None
|
594 |
+
|
595 |
+
|
596 |
+
def all_gather_tensors(input_, size=None, dim=0, group=None):
|
597 |
+
if torch.distributed.get_world_size(group) == 1:
|
598 |
+
# no sequence parallelism
|
599 |
+
return input_
|
600 |
+
gathered_tensors = _GatherTensors.apply(input_, group)
|
601 |
+
|
602 |
+
if size:
|
603 |
+
split_gathered_tensors = []
|
604 |
+
for s, gathered_tensor in zip(size, gathered_tensors):
|
605 |
+
split_gathered_tensor = torch.split(gathered_tensor, s.tolist())
|
606 |
+
split_gathered_tensors.append(split_gathered_tensor)
|
607 |
+
|
608 |
+
gathered_tensors = [y for x in zip(*split_gathered_tensors) for y in x]
|
609 |
+
|
610 |
+
return torch.cat(gathered_tensors, dim).contiguous()
|
611 |
+
|
612 |
+
|
613 |
+
def get_sequence_data_parallel_world_size():
|
614 |
+
return torch.distributed.get_world_size()
|
615 |
+
|
616 |
+
|
617 |
+
def get_sequence_data_parallel_rank():
|
618 |
+
return torch.distributed.get_rank()
|
619 |
+
|
620 |
+
|
621 |
+
def get_sequence_data_parallel_group():
|
622 |
+
return torch.distributed.group.WORLD
|
623 |
+
|
624 |
+
|
625 |
+
if is_deepspeed_available():
|
626 |
+
deepspeed_groups._get_sequence_data_parallel_world_size = get_sequence_data_parallel_world_size
|
627 |
+
deepspeed_groups._get_sequence_data_parallel_rank = get_sequence_data_parallel_rank
|
628 |
+
deepspeed_groups._get_sequence_data_parallel_group = get_sequence_data_parallel_group
|
629 |
+
|
630 |
+
|
631 |
+
def _gather_tokens(input_, dim=0, group=None):
|
632 |
+
"""Gather tensors and concatenate them along a dimension"""
|
633 |
+
input_ = input_.contiguous()
|
634 |
+
world_size = torch.distributed.get_world_size(group)
|
635 |
+
if world_size == 1:
|
636 |
+
return input_
|
637 |
+
|
638 |
+
gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device)
|
639 |
+
torch.distributed.all_gather_into_tensor(gather_buffer, input_, group=group)
|
640 |
+
if dim == 0:
|
641 |
+
shape = list(input_.size())
|
642 |
+
shape[0] = shape[0] * world_size
|
643 |
+
output = gather_buffer.view(shape)
|
644 |
+
else:
|
645 |
+
tensor_list = [
|
646 |
+
gather_buffer.narrow(0, input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size)
|
647 |
+
]
|
648 |
+
# Note: torch.cat already creates a contiguous tensor.
|
649 |
+
output = torch.cat(tensor_list, dim=dim).contiguous()
|
650 |
+
|
651 |
+
return output
|
652 |
+
|
653 |
+
|
654 |
+
def _drop_tokens(input_, dim=0, group=None):
|
655 |
+
"""Divide a tensor among the sequence parallel ranks"""
|
656 |
+
world_size = torch.distributed.get_world_size(group)
|
657 |
+
if world_size == 1:
|
658 |
+
return input_
|
659 |
+
this_rank = torch.distributed.get_rank(group)
|
660 |
+
assert input_.shape[dim] % world_size == 0, (
|
661 |
+
f"input dimension {dim} ({input_.shape[dim]}) is not divisible by sequence parallel world size ({world_size})"
|
662 |
+
)
|
663 |
+
chunk_size = input_.shape[dim] // world_size
|
664 |
+
|
665 |
+
return torch.narrow(input_, dim, this_rank * chunk_size, chunk_size)
|
666 |
+
|
667 |
+
|
668 |
+
class _DropTokens(torch.autograd.Function):
|
669 |
+
"Divide tokens equally among the sequence parallel ranks"
|
670 |
+
|
671 |
+
@staticmethod
|
672 |
+
def symbolic(graph, input_, dim, group, grad_scale):
|
673 |
+
return _drop_tokens(input_, dim, group)
|
674 |
+
|
675 |
+
@staticmethod
|
676 |
+
def forward(ctx, input_, dim, group, grad_scale):
|
677 |
+
ctx.dim = dim
|
678 |
+
ctx.group = group
|
679 |
+
ctx.grad_scale = grad_scale
|
680 |
+
return _drop_tokens(input_, dim, group)
|
681 |
+
|
682 |
+
@staticmethod
|
683 |
+
def backward(ctx, grad_output):
|
684 |
+
grad_input = _gather_tokens(grad_output, ctx.dim, ctx.group)
|
685 |
+
if ctx.grad_scale != 1:
|
686 |
+
grad_input /= ctx.grad_scale
|
687 |
+
return grad_input, None, None, None
|
688 |
+
|
689 |
+
|
690 |
+
class _GatherTokens(torch.autograd.Function):
|
691 |
+
"Gather tokens among the sequence parallel ranks"
|
692 |
+
|
693 |
+
@staticmethod
|
694 |
+
def symbolic(graph, input_, dim, group, grad_scale):
|
695 |
+
return _gather_tokens(input_, dim, group)
|
696 |
+
|
697 |
+
@staticmethod
|
698 |
+
def forward(ctx, input_, dim, group, grad_scale):
|
699 |
+
ctx.dim = dim
|
700 |
+
ctx.group = group
|
701 |
+
ctx.grad_scale = grad_scale
|
702 |
+
return _gather_tokens(input_, dim, group)
|
703 |
+
|
704 |
+
@staticmethod
|
705 |
+
def backward(ctx, grad_output):
|
706 |
+
grad_input = _drop_tokens(grad_output, ctx.dim, ctx.group)
|
707 |
+
if ctx.grad_scale != 1:
|
708 |
+
grad_input *= ctx.grad_scale
|
709 |
+
return grad_input, None, None, None
|
710 |
+
|
711 |
+
|
712 |
+
def drop_tokens(input_, dim=0, group=None, grad_scale=1):
|
713 |
+
if torch.distributed.get_world_size(group) == 1:
|
714 |
+
# no sequence parallelism
|
715 |
+
return input_
|
716 |
+
return _DropTokens.apply(input_, dim, group, grad_scale)
|
717 |
+
|
718 |
+
|
719 |
+
def gather_tokens(input_, dim=0, group=None, grad_scale=1):
|
720 |
+
if torch.distributed.get_world_size(group) == 1:
|
721 |
+
# no sequence parallelism
|
722 |
+
return input_
|
723 |
+
return _GatherTokens.apply(input_, dim, group, grad_scale)
|
724 |
+
|
725 |
+
|
726 |
+
def sequence_chunking_per_rank(sp_size, sp_rank, *args, dim=1):
|
727 |
+
"""
|
728 |
+
Slice the inputs to create chuncks per the sequence parallel rank. This is used for the context parallel training.
|
729 |
+
|
730 |
+
Args:
|
731 |
+
sp_size (`int`):
|
732 |
+
Sequence parallel size.
|
733 |
+
sp_rank (`int`):
|
734 |
+
Sequence parallel rank for the current process.
|
735 |
+
dim (`int`):
|
736 |
+
The dimension to slice
|
737 |
+
"""
|
738 |
+
if sp_size == 1:
|
739 |
+
return args[0] if len(args) == 1 else args
|
740 |
+
|
741 |
+
seq_length = args[0].size(dim)
|
742 |
+
for arg in args[1:]:
|
743 |
+
assert arg.size(dim) == seq_length, (
|
744 |
+
f"arg={arg} ({arg.shape[dim]}) does not have the same size as args[0] ({seq_length}) in dimension {dim}"
|
745 |
+
)
|
746 |
+
assert seq_length % sp_size == 0, (
|
747 |
+
f"dimension {dim} ({args[0].shape[dim]}) is not divisible by sequence parallel world size ({sp_size})"
|
748 |
+
)
|
749 |
+
|
750 |
+
sub_seq_length = seq_length // sp_size
|
751 |
+
sub_seq_start = sp_rank * sub_seq_length
|
752 |
+
|
753 |
+
output = []
|
754 |
+
for ind in args:
|
755 |
+
ind = torch.narrow(ind, dim, sub_seq_start, sub_seq_length)
|
756 |
+
output.append(ind)
|
757 |
+
|
758 |
+
return tuple(output) if len(output) > 1 else output[0]
|
759 |
+
|
760 |
+
|
761 |
+
@contextmanager
|
762 |
+
def disable_deepspeed_ulysses():
|
763 |
+
"""Disable deepspeed ulysses (sequence parallelism) if it is enabled"""
|
764 |
+
if is_deepspeed_ulysses_enabled():
|
765 |
+
_old_get_sequence_parallel_world_size = deepspeed_groups._get_sequence_parallel_world_size
|
766 |
+
|
767 |
+
def _get_sequence_parallel_world_size():
|
768 |
+
return 1
|
769 |
+
|
770 |
+
deepspeed_groups._get_sequence_parallel_world_size = _get_sequence_parallel_world_size
|
771 |
+
try:
|
772 |
+
yield
|
773 |
+
finally:
|
774 |
+
deepspeed_groups._get_sequence_parallel_world_size = _old_get_sequence_parallel_world_size
|
775 |
+
else:
|
776 |
+
context = contextlib.nullcontext
|
777 |
+
with context():
|
778 |
+
yield
|
higgs_audio/serve/serve_engine.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import base64
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from io import BytesIO
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import List, Optional, Union
|
8 |
+
from copy import deepcopy
|
9 |
+
from transformers import AutoTokenizer, AutoProcessor
|
10 |
+
from transformers.cache_utils import StaticCache
|
11 |
+
from transformers.generation.streamers import BaseStreamer
|
12 |
+
from transformers.generation.stopping_criteria import StoppingCriteria
|
13 |
+
from dataclasses import asdict
|
14 |
+
from loguru import logger
|
15 |
+
import threading
|
16 |
+
import librosa
|
17 |
+
|
18 |
+
|
19 |
+
from ..dataset.chatml_dataset import (
|
20 |
+
ChatMLSample,
|
21 |
+
ChatMLDatasetSample,
|
22 |
+
prepare_chatml_sample,
|
23 |
+
)
|
24 |
+
from ..model import HiggsAudioModel
|
25 |
+
from ..model.utils import revert_delay_pattern
|
26 |
+
from ..data_collator.higgs_audio_collator import HiggsAudioSampleCollator
|
27 |
+
from ..audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class HiggsAudioStreamerDelta:
|
32 |
+
"""Represents a chunk of generated content, either text or audio tokens."""
|
33 |
+
|
34 |
+
text: Optional[str] = None
|
35 |
+
text_tokens: Optional[torch.Tensor] = None
|
36 |
+
audio_tokens: Optional[torch.Tensor] = None
|
37 |
+
finish_reason: Optional[str] = None
|
38 |
+
|
39 |
+
|
40 |
+
class AsyncHiggsAudioStreamer(BaseStreamer):
|
41 |
+
"""
|
42 |
+
Async streamer that handles both text and audio token generation from Higgs-Audio model.
|
43 |
+
Stores chunks in a queue to be consumed by downstream applications.
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
tokenizer (`AutoTokenizer`):
|
47 |
+
The tokenizer used to decode text tokens.
|
48 |
+
skip_prompt (`bool`, *optional*, defaults to `False`):
|
49 |
+
Whether to skip the prompt tokens in generation.
|
50 |
+
timeout (`float`, *optional*):
|
51 |
+
The timeout for the queue. If `None`, the queue will block indefinitely.
|
52 |
+
decode_kwargs (`dict`, *optional*):
|
53 |
+
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
54 |
+
|
55 |
+
Examples:
|
56 |
+
```python
|
57 |
+
>>> from transformers import AutoTokenizer
|
58 |
+
>>> from threading import Thread
|
59 |
+
>>> import asyncio
|
60 |
+
|
61 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("path/to/higgs/tokenizer")
|
62 |
+
>>> model = HiggsAudioModel.from_pretrained("path/to/higgs/model")
|
63 |
+
>>> inputs = tokenizer(["Generate some text and audio:"], return_tensors="pt")
|
64 |
+
|
65 |
+
>>> async def main():
|
66 |
+
... streamer = AsyncHiggsAudioStreamer(tokenizer)
|
67 |
+
... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
|
68 |
+
... thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
69 |
+
... thread.start()
|
70 |
+
...
|
71 |
+
... async for delta in streamer:
|
72 |
+
... if delta.text is not None:
|
73 |
+
... print("Text:", delta.text)
|
74 |
+
... if delta.audio_tokens is not None:
|
75 |
+
... print("Audio tokens shape:", delta.audio_tokens.shape)
|
76 |
+
>>> asyncio.run(main())
|
77 |
+
```
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
tokenizer: "AutoTokenizer",
|
83 |
+
skip_prompt: bool = False,
|
84 |
+
timeout: Optional[float] = None,
|
85 |
+
audio_num_codebooks: int = 1,
|
86 |
+
**decode_kwargs,
|
87 |
+
):
|
88 |
+
self.tokenizer = tokenizer
|
89 |
+
self.skip_prompt = skip_prompt
|
90 |
+
self.timeout = timeout
|
91 |
+
self.decode_kwargs = decode_kwargs
|
92 |
+
self.audio_num_codebooks = audio_num_codebooks
|
93 |
+
|
94 |
+
# Queue to store generated chunks
|
95 |
+
self.queue = asyncio.Queue()
|
96 |
+
self.stop_signal = None
|
97 |
+
|
98 |
+
# Get running event loop
|
99 |
+
self.loop = asyncio.get_running_loop()
|
100 |
+
self.has_asyncio_timeout = hasattr(asyncio, "timeout")
|
101 |
+
|
102 |
+
# State tracking
|
103 |
+
self.next_tokens_are_prompt = True
|
104 |
+
|
105 |
+
def put(self, value: torch.Tensor):
|
106 |
+
"""
|
107 |
+
Receives tokens and processes them as either text or audio tokens.
|
108 |
+
For text tokens, decodes and caches them until complete words are formed.
|
109 |
+
For audio tokens, directly queues them.
|
110 |
+
"""
|
111 |
+
if value.shape[0] > 1 and not self.next_tokens_are_prompt:
|
112 |
+
# This is likely audio tokens (shape: [audio_num_codebooks])
|
113 |
+
assert value.shape[0] == self.audio_num_codebooks, "Number of codebooks mismatch"
|
114 |
+
delta = HiggsAudioStreamerDelta(audio_tokens=value)
|
115 |
+
self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
|
116 |
+
return
|
117 |
+
|
118 |
+
# Skip prompt tokens if configured
|
119 |
+
if self.skip_prompt and self.next_tokens_are_prompt:
|
120 |
+
self.next_tokens_are_prompt = False
|
121 |
+
return
|
122 |
+
|
123 |
+
# Process as text tokens
|
124 |
+
if len(value.shape) > 1:
|
125 |
+
value = value[0]
|
126 |
+
|
127 |
+
text = self.tokenizer.decode(value, **self.decode_kwargs)
|
128 |
+
delta = HiggsAudioStreamerDelta(text=text, text_tokens=value)
|
129 |
+
self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
|
130 |
+
|
131 |
+
def end(self):
|
132 |
+
"""Flushes any remaining text tokens and signals the end of generation."""
|
133 |
+
self.next_tokens_are_prompt = True
|
134 |
+
self.loop.call_soon_threadsafe(self.queue.put_nowait, self.stop_signal)
|
135 |
+
|
136 |
+
def __aiter__(self):
|
137 |
+
return self
|
138 |
+
|
139 |
+
async def __anext__(self):
|
140 |
+
try:
|
141 |
+
if self.has_asyncio_timeout:
|
142 |
+
async with asyncio.timeout(self.timeout):
|
143 |
+
value = await self.queue.get()
|
144 |
+
else:
|
145 |
+
value = await asyncio.wait_for(self.queue.get(), timeout=self.timeout)
|
146 |
+
except asyncio.TimeoutError:
|
147 |
+
raise TimeoutError()
|
148 |
+
else:
|
149 |
+
if value == self.stop_signal:
|
150 |
+
raise StopAsyncIteration()
|
151 |
+
else:
|
152 |
+
return value
|
153 |
+
|
154 |
+
|
155 |
+
class AsyncStoppingCriteria(StoppingCriteria):
|
156 |
+
"""
|
157 |
+
Stopping criteria that checks for stop signal from a threading event.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
stop_signal (threading.Event): Event that will receive stop signals
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(self, stop_signal: threading.Event):
|
164 |
+
self.stop_signal = stop_signal
|
165 |
+
|
166 |
+
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
167 |
+
if self.stop_signal.is_set():
|
168 |
+
logger.info(f"Stop signal received. Can be caused by client disconnection.")
|
169 |
+
return True
|
170 |
+
return False
|
171 |
+
|
172 |
+
|
173 |
+
@dataclass
|
174 |
+
class HiggsAudioResponse:
|
175 |
+
audio: Optional[np.ndarray] = None
|
176 |
+
generated_audio_tokens: Optional[np.ndarray] = None
|
177 |
+
sampling_rate: Optional[int] = None
|
178 |
+
generated_text: str = ""
|
179 |
+
generated_text_tokens: np.ndarray = np.array([])
|
180 |
+
usage: Optional[dict] = None
|
181 |
+
|
182 |
+
|
183 |
+
class HiggsAudioServeEngine:
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
model_name_or_path: str,
|
187 |
+
audio_tokenizer_name_or_path: str,
|
188 |
+
tokenizer_name_or_path: Optional[str] = None,
|
189 |
+
device: str = "cuda",
|
190 |
+
torch_dtype: Union[torch.dtype, str] = "auto",
|
191 |
+
kv_cache_lengths: List[int] = [1024, 4096, 8192], # Multiple KV cache sizes
|
192 |
+
):
|
193 |
+
"""
|
194 |
+
Initialize the HiggsAudioServeEngine, a serving wrapper for the HiggsAudioModel.
|
195 |
+
The model, tokenizer, and audio tokenizer will be downloaded from the Hugging Face Hub if they are not local.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
model_name_or_path (str):
|
199 |
+
The name or path of the model to load.
|
200 |
+
audio_tokenizer_name_or_path (str):
|
201 |
+
The name or path of the audio tokenizer to load.
|
202 |
+
tokenizer_name_or_path (str):
|
203 |
+
The name or path of the tokenizer to load.
|
204 |
+
device (str):
|
205 |
+
The device to use for the model.
|
206 |
+
kv_cache_lengths (List[int]):
|
207 |
+
The lengths of the KV caches to use for the model. Used for cuda graph capture when device is cuda.
|
208 |
+
torch_dtype (Union[torch.dtype, str]):
|
209 |
+
The dtype to use for the model.
|
210 |
+
"""
|
211 |
+
self.device = device
|
212 |
+
self.model_name_or_path = model_name_or_path
|
213 |
+
self.torch_dtype = torch_dtype
|
214 |
+
|
215 |
+
# Initialize model and tokenizer
|
216 |
+
self.model = HiggsAudioModel.from_pretrained(model_name_or_path, torch_dtype=torch_dtype).to(device)
|
217 |
+
logger.info(f"Loaded model from {model_name_or_path}, dtype: {self.model.dtype}")
|
218 |
+
|
219 |
+
if tokenizer_name_or_path is None:
|
220 |
+
tokenizer_name_or_path = model_name_or_path
|
221 |
+
logger.info(f"Loading tokenizer from {tokenizer_name_or_path}")
|
222 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
|
223 |
+
|
224 |
+
logger.info(f"Initializing Higgs Audio Tokenizer")
|
225 |
+
self.audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer_name_or_path, device=device)
|
226 |
+
|
227 |
+
self.audio_num_codebooks = self.model.config.audio_num_codebooks
|
228 |
+
self.audio_codebook_size = self.model.config.audio_codebook_size
|
229 |
+
self.audio_tokenizer_tps = self.audio_tokenizer.tps
|
230 |
+
self.samples_per_token = int(self.audio_tokenizer.sampling_rate // self.audio_tokenizer_tps)
|
231 |
+
self.hamming_window_len = 2 * self.audio_num_codebooks * self.samples_per_token
|
232 |
+
# Set the audio special tokens
|
233 |
+
self.model.set_audio_special_tokens(self.tokenizer)
|
234 |
+
|
235 |
+
# Prepare KV caches for different lengths
|
236 |
+
cache_config = deepcopy(self.model.config.text_config)
|
237 |
+
cache_config.num_hidden_layers = self.model.config.text_config.num_hidden_layers
|
238 |
+
if self.model.config.audio_dual_ffn_layers:
|
239 |
+
cache_config.num_hidden_layers += len(self.model.config.audio_dual_ffn_layers)
|
240 |
+
# A list of KV caches for different lengths
|
241 |
+
self.kv_caches = {
|
242 |
+
length: StaticCache(
|
243 |
+
config=cache_config,
|
244 |
+
max_batch_size=1,
|
245 |
+
max_cache_len=length,
|
246 |
+
device=self.model.device,
|
247 |
+
dtype=self.model.dtype,
|
248 |
+
)
|
249 |
+
for length in sorted(kv_cache_lengths)
|
250 |
+
}
|
251 |
+
|
252 |
+
if self.model.config.encode_whisper_embed:
|
253 |
+
logger.info(f"Loading whisper processor")
|
254 |
+
whisper_processor = AutoProcessor.from_pretrained(
|
255 |
+
"openai/whisper-large-v3-turbo",
|
256 |
+
trust_remote=True,
|
257 |
+
device=self.device,
|
258 |
+
)
|
259 |
+
else:
|
260 |
+
whisper_processor = None
|
261 |
+
|
262 |
+
# Reuse collator to prepare inference samples
|
263 |
+
self.collator = HiggsAudioSampleCollator(
|
264 |
+
whisper_processor=whisper_processor,
|
265 |
+
encode_whisper_embed=self.model.config.encode_whisper_embed,
|
266 |
+
audio_in_token_id=self.model.config.audio_in_token_idx,
|
267 |
+
audio_out_token_id=self.model.config.audio_out_token_idx,
|
268 |
+
audio_stream_bos_id=self.model.config.audio_stream_bos_id,
|
269 |
+
audio_stream_eos_id=self.model.config.audio_stream_eos_id,
|
270 |
+
pad_token_id=self.model.config.pad_token_id,
|
271 |
+
return_audio_in_tokens=False,
|
272 |
+
use_delay_pattern=self.model.config.use_delay_pattern,
|
273 |
+
audio_num_codebooks=self.model.config.audio_num_codebooks,
|
274 |
+
round_to=1,
|
275 |
+
)
|
276 |
+
|
277 |
+
# Lock to prevent multiple generations from happening at the same time
|
278 |
+
self.generate_lock = threading.Lock()
|
279 |
+
|
280 |
+
# Capture CUDA graphs for each KV cache length
|
281 |
+
if device == "cuda":
|
282 |
+
logger.info(f"Capturing CUDA graphs for each KV cache length")
|
283 |
+
self.model.capture_model(self.kv_caches.values())
|
284 |
+
|
285 |
+
def _prepare_inputs(self, chat_ml_sample: ChatMLSample, force_audio_gen: bool = False):
|
286 |
+
input_tokens, _, audio_contents, _ = prepare_chatml_sample(
|
287 |
+
chat_ml_sample,
|
288 |
+
self.tokenizer,
|
289 |
+
)
|
290 |
+
|
291 |
+
postfix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
292 |
+
if force_audio_gen:
|
293 |
+
postfix += "<|audio_out_bos|>"
|
294 |
+
postfix = self.tokenizer.encode(postfix, add_special_tokens=False)
|
295 |
+
input_tokens.extend(postfix)
|
296 |
+
|
297 |
+
# Configure the audio inputs
|
298 |
+
audio_ids_l = []
|
299 |
+
for audio_content in audio_contents:
|
300 |
+
if audio_content.audio_url not in ["placeholder", ""]:
|
301 |
+
raw_audio, _ = librosa.load(audio_content.audio_url, sr=self.audio_tokenizer.sampling_rate)
|
302 |
+
elif audio_content.raw_audio is not None:
|
303 |
+
raw_audio, _ = librosa.load(
|
304 |
+
BytesIO(base64.b64decode(audio_content.raw_audio)),
|
305 |
+
sr=self.audio_tokenizer.sampling_rate,
|
306 |
+
)
|
307 |
+
else:
|
308 |
+
raw_audio = None
|
309 |
+
|
310 |
+
if raw_audio is not None:
|
311 |
+
audio_ids = self.audio_tokenizer.encode(raw_audio, self.audio_tokenizer.sampling_rate)
|
312 |
+
audio_ids_l.append(audio_ids.squeeze(0).cpu())
|
313 |
+
|
314 |
+
if len(audio_ids_l) > 0:
|
315 |
+
audio_ids_start = torch.tensor(
|
316 |
+
np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
|
317 |
+
dtype=torch.long,
|
318 |
+
device=self.device,
|
319 |
+
)[0:-1]
|
320 |
+
audio_ids_concat = torch.cat(audio_ids_l, dim=1)
|
321 |
+
else:
|
322 |
+
audio_ids_start = None
|
323 |
+
audio_ids_concat = None
|
324 |
+
|
325 |
+
sample = ChatMLDatasetSample(
|
326 |
+
input_ids=torch.LongTensor(input_tokens),
|
327 |
+
label_ids=None,
|
328 |
+
audio_ids_concat=audio_ids_concat,
|
329 |
+
audio_ids_start=audio_ids_start,
|
330 |
+
audio_waveforms_concat=None,
|
331 |
+
audio_waveforms_start=None,
|
332 |
+
audio_sample_rate=None,
|
333 |
+
audio_speaker_indices=None,
|
334 |
+
)
|
335 |
+
data = self.collator([sample])
|
336 |
+
inputs = asdict(data)
|
337 |
+
for k, v in inputs.items():
|
338 |
+
if isinstance(v, torch.Tensor):
|
339 |
+
inputs[k] = v.to(self.model.device)
|
340 |
+
|
341 |
+
return inputs
|
342 |
+
|
343 |
+
def _prepare_kv_caches(self):
|
344 |
+
for kv_cache in self.kv_caches.values():
|
345 |
+
kv_cache.reset()
|
346 |
+
|
347 |
+
def generate(
|
348 |
+
self,
|
349 |
+
chat_ml_sample: ChatMLSample,
|
350 |
+
max_new_tokens: int,
|
351 |
+
temperature: float = 0.7,
|
352 |
+
top_k: Optional[int] = None,
|
353 |
+
top_p: float = 0.95,
|
354 |
+
stop_strings: Optional[List[str]] = None,
|
355 |
+
force_audio_gen: bool = False,
|
356 |
+
ras_win_len: Optional[int] = None,
|
357 |
+
ras_win_max_num_repeat: int = 2,
|
358 |
+
):
|
359 |
+
"""
|
360 |
+
Generate audio from a chatml sample.
|
361 |
+
Args:
|
362 |
+
chat_ml_sample: A chatml sample.
|
363 |
+
max_new_tokens: The maximum number of new tokens to generate.
|
364 |
+
temperature: The temperature to use for the generation.
|
365 |
+
top_p: The top p to use for the generation.
|
366 |
+
Returns:
|
367 |
+
A dictionary with the following keys:
|
368 |
+
audio: The generated audio.
|
369 |
+
sampling_rate: The sampling rate of the generated audio.
|
370 |
+
"""
|
371 |
+
# Default stop strings
|
372 |
+
if stop_strings is None:
|
373 |
+
stop_strings = ["<|end_of_text|>", "<|eot_id|>"]
|
374 |
+
|
375 |
+
with torch.no_grad(), self.generate_lock:
|
376 |
+
inputs = self._prepare_inputs(chat_ml_sample, force_audio_gen=force_audio_gen)
|
377 |
+
prompt_token_ids = inputs["input_ids"][0].cpu().numpy()
|
378 |
+
|
379 |
+
self._prepare_kv_caches()
|
380 |
+
|
381 |
+
outputs = self.model.generate(
|
382 |
+
**inputs,
|
383 |
+
max_new_tokens=max_new_tokens,
|
384 |
+
use_cache=True,
|
385 |
+
stop_strings=stop_strings,
|
386 |
+
tokenizer=self.tokenizer,
|
387 |
+
do_sample=False if temperature == 0.0 else True,
|
388 |
+
temperature=temperature,
|
389 |
+
top_k=top_k,
|
390 |
+
top_p=top_p,
|
391 |
+
past_key_values_buckets=self.kv_caches,
|
392 |
+
ras_win_len=ras_win_len,
|
393 |
+
ras_win_max_num_repeat=ras_win_max_num_repeat,
|
394 |
+
)
|
395 |
+
|
396 |
+
if len(outputs[1]) > 0:
|
397 |
+
wv_list = []
|
398 |
+
for output_audio in outputs[1]:
|
399 |
+
vq_code = revert_delay_pattern(output_audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
|
400 |
+
wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
|
401 |
+
wv_list.append(wv_numpy)
|
402 |
+
wv_numpy = np.concatenate(wv_list)
|
403 |
+
else:
|
404 |
+
wv_numpy = None
|
405 |
+
|
406 |
+
# We only support one request at a time now
|
407 |
+
generated_text_tokens = outputs[0][0].cpu().numpy()[len(prompt_token_ids) :]
|
408 |
+
generated_text = self.tokenizer.decode(generated_text_tokens)
|
409 |
+
generated_audio_tokens = outputs[1][0].cpu().numpy()
|
410 |
+
return HiggsAudioResponse(
|
411 |
+
audio=wv_numpy,
|
412 |
+
generated_audio_tokens=generated_audio_tokens,
|
413 |
+
sampling_rate=self.audio_tokenizer.sampling_rate,
|
414 |
+
generated_text=generated_text,
|
415 |
+
generated_text_tokens=generated_text_tokens,
|
416 |
+
usage={
|
417 |
+
"prompt_tokens": prompt_token_ids.shape[0],
|
418 |
+
"completion_tokens": generated_text_tokens.shape[0] + generated_audio_tokens.shape[1],
|
419 |
+
"total_tokens": (
|
420 |
+
prompt_token_ids.shape[0] + generated_text_tokens.shape[0] + generated_audio_tokens.shape[1]
|
421 |
+
),
|
422 |
+
"cached_tokens": 0,
|
423 |
+
},
|
424 |
+
)
|
higgs_audio/serve/utils.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uuid
|
2 |
+
import base64
|
3 |
+
import re
|
4 |
+
import regex
|
5 |
+
from typing import AsyncGenerator, Union
|
6 |
+
import io
|
7 |
+
from pydub import AudioSegment
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
from functools import lru_cache
|
11 |
+
|
12 |
+
from ..audio_processing.higgs_audio_tokenizer import HiggsAudioTokenizer
|
13 |
+
|
14 |
+
|
15 |
+
def random_uuid() -> str:
|
16 |
+
return str(uuid.uuid4().hex)
|
17 |
+
|
18 |
+
|
19 |
+
async def async_generator_wrap(first_element, gen: AsyncGenerator):
|
20 |
+
"""Wrap an async generator with the first element."""
|
21 |
+
yield first_element
|
22 |
+
async for item in gen:
|
23 |
+
yield item
|
24 |
+
|
25 |
+
|
26 |
+
@lru_cache(maxsize=50)
|
27 |
+
def encode_base64_content_from_file(file_path: str) -> str:
|
28 |
+
"""Encode a content from a local file to base64 format."""
|
29 |
+
# Read the MP3 file as binary and encode it directly to Base64
|
30 |
+
with open(file_path, "rb") as audio_file:
|
31 |
+
audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8")
|
32 |
+
return audio_base64
|
33 |
+
|
34 |
+
|
35 |
+
def pcm16_to_target_format(
|
36 |
+
np_audio: np.ndarray,
|
37 |
+
sample_rate: int,
|
38 |
+
bit_depth: int,
|
39 |
+
channels: int,
|
40 |
+
format: str,
|
41 |
+
target_rate: int,
|
42 |
+
):
|
43 |
+
wav_audio = AudioSegment(
|
44 |
+
np_audio.tobytes(),
|
45 |
+
frame_rate=sample_rate,
|
46 |
+
sample_width=bit_depth // 8,
|
47 |
+
channels=channels,
|
48 |
+
)
|
49 |
+
if target_rate is not None and target_rate != sample_rate:
|
50 |
+
wav_audio = wav_audio.set_frame_rate(target_rate)
|
51 |
+
|
52 |
+
# Convert WAV to MP3
|
53 |
+
target_io = io.BytesIO()
|
54 |
+
wav_audio.export(target_io, format=format)
|
55 |
+
target_io.seek(0)
|
56 |
+
|
57 |
+
return target_io
|
58 |
+
|
59 |
+
|
60 |
+
chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
|
61 |
+
|
62 |
+
|
63 |
+
def contains_chinese(text: str):
|
64 |
+
return bool(chinese_char_pattern.search(text))
|
65 |
+
|
66 |
+
|
67 |
+
# remove blank between chinese character
|
68 |
+
def replace_blank(text: str):
|
69 |
+
out_str = []
|
70 |
+
for i, c in enumerate(text):
|
71 |
+
if c == " ":
|
72 |
+
if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
|
73 |
+
out_str.append(c)
|
74 |
+
else:
|
75 |
+
out_str.append(c)
|
76 |
+
return "".join(out_str)
|
77 |
+
|
78 |
+
|
79 |
+
def replace_corner_mark(text: str):
|
80 |
+
text = text.replace("²", "平方")
|
81 |
+
text = text.replace("³", "立方")
|
82 |
+
return text
|
83 |
+
|
84 |
+
|
85 |
+
# remove meaningless symbol
|
86 |
+
def remove_bracket(text: str):
|
87 |
+
text = text.replace("(", "").replace(")", "")
|
88 |
+
text = text.replace("【", "").replace("】", "")
|
89 |
+
text = text.replace("`", "").replace("`", "")
|
90 |
+
text = text.replace("——", " ")
|
91 |
+
return text
|
92 |
+
|
93 |
+
|
94 |
+
# split paragrah logic:
|
95 |
+
# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
|
96 |
+
# 2. cal sentence len according to lang
|
97 |
+
# 3. split sentence according to puncatation
|
98 |
+
def split_paragraph(
|
99 |
+
text: str,
|
100 |
+
tokenize,
|
101 |
+
lang="zh",
|
102 |
+
token_max_n=80,
|
103 |
+
token_min_n=60,
|
104 |
+
merge_len=20,
|
105 |
+
comma_split=False,
|
106 |
+
):
|
107 |
+
def calc_utt_length(_text: str):
|
108 |
+
if lang == "zh":
|
109 |
+
return len(_text)
|
110 |
+
else:
|
111 |
+
return len(tokenize(_text))
|
112 |
+
|
113 |
+
def should_merge(_text: str):
|
114 |
+
if lang == "zh":
|
115 |
+
return len(_text) < merge_len
|
116 |
+
else:
|
117 |
+
return len(tokenize(_text)) < merge_len
|
118 |
+
|
119 |
+
if lang == "zh":
|
120 |
+
pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"]
|
121 |
+
else:
|
122 |
+
pounc = [".", "?", "!", ";", ":"]
|
123 |
+
if comma_split:
|
124 |
+
pounc.extend([",", ","])
|
125 |
+
|
126 |
+
if text[-1] not in pounc:
|
127 |
+
if lang == "zh":
|
128 |
+
text += "。"
|
129 |
+
else:
|
130 |
+
text += "."
|
131 |
+
|
132 |
+
st = 0
|
133 |
+
utts = []
|
134 |
+
for i, c in enumerate(text):
|
135 |
+
if c in pounc:
|
136 |
+
if len(text[st:i]) > 0:
|
137 |
+
utts.append(text[st:i] + c)
|
138 |
+
if i + 1 < len(text) and text[i + 1] in ['"', "”"]:
|
139 |
+
tmp = utts.pop(-1)
|
140 |
+
utts.append(tmp + text[i + 1])
|
141 |
+
st = i + 2
|
142 |
+
else:
|
143 |
+
st = i + 1
|
144 |
+
|
145 |
+
final_utts = []
|
146 |
+
cur_utt = ""
|
147 |
+
for utt in utts:
|
148 |
+
if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
|
149 |
+
final_utts.append(cur_utt)
|
150 |
+
cur_utt = ""
|
151 |
+
cur_utt = cur_utt + utt
|
152 |
+
if len(cur_utt) > 0:
|
153 |
+
if should_merge(cur_utt) and len(final_utts) != 0:
|
154 |
+
final_utts[-1] = final_utts[-1] + cur_utt
|
155 |
+
else:
|
156 |
+
final_utts.append(cur_utt)
|
157 |
+
|
158 |
+
return final_utts
|
159 |
+
|
160 |
+
|
161 |
+
def is_only_punctuation(text: str):
|
162 |
+
# Regular expression: Match strings that consist only of punctuation marks or are empty.
|
163 |
+
punctuation_pattern = r"^[\p{P}\p{S}]*$"
|
164 |
+
return bool(regex.fullmatch(punctuation_pattern, text))
|
165 |
+
|
166 |
+
|
167 |
+
# spell Arabic numerals
|
168 |
+
def spell_out_number(text: str, inflect_parser):
|
169 |
+
new_text = []
|
170 |
+
st = None
|
171 |
+
for i, c in enumerate(text):
|
172 |
+
if not c.isdigit():
|
173 |
+
if st is not None:
|
174 |
+
num_str = inflect_parser.number_to_words(text[st:i])
|
175 |
+
new_text.append(num_str)
|
176 |
+
st = None
|
177 |
+
new_text.append(c)
|
178 |
+
else:
|
179 |
+
if st is None:
|
180 |
+
st = i
|
181 |
+
if st is not None and st < len(text):
|
182 |
+
num_str = inflect_parser.number_to_words(text[st:])
|
183 |
+
new_text.append(num_str)
|
184 |
+
return "".join(new_text)
|
185 |
+
|
186 |
+
|
187 |
+
def remove_emoji(text: str):
|
188 |
+
# Pattern to match emojis and their modifiers
|
189 |
+
# - Standard emoji range
|
190 |
+
# - Zero-width joiners (U+200D)
|
191 |
+
# - Variation selectors (U+FE0F, U+FE0E)
|
192 |
+
# - Skin tone modifiers (U+1F3FB to U+1F3FF)
|
193 |
+
emoji_pattern = re.compile(
|
194 |
+
r"["
|
195 |
+
r"\U00010000-\U0010FFFF" # Standard emoji range
|
196 |
+
r"\u200D" # Zero-width joiner
|
197 |
+
r"\uFE0F\uFE0E" # Variation selectors
|
198 |
+
r"\U0001F3FB-\U0001F3FF" # Skin tone modifiers
|
199 |
+
r"]+",
|
200 |
+
flags=re.UNICODE,
|
201 |
+
)
|
202 |
+
return emoji_pattern.sub(r"", text)
|
203 |
+
|
204 |
+
|
205 |
+
def remove_repeated_punctuations(text, punctuations):
|
206 |
+
if len(punctuations) == 0:
|
207 |
+
return text
|
208 |
+
pattern = f"[{re.escape(''.join(punctuations))}]" # Create regex pattern for given punctuations
|
209 |
+
return re.sub(rf"({pattern})\1+", r"\1", text)
|
210 |
+
|
211 |
+
|
212 |
+
def full_to_half_width(text: str) -> str:
|
213 |
+
"""Convert full-width punctuation to half-width in a given string."""
|
214 |
+
full_width = "!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"
|
215 |
+
half_width = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
|
216 |
+
trans_table = str.maketrans(full_width, half_width)
|
217 |
+
return text.translate(trans_table)
|
218 |
+
|
219 |
+
|
220 |
+
def split_interleaved_delayed_audios(
|
221 |
+
audio_data: Union[list[list[int]], torch.Tensor],
|
222 |
+
audio_tokenizer: HiggsAudioTokenizer,
|
223 |
+
audio_stream_eos_id: int,
|
224 |
+
) -> list[tuple[list[list[int]], torch.Tensor]]:
|
225 |
+
separator = [audio_stream_eos_id] * audio_tokenizer.num_codebooks
|
226 |
+
|
227 |
+
# Convert separator to numpy array if audio_data is numpy array
|
228 |
+
if isinstance(audio_data, torch.Tensor):
|
229 |
+
audio_data = audio_data.transpose(1, 0)
|
230 |
+
separator = torch.tensor(separator)
|
231 |
+
# Find the indices where the rows equal the separator
|
232 |
+
split_indices = torch.where(torch.all(audio_data == separator, dim=1))[0]
|
233 |
+
start = 0
|
234 |
+
groups = []
|
235 |
+
for idx in split_indices:
|
236 |
+
groups.append(audio_data[start:idx].transpose(1, 0))
|
237 |
+
start = idx + 1
|
238 |
+
if start < len(audio_data):
|
239 |
+
groups.append(audio_data[start:].transpose(1, 0))
|
240 |
+
else:
|
241 |
+
groups = []
|
242 |
+
current = []
|
243 |
+
for row in audio_data:
|
244 |
+
current.append(row)
|
245 |
+
|
246 |
+
if row == separator:
|
247 |
+
groups.append(current)
|
248 |
+
current = []
|
249 |
+
|
250 |
+
# Don't forget the last group if there's no trailing separator
|
251 |
+
if current:
|
252 |
+
groups.append(current)
|
253 |
+
|
254 |
+
return groups
|
pyproject.toml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[tool.ruff]
|
6 |
+
line-length = 119
|
7 |
+
target-version = "py310"
|
8 |
+
indent-width = 4
|
9 |
+
exclude = [
|
10 |
+
".bzr",
|
11 |
+
".direnv",
|
12 |
+
".eggs",
|
13 |
+
".git",
|
14 |
+
".git-rewrite",
|
15 |
+
".hg",
|
16 |
+
".ipynb_checkpoints",
|
17 |
+
".mypy_cache",
|
18 |
+
".nox",
|
19 |
+
".pants.d",
|
20 |
+
".pyenv",
|
21 |
+
".pytest_cache",
|
22 |
+
".pytype",
|
23 |
+
".ruff_cache",
|
24 |
+
".svn",
|
25 |
+
".tox",
|
26 |
+
".venv",
|
27 |
+
".vscode",
|
28 |
+
"__pypackages__",
|
29 |
+
"_build",
|
30 |
+
"buck-out",
|
31 |
+
"build",
|
32 |
+
"dist",
|
33 |
+
"node_modules",
|
34 |
+
"site-packages",
|
35 |
+
"venv",
|
36 |
+
"external",
|
37 |
+
"third_party",
|
38 |
+
]
|
39 |
+
|
40 |
+
[tool.ruff.lint]
|
41 |
+
preview = true
|
42 |
+
ignore-init-module-imports = true
|
43 |
+
extend-select = [
|
44 |
+
"B009", # static getattr
|
45 |
+
"B010", # static setattr
|
46 |
+
"CPY", # Copyright
|
47 |
+
"E", # PEP8 errors
|
48 |
+
"F", # PEP8 formatting
|
49 |
+
"I", # Import sorting
|
50 |
+
"TID251", # Banned API
|
51 |
+
"UP", # Pyupgrade
|
52 |
+
"W", # PEP8 warnings
|
53 |
+
]
|
54 |
+
ignore = [
|
55 |
+
"E501", # Line length (handled by ruff-format)
|
56 |
+
"E741", # Ambiguous variable name
|
57 |
+
"W605", # Invalid escape sequence
|
58 |
+
"UP007", # X | Y type annotations
|
59 |
+
]
|
60 |
+
|
61 |
+
[tool.ruff.lint.per-file-ignores]
|
62 |
+
"__init__.py" = [
|
63 |
+
"F401", # Ignore seemingly unused imports (they're meant for re-export)
|
64 |
+
]
|
65 |
+
|
66 |
+
[tool.ruff.lint.isort]
|
67 |
+
lines-after-imports = 2
|
68 |
+
known-first-party = ["character_tuning"]
|
69 |
+
|
70 |
+
[tool.ruff.format]
|
71 |
+
# Like Black, use double quotes for strings.
|
72 |
+
quote-style = "double"
|
73 |
+
|
74 |
+
# Like Black, indent with spaces, rather than tabs.
|
75 |
+
indent-style = "space"
|
76 |
+
|
77 |
+
# Like Black, respect magic trailing commas.
|
78 |
+
skip-magic-trailing-comma = false
|
79 |
+
|
80 |
+
# Like Black, automatically detect the appropriate line ending.
|
81 |
+
line-ending = "auto"
|
82 |
+
|
83 |
+
# Enable auto-formatting of code examples in docstrings. Markdown,
|
84 |
+
# reStructuredText code/literal blocks and doctests are all supported.
|
85 |
+
#
|
86 |
+
# This is currently disabled by default, but it is planned for this
|
87 |
+
# to be opt-out in the future.
|
88 |
+
docstring-code-format = false
|
89 |
+
|
90 |
+
# Set the line length limit used when formatting code snippets in
|
91 |
+
# docstrings.
|
92 |
+
#
|
93 |
+
# This only has an effect when the `docstring-code-format` setting is
|
94 |
+
# enabled.
|
95 |
+
docstring-code-line-length = "dynamic"
|
96 |
+
|
97 |
+
[tool.ruff.lint.flake8-tidy-imports.banned-api]
|
98 |
+
"os.getenv".msg = "Use os.environ instead"
|
99 |
+
"os.putenv".msg = "Use os.environ instead"
|
100 |
+
"os.unsetenv".msg = "Use os.environ instead"
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
descript-audio-codec
|
2 |
+
torch==2.5.1
|
3 |
+
torchaudio==2.5.1
|
4 |
+
transformers>=4.45.1,<4.47.0
|
5 |
+
librosa
|
6 |
+
dacite
|
7 |
+
boto3==1.35.36
|
8 |
+
s3fs
|
9 |
+
json_repair
|
10 |
+
pandas
|
11 |
+
pydantic
|
12 |
+
vector_quantize_pytorch
|
13 |
+
loguru
|
14 |
+
pydub
|
15 |
+
ruff==0.12.2
|
16 |
+
omegaconf
|
17 |
+
click
|