ordinaryaccount commited on
Commit
26f400e
·
1 Parent(s): 2b694f4
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +6 -6
  2. app.py +137 -0
  3. audios/1.wav +3 -0
  4. audios/2.wav +3 -0
  5. audios/3.wav +3 -0
  6. audios/4.wav +3 -0
  7. audios/5.wav +3 -0
  8. infer.py +175 -0
  9. models.py +299 -0
  10. requirements.txt +4 -0
  11. speechtokenizer/__init__.py +3 -0
  12. speechtokenizer/model.py +215 -0
  13. speechtokenizer/modules/__init__.py +21 -0
  14. speechtokenizer/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  15. speechtokenizer/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  16. speechtokenizer/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  17. speechtokenizer/modules/__pycache__/conv.cpython-310.pyc +0 -0
  18. speechtokenizer/modules/__pycache__/conv.cpython-311.pyc +0 -0
  19. speechtokenizer/modules/__pycache__/conv.cpython-39.pyc +0 -0
  20. speechtokenizer/modules/__pycache__/lstm.cpython-310.pyc +0 -0
  21. speechtokenizer/modules/__pycache__/lstm.cpython-311.pyc +0 -0
  22. speechtokenizer/modules/__pycache__/lstm.cpython-39.pyc +0 -0
  23. speechtokenizer/modules/__pycache__/norm.cpython-310.pyc +0 -0
  24. speechtokenizer/modules/__pycache__/norm.cpython-311.pyc +0 -0
  25. speechtokenizer/modules/__pycache__/norm.cpython-39.pyc +0 -0
  26. speechtokenizer/modules/__pycache__/seanet.cpython-310.pyc +0 -0
  27. speechtokenizer/modules/__pycache__/seanet.cpython-311.pyc +0 -0
  28. speechtokenizer/modules/__pycache__/seanet.cpython-39.pyc +0 -0
  29. speechtokenizer/modules/conv.py +252 -0
  30. speechtokenizer/modules/lstm.py +31 -0
  31. speechtokenizer/modules/norm.py +28 -0
  32. speechtokenizer/modules/seanet.py +408 -0
  33. speechtokenizer/pretrained_model/SpeechTokenizer.pt +3 -0
  34. speechtokenizer/pretrained_model/speechtokenizer_hubert_avg_config.json +47 -0
  35. speechtokenizer/quantization/__init__.py +8 -0
  36. speechtokenizer/quantization/__pycache__/__init__.cpython-310.pyc +0 -0
  37. speechtokenizer/quantization/__pycache__/__init__.cpython-311.pyc +0 -0
  38. speechtokenizer/quantization/__pycache__/__init__.cpython-39.pyc +0 -0
  39. speechtokenizer/quantization/__pycache__/core_vq.cpython-310.pyc +0 -0
  40. speechtokenizer/quantization/__pycache__/core_vq.cpython-311.pyc +0 -0
  41. speechtokenizer/quantization/__pycache__/core_vq.cpython-39.pyc +0 -0
  42. speechtokenizer/quantization/__pycache__/distrib.cpython-310.pyc +0 -0
  43. speechtokenizer/quantization/__pycache__/distrib.cpython-311.pyc +0 -0
  44. speechtokenizer/quantization/__pycache__/distrib.cpython-39.pyc +0 -0
  45. speechtokenizer/quantization/__pycache__/vq.cpython-310.pyc +0 -0
  46. speechtokenizer/quantization/__pycache__/vq.cpython-311.pyc +0 -0
  47. speechtokenizer/quantization/__pycache__/vq.cpython-39.pyc +0 -0
  48. speechtokenizer/quantization/ac.py +292 -0
  49. speechtokenizer/quantization/core_vq.py +385 -0
  50. speechtokenizer/quantization/distrib.py +126 -0
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
  title: VoiceMark
3
- emoji: 📚
4
- colorFrom: blue
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.30.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
- short_description: Zero-Shot Voice Cloning-Resistant Watermarking Approach Leve
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: VoiceMark
3
+ emoji: 📊
4
+ colorFrom: purple
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.16.0
8
  app_file: app.py
9
  pinned: false
10
+ license: gpl
11
+ short_description: Zero-Shot Voice Cloning-Resistant Watermarking
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import os
4
+ import torchaudio
5
+ from infer import (
6
+ WatermarkSolver,
7
+ hamming_distance
8
+ )
9
+
10
+ # Predefined watermarks (instead of loading from a JSON file)
11
+ watermarks = {
12
+ "VoiceMark": "1000010101010011",
13
+ "Voice Cloning": "1111111001000010",
14
+ "Speech Security": "1011101100001110",
15
+ "Audio Watermarking": "0110110011100010",
16
+ "Deep Learning": "0000100111111000",
17
+ "Artificial Intelligence": "0010000100011111",
18
+ "Hello World": "0001111101110001",
19
+ "Happy New Year": "1101011011011101",
20
+ "World Peace": "0011110010011110",
21
+ "Good Morning": "0000001011000010",
22
+ }
23
+
24
+ # Initialize WatermarkSolver model
25
+ solver = WatermarkSolver()
26
+ solver.load_model(checkpoint_dir="./", checkpoint_name="voicemark.pth", strict=True)
27
+
28
+ # Gradio interface
29
+ with gr.Blocks() as demo:
30
+ gr.Markdown(
31
+ "# VoiceMark: Zero-Shot Voice Cloning-Resistant Watermarking Approach Leveraging Speaker-Specific Latents"
32
+ )
33
+ with gr.Column():
34
+ gr.Image(
35
+ value="voicemark_overview.png",
36
+ width=925,
37
+ height=487,
38
+ elem_id="overview_image",
39
+ label="overview"
40
+ )
41
+ # Step 1: Upload audio and select watermark
42
+ gr.HTML("<h3 style='text-align: center;'>The overall architecture of our proposed VoiceMark</h3>")
43
+
44
+ # Step 1: Upload audio and select watermark
45
+ gr.Markdown(
46
+ """
47
+ **Step 1**: Upload an audio file or select one from the provided samples, choose a watermark, and generate the watermarked audio.
48
+ """
49
+ )
50
+
51
+ with gr.Row():
52
+ with gr.Column():
53
+ audio_input = gr.Audio(label="Upload Audio", type="filepath")
54
+
55
+ gr.Examples(
56
+ examples=[
57
+ ["audios/1.wav"],
58
+ ["audios/2.wav"],
59
+ ["audios/3.wav"],
60
+ ["audios/4.wav"],
61
+ ["audios/5.wav"],
62
+ ],
63
+ inputs=audio_input,
64
+ label="Sample Audios (Click to Use)"
65
+ )
66
+
67
+ with gr.Column():
68
+ audio_output = gr.Audio(label="Watermarked Audio", type="filepath")
69
+ watermark_list = gr.Dropdown(
70
+ label="Select Watermark", choices=list(watermarks.keys()), interactive=True
71
+ )
72
+ add_watermark_button = gr.Button("Add Watermark to Audio")
73
+
74
+ # Step 2: TTS tools demo links
75
+ gr.Markdown(
76
+ """
77
+ **Step 2**: Download the generated watermarked audio, then use Zero-Shot Voice Cloning tools to generate the cloned audio. Some available tools are:
78
+ - [CosyVoice2: Scalable Streaming Speech Synthesis with Large Language Models](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B)
79
+ - [F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
80
+ - [MaskGCT: Zero-Shot Text-to-Speech with Masked Generative Codec Transformer](https://huggingface.co/spaces/amphion/maskgct)
81
+ """
82
+ )
83
+
84
+ # Step 3: Upload cloned audio to decode watermark
85
+ gr.Markdown(
86
+ """
87
+ **Step 3**: Upload the cloned audio and decode your watermark.
88
+ """
89
+ )
90
+
91
+ with gr.Row():
92
+ decode_audio_input = gr.Audio(label="Upload Cloned Audio", type="filepath")
93
+ with gr.Column():
94
+ decoded_watermark_output = gr.Textbox(label="Decoded Watermark")
95
+ decode_button = gr.Button("Decode Watermark")
96
+
97
+ def process_audio(audio_path, watermark_text):
98
+ if not audio_path:
99
+ return "No audio selected. Please upload or select a sample."
100
+ try:
101
+ watermarked_audio = solver.infer_for_ui(
102
+ audio_path, watermarks[watermark_text]
103
+ )
104
+ return watermarked_audio
105
+ except ValueError as e:
106
+ return str(e)
107
+
108
+ add_watermark_button.click(
109
+ process_audio,
110
+ inputs=[audio_input, watermark_list],
111
+ outputs=audio_output
112
+ )
113
+
114
+ def decode_watermark(audio_path):
115
+ try:
116
+ detect_prob, decoded_id = solver.decode_for_ui(audio_path)
117
+ if detect_prob < 1e-2:
118
+ return "No matching watermark found"
119
+ closest_match = None
120
+ min_distance = float("inf")
121
+ for text, id_bin in watermarks.items():
122
+ distance = hamming_distance(decoded_id, id_bin, base=16)
123
+ if distance < min_distance:
124
+ closest_match = text
125
+ min_distance = distance
126
+ if min_distance < 10:
127
+ return closest_match
128
+ return "No matching watermark found"
129
+ except ValueError as e:
130
+ return str(e)
131
+
132
+ decode_button.click(
133
+ decode_watermark, inputs=decode_audio_input, outputs=decoded_watermark_output
134
+ )
135
+
136
+ # Launch the Gradio app
137
+ demo.launch()
audios/1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb97a4d6a89ce3cfff4d504de24ef4d79658cc68cc765da7c905b30ecb5e9f1d
3
+ size 160078
audios/2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5cf698739bad50c02ab96375bed49594a0b1a4421406b6ed10edc432f09dc94
3
+ size 152078
audios/3.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4673c2f94dec181e7603df7564af0778d6589a7d7f4878f30a9c42bdfc2324f7
3
+ size 160078
audios/4.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10c63ef8e2866b7e4b98d15cbbc084da667b350ab4d084f338b656661b619ceb
3
+ size 98750
audios/5.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:049b7061915306d541d2d8c059c6c9599502997562bc32120f16a873a55e499e
3
+ size 63072
infer.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import SBW
2
+ import torchaudio
3
+ import torch
4
+ import os
5
+
6
+ def hamming_distance(s1, s2, base=2):
7
+ """
8
+ Calculate the absolute difference between two strings interpreted in chunks of a specified base.
9
+
10
+ Args:
11
+ s1 (str): The first binary string.
12
+ s2 (str): The second binary string.
13
+ base (int): The base to interpret the binary strings (e.g., 2 for binary, 16 for hexadecimal).
14
+
15
+ Returns:
16
+ int: The sum of absolute differences for corresponding chunks.
17
+
18
+ Raises:
19
+ ValueError: If the strings are not of equal length or invalid for the given base.
20
+ """
21
+ if len(s1) != len(s2):
22
+ raise ValueError("Both strings must be of equal length")
23
+
24
+ # Determine the chunk size for the given base
25
+ import math
26
+
27
+ chunk_size = int(math.log(base, 2))
28
+
29
+ if len(s1) % chunk_size != 0 or len(s2) % chunk_size != 0:
30
+ raise ValueError(
31
+ f"Binary strings must be a multiple of {chunk_size} bits for base {base}"
32
+ )
33
+
34
+ # Split binary strings into chunks
35
+ def to_chunks(binary_str):
36
+ return [
37
+ binary_str[i : i + chunk_size]
38
+ for i in range(0, len(binary_str), chunk_size)
39
+ ]
40
+
41
+ chunks_s1 = to_chunks(s1)
42
+ chunks_s2 = to_chunks(s2)
43
+
44
+ # Convert chunks to integers and calculate absolute difference
45
+ def absolute_difference_chunks(chunk1, chunk2):
46
+ int1 = int(chunk1, 2)
47
+ int2 = int(chunk2, 2)
48
+ return abs(int1 - int2)
49
+
50
+ return sum(
51
+ absolute_difference_chunks(c1, c2) for c1, c2 in zip(chunks_s1, chunks_s2)
52
+ )
53
+
54
+ def random_message(nbits: int, batch_size: int) -> torch.Tensor:
55
+ """Return random message as 0/1 tensor."""
56
+ if nbits == 0:
57
+ return torch.tensor([])
58
+ return torch.randint(0, 2, (batch_size, nbits))
59
+
60
+
61
+ def string_to_message(message_str: str, batch_size: int) -> torch.Tensor:
62
+ """
63
+ Convert a binary string to a message tensor.
64
+
65
+ Args:
66
+ message_str (str): A string of '0's and '1's.
67
+ batch_size (int): The batch size for the output tensor.
68
+
69
+ Returns:
70
+ torch.Tensor: A tensor of shape (batch_size, nbits) where nbits is the length of message_str.
71
+ """
72
+ # Convert the string to a list of integers (0 and 1)
73
+ message_list = [int(bit) for bit in message_str]
74
+ nbits = len(message_list)
75
+ # Convert the list to a tensor and repeat it for the batch size
76
+ message_tensor = torch.tensor(message_list).unsqueeze(0).repeat(batch_size, 1)
77
+ return message_tensor
78
+
79
+
80
+ class WatermarkSolver:
81
+ def __init__(self):
82
+ self.device = 'cpu'
83
+ self.sample_rate = 16000
84
+ self.nbits = 16
85
+ self.model = SBW()
86
+
87
+ def load_model(
88
+ self, checkpoint_dir, checkpoint_name, strict=False
89
+ ):
90
+ """
91
+ Load the latest model weights from the checkpoint directory.
92
+ """
93
+ checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")]
94
+ if not checkpoint_files:
95
+ print("No checkpoint files found in the directory.")
96
+ return
97
+
98
+ checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
99
+ print(f"Loading model weights from {checkpoint_path}...")
100
+
101
+ # Load the checkpoint
102
+ checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'),weights_only=False)
103
+
104
+ # Load model state dict with strict=False to allow missing keys (semantic_encoder)
105
+ model_state_dict = checkpoint["model_state_dict"]
106
+
107
+ new_state_dict = {}
108
+ for k, v in model_state_dict.items():
109
+ new_key = k.replace("module.", "")
110
+ new_state_dict[new_key] = v
111
+ if not strict:
112
+ new_state_dict = {
113
+ k: v
114
+ for k, v in new_state_dict.items()
115
+ if k in self.model.state_dict()
116
+ and self.model.state_dict()[k].shape == v.shape
117
+ }
118
+ self.model.load_state_dict(new_state_dict, strict=False)
119
+
120
+ print("Model state dict loaded successfully.")
121
+ self.epoch = checkpoint["epoch"]
122
+ print(f"Checkpoint loaded from epoch {checkpoint['epoch']}.")
123
+
124
+ def infer_for_ui(self, input_audio_path, watermark, output_audio_path="tmp"):
125
+ message_str = watermark
126
+ self.model.eval()
127
+
128
+ waveform, sample_rate = torchaudio.load(input_audio_path)
129
+
130
+ if sample_rate != self.sample_rate:
131
+ resampler = torchaudio.transforms.Resample(
132
+ orig_freq=sample_rate, new_freq=self.sample_rate
133
+ )
134
+ waveform = resampler(waveform)
135
+
136
+ waveform = waveform[:1].to(self.device).unsqueeze(1)
137
+ message = string_to_message(message_str=message_str, batch_size=1).to(
138
+ self.device
139
+ )
140
+
141
+ with torch.no_grad():
142
+ output = self.model(
143
+ waveform,
144
+ message=message,
145
+ )
146
+ y_wm = output["recon_wm"]
147
+
148
+ os.makedirs(output_audio_path, exist_ok=True)
149
+ watermarked_audio_path = os.path.join(
150
+ output_audio_path,
151
+ f"{os.path.splitext(os.path.basename(input_audio_path))[0]}_watermarked.wav",
152
+ )
153
+ torchaudio.save(watermarked_audio_path, y_wm.squeeze(1).cpu(), self.sample_rate)
154
+ return watermarked_audio_path
155
+
156
+ def decode_for_ui(self, input_audio_path):
157
+ self.model.eval()
158
+
159
+ waveform, sample_rate = torchaudio.load(input_audio_path)
160
+
161
+ if sample_rate != self.sample_rate:
162
+ resampler = torchaudio.transforms.Resample(
163
+ orig_freq=sample_rate, new_freq=self.sample_rate
164
+ )
165
+ waveform = resampler(waveform)
166
+
167
+ waveform = waveform[:1].to(self.device).unsqueeze(1)
168
+
169
+ with torch.no_grad():
170
+ detect_prob, detected_message_tensor, _ = self.model.detect_watermark(waveform)
171
+ detected_id = "".join(
172
+ map(str, detected_message_tensor.squeeze(0).cpu().numpy().tolist())
173
+ )
174
+
175
+ return detect_prob,detected_id
models.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Dict, Optional, Tuple
4
+
5
+ import torch.nn.functional as F
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class WMDetector(nn.Module):
11
+ """
12
+ Detect watermarks in an audio signal using a Transformer architecture,
13
+ where the watermark bits are split into bytes (8 bits each).
14
+ We assume nbits is a multiple of 8.
15
+ """
16
+
17
+ def __init__(
18
+ self, input_channels: int, nbits: int, nchunk_size: int, d_model: int = 512
19
+ ):
20
+ """
21
+ Args:
22
+ input_channels (int): Number of input channels in the audio feature (e.g., mel channels).
23
+ nbits (int): Total number of bits in the watermark, must be a multiple of 8.
24
+ d_model (int): Embedding dimension for the Transformer.
25
+ """
26
+ super().__init__()
27
+ self.nchunk_size = nchunk_size
28
+ assert nbits % nchunk_size == 0, "nbits must be a multiple of 8!"
29
+ self.nbits = nbits
30
+ self.d_model = d_model
31
+ # Number of bytes
32
+ self.nchunks = nbits // nchunk_size
33
+
34
+ # 1D convolution to map the input channels to d_model
35
+ self.embedding = nn.Conv1d(input_channels, d_model, kernel_size=1)
36
+
37
+ # Transformer encoder block
38
+ self.transformer = nn.TransformerEncoder(
39
+ nn.TransformerEncoderLayer(
40
+ d_model=d_model,
41
+ nhead=1,
42
+ dim_feedforward=d_model * 2,
43
+ activation="gelu",
44
+ batch_first=True,
45
+ ),
46
+ num_layers=8,
47
+ )
48
+
49
+ # A linear head for watermark presence detection (binary)
50
+ self.watermark_head = nn.Linear(d_model, 1)
51
+
52
+ # For each byte, we perform a 256-way classification
53
+ self.message_heads = nn.ModuleList(
54
+ nn.Linear(d_model, 2**nchunk_size) for _ in range(self.nchunks)
55
+ )
56
+
57
+ # Learnable embeddings for each byte chunk (instead of per bit)
58
+ # Shape: [nchunks, d_model]
59
+ self.nchunk_embeddings = nn.Parameter(torch.randn(self.nchunks, d_model))
60
+
61
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
62
+ """
63
+ Forward pass of the detector.
64
+
65
+ Returns:
66
+ logits (torch.Tensor): Watermark detection logits of shape [batch, seq_len].
67
+ chunk_logits (torch.Tensor): Byte-level classification logits of shape [batch, nchunks, 256].
68
+ """
69
+ batch_size, input_channels, time_steps = x.shape
70
+
71
+ # 1) Map [batch, in_channels, time_steps] → [batch, time_steps, d_model]
72
+ x = self.embedding(x).permute(0, 2, 1) # [batch, time_steps, d_model]
73
+
74
+ # 2) Prepend chunk embeddings at the beginning of the sequence
75
+ # [nchunks, d_model] → [1, nchunks, d_model] → [batch, nchunks, d_model]
76
+ nchunk_embeds = self.nchunk_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
77
+ # Concatenate along the time dimension: [batch, nchunks + time_steps, d_model]
78
+ x = torch.cat([nchunk_embeds, x], dim=1)
79
+
80
+ # 3) Pass through the Transformer
81
+ x = self.transformer(x)
82
+ # x has shape [batch, nchunks + time_steps, d_model]
83
+
84
+ # (a) Watermark presence detection: skip the first nchunks
85
+ detection_part = x[:, self.nchunks :] # [batch, time_steps, d_model]
86
+ logits = self.watermark_head(detection_part).squeeze(-1) # [batch, time_steps]
87
+
88
+ # (b) Message decoding: use the first nchunks
89
+ message_part = x[:, : self.nchunks] # [batch, nchunks, d_model]
90
+ chunk_logits_list = []
91
+ for i, head in enumerate(self.message_heads):
92
+ # message_part[:, i, :] has shape [batch, d_model]
93
+ # each head outputs [batch, 256]
94
+ chunk_vec = message_part[:, i, :]
95
+ chunk_logits_list.append(head(chunk_vec).unsqueeze(1)) # [batch, 1, 256]
96
+
97
+ # Concatenate along the 'nchunks' dimension → [batch, nchunks, 256]
98
+ chunk_logits = torch.cat(chunk_logits_list, dim=1)
99
+
100
+ return logits, chunk_logits
101
+
102
+ def detect_watermark(
103
+ self,
104
+ x: torch.Tensor,
105
+ sample_rate: Optional[int] = None,
106
+ threshold: float = 0.5,
107
+ ) -> Tuple[float, torch.Tensor, torch.Tensor]:
108
+ """
109
+ A convenience function for inference.
110
+
111
+ Returns:
112
+ detect_prob (float): Probability that the audio is watermarked.
113
+ binary_message (torch.Tensor): The recovered message of shape [batch, nbits] (binary).
114
+ detected (torch.Tensor): The sigmoid values of the per-timestep watermark detection.
115
+ """
116
+ logits, chunk_logits = self.forward(x)
117
+ # logits: [batch, seq_len] → raw logits for watermark presence detection
118
+ # chunk_logits: [batch, nchunks, 256] → classification logits for each byte
119
+
120
+ # (1) Compute watermark detection probability
121
+ detected = torch.sigmoid(logits) # [batch, seq_len]
122
+ detect_prob = detected.mean(dim=-1).cpu().item()
123
+
124
+ # (2) Decode the message: chunk_logits has shape [batch, nchunks, 256]
125
+ chunk_probs = F.softmax(chunk_logits, dim=-1) # [batch, nchunks, 256]
126
+ chunk_indices = torch.argmax(
127
+ chunk_probs, dim=-1
128
+ ) # [batch, nchunks], each in [0..255]
129
+ # (3) Convert each byte back to 8 bits
130
+ # Finally, assemble into a [batch, nbits] binary tensor
131
+ binary_message = []
132
+ for i in range(self.nchunks):
133
+ chunk_val = chunk_indices[:, i] # [batch]
134
+ # Extract 8 bits from the integer (0..255)
135
+ chunk_bits = []
136
+ for b in range(self.nchunk_size):
137
+ bit_b = (chunk_val >> b) & 1 # get bit b
138
+ chunk_bits.append(bit_b.unsqueeze(-1))
139
+ # Concatenate bits to shape [batch, 8]
140
+ chunk_bits = torch.cat(chunk_bits, dim=-1)
141
+ binary_message.append(chunk_bits)
142
+
143
+ # Concatenate all bytes → [batch, nbits]
144
+ binary_message = torch.cat(binary_message, dim=-1)
145
+
146
+ return detect_prob, binary_message, detected
147
+
148
+
149
+
150
+ class WMEmbedder(nn.Module):
151
+ """
152
+ A class that takes a secret message, processes it into chunk embeddings
153
+ (as a small sequence), and uses a TransformerDecoder to do cross-attention
154
+ between the original hidden (target) and the watermark tokens (memory).
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ nbits: int, # total bits in the secret message
160
+ input_dim: int, # the input dimension (e.g. audio feature dimension)
161
+ nchunk_size: int,
162
+ hidden_dim: int = 256,
163
+ num_heads: int = 1,
164
+ num_layers: int = 4,
165
+ ):
166
+ super().__init__()
167
+ self.nchunk_size = nchunk_size
168
+ assert nbits % nchunk_size == 0, "nbits must be a multiple of nchunk_size!"
169
+ self.nbits = nbits
170
+ self.nchunks = nbits // nchunk_size # how many chunks
171
+
172
+ # Each chunk (0..2^nchunk_size - 1) maps to an embedding of size [hidden_dim]
173
+ self.msg_embeddings = nn.ModuleList(
174
+ nn.Embedding(2**nchunk_size, hidden_dim) for _ in range(self.nchunks)
175
+ )
176
+
177
+ # Linear to project [input_dim] -> [hidden_dim]
178
+ self.input_projection = nn.Linear(input_dim, hidden_dim)
179
+
180
+ # TransformerDecoder for cross-attention
181
+ # d_model=hidden_dim, so the decoder expects [b, seq_len, hidden_dim] as tgt
182
+ # and [b, memory_len, hidden_dim] as memory
183
+ decoder_layer = nn.TransformerDecoderLayer(
184
+ d_model=hidden_dim,
185
+ nhead=num_heads,
186
+ dim_feedforward=2 * hidden_dim,
187
+ activation="gelu",
188
+ batch_first=True, # so shape is [batch, seq, feature]
189
+ )
190
+ self.transformer_decoder = nn.TransformerDecoder(
191
+ decoder_layer, num_layers=num_layers
192
+ )
193
+
194
+ # Project [hidden_dim] -> [input_dim]
195
+ # self.output_projection1 = nn.Linear(hidden_dim * 2, hidden_dim)
196
+ self.output_projection = nn.Linear(hidden_dim, input_dim)
197
+
198
+ def forward(self, hidden: torch.Tensor, msg: torch.Tensor) -> torch.Tensor:
199
+ """
200
+ Args:
201
+ hidden: [batch, input_dim, seq_len]
202
+ msg: [batch, nbits]
203
+ Returns:
204
+ A tensor [batch, input_dim, seq_len] with watermark injected.
205
+ """
206
+ b, in_dim, seq_len = hidden.shape
207
+
208
+ # 1) Project input features to [b, seq_len, hidden_dim]
209
+ hidden_projected = self.input_projection(
210
+ hidden.permute(0, 2, 1)
211
+ ) # => [b, seq_len, hidden_dim]
212
+
213
+ # 2) Convert the msg bits into a sequence of chunk embeddings
214
+ # We keep each chunk as one token => [b, nchunks, hidden_dim]
215
+ chunk_emb_list = []
216
+ for i in range(self.nchunks):
217
+ # msg[:, i*nchunk_size : (i+1)*nchunk_size] => shape [b, nchunk_size]
218
+ chunk_bits = msg[:, i * self.nchunk_size : (i + 1) * self.nchunk_size]
219
+ chunk_val = torch.zeros_like(chunk_bits[:, 0]) # shape [b]
220
+ for bit_idx in range(self.nchunk_size):
221
+ # shift bits
222
+ chunk_val += chunk_bits[:, bit_idx] << bit_idx
223
+
224
+ # embedding => [b, hidden_dim]
225
+ chunk_emb = self.msg_embeddings[i](chunk_val)
226
+ chunk_emb_list.append(chunk_emb.unsqueeze(1)) # => [b,1,hidden_dim]
227
+
228
+ # Concat => [b, nchunks, hidden_dim]
229
+ chunk_emb_seq = torch.cat(chunk_emb_list, dim=1) # [b, nchunks, hidden_dim]
230
+
231
+ # 3) Use chunk_emb_seq as memory, hidden_projected as target for TransformerDecoder
232
+ #
233
+ # TransformerDecoder forward signature:
234
+ # transformer_decoder(tgt, memory, ...)
235
+ # => [b, seq_len, hidden_dim]
236
+ x_decoded = self.transformer_decoder(
237
+ tgt=hidden_projected, # [b, seq_len, hidden_dim]
238
+ memory=chunk_emb_seq, # [b, nchunks, hidden_dim]
239
+ )
240
+
241
+ # 4) Project back to input_dim => [b, seq_len, input_dim]
242
+ x_output = self.output_projection(x_decoded)
243
+
244
+ # 5) permute back to [b, input_dim, seq_len]
245
+ x_output = x_output.permute(0, 2, 1) # => [b, input_dim, seq_len]
246
+
247
+ # 6) (Optional) Residual with original hidden
248
+ x_output = x_output + hidden
249
+
250
+ return x_output
251
+
252
+
253
+ from speechtokenizer import SpeechTokenizer
254
+
255
+
256
+ class SBW(nn.Module):
257
+ def __init__(self):
258
+ super().__init__()
259
+ self.nbits = 16
260
+ config_path = (
261
+ "speechtokenizer/pretrained_model/speechtokenizer_hubert_avg_config.json"
262
+ )
263
+ ckpt_path = "speechtokenizer/pretrained_model/SpeechTokenizer.pt"
264
+ self.st_model = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
265
+ self.msg_processor = WMEmbedder(
266
+ nbits=16,
267
+ input_dim=1024,
268
+ nchunk_size=4,
269
+ )
270
+ self.detector = WMDetector(
271
+ 1024,
272
+ 16,
273
+ nchunk_size=4,
274
+ )
275
+
276
+ def detect_watermark(
277
+ self, x: torch.Tensor, return_logits=False
278
+ ) -> Tuple[float, torch.Tensor]:
279
+ embedding = self.st_model.forward_feature(x)
280
+ if return_logits:
281
+ return self.detector(embedding)
282
+ return self.detector.detect_watermark(embedding)
283
+
284
+ def forward(
285
+ self,
286
+ speech_input: torch.Tensor,
287
+ message: Optional[torch.Tensor] = None,
288
+ ) -> Dict[str, torch.Tensor]:
289
+ recon, recon_wm, acoustic, acoustic_wm = self.st_model(
290
+ speech_input, msg_processor=self.msg_processor, message=message
291
+ )
292
+ wav_length = min(speech_input.size(-1), recon_wm.size(-1))
293
+ speech_input = speech_input[..., :wav_length]
294
+ recon = recon[..., :wav_length]
295
+ recon_wm = recon_wm[..., :wav_length]
296
+ return {
297
+ "recon": recon,
298
+ "recon_wm": recon_wm,
299
+ }
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torchaudio
2
+ beartype
3
+ einops
4
+ omegaconf
speechtokenizer/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .model import SpeechTokenizer
2
+
3
+ __version__ = '1.0.0'
speechtokenizer/model.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Wed Aug 30 15:47:55 2023
4
+ @author: zhangxin
5
+ """
6
+
7
+ import random
8
+ from models import WMEmbedder
9
+ from .modules.seanet import SEANetEncoder, SEANetDecoder
10
+ from .quantization import ResidualVectorQuantizer
11
+ import torch.nn as nn
12
+ from einops import rearrange
13
+ import torch
14
+ import numpy as np
15
+
16
+
17
+ class SpeechTokenizer(nn.Module):
18
+ def __init__(self, config):
19
+ """
20
+
21
+ Parameters
22
+ ----------
23
+ config : json
24
+ Model Config.
25
+
26
+ """
27
+ super().__init__()
28
+ self.encoder = SEANetEncoder(
29
+ n_filters=config.get("n_filters"),
30
+ dimension=config.get("dimension"),
31
+ ratios=config.get("strides"),
32
+ lstm=config.get("lstm_layers"),
33
+ bidirectional=config.get("bidirectional"),
34
+ dilation_base=config.get("dilation_base"),
35
+ residual_kernel_size=config.get("residual_kernel_size"),
36
+ n_residual_layers=config.get("n_residual_layers"),
37
+ activation=config.get("activation"),
38
+ )
39
+ self.sample_rate = config.get("sample_rate")
40
+ self.n_q = config.get("n_q")
41
+ self.downsample_rate = np.prod(config.get("strides"))
42
+ if config.get("dimension") != config.get("semantic_dimension"):
43
+ self.transform = nn.Linear(
44
+ config.get("dimension"), config.get("semantic_dimension")
45
+ )
46
+ else:
47
+ self.transform = nn.Identity()
48
+ self.quantizer = ResidualVectorQuantizer(
49
+ dimension=config.get("dimension"),
50
+ n_q=config.get("n_q"),
51
+ bins=config.get("codebook_size"),
52
+ )
53
+ self.decoder = SEANetDecoder(
54
+ n_filters=config.get("n_filters"),
55
+ dimension=config.get("dimension"),
56
+ ratios=config.get("strides"),
57
+ lstm=config.get("lstm_layers"),
58
+ bidirectional=False,
59
+ dilation_base=config.get("dilation_base"),
60
+ residual_kernel_size=config.get("residual_kernel_size"),
61
+ n_residual_layers=config.get("n_residual_layers"),
62
+ activation=config.get("activation"),
63
+ )
64
+
65
+ @classmethod
66
+ def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
67
+ """
68
+
69
+ Parameters
70
+ ----------
71
+ config_path : str
72
+ Path of model configuration file.
73
+ ckpt_path : str
74
+ Path of model checkpoint.
75
+
76
+ Returns
77
+ -------
78
+ model : SpeechTokenizer
79
+ SpeechTokenizer model.
80
+
81
+ """
82
+ import json
83
+
84
+ with open(config_path) as f:
85
+ cfg = json.load(f)
86
+ model = cls(cfg)
87
+ params = torch.load(ckpt_path, map_location="cpu")
88
+ model.load_state_dict(params)
89
+ return model
90
+
91
+ def forward(
92
+ self,
93
+ x: torch.tensor,
94
+ n_q: int = None,
95
+ layers: list = [0],
96
+ msg_processor: WMEmbedder = None,
97
+ message: torch.Tensor = None,
98
+ ):
99
+ """
100
+
101
+ Parameters
102
+ ----------
103
+ x : torch.tensor
104
+ Input wavs. Shape: (batch, channels, timesteps).
105
+ n_q : int, optional
106
+ Number of quantizers in RVQ used to encode. The default is all layers.
107
+ layers : list[int], optional
108
+ Layers of RVQ should return quantized result. The default is the first layer.
109
+
110
+ Returns
111
+ -------
112
+ o : torch.tensor
113
+ Output wavs. Shape: (batch, channels, timesteps).
114
+ commit_loss : torch.tensor
115
+ Commitment loss from residual vector quantizers.
116
+ feature : torch.tensor
117
+ Output of RVQ's first layer. Shape: (batch, timesteps, dimension)
118
+
119
+ """
120
+ with torch.no_grad():
121
+ e = self.encoder(x)
122
+ quantized_full, _, _, quantized_list = self.quantizer(
123
+ e, n_q=n_q, layers=[0, 1, 2, 3, 4, 5, 6, 7], st=0
124
+ )
125
+ # semantic, _, _, _ = self.quantizer(e, n_q=1, st=0)
126
+ # acoustic = e - semantic
127
+ o = self.decoder(quantized_full)
128
+
129
+ subset = quantized_list[1:]
130
+
131
+ # half_len = len(subset) // 2
132
+ # selected_for_processing = random.sample(subset, half_len)
133
+
134
+ # selected_ids = set(id(x) for x in selected_for_processing)
135
+ # acoustic_wm = sum(
136
+ # msg_processor(x, message) if id(x) in selected_ids else x for x in subset
137
+ # )
138
+
139
+ acoustic_wm = sum(msg_processor(x, message) for x in subset)
140
+
141
+ acoustic = e - quantized_list[0]
142
+
143
+ # e_wm = acoustic_wm
144
+ e_wm = quantized_list[0] + acoustic_wm
145
+
146
+ o_wm = self.decoder(e_wm)
147
+
148
+ return (o, o_wm, acoustic, acoustic_wm)
149
+
150
+ def forward_feature(self, x: torch.tensor, layers: list = None):
151
+ """
152
+
153
+ Parameters
154
+ ----------
155
+ x : torch.tensor
156
+ Input wavs. Shape should be (batch, channels, timesteps).
157
+ layers : list[int], optional
158
+ Layers of RVQ should return quantized result. The default is all layers.
159
+
160
+ Returns
161
+ -------
162
+ quantized_list : list[torch.tensor]
163
+ Quantized of required layers.
164
+
165
+ """
166
+ e = self.encoder(x)
167
+ with torch.no_grad():
168
+ semantic, _, _, _ = self.quantizer(e, st=0, n_q=1)
169
+ acoustic = e - semantic
170
+ return acoustic
171
+
172
+ def encode(self, x: torch.tensor, n_q: int = None, st: int = None):
173
+ """
174
+
175
+ Parameters
176
+ ----------
177
+ x : torch.tensor
178
+ Input wavs. Shape: (batch, channels, timesteps).
179
+ n_q : int, optional
180
+ Number of quantizers in RVQ used to encode. The default is all layers.
181
+ st : int, optional
182
+ Start quantizer index in RVQ. The default is 0.
183
+
184
+ Returns
185
+ -------
186
+ codes : torch.tensor
187
+ Output indices for each quantizer. Shape: (n_q, batch, timesteps)
188
+
189
+ """
190
+ e = self.encoder(x)
191
+ if st is None:
192
+ st = 0
193
+ n_q = n_q if n_q else self.n_q
194
+ codes = self.quantizer.encode(e, n_q=n_q, st=st)
195
+ return codes
196
+
197
+ def decode(self, codes: torch.tensor, st: int = 0):
198
+ """
199
+
200
+ Parameters
201
+ ----------
202
+ codes : torch.tensor
203
+ Indices for each quantizer. Shape: (n_q, batch, timesteps).
204
+ st : int, optional
205
+ Start quantizer index in RVQ. The default is 0.
206
+
207
+ Returns
208
+ -------
209
+ o : torch.tensor
210
+ Reconstruct wavs from codes. Shape: (batch, channels, timesteps)
211
+
212
+ """
213
+ quantized = self.quantizer.decode(codes, st=st)
214
+ o = self.decoder(quantized)
215
+ return o
speechtokenizer/modules/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Torch modules."""
8
+
9
+ # flake8: noqa
10
+ from .conv import (
11
+ pad1d,
12
+ unpad1d,
13
+ NormConv1d,
14
+ NormConvTranspose1d,
15
+ NormConv2d,
16
+ NormConvTranspose2d,
17
+ SConv1d,
18
+ SConvTranspose1d,
19
+ )
20
+ from .lstm import SLSTM
21
+ from .seanet import SEANetEncoder, SEANetDecoder
speechtokenizer/modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (500 Bytes). View file
 
speechtokenizer/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (654 Bytes). View file
 
speechtokenizer/modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (494 Bytes). View file
 
speechtokenizer/modules/__pycache__/conv.cpython-310.pyc ADDED
Binary file (9.2 kB). View file
 
speechtokenizer/modules/__pycache__/conv.cpython-311.pyc ADDED
Binary file (15.3 kB). View file
 
speechtokenizer/modules/__pycache__/conv.cpython-39.pyc ADDED
Binary file (9.43 kB). View file
 
speechtokenizer/modules/__pycache__/lstm.cpython-310.pyc ADDED
Binary file (1.14 kB). View file
 
speechtokenizer/modules/__pycache__/lstm.cpython-311.pyc ADDED
Binary file (1.8 kB). View file
 
speechtokenizer/modules/__pycache__/lstm.cpython-39.pyc ADDED
Binary file (1.13 kB). View file
 
speechtokenizer/modules/__pycache__/norm.cpython-310.pyc ADDED
Binary file (1.15 kB). View file
 
speechtokenizer/modules/__pycache__/norm.cpython-311.pyc ADDED
Binary file (1.69 kB). View file
 
speechtokenizer/modules/__pycache__/norm.cpython-39.pyc ADDED
Binary file (1.15 kB). View file
 
speechtokenizer/modules/__pycache__/seanet.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
speechtokenizer/modules/__pycache__/seanet.cpython-311.pyc ADDED
Binary file (16.6 kB). View file
 
speechtokenizer/modules/__pycache__/seanet.cpython-39.pyc ADDED
Binary file (10.5 kB). View file
 
speechtokenizer/modules/conv.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Convolutional layers wrappers and utilities."""
8
+
9
+ import math
10
+ import typing as tp
11
+ import warnings
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torch.nn.utils import spectral_norm, weight_norm
17
+
18
+ from .norm import ConvLayerNorm
19
+
20
+
21
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
22
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
23
+
24
+
25
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
26
+ assert norm in CONV_NORMALIZATIONS
27
+ if norm == 'weight_norm':
28
+ return weight_norm(module)
29
+ elif norm == 'spectral_norm':
30
+ return spectral_norm(module)
31
+ else:
32
+ # We already check was in CONV_NORMALIZATION, so any other choice
33
+ # doesn't need reparametrization.
34
+ return module
35
+
36
+
37
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
38
+ """Return the proper normalization module. If causal is True, this will ensure the returned
39
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
40
+ """
41
+ assert norm in CONV_NORMALIZATIONS
42
+ if norm == 'layer_norm':
43
+ assert isinstance(module, nn.modules.conv._ConvNd)
44
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
45
+ elif norm == 'time_group_norm':
46
+ if causal:
47
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
48
+ assert isinstance(module, nn.modules.conv._ConvNd)
49
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
50
+ else:
51
+ return nn.Identity()
52
+
53
+
54
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
55
+ padding_total: int = 0) -> int:
56
+ """See `pad_for_conv1d`.
57
+ """
58
+ length = x.shape[-1]
59
+ n_frames = (length - kernel_size + padding_total) / stride + 1
60
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
61
+ return ideal_length - length
62
+
63
+
64
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
65
+ """Pad for a convolution to make sure that the last window is full.
66
+ Extra padding is added at the end. This is required to ensure that we can rebuild
67
+ an output of the same length, as otherwise, even with padding, some time steps
68
+ might get removed.
69
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
70
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
71
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
72
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
73
+ 1 2 3 4 # once you removed padding, we are missing one time step !
74
+ """
75
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
76
+ return F.pad(x, (0, extra_padding))
77
+
78
+
79
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
80
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
81
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
82
+ """
83
+ length = x.shape[-1]
84
+ padding_left, padding_right = paddings
85
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
86
+ if mode == 'reflect':
87
+ max_pad = max(padding_left, padding_right)
88
+ extra_pad = 0
89
+ if length <= max_pad:
90
+ extra_pad = max_pad - length + 1
91
+ x = F.pad(x, (0, extra_pad))
92
+ padded = F.pad(x, paddings, mode, value)
93
+ end = padded.shape[-1] - extra_pad
94
+ return padded[..., :end]
95
+ else:
96
+ return F.pad(x, paddings, mode, value)
97
+
98
+
99
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
100
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
101
+ padding_left, padding_right = paddings
102
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
103
+ assert (padding_left + padding_right) <= x.shape[-1]
104
+ end = x.shape[-1] - padding_right
105
+ return x[..., padding_left: end]
106
+
107
+
108
+ class NormConv1d(nn.Module):
109
+ """Wrapper around Conv1d and normalization applied to this conv
110
+ to provide a uniform interface across normalization approaches.
111
+ """
112
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
113
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
114
+ super().__init__()
115
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
116
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
117
+ self.norm_type = norm
118
+
119
+ def forward(self, x):
120
+ x = self.conv(x)
121
+ x = self.norm(x)
122
+ return x
123
+
124
+
125
+ class NormConv2d(nn.Module):
126
+ """Wrapper around Conv2d and normalization applied to this conv
127
+ to provide a uniform interface across normalization approaches.
128
+ """
129
+ def __init__(self, *args, norm: str = 'none',
130
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
131
+ super().__init__()
132
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
133
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
134
+ self.norm_type = norm
135
+
136
+ def forward(self, x):
137
+ x = self.conv(x)
138
+ x = self.norm(x)
139
+ return x
140
+
141
+
142
+ class NormConvTranspose1d(nn.Module):
143
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
144
+ to provide a uniform interface across normalization approaches.
145
+ """
146
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
147
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
148
+ super().__init__()
149
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
150
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
151
+ self.norm_type = norm
152
+
153
+ def forward(self, x):
154
+ x = self.convtr(x)
155
+ x = self.norm(x)
156
+ return x
157
+
158
+
159
+ class NormConvTranspose2d(nn.Module):
160
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
161
+ to provide a uniform interface across normalization approaches.
162
+ """
163
+ def __init__(self, *args, norm: str = 'none',
164
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
165
+ super().__init__()
166
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
167
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
168
+
169
+ def forward(self, x):
170
+ x = self.convtr(x)
171
+ x = self.norm(x)
172
+ return x
173
+
174
+
175
+ class SConv1d(nn.Module):
176
+ """Conv1d with some builtin handling of asymmetric or causal padding
177
+ and normalization.
178
+ """
179
+ def __init__(self, in_channels: int, out_channels: int,
180
+ kernel_size: int, stride: int = 1, dilation: int = 1,
181
+ groups: int = 1, bias: bool = True, causal: bool = False,
182
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
183
+ pad_mode: str = 'reflect'):
184
+ super().__init__()
185
+ # warn user on unusual setup between dilation and stride
186
+ if stride > 1 and dilation > 1:
187
+ warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
188
+ f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
189
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
190
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
191
+ norm=norm, norm_kwargs=norm_kwargs)
192
+ self.causal = causal
193
+ self.pad_mode = pad_mode
194
+
195
+ def forward(self, x):
196
+ B, C, T = x.shape
197
+ kernel_size = self.conv.conv.kernel_size[0]
198
+ stride = self.conv.conv.stride[0]
199
+ dilation = self.conv.conv.dilation[0]
200
+ padding_total = (kernel_size - 1) * dilation - (stride - 1)
201
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
202
+ if self.causal:
203
+ # Left padding for causal
204
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
205
+ else:
206
+ # Asymmetric padding required for odd strides
207
+ padding_right = padding_total // 2
208
+ padding_left = padding_total - padding_right
209
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
210
+ return self.conv(x)
211
+
212
+
213
+ class SConvTranspose1d(nn.Module):
214
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
215
+ and normalization.
216
+ """
217
+ def __init__(self, in_channels: int, out_channels: int,
218
+ kernel_size: int, stride: int = 1, causal: bool = False,
219
+ norm: str = 'none', trim_right_ratio: float = 1.,
220
+ norm_kwargs: tp.Dict[str, tp.Any] = {}):
221
+ super().__init__()
222
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
223
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
224
+ self.causal = causal
225
+ self.trim_right_ratio = trim_right_ratio
226
+ assert self.causal or self.trim_right_ratio == 1., \
227
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
228
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
229
+
230
+ def forward(self, x):
231
+ kernel_size = self.convtr.convtr.kernel_size[0]
232
+ stride = self.convtr.convtr.stride[0]
233
+ padding_total = kernel_size - stride
234
+
235
+ y = self.convtr(x)
236
+
237
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
238
+ # removed at the very end, when keeping only the right length for the output,
239
+ # as removing it here would require also passing the length at the matching layer
240
+ # in the encoder.
241
+ if self.causal:
242
+ # Trim the padding on the right according to the specified ratio
243
+ # if trim_right_ratio = 1.0, trim everything from right
244
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
245
+ padding_left = padding_total - padding_right
246
+ y = unpad1d(y, (padding_left, padding_right))
247
+ else:
248
+ # Asymmetric padding required for odd strides
249
+ padding_right = padding_total // 2
250
+ padding_left = padding_total - padding_right
251
+ y = unpad1d(y, (padding_left, padding_right))
252
+ return y
speechtokenizer/modules/lstm.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """LSTM layers module."""
8
+
9
+ from torch import nn
10
+
11
+
12
+ class SLSTM(nn.Module):
13
+ """
14
+ LSTM without worrying about the hidden state, nor the layout of the data.
15
+ Expects input as convolutional layout.
16
+ """
17
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True, bidirectional: bool=False):
18
+ super().__init__()
19
+ self.bidirectional = bidirectional
20
+ self.skip = skip
21
+ self.lstm = nn.LSTM(dimension, dimension, num_layers, bidirectional=bidirectional)
22
+
23
+ def forward(self, x):
24
+ x = x.permute(2, 0, 1)
25
+ y, _ = self.lstm(x)
26
+ if self.bidirectional:
27
+ x = x.repeat(1, 1, 2)
28
+ if self.skip:
29
+ y = y + x
30
+ y = y.permute(1, 2, 0)
31
+ return y
speechtokenizer/modules/norm.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Normalization modules."""
8
+
9
+ import typing as tp
10
+
11
+ import einops
12
+ import torch
13
+ from torch import nn
14
+
15
+
16
+ class ConvLayerNorm(nn.LayerNorm):
17
+ """
18
+ Convolution-friendly LayerNorm that moves channels to last dimensions
19
+ before running the normalization and moves them back to original position right after.
20
+ """
21
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
22
+ super().__init__(normalized_shape, **kwargs)
23
+
24
+ def forward(self, x):
25
+ x = einops.rearrange(x, 'b ... t -> b t ...')
26
+ x = super().forward(x)
27
+ x = einops.rearrange(x, 'b t ... -> b ... t')
28
+ return
speechtokenizer/modules/seanet.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Encodec SEANet-based encoder and decoder implementation."""
8
+
9
+ import typing as tp
10
+
11
+ import numpy as np
12
+ import torch.nn as nn
13
+ import torch
14
+
15
+ from . import SConv1d, SConvTranspose1d, SLSTM
16
+
17
+
18
+ @torch.jit.script
19
+ def snake(x, alpha):
20
+ shape = x.shape
21
+ x = x.reshape(shape[0], shape[1], -1)
22
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
+ x = x.reshape(shape)
24
+ return x
25
+
26
+
27
+ class Snake1d(nn.Module):
28
+ def __init__(self, channels):
29
+ super().__init__()
30
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
+
32
+ def forward(self, x):
33
+ return snake(x, self.alpha)
34
+
35
+
36
+ class SEANetResnetBlock(nn.Module):
37
+ """Residual block from SEANet model.
38
+ Args:
39
+ dim (int): Dimension of the input/output
40
+ kernel_sizes (list): List of kernel sizes for the convolutions.
41
+ dilations (list): List of dilations for the convolutions.
42
+ activation (str): Activation function.
43
+ activation_params (dict): Parameters to provide to the activation function
44
+ norm (str): Normalization method.
45
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
46
+ causal (bool): Whether to use fully causal convolution.
47
+ pad_mode (str): Padding mode for the convolutions.
48
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3)
49
+ true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ dim: int,
55
+ kernel_sizes: tp.List[int] = [3, 1],
56
+ dilations: tp.List[int] = [1, 1],
57
+ activation: str = "ELU",
58
+ activation_params: dict = {"alpha": 1.0},
59
+ norm: str = "weight_norm",
60
+ norm_params: tp.Dict[str, tp.Any] = {},
61
+ causal: bool = False,
62
+ pad_mode: str = "reflect",
63
+ compress: int = 2,
64
+ true_skip: bool = True,
65
+ ):
66
+ super().__init__()
67
+ assert len(kernel_sizes) == len(
68
+ dilations
69
+ ), "Number of kernel sizes should match number of dilations"
70
+ act = getattr(nn, activation) if activation != "Snake" else Snake1d
71
+ hidden = dim // compress
72
+ block = []
73
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
74
+ in_chs = dim if i == 0 else hidden
75
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
76
+ block += [
77
+ act(**activation_params) if activation != "Snake" else act(in_chs),
78
+ SConv1d(
79
+ in_chs,
80
+ out_chs,
81
+ kernel_size=kernel_size,
82
+ dilation=dilation,
83
+ norm=norm,
84
+ norm_kwargs=norm_params,
85
+ causal=causal,
86
+ pad_mode=pad_mode,
87
+ ),
88
+ ]
89
+ self.block = nn.Sequential(*block)
90
+ self.shortcut: nn.Module
91
+ if true_skip:
92
+ self.shortcut = nn.Identity()
93
+ else:
94
+ self.shortcut = SConv1d(
95
+ dim,
96
+ dim,
97
+ kernel_size=1,
98
+ norm=norm,
99
+ norm_kwargs=norm_params,
100
+ causal=causal,
101
+ pad_mode=pad_mode,
102
+ )
103
+
104
+ def forward(self, x):
105
+ return self.shortcut(x) + self.block(x)
106
+
107
+
108
+ class SEANetEncoder(nn.Module):
109
+ """SEANet encoder.
110
+ Args:
111
+ channels (int): Audio channels.
112
+ dimension (int): Intermediate representation dimension.
113
+ n_filters (int): Base width for the model.
114
+ n_residual_layers (int): nb of residual layers.
115
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
116
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
117
+ that must match the decoder order
118
+ activation (str): Activation function.
119
+ activation_params (dict): Parameters to provide to the activation function
120
+ norm (str): Normalization method.
121
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
122
+ kernel_size (int): Kernel size for the initial convolution.
123
+ last_kernel_size (int): Kernel size for the initial convolution.
124
+ residual_kernel_size (int): Kernel size for the residual layers.
125
+ dilation_base (int): How much to increase the dilation with each layer.
126
+ causal (bool): Whether to use fully causal convolution.
127
+ pad_mode (str): Padding mode for the convolutions.
128
+ true_skip (bool): Whether to use true skip connection or a simple
129
+ (streamable) convolution as the skip connection in the residual network blocks.
130
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
131
+ lstm (int): Number of LSTM layers at the end of the encoder.
132
+ """
133
+
134
+ def __init__(
135
+ self,
136
+ channels: int = 1,
137
+ dimension: int = 128,
138
+ n_filters: int = 32,
139
+ n_residual_layers: int = 1,
140
+ ratios: tp.List[int] = [8, 5, 4, 2],
141
+ activation: str = "ELU",
142
+ activation_params: dict = {"alpha": 1.0},
143
+ norm: str = "weight_norm",
144
+ norm_params: tp.Dict[str, tp.Any] = {},
145
+ kernel_size: int = 7,
146
+ last_kernel_size: int = 7,
147
+ residual_kernel_size: int = 3,
148
+ dilation_base: int = 2,
149
+ causal: bool = False,
150
+ pad_mode: str = "reflect",
151
+ true_skip: bool = False,
152
+ compress: int = 2,
153
+ lstm: int = 2,
154
+ bidirectional: bool = False,
155
+ ):
156
+ super().__init__()
157
+ self.channels = channels
158
+ self.dimension = dimension
159
+ self.n_filters = n_filters
160
+ self.ratios = list(reversed(ratios))
161
+ del ratios
162
+ self.n_residual_layers = n_residual_layers
163
+ self.hop_length = np.prod(self.ratios) # 计算乘积
164
+
165
+ act = getattr(nn, activation) if activation != "Snake" else Snake1d
166
+ mult = 1
167
+ model: tp.List[nn.Module] = [
168
+ SConv1d(
169
+ channels,
170
+ mult * n_filters,
171
+ kernel_size,
172
+ norm=norm,
173
+ norm_kwargs=norm_params,
174
+ causal=causal,
175
+ pad_mode=pad_mode,
176
+ )
177
+ ]
178
+ # Downsample to raw audio scale
179
+ for i, ratio in enumerate(self.ratios):
180
+ # Add residual layers
181
+ for j in range(n_residual_layers):
182
+ model += [
183
+ SEANetResnetBlock(
184
+ mult * n_filters,
185
+ kernel_sizes=[residual_kernel_size, 1],
186
+ dilations=[dilation_base**j, 1],
187
+ norm=norm,
188
+ norm_params=norm_params,
189
+ activation=activation,
190
+ activation_params=activation_params,
191
+ causal=causal,
192
+ pad_mode=pad_mode,
193
+ compress=compress,
194
+ true_skip=true_skip,
195
+ )
196
+ ]
197
+
198
+ # Add downsampling layers
199
+ model += [
200
+ (
201
+ act(**activation_params)
202
+ if activation != "Snake"
203
+ else act(mult * n_filters)
204
+ ),
205
+ SConv1d(
206
+ mult * n_filters,
207
+ mult * n_filters * 2,
208
+ kernel_size=ratio * 2,
209
+ stride=ratio,
210
+ norm=norm,
211
+ norm_kwargs=norm_params,
212
+ causal=causal,
213
+ pad_mode=pad_mode,
214
+ ),
215
+ ]
216
+ mult *= 2
217
+
218
+ if lstm:
219
+ model += [
220
+ SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)
221
+ ]
222
+
223
+ mult = mult * 2 if bidirectional else mult
224
+ model += [
225
+ (
226
+ act(**activation_params)
227
+ if activation != "Snake"
228
+ else act(mult * n_filters)
229
+ ),
230
+ SConv1d(
231
+ mult * n_filters,
232
+ dimension,
233
+ last_kernel_size,
234
+ norm=norm,
235
+ norm_kwargs=norm_params,
236
+ causal=causal,
237
+ pad_mode=pad_mode,
238
+ ),
239
+ ]
240
+
241
+ self.model = nn.Sequential(*model)
242
+
243
+ def forward(self, x):
244
+ return self.model(x)
245
+
246
+
247
+ class SEANetDecoder(nn.Module):
248
+ """SEANet decoder.
249
+ Args:
250
+ channels (int): Audio channels.
251
+ dimension (int): Intermediate representation dimension.
252
+ n_filters (int): Base width for the model.
253
+ n_residual_layers (int): nb of residual layers.
254
+ ratios (Sequence[int]): kernel size and stride ratios
255
+ activation (str): Activation function.
256
+ activation_params (dict): Parameters to provide to the activation function
257
+ final_activation (str): Final activation function after all convolutions.
258
+ final_activation_params (dict): Parameters to provide to the activation function
259
+ norm (str): Normalization method.
260
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
261
+ kernel_size (int): Kernel size for the initial convolution.
262
+ last_kernel_size (int): Kernel size for the initial convolution.
263
+ residual_kernel_size (int): Kernel size for the residual layers.
264
+ dilation_base (int): How much to increase the dilation with each layer.
265
+ causal (bool): Whether to use fully causal convolution.
266
+ pad_mode (str): Padding mode for the convolutions.
267
+ true_skip (bool): Whether to use true skip connection or a simple
268
+ (streamable) convolution as the skip connection in the residual network blocks.
269
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
270
+ lstm (int): Number of LSTM layers at the end of the encoder.
271
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
272
+ If equal to 1.0, it means that all the trimming is done at the right.
273
+ """
274
+
275
+ def __init__(
276
+ self,
277
+ channels: int = 1,
278
+ dimension: int = 128,
279
+ n_filters: int = 32,
280
+ n_residual_layers: int = 1,
281
+ ratios: tp.List[int] = [8, 5, 4, 2],
282
+ activation: str = "ELU",
283
+ activation_params: dict = {"alpha": 1.0},
284
+ final_activation: tp.Optional[str] = None,
285
+ final_activation_params: tp.Optional[dict] = None,
286
+ norm: str = "weight_norm",
287
+ norm_params: tp.Dict[str, tp.Any] = {},
288
+ kernel_size: int = 7,
289
+ last_kernel_size: int = 7,
290
+ residual_kernel_size: int = 3,
291
+ dilation_base: int = 2,
292
+ causal: bool = False,
293
+ pad_mode: str = "reflect",
294
+ true_skip: bool = False,
295
+ compress: int = 2,
296
+ lstm: int = 2,
297
+ trim_right_ratio: float = 1.0,
298
+ bidirectional: bool = False,
299
+ ):
300
+ super().__init__()
301
+ self.dimension = dimension
302
+ self.channels = channels
303
+ self.n_filters = n_filters
304
+ self.ratios = ratios
305
+ del ratios
306
+ self.n_residual_layers = n_residual_layers
307
+ self.hop_length = np.prod(self.ratios)
308
+
309
+ act = getattr(nn, activation) if activation != "Snake" else Snake1d
310
+ mult = int(2 ** len(self.ratios))
311
+ model: tp.List[nn.Module] = [
312
+ SConv1d(
313
+ dimension,
314
+ mult * n_filters,
315
+ kernel_size,
316
+ norm=norm,
317
+ norm_kwargs=norm_params,
318
+ causal=causal,
319
+ pad_mode=pad_mode,
320
+ )
321
+ ]
322
+
323
+ if lstm:
324
+ model += [
325
+ SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)
326
+ ]
327
+
328
+ # Upsample to raw audio scale
329
+ for i, ratio in enumerate(self.ratios):
330
+ # Add upsampling layers
331
+ model += [
332
+ (
333
+ act(**activation_params)
334
+ if activation != "Snake"
335
+ else act(mult * n_filters)
336
+ ),
337
+ SConvTranspose1d(
338
+ mult * n_filters,
339
+ mult * n_filters // 2,
340
+ kernel_size=ratio * 2,
341
+ stride=ratio,
342
+ norm=norm,
343
+ norm_kwargs=norm_params,
344
+ causal=causal,
345
+ trim_right_ratio=trim_right_ratio,
346
+ ),
347
+ ]
348
+ # Add residual layers
349
+ for j in range(n_residual_layers):
350
+ model += [
351
+ SEANetResnetBlock(
352
+ mult * n_filters // 2,
353
+ kernel_sizes=[residual_kernel_size, 1],
354
+ dilations=[dilation_base**j, 1],
355
+ activation=activation,
356
+ activation_params=activation_params,
357
+ norm=norm,
358
+ norm_params=norm_params,
359
+ causal=causal,
360
+ pad_mode=pad_mode,
361
+ compress=compress,
362
+ true_skip=true_skip,
363
+ )
364
+ ]
365
+
366
+ mult //= 2
367
+
368
+ # Add final layers
369
+ model += [
370
+ act(**activation_params) if activation != "Snake" else act(n_filters),
371
+ SConv1d(
372
+ n_filters,
373
+ channels,
374
+ last_kernel_size,
375
+ norm=norm,
376
+ norm_kwargs=norm_params,
377
+ causal=causal,
378
+ pad_mode=pad_mode,
379
+ ),
380
+ ]
381
+ # Add optional final activation to decoder (eg. tanh)
382
+ if final_activation is not None:
383
+ final_act = getattr(nn, final_activation)
384
+ final_activation_params = final_activation_params or {}
385
+ model += [final_act(**final_activation_params)]
386
+ self.model = nn.Sequential(*model)
387
+
388
+ def forward(self, z):
389
+ y = self.model(z)
390
+ return y
391
+
392
+
393
+ def test():
394
+ import torch
395
+
396
+ encoder = SEANetEncoder()
397
+ decoder = SEANetDecoder()
398
+ x = torch.randn(1, 1, 24000)
399
+ z = encoder(x)
400
+ print("z ", z.shape)
401
+ assert 1 == 2
402
+ assert list(z.shape) == [1, 128, 75], z.shape
403
+ y = decoder(z)
404
+ assert y.shape == x.shape, (x.shape, y.shape)
405
+
406
+
407
+ if __name__ == "__main__":
408
+ test()
speechtokenizer/pretrained_model/SpeechTokenizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d04593b6c9a4b475f91ca481141a6ef5b23e6ac112f347dd2b2717f193c1c728
3
+ size 481906997
speechtokenizer/pretrained_model/speechtokenizer_hubert_avg_config.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 3,
4
+ "batch_size": 60,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.5,
7
+ "adam_b2": 0.9,
8
+ "lr_decay": 0.98,
9
+ "seed": 1234,
10
+ "lambda_distill": 0.15,
11
+
12
+ "n_filters": 64,
13
+ "strides": [8,5,4,2],
14
+ "dimension": 1024,
15
+ "semantic_dimension": 768,
16
+ "bidirectional": true,
17
+ "dilation_base": 2,
18
+ "residual_kernel_size": 3,
19
+ "n_residual_layers": 1,
20
+ "lstm_layers": 2,
21
+ "activation": "ELU",
22
+
23
+ "segment_size": 48000,
24
+ "num_mels": 80,
25
+ "num_freq": 1025,
26
+ "n_fft": 1024,
27
+ "hop_size": 240,
28
+ "win_size": 1024,
29
+
30
+ "sampling_rate": 16000,
31
+ "sample_rate": 16000,
32
+
33
+ "codebook_size": 1024,
34
+ "n_q": 8,
35
+
36
+ "fmin": 0,
37
+ "fmax": 8000,
38
+ "fmax_for_loss": null,
39
+
40
+ "num_workers": 12,
41
+
42
+ "dist_config": {
43
+ "dist_backend": "nccl",
44
+ "dist_url": "tcp://localhost:54322",
45
+ "world_size": 1
46
+ }
47
+ }
speechtokenizer/quantization/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # flake8: noqa
8
+ from .vq import QuantizedResult, ResidualVectorQuantizer
speechtokenizer/quantization/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (249 Bytes). View file
 
speechtokenizer/quantization/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (283 Bytes). View file
 
speechtokenizer/quantization/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (243 Bytes). View file
 
speechtokenizer/quantization/__pycache__/core_vq.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
speechtokenizer/quantization/__pycache__/core_vq.cpython-311.pyc ADDED
Binary file (18.9 kB). View file
 
speechtokenizer/quantization/__pycache__/core_vq.cpython-39.pyc ADDED
Binary file (11.1 kB). View file
 
speechtokenizer/quantization/__pycache__/distrib.cpython-310.pyc ADDED
Binary file (3.76 kB). View file
 
speechtokenizer/quantization/__pycache__/distrib.cpython-311.pyc ADDED
Binary file (6.95 kB). View file
 
speechtokenizer/quantization/__pycache__/distrib.cpython-39.pyc ADDED
Binary file (3.78 kB). View file
 
speechtokenizer/quantization/__pycache__/vq.cpython-310.pyc ADDED
Binary file (4.62 kB). View file
 
speechtokenizer/quantization/__pycache__/vq.cpython-311.pyc ADDED
Binary file (6.27 kB). View file
 
speechtokenizer/quantization/__pycache__/vq.cpython-39.pyc ADDED
Binary file (4.57 kB). View file
 
speechtokenizer/quantization/ac.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Arithmetic coder."""
8
+
9
+ import io
10
+ import math
11
+ import random
12
+ import typing as tp
13
+ import torch
14
+
15
+ from ..binary import BitPacker, BitUnpacker
16
+
17
+
18
+ def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int,
19
+ roundoff: float = 1e-8, min_range: int = 2,
20
+ check: bool = True) -> torch.Tensor:
21
+ """Turn the given PDF into a quantized CDF that splits
22
+ [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
23
+ to the PDF.
24
+
25
+ Args:
26
+ pdf (torch.Tensor): probability distribution, shape should be `[N]`.
27
+ total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
28
+ during the coding process is `[0, 2 ** total_range_bits - 1]`.
29
+ roundoff (float): will round the pdf up to that level to remove difference coming
30
+ from e.g. evaluating the Language Model on different architectures.
31
+ min_range (int): minimum range width. Should always be at least 2 for numerical
32
+ stability. Use this to avoid pathological behavior is a value
33
+ that is expected to be rare actually happens in real life.
34
+ check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
35
+ """
36
+ pdf = pdf.detach()
37
+ if roundoff:
38
+ pdf = (pdf / roundoff).floor() * roundoff
39
+ # interpolate with uniform distribution to achieve desired minimum probability.
40
+ total_range = 2 ** total_range_bits
41
+ cardinality = len(pdf)
42
+ alpha = min_range * cardinality / total_range
43
+ assert alpha <= 1, "you must reduce min_range"
44
+ ranges = (((1 - alpha) * total_range) * pdf).floor().long()
45
+ ranges += min_range
46
+ quantized_cdf = torch.cumsum(ranges, dim=-1)
47
+ if min_range < 2:
48
+ raise ValueError("min_range must be at least 2.")
49
+ if check:
50
+ assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
51
+ if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
52
+ raise ValueError("You must increase your total_range_bits.")
53
+ return quantized_cdf
54
+
55
+
56
+ class ArithmeticCoder:
57
+ """ArithmeticCoder,
58
+ Let us take a distribution `p` over `N` symbols, and assume we have a stream
59
+ of random variables `s_t` sampled from `p`. Let us assume that we have a budget
60
+ of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
61
+ corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
62
+ sequence `(s_t)` by doing the following:
63
+
64
+ 1) Initialize the current range to` [0 ** 2 B - 1]`.
65
+ 2) For each time step t, split the current range into contiguous chunks,
66
+ one for each possible outcome, with size roughly proportional to `p`.
67
+ For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
68
+ would be `{[0, 2], [3, 3]}`.
69
+ 3) Select the chunk corresponding to `s_t`, and replace the current range with this.
70
+ 4) When done encoding all the values, just select any value remaining in the range.
71
+
72
+ You will notice that this procedure can fail: for instance if at any point in time
73
+ the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
74
+ possible outcome. Intuitively, the more likely a value is, the less the range width
75
+ will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
76
+ coding scheme, likely outcomes would take less bits, and more of them can be coded
77
+ with a fixed budget.
78
+
79
+ In practice, we do not know `B` ahead of time, but we have a way to inject new bits
80
+ when the current range decreases below a given limit (given by `total_range_bits`), without
81
+ having to redo all the computations. If we encode mostly likely values, we will seldom
82
+ need to inject new bits, but a single rare value can deplete our stock of entropy!
83
+
84
+ In this explanation, we assumed that the distribution `p` was constant. In fact, the present
85
+ code works for any sequence `(p_t)` possibly different for each timestep.
86
+ We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
87
+ the KL between the true distribution and `p_t`, the most efficient the coding will be.
88
+
89
+ Args:
90
+ fo (IO[bytes]): file-like object to which the bytes will be written to.
91
+ total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
92
+ Any time the current range width fall under this limit, new bits will
93
+ be injected to rescale the initial range.
94
+ """
95
+
96
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
97
+ assert total_range_bits <= 30
98
+ self.total_range_bits = total_range_bits
99
+ self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
100
+ self.low: int = 0
101
+ self.high: int = 0
102
+ self.max_bit: int = -1
103
+ self._dbg: tp.List[tp.Any] = []
104
+ self._dbg2: tp.List[tp.Any] = []
105
+
106
+ @property
107
+ def delta(self) -> int:
108
+ """Return the current range width."""
109
+ return self.high - self.low + 1
110
+
111
+ def _flush_common_prefix(self):
112
+ # If self.low and self.high start with the sames bits,
113
+ # those won't change anymore as we always just increase the range
114
+ # by powers of 2, and we can flush them out to the bit stream.
115
+ assert self.high >= self.low, (self.low, self.high)
116
+ assert self.high < 2 ** (self.max_bit + 1)
117
+ while self.max_bit >= 0:
118
+ b1 = self.low >> self.max_bit
119
+ b2 = self.high >> self.max_bit
120
+ if b1 == b2:
121
+ self.low -= (b1 << self.max_bit)
122
+ self.high -= (b1 << self.max_bit)
123
+ assert self.high >= self.low, (self.high, self.low, self.max_bit)
124
+ assert self.low >= 0
125
+ self.max_bit -= 1
126
+ self.packer.push(b1)
127
+ else:
128
+ break
129
+
130
+ def push(self, symbol: int, quantized_cdf: torch.Tensor):
131
+ """Push the given symbol on the stream, flushing out bits
132
+ if possible.
133
+
134
+ Args:
135
+ symbol (int): symbol to encode with the AC.
136
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
137
+ to build this from your pdf estimate.
138
+ """
139
+ while self.delta < 2 ** self.total_range_bits:
140
+ self.low *= 2
141
+ self.high = self.high * 2 + 1
142
+ self.max_bit += 1
143
+
144
+ range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
145
+ range_high = quantized_cdf[symbol].item() - 1
146
+ effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
147
+ effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
148
+ assert self.low <= self.high
149
+ self.high = self.low + effective_high
150
+ self.low = self.low + effective_low
151
+ assert self.low <= self.high, (effective_low, effective_high, range_low, range_high)
152
+ self._dbg.append((self.low, self.high))
153
+ self._dbg2.append((self.low, self.high))
154
+ outs = self._flush_common_prefix()
155
+ assert self.low <= self.high
156
+ assert self.max_bit >= -1
157
+ assert self.max_bit <= 61, self.max_bit
158
+ return outs
159
+
160
+ def flush(self):
161
+ """Flush the remaining information to the stream.
162
+ """
163
+ while self.max_bit >= 0:
164
+ b1 = (self.low >> self.max_bit) & 1
165
+ self.packer.push(b1)
166
+ self.max_bit -= 1
167
+ self.packer.flush()
168
+
169
+
170
+ class ArithmeticDecoder:
171
+ """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
172
+
173
+ Note that this must be called with **exactly** the same parameters and sequence
174
+ of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
175
+
176
+ If the AC encoder current range is [L, H], with `L` and `H` having the some common
177
+ prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
178
+ For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
179
+ `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
180
+ for a specific sequence of symbols and a binary-search allows us to decode those symbols.
181
+ At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
182
+ and we will need to read new bits from the stream and repeat the process.
183
+
184
+ """
185
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
186
+ self.total_range_bits = total_range_bits
187
+ self.low: int = 0
188
+ self.high: int = 0
189
+ self.current: int = 0
190
+ self.max_bit: int = -1
191
+ self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
192
+ # Following is for debugging
193
+ self._dbg: tp.List[tp.Any] = []
194
+ self._dbg2: tp.List[tp.Any] = []
195
+ self._last: tp.Any = None
196
+
197
+ @property
198
+ def delta(self) -> int:
199
+ return self.high - self.low + 1
200
+
201
+ def _flush_common_prefix(self):
202
+ # Given the current range [L, H], if both have a common prefix,
203
+ # we know we can remove it from our representation to avoid handling large numbers.
204
+ while self.max_bit >= 0:
205
+ b1 = self.low >> self.max_bit
206
+ b2 = self.high >> self.max_bit
207
+ if b1 == b2:
208
+ self.low -= (b1 << self.max_bit)
209
+ self.high -= (b1 << self.max_bit)
210
+ self.current -= (b1 << self.max_bit)
211
+ assert self.high >= self.low
212
+ assert self.low >= 0
213
+ self.max_bit -= 1
214
+ else:
215
+ break
216
+
217
+ def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
218
+ """Pull a symbol, reading as many bits from the stream as required.
219
+ This returns `None` when the stream has been exhausted.
220
+
221
+ Args:
222
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
223
+ to build this from your pdf estimate. This must be **exatly**
224
+ the same cdf as the one used at encoding time.
225
+ """
226
+ while self.delta < 2 ** self.total_range_bits:
227
+ bit = self.unpacker.pull()
228
+ if bit is None:
229
+ return None
230
+ self.low *= 2
231
+ self.high = self.high * 2 + 1
232
+ self.current = self.current * 2 + bit
233
+ self.max_bit += 1
234
+
235
+ def bin_search(low_idx: int, high_idx: int):
236
+ # Binary search is not just for coding interviews :)
237
+ if high_idx < low_idx:
238
+ raise RuntimeError("Binary search failed")
239
+ mid = (low_idx + high_idx) // 2
240
+ range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
241
+ range_high = quantized_cdf[mid].item() - 1
242
+ effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
243
+ effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
244
+ low = effective_low + self.low
245
+ high = effective_high + self.low
246
+ if self.current >= low:
247
+ if self.current <= high:
248
+ return (mid, low, high, self.current)
249
+ else:
250
+ return bin_search(mid + 1, high_idx)
251
+ else:
252
+ return bin_search(low_idx, mid - 1)
253
+
254
+ self._last = (self.low, self.high, self.current, self.max_bit)
255
+ sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
256
+ self._dbg.append((self.low, self.high, self.current))
257
+ self._flush_common_prefix()
258
+ self._dbg2.append((self.low, self.high, self.current))
259
+
260
+ return sym
261
+
262
+
263
+ def test():
264
+ torch.manual_seed(1234)
265
+ random.seed(1234)
266
+ for _ in range(4):
267
+ pdfs = []
268
+ cardinality = random.randrange(4000)
269
+ steps = random.randrange(100, 500)
270
+ fo = io.BytesIO()
271
+ encoder = ArithmeticCoder(fo)
272
+ symbols = []
273
+ for step in range(steps):
274
+ pdf = torch.softmax(torch.randn(cardinality), dim=0)
275
+ pdfs.append(pdf)
276
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
277
+ symbol = torch.multinomial(pdf, 1).item()
278
+ symbols.append(symbol)
279
+ encoder.push(symbol, q_cdf)
280
+ encoder.flush()
281
+
282
+ fo.seek(0)
283
+ decoder = ArithmeticDecoder(fo)
284
+ for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
285
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
286
+ decoded_symbol = decoder.pull(q_cdf)
287
+ assert decoded_symbol == symbol, idx
288
+ assert decoder.pull(torch.zeros(1)) is None
289
+
290
+
291
+ if __name__ == "__main__":
292
+ test()
speechtokenizer/quantization/core_vq.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # This implementation is inspired from
8
+ # https://github.com/lucidrains/vector-quantize-pytorch
9
+ # which is released under MIT License. Hereafter, the original license:
10
+ # MIT License
11
+ #
12
+ # Copyright (c) 2020 Phil Wang
13
+ #
14
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ # of this software and associated documentation files (the "Software"), to deal
16
+ # in the Software without restriction, including without limitation the rights
17
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ # copies of the Software, and to permit persons to whom the Software is
19
+ # furnished to do so, subject to the following conditions:
20
+ #
21
+ # The above copyright notice and this permission notice shall be included in all
22
+ # copies or substantial portions of the Software.
23
+ #
24
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ # SOFTWARE.
31
+
32
+ """Core vector quantization implementation."""
33
+ import typing as tp
34
+
35
+ from einops import rearrange, repeat
36
+ import torch
37
+ from torch import nn
38
+ import torch.nn.functional as F
39
+
40
+ from .distrib import broadcast_tensors, rank
41
+
42
+
43
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
44
+ return val if val is not None else d
45
+
46
+
47
+ def ema_inplace(moving_avg, new, decay: float):
48
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
49
+
50
+
51
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
52
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
53
+
54
+
55
+ def uniform_init(*shape: int):
56
+ t = torch.empty(shape)
57
+ nn.init.kaiming_uniform_(t)
58
+ return t
59
+
60
+
61
+ def sample_vectors(samples, num: int):
62
+ num_samples, device = samples.shape[0], samples.device
63
+
64
+ if num_samples >= num:
65
+ indices = torch.randperm(num_samples, device=device)[:num]
66
+ else:
67
+ indices = torch.randint(0, num_samples, (num,), device=device)
68
+
69
+ return samples[indices]
70
+
71
+
72
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
73
+ dim, dtype = samples.shape[-1], samples.dtype
74
+
75
+ means = sample_vectors(samples, num_clusters)
76
+
77
+ for _ in range(num_iters):
78
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
79
+ dists = -(diffs**2).sum(dim=-1)
80
+
81
+ buckets = dists.max(dim=-1).indices
82
+ bins = torch.bincount(buckets, minlength=num_clusters)
83
+ zero_mask = bins == 0
84
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
85
+
86
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
87
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
88
+ new_means = new_means / bins_min_clamped[..., None]
89
+
90
+ means = torch.where(zero_mask[..., None], means, new_means)
91
+
92
+ return means, bins
93
+
94
+
95
+ class EuclideanCodebook(nn.Module):
96
+ """Codebook with Euclidean distance.
97
+ Args:
98
+ dim (int): Dimension.
99
+ codebook_size (int): Codebook size.
100
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
101
+ If set to true, run the k-means algorithm on the first training batch and use
102
+ the learned centroids as initialization.
103
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
104
+ decay (float): Decay for exponential moving average over the codebooks.
105
+ epsilon (float): Epsilon value for numerical stability.
106
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
107
+ that have an exponential moving average cluster size less than the specified threshold with
108
+ randomly selected vector from the current batch.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ dim: int,
114
+ codebook_size: int,
115
+ kmeans_init: int = False,
116
+ kmeans_iters: int = 10,
117
+ decay: float = 0.99,
118
+ epsilon: float = 1e-5,
119
+ threshold_ema_dead_code: int = 2,
120
+ ):
121
+ super().__init__()
122
+ self.decay = decay
123
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
124
+ uniform_init if not kmeans_init else torch.zeros
125
+ )
126
+ embed = init_fn(codebook_size, dim)
127
+
128
+ self.codebook_size = codebook_size
129
+
130
+ self.kmeans_iters = kmeans_iters
131
+ self.epsilon = epsilon
132
+ self.threshold_ema_dead_code = threshold_ema_dead_code
133
+
134
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
135
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
136
+ self.register_buffer("embed", embed)
137
+ self.register_buffer("embed_avg", embed.clone())
138
+
139
+ @torch.jit.ignore
140
+ def init_embed_(self, data):
141
+ if self.inited:
142
+ return
143
+
144
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
145
+ self.embed.data.copy_(embed)
146
+ self.embed_avg.data.copy_(embed.clone())
147
+ self.cluster_size.data.copy_(cluster_size)
148
+ self.inited.data.copy_(torch.Tensor([True]))
149
+ # Make sure all buffers across workers are in sync after initialization
150
+ # broadcast_tensors(self.buffers())
151
+
152
+ def replace_(self, samples, mask):
153
+ modified_codebook = torch.where(
154
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
155
+ )
156
+ self.embed.data.copy_(modified_codebook)
157
+
158
+ def expire_codes_(self, batch_samples):
159
+ if self.threshold_ema_dead_code == 0:
160
+ return
161
+
162
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
163
+ if not torch.any(expired_codes):
164
+ return
165
+
166
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
167
+ self.replace_(batch_samples, mask=expired_codes)
168
+ # broadcast_tensors(self.buffers())
169
+
170
+ def preprocess(self, x):
171
+ x = rearrange(x, "... d -> (...) d")
172
+ return x
173
+
174
+ def quantize(self, x):
175
+ embed = self.embed.t()
176
+ dist = -(
177
+ x.pow(2).sum(1, keepdim=True)
178
+ - 2 * x @ embed
179
+ + embed.pow(2).sum(0, keepdim=True)
180
+ )
181
+ embed_ind = dist.max(dim=-1).indices
182
+ return embed_ind
183
+
184
+ def postprocess_emb(self, embed_ind, shape):
185
+ return embed_ind.view(*shape[:-1])
186
+
187
+ def dequantize(self, embed_ind):
188
+ quantize = F.embedding(embed_ind, self.embed)
189
+ return quantize
190
+
191
+ def encode(self, x):
192
+ shape = x.shape
193
+ # pre-process
194
+ x = self.preprocess(x)
195
+ # quantize
196
+ embed_ind = self.quantize(x)
197
+ # post-process
198
+ embed_ind = self.postprocess_emb(embed_ind, shape)
199
+ return embed_ind
200
+
201
+ def decode(self, embed_ind):
202
+ quantize = self.dequantize(embed_ind)
203
+ return quantize
204
+
205
+ def forward(self, x):
206
+ shape, dtype = x.shape, x.dtype
207
+ x = self.preprocess(x)
208
+
209
+ # self.init_embed_(x)
210
+
211
+ embed_ind = self.quantize(x)
212
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
213
+ embed_ind = self.postprocess_emb(embed_ind, shape)
214
+ quantize = self.dequantize(embed_ind)
215
+
216
+ # if self.training:
217
+ # # We do the expiry of code at that point as buffers are in sync
218
+ # # and all the workers will take the same decision.
219
+ # self.expire_codes_(x)
220
+ # ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
221
+ # embed_sum = x.t() @ embed_onehot
222
+ # ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
223
+ # cluster_size = (
224
+ # laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
225
+ # * self.cluster_size.sum()
226
+ # )
227
+ # embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
228
+ # self.embed.data.copy_(embed_normalized)
229
+
230
+ return quantize, embed_ind
231
+
232
+
233
+ class VectorQuantization(nn.Module):
234
+ """Vector quantization implementation.
235
+ Currently supports only euclidean distance.
236
+ Args:
237
+ dim (int): Dimension
238
+ codebook_size (int): Codebook size
239
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
240
+ decay (float): Decay for exponential moving average over the codebooks.
241
+ epsilon (float): Epsilon value for numerical stability.
242
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
243
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
244
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
245
+ that have an exponential moving average cluster size less than the specified threshold with
246
+ randomly selected vector from the current batch.
247
+ commitment_weight (float): Weight for commitment loss.
248
+ """
249
+
250
+ def __init__(
251
+ self,
252
+ dim: int,
253
+ codebook_size: int,
254
+ codebook_dim: tp.Optional[int] = None,
255
+ decay: float = 0.99,
256
+ epsilon: float = 1e-5,
257
+ kmeans_init: bool = True,
258
+ kmeans_iters: int = 50,
259
+ threshold_ema_dead_code: int = 2,
260
+ commitment_weight: float = 1.0,
261
+ ):
262
+ super().__init__()
263
+ _codebook_dim: int = default(codebook_dim, dim)
264
+
265
+ requires_projection = _codebook_dim != dim
266
+ self.project_in = (
267
+ nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
268
+ )
269
+ self.project_out = (
270
+ nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
271
+ )
272
+
273
+ self.epsilon = epsilon
274
+ self.commitment_weight = commitment_weight
275
+
276
+ self._codebook = EuclideanCodebook(
277
+ dim=_codebook_dim,
278
+ codebook_size=codebook_size,
279
+ kmeans_init=kmeans_init,
280
+ kmeans_iters=kmeans_iters,
281
+ decay=decay,
282
+ epsilon=epsilon,
283
+ threshold_ema_dead_code=threshold_ema_dead_code,
284
+ )
285
+ self.codebook_size = codebook_size
286
+
287
+ @property
288
+ def codebook(self):
289
+ return self._codebook.embed
290
+
291
+ def encode(self, x):
292
+ x = rearrange(x, "b d n -> b n d")
293
+ x = self.project_in(x)
294
+ embed_in = self._codebook.encode(x)
295
+ return embed_in
296
+
297
+ def decode(self, embed_ind):
298
+ quantize = self._codebook.decode(embed_ind)
299
+ quantize = self.project_out(quantize)
300
+ quantize = rearrange(quantize, "b n d -> b d n")
301
+ return quantize
302
+
303
+ def forward(self, x):
304
+ device = x.device
305
+ x = rearrange(x, "b d n -> b n d")
306
+ x = self.project_in(x)
307
+
308
+ quantize, embed_ind = self._codebook(x)
309
+
310
+ if self.training:
311
+ quantize = x + (quantize - x).detach()
312
+
313
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
314
+
315
+ if self.training:
316
+ if self.commitment_weight > 0:
317
+ commit_loss = F.mse_loss(quantize.detach(), x)
318
+ loss = loss + commit_loss * self.commitment_weight
319
+
320
+ quantize = self.project_out(quantize)
321
+ quantize = rearrange(quantize, "b n d -> b d n")
322
+ return quantize, embed_ind, loss
323
+
324
+
325
+ class ResidualVectorQuantization(nn.Module):
326
+ """Residual vector quantization implementation.
327
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
328
+ """
329
+
330
+ def __init__(self, *, num_quantizers, **kwargs):
331
+ super().__init__()
332
+ self.layers = nn.ModuleList(
333
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
334
+ )
335
+
336
+ def forward(
337
+ self,
338
+ x,
339
+ n_q: tp.Optional[int] = None,
340
+ layers: tp.Optional[list] = None,
341
+ st: tp.Optional[int] = None,
342
+ ):
343
+ quantized_out = 0.0
344
+ residual = x
345
+
346
+ all_losses = []
347
+ all_indices = []
348
+ out_quantized = []
349
+ n_q = n_q or len(self.layers)
350
+ st = st or 0
351
+ for i, layer in enumerate(self.layers[st:n_q]):
352
+ quantized, indices, loss = layer(residual)
353
+ residual = residual - quantized
354
+ quantized_out = quantized_out + quantized
355
+
356
+ all_indices.append(indices)
357
+ all_losses.append(loss)
358
+ if layers and i in layers:
359
+ out_quantized.append(quantized)
360
+
361
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
362
+ return quantized_out, out_indices, out_losses, out_quantized
363
+
364
+ def encode(
365
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
366
+ ) -> torch.Tensor:
367
+ residual = x
368
+ all_indices = []
369
+ n_q = n_q or len(self.layers)
370
+ st = st or 0
371
+ for layer in self.layers[st:n_q]:
372
+ indices = layer.encode(residual)
373
+ quantized = layer.decode(indices)
374
+ residual = residual - quantized
375
+ all_indices.append(indices)
376
+ out_indices = torch.stack(all_indices)
377
+ return out_indices
378
+
379
+ def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
380
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
381
+ for i, indices in enumerate(q_indices):
382
+ layer = self.layers[st + i]
383
+ quantized = layer.decode(indices)
384
+ quantized_out = quantized_out + quantized
385
+ return quantized_out
speechtokenizer/quantization/distrib.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Torch distributed utilities."""
8
+
9
+ import typing as tp
10
+
11
+ import torch
12
+
13
+
14
+ def rank():
15
+ if torch.distributed.is_initialized():
16
+ return torch.distributed.get_rank()
17
+ else:
18
+ return 0
19
+
20
+
21
+ def world_size():
22
+ if torch.distributed.is_initialized():
23
+ return torch.distributed.get_world_size()
24
+ else:
25
+ return 1
26
+
27
+
28
+ def is_distributed():
29
+ return world_size() > 1
30
+
31
+
32
+ def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
33
+ if is_distributed():
34
+ return torch.distributed.all_reduce(tensor, op)
35
+
36
+
37
+ def _is_complex_or_float(tensor):
38
+ return torch.is_floating_point(tensor) or torch.is_complex(tensor)
39
+
40
+
41
+ def _check_number_of_params(params: tp.List[torch.Tensor]):
42
+ # utility function to check that the number of params in all workers is the same,
43
+ # and thus avoid a deadlock with distributed all reduce.
44
+ if not is_distributed() or not params:
45
+ return
46
+ #print('params[0].device ', params[0].device)
47
+ tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
48
+ all_reduce(tensor)
49
+ if tensor.item() != len(params) * world_size():
50
+ # If not all the workers have the same number, for at least one of them,
51
+ # this inequality will be verified.
52
+ raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
53
+ "at least one worker has a different one.")
54
+
55
+
56
+ def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
57
+ """Broadcast the tensors from the given parameters to all workers.
58
+ This can be used to ensure that all workers have the same model to start with.
59
+ """
60
+ if not is_distributed():
61
+ return
62
+ tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
63
+ _check_number_of_params(tensors)
64
+ handles = []
65
+ for tensor in tensors:
66
+ # src = int(rank()) # added code
67
+ handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
68
+ handles.append(handle)
69
+ for handle in handles:
70
+ handle.wait()
71
+
72
+
73
+ def sync_buffer(buffers, average=True):
74
+ """
75
+ Sync grad for buffers. If average is False, broadcast instead of averaging.
76
+ """
77
+ if not is_distributed():
78
+ return
79
+ handles = []
80
+ for buffer in buffers:
81
+ if torch.is_floating_point(buffer.data):
82
+ if average:
83
+ handle = torch.distributed.all_reduce(
84
+ buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
85
+ else:
86
+ handle = torch.distributed.broadcast(
87
+ buffer.data, src=0, async_op=True)
88
+ handles.append((buffer, handle))
89
+ for buffer, handle in handles:
90
+ handle.wait()
91
+ if average:
92
+ buffer.data /= world_size
93
+
94
+
95
+ def sync_grad(params):
96
+ """
97
+ Simpler alternative to DistributedDataParallel, that doesn't rely
98
+ on any black magic. For simple models it can also be as fast.
99
+ Just call this on your model parameters after the call to backward!
100
+ """
101
+ if not is_distributed():
102
+ return
103
+ handles = []
104
+ for p in params:
105
+ if p.grad is not None:
106
+ handle = torch.distributed.all_reduce(
107
+ p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
108
+ handles.append((p, handle))
109
+ for p, handle in handles:
110
+ handle.wait()
111
+ p.grad.data /= world_size()
112
+
113
+
114
+ def average_metrics(metrics: tp.Dict[str, float], count=1.):
115
+ """Average a dictionary of metrics across all workers, using the optional
116
+ `count` as unormalized weight.
117
+ """
118
+ if not is_distributed():
119
+ return metrics
120
+ keys, values = zip(*metrics.items())
121
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
122
+ tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
123
+ tensor *= count
124
+ all_reduce(tensor)
125
+ averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
126
+ return dict(zip(keys, averaged))