Commit
·
26f400e
1
Parent(s):
2b694f4
Add file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +6 -6
- app.py +137 -0
- audios/1.wav +3 -0
- audios/2.wav +3 -0
- audios/3.wav +3 -0
- audios/4.wav +3 -0
- audios/5.wav +3 -0
- infer.py +175 -0
- models.py +299 -0
- requirements.txt +4 -0
- speechtokenizer/__init__.py +3 -0
- speechtokenizer/model.py +215 -0
- speechtokenizer/modules/__init__.py +21 -0
- speechtokenizer/modules/__pycache__/__init__.cpython-310.pyc +0 -0
- speechtokenizer/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- speechtokenizer/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- speechtokenizer/modules/__pycache__/conv.cpython-310.pyc +0 -0
- speechtokenizer/modules/__pycache__/conv.cpython-311.pyc +0 -0
- speechtokenizer/modules/__pycache__/conv.cpython-39.pyc +0 -0
- speechtokenizer/modules/__pycache__/lstm.cpython-310.pyc +0 -0
- speechtokenizer/modules/__pycache__/lstm.cpython-311.pyc +0 -0
- speechtokenizer/modules/__pycache__/lstm.cpython-39.pyc +0 -0
- speechtokenizer/modules/__pycache__/norm.cpython-310.pyc +0 -0
- speechtokenizer/modules/__pycache__/norm.cpython-311.pyc +0 -0
- speechtokenizer/modules/__pycache__/norm.cpython-39.pyc +0 -0
- speechtokenizer/modules/__pycache__/seanet.cpython-310.pyc +0 -0
- speechtokenizer/modules/__pycache__/seanet.cpython-311.pyc +0 -0
- speechtokenizer/modules/__pycache__/seanet.cpython-39.pyc +0 -0
- speechtokenizer/modules/conv.py +252 -0
- speechtokenizer/modules/lstm.py +31 -0
- speechtokenizer/modules/norm.py +28 -0
- speechtokenizer/modules/seanet.py +408 -0
- speechtokenizer/pretrained_model/SpeechTokenizer.pt +3 -0
- speechtokenizer/pretrained_model/speechtokenizer_hubert_avg_config.json +47 -0
- speechtokenizer/quantization/__init__.py +8 -0
- speechtokenizer/quantization/__pycache__/__init__.cpython-310.pyc +0 -0
- speechtokenizer/quantization/__pycache__/__init__.cpython-311.pyc +0 -0
- speechtokenizer/quantization/__pycache__/__init__.cpython-39.pyc +0 -0
- speechtokenizer/quantization/__pycache__/core_vq.cpython-310.pyc +0 -0
- speechtokenizer/quantization/__pycache__/core_vq.cpython-311.pyc +0 -0
- speechtokenizer/quantization/__pycache__/core_vq.cpython-39.pyc +0 -0
- speechtokenizer/quantization/__pycache__/distrib.cpython-310.pyc +0 -0
- speechtokenizer/quantization/__pycache__/distrib.cpython-311.pyc +0 -0
- speechtokenizer/quantization/__pycache__/distrib.cpython-39.pyc +0 -0
- speechtokenizer/quantization/__pycache__/vq.cpython-310.pyc +0 -0
- speechtokenizer/quantization/__pycache__/vq.cpython-311.pyc +0 -0
- speechtokenizer/quantization/__pycache__/vq.cpython-39.pyc +0 -0
- speechtokenizer/quantization/ac.py +292 -0
- speechtokenizer/quantization/core_vq.py +385 -0
- speechtokenizer/quantization/distrib.py +126 -0
README.md
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
---
|
2 |
title: VoiceMark
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license:
|
11 |
-
short_description: Zero-Shot Voice Cloning-Resistant Watermarking
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: VoiceMark
|
3 |
+
emoji: 📊
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.16.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: gpl
|
11 |
+
short_description: Zero-Shot Voice Cloning-Resistant Watermarking
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import torchaudio
|
5 |
+
from infer import (
|
6 |
+
WatermarkSolver,
|
7 |
+
hamming_distance
|
8 |
+
)
|
9 |
+
|
10 |
+
# Predefined watermarks (instead of loading from a JSON file)
|
11 |
+
watermarks = {
|
12 |
+
"VoiceMark": "1000010101010011",
|
13 |
+
"Voice Cloning": "1111111001000010",
|
14 |
+
"Speech Security": "1011101100001110",
|
15 |
+
"Audio Watermarking": "0110110011100010",
|
16 |
+
"Deep Learning": "0000100111111000",
|
17 |
+
"Artificial Intelligence": "0010000100011111",
|
18 |
+
"Hello World": "0001111101110001",
|
19 |
+
"Happy New Year": "1101011011011101",
|
20 |
+
"World Peace": "0011110010011110",
|
21 |
+
"Good Morning": "0000001011000010",
|
22 |
+
}
|
23 |
+
|
24 |
+
# Initialize WatermarkSolver model
|
25 |
+
solver = WatermarkSolver()
|
26 |
+
solver.load_model(checkpoint_dir="./", checkpoint_name="voicemark.pth", strict=True)
|
27 |
+
|
28 |
+
# Gradio interface
|
29 |
+
with gr.Blocks() as demo:
|
30 |
+
gr.Markdown(
|
31 |
+
"# VoiceMark: Zero-Shot Voice Cloning-Resistant Watermarking Approach Leveraging Speaker-Specific Latents"
|
32 |
+
)
|
33 |
+
with gr.Column():
|
34 |
+
gr.Image(
|
35 |
+
value="voicemark_overview.png",
|
36 |
+
width=925,
|
37 |
+
height=487,
|
38 |
+
elem_id="overview_image",
|
39 |
+
label="overview"
|
40 |
+
)
|
41 |
+
# Step 1: Upload audio and select watermark
|
42 |
+
gr.HTML("<h3 style='text-align: center;'>The overall architecture of our proposed VoiceMark</h3>")
|
43 |
+
|
44 |
+
# Step 1: Upload audio and select watermark
|
45 |
+
gr.Markdown(
|
46 |
+
"""
|
47 |
+
**Step 1**: Upload an audio file or select one from the provided samples, choose a watermark, and generate the watermarked audio.
|
48 |
+
"""
|
49 |
+
)
|
50 |
+
|
51 |
+
with gr.Row():
|
52 |
+
with gr.Column():
|
53 |
+
audio_input = gr.Audio(label="Upload Audio", type="filepath")
|
54 |
+
|
55 |
+
gr.Examples(
|
56 |
+
examples=[
|
57 |
+
["audios/1.wav"],
|
58 |
+
["audios/2.wav"],
|
59 |
+
["audios/3.wav"],
|
60 |
+
["audios/4.wav"],
|
61 |
+
["audios/5.wav"],
|
62 |
+
],
|
63 |
+
inputs=audio_input,
|
64 |
+
label="Sample Audios (Click to Use)"
|
65 |
+
)
|
66 |
+
|
67 |
+
with gr.Column():
|
68 |
+
audio_output = gr.Audio(label="Watermarked Audio", type="filepath")
|
69 |
+
watermark_list = gr.Dropdown(
|
70 |
+
label="Select Watermark", choices=list(watermarks.keys()), interactive=True
|
71 |
+
)
|
72 |
+
add_watermark_button = gr.Button("Add Watermark to Audio")
|
73 |
+
|
74 |
+
# Step 2: TTS tools demo links
|
75 |
+
gr.Markdown(
|
76 |
+
"""
|
77 |
+
**Step 2**: Download the generated watermarked audio, then use Zero-Shot Voice Cloning tools to generate the cloned audio. Some available tools are:
|
78 |
+
- [CosyVoice2: Scalable Streaming Speech Synthesis with Large Language Models](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B)
|
79 |
+
- [F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
80 |
+
- [MaskGCT: Zero-Shot Text-to-Speech with Masked Generative Codec Transformer](https://huggingface.co/spaces/amphion/maskgct)
|
81 |
+
"""
|
82 |
+
)
|
83 |
+
|
84 |
+
# Step 3: Upload cloned audio to decode watermark
|
85 |
+
gr.Markdown(
|
86 |
+
"""
|
87 |
+
**Step 3**: Upload the cloned audio and decode your watermark.
|
88 |
+
"""
|
89 |
+
)
|
90 |
+
|
91 |
+
with gr.Row():
|
92 |
+
decode_audio_input = gr.Audio(label="Upload Cloned Audio", type="filepath")
|
93 |
+
with gr.Column():
|
94 |
+
decoded_watermark_output = gr.Textbox(label="Decoded Watermark")
|
95 |
+
decode_button = gr.Button("Decode Watermark")
|
96 |
+
|
97 |
+
def process_audio(audio_path, watermark_text):
|
98 |
+
if not audio_path:
|
99 |
+
return "No audio selected. Please upload or select a sample."
|
100 |
+
try:
|
101 |
+
watermarked_audio = solver.infer_for_ui(
|
102 |
+
audio_path, watermarks[watermark_text]
|
103 |
+
)
|
104 |
+
return watermarked_audio
|
105 |
+
except ValueError as e:
|
106 |
+
return str(e)
|
107 |
+
|
108 |
+
add_watermark_button.click(
|
109 |
+
process_audio,
|
110 |
+
inputs=[audio_input, watermark_list],
|
111 |
+
outputs=audio_output
|
112 |
+
)
|
113 |
+
|
114 |
+
def decode_watermark(audio_path):
|
115 |
+
try:
|
116 |
+
detect_prob, decoded_id = solver.decode_for_ui(audio_path)
|
117 |
+
if detect_prob < 1e-2:
|
118 |
+
return "No matching watermark found"
|
119 |
+
closest_match = None
|
120 |
+
min_distance = float("inf")
|
121 |
+
for text, id_bin in watermarks.items():
|
122 |
+
distance = hamming_distance(decoded_id, id_bin, base=16)
|
123 |
+
if distance < min_distance:
|
124 |
+
closest_match = text
|
125 |
+
min_distance = distance
|
126 |
+
if min_distance < 10:
|
127 |
+
return closest_match
|
128 |
+
return "No matching watermark found"
|
129 |
+
except ValueError as e:
|
130 |
+
return str(e)
|
131 |
+
|
132 |
+
decode_button.click(
|
133 |
+
decode_watermark, inputs=decode_audio_input, outputs=decoded_watermark_output
|
134 |
+
)
|
135 |
+
|
136 |
+
# Launch the Gradio app
|
137 |
+
demo.launch()
|
audios/1.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cb97a4d6a89ce3cfff4d504de24ef4d79658cc68cc765da7c905b30ecb5e9f1d
|
3 |
+
size 160078
|
audios/2.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d5cf698739bad50c02ab96375bed49594a0b1a4421406b6ed10edc432f09dc94
|
3 |
+
size 152078
|
audios/3.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4673c2f94dec181e7603df7564af0778d6589a7d7f4878f30a9c42bdfc2324f7
|
3 |
+
size 160078
|
audios/4.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10c63ef8e2866b7e4b98d15cbbc084da667b350ab4d084f338b656661b619ceb
|
3 |
+
size 98750
|
audios/5.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:049b7061915306d541d2d8c059c6c9599502997562bc32120f16a873a55e499e
|
3 |
+
size 63072
|
infer.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models import SBW
|
2 |
+
import torchaudio
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
|
6 |
+
def hamming_distance(s1, s2, base=2):
|
7 |
+
"""
|
8 |
+
Calculate the absolute difference between two strings interpreted in chunks of a specified base.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
s1 (str): The first binary string.
|
12 |
+
s2 (str): The second binary string.
|
13 |
+
base (int): The base to interpret the binary strings (e.g., 2 for binary, 16 for hexadecimal).
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
int: The sum of absolute differences for corresponding chunks.
|
17 |
+
|
18 |
+
Raises:
|
19 |
+
ValueError: If the strings are not of equal length or invalid for the given base.
|
20 |
+
"""
|
21 |
+
if len(s1) != len(s2):
|
22 |
+
raise ValueError("Both strings must be of equal length")
|
23 |
+
|
24 |
+
# Determine the chunk size for the given base
|
25 |
+
import math
|
26 |
+
|
27 |
+
chunk_size = int(math.log(base, 2))
|
28 |
+
|
29 |
+
if len(s1) % chunk_size != 0 or len(s2) % chunk_size != 0:
|
30 |
+
raise ValueError(
|
31 |
+
f"Binary strings must be a multiple of {chunk_size} bits for base {base}"
|
32 |
+
)
|
33 |
+
|
34 |
+
# Split binary strings into chunks
|
35 |
+
def to_chunks(binary_str):
|
36 |
+
return [
|
37 |
+
binary_str[i : i + chunk_size]
|
38 |
+
for i in range(0, len(binary_str), chunk_size)
|
39 |
+
]
|
40 |
+
|
41 |
+
chunks_s1 = to_chunks(s1)
|
42 |
+
chunks_s2 = to_chunks(s2)
|
43 |
+
|
44 |
+
# Convert chunks to integers and calculate absolute difference
|
45 |
+
def absolute_difference_chunks(chunk1, chunk2):
|
46 |
+
int1 = int(chunk1, 2)
|
47 |
+
int2 = int(chunk2, 2)
|
48 |
+
return abs(int1 - int2)
|
49 |
+
|
50 |
+
return sum(
|
51 |
+
absolute_difference_chunks(c1, c2) for c1, c2 in zip(chunks_s1, chunks_s2)
|
52 |
+
)
|
53 |
+
|
54 |
+
def random_message(nbits: int, batch_size: int) -> torch.Tensor:
|
55 |
+
"""Return random message as 0/1 tensor."""
|
56 |
+
if nbits == 0:
|
57 |
+
return torch.tensor([])
|
58 |
+
return torch.randint(0, 2, (batch_size, nbits))
|
59 |
+
|
60 |
+
|
61 |
+
def string_to_message(message_str: str, batch_size: int) -> torch.Tensor:
|
62 |
+
"""
|
63 |
+
Convert a binary string to a message tensor.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
message_str (str): A string of '0's and '1's.
|
67 |
+
batch_size (int): The batch size for the output tensor.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
torch.Tensor: A tensor of shape (batch_size, nbits) where nbits is the length of message_str.
|
71 |
+
"""
|
72 |
+
# Convert the string to a list of integers (0 and 1)
|
73 |
+
message_list = [int(bit) for bit in message_str]
|
74 |
+
nbits = len(message_list)
|
75 |
+
# Convert the list to a tensor and repeat it for the batch size
|
76 |
+
message_tensor = torch.tensor(message_list).unsqueeze(0).repeat(batch_size, 1)
|
77 |
+
return message_tensor
|
78 |
+
|
79 |
+
|
80 |
+
class WatermarkSolver:
|
81 |
+
def __init__(self):
|
82 |
+
self.device = 'cpu'
|
83 |
+
self.sample_rate = 16000
|
84 |
+
self.nbits = 16
|
85 |
+
self.model = SBW()
|
86 |
+
|
87 |
+
def load_model(
|
88 |
+
self, checkpoint_dir, checkpoint_name, strict=False
|
89 |
+
):
|
90 |
+
"""
|
91 |
+
Load the latest model weights from the checkpoint directory.
|
92 |
+
"""
|
93 |
+
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")]
|
94 |
+
if not checkpoint_files:
|
95 |
+
print("No checkpoint files found in the directory.")
|
96 |
+
return
|
97 |
+
|
98 |
+
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
|
99 |
+
print(f"Loading model weights from {checkpoint_path}...")
|
100 |
+
|
101 |
+
# Load the checkpoint
|
102 |
+
checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'),weights_only=False)
|
103 |
+
|
104 |
+
# Load model state dict with strict=False to allow missing keys (semantic_encoder)
|
105 |
+
model_state_dict = checkpoint["model_state_dict"]
|
106 |
+
|
107 |
+
new_state_dict = {}
|
108 |
+
for k, v in model_state_dict.items():
|
109 |
+
new_key = k.replace("module.", "")
|
110 |
+
new_state_dict[new_key] = v
|
111 |
+
if not strict:
|
112 |
+
new_state_dict = {
|
113 |
+
k: v
|
114 |
+
for k, v in new_state_dict.items()
|
115 |
+
if k in self.model.state_dict()
|
116 |
+
and self.model.state_dict()[k].shape == v.shape
|
117 |
+
}
|
118 |
+
self.model.load_state_dict(new_state_dict, strict=False)
|
119 |
+
|
120 |
+
print("Model state dict loaded successfully.")
|
121 |
+
self.epoch = checkpoint["epoch"]
|
122 |
+
print(f"Checkpoint loaded from epoch {checkpoint['epoch']}.")
|
123 |
+
|
124 |
+
def infer_for_ui(self, input_audio_path, watermark, output_audio_path="tmp"):
|
125 |
+
message_str = watermark
|
126 |
+
self.model.eval()
|
127 |
+
|
128 |
+
waveform, sample_rate = torchaudio.load(input_audio_path)
|
129 |
+
|
130 |
+
if sample_rate != self.sample_rate:
|
131 |
+
resampler = torchaudio.transforms.Resample(
|
132 |
+
orig_freq=sample_rate, new_freq=self.sample_rate
|
133 |
+
)
|
134 |
+
waveform = resampler(waveform)
|
135 |
+
|
136 |
+
waveform = waveform[:1].to(self.device).unsqueeze(1)
|
137 |
+
message = string_to_message(message_str=message_str, batch_size=1).to(
|
138 |
+
self.device
|
139 |
+
)
|
140 |
+
|
141 |
+
with torch.no_grad():
|
142 |
+
output = self.model(
|
143 |
+
waveform,
|
144 |
+
message=message,
|
145 |
+
)
|
146 |
+
y_wm = output["recon_wm"]
|
147 |
+
|
148 |
+
os.makedirs(output_audio_path, exist_ok=True)
|
149 |
+
watermarked_audio_path = os.path.join(
|
150 |
+
output_audio_path,
|
151 |
+
f"{os.path.splitext(os.path.basename(input_audio_path))[0]}_watermarked.wav",
|
152 |
+
)
|
153 |
+
torchaudio.save(watermarked_audio_path, y_wm.squeeze(1).cpu(), self.sample_rate)
|
154 |
+
return watermarked_audio_path
|
155 |
+
|
156 |
+
def decode_for_ui(self, input_audio_path):
|
157 |
+
self.model.eval()
|
158 |
+
|
159 |
+
waveform, sample_rate = torchaudio.load(input_audio_path)
|
160 |
+
|
161 |
+
if sample_rate != self.sample_rate:
|
162 |
+
resampler = torchaudio.transforms.Resample(
|
163 |
+
orig_freq=sample_rate, new_freq=self.sample_rate
|
164 |
+
)
|
165 |
+
waveform = resampler(waveform)
|
166 |
+
|
167 |
+
waveform = waveform[:1].to(self.device).unsqueeze(1)
|
168 |
+
|
169 |
+
with torch.no_grad():
|
170 |
+
detect_prob, detected_message_tensor, _ = self.model.detect_watermark(waveform)
|
171 |
+
detected_id = "".join(
|
172 |
+
map(str, detected_message_tensor.squeeze(0).cpu().numpy().tolist())
|
173 |
+
)
|
174 |
+
|
175 |
+
return detect_prob,detected_id
|
models.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import Dict, Optional, Tuple
|
4 |
+
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
class WMDetector(nn.Module):
|
11 |
+
"""
|
12 |
+
Detect watermarks in an audio signal using a Transformer architecture,
|
13 |
+
where the watermark bits are split into bytes (8 bits each).
|
14 |
+
We assume nbits is a multiple of 8.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self, input_channels: int, nbits: int, nchunk_size: int, d_model: int = 512
|
19 |
+
):
|
20 |
+
"""
|
21 |
+
Args:
|
22 |
+
input_channels (int): Number of input channels in the audio feature (e.g., mel channels).
|
23 |
+
nbits (int): Total number of bits in the watermark, must be a multiple of 8.
|
24 |
+
d_model (int): Embedding dimension for the Transformer.
|
25 |
+
"""
|
26 |
+
super().__init__()
|
27 |
+
self.nchunk_size = nchunk_size
|
28 |
+
assert nbits % nchunk_size == 0, "nbits must be a multiple of 8!"
|
29 |
+
self.nbits = nbits
|
30 |
+
self.d_model = d_model
|
31 |
+
# Number of bytes
|
32 |
+
self.nchunks = nbits // nchunk_size
|
33 |
+
|
34 |
+
# 1D convolution to map the input channels to d_model
|
35 |
+
self.embedding = nn.Conv1d(input_channels, d_model, kernel_size=1)
|
36 |
+
|
37 |
+
# Transformer encoder block
|
38 |
+
self.transformer = nn.TransformerEncoder(
|
39 |
+
nn.TransformerEncoderLayer(
|
40 |
+
d_model=d_model,
|
41 |
+
nhead=1,
|
42 |
+
dim_feedforward=d_model * 2,
|
43 |
+
activation="gelu",
|
44 |
+
batch_first=True,
|
45 |
+
),
|
46 |
+
num_layers=8,
|
47 |
+
)
|
48 |
+
|
49 |
+
# A linear head for watermark presence detection (binary)
|
50 |
+
self.watermark_head = nn.Linear(d_model, 1)
|
51 |
+
|
52 |
+
# For each byte, we perform a 256-way classification
|
53 |
+
self.message_heads = nn.ModuleList(
|
54 |
+
nn.Linear(d_model, 2**nchunk_size) for _ in range(self.nchunks)
|
55 |
+
)
|
56 |
+
|
57 |
+
# Learnable embeddings for each byte chunk (instead of per bit)
|
58 |
+
# Shape: [nchunks, d_model]
|
59 |
+
self.nchunk_embeddings = nn.Parameter(torch.randn(self.nchunks, d_model))
|
60 |
+
|
61 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
62 |
+
"""
|
63 |
+
Forward pass of the detector.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
logits (torch.Tensor): Watermark detection logits of shape [batch, seq_len].
|
67 |
+
chunk_logits (torch.Tensor): Byte-level classification logits of shape [batch, nchunks, 256].
|
68 |
+
"""
|
69 |
+
batch_size, input_channels, time_steps = x.shape
|
70 |
+
|
71 |
+
# 1) Map [batch, in_channels, time_steps] → [batch, time_steps, d_model]
|
72 |
+
x = self.embedding(x).permute(0, 2, 1) # [batch, time_steps, d_model]
|
73 |
+
|
74 |
+
# 2) Prepend chunk embeddings at the beginning of the sequence
|
75 |
+
# [nchunks, d_model] → [1, nchunks, d_model] → [batch, nchunks, d_model]
|
76 |
+
nchunk_embeds = self.nchunk_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
|
77 |
+
# Concatenate along the time dimension: [batch, nchunks + time_steps, d_model]
|
78 |
+
x = torch.cat([nchunk_embeds, x], dim=1)
|
79 |
+
|
80 |
+
# 3) Pass through the Transformer
|
81 |
+
x = self.transformer(x)
|
82 |
+
# x has shape [batch, nchunks + time_steps, d_model]
|
83 |
+
|
84 |
+
# (a) Watermark presence detection: skip the first nchunks
|
85 |
+
detection_part = x[:, self.nchunks :] # [batch, time_steps, d_model]
|
86 |
+
logits = self.watermark_head(detection_part).squeeze(-1) # [batch, time_steps]
|
87 |
+
|
88 |
+
# (b) Message decoding: use the first nchunks
|
89 |
+
message_part = x[:, : self.nchunks] # [batch, nchunks, d_model]
|
90 |
+
chunk_logits_list = []
|
91 |
+
for i, head in enumerate(self.message_heads):
|
92 |
+
# message_part[:, i, :] has shape [batch, d_model]
|
93 |
+
# each head outputs [batch, 256]
|
94 |
+
chunk_vec = message_part[:, i, :]
|
95 |
+
chunk_logits_list.append(head(chunk_vec).unsqueeze(1)) # [batch, 1, 256]
|
96 |
+
|
97 |
+
# Concatenate along the 'nchunks' dimension → [batch, nchunks, 256]
|
98 |
+
chunk_logits = torch.cat(chunk_logits_list, dim=1)
|
99 |
+
|
100 |
+
return logits, chunk_logits
|
101 |
+
|
102 |
+
def detect_watermark(
|
103 |
+
self,
|
104 |
+
x: torch.Tensor,
|
105 |
+
sample_rate: Optional[int] = None,
|
106 |
+
threshold: float = 0.5,
|
107 |
+
) -> Tuple[float, torch.Tensor, torch.Tensor]:
|
108 |
+
"""
|
109 |
+
A convenience function for inference.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
detect_prob (float): Probability that the audio is watermarked.
|
113 |
+
binary_message (torch.Tensor): The recovered message of shape [batch, nbits] (binary).
|
114 |
+
detected (torch.Tensor): The sigmoid values of the per-timestep watermark detection.
|
115 |
+
"""
|
116 |
+
logits, chunk_logits = self.forward(x)
|
117 |
+
# logits: [batch, seq_len] → raw logits for watermark presence detection
|
118 |
+
# chunk_logits: [batch, nchunks, 256] → classification logits for each byte
|
119 |
+
|
120 |
+
# (1) Compute watermark detection probability
|
121 |
+
detected = torch.sigmoid(logits) # [batch, seq_len]
|
122 |
+
detect_prob = detected.mean(dim=-1).cpu().item()
|
123 |
+
|
124 |
+
# (2) Decode the message: chunk_logits has shape [batch, nchunks, 256]
|
125 |
+
chunk_probs = F.softmax(chunk_logits, dim=-1) # [batch, nchunks, 256]
|
126 |
+
chunk_indices = torch.argmax(
|
127 |
+
chunk_probs, dim=-1
|
128 |
+
) # [batch, nchunks], each in [0..255]
|
129 |
+
# (3) Convert each byte back to 8 bits
|
130 |
+
# Finally, assemble into a [batch, nbits] binary tensor
|
131 |
+
binary_message = []
|
132 |
+
for i in range(self.nchunks):
|
133 |
+
chunk_val = chunk_indices[:, i] # [batch]
|
134 |
+
# Extract 8 bits from the integer (0..255)
|
135 |
+
chunk_bits = []
|
136 |
+
for b in range(self.nchunk_size):
|
137 |
+
bit_b = (chunk_val >> b) & 1 # get bit b
|
138 |
+
chunk_bits.append(bit_b.unsqueeze(-1))
|
139 |
+
# Concatenate bits to shape [batch, 8]
|
140 |
+
chunk_bits = torch.cat(chunk_bits, dim=-1)
|
141 |
+
binary_message.append(chunk_bits)
|
142 |
+
|
143 |
+
# Concatenate all bytes → [batch, nbits]
|
144 |
+
binary_message = torch.cat(binary_message, dim=-1)
|
145 |
+
|
146 |
+
return detect_prob, binary_message, detected
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
class WMEmbedder(nn.Module):
|
151 |
+
"""
|
152 |
+
A class that takes a secret message, processes it into chunk embeddings
|
153 |
+
(as a small sequence), and uses a TransformerDecoder to do cross-attention
|
154 |
+
between the original hidden (target) and the watermark tokens (memory).
|
155 |
+
"""
|
156 |
+
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
nbits: int, # total bits in the secret message
|
160 |
+
input_dim: int, # the input dimension (e.g. audio feature dimension)
|
161 |
+
nchunk_size: int,
|
162 |
+
hidden_dim: int = 256,
|
163 |
+
num_heads: int = 1,
|
164 |
+
num_layers: int = 4,
|
165 |
+
):
|
166 |
+
super().__init__()
|
167 |
+
self.nchunk_size = nchunk_size
|
168 |
+
assert nbits % nchunk_size == 0, "nbits must be a multiple of nchunk_size!"
|
169 |
+
self.nbits = nbits
|
170 |
+
self.nchunks = nbits // nchunk_size # how many chunks
|
171 |
+
|
172 |
+
# Each chunk (0..2^nchunk_size - 1) maps to an embedding of size [hidden_dim]
|
173 |
+
self.msg_embeddings = nn.ModuleList(
|
174 |
+
nn.Embedding(2**nchunk_size, hidden_dim) for _ in range(self.nchunks)
|
175 |
+
)
|
176 |
+
|
177 |
+
# Linear to project [input_dim] -> [hidden_dim]
|
178 |
+
self.input_projection = nn.Linear(input_dim, hidden_dim)
|
179 |
+
|
180 |
+
# TransformerDecoder for cross-attention
|
181 |
+
# d_model=hidden_dim, so the decoder expects [b, seq_len, hidden_dim] as tgt
|
182 |
+
# and [b, memory_len, hidden_dim] as memory
|
183 |
+
decoder_layer = nn.TransformerDecoderLayer(
|
184 |
+
d_model=hidden_dim,
|
185 |
+
nhead=num_heads,
|
186 |
+
dim_feedforward=2 * hidden_dim,
|
187 |
+
activation="gelu",
|
188 |
+
batch_first=True, # so shape is [batch, seq, feature]
|
189 |
+
)
|
190 |
+
self.transformer_decoder = nn.TransformerDecoder(
|
191 |
+
decoder_layer, num_layers=num_layers
|
192 |
+
)
|
193 |
+
|
194 |
+
# Project [hidden_dim] -> [input_dim]
|
195 |
+
# self.output_projection1 = nn.Linear(hidden_dim * 2, hidden_dim)
|
196 |
+
self.output_projection = nn.Linear(hidden_dim, input_dim)
|
197 |
+
|
198 |
+
def forward(self, hidden: torch.Tensor, msg: torch.Tensor) -> torch.Tensor:
|
199 |
+
"""
|
200 |
+
Args:
|
201 |
+
hidden: [batch, input_dim, seq_len]
|
202 |
+
msg: [batch, nbits]
|
203 |
+
Returns:
|
204 |
+
A tensor [batch, input_dim, seq_len] with watermark injected.
|
205 |
+
"""
|
206 |
+
b, in_dim, seq_len = hidden.shape
|
207 |
+
|
208 |
+
# 1) Project input features to [b, seq_len, hidden_dim]
|
209 |
+
hidden_projected = self.input_projection(
|
210 |
+
hidden.permute(0, 2, 1)
|
211 |
+
) # => [b, seq_len, hidden_dim]
|
212 |
+
|
213 |
+
# 2) Convert the msg bits into a sequence of chunk embeddings
|
214 |
+
# We keep each chunk as one token => [b, nchunks, hidden_dim]
|
215 |
+
chunk_emb_list = []
|
216 |
+
for i in range(self.nchunks):
|
217 |
+
# msg[:, i*nchunk_size : (i+1)*nchunk_size] => shape [b, nchunk_size]
|
218 |
+
chunk_bits = msg[:, i * self.nchunk_size : (i + 1) * self.nchunk_size]
|
219 |
+
chunk_val = torch.zeros_like(chunk_bits[:, 0]) # shape [b]
|
220 |
+
for bit_idx in range(self.nchunk_size):
|
221 |
+
# shift bits
|
222 |
+
chunk_val += chunk_bits[:, bit_idx] << bit_idx
|
223 |
+
|
224 |
+
# embedding => [b, hidden_dim]
|
225 |
+
chunk_emb = self.msg_embeddings[i](chunk_val)
|
226 |
+
chunk_emb_list.append(chunk_emb.unsqueeze(1)) # => [b,1,hidden_dim]
|
227 |
+
|
228 |
+
# Concat => [b, nchunks, hidden_dim]
|
229 |
+
chunk_emb_seq = torch.cat(chunk_emb_list, dim=1) # [b, nchunks, hidden_dim]
|
230 |
+
|
231 |
+
# 3) Use chunk_emb_seq as memory, hidden_projected as target for TransformerDecoder
|
232 |
+
#
|
233 |
+
# TransformerDecoder forward signature:
|
234 |
+
# transformer_decoder(tgt, memory, ...)
|
235 |
+
# => [b, seq_len, hidden_dim]
|
236 |
+
x_decoded = self.transformer_decoder(
|
237 |
+
tgt=hidden_projected, # [b, seq_len, hidden_dim]
|
238 |
+
memory=chunk_emb_seq, # [b, nchunks, hidden_dim]
|
239 |
+
)
|
240 |
+
|
241 |
+
# 4) Project back to input_dim => [b, seq_len, input_dim]
|
242 |
+
x_output = self.output_projection(x_decoded)
|
243 |
+
|
244 |
+
# 5) permute back to [b, input_dim, seq_len]
|
245 |
+
x_output = x_output.permute(0, 2, 1) # => [b, input_dim, seq_len]
|
246 |
+
|
247 |
+
# 6) (Optional) Residual with original hidden
|
248 |
+
x_output = x_output + hidden
|
249 |
+
|
250 |
+
return x_output
|
251 |
+
|
252 |
+
|
253 |
+
from speechtokenizer import SpeechTokenizer
|
254 |
+
|
255 |
+
|
256 |
+
class SBW(nn.Module):
|
257 |
+
def __init__(self):
|
258 |
+
super().__init__()
|
259 |
+
self.nbits = 16
|
260 |
+
config_path = (
|
261 |
+
"speechtokenizer/pretrained_model/speechtokenizer_hubert_avg_config.json"
|
262 |
+
)
|
263 |
+
ckpt_path = "speechtokenizer/pretrained_model/SpeechTokenizer.pt"
|
264 |
+
self.st_model = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
|
265 |
+
self.msg_processor = WMEmbedder(
|
266 |
+
nbits=16,
|
267 |
+
input_dim=1024,
|
268 |
+
nchunk_size=4,
|
269 |
+
)
|
270 |
+
self.detector = WMDetector(
|
271 |
+
1024,
|
272 |
+
16,
|
273 |
+
nchunk_size=4,
|
274 |
+
)
|
275 |
+
|
276 |
+
def detect_watermark(
|
277 |
+
self, x: torch.Tensor, return_logits=False
|
278 |
+
) -> Tuple[float, torch.Tensor]:
|
279 |
+
embedding = self.st_model.forward_feature(x)
|
280 |
+
if return_logits:
|
281 |
+
return self.detector(embedding)
|
282 |
+
return self.detector.detect_watermark(embedding)
|
283 |
+
|
284 |
+
def forward(
|
285 |
+
self,
|
286 |
+
speech_input: torch.Tensor,
|
287 |
+
message: Optional[torch.Tensor] = None,
|
288 |
+
) -> Dict[str, torch.Tensor]:
|
289 |
+
recon, recon_wm, acoustic, acoustic_wm = self.st_model(
|
290 |
+
speech_input, msg_processor=self.msg_processor, message=message
|
291 |
+
)
|
292 |
+
wav_length = min(speech_input.size(-1), recon_wm.size(-1))
|
293 |
+
speech_input = speech_input[..., :wav_length]
|
294 |
+
recon = recon[..., :wav_length]
|
295 |
+
recon_wm = recon_wm[..., :wav_length]
|
296 |
+
return {
|
297 |
+
"recon": recon,
|
298 |
+
"recon_wm": recon_wm,
|
299 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchaudio
|
2 |
+
beartype
|
3 |
+
einops
|
4 |
+
omegaconf
|
speechtokenizer/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .model import SpeechTokenizer
|
2 |
+
|
3 |
+
__version__ = '1.0.0'
|
speechtokenizer/model.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Wed Aug 30 15:47:55 2023
|
4 |
+
@author: zhangxin
|
5 |
+
"""
|
6 |
+
|
7 |
+
import random
|
8 |
+
from models import WMEmbedder
|
9 |
+
from .modules.seanet import SEANetEncoder, SEANetDecoder
|
10 |
+
from .quantization import ResidualVectorQuantizer
|
11 |
+
import torch.nn as nn
|
12 |
+
from einops import rearrange
|
13 |
+
import torch
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
|
17 |
+
class SpeechTokenizer(nn.Module):
|
18 |
+
def __init__(self, config):
|
19 |
+
"""
|
20 |
+
|
21 |
+
Parameters
|
22 |
+
----------
|
23 |
+
config : json
|
24 |
+
Model Config.
|
25 |
+
|
26 |
+
"""
|
27 |
+
super().__init__()
|
28 |
+
self.encoder = SEANetEncoder(
|
29 |
+
n_filters=config.get("n_filters"),
|
30 |
+
dimension=config.get("dimension"),
|
31 |
+
ratios=config.get("strides"),
|
32 |
+
lstm=config.get("lstm_layers"),
|
33 |
+
bidirectional=config.get("bidirectional"),
|
34 |
+
dilation_base=config.get("dilation_base"),
|
35 |
+
residual_kernel_size=config.get("residual_kernel_size"),
|
36 |
+
n_residual_layers=config.get("n_residual_layers"),
|
37 |
+
activation=config.get("activation"),
|
38 |
+
)
|
39 |
+
self.sample_rate = config.get("sample_rate")
|
40 |
+
self.n_q = config.get("n_q")
|
41 |
+
self.downsample_rate = np.prod(config.get("strides"))
|
42 |
+
if config.get("dimension") != config.get("semantic_dimension"):
|
43 |
+
self.transform = nn.Linear(
|
44 |
+
config.get("dimension"), config.get("semantic_dimension")
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
self.transform = nn.Identity()
|
48 |
+
self.quantizer = ResidualVectorQuantizer(
|
49 |
+
dimension=config.get("dimension"),
|
50 |
+
n_q=config.get("n_q"),
|
51 |
+
bins=config.get("codebook_size"),
|
52 |
+
)
|
53 |
+
self.decoder = SEANetDecoder(
|
54 |
+
n_filters=config.get("n_filters"),
|
55 |
+
dimension=config.get("dimension"),
|
56 |
+
ratios=config.get("strides"),
|
57 |
+
lstm=config.get("lstm_layers"),
|
58 |
+
bidirectional=False,
|
59 |
+
dilation_base=config.get("dilation_base"),
|
60 |
+
residual_kernel_size=config.get("residual_kernel_size"),
|
61 |
+
n_residual_layers=config.get("n_residual_layers"),
|
62 |
+
activation=config.get("activation"),
|
63 |
+
)
|
64 |
+
|
65 |
+
@classmethod
|
66 |
+
def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
|
67 |
+
"""
|
68 |
+
|
69 |
+
Parameters
|
70 |
+
----------
|
71 |
+
config_path : str
|
72 |
+
Path of model configuration file.
|
73 |
+
ckpt_path : str
|
74 |
+
Path of model checkpoint.
|
75 |
+
|
76 |
+
Returns
|
77 |
+
-------
|
78 |
+
model : SpeechTokenizer
|
79 |
+
SpeechTokenizer model.
|
80 |
+
|
81 |
+
"""
|
82 |
+
import json
|
83 |
+
|
84 |
+
with open(config_path) as f:
|
85 |
+
cfg = json.load(f)
|
86 |
+
model = cls(cfg)
|
87 |
+
params = torch.load(ckpt_path, map_location="cpu")
|
88 |
+
model.load_state_dict(params)
|
89 |
+
return model
|
90 |
+
|
91 |
+
def forward(
|
92 |
+
self,
|
93 |
+
x: torch.tensor,
|
94 |
+
n_q: int = None,
|
95 |
+
layers: list = [0],
|
96 |
+
msg_processor: WMEmbedder = None,
|
97 |
+
message: torch.Tensor = None,
|
98 |
+
):
|
99 |
+
"""
|
100 |
+
|
101 |
+
Parameters
|
102 |
+
----------
|
103 |
+
x : torch.tensor
|
104 |
+
Input wavs. Shape: (batch, channels, timesteps).
|
105 |
+
n_q : int, optional
|
106 |
+
Number of quantizers in RVQ used to encode. The default is all layers.
|
107 |
+
layers : list[int], optional
|
108 |
+
Layers of RVQ should return quantized result. The default is the first layer.
|
109 |
+
|
110 |
+
Returns
|
111 |
+
-------
|
112 |
+
o : torch.tensor
|
113 |
+
Output wavs. Shape: (batch, channels, timesteps).
|
114 |
+
commit_loss : torch.tensor
|
115 |
+
Commitment loss from residual vector quantizers.
|
116 |
+
feature : torch.tensor
|
117 |
+
Output of RVQ's first layer. Shape: (batch, timesteps, dimension)
|
118 |
+
|
119 |
+
"""
|
120 |
+
with torch.no_grad():
|
121 |
+
e = self.encoder(x)
|
122 |
+
quantized_full, _, _, quantized_list = self.quantizer(
|
123 |
+
e, n_q=n_q, layers=[0, 1, 2, 3, 4, 5, 6, 7], st=0
|
124 |
+
)
|
125 |
+
# semantic, _, _, _ = self.quantizer(e, n_q=1, st=0)
|
126 |
+
# acoustic = e - semantic
|
127 |
+
o = self.decoder(quantized_full)
|
128 |
+
|
129 |
+
subset = quantized_list[1:]
|
130 |
+
|
131 |
+
# half_len = len(subset) // 2
|
132 |
+
# selected_for_processing = random.sample(subset, half_len)
|
133 |
+
|
134 |
+
# selected_ids = set(id(x) for x in selected_for_processing)
|
135 |
+
# acoustic_wm = sum(
|
136 |
+
# msg_processor(x, message) if id(x) in selected_ids else x for x in subset
|
137 |
+
# )
|
138 |
+
|
139 |
+
acoustic_wm = sum(msg_processor(x, message) for x in subset)
|
140 |
+
|
141 |
+
acoustic = e - quantized_list[0]
|
142 |
+
|
143 |
+
# e_wm = acoustic_wm
|
144 |
+
e_wm = quantized_list[0] + acoustic_wm
|
145 |
+
|
146 |
+
o_wm = self.decoder(e_wm)
|
147 |
+
|
148 |
+
return (o, o_wm, acoustic, acoustic_wm)
|
149 |
+
|
150 |
+
def forward_feature(self, x: torch.tensor, layers: list = None):
|
151 |
+
"""
|
152 |
+
|
153 |
+
Parameters
|
154 |
+
----------
|
155 |
+
x : torch.tensor
|
156 |
+
Input wavs. Shape should be (batch, channels, timesteps).
|
157 |
+
layers : list[int], optional
|
158 |
+
Layers of RVQ should return quantized result. The default is all layers.
|
159 |
+
|
160 |
+
Returns
|
161 |
+
-------
|
162 |
+
quantized_list : list[torch.tensor]
|
163 |
+
Quantized of required layers.
|
164 |
+
|
165 |
+
"""
|
166 |
+
e = self.encoder(x)
|
167 |
+
with torch.no_grad():
|
168 |
+
semantic, _, _, _ = self.quantizer(e, st=0, n_q=1)
|
169 |
+
acoustic = e - semantic
|
170 |
+
return acoustic
|
171 |
+
|
172 |
+
def encode(self, x: torch.tensor, n_q: int = None, st: int = None):
|
173 |
+
"""
|
174 |
+
|
175 |
+
Parameters
|
176 |
+
----------
|
177 |
+
x : torch.tensor
|
178 |
+
Input wavs. Shape: (batch, channels, timesteps).
|
179 |
+
n_q : int, optional
|
180 |
+
Number of quantizers in RVQ used to encode. The default is all layers.
|
181 |
+
st : int, optional
|
182 |
+
Start quantizer index in RVQ. The default is 0.
|
183 |
+
|
184 |
+
Returns
|
185 |
+
-------
|
186 |
+
codes : torch.tensor
|
187 |
+
Output indices for each quantizer. Shape: (n_q, batch, timesteps)
|
188 |
+
|
189 |
+
"""
|
190 |
+
e = self.encoder(x)
|
191 |
+
if st is None:
|
192 |
+
st = 0
|
193 |
+
n_q = n_q if n_q else self.n_q
|
194 |
+
codes = self.quantizer.encode(e, n_q=n_q, st=st)
|
195 |
+
return codes
|
196 |
+
|
197 |
+
def decode(self, codes: torch.tensor, st: int = 0):
|
198 |
+
"""
|
199 |
+
|
200 |
+
Parameters
|
201 |
+
----------
|
202 |
+
codes : torch.tensor
|
203 |
+
Indices for each quantizer. Shape: (n_q, batch, timesteps).
|
204 |
+
st : int, optional
|
205 |
+
Start quantizer index in RVQ. The default is 0.
|
206 |
+
|
207 |
+
Returns
|
208 |
+
-------
|
209 |
+
o : torch.tensor
|
210 |
+
Reconstruct wavs from codes. Shape: (batch, channels, timesteps)
|
211 |
+
|
212 |
+
"""
|
213 |
+
quantized = self.quantizer.decode(codes, st=st)
|
214 |
+
o = self.decoder(quantized)
|
215 |
+
return o
|
speechtokenizer/modules/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""Torch modules."""
|
8 |
+
|
9 |
+
# flake8: noqa
|
10 |
+
from .conv import (
|
11 |
+
pad1d,
|
12 |
+
unpad1d,
|
13 |
+
NormConv1d,
|
14 |
+
NormConvTranspose1d,
|
15 |
+
NormConv2d,
|
16 |
+
NormConvTranspose2d,
|
17 |
+
SConv1d,
|
18 |
+
SConvTranspose1d,
|
19 |
+
)
|
20 |
+
from .lstm import SLSTM
|
21 |
+
from .seanet import SEANetEncoder, SEANetDecoder
|
speechtokenizer/modules/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (500 Bytes). View file
|
|
speechtokenizer/modules/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (654 Bytes). View file
|
|
speechtokenizer/modules/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (494 Bytes). View file
|
|
speechtokenizer/modules/__pycache__/conv.cpython-310.pyc
ADDED
Binary file (9.2 kB). View file
|
|
speechtokenizer/modules/__pycache__/conv.cpython-311.pyc
ADDED
Binary file (15.3 kB). View file
|
|
speechtokenizer/modules/__pycache__/conv.cpython-39.pyc
ADDED
Binary file (9.43 kB). View file
|
|
speechtokenizer/modules/__pycache__/lstm.cpython-310.pyc
ADDED
Binary file (1.14 kB). View file
|
|
speechtokenizer/modules/__pycache__/lstm.cpython-311.pyc
ADDED
Binary file (1.8 kB). View file
|
|
speechtokenizer/modules/__pycache__/lstm.cpython-39.pyc
ADDED
Binary file (1.13 kB). View file
|
|
speechtokenizer/modules/__pycache__/norm.cpython-310.pyc
ADDED
Binary file (1.15 kB). View file
|
|
speechtokenizer/modules/__pycache__/norm.cpython-311.pyc
ADDED
Binary file (1.69 kB). View file
|
|
speechtokenizer/modules/__pycache__/norm.cpython-39.pyc
ADDED
Binary file (1.15 kB). View file
|
|
speechtokenizer/modules/__pycache__/seanet.cpython-310.pyc
ADDED
Binary file (11 kB). View file
|
|
speechtokenizer/modules/__pycache__/seanet.cpython-311.pyc
ADDED
Binary file (16.6 kB). View file
|
|
speechtokenizer/modules/__pycache__/seanet.cpython-39.pyc
ADDED
Binary file (10.5 kB). View file
|
|
speechtokenizer/modules/conv.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""Convolutional layers wrappers and utilities."""
|
8 |
+
|
9 |
+
import math
|
10 |
+
import typing as tp
|
11 |
+
import warnings
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
from torch.nn import functional as F
|
16 |
+
from torch.nn.utils import spectral_norm, weight_norm
|
17 |
+
|
18 |
+
from .norm import ConvLayerNorm
|
19 |
+
|
20 |
+
|
21 |
+
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
22 |
+
'time_layer_norm', 'layer_norm', 'time_group_norm'])
|
23 |
+
|
24 |
+
|
25 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
|
26 |
+
assert norm in CONV_NORMALIZATIONS
|
27 |
+
if norm == 'weight_norm':
|
28 |
+
return weight_norm(module)
|
29 |
+
elif norm == 'spectral_norm':
|
30 |
+
return spectral_norm(module)
|
31 |
+
else:
|
32 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
33 |
+
# doesn't need reparametrization.
|
34 |
+
return module
|
35 |
+
|
36 |
+
|
37 |
+
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
|
38 |
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
39 |
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
40 |
+
"""
|
41 |
+
assert norm in CONV_NORMALIZATIONS
|
42 |
+
if norm == 'layer_norm':
|
43 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
44 |
+
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
45 |
+
elif norm == 'time_group_norm':
|
46 |
+
if causal:
|
47 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
48 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
49 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
50 |
+
else:
|
51 |
+
return nn.Identity()
|
52 |
+
|
53 |
+
|
54 |
+
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
55 |
+
padding_total: int = 0) -> int:
|
56 |
+
"""See `pad_for_conv1d`.
|
57 |
+
"""
|
58 |
+
length = x.shape[-1]
|
59 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
60 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
61 |
+
return ideal_length - length
|
62 |
+
|
63 |
+
|
64 |
+
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
|
65 |
+
"""Pad for a convolution to make sure that the last window is full.
|
66 |
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
67 |
+
an output of the same length, as otherwise, even with padding, some time steps
|
68 |
+
might get removed.
|
69 |
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
70 |
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
71 |
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
72 |
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
73 |
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
74 |
+
"""
|
75 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
76 |
+
return F.pad(x, (0, extra_padding))
|
77 |
+
|
78 |
+
|
79 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
|
80 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
81 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
82 |
+
"""
|
83 |
+
length = x.shape[-1]
|
84 |
+
padding_left, padding_right = paddings
|
85 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
86 |
+
if mode == 'reflect':
|
87 |
+
max_pad = max(padding_left, padding_right)
|
88 |
+
extra_pad = 0
|
89 |
+
if length <= max_pad:
|
90 |
+
extra_pad = max_pad - length + 1
|
91 |
+
x = F.pad(x, (0, extra_pad))
|
92 |
+
padded = F.pad(x, paddings, mode, value)
|
93 |
+
end = padded.shape[-1] - extra_pad
|
94 |
+
return padded[..., :end]
|
95 |
+
else:
|
96 |
+
return F.pad(x, paddings, mode, value)
|
97 |
+
|
98 |
+
|
99 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
100 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
101 |
+
padding_left, padding_right = paddings
|
102 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
103 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
104 |
+
end = x.shape[-1] - padding_right
|
105 |
+
return x[..., padding_left: end]
|
106 |
+
|
107 |
+
|
108 |
+
class NormConv1d(nn.Module):
|
109 |
+
"""Wrapper around Conv1d and normalization applied to this conv
|
110 |
+
to provide a uniform interface across normalization approaches.
|
111 |
+
"""
|
112 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
113 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
114 |
+
super().__init__()
|
115 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
116 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
117 |
+
self.norm_type = norm
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x = self.conv(x)
|
121 |
+
x = self.norm(x)
|
122 |
+
return x
|
123 |
+
|
124 |
+
|
125 |
+
class NormConv2d(nn.Module):
|
126 |
+
"""Wrapper around Conv2d and normalization applied to this conv
|
127 |
+
to provide a uniform interface across normalization approaches.
|
128 |
+
"""
|
129 |
+
def __init__(self, *args, norm: str = 'none',
|
130 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
131 |
+
super().__init__()
|
132 |
+
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
133 |
+
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
134 |
+
self.norm_type = norm
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
x = self.conv(x)
|
138 |
+
x = self.norm(x)
|
139 |
+
return x
|
140 |
+
|
141 |
+
|
142 |
+
class NormConvTranspose1d(nn.Module):
|
143 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
144 |
+
to provide a uniform interface across normalization approaches.
|
145 |
+
"""
|
146 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
147 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
148 |
+
super().__init__()
|
149 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
150 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
151 |
+
self.norm_type = norm
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
x = self.convtr(x)
|
155 |
+
x = self.norm(x)
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class NormConvTranspose2d(nn.Module):
|
160 |
+
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
161 |
+
to provide a uniform interface across normalization approaches.
|
162 |
+
"""
|
163 |
+
def __init__(self, *args, norm: str = 'none',
|
164 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
165 |
+
super().__init__()
|
166 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
|
167 |
+
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
168 |
+
|
169 |
+
def forward(self, x):
|
170 |
+
x = self.convtr(x)
|
171 |
+
x = self.norm(x)
|
172 |
+
return x
|
173 |
+
|
174 |
+
|
175 |
+
class SConv1d(nn.Module):
|
176 |
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
177 |
+
and normalization.
|
178 |
+
"""
|
179 |
+
def __init__(self, in_channels: int, out_channels: int,
|
180 |
+
kernel_size: int, stride: int = 1, dilation: int = 1,
|
181 |
+
groups: int = 1, bias: bool = True, causal: bool = False,
|
182 |
+
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
183 |
+
pad_mode: str = 'reflect'):
|
184 |
+
super().__init__()
|
185 |
+
# warn user on unusual setup between dilation and stride
|
186 |
+
if stride > 1 and dilation > 1:
|
187 |
+
warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
|
188 |
+
f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
|
189 |
+
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
190 |
+
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
191 |
+
norm=norm, norm_kwargs=norm_kwargs)
|
192 |
+
self.causal = causal
|
193 |
+
self.pad_mode = pad_mode
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
B, C, T = x.shape
|
197 |
+
kernel_size = self.conv.conv.kernel_size[0]
|
198 |
+
stride = self.conv.conv.stride[0]
|
199 |
+
dilation = self.conv.conv.dilation[0]
|
200 |
+
padding_total = (kernel_size - 1) * dilation - (stride - 1)
|
201 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
202 |
+
if self.causal:
|
203 |
+
# Left padding for causal
|
204 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
205 |
+
else:
|
206 |
+
# Asymmetric padding required for odd strides
|
207 |
+
padding_right = padding_total // 2
|
208 |
+
padding_left = padding_total - padding_right
|
209 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
210 |
+
return self.conv(x)
|
211 |
+
|
212 |
+
|
213 |
+
class SConvTranspose1d(nn.Module):
|
214 |
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
215 |
+
and normalization.
|
216 |
+
"""
|
217 |
+
def __init__(self, in_channels: int, out_channels: int,
|
218 |
+
kernel_size: int, stride: int = 1, causal: bool = False,
|
219 |
+
norm: str = 'none', trim_right_ratio: float = 1.,
|
220 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}):
|
221 |
+
super().__init__()
|
222 |
+
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
223 |
+
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
|
224 |
+
self.causal = causal
|
225 |
+
self.trim_right_ratio = trim_right_ratio
|
226 |
+
assert self.causal or self.trim_right_ratio == 1., \
|
227 |
+
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
228 |
+
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
229 |
+
|
230 |
+
def forward(self, x):
|
231 |
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
232 |
+
stride = self.convtr.convtr.stride[0]
|
233 |
+
padding_total = kernel_size - stride
|
234 |
+
|
235 |
+
y = self.convtr(x)
|
236 |
+
|
237 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
238 |
+
# removed at the very end, when keeping only the right length for the output,
|
239 |
+
# as removing it here would require also passing the length at the matching layer
|
240 |
+
# in the encoder.
|
241 |
+
if self.causal:
|
242 |
+
# Trim the padding on the right according to the specified ratio
|
243 |
+
# if trim_right_ratio = 1.0, trim everything from right
|
244 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
245 |
+
padding_left = padding_total - padding_right
|
246 |
+
y = unpad1d(y, (padding_left, padding_right))
|
247 |
+
else:
|
248 |
+
# Asymmetric padding required for odd strides
|
249 |
+
padding_right = padding_total // 2
|
250 |
+
padding_left = padding_total - padding_right
|
251 |
+
y = unpad1d(y, (padding_left, padding_right))
|
252 |
+
return y
|
speechtokenizer/modules/lstm.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""LSTM layers module."""
|
8 |
+
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
|
12 |
+
class SLSTM(nn.Module):
|
13 |
+
"""
|
14 |
+
LSTM without worrying about the hidden state, nor the layout of the data.
|
15 |
+
Expects input as convolutional layout.
|
16 |
+
"""
|
17 |
+
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True, bidirectional: bool=False):
|
18 |
+
super().__init__()
|
19 |
+
self.bidirectional = bidirectional
|
20 |
+
self.skip = skip
|
21 |
+
self.lstm = nn.LSTM(dimension, dimension, num_layers, bidirectional=bidirectional)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
x = x.permute(2, 0, 1)
|
25 |
+
y, _ = self.lstm(x)
|
26 |
+
if self.bidirectional:
|
27 |
+
x = x.repeat(1, 1, 2)
|
28 |
+
if self.skip:
|
29 |
+
y = y + x
|
30 |
+
y = y.permute(1, 2, 0)
|
31 |
+
return y
|
speechtokenizer/modules/norm.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""Normalization modules."""
|
8 |
+
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import einops
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
|
16 |
+
class ConvLayerNorm(nn.LayerNorm):
|
17 |
+
"""
|
18 |
+
Convolution-friendly LayerNorm that moves channels to last dimensions
|
19 |
+
before running the normalization and moves them back to original position right after.
|
20 |
+
"""
|
21 |
+
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
|
22 |
+
super().__init__(normalized_shape, **kwargs)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
x = einops.rearrange(x, 'b ... t -> b t ...')
|
26 |
+
x = super().forward(x)
|
27 |
+
x = einops.rearrange(x, 'b t ... -> b ... t')
|
28 |
+
return
|
speechtokenizer/modules/seanet.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""Encodec SEANet-based encoder and decoder implementation."""
|
8 |
+
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from . import SConv1d, SConvTranspose1d, SLSTM
|
16 |
+
|
17 |
+
|
18 |
+
@torch.jit.script
|
19 |
+
def snake(x, alpha):
|
20 |
+
shape = x.shape
|
21 |
+
x = x.reshape(shape[0], shape[1], -1)
|
22 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
23 |
+
x = x.reshape(shape)
|
24 |
+
return x
|
25 |
+
|
26 |
+
|
27 |
+
class Snake1d(nn.Module):
|
28 |
+
def __init__(self, channels):
|
29 |
+
super().__init__()
|
30 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
return snake(x, self.alpha)
|
34 |
+
|
35 |
+
|
36 |
+
class SEANetResnetBlock(nn.Module):
|
37 |
+
"""Residual block from SEANet model.
|
38 |
+
Args:
|
39 |
+
dim (int): Dimension of the input/output
|
40 |
+
kernel_sizes (list): List of kernel sizes for the convolutions.
|
41 |
+
dilations (list): List of dilations for the convolutions.
|
42 |
+
activation (str): Activation function.
|
43 |
+
activation_params (dict): Parameters to provide to the activation function
|
44 |
+
norm (str): Normalization method.
|
45 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
46 |
+
causal (bool): Whether to use fully causal convolution.
|
47 |
+
pad_mode (str): Padding mode for the convolutions.
|
48 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3)
|
49 |
+
true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
dim: int,
|
55 |
+
kernel_sizes: tp.List[int] = [3, 1],
|
56 |
+
dilations: tp.List[int] = [1, 1],
|
57 |
+
activation: str = "ELU",
|
58 |
+
activation_params: dict = {"alpha": 1.0},
|
59 |
+
norm: str = "weight_norm",
|
60 |
+
norm_params: tp.Dict[str, tp.Any] = {},
|
61 |
+
causal: bool = False,
|
62 |
+
pad_mode: str = "reflect",
|
63 |
+
compress: int = 2,
|
64 |
+
true_skip: bool = True,
|
65 |
+
):
|
66 |
+
super().__init__()
|
67 |
+
assert len(kernel_sizes) == len(
|
68 |
+
dilations
|
69 |
+
), "Number of kernel sizes should match number of dilations"
|
70 |
+
act = getattr(nn, activation) if activation != "Snake" else Snake1d
|
71 |
+
hidden = dim // compress
|
72 |
+
block = []
|
73 |
+
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
74 |
+
in_chs = dim if i == 0 else hidden
|
75 |
+
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
76 |
+
block += [
|
77 |
+
act(**activation_params) if activation != "Snake" else act(in_chs),
|
78 |
+
SConv1d(
|
79 |
+
in_chs,
|
80 |
+
out_chs,
|
81 |
+
kernel_size=kernel_size,
|
82 |
+
dilation=dilation,
|
83 |
+
norm=norm,
|
84 |
+
norm_kwargs=norm_params,
|
85 |
+
causal=causal,
|
86 |
+
pad_mode=pad_mode,
|
87 |
+
),
|
88 |
+
]
|
89 |
+
self.block = nn.Sequential(*block)
|
90 |
+
self.shortcut: nn.Module
|
91 |
+
if true_skip:
|
92 |
+
self.shortcut = nn.Identity()
|
93 |
+
else:
|
94 |
+
self.shortcut = SConv1d(
|
95 |
+
dim,
|
96 |
+
dim,
|
97 |
+
kernel_size=1,
|
98 |
+
norm=norm,
|
99 |
+
norm_kwargs=norm_params,
|
100 |
+
causal=causal,
|
101 |
+
pad_mode=pad_mode,
|
102 |
+
)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
return self.shortcut(x) + self.block(x)
|
106 |
+
|
107 |
+
|
108 |
+
class SEANetEncoder(nn.Module):
|
109 |
+
"""SEANet encoder.
|
110 |
+
Args:
|
111 |
+
channels (int): Audio channels.
|
112 |
+
dimension (int): Intermediate representation dimension.
|
113 |
+
n_filters (int): Base width for the model.
|
114 |
+
n_residual_layers (int): nb of residual layers.
|
115 |
+
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
116 |
+
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
117 |
+
that must match the decoder order
|
118 |
+
activation (str): Activation function.
|
119 |
+
activation_params (dict): Parameters to provide to the activation function
|
120 |
+
norm (str): Normalization method.
|
121 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
122 |
+
kernel_size (int): Kernel size for the initial convolution.
|
123 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
124 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
125 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
126 |
+
causal (bool): Whether to use fully causal convolution.
|
127 |
+
pad_mode (str): Padding mode for the convolutions.
|
128 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
129 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
130 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
131 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
132 |
+
"""
|
133 |
+
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
channels: int = 1,
|
137 |
+
dimension: int = 128,
|
138 |
+
n_filters: int = 32,
|
139 |
+
n_residual_layers: int = 1,
|
140 |
+
ratios: tp.List[int] = [8, 5, 4, 2],
|
141 |
+
activation: str = "ELU",
|
142 |
+
activation_params: dict = {"alpha": 1.0},
|
143 |
+
norm: str = "weight_norm",
|
144 |
+
norm_params: tp.Dict[str, tp.Any] = {},
|
145 |
+
kernel_size: int = 7,
|
146 |
+
last_kernel_size: int = 7,
|
147 |
+
residual_kernel_size: int = 3,
|
148 |
+
dilation_base: int = 2,
|
149 |
+
causal: bool = False,
|
150 |
+
pad_mode: str = "reflect",
|
151 |
+
true_skip: bool = False,
|
152 |
+
compress: int = 2,
|
153 |
+
lstm: int = 2,
|
154 |
+
bidirectional: bool = False,
|
155 |
+
):
|
156 |
+
super().__init__()
|
157 |
+
self.channels = channels
|
158 |
+
self.dimension = dimension
|
159 |
+
self.n_filters = n_filters
|
160 |
+
self.ratios = list(reversed(ratios))
|
161 |
+
del ratios
|
162 |
+
self.n_residual_layers = n_residual_layers
|
163 |
+
self.hop_length = np.prod(self.ratios) # 计算乘积
|
164 |
+
|
165 |
+
act = getattr(nn, activation) if activation != "Snake" else Snake1d
|
166 |
+
mult = 1
|
167 |
+
model: tp.List[nn.Module] = [
|
168 |
+
SConv1d(
|
169 |
+
channels,
|
170 |
+
mult * n_filters,
|
171 |
+
kernel_size,
|
172 |
+
norm=norm,
|
173 |
+
norm_kwargs=norm_params,
|
174 |
+
causal=causal,
|
175 |
+
pad_mode=pad_mode,
|
176 |
+
)
|
177 |
+
]
|
178 |
+
# Downsample to raw audio scale
|
179 |
+
for i, ratio in enumerate(self.ratios):
|
180 |
+
# Add residual layers
|
181 |
+
for j in range(n_residual_layers):
|
182 |
+
model += [
|
183 |
+
SEANetResnetBlock(
|
184 |
+
mult * n_filters,
|
185 |
+
kernel_sizes=[residual_kernel_size, 1],
|
186 |
+
dilations=[dilation_base**j, 1],
|
187 |
+
norm=norm,
|
188 |
+
norm_params=norm_params,
|
189 |
+
activation=activation,
|
190 |
+
activation_params=activation_params,
|
191 |
+
causal=causal,
|
192 |
+
pad_mode=pad_mode,
|
193 |
+
compress=compress,
|
194 |
+
true_skip=true_skip,
|
195 |
+
)
|
196 |
+
]
|
197 |
+
|
198 |
+
# Add downsampling layers
|
199 |
+
model += [
|
200 |
+
(
|
201 |
+
act(**activation_params)
|
202 |
+
if activation != "Snake"
|
203 |
+
else act(mult * n_filters)
|
204 |
+
),
|
205 |
+
SConv1d(
|
206 |
+
mult * n_filters,
|
207 |
+
mult * n_filters * 2,
|
208 |
+
kernel_size=ratio * 2,
|
209 |
+
stride=ratio,
|
210 |
+
norm=norm,
|
211 |
+
norm_kwargs=norm_params,
|
212 |
+
causal=causal,
|
213 |
+
pad_mode=pad_mode,
|
214 |
+
),
|
215 |
+
]
|
216 |
+
mult *= 2
|
217 |
+
|
218 |
+
if lstm:
|
219 |
+
model += [
|
220 |
+
SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)
|
221 |
+
]
|
222 |
+
|
223 |
+
mult = mult * 2 if bidirectional else mult
|
224 |
+
model += [
|
225 |
+
(
|
226 |
+
act(**activation_params)
|
227 |
+
if activation != "Snake"
|
228 |
+
else act(mult * n_filters)
|
229 |
+
),
|
230 |
+
SConv1d(
|
231 |
+
mult * n_filters,
|
232 |
+
dimension,
|
233 |
+
last_kernel_size,
|
234 |
+
norm=norm,
|
235 |
+
norm_kwargs=norm_params,
|
236 |
+
causal=causal,
|
237 |
+
pad_mode=pad_mode,
|
238 |
+
),
|
239 |
+
]
|
240 |
+
|
241 |
+
self.model = nn.Sequential(*model)
|
242 |
+
|
243 |
+
def forward(self, x):
|
244 |
+
return self.model(x)
|
245 |
+
|
246 |
+
|
247 |
+
class SEANetDecoder(nn.Module):
|
248 |
+
"""SEANet decoder.
|
249 |
+
Args:
|
250 |
+
channels (int): Audio channels.
|
251 |
+
dimension (int): Intermediate representation dimension.
|
252 |
+
n_filters (int): Base width for the model.
|
253 |
+
n_residual_layers (int): nb of residual layers.
|
254 |
+
ratios (Sequence[int]): kernel size and stride ratios
|
255 |
+
activation (str): Activation function.
|
256 |
+
activation_params (dict): Parameters to provide to the activation function
|
257 |
+
final_activation (str): Final activation function after all convolutions.
|
258 |
+
final_activation_params (dict): Parameters to provide to the activation function
|
259 |
+
norm (str): Normalization method.
|
260 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
261 |
+
kernel_size (int): Kernel size for the initial convolution.
|
262 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
263 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
264 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
265 |
+
causal (bool): Whether to use fully causal convolution.
|
266 |
+
pad_mode (str): Padding mode for the convolutions.
|
267 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
268 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
269 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
270 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
271 |
+
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
272 |
+
If equal to 1.0, it means that all the trimming is done at the right.
|
273 |
+
"""
|
274 |
+
|
275 |
+
def __init__(
|
276 |
+
self,
|
277 |
+
channels: int = 1,
|
278 |
+
dimension: int = 128,
|
279 |
+
n_filters: int = 32,
|
280 |
+
n_residual_layers: int = 1,
|
281 |
+
ratios: tp.List[int] = [8, 5, 4, 2],
|
282 |
+
activation: str = "ELU",
|
283 |
+
activation_params: dict = {"alpha": 1.0},
|
284 |
+
final_activation: tp.Optional[str] = None,
|
285 |
+
final_activation_params: tp.Optional[dict] = None,
|
286 |
+
norm: str = "weight_norm",
|
287 |
+
norm_params: tp.Dict[str, tp.Any] = {},
|
288 |
+
kernel_size: int = 7,
|
289 |
+
last_kernel_size: int = 7,
|
290 |
+
residual_kernel_size: int = 3,
|
291 |
+
dilation_base: int = 2,
|
292 |
+
causal: bool = False,
|
293 |
+
pad_mode: str = "reflect",
|
294 |
+
true_skip: bool = False,
|
295 |
+
compress: int = 2,
|
296 |
+
lstm: int = 2,
|
297 |
+
trim_right_ratio: float = 1.0,
|
298 |
+
bidirectional: bool = False,
|
299 |
+
):
|
300 |
+
super().__init__()
|
301 |
+
self.dimension = dimension
|
302 |
+
self.channels = channels
|
303 |
+
self.n_filters = n_filters
|
304 |
+
self.ratios = ratios
|
305 |
+
del ratios
|
306 |
+
self.n_residual_layers = n_residual_layers
|
307 |
+
self.hop_length = np.prod(self.ratios)
|
308 |
+
|
309 |
+
act = getattr(nn, activation) if activation != "Snake" else Snake1d
|
310 |
+
mult = int(2 ** len(self.ratios))
|
311 |
+
model: tp.List[nn.Module] = [
|
312 |
+
SConv1d(
|
313 |
+
dimension,
|
314 |
+
mult * n_filters,
|
315 |
+
kernel_size,
|
316 |
+
norm=norm,
|
317 |
+
norm_kwargs=norm_params,
|
318 |
+
causal=causal,
|
319 |
+
pad_mode=pad_mode,
|
320 |
+
)
|
321 |
+
]
|
322 |
+
|
323 |
+
if lstm:
|
324 |
+
model += [
|
325 |
+
SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)
|
326 |
+
]
|
327 |
+
|
328 |
+
# Upsample to raw audio scale
|
329 |
+
for i, ratio in enumerate(self.ratios):
|
330 |
+
# Add upsampling layers
|
331 |
+
model += [
|
332 |
+
(
|
333 |
+
act(**activation_params)
|
334 |
+
if activation != "Snake"
|
335 |
+
else act(mult * n_filters)
|
336 |
+
),
|
337 |
+
SConvTranspose1d(
|
338 |
+
mult * n_filters,
|
339 |
+
mult * n_filters // 2,
|
340 |
+
kernel_size=ratio * 2,
|
341 |
+
stride=ratio,
|
342 |
+
norm=norm,
|
343 |
+
norm_kwargs=norm_params,
|
344 |
+
causal=causal,
|
345 |
+
trim_right_ratio=trim_right_ratio,
|
346 |
+
),
|
347 |
+
]
|
348 |
+
# Add residual layers
|
349 |
+
for j in range(n_residual_layers):
|
350 |
+
model += [
|
351 |
+
SEANetResnetBlock(
|
352 |
+
mult * n_filters // 2,
|
353 |
+
kernel_sizes=[residual_kernel_size, 1],
|
354 |
+
dilations=[dilation_base**j, 1],
|
355 |
+
activation=activation,
|
356 |
+
activation_params=activation_params,
|
357 |
+
norm=norm,
|
358 |
+
norm_params=norm_params,
|
359 |
+
causal=causal,
|
360 |
+
pad_mode=pad_mode,
|
361 |
+
compress=compress,
|
362 |
+
true_skip=true_skip,
|
363 |
+
)
|
364 |
+
]
|
365 |
+
|
366 |
+
mult //= 2
|
367 |
+
|
368 |
+
# Add final layers
|
369 |
+
model += [
|
370 |
+
act(**activation_params) if activation != "Snake" else act(n_filters),
|
371 |
+
SConv1d(
|
372 |
+
n_filters,
|
373 |
+
channels,
|
374 |
+
last_kernel_size,
|
375 |
+
norm=norm,
|
376 |
+
norm_kwargs=norm_params,
|
377 |
+
causal=causal,
|
378 |
+
pad_mode=pad_mode,
|
379 |
+
),
|
380 |
+
]
|
381 |
+
# Add optional final activation to decoder (eg. tanh)
|
382 |
+
if final_activation is not None:
|
383 |
+
final_act = getattr(nn, final_activation)
|
384 |
+
final_activation_params = final_activation_params or {}
|
385 |
+
model += [final_act(**final_activation_params)]
|
386 |
+
self.model = nn.Sequential(*model)
|
387 |
+
|
388 |
+
def forward(self, z):
|
389 |
+
y = self.model(z)
|
390 |
+
return y
|
391 |
+
|
392 |
+
|
393 |
+
def test():
|
394 |
+
import torch
|
395 |
+
|
396 |
+
encoder = SEANetEncoder()
|
397 |
+
decoder = SEANetDecoder()
|
398 |
+
x = torch.randn(1, 1, 24000)
|
399 |
+
z = encoder(x)
|
400 |
+
print("z ", z.shape)
|
401 |
+
assert 1 == 2
|
402 |
+
assert list(z.shape) == [1, 128, 75], z.shape
|
403 |
+
y = decoder(z)
|
404 |
+
assert y.shape == x.shape, (x.shape, y.shape)
|
405 |
+
|
406 |
+
|
407 |
+
if __name__ == "__main__":
|
408 |
+
test()
|
speechtokenizer/pretrained_model/SpeechTokenizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d04593b6c9a4b475f91ca481141a6ef5b23e6ac112f347dd2b2717f193c1c728
|
3 |
+
size 481906997
|
speechtokenizer/pretrained_model/speechtokenizer_hubert_avg_config.json
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"resblock": "1",
|
3 |
+
"num_gpus": 3,
|
4 |
+
"batch_size": 60,
|
5 |
+
"learning_rate": 0.0001,
|
6 |
+
"adam_b1": 0.5,
|
7 |
+
"adam_b2": 0.9,
|
8 |
+
"lr_decay": 0.98,
|
9 |
+
"seed": 1234,
|
10 |
+
"lambda_distill": 0.15,
|
11 |
+
|
12 |
+
"n_filters": 64,
|
13 |
+
"strides": [8,5,4,2],
|
14 |
+
"dimension": 1024,
|
15 |
+
"semantic_dimension": 768,
|
16 |
+
"bidirectional": true,
|
17 |
+
"dilation_base": 2,
|
18 |
+
"residual_kernel_size": 3,
|
19 |
+
"n_residual_layers": 1,
|
20 |
+
"lstm_layers": 2,
|
21 |
+
"activation": "ELU",
|
22 |
+
|
23 |
+
"segment_size": 48000,
|
24 |
+
"num_mels": 80,
|
25 |
+
"num_freq": 1025,
|
26 |
+
"n_fft": 1024,
|
27 |
+
"hop_size": 240,
|
28 |
+
"win_size": 1024,
|
29 |
+
|
30 |
+
"sampling_rate": 16000,
|
31 |
+
"sample_rate": 16000,
|
32 |
+
|
33 |
+
"codebook_size": 1024,
|
34 |
+
"n_q": 8,
|
35 |
+
|
36 |
+
"fmin": 0,
|
37 |
+
"fmax": 8000,
|
38 |
+
"fmax_for_loss": null,
|
39 |
+
|
40 |
+
"num_workers": 12,
|
41 |
+
|
42 |
+
"dist_config": {
|
43 |
+
"dist_backend": "nccl",
|
44 |
+
"dist_url": "tcp://localhost:54322",
|
45 |
+
"world_size": 1
|
46 |
+
}
|
47 |
+
}
|
speechtokenizer/quantization/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# flake8: noqa
|
8 |
+
from .vq import QuantizedResult, ResidualVectorQuantizer
|
speechtokenizer/quantization/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (249 Bytes). View file
|
|
speechtokenizer/quantization/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (283 Bytes). View file
|
|
speechtokenizer/quantization/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (243 Bytes). View file
|
|
speechtokenizer/quantization/__pycache__/core_vq.cpython-310.pyc
ADDED
Binary file (11.2 kB). View file
|
|
speechtokenizer/quantization/__pycache__/core_vq.cpython-311.pyc
ADDED
Binary file (18.9 kB). View file
|
|
speechtokenizer/quantization/__pycache__/core_vq.cpython-39.pyc
ADDED
Binary file (11.1 kB). View file
|
|
speechtokenizer/quantization/__pycache__/distrib.cpython-310.pyc
ADDED
Binary file (3.76 kB). View file
|
|
speechtokenizer/quantization/__pycache__/distrib.cpython-311.pyc
ADDED
Binary file (6.95 kB). View file
|
|
speechtokenizer/quantization/__pycache__/distrib.cpython-39.pyc
ADDED
Binary file (3.78 kB). View file
|
|
speechtokenizer/quantization/__pycache__/vq.cpython-310.pyc
ADDED
Binary file (4.62 kB). View file
|
|
speechtokenizer/quantization/__pycache__/vq.cpython-311.pyc
ADDED
Binary file (6.27 kB). View file
|
|
speechtokenizer/quantization/__pycache__/vq.cpython-39.pyc
ADDED
Binary file (4.57 kB). View file
|
|
speechtokenizer/quantization/ac.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""Arithmetic coder."""
|
8 |
+
|
9 |
+
import io
|
10 |
+
import math
|
11 |
+
import random
|
12 |
+
import typing as tp
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from ..binary import BitPacker, BitUnpacker
|
16 |
+
|
17 |
+
|
18 |
+
def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int,
|
19 |
+
roundoff: float = 1e-8, min_range: int = 2,
|
20 |
+
check: bool = True) -> torch.Tensor:
|
21 |
+
"""Turn the given PDF into a quantized CDF that splits
|
22 |
+
[0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
|
23 |
+
to the PDF.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
pdf (torch.Tensor): probability distribution, shape should be `[N]`.
|
27 |
+
total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
|
28 |
+
during the coding process is `[0, 2 ** total_range_bits - 1]`.
|
29 |
+
roundoff (float): will round the pdf up to that level to remove difference coming
|
30 |
+
from e.g. evaluating the Language Model on different architectures.
|
31 |
+
min_range (int): minimum range width. Should always be at least 2 for numerical
|
32 |
+
stability. Use this to avoid pathological behavior is a value
|
33 |
+
that is expected to be rare actually happens in real life.
|
34 |
+
check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
|
35 |
+
"""
|
36 |
+
pdf = pdf.detach()
|
37 |
+
if roundoff:
|
38 |
+
pdf = (pdf / roundoff).floor() * roundoff
|
39 |
+
# interpolate with uniform distribution to achieve desired minimum probability.
|
40 |
+
total_range = 2 ** total_range_bits
|
41 |
+
cardinality = len(pdf)
|
42 |
+
alpha = min_range * cardinality / total_range
|
43 |
+
assert alpha <= 1, "you must reduce min_range"
|
44 |
+
ranges = (((1 - alpha) * total_range) * pdf).floor().long()
|
45 |
+
ranges += min_range
|
46 |
+
quantized_cdf = torch.cumsum(ranges, dim=-1)
|
47 |
+
if min_range < 2:
|
48 |
+
raise ValueError("min_range must be at least 2.")
|
49 |
+
if check:
|
50 |
+
assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
|
51 |
+
if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
|
52 |
+
raise ValueError("You must increase your total_range_bits.")
|
53 |
+
return quantized_cdf
|
54 |
+
|
55 |
+
|
56 |
+
class ArithmeticCoder:
|
57 |
+
"""ArithmeticCoder,
|
58 |
+
Let us take a distribution `p` over `N` symbols, and assume we have a stream
|
59 |
+
of random variables `s_t` sampled from `p`. Let us assume that we have a budget
|
60 |
+
of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
|
61 |
+
corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
|
62 |
+
sequence `(s_t)` by doing the following:
|
63 |
+
|
64 |
+
1) Initialize the current range to` [0 ** 2 B - 1]`.
|
65 |
+
2) For each time step t, split the current range into contiguous chunks,
|
66 |
+
one for each possible outcome, with size roughly proportional to `p`.
|
67 |
+
For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
|
68 |
+
would be `{[0, 2], [3, 3]}`.
|
69 |
+
3) Select the chunk corresponding to `s_t`, and replace the current range with this.
|
70 |
+
4) When done encoding all the values, just select any value remaining in the range.
|
71 |
+
|
72 |
+
You will notice that this procedure can fail: for instance if at any point in time
|
73 |
+
the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
|
74 |
+
possible outcome. Intuitively, the more likely a value is, the less the range width
|
75 |
+
will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
|
76 |
+
coding scheme, likely outcomes would take less bits, and more of them can be coded
|
77 |
+
with a fixed budget.
|
78 |
+
|
79 |
+
In practice, we do not know `B` ahead of time, but we have a way to inject new bits
|
80 |
+
when the current range decreases below a given limit (given by `total_range_bits`), without
|
81 |
+
having to redo all the computations. If we encode mostly likely values, we will seldom
|
82 |
+
need to inject new bits, but a single rare value can deplete our stock of entropy!
|
83 |
+
|
84 |
+
In this explanation, we assumed that the distribution `p` was constant. In fact, the present
|
85 |
+
code works for any sequence `(p_t)` possibly different for each timestep.
|
86 |
+
We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
|
87 |
+
the KL between the true distribution and `p_t`, the most efficient the coding will be.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
fo (IO[bytes]): file-like object to which the bytes will be written to.
|
91 |
+
total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
|
92 |
+
Any time the current range width fall under this limit, new bits will
|
93 |
+
be injected to rescale the initial range.
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
|
97 |
+
assert total_range_bits <= 30
|
98 |
+
self.total_range_bits = total_range_bits
|
99 |
+
self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
|
100 |
+
self.low: int = 0
|
101 |
+
self.high: int = 0
|
102 |
+
self.max_bit: int = -1
|
103 |
+
self._dbg: tp.List[tp.Any] = []
|
104 |
+
self._dbg2: tp.List[tp.Any] = []
|
105 |
+
|
106 |
+
@property
|
107 |
+
def delta(self) -> int:
|
108 |
+
"""Return the current range width."""
|
109 |
+
return self.high - self.low + 1
|
110 |
+
|
111 |
+
def _flush_common_prefix(self):
|
112 |
+
# If self.low and self.high start with the sames bits,
|
113 |
+
# those won't change anymore as we always just increase the range
|
114 |
+
# by powers of 2, and we can flush them out to the bit stream.
|
115 |
+
assert self.high >= self.low, (self.low, self.high)
|
116 |
+
assert self.high < 2 ** (self.max_bit + 1)
|
117 |
+
while self.max_bit >= 0:
|
118 |
+
b1 = self.low >> self.max_bit
|
119 |
+
b2 = self.high >> self.max_bit
|
120 |
+
if b1 == b2:
|
121 |
+
self.low -= (b1 << self.max_bit)
|
122 |
+
self.high -= (b1 << self.max_bit)
|
123 |
+
assert self.high >= self.low, (self.high, self.low, self.max_bit)
|
124 |
+
assert self.low >= 0
|
125 |
+
self.max_bit -= 1
|
126 |
+
self.packer.push(b1)
|
127 |
+
else:
|
128 |
+
break
|
129 |
+
|
130 |
+
def push(self, symbol: int, quantized_cdf: torch.Tensor):
|
131 |
+
"""Push the given symbol on the stream, flushing out bits
|
132 |
+
if possible.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
symbol (int): symbol to encode with the AC.
|
136 |
+
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
|
137 |
+
to build this from your pdf estimate.
|
138 |
+
"""
|
139 |
+
while self.delta < 2 ** self.total_range_bits:
|
140 |
+
self.low *= 2
|
141 |
+
self.high = self.high * 2 + 1
|
142 |
+
self.max_bit += 1
|
143 |
+
|
144 |
+
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
|
145 |
+
range_high = quantized_cdf[symbol].item() - 1
|
146 |
+
effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
|
147 |
+
effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
|
148 |
+
assert self.low <= self.high
|
149 |
+
self.high = self.low + effective_high
|
150 |
+
self.low = self.low + effective_low
|
151 |
+
assert self.low <= self.high, (effective_low, effective_high, range_low, range_high)
|
152 |
+
self._dbg.append((self.low, self.high))
|
153 |
+
self._dbg2.append((self.low, self.high))
|
154 |
+
outs = self._flush_common_prefix()
|
155 |
+
assert self.low <= self.high
|
156 |
+
assert self.max_bit >= -1
|
157 |
+
assert self.max_bit <= 61, self.max_bit
|
158 |
+
return outs
|
159 |
+
|
160 |
+
def flush(self):
|
161 |
+
"""Flush the remaining information to the stream.
|
162 |
+
"""
|
163 |
+
while self.max_bit >= 0:
|
164 |
+
b1 = (self.low >> self.max_bit) & 1
|
165 |
+
self.packer.push(b1)
|
166 |
+
self.max_bit -= 1
|
167 |
+
self.packer.flush()
|
168 |
+
|
169 |
+
|
170 |
+
class ArithmeticDecoder:
|
171 |
+
"""ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
|
172 |
+
|
173 |
+
Note that this must be called with **exactly** the same parameters and sequence
|
174 |
+
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
|
175 |
+
|
176 |
+
If the AC encoder current range is [L, H], with `L` and `H` having the some common
|
177 |
+
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
|
178 |
+
For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
|
179 |
+
`[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
|
180 |
+
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
|
181 |
+
At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
|
182 |
+
and we will need to read new bits from the stream and repeat the process.
|
183 |
+
|
184 |
+
"""
|
185 |
+
def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
|
186 |
+
self.total_range_bits = total_range_bits
|
187 |
+
self.low: int = 0
|
188 |
+
self.high: int = 0
|
189 |
+
self.current: int = 0
|
190 |
+
self.max_bit: int = -1
|
191 |
+
self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
|
192 |
+
# Following is for debugging
|
193 |
+
self._dbg: tp.List[tp.Any] = []
|
194 |
+
self._dbg2: tp.List[tp.Any] = []
|
195 |
+
self._last: tp.Any = None
|
196 |
+
|
197 |
+
@property
|
198 |
+
def delta(self) -> int:
|
199 |
+
return self.high - self.low + 1
|
200 |
+
|
201 |
+
def _flush_common_prefix(self):
|
202 |
+
# Given the current range [L, H], if both have a common prefix,
|
203 |
+
# we know we can remove it from our representation to avoid handling large numbers.
|
204 |
+
while self.max_bit >= 0:
|
205 |
+
b1 = self.low >> self.max_bit
|
206 |
+
b2 = self.high >> self.max_bit
|
207 |
+
if b1 == b2:
|
208 |
+
self.low -= (b1 << self.max_bit)
|
209 |
+
self.high -= (b1 << self.max_bit)
|
210 |
+
self.current -= (b1 << self.max_bit)
|
211 |
+
assert self.high >= self.low
|
212 |
+
assert self.low >= 0
|
213 |
+
self.max_bit -= 1
|
214 |
+
else:
|
215 |
+
break
|
216 |
+
|
217 |
+
def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
|
218 |
+
"""Pull a symbol, reading as many bits from the stream as required.
|
219 |
+
This returns `None` when the stream has been exhausted.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
|
223 |
+
to build this from your pdf estimate. This must be **exatly**
|
224 |
+
the same cdf as the one used at encoding time.
|
225 |
+
"""
|
226 |
+
while self.delta < 2 ** self.total_range_bits:
|
227 |
+
bit = self.unpacker.pull()
|
228 |
+
if bit is None:
|
229 |
+
return None
|
230 |
+
self.low *= 2
|
231 |
+
self.high = self.high * 2 + 1
|
232 |
+
self.current = self.current * 2 + bit
|
233 |
+
self.max_bit += 1
|
234 |
+
|
235 |
+
def bin_search(low_idx: int, high_idx: int):
|
236 |
+
# Binary search is not just for coding interviews :)
|
237 |
+
if high_idx < low_idx:
|
238 |
+
raise RuntimeError("Binary search failed")
|
239 |
+
mid = (low_idx + high_idx) // 2
|
240 |
+
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
|
241 |
+
range_high = quantized_cdf[mid].item() - 1
|
242 |
+
effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
|
243 |
+
effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
|
244 |
+
low = effective_low + self.low
|
245 |
+
high = effective_high + self.low
|
246 |
+
if self.current >= low:
|
247 |
+
if self.current <= high:
|
248 |
+
return (mid, low, high, self.current)
|
249 |
+
else:
|
250 |
+
return bin_search(mid + 1, high_idx)
|
251 |
+
else:
|
252 |
+
return bin_search(low_idx, mid - 1)
|
253 |
+
|
254 |
+
self._last = (self.low, self.high, self.current, self.max_bit)
|
255 |
+
sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
|
256 |
+
self._dbg.append((self.low, self.high, self.current))
|
257 |
+
self._flush_common_prefix()
|
258 |
+
self._dbg2.append((self.low, self.high, self.current))
|
259 |
+
|
260 |
+
return sym
|
261 |
+
|
262 |
+
|
263 |
+
def test():
|
264 |
+
torch.manual_seed(1234)
|
265 |
+
random.seed(1234)
|
266 |
+
for _ in range(4):
|
267 |
+
pdfs = []
|
268 |
+
cardinality = random.randrange(4000)
|
269 |
+
steps = random.randrange(100, 500)
|
270 |
+
fo = io.BytesIO()
|
271 |
+
encoder = ArithmeticCoder(fo)
|
272 |
+
symbols = []
|
273 |
+
for step in range(steps):
|
274 |
+
pdf = torch.softmax(torch.randn(cardinality), dim=0)
|
275 |
+
pdfs.append(pdf)
|
276 |
+
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
|
277 |
+
symbol = torch.multinomial(pdf, 1).item()
|
278 |
+
symbols.append(symbol)
|
279 |
+
encoder.push(symbol, q_cdf)
|
280 |
+
encoder.flush()
|
281 |
+
|
282 |
+
fo.seek(0)
|
283 |
+
decoder = ArithmeticDecoder(fo)
|
284 |
+
for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
|
285 |
+
q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
|
286 |
+
decoded_symbol = decoder.pull(q_cdf)
|
287 |
+
assert decoded_symbol == symbol, idx
|
288 |
+
assert decoder.pull(torch.zeros(1)) is None
|
289 |
+
|
290 |
+
|
291 |
+
if __name__ == "__main__":
|
292 |
+
test()
|
speechtokenizer/quantization/core_vq.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# This implementation is inspired from
|
8 |
+
# https://github.com/lucidrains/vector-quantize-pytorch
|
9 |
+
# which is released under MIT License. Hereafter, the original license:
|
10 |
+
# MIT License
|
11 |
+
#
|
12 |
+
# Copyright (c) 2020 Phil Wang
|
13 |
+
#
|
14 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
15 |
+
# of this software and associated documentation files (the "Software"), to deal
|
16 |
+
# in the Software without restriction, including without limitation the rights
|
17 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
18 |
+
# copies of the Software, and to permit persons to whom the Software is
|
19 |
+
# furnished to do so, subject to the following conditions:
|
20 |
+
#
|
21 |
+
# The above copyright notice and this permission notice shall be included in all
|
22 |
+
# copies or substantial portions of the Software.
|
23 |
+
#
|
24 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
25 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
26 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
27 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
28 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
29 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
30 |
+
# SOFTWARE.
|
31 |
+
|
32 |
+
"""Core vector quantization implementation."""
|
33 |
+
import typing as tp
|
34 |
+
|
35 |
+
from einops import rearrange, repeat
|
36 |
+
import torch
|
37 |
+
from torch import nn
|
38 |
+
import torch.nn.functional as F
|
39 |
+
|
40 |
+
from .distrib import broadcast_tensors, rank
|
41 |
+
|
42 |
+
|
43 |
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
44 |
+
return val if val is not None else d
|
45 |
+
|
46 |
+
|
47 |
+
def ema_inplace(moving_avg, new, decay: float):
|
48 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
49 |
+
|
50 |
+
|
51 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
52 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
53 |
+
|
54 |
+
|
55 |
+
def uniform_init(*shape: int):
|
56 |
+
t = torch.empty(shape)
|
57 |
+
nn.init.kaiming_uniform_(t)
|
58 |
+
return t
|
59 |
+
|
60 |
+
|
61 |
+
def sample_vectors(samples, num: int):
|
62 |
+
num_samples, device = samples.shape[0], samples.device
|
63 |
+
|
64 |
+
if num_samples >= num:
|
65 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
66 |
+
else:
|
67 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
68 |
+
|
69 |
+
return samples[indices]
|
70 |
+
|
71 |
+
|
72 |
+
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
73 |
+
dim, dtype = samples.shape[-1], samples.dtype
|
74 |
+
|
75 |
+
means = sample_vectors(samples, num_clusters)
|
76 |
+
|
77 |
+
for _ in range(num_iters):
|
78 |
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
|
79 |
+
dists = -(diffs**2).sum(dim=-1)
|
80 |
+
|
81 |
+
buckets = dists.max(dim=-1).indices
|
82 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
83 |
+
zero_mask = bins == 0
|
84 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
85 |
+
|
86 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
87 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
88 |
+
new_means = new_means / bins_min_clamped[..., None]
|
89 |
+
|
90 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
91 |
+
|
92 |
+
return means, bins
|
93 |
+
|
94 |
+
|
95 |
+
class EuclideanCodebook(nn.Module):
|
96 |
+
"""Codebook with Euclidean distance.
|
97 |
+
Args:
|
98 |
+
dim (int): Dimension.
|
99 |
+
codebook_size (int): Codebook size.
|
100 |
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
101 |
+
If set to true, run the k-means algorithm on the first training batch and use
|
102 |
+
the learned centroids as initialization.
|
103 |
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
104 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
105 |
+
epsilon (float): Epsilon value for numerical stability.
|
106 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
107 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
108 |
+
randomly selected vector from the current batch.
|
109 |
+
"""
|
110 |
+
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
dim: int,
|
114 |
+
codebook_size: int,
|
115 |
+
kmeans_init: int = False,
|
116 |
+
kmeans_iters: int = 10,
|
117 |
+
decay: float = 0.99,
|
118 |
+
epsilon: float = 1e-5,
|
119 |
+
threshold_ema_dead_code: int = 2,
|
120 |
+
):
|
121 |
+
super().__init__()
|
122 |
+
self.decay = decay
|
123 |
+
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
|
124 |
+
uniform_init if not kmeans_init else torch.zeros
|
125 |
+
)
|
126 |
+
embed = init_fn(codebook_size, dim)
|
127 |
+
|
128 |
+
self.codebook_size = codebook_size
|
129 |
+
|
130 |
+
self.kmeans_iters = kmeans_iters
|
131 |
+
self.epsilon = epsilon
|
132 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
133 |
+
|
134 |
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
135 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
136 |
+
self.register_buffer("embed", embed)
|
137 |
+
self.register_buffer("embed_avg", embed.clone())
|
138 |
+
|
139 |
+
@torch.jit.ignore
|
140 |
+
def init_embed_(self, data):
|
141 |
+
if self.inited:
|
142 |
+
return
|
143 |
+
|
144 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
145 |
+
self.embed.data.copy_(embed)
|
146 |
+
self.embed_avg.data.copy_(embed.clone())
|
147 |
+
self.cluster_size.data.copy_(cluster_size)
|
148 |
+
self.inited.data.copy_(torch.Tensor([True]))
|
149 |
+
# Make sure all buffers across workers are in sync after initialization
|
150 |
+
# broadcast_tensors(self.buffers())
|
151 |
+
|
152 |
+
def replace_(self, samples, mask):
|
153 |
+
modified_codebook = torch.where(
|
154 |
+
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
155 |
+
)
|
156 |
+
self.embed.data.copy_(modified_codebook)
|
157 |
+
|
158 |
+
def expire_codes_(self, batch_samples):
|
159 |
+
if self.threshold_ema_dead_code == 0:
|
160 |
+
return
|
161 |
+
|
162 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
163 |
+
if not torch.any(expired_codes):
|
164 |
+
return
|
165 |
+
|
166 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
167 |
+
self.replace_(batch_samples, mask=expired_codes)
|
168 |
+
# broadcast_tensors(self.buffers())
|
169 |
+
|
170 |
+
def preprocess(self, x):
|
171 |
+
x = rearrange(x, "... d -> (...) d")
|
172 |
+
return x
|
173 |
+
|
174 |
+
def quantize(self, x):
|
175 |
+
embed = self.embed.t()
|
176 |
+
dist = -(
|
177 |
+
x.pow(2).sum(1, keepdim=True)
|
178 |
+
- 2 * x @ embed
|
179 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
180 |
+
)
|
181 |
+
embed_ind = dist.max(dim=-1).indices
|
182 |
+
return embed_ind
|
183 |
+
|
184 |
+
def postprocess_emb(self, embed_ind, shape):
|
185 |
+
return embed_ind.view(*shape[:-1])
|
186 |
+
|
187 |
+
def dequantize(self, embed_ind):
|
188 |
+
quantize = F.embedding(embed_ind, self.embed)
|
189 |
+
return quantize
|
190 |
+
|
191 |
+
def encode(self, x):
|
192 |
+
shape = x.shape
|
193 |
+
# pre-process
|
194 |
+
x = self.preprocess(x)
|
195 |
+
# quantize
|
196 |
+
embed_ind = self.quantize(x)
|
197 |
+
# post-process
|
198 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
199 |
+
return embed_ind
|
200 |
+
|
201 |
+
def decode(self, embed_ind):
|
202 |
+
quantize = self.dequantize(embed_ind)
|
203 |
+
return quantize
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
shape, dtype = x.shape, x.dtype
|
207 |
+
x = self.preprocess(x)
|
208 |
+
|
209 |
+
# self.init_embed_(x)
|
210 |
+
|
211 |
+
embed_ind = self.quantize(x)
|
212 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
213 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
214 |
+
quantize = self.dequantize(embed_ind)
|
215 |
+
|
216 |
+
# if self.training:
|
217 |
+
# # We do the expiry of code at that point as buffers are in sync
|
218 |
+
# # and all the workers will take the same decision.
|
219 |
+
# self.expire_codes_(x)
|
220 |
+
# ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
221 |
+
# embed_sum = x.t() @ embed_onehot
|
222 |
+
# ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
223 |
+
# cluster_size = (
|
224 |
+
# laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
225 |
+
# * self.cluster_size.sum()
|
226 |
+
# )
|
227 |
+
# embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
228 |
+
# self.embed.data.copy_(embed_normalized)
|
229 |
+
|
230 |
+
return quantize, embed_ind
|
231 |
+
|
232 |
+
|
233 |
+
class VectorQuantization(nn.Module):
|
234 |
+
"""Vector quantization implementation.
|
235 |
+
Currently supports only euclidean distance.
|
236 |
+
Args:
|
237 |
+
dim (int): Dimension
|
238 |
+
codebook_size (int): Codebook size
|
239 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
240 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
241 |
+
epsilon (float): Epsilon value for numerical stability.
|
242 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
243 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
244 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
245 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
246 |
+
randomly selected vector from the current batch.
|
247 |
+
commitment_weight (float): Weight for commitment loss.
|
248 |
+
"""
|
249 |
+
|
250 |
+
def __init__(
|
251 |
+
self,
|
252 |
+
dim: int,
|
253 |
+
codebook_size: int,
|
254 |
+
codebook_dim: tp.Optional[int] = None,
|
255 |
+
decay: float = 0.99,
|
256 |
+
epsilon: float = 1e-5,
|
257 |
+
kmeans_init: bool = True,
|
258 |
+
kmeans_iters: int = 50,
|
259 |
+
threshold_ema_dead_code: int = 2,
|
260 |
+
commitment_weight: float = 1.0,
|
261 |
+
):
|
262 |
+
super().__init__()
|
263 |
+
_codebook_dim: int = default(codebook_dim, dim)
|
264 |
+
|
265 |
+
requires_projection = _codebook_dim != dim
|
266 |
+
self.project_in = (
|
267 |
+
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
268 |
+
)
|
269 |
+
self.project_out = (
|
270 |
+
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
271 |
+
)
|
272 |
+
|
273 |
+
self.epsilon = epsilon
|
274 |
+
self.commitment_weight = commitment_weight
|
275 |
+
|
276 |
+
self._codebook = EuclideanCodebook(
|
277 |
+
dim=_codebook_dim,
|
278 |
+
codebook_size=codebook_size,
|
279 |
+
kmeans_init=kmeans_init,
|
280 |
+
kmeans_iters=kmeans_iters,
|
281 |
+
decay=decay,
|
282 |
+
epsilon=epsilon,
|
283 |
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
284 |
+
)
|
285 |
+
self.codebook_size = codebook_size
|
286 |
+
|
287 |
+
@property
|
288 |
+
def codebook(self):
|
289 |
+
return self._codebook.embed
|
290 |
+
|
291 |
+
def encode(self, x):
|
292 |
+
x = rearrange(x, "b d n -> b n d")
|
293 |
+
x = self.project_in(x)
|
294 |
+
embed_in = self._codebook.encode(x)
|
295 |
+
return embed_in
|
296 |
+
|
297 |
+
def decode(self, embed_ind):
|
298 |
+
quantize = self._codebook.decode(embed_ind)
|
299 |
+
quantize = self.project_out(quantize)
|
300 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
301 |
+
return quantize
|
302 |
+
|
303 |
+
def forward(self, x):
|
304 |
+
device = x.device
|
305 |
+
x = rearrange(x, "b d n -> b n d")
|
306 |
+
x = self.project_in(x)
|
307 |
+
|
308 |
+
quantize, embed_ind = self._codebook(x)
|
309 |
+
|
310 |
+
if self.training:
|
311 |
+
quantize = x + (quantize - x).detach()
|
312 |
+
|
313 |
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
314 |
+
|
315 |
+
if self.training:
|
316 |
+
if self.commitment_weight > 0:
|
317 |
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
318 |
+
loss = loss + commit_loss * self.commitment_weight
|
319 |
+
|
320 |
+
quantize = self.project_out(quantize)
|
321 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
322 |
+
return quantize, embed_ind, loss
|
323 |
+
|
324 |
+
|
325 |
+
class ResidualVectorQuantization(nn.Module):
|
326 |
+
"""Residual vector quantization implementation.
|
327 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
328 |
+
"""
|
329 |
+
|
330 |
+
def __init__(self, *, num_quantizers, **kwargs):
|
331 |
+
super().__init__()
|
332 |
+
self.layers = nn.ModuleList(
|
333 |
+
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
334 |
+
)
|
335 |
+
|
336 |
+
def forward(
|
337 |
+
self,
|
338 |
+
x,
|
339 |
+
n_q: tp.Optional[int] = None,
|
340 |
+
layers: tp.Optional[list] = None,
|
341 |
+
st: tp.Optional[int] = None,
|
342 |
+
):
|
343 |
+
quantized_out = 0.0
|
344 |
+
residual = x
|
345 |
+
|
346 |
+
all_losses = []
|
347 |
+
all_indices = []
|
348 |
+
out_quantized = []
|
349 |
+
n_q = n_q or len(self.layers)
|
350 |
+
st = st or 0
|
351 |
+
for i, layer in enumerate(self.layers[st:n_q]):
|
352 |
+
quantized, indices, loss = layer(residual)
|
353 |
+
residual = residual - quantized
|
354 |
+
quantized_out = quantized_out + quantized
|
355 |
+
|
356 |
+
all_indices.append(indices)
|
357 |
+
all_losses.append(loss)
|
358 |
+
if layers and i in layers:
|
359 |
+
out_quantized.append(quantized)
|
360 |
+
|
361 |
+
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
362 |
+
return quantized_out, out_indices, out_losses, out_quantized
|
363 |
+
|
364 |
+
def encode(
|
365 |
+
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
366 |
+
) -> torch.Tensor:
|
367 |
+
residual = x
|
368 |
+
all_indices = []
|
369 |
+
n_q = n_q or len(self.layers)
|
370 |
+
st = st or 0
|
371 |
+
for layer in self.layers[st:n_q]:
|
372 |
+
indices = layer.encode(residual)
|
373 |
+
quantized = layer.decode(indices)
|
374 |
+
residual = residual - quantized
|
375 |
+
all_indices.append(indices)
|
376 |
+
out_indices = torch.stack(all_indices)
|
377 |
+
return out_indices
|
378 |
+
|
379 |
+
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
|
380 |
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
381 |
+
for i, indices in enumerate(q_indices):
|
382 |
+
layer = self.layers[st + i]
|
383 |
+
quantized = layer.decode(indices)
|
384 |
+
quantized_out = quantized_out + quantized
|
385 |
+
return quantized_out
|
speechtokenizer/quantization/distrib.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""Torch distributed utilities."""
|
8 |
+
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
def rank():
|
15 |
+
if torch.distributed.is_initialized():
|
16 |
+
return torch.distributed.get_rank()
|
17 |
+
else:
|
18 |
+
return 0
|
19 |
+
|
20 |
+
|
21 |
+
def world_size():
|
22 |
+
if torch.distributed.is_initialized():
|
23 |
+
return torch.distributed.get_world_size()
|
24 |
+
else:
|
25 |
+
return 1
|
26 |
+
|
27 |
+
|
28 |
+
def is_distributed():
|
29 |
+
return world_size() > 1
|
30 |
+
|
31 |
+
|
32 |
+
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
|
33 |
+
if is_distributed():
|
34 |
+
return torch.distributed.all_reduce(tensor, op)
|
35 |
+
|
36 |
+
|
37 |
+
def _is_complex_or_float(tensor):
|
38 |
+
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
39 |
+
|
40 |
+
|
41 |
+
def _check_number_of_params(params: tp.List[torch.Tensor]):
|
42 |
+
# utility function to check that the number of params in all workers is the same,
|
43 |
+
# and thus avoid a deadlock with distributed all reduce.
|
44 |
+
if not is_distributed() or not params:
|
45 |
+
return
|
46 |
+
#print('params[0].device ', params[0].device)
|
47 |
+
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
48 |
+
all_reduce(tensor)
|
49 |
+
if tensor.item() != len(params) * world_size():
|
50 |
+
# If not all the workers have the same number, for at least one of them,
|
51 |
+
# this inequality will be verified.
|
52 |
+
raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
|
53 |
+
"at least one worker has a different one.")
|
54 |
+
|
55 |
+
|
56 |
+
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
|
57 |
+
"""Broadcast the tensors from the given parameters to all workers.
|
58 |
+
This can be used to ensure that all workers have the same model to start with.
|
59 |
+
"""
|
60 |
+
if not is_distributed():
|
61 |
+
return
|
62 |
+
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
63 |
+
_check_number_of_params(tensors)
|
64 |
+
handles = []
|
65 |
+
for tensor in tensors:
|
66 |
+
# src = int(rank()) # added code
|
67 |
+
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
|
68 |
+
handles.append(handle)
|
69 |
+
for handle in handles:
|
70 |
+
handle.wait()
|
71 |
+
|
72 |
+
|
73 |
+
def sync_buffer(buffers, average=True):
|
74 |
+
"""
|
75 |
+
Sync grad for buffers. If average is False, broadcast instead of averaging.
|
76 |
+
"""
|
77 |
+
if not is_distributed():
|
78 |
+
return
|
79 |
+
handles = []
|
80 |
+
for buffer in buffers:
|
81 |
+
if torch.is_floating_point(buffer.data):
|
82 |
+
if average:
|
83 |
+
handle = torch.distributed.all_reduce(
|
84 |
+
buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
85 |
+
else:
|
86 |
+
handle = torch.distributed.broadcast(
|
87 |
+
buffer.data, src=0, async_op=True)
|
88 |
+
handles.append((buffer, handle))
|
89 |
+
for buffer, handle in handles:
|
90 |
+
handle.wait()
|
91 |
+
if average:
|
92 |
+
buffer.data /= world_size
|
93 |
+
|
94 |
+
|
95 |
+
def sync_grad(params):
|
96 |
+
"""
|
97 |
+
Simpler alternative to DistributedDataParallel, that doesn't rely
|
98 |
+
on any black magic. For simple models it can also be as fast.
|
99 |
+
Just call this on your model parameters after the call to backward!
|
100 |
+
"""
|
101 |
+
if not is_distributed():
|
102 |
+
return
|
103 |
+
handles = []
|
104 |
+
for p in params:
|
105 |
+
if p.grad is not None:
|
106 |
+
handle = torch.distributed.all_reduce(
|
107 |
+
p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
108 |
+
handles.append((p, handle))
|
109 |
+
for p, handle in handles:
|
110 |
+
handle.wait()
|
111 |
+
p.grad.data /= world_size()
|
112 |
+
|
113 |
+
|
114 |
+
def average_metrics(metrics: tp.Dict[str, float], count=1.):
|
115 |
+
"""Average a dictionary of metrics across all workers, using the optional
|
116 |
+
`count` as unormalized weight.
|
117 |
+
"""
|
118 |
+
if not is_distributed():
|
119 |
+
return metrics
|
120 |
+
keys, values = zip(*metrics.items())
|
121 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
122 |
+
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
|
123 |
+
tensor *= count
|
124 |
+
all_reduce(tensor)
|
125 |
+
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
|
126 |
+
return dict(zip(keys, averaged))
|