diff --git a/README.md b/README.md
index 56715545d8657e9074cf9161fa8d984634e29b45..892883c05dde4e31fbd21349642716e77f442232 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,14 @@
---
title: VoiceMark
-emoji: đ
-colorFrom: blue
-colorTo: blue
+emoji: đ
+colorFrom: purple
+colorTo: green
sdk: gradio
-sdk_version: 5.30.0
+sdk_version: 5.16.0
app_file: app.py
pinned: false
-license: apache-2.0
-short_description: Zero-Shot Voice Cloning-Resistant Watermarking Approach Leve
+license: gpl
+short_description: Zero-Shot Voice Cloning-Resistant Watermarking
---
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..db6a9252be66cbf9a963be863b107dcb2d227ac8
--- /dev/null
+++ b/app.py
@@ -0,0 +1,137 @@
+import gradio as gr
+import json
+import os
+import torchaudio
+from infer import (
+ WatermarkSolver,
+ hamming_distance
+)
+
+# Predefined watermarks (instead of loading from a JSON file)
+watermarks = {
+ "VoiceMark": "1000010101010011",
+ "Voice Cloning": "1111111001000010",
+ "Speech Security": "1011101100001110",
+ "Audio Watermarking": "0110110011100010",
+ "Deep Learning": "0000100111111000",
+ "Artificial Intelligence": "0010000100011111",
+ "Hello World": "0001111101110001",
+ "Happy New Year": "1101011011011101",
+ "World Peace": "0011110010011110",
+ "Good Morning": "0000001011000010",
+}
+
+# Initialize WatermarkSolver model
+solver = WatermarkSolver()
+solver.load_model(checkpoint_dir="./", checkpoint_name="voicemark.pth", strict=True)
+
+# Gradio interface
+with gr.Blocks() as demo:
+ gr.Markdown(
+ "# VoiceMark: Zero-Shot Voice Cloning-Resistant Watermarking Approach Leveraging Speaker-Specific Latents"
+ )
+ with gr.Column():
+ gr.Image(
+ value="voicemark_overview.png",
+ width=925,
+ height=487,
+ elem_id="overview_image",
+ label="overview"
+ )
+ # Step 1: Upload audio and select watermark
+ gr.HTML("
The overall architecture of our proposed VoiceMark
")
+
+ # Step 1: Upload audio and select watermark
+ gr.Markdown(
+ """
+ **Step 1**: Upload an audio file or select one from the provided samples, choose a watermark, and generate the watermarked audio.
+ """
+ )
+
+ with gr.Row():
+ with gr.Column():
+ audio_input = gr.Audio(label="Upload Audio", type="filepath")
+
+ gr.Examples(
+ examples=[
+ ["audios/1.wav"],
+ ["audios/2.wav"],
+ ["audios/3.wav"],
+ ["audios/4.wav"],
+ ["audios/5.wav"],
+ ],
+ inputs=audio_input,
+ label="Sample Audios (Click to Use)"
+ )
+
+ with gr.Column():
+ audio_output = gr.Audio(label="Watermarked Audio", type="filepath")
+ watermark_list = gr.Dropdown(
+ label="Select Watermark", choices=list(watermarks.keys()), interactive=True
+ )
+ add_watermark_button = gr.Button("Add Watermark to Audio")
+
+ # Step 2: TTS tools demo links
+ gr.Markdown(
+ """
+ **Step 2**: Download the generated watermarked audio, then use Zero-Shot Voice Cloning tools to generate the cloned audio. Some available tools are:
+ - [CosyVoice2: Scalable Streaming Speech Synthesis with Large Language Models](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B)
+ - [F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
+ - [MaskGCT: Zero-Shot Text-to-Speech with Masked Generative Codec Transformer](https://huggingface.co/spaces/amphion/maskgct)
+ """
+ )
+
+ # Step 3: Upload cloned audio to decode watermark
+ gr.Markdown(
+ """
+ **Step 3**: Upload the cloned audio and decode your watermark.
+ """
+ )
+
+ with gr.Row():
+ decode_audio_input = gr.Audio(label="Upload Cloned Audio", type="filepath")
+ with gr.Column():
+ decoded_watermark_output = gr.Textbox(label="Decoded Watermark")
+ decode_button = gr.Button("Decode Watermark")
+
+ def process_audio(audio_path, watermark_text):
+ if not audio_path:
+ return "No audio selected. Please upload or select a sample."
+ try:
+ watermarked_audio = solver.infer_for_ui(
+ audio_path, watermarks[watermark_text]
+ )
+ return watermarked_audio
+ except ValueError as e:
+ return str(e)
+
+ add_watermark_button.click(
+ process_audio,
+ inputs=[audio_input, watermark_list],
+ outputs=audio_output
+ )
+
+ def decode_watermark(audio_path):
+ try:
+ detect_prob, decoded_id = solver.decode_for_ui(audio_path)
+ if detect_prob < 1e-2:
+ return "No matching watermark found"
+ closest_match = None
+ min_distance = float("inf")
+ for text, id_bin in watermarks.items():
+ distance = hamming_distance(decoded_id, id_bin, base=16)
+ if distance < min_distance:
+ closest_match = text
+ min_distance = distance
+ if min_distance < 10:
+ return closest_match
+ return "No matching watermark found"
+ except ValueError as e:
+ return str(e)
+
+ decode_button.click(
+ decode_watermark, inputs=decode_audio_input, outputs=decoded_watermark_output
+ )
+
+# Launch the Gradio app
+demo.launch()
diff --git a/audios/1.wav b/audios/1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..8bc9491ca8b637dce252df5b56025b8e6c76b58f
--- /dev/null
+++ b/audios/1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cb97a4d6a89ce3cfff4d504de24ef4d79658cc68cc765da7c905b30ecb5e9f1d
+size 160078
diff --git a/audios/2.wav b/audios/2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..e584955749e912955badac8e7b95eecb7ee5d581
--- /dev/null
+++ b/audios/2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d5cf698739bad50c02ab96375bed49594a0b1a4421406b6ed10edc432f09dc94
+size 152078
diff --git a/audios/3.wav b/audios/3.wav
new file mode 100644
index 0000000000000000000000000000000000000000..a2cb187543ddf7d7ebad1ed521576bfdb5c5d9e8
--- /dev/null
+++ b/audios/3.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4673c2f94dec181e7603df7564af0778d6589a7d7f4878f30a9c42bdfc2324f7
+size 160078
diff --git a/audios/4.wav b/audios/4.wav
new file mode 100644
index 0000000000000000000000000000000000000000..f9bdb6702975b272f60c98cc2c8f209c6a11f640
--- /dev/null
+++ b/audios/4.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:10c63ef8e2866b7e4b98d15cbbc084da667b350ab4d084f338b656661b619ceb
+size 98750
diff --git a/audios/5.wav b/audios/5.wav
new file mode 100644
index 0000000000000000000000000000000000000000..650cf3edb5d1b6007b77a980945a6960bcd54498
--- /dev/null
+++ b/audios/5.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:049b7061915306d541d2d8c059c6c9599502997562bc32120f16a873a55e499e
+size 63072
diff --git a/infer.py b/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a5d0cfe9fea48e7e12659d77d5ca5a6673e758d
--- /dev/null
+++ b/infer.py
@@ -0,0 +1,175 @@
+from models import SBW
+import torchaudio
+import torch
+import os
+
+def hamming_distance(s1, s2, base=2):
+ """
+ Calculate the absolute difference between two strings interpreted in chunks of a specified base.
+
+ Args:
+ s1 (str): The first binary string.
+ s2 (str): The second binary string.
+ base (int): The base to interpret the binary strings (e.g., 2 for binary, 16 for hexadecimal).
+
+ Returns:
+ int: The sum of absolute differences for corresponding chunks.
+
+ Raises:
+ ValueError: If the strings are not of equal length or invalid for the given base.
+ """
+ if len(s1) != len(s2):
+ raise ValueError("Both strings must be of equal length")
+
+ # Determine the chunk size for the given base
+ import math
+
+ chunk_size = int(math.log(base, 2))
+
+ if len(s1) % chunk_size != 0 or len(s2) % chunk_size != 0:
+ raise ValueError(
+ f"Binary strings must be a multiple of {chunk_size} bits for base {base}"
+ )
+
+ # Split binary strings into chunks
+ def to_chunks(binary_str):
+ return [
+ binary_str[i : i + chunk_size]
+ for i in range(0, len(binary_str), chunk_size)
+ ]
+
+ chunks_s1 = to_chunks(s1)
+ chunks_s2 = to_chunks(s2)
+
+ # Convert chunks to integers and calculate absolute difference
+ def absolute_difference_chunks(chunk1, chunk2):
+ int1 = int(chunk1, 2)
+ int2 = int(chunk2, 2)
+ return abs(int1 - int2)
+
+ return sum(
+ absolute_difference_chunks(c1, c2) for c1, c2 in zip(chunks_s1, chunks_s2)
+ )
+
+def random_message(nbits: int, batch_size: int) -> torch.Tensor:
+ """Return random message as 0/1 tensor."""
+ if nbits == 0:
+ return torch.tensor([])
+ return torch.randint(0, 2, (batch_size, nbits))
+
+
+def string_to_message(message_str: str, batch_size: int) -> torch.Tensor:
+ """
+ Convert a binary string to a message tensor.
+
+ Args:
+ message_str (str): A string of '0's and '1's.
+ batch_size (int): The batch size for the output tensor.
+
+ Returns:
+ torch.Tensor: A tensor of shape (batch_size, nbits) where nbits is the length of message_str.
+ """
+ # Convert the string to a list of integers (0 and 1)
+ message_list = [int(bit) for bit in message_str]
+ nbits = len(message_list)
+ # Convert the list to a tensor and repeat it for the batch size
+ message_tensor = torch.tensor(message_list).unsqueeze(0).repeat(batch_size, 1)
+ return message_tensor
+
+
+class WatermarkSolver:
+ def __init__(self):
+ self.device = 'cpu'
+ self.sample_rate = 16000
+ self.nbits = 16
+ self.model = SBW()
+
+ def load_model(
+ self, checkpoint_dir, checkpoint_name, strict=False
+ ):
+ """
+ Load the latest model weights from the checkpoint directory.
+ """
+ checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")]
+ if not checkpoint_files:
+ print("No checkpoint files found in the directory.")
+ return
+
+ checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
+ print(f"Loading model weights from {checkpoint_path}...")
+
+ # Load the checkpoint
+ checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'),weights_only=False)
+
+ # Load model state dict with strict=False to allow missing keys (semantic_encoder)
+ model_state_dict = checkpoint["model_state_dict"]
+
+ new_state_dict = {}
+ for k, v in model_state_dict.items():
+ new_key = k.replace("module.", "")
+ new_state_dict[new_key] = v
+ if not strict:
+ new_state_dict = {
+ k: v
+ for k, v in new_state_dict.items()
+ if k in self.model.state_dict()
+ and self.model.state_dict()[k].shape == v.shape
+ }
+ self.model.load_state_dict(new_state_dict, strict=False)
+
+ print("Model state dict loaded successfully.")
+ self.epoch = checkpoint["epoch"]
+ print(f"Checkpoint loaded from epoch {checkpoint['epoch']}.")
+
+ def infer_for_ui(self, input_audio_path, watermark, output_audio_path="tmp"):
+ message_str = watermark
+ self.model.eval()
+
+ waveform, sample_rate = torchaudio.load(input_audio_path)
+
+ if sample_rate != self.sample_rate:
+ resampler = torchaudio.transforms.Resample(
+ orig_freq=sample_rate, new_freq=self.sample_rate
+ )
+ waveform = resampler(waveform)
+
+ waveform = waveform[:1].to(self.device).unsqueeze(1)
+ message = string_to_message(message_str=message_str, batch_size=1).to(
+ self.device
+ )
+
+ with torch.no_grad():
+ output = self.model(
+ waveform,
+ message=message,
+ )
+ y_wm = output["recon_wm"]
+
+ os.makedirs(output_audio_path, exist_ok=True)
+ watermarked_audio_path = os.path.join(
+ output_audio_path,
+ f"{os.path.splitext(os.path.basename(input_audio_path))[0]}_watermarked.wav",
+ )
+ torchaudio.save(watermarked_audio_path, y_wm.squeeze(1).cpu(), self.sample_rate)
+ return watermarked_audio_path
+
+ def decode_for_ui(self, input_audio_path):
+ self.model.eval()
+
+ waveform, sample_rate = torchaudio.load(input_audio_path)
+
+ if sample_rate != self.sample_rate:
+ resampler = torchaudio.transforms.Resample(
+ orig_freq=sample_rate, new_freq=self.sample_rate
+ )
+ waveform = resampler(waveform)
+
+ waveform = waveform[:1].to(self.device).unsqueeze(1)
+
+ with torch.no_grad():
+ detect_prob, detected_message_tensor, _ = self.model.detect_watermark(waveform)
+ detected_id = "".join(
+ map(str, detected_message_tensor.squeeze(0).cpu().numpy().tolist())
+ )
+
+ return detect_prob,detected_id
diff --git a/models.py b/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..c197b7824e64a92834f53810468c5c67534b62d9
--- /dev/null
+++ b/models.py
@@ -0,0 +1,299 @@
+import torch
+import torch.nn as nn
+from typing import Dict, Optional, Tuple
+
+import torch.nn.functional as F
+import torch
+import torch.nn as nn
+
+
+class WMDetector(nn.Module):
+ """
+ Detect watermarks in an audio signal using a Transformer architecture,
+ where the watermark bits are split into bytes (8 bits each).
+ We assume nbits is a multiple of 8.
+ """
+
+ def __init__(
+ self, input_channels: int, nbits: int, nchunk_size: int, d_model: int = 512
+ ):
+ """
+ Args:
+ input_channels (int): Number of input channels in the audio feature (e.g., mel channels).
+ nbits (int): Total number of bits in the watermark, must be a multiple of 8.
+ d_model (int): Embedding dimension for the Transformer.
+ """
+ super().__init__()
+ self.nchunk_size = nchunk_size
+ assert nbits % nchunk_size == 0, "nbits must be a multiple of 8!"
+ self.nbits = nbits
+ self.d_model = d_model
+ # Number of bytes
+ self.nchunks = nbits // nchunk_size
+
+ # 1D convolution to map the input channels to d_model
+ self.embedding = nn.Conv1d(input_channels, d_model, kernel_size=1)
+
+ # Transformer encoder block
+ self.transformer = nn.TransformerEncoder(
+ nn.TransformerEncoderLayer(
+ d_model=d_model,
+ nhead=1,
+ dim_feedforward=d_model * 2,
+ activation="gelu",
+ batch_first=True,
+ ),
+ num_layers=8,
+ )
+
+ # A linear head for watermark presence detection (binary)
+ self.watermark_head = nn.Linear(d_model, 1)
+
+ # For each byte, we perform a 256-way classification
+ self.message_heads = nn.ModuleList(
+ nn.Linear(d_model, 2**nchunk_size) for _ in range(self.nchunks)
+ )
+
+ # Learnable embeddings for each byte chunk (instead of per bit)
+ # Shape: [nchunks, d_model]
+ self.nchunk_embeddings = nn.Parameter(torch.randn(self.nchunks, d_model))
+
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Forward pass of the detector.
+
+ Returns:
+ logits (torch.Tensor): Watermark detection logits of shape [batch, seq_len].
+ chunk_logits (torch.Tensor): Byte-level classification logits of shape [batch, nchunks, 256].
+ """
+ batch_size, input_channels, time_steps = x.shape
+
+ # 1) Map [batch, in_channels, time_steps] â [batch, time_steps, d_model]
+ x = self.embedding(x).permute(0, 2, 1) # [batch, time_steps, d_model]
+
+ # 2) Prepend chunk embeddings at the beginning of the sequence
+ # [nchunks, d_model] â [1, nchunks, d_model] â [batch, nchunks, d_model]
+ nchunk_embeds = self.nchunk_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
+ # Concatenate along the time dimension: [batch, nchunks + time_steps, d_model]
+ x = torch.cat([nchunk_embeds, x], dim=1)
+
+ # 3) Pass through the Transformer
+ x = self.transformer(x)
+ # x has shape [batch, nchunks + time_steps, d_model]
+
+ # (a) Watermark presence detection: skip the first nchunks
+ detection_part = x[:, self.nchunks :] # [batch, time_steps, d_model]
+ logits = self.watermark_head(detection_part).squeeze(-1) # [batch, time_steps]
+
+ # (b) Message decoding: use the first nchunks
+ message_part = x[:, : self.nchunks] # [batch, nchunks, d_model]
+ chunk_logits_list = []
+ for i, head in enumerate(self.message_heads):
+ # message_part[:, i, :] has shape [batch, d_model]
+ # each head outputs [batch, 256]
+ chunk_vec = message_part[:, i, :]
+ chunk_logits_list.append(head(chunk_vec).unsqueeze(1)) # [batch, 1, 256]
+
+ # Concatenate along the 'nchunks' dimension â [batch, nchunks, 256]
+ chunk_logits = torch.cat(chunk_logits_list, dim=1)
+
+ return logits, chunk_logits
+
+ def detect_watermark(
+ self,
+ x: torch.Tensor,
+ sample_rate: Optional[int] = None,
+ threshold: float = 0.5,
+ ) -> Tuple[float, torch.Tensor, torch.Tensor]:
+ """
+ A convenience function for inference.
+
+ Returns:
+ detect_prob (float): Probability that the audio is watermarked.
+ binary_message (torch.Tensor): The recovered message of shape [batch, nbits] (binary).
+ detected (torch.Tensor): The sigmoid values of the per-timestep watermark detection.
+ """
+ logits, chunk_logits = self.forward(x)
+ # logits: [batch, seq_len] â raw logits for watermark presence detection
+ # chunk_logits: [batch, nchunks, 256] â classification logits for each byte
+
+ # (1) Compute watermark detection probability
+ detected = torch.sigmoid(logits) # [batch, seq_len]
+ detect_prob = detected.mean(dim=-1).cpu().item()
+
+ # (2) Decode the message: chunk_logits has shape [batch, nchunks, 256]
+ chunk_probs = F.softmax(chunk_logits, dim=-1) # [batch, nchunks, 256]
+ chunk_indices = torch.argmax(
+ chunk_probs, dim=-1
+ ) # [batch, nchunks], each in [0..255]
+ # (3) Convert each byte back to 8 bits
+ # Finally, assemble into a [batch, nbits] binary tensor
+ binary_message = []
+ for i in range(self.nchunks):
+ chunk_val = chunk_indices[:, i] # [batch]
+ # Extract 8 bits from the integer (0..255)
+ chunk_bits = []
+ for b in range(self.nchunk_size):
+ bit_b = (chunk_val >> b) & 1 # get bit b
+ chunk_bits.append(bit_b.unsqueeze(-1))
+ # Concatenate bits to shape [batch, 8]
+ chunk_bits = torch.cat(chunk_bits, dim=-1)
+ binary_message.append(chunk_bits)
+
+ # Concatenate all bytes â [batch, nbits]
+ binary_message = torch.cat(binary_message, dim=-1)
+
+ return detect_prob, binary_message, detected
+
+
+
+class WMEmbedder(nn.Module):
+ """
+ A class that takes a secret message, processes it into chunk embeddings
+ (as a small sequence), and uses a TransformerDecoder to do cross-attention
+ between the original hidden (target) and the watermark tokens (memory).
+ """
+
+ def __init__(
+ self,
+ nbits: int, # total bits in the secret message
+ input_dim: int, # the input dimension (e.g. audio feature dimension)
+ nchunk_size: int,
+ hidden_dim: int = 256,
+ num_heads: int = 1,
+ num_layers: int = 4,
+ ):
+ super().__init__()
+ self.nchunk_size = nchunk_size
+ assert nbits % nchunk_size == 0, "nbits must be a multiple of nchunk_size!"
+ self.nbits = nbits
+ self.nchunks = nbits // nchunk_size # how many chunks
+
+ # Each chunk (0..2^nchunk_size - 1) maps to an embedding of size [hidden_dim]
+ self.msg_embeddings = nn.ModuleList(
+ nn.Embedding(2**nchunk_size, hidden_dim) for _ in range(self.nchunks)
+ )
+
+ # Linear to project [input_dim] -> [hidden_dim]
+ self.input_projection = nn.Linear(input_dim, hidden_dim)
+
+ # TransformerDecoder for cross-attention
+ # d_model=hidden_dim, so the decoder expects [b, seq_len, hidden_dim] as tgt
+ # and [b, memory_len, hidden_dim] as memory
+ decoder_layer = nn.TransformerDecoderLayer(
+ d_model=hidden_dim,
+ nhead=num_heads,
+ dim_feedforward=2 * hidden_dim,
+ activation="gelu",
+ batch_first=True, # so shape is [batch, seq, feature]
+ )
+ self.transformer_decoder = nn.TransformerDecoder(
+ decoder_layer, num_layers=num_layers
+ )
+
+ # Project [hidden_dim] -> [input_dim]
+ # self.output_projection1 = nn.Linear(hidden_dim * 2, hidden_dim)
+ self.output_projection = nn.Linear(hidden_dim, input_dim)
+
+ def forward(self, hidden: torch.Tensor, msg: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ hidden: [batch, input_dim, seq_len]
+ msg: [batch, nbits]
+ Returns:
+ A tensor [batch, input_dim, seq_len] with watermark injected.
+ """
+ b, in_dim, seq_len = hidden.shape
+
+ # 1) Project input features to [b, seq_len, hidden_dim]
+ hidden_projected = self.input_projection(
+ hidden.permute(0, 2, 1)
+ ) # => [b, seq_len, hidden_dim]
+
+ # 2) Convert the msg bits into a sequence of chunk embeddings
+ # We keep each chunk as one token => [b, nchunks, hidden_dim]
+ chunk_emb_list = []
+ for i in range(self.nchunks):
+ # msg[:, i*nchunk_size : (i+1)*nchunk_size] => shape [b, nchunk_size]
+ chunk_bits = msg[:, i * self.nchunk_size : (i + 1) * self.nchunk_size]
+ chunk_val = torch.zeros_like(chunk_bits[:, 0]) # shape [b]
+ for bit_idx in range(self.nchunk_size):
+ # shift bits
+ chunk_val += chunk_bits[:, bit_idx] << bit_idx
+
+ # embedding => [b, hidden_dim]
+ chunk_emb = self.msg_embeddings[i](chunk_val)
+ chunk_emb_list.append(chunk_emb.unsqueeze(1)) # => [b,1,hidden_dim]
+
+ # Concat => [b, nchunks, hidden_dim]
+ chunk_emb_seq = torch.cat(chunk_emb_list, dim=1) # [b, nchunks, hidden_dim]
+
+ # 3) Use chunk_emb_seq as memory, hidden_projected as target for TransformerDecoder
+ #
+ # TransformerDecoder forward signature:
+ # transformer_decoder(tgt, memory, ...)
+ # => [b, seq_len, hidden_dim]
+ x_decoded = self.transformer_decoder(
+ tgt=hidden_projected, # [b, seq_len, hidden_dim]
+ memory=chunk_emb_seq, # [b, nchunks, hidden_dim]
+ )
+
+ # 4) Project back to input_dim => [b, seq_len, input_dim]
+ x_output = self.output_projection(x_decoded)
+
+ # 5) permute back to [b, input_dim, seq_len]
+ x_output = x_output.permute(0, 2, 1) # => [b, input_dim, seq_len]
+
+ # 6) (Optional) Residual with original hidden
+ x_output = x_output + hidden
+
+ return x_output
+
+
+from speechtokenizer import SpeechTokenizer
+
+
+class SBW(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.nbits = 16
+ config_path = (
+ "speechtokenizer/pretrained_model/speechtokenizer_hubert_avg_config.json"
+ )
+ ckpt_path = "speechtokenizer/pretrained_model/SpeechTokenizer.pt"
+ self.st_model = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
+ self.msg_processor = WMEmbedder(
+ nbits=16,
+ input_dim=1024,
+ nchunk_size=4,
+ )
+ self.detector = WMDetector(
+ 1024,
+ 16,
+ nchunk_size=4,
+ )
+
+ def detect_watermark(
+ self, x: torch.Tensor, return_logits=False
+ ) -> Tuple[float, torch.Tensor]:
+ embedding = self.st_model.forward_feature(x)
+ if return_logits:
+ return self.detector(embedding)
+ return self.detector.detect_watermark(embedding)
+
+ def forward(
+ self,
+ speech_input: torch.Tensor,
+ message: Optional[torch.Tensor] = None,
+ ) -> Dict[str, torch.Tensor]:
+ recon, recon_wm, acoustic, acoustic_wm = self.st_model(
+ speech_input, msg_processor=self.msg_processor, message=message
+ )
+ wav_length = min(speech_input.size(-1), recon_wm.size(-1))
+ speech_input = speech_input[..., :wav_length]
+ recon = recon[..., :wav_length]
+ recon_wm = recon_wm[..., :wav_length]
+ return {
+ "recon": recon,
+ "recon_wm": recon_wm,
+ }
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1220e33b5d43e1e68b2ef3d599f7630434d02b95
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,4 @@
+torchaudio
+beartype
+einops
+omegaconf
\ No newline at end of file
diff --git a/speechtokenizer/__init__.py b/speechtokenizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..390bd890dffbf95ce01d141eb5013177dcea468c
--- /dev/null
+++ b/speechtokenizer/__init__.py
@@ -0,0 +1,3 @@
+from .model import SpeechTokenizer
+
+__version__ = '1.0.0'
\ No newline at end of file
diff --git a/speechtokenizer/model.py b/speechtokenizer/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..cec8a020639e1107b283268c70e5305d288da46c
--- /dev/null
+++ b/speechtokenizer/model.py
@@ -0,0 +1,215 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Wed Aug 30 15:47:55 2023
+@author: zhangxin
+"""
+
+import random
+from models import WMEmbedder
+from .modules.seanet import SEANetEncoder, SEANetDecoder
+from .quantization import ResidualVectorQuantizer
+import torch.nn as nn
+from einops import rearrange
+import torch
+import numpy as np
+
+
+class SpeechTokenizer(nn.Module):
+ def __init__(self, config):
+ """
+
+ Parameters
+ ----------
+ config : json
+ Model Config.
+
+ """
+ super().__init__()
+ self.encoder = SEANetEncoder(
+ n_filters=config.get("n_filters"),
+ dimension=config.get("dimension"),
+ ratios=config.get("strides"),
+ lstm=config.get("lstm_layers"),
+ bidirectional=config.get("bidirectional"),
+ dilation_base=config.get("dilation_base"),
+ residual_kernel_size=config.get("residual_kernel_size"),
+ n_residual_layers=config.get("n_residual_layers"),
+ activation=config.get("activation"),
+ )
+ self.sample_rate = config.get("sample_rate")
+ self.n_q = config.get("n_q")
+ self.downsample_rate = np.prod(config.get("strides"))
+ if config.get("dimension") != config.get("semantic_dimension"):
+ self.transform = nn.Linear(
+ config.get("dimension"), config.get("semantic_dimension")
+ )
+ else:
+ self.transform = nn.Identity()
+ self.quantizer = ResidualVectorQuantizer(
+ dimension=config.get("dimension"),
+ n_q=config.get("n_q"),
+ bins=config.get("codebook_size"),
+ )
+ self.decoder = SEANetDecoder(
+ n_filters=config.get("n_filters"),
+ dimension=config.get("dimension"),
+ ratios=config.get("strides"),
+ lstm=config.get("lstm_layers"),
+ bidirectional=False,
+ dilation_base=config.get("dilation_base"),
+ residual_kernel_size=config.get("residual_kernel_size"),
+ n_residual_layers=config.get("n_residual_layers"),
+ activation=config.get("activation"),
+ )
+
+ @classmethod
+ def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
+ """
+
+ Parameters
+ ----------
+ config_path : str
+ Path of model configuration file.
+ ckpt_path : str
+ Path of model checkpoint.
+
+ Returns
+ -------
+ model : SpeechTokenizer
+ SpeechTokenizer model.
+
+ """
+ import json
+
+ with open(config_path) as f:
+ cfg = json.load(f)
+ model = cls(cfg)
+ params = torch.load(ckpt_path, map_location="cpu")
+ model.load_state_dict(params)
+ return model
+
+ def forward(
+ self,
+ x: torch.tensor,
+ n_q: int = None,
+ layers: list = [0],
+ msg_processor: WMEmbedder = None,
+ message: torch.Tensor = None,
+ ):
+ """
+
+ Parameters
+ ----------
+ x : torch.tensor
+ Input wavs. Shape: (batch, channels, timesteps).
+ n_q : int, optional
+ Number of quantizers in RVQ used to encode. The default is all layers.
+ layers : list[int], optional
+ Layers of RVQ should return quantized result. The default is the first layer.
+
+ Returns
+ -------
+ o : torch.tensor
+ Output wavs. Shape: (batch, channels, timesteps).
+ commit_loss : torch.tensor
+ Commitment loss from residual vector quantizers.
+ feature : torch.tensor
+ Output of RVQ's first layer. Shape: (batch, timesteps, dimension)
+
+ """
+ with torch.no_grad():
+ e = self.encoder(x)
+ quantized_full, _, _, quantized_list = self.quantizer(
+ e, n_q=n_q, layers=[0, 1, 2, 3, 4, 5, 6, 7], st=0
+ )
+ # semantic, _, _, _ = self.quantizer(e, n_q=1, st=0)
+ # acoustic = e - semantic
+ o = self.decoder(quantized_full)
+
+ subset = quantized_list[1:]
+
+ # half_len = len(subset) // 2
+ # selected_for_processing = random.sample(subset, half_len)
+
+ # selected_ids = set(id(x) for x in selected_for_processing)
+ # acoustic_wm = sum(
+ # msg_processor(x, message) if id(x) in selected_ids else x for x in subset
+ # )
+
+ acoustic_wm = sum(msg_processor(x, message) for x in subset)
+
+ acoustic = e - quantized_list[0]
+
+ # e_wm = acoustic_wm
+ e_wm = quantized_list[0] + acoustic_wm
+
+ o_wm = self.decoder(e_wm)
+
+ return (o, o_wm, acoustic, acoustic_wm)
+
+ def forward_feature(self, x: torch.tensor, layers: list = None):
+ """
+
+ Parameters
+ ----------
+ x : torch.tensor
+ Input wavs. Shape should be (batch, channels, timesteps).
+ layers : list[int], optional
+ Layers of RVQ should return quantized result. The default is all layers.
+
+ Returns
+ -------
+ quantized_list : list[torch.tensor]
+ Quantized of required layers.
+
+ """
+ e = self.encoder(x)
+ with torch.no_grad():
+ semantic, _, _, _ = self.quantizer(e, st=0, n_q=1)
+ acoustic = e - semantic
+ return acoustic
+
+ def encode(self, x: torch.tensor, n_q: int = None, st: int = None):
+ """
+
+ Parameters
+ ----------
+ x : torch.tensor
+ Input wavs. Shape: (batch, channels, timesteps).
+ n_q : int, optional
+ Number of quantizers in RVQ used to encode. The default is all layers.
+ st : int, optional
+ Start quantizer index in RVQ. The default is 0.
+
+ Returns
+ -------
+ codes : torch.tensor
+ Output indices for each quantizer. Shape: (n_q, batch, timesteps)
+
+ """
+ e = self.encoder(x)
+ if st is None:
+ st = 0
+ n_q = n_q if n_q else self.n_q
+ codes = self.quantizer.encode(e, n_q=n_q, st=st)
+ return codes
+
+ def decode(self, codes: torch.tensor, st: int = 0):
+ """
+
+ Parameters
+ ----------
+ codes : torch.tensor
+ Indices for each quantizer. Shape: (n_q, batch, timesteps).
+ st : int, optional
+ Start quantizer index in RVQ. The default is 0.
+
+ Returns
+ -------
+ o : torch.tensor
+ Reconstruct wavs from codes. Shape: (batch, channels, timesteps)
+
+ """
+ quantized = self.quantizer.decode(codes, st=st)
+ o = self.decoder(quantized)
+ return o
diff --git a/speechtokenizer/modules/__init__.py b/speechtokenizer/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d61ab031045422c4000cf26388f24d072572089
--- /dev/null
+++ b/speechtokenizer/modules/__init__.py
@@ -0,0 +1,21 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Torch modules."""
+
+# flake8: noqa
+from .conv import (
+ pad1d,
+ unpad1d,
+ NormConv1d,
+ NormConvTranspose1d,
+ NormConv2d,
+ NormConvTranspose2d,
+ SConv1d,
+ SConvTranspose1d,
+)
+from .lstm import SLSTM
+from .seanet import SEANetEncoder, SEANetDecoder
\ No newline at end of file
diff --git a/speechtokenizer/modules/__pycache__/__init__.cpython-310.pyc b/speechtokenizer/modules/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a32d6b7c2de6cbbc18cb783ad1943e6f17caee15
Binary files /dev/null and b/speechtokenizer/modules/__pycache__/__init__.cpython-310.pyc differ
diff --git a/speechtokenizer/modules/__pycache__/__init__.cpython-311.pyc b/speechtokenizer/modules/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..488e7cede536d6a98ece07eeb4e16d1f120eb081
Binary files /dev/null and b/speechtokenizer/modules/__pycache__/__init__.cpython-311.pyc differ
diff --git a/speechtokenizer/modules/__pycache__/__init__.cpython-39.pyc b/speechtokenizer/modules/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9416d9a323d941a0fce79863e2d0b82105666e20
Binary files /dev/null and b/speechtokenizer/modules/__pycache__/__init__.cpython-39.pyc differ
diff --git a/speechtokenizer/modules/__pycache__/conv.cpython-310.pyc b/speechtokenizer/modules/__pycache__/conv.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf35c42c62599415234e4815f176c9bdd83f87d3
Binary files /dev/null and b/speechtokenizer/modules/__pycache__/conv.cpython-310.pyc differ
diff --git a/speechtokenizer/modules/__pycache__/conv.cpython-311.pyc b/speechtokenizer/modules/__pycache__/conv.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef84ffa2942d241c0de6898dab490d0ce0b35f57
Binary files /dev/null and b/speechtokenizer/modules/__pycache__/conv.cpython-311.pyc differ
diff --git a/speechtokenizer/modules/__pycache__/conv.cpython-39.pyc b/speechtokenizer/modules/__pycache__/conv.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e51bc6d346bf5659edb9b16f0243f3f9c7d6c0e7
Binary files /dev/null and b/speechtokenizer/modules/__pycache__/conv.cpython-39.pyc differ
diff --git a/speechtokenizer/modules/__pycache__/lstm.cpython-310.pyc b/speechtokenizer/modules/__pycache__/lstm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8824b269519aa9c34b331b6b63293c3a2a23d97d
Binary files /dev/null and b/speechtokenizer/modules/__pycache__/lstm.cpython-310.pyc differ
diff --git a/speechtokenizer/modules/__pycache__/lstm.cpython-311.pyc b/speechtokenizer/modules/__pycache__/lstm.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..363de3b98fbc8a0a5a3b47f692cf264b364024eb
Binary files /dev/null and b/speechtokenizer/modules/__pycache__/lstm.cpython-311.pyc differ
diff --git a/speechtokenizer/modules/__pycache__/lstm.cpython-39.pyc b/speechtokenizer/modules/__pycache__/lstm.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2f361e0ff5749098e2d04d067fa85a2d47ba3f57
Binary files /dev/null and b/speechtokenizer/modules/__pycache__/lstm.cpython-39.pyc differ
diff --git a/speechtokenizer/modules/__pycache__/norm.cpython-310.pyc b/speechtokenizer/modules/__pycache__/norm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..731f81f0d438c2c70f1ad5ecedfb212f4544549c
Binary files /dev/null and b/speechtokenizer/modules/__pycache__/norm.cpython-310.pyc differ
diff --git a/speechtokenizer/modules/__pycache__/norm.cpython-311.pyc b/speechtokenizer/modules/__pycache__/norm.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..defcb330a18291545a4885475612955e93ac39be
Binary files /dev/null and b/speechtokenizer/modules/__pycache__/norm.cpython-311.pyc differ
diff --git a/speechtokenizer/modules/__pycache__/norm.cpython-39.pyc b/speechtokenizer/modules/__pycache__/norm.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5221b03d1254b82ea0f893b3d45e1e4e475807b8
Binary files /dev/null and b/speechtokenizer/modules/__pycache__/norm.cpython-39.pyc differ
diff --git a/speechtokenizer/modules/__pycache__/seanet.cpython-310.pyc b/speechtokenizer/modules/__pycache__/seanet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d8cfdd3816c3dcc8856f66fd61878e964d391f5a
Binary files /dev/null and b/speechtokenizer/modules/__pycache__/seanet.cpython-310.pyc differ
diff --git a/speechtokenizer/modules/__pycache__/seanet.cpython-311.pyc b/speechtokenizer/modules/__pycache__/seanet.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8d11fac416dc5ed2647d2ea67ab0891e7c49492c
Binary files /dev/null and b/speechtokenizer/modules/__pycache__/seanet.cpython-311.pyc differ
diff --git a/speechtokenizer/modules/__pycache__/seanet.cpython-39.pyc b/speechtokenizer/modules/__pycache__/seanet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..670f9f718da0ac17d0488413d887e72ba8749892
Binary files /dev/null and b/speechtokenizer/modules/__pycache__/seanet.cpython-39.pyc differ
diff --git a/speechtokenizer/modules/conv.py b/speechtokenizer/modules/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..f37963901118e10bfe009e3c6d27e47be06be7cc
--- /dev/null
+++ b/speechtokenizer/modules/conv.py
@@ -0,0 +1,252 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Convolutional layers wrappers and utilities."""
+
+import math
+import typing as tp
+import warnings
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.utils import spectral_norm, weight_norm
+
+from .norm import ConvLayerNorm
+
+
+CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
+
+
+def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
+ assert norm in CONV_NORMALIZATIONS
+ if norm == 'weight_norm':
+ return weight_norm(module)
+ elif norm == 'spectral_norm':
+ return spectral_norm(module)
+ else:
+ # We already check was in CONV_NORMALIZATION, so any other choice
+ # doesn't need reparametrization.
+ return module
+
+
+def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
+ """Return the proper normalization module. If causal is True, this will ensure the returned
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
+ """
+ assert norm in CONV_NORMALIZATIONS
+ if norm == 'layer_norm':
+ assert isinstance(module, nn.modules.conv._ConvNd)
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
+ elif norm == 'time_group_norm':
+ if causal:
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
+ assert isinstance(module, nn.modules.conv._ConvNd)
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
+ else:
+ return nn.Identity()
+
+
+def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
+ padding_total: int = 0) -> int:
+ """See `pad_for_conv1d`.
+ """
+ length = x.shape[-1]
+ n_frames = (length - kernel_size + padding_total) / stride + 1
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+ return ideal_length - length
+
+
+def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
+ """Pad for a convolution to make sure that the last window is full.
+ Extra padding is added at the end. This is required to ensure that we can rebuild
+ an output of the same length, as otherwise, even with padding, some time steps
+ might get removed.
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
+ 1 2 3 4 # once you removed padding, we are missing one time step !
+ """
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+ return F.pad(x, (0, extra_padding))
+
+
+def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
+ """
+ length = x.shape[-1]
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ if mode == 'reflect':
+ max_pad = max(padding_left, padding_right)
+ extra_pad = 0
+ if length <= max_pad:
+ extra_pad = max_pad - length + 1
+ x = F.pad(x, (0, extra_pad))
+ padded = F.pad(x, paddings, mode, value)
+ end = padded.shape[-1] - extra_pad
+ return padded[..., :end]
+ else:
+ return F.pad(x, paddings, mode, value)
+
+
+def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ assert (padding_left + padding_right) <= x.shape[-1]
+ end = x.shape[-1] - padding_right
+ return x[..., padding_left: end]
+
+
+class NormConv1d(nn.Module):
+ """Wrapper around Conv1d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+ super().__init__()
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConv2d(nn.Module):
+ """Wrapper around Conv2d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+ def __init__(self, *args, norm: str = 'none',
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+ super().__init__()
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConvTranspose1d(nn.Module):
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+ super().__init__()
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.convtr(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConvTranspose2d(nn.Module):
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+ def __init__(self, *args, norm: str = 'none',
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+ super().__init__()
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
+
+ def forward(self, x):
+ x = self.convtr(x)
+ x = self.norm(x)
+ return x
+
+
+class SConv1d(nn.Module):
+ """Conv1d with some builtin handling of asymmetric or causal padding
+ and normalization.
+ """
+ def __init__(self, in_channels: int, out_channels: int,
+ kernel_size: int, stride: int = 1, dilation: int = 1,
+ groups: int = 1, bias: bool = True, causal: bool = False,
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
+ pad_mode: str = 'reflect'):
+ super().__init__()
+ # warn user on unusual setup between dilation and stride
+ if stride > 1 and dilation > 1:
+ warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
+ f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
+ norm=norm, norm_kwargs=norm_kwargs)
+ self.causal = causal
+ self.pad_mode = pad_mode
+
+ def forward(self, x):
+ B, C, T = x.shape
+ kernel_size = self.conv.conv.kernel_size[0]
+ stride = self.conv.conv.stride[0]
+ dilation = self.conv.conv.dilation[0]
+ padding_total = (kernel_size - 1) * dilation - (stride - 1)
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+ if self.causal:
+ # Left padding for causal
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = padding_total // 2
+ padding_left = padding_total - padding_right
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
+ return self.conv(x)
+
+
+class SConvTranspose1d(nn.Module):
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
+ and normalization.
+ """
+ def __init__(self, in_channels: int, out_channels: int,
+ kernel_size: int, stride: int = 1, causal: bool = False,
+ norm: str = 'none', trim_right_ratio: float = 1.,
+ norm_kwargs: tp.Dict[str, tp.Any] = {}):
+ super().__init__()
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
+ self.causal = causal
+ self.trim_right_ratio = trim_right_ratio
+ assert self.causal or self.trim_right_ratio == 1., \
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
+
+ def forward(self, x):
+ kernel_size = self.convtr.convtr.kernel_size[0]
+ stride = self.convtr.convtr.stride[0]
+ padding_total = kernel_size - stride
+
+ y = self.convtr(x)
+
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+ # removed at the very end, when keeping only the right length for the output,
+ # as removing it here would require also passing the length at the matching layer
+ # in the encoder.
+ if self.causal:
+ # Trim the padding on the right according to the specified ratio
+ # if trim_right_ratio = 1.0, trim everything from right
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
+ padding_left = padding_total - padding_right
+ y = unpad1d(y, (padding_left, padding_right))
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = padding_total // 2
+ padding_left = padding_total - padding_right
+ y = unpad1d(y, (padding_left, padding_right))
+ return y
diff --git a/speechtokenizer/modules/lstm.py b/speechtokenizer/modules/lstm.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba669aef7249f84f8944b2f2079bf29ff34235e0
--- /dev/null
+++ b/speechtokenizer/modules/lstm.py
@@ -0,0 +1,31 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""LSTM layers module."""
+
+from torch import nn
+
+
+class SLSTM(nn.Module):
+ """
+ LSTM without worrying about the hidden state, nor the layout of the data.
+ Expects input as convolutional layout.
+ """
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True, bidirectional: bool=False):
+ super().__init__()
+ self.bidirectional = bidirectional
+ self.skip = skip
+ self.lstm = nn.LSTM(dimension, dimension, num_layers, bidirectional=bidirectional)
+
+ def forward(self, x):
+ x = x.permute(2, 0, 1)
+ y, _ = self.lstm(x)
+ if self.bidirectional:
+ x = x.repeat(1, 1, 2)
+ if self.skip:
+ y = y + x
+ y = y.permute(1, 2, 0)
+ return y
diff --git a/speechtokenizer/modules/norm.py b/speechtokenizer/modules/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..19970e0a21ea1c10461cb56d776619dd5f64ff36
--- /dev/null
+++ b/speechtokenizer/modules/norm.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Normalization modules."""
+
+import typing as tp
+
+import einops
+import torch
+from torch import nn
+
+
+class ConvLayerNorm(nn.LayerNorm):
+ """
+ Convolution-friendly LayerNorm that moves channels to last dimensions
+ before running the normalization and moves them back to original position right after.
+ """
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
+ super().__init__(normalized_shape, **kwargs)
+
+ def forward(self, x):
+ x = einops.rearrange(x, 'b ... t -> b t ...')
+ x = super().forward(x)
+ x = einops.rearrange(x, 'b t ... -> b ... t')
+ return
diff --git a/speechtokenizer/modules/seanet.py b/speechtokenizer/modules/seanet.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2456a2c2fee8de888af6138ef121d3b4dde81c9
--- /dev/null
+++ b/speechtokenizer/modules/seanet.py
@@ -0,0 +1,408 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Encodec SEANet-based encoder and decoder implementation."""
+
+import typing as tp
+
+import numpy as np
+import torch.nn as nn
+import torch
+
+from . import SConv1d, SConvTranspose1d, SLSTM
+
+
+@torch.jit.script
+def snake(x, alpha):
+ shape = x.shape
+ x = x.reshape(shape[0], shape[1], -1)
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
+ x = x.reshape(shape)
+ return x
+
+
+class Snake1d(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
+
+ def forward(self, x):
+ return snake(x, self.alpha)
+
+
+class SEANetResnetBlock(nn.Module):
+ """Residual block from SEANet model.
+ Args:
+ dim (int): Dimension of the input/output
+ kernel_sizes (list): List of kernel sizes for the convolutions.
+ dilations (list): List of dilations for the convolutions.
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function
+ norm (str): Normalization method.
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+ causal (bool): Whether to use fully causal convolution.
+ pad_mode (str): Padding mode for the convolutions.
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3)
+ true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ kernel_sizes: tp.List[int] = [3, 1],
+ dilations: tp.List[int] = [1, 1],
+ activation: str = "ELU",
+ activation_params: dict = {"alpha": 1.0},
+ norm: str = "weight_norm",
+ norm_params: tp.Dict[str, tp.Any] = {},
+ causal: bool = False,
+ pad_mode: str = "reflect",
+ compress: int = 2,
+ true_skip: bool = True,
+ ):
+ super().__init__()
+ assert len(kernel_sizes) == len(
+ dilations
+ ), "Number of kernel sizes should match number of dilations"
+ act = getattr(nn, activation) if activation != "Snake" else Snake1d
+ hidden = dim // compress
+ block = []
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
+ in_chs = dim if i == 0 else hidden
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
+ block += [
+ act(**activation_params) if activation != "Snake" else act(in_chs),
+ SConv1d(
+ in_chs,
+ out_chs,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ norm=norm,
+ norm_kwargs=norm_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ ),
+ ]
+ self.block = nn.Sequential(*block)
+ self.shortcut: nn.Module
+ if true_skip:
+ self.shortcut = nn.Identity()
+ else:
+ self.shortcut = SConv1d(
+ dim,
+ dim,
+ kernel_size=1,
+ norm=norm,
+ norm_kwargs=norm_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ )
+
+ def forward(self, x):
+ return self.shortcut(x) + self.block(x)
+
+
+class SEANetEncoder(nn.Module):
+ """SEANet encoder.
+ Args:
+ channels (int): Audio channels.
+ dimension (int): Intermediate representation dimension.
+ n_filters (int): Base width for the model.
+ n_residual_layers (int): nb of residual layers.
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
+ that must match the decoder order
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function
+ norm (str): Normalization method.
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+ kernel_size (int): Kernel size for the initial convolution.
+ last_kernel_size (int): Kernel size for the initial convolution.
+ residual_kernel_size (int): Kernel size for the residual layers.
+ dilation_base (int): How much to increase the dilation with each layer.
+ causal (bool): Whether to use fully causal convolution.
+ pad_mode (str): Padding mode for the convolutions.
+ true_skip (bool): Whether to use true skip connection or a simple
+ (streamable) convolution as the skip connection in the residual network blocks.
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+ lstm (int): Number of LSTM layers at the end of the encoder.
+ """
+
+ def __init__(
+ self,
+ channels: int = 1,
+ dimension: int = 128,
+ n_filters: int = 32,
+ n_residual_layers: int = 1,
+ ratios: tp.List[int] = [8, 5, 4, 2],
+ activation: str = "ELU",
+ activation_params: dict = {"alpha": 1.0},
+ norm: str = "weight_norm",
+ norm_params: tp.Dict[str, tp.Any] = {},
+ kernel_size: int = 7,
+ last_kernel_size: int = 7,
+ residual_kernel_size: int = 3,
+ dilation_base: int = 2,
+ causal: bool = False,
+ pad_mode: str = "reflect",
+ true_skip: bool = False,
+ compress: int = 2,
+ lstm: int = 2,
+ bidirectional: bool = False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.dimension = dimension
+ self.n_filters = n_filters
+ self.ratios = list(reversed(ratios))
+ del ratios
+ self.n_residual_layers = n_residual_layers
+ self.hop_length = np.prod(self.ratios) # 莥įŽäšį§¯
+
+ act = getattr(nn, activation) if activation != "Snake" else Snake1d
+ mult = 1
+ model: tp.List[nn.Module] = [
+ SConv1d(
+ channels,
+ mult * n_filters,
+ kernel_size,
+ norm=norm,
+ norm_kwargs=norm_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ )
+ ]
+ # Downsample to raw audio scale
+ for i, ratio in enumerate(self.ratios):
+ # Add residual layers
+ for j in range(n_residual_layers):
+ model += [
+ SEANetResnetBlock(
+ mult * n_filters,
+ kernel_sizes=[residual_kernel_size, 1],
+ dilations=[dilation_base**j, 1],
+ norm=norm,
+ norm_params=norm_params,
+ activation=activation,
+ activation_params=activation_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ compress=compress,
+ true_skip=true_skip,
+ )
+ ]
+
+ # Add downsampling layers
+ model += [
+ (
+ act(**activation_params)
+ if activation != "Snake"
+ else act(mult * n_filters)
+ ),
+ SConv1d(
+ mult * n_filters,
+ mult * n_filters * 2,
+ kernel_size=ratio * 2,
+ stride=ratio,
+ norm=norm,
+ norm_kwargs=norm_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ ),
+ ]
+ mult *= 2
+
+ if lstm:
+ model += [
+ SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)
+ ]
+
+ mult = mult * 2 if bidirectional else mult
+ model += [
+ (
+ act(**activation_params)
+ if activation != "Snake"
+ else act(mult * n_filters)
+ ),
+ SConv1d(
+ mult * n_filters,
+ dimension,
+ last_kernel_size,
+ norm=norm,
+ norm_kwargs=norm_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ ),
+ ]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class SEANetDecoder(nn.Module):
+ """SEANet decoder.
+ Args:
+ channels (int): Audio channels.
+ dimension (int): Intermediate representation dimension.
+ n_filters (int): Base width for the model.
+ n_residual_layers (int): nb of residual layers.
+ ratios (Sequence[int]): kernel size and stride ratios
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function
+ final_activation (str): Final activation function after all convolutions.
+ final_activation_params (dict): Parameters to provide to the activation function
+ norm (str): Normalization method.
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+ kernel_size (int): Kernel size for the initial convolution.
+ last_kernel_size (int): Kernel size for the initial convolution.
+ residual_kernel_size (int): Kernel size for the residual layers.
+ dilation_base (int): How much to increase the dilation with each layer.
+ causal (bool): Whether to use fully causal convolution.
+ pad_mode (str): Padding mode for the convolutions.
+ true_skip (bool): Whether to use true skip connection or a simple
+ (streamable) convolution as the skip connection in the residual network blocks.
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+ lstm (int): Number of LSTM layers at the end of the encoder.
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
+ If equal to 1.0, it means that all the trimming is done at the right.
+ """
+
+ def __init__(
+ self,
+ channels: int = 1,
+ dimension: int = 128,
+ n_filters: int = 32,
+ n_residual_layers: int = 1,
+ ratios: tp.List[int] = [8, 5, 4, 2],
+ activation: str = "ELU",
+ activation_params: dict = {"alpha": 1.0},
+ final_activation: tp.Optional[str] = None,
+ final_activation_params: tp.Optional[dict] = None,
+ norm: str = "weight_norm",
+ norm_params: tp.Dict[str, tp.Any] = {},
+ kernel_size: int = 7,
+ last_kernel_size: int = 7,
+ residual_kernel_size: int = 3,
+ dilation_base: int = 2,
+ causal: bool = False,
+ pad_mode: str = "reflect",
+ true_skip: bool = False,
+ compress: int = 2,
+ lstm: int = 2,
+ trim_right_ratio: float = 1.0,
+ bidirectional: bool = False,
+ ):
+ super().__init__()
+ self.dimension = dimension
+ self.channels = channels
+ self.n_filters = n_filters
+ self.ratios = ratios
+ del ratios
+ self.n_residual_layers = n_residual_layers
+ self.hop_length = np.prod(self.ratios)
+
+ act = getattr(nn, activation) if activation != "Snake" else Snake1d
+ mult = int(2 ** len(self.ratios))
+ model: tp.List[nn.Module] = [
+ SConv1d(
+ dimension,
+ mult * n_filters,
+ kernel_size,
+ norm=norm,
+ norm_kwargs=norm_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ )
+ ]
+
+ if lstm:
+ model += [
+ SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional)
+ ]
+
+ # Upsample to raw audio scale
+ for i, ratio in enumerate(self.ratios):
+ # Add upsampling layers
+ model += [
+ (
+ act(**activation_params)
+ if activation != "Snake"
+ else act(mult * n_filters)
+ ),
+ SConvTranspose1d(
+ mult * n_filters,
+ mult * n_filters // 2,
+ kernel_size=ratio * 2,
+ stride=ratio,
+ norm=norm,
+ norm_kwargs=norm_params,
+ causal=causal,
+ trim_right_ratio=trim_right_ratio,
+ ),
+ ]
+ # Add residual layers
+ for j in range(n_residual_layers):
+ model += [
+ SEANetResnetBlock(
+ mult * n_filters // 2,
+ kernel_sizes=[residual_kernel_size, 1],
+ dilations=[dilation_base**j, 1],
+ activation=activation,
+ activation_params=activation_params,
+ norm=norm,
+ norm_params=norm_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ compress=compress,
+ true_skip=true_skip,
+ )
+ ]
+
+ mult //= 2
+
+ # Add final layers
+ model += [
+ act(**activation_params) if activation != "Snake" else act(n_filters),
+ SConv1d(
+ n_filters,
+ channels,
+ last_kernel_size,
+ norm=norm,
+ norm_kwargs=norm_params,
+ causal=causal,
+ pad_mode=pad_mode,
+ ),
+ ]
+ # Add optional final activation to decoder (eg. tanh)
+ if final_activation is not None:
+ final_act = getattr(nn, final_activation)
+ final_activation_params = final_activation_params or {}
+ model += [final_act(**final_activation_params)]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, z):
+ y = self.model(z)
+ return y
+
+
+def test():
+ import torch
+
+ encoder = SEANetEncoder()
+ decoder = SEANetDecoder()
+ x = torch.randn(1, 1, 24000)
+ z = encoder(x)
+ print("z ", z.shape)
+ assert 1 == 2
+ assert list(z.shape) == [1, 128, 75], z.shape
+ y = decoder(z)
+ assert y.shape == x.shape, (x.shape, y.shape)
+
+
+if __name__ == "__main__":
+ test()
diff --git a/speechtokenizer/pretrained_model/SpeechTokenizer.pt b/speechtokenizer/pretrained_model/SpeechTokenizer.pt
new file mode 100644
index 0000000000000000000000000000000000000000..c81992fd5ac3e9726ff58f659f3115bf032aac32
--- /dev/null
+++ b/speechtokenizer/pretrained_model/SpeechTokenizer.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d04593b6c9a4b475f91ca481141a6ef5b23e6ac112f347dd2b2717f193c1c728
+size 481906997
diff --git a/speechtokenizer/pretrained_model/speechtokenizer_hubert_avg_config.json b/speechtokenizer/pretrained_model/speechtokenizer_hubert_avg_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..fd35d63489d9b631f4bb419fcba5f74d1f76aa40
--- /dev/null
+++ b/speechtokenizer/pretrained_model/speechtokenizer_hubert_avg_config.json
@@ -0,0 +1,47 @@
+{
+ "resblock": "1",
+ "num_gpus": 3,
+ "batch_size": 60,
+ "learning_rate": 0.0001,
+ "adam_b1": 0.5,
+ "adam_b2": 0.9,
+ "lr_decay": 0.98,
+ "seed": 1234,
+ "lambda_distill": 0.15,
+
+ "n_filters": 64,
+ "strides": [8,5,4,2],
+ "dimension": 1024,
+ "semantic_dimension": 768,
+ "bidirectional": true,
+ "dilation_base": 2,
+ "residual_kernel_size": 3,
+ "n_residual_layers": 1,
+ "lstm_layers": 2,
+ "activation": "ELU",
+
+ "segment_size": 48000,
+ "num_mels": 80,
+ "num_freq": 1025,
+ "n_fft": 1024,
+ "hop_size": 240,
+ "win_size": 1024,
+
+ "sampling_rate": 16000,
+ "sample_rate": 16000,
+
+ "codebook_size": 1024,
+ "n_q": 8,
+
+ "fmin": 0,
+ "fmax": 8000,
+ "fmax_for_loss": null,
+
+ "num_workers": 12,
+
+ "dist_config": {
+ "dist_backend": "nccl",
+ "dist_url": "tcp://localhost:54322",
+ "world_size": 1
+ }
+}
diff --git a/speechtokenizer/quantization/__init__.py b/speechtokenizer/quantization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfabe52b8cb6f260cdda6137b34df2f4736bd02f
--- /dev/null
+++ b/speechtokenizer/quantization/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# flake8: noqa
+from .vq import QuantizedResult, ResidualVectorQuantizer
diff --git a/speechtokenizer/quantization/__pycache__/__init__.cpython-310.pyc b/speechtokenizer/quantization/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a9aa18ebd773618b79afb090c24d6ed9ca9b716
Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/__init__.cpython-310.pyc differ
diff --git a/speechtokenizer/quantization/__pycache__/__init__.cpython-311.pyc b/speechtokenizer/quantization/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65723079defbf8b70d4f8b568141253e2af6b02f
Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/__init__.cpython-311.pyc differ
diff --git a/speechtokenizer/quantization/__pycache__/__init__.cpython-39.pyc b/speechtokenizer/quantization/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..76200d4a860c5b09a9f6f794af1ca96d0eaf7e58
Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/__init__.cpython-39.pyc differ
diff --git a/speechtokenizer/quantization/__pycache__/core_vq.cpython-310.pyc b/speechtokenizer/quantization/__pycache__/core_vq.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c62dc7ad30c7d70ae643cc3ae638980a01ec73d5
Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/core_vq.cpython-310.pyc differ
diff --git a/speechtokenizer/quantization/__pycache__/core_vq.cpython-311.pyc b/speechtokenizer/quantization/__pycache__/core_vq.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4acb73bc60e0133ae2d3421300d0922e58eda002
Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/core_vq.cpython-311.pyc differ
diff --git a/speechtokenizer/quantization/__pycache__/core_vq.cpython-39.pyc b/speechtokenizer/quantization/__pycache__/core_vq.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..359211656337bd3cb468b4005292dc9a36bbdc07
Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/core_vq.cpython-39.pyc differ
diff --git a/speechtokenizer/quantization/__pycache__/distrib.cpython-310.pyc b/speechtokenizer/quantization/__pycache__/distrib.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35d5648fb781d6b57dacce130f3a4da5a4e96af7
Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/distrib.cpython-310.pyc differ
diff --git a/speechtokenizer/quantization/__pycache__/distrib.cpython-311.pyc b/speechtokenizer/quantization/__pycache__/distrib.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0cdf666fb73d46c7e3beaeb7f55c23583108dd16
Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/distrib.cpython-311.pyc differ
diff --git a/speechtokenizer/quantization/__pycache__/distrib.cpython-39.pyc b/speechtokenizer/quantization/__pycache__/distrib.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..069fc2a8a74117c9bfeb955243c55c785cc2d16b
Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/distrib.cpython-39.pyc differ
diff --git a/speechtokenizer/quantization/__pycache__/vq.cpython-310.pyc b/speechtokenizer/quantization/__pycache__/vq.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..537d9ea38286c2ef6ee329b43e421bd988c69388
Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/vq.cpython-310.pyc differ
diff --git a/speechtokenizer/quantization/__pycache__/vq.cpython-311.pyc b/speechtokenizer/quantization/__pycache__/vq.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96dec9ab5693e6ffdc62cdfe7210c901c530e461
Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/vq.cpython-311.pyc differ
diff --git a/speechtokenizer/quantization/__pycache__/vq.cpython-39.pyc b/speechtokenizer/quantization/__pycache__/vq.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9847026c05a9d3f900fb985c58a658d1ebe4a74a
Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/vq.cpython-39.pyc differ
diff --git a/speechtokenizer/quantization/ac.py b/speechtokenizer/quantization/ac.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0f3e5dcd385cd273a145effa3f53ce7ccfdc74c
--- /dev/null
+++ b/speechtokenizer/quantization/ac.py
@@ -0,0 +1,292 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Arithmetic coder."""
+
+import io
+import math
+import random
+import typing as tp
+import torch
+
+from ..binary import BitPacker, BitUnpacker
+
+
+def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int,
+ roundoff: float = 1e-8, min_range: int = 2,
+ check: bool = True) -> torch.Tensor:
+ """Turn the given PDF into a quantized CDF that splits
+ [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
+ to the PDF.
+
+ Args:
+ pdf (torch.Tensor): probability distribution, shape should be `[N]`.
+ total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
+ during the coding process is `[0, 2 ** total_range_bits - 1]`.
+ roundoff (float): will round the pdf up to that level to remove difference coming
+ from e.g. evaluating the Language Model on different architectures.
+ min_range (int): minimum range width. Should always be at least 2 for numerical
+ stability. Use this to avoid pathological behavior is a value
+ that is expected to be rare actually happens in real life.
+ check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
+ """
+ pdf = pdf.detach()
+ if roundoff:
+ pdf = (pdf / roundoff).floor() * roundoff
+ # interpolate with uniform distribution to achieve desired minimum probability.
+ total_range = 2 ** total_range_bits
+ cardinality = len(pdf)
+ alpha = min_range * cardinality / total_range
+ assert alpha <= 1, "you must reduce min_range"
+ ranges = (((1 - alpha) * total_range) * pdf).floor().long()
+ ranges += min_range
+ quantized_cdf = torch.cumsum(ranges, dim=-1)
+ if min_range < 2:
+ raise ValueError("min_range must be at least 2.")
+ if check:
+ assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
+ if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
+ raise ValueError("You must increase your total_range_bits.")
+ return quantized_cdf
+
+
+class ArithmeticCoder:
+ """ArithmeticCoder,
+ Let us take a distribution `p` over `N` symbols, and assume we have a stream
+ of random variables `s_t` sampled from `p`. Let us assume that we have a budget
+ of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
+ corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
+ sequence `(s_t)` by doing the following:
+
+ 1) Initialize the current range to` [0 ** 2 B - 1]`.
+ 2) For each time step t, split the current range into contiguous chunks,
+ one for each possible outcome, with size roughly proportional to `p`.
+ For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
+ would be `{[0, 2], [3, 3]}`.
+ 3) Select the chunk corresponding to `s_t`, and replace the current range with this.
+ 4) When done encoding all the values, just select any value remaining in the range.
+
+ You will notice that this procedure can fail: for instance if at any point in time
+ the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
+ possible outcome. Intuitively, the more likely a value is, the less the range width
+ will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
+ coding scheme, likely outcomes would take less bits, and more of them can be coded
+ with a fixed budget.
+
+ In practice, we do not know `B` ahead of time, but we have a way to inject new bits
+ when the current range decreases below a given limit (given by `total_range_bits`), without
+ having to redo all the computations. If we encode mostly likely values, we will seldom
+ need to inject new bits, but a single rare value can deplete our stock of entropy!
+
+ In this explanation, we assumed that the distribution `p` was constant. In fact, the present
+ code works for any sequence `(p_t)` possibly different for each timestep.
+ We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
+ the KL between the true distribution and `p_t`, the most efficient the coding will be.
+
+ Args:
+ fo (IO[bytes]): file-like object to which the bytes will be written to.
+ total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
+ Any time the current range width fall under this limit, new bits will
+ be injected to rescale the initial range.
+ """
+
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
+ assert total_range_bits <= 30
+ self.total_range_bits = total_range_bits
+ self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
+ self.low: int = 0
+ self.high: int = 0
+ self.max_bit: int = -1
+ self._dbg: tp.List[tp.Any] = []
+ self._dbg2: tp.List[tp.Any] = []
+
+ @property
+ def delta(self) -> int:
+ """Return the current range width."""
+ return self.high - self.low + 1
+
+ def _flush_common_prefix(self):
+ # If self.low and self.high start with the sames bits,
+ # those won't change anymore as we always just increase the range
+ # by powers of 2, and we can flush them out to the bit stream.
+ assert self.high >= self.low, (self.low, self.high)
+ assert self.high < 2 ** (self.max_bit + 1)
+ while self.max_bit >= 0:
+ b1 = self.low >> self.max_bit
+ b2 = self.high >> self.max_bit
+ if b1 == b2:
+ self.low -= (b1 << self.max_bit)
+ self.high -= (b1 << self.max_bit)
+ assert self.high >= self.low, (self.high, self.low, self.max_bit)
+ assert self.low >= 0
+ self.max_bit -= 1
+ self.packer.push(b1)
+ else:
+ break
+
+ def push(self, symbol: int, quantized_cdf: torch.Tensor):
+ """Push the given symbol on the stream, flushing out bits
+ if possible.
+
+ Args:
+ symbol (int): symbol to encode with the AC.
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
+ to build this from your pdf estimate.
+ """
+ while self.delta < 2 ** self.total_range_bits:
+ self.low *= 2
+ self.high = self.high * 2 + 1
+ self.max_bit += 1
+
+ range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
+ range_high = quantized_cdf[symbol].item() - 1
+ effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
+ effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
+ assert self.low <= self.high
+ self.high = self.low + effective_high
+ self.low = self.low + effective_low
+ assert self.low <= self.high, (effective_low, effective_high, range_low, range_high)
+ self._dbg.append((self.low, self.high))
+ self._dbg2.append((self.low, self.high))
+ outs = self._flush_common_prefix()
+ assert self.low <= self.high
+ assert self.max_bit >= -1
+ assert self.max_bit <= 61, self.max_bit
+ return outs
+
+ def flush(self):
+ """Flush the remaining information to the stream.
+ """
+ while self.max_bit >= 0:
+ b1 = (self.low >> self.max_bit) & 1
+ self.packer.push(b1)
+ self.max_bit -= 1
+ self.packer.flush()
+
+
+class ArithmeticDecoder:
+ """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
+
+ Note that this must be called with **exactly** the same parameters and sequence
+ of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
+
+ If the AC encoder current range is [L, H], with `L` and `H` having the some common
+ prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
+ For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
+ `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
+ for a specific sequence of symbols and a binary-search allows us to decode those symbols.
+ At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
+ and we will need to read new bits from the stream and repeat the process.
+
+ """
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
+ self.total_range_bits = total_range_bits
+ self.low: int = 0
+ self.high: int = 0
+ self.current: int = 0
+ self.max_bit: int = -1
+ self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
+ # Following is for debugging
+ self._dbg: tp.List[tp.Any] = []
+ self._dbg2: tp.List[tp.Any] = []
+ self._last: tp.Any = None
+
+ @property
+ def delta(self) -> int:
+ return self.high - self.low + 1
+
+ def _flush_common_prefix(self):
+ # Given the current range [L, H], if both have a common prefix,
+ # we know we can remove it from our representation to avoid handling large numbers.
+ while self.max_bit >= 0:
+ b1 = self.low >> self.max_bit
+ b2 = self.high >> self.max_bit
+ if b1 == b2:
+ self.low -= (b1 << self.max_bit)
+ self.high -= (b1 << self.max_bit)
+ self.current -= (b1 << self.max_bit)
+ assert self.high >= self.low
+ assert self.low >= 0
+ self.max_bit -= 1
+ else:
+ break
+
+ def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
+ """Pull a symbol, reading as many bits from the stream as required.
+ This returns `None` when the stream has been exhausted.
+
+ Args:
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
+ to build this from your pdf estimate. This must be **exatly**
+ the same cdf as the one used at encoding time.
+ """
+ while self.delta < 2 ** self.total_range_bits:
+ bit = self.unpacker.pull()
+ if bit is None:
+ return None
+ self.low *= 2
+ self.high = self.high * 2 + 1
+ self.current = self.current * 2 + bit
+ self.max_bit += 1
+
+ def bin_search(low_idx: int, high_idx: int):
+ # Binary search is not just for coding interviews :)
+ if high_idx < low_idx:
+ raise RuntimeError("Binary search failed")
+ mid = (low_idx + high_idx) // 2
+ range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
+ range_high = quantized_cdf[mid].item() - 1
+ effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
+ effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
+ low = effective_low + self.low
+ high = effective_high + self.low
+ if self.current >= low:
+ if self.current <= high:
+ return (mid, low, high, self.current)
+ else:
+ return bin_search(mid + 1, high_idx)
+ else:
+ return bin_search(low_idx, mid - 1)
+
+ self._last = (self.low, self.high, self.current, self.max_bit)
+ sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
+ self._dbg.append((self.low, self.high, self.current))
+ self._flush_common_prefix()
+ self._dbg2.append((self.low, self.high, self.current))
+
+ return sym
+
+
+def test():
+ torch.manual_seed(1234)
+ random.seed(1234)
+ for _ in range(4):
+ pdfs = []
+ cardinality = random.randrange(4000)
+ steps = random.randrange(100, 500)
+ fo = io.BytesIO()
+ encoder = ArithmeticCoder(fo)
+ symbols = []
+ for step in range(steps):
+ pdf = torch.softmax(torch.randn(cardinality), dim=0)
+ pdfs.append(pdf)
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
+ symbol = torch.multinomial(pdf, 1).item()
+ symbols.append(symbol)
+ encoder.push(symbol, q_cdf)
+ encoder.flush()
+
+ fo.seek(0)
+ decoder = ArithmeticDecoder(fo)
+ for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
+ decoded_symbol = decoder.pull(q_cdf)
+ assert decoded_symbol == symbol, idx
+ assert decoder.pull(torch.zeros(1)) is None
+
+
+if __name__ == "__main__":
+ test()
diff --git a/speechtokenizer/quantization/core_vq.py b/speechtokenizer/quantization/core_vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8f32ab303e884a120863f5e360e2af95c98429f
--- /dev/null
+++ b/speechtokenizer/quantization/core_vq.py
@@ -0,0 +1,385 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+#
+# This implementation is inspired from
+# https://github.com/lucidrains/vector-quantize-pytorch
+# which is released under MIT License. Hereafter, the original license:
+# MIT License
+#
+# Copyright (c) 2020 Phil Wang
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+"""Core vector quantization implementation."""
+import typing as tp
+
+from einops import rearrange, repeat
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from .distrib import broadcast_tensors, rank
+
+
+def default(val: tp.Any, d: tp.Any) -> tp.Any:
+ return val if val is not None else d
+
+
+def ema_inplace(moving_avg, new, decay: float):
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+
+
+def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
+
+
+def uniform_init(*shape: int):
+ t = torch.empty(shape)
+ nn.init.kaiming_uniform_(t)
+ return t
+
+
+def sample_vectors(samples, num: int):
+ num_samples, device = samples.shape[0], samples.device
+
+ if num_samples >= num:
+ indices = torch.randperm(num_samples, device=device)[:num]
+ else:
+ indices = torch.randint(0, num_samples, (num,), device=device)
+
+ return samples[indices]
+
+
+def kmeans(samples, num_clusters: int, num_iters: int = 10):
+ dim, dtype = samples.shape[-1], samples.dtype
+
+ means = sample_vectors(samples, num_clusters)
+
+ for _ in range(num_iters):
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
+ dists = -(diffs**2).sum(dim=-1)
+
+ buckets = dists.max(dim=-1).indices
+ bins = torch.bincount(buckets, minlength=num_clusters)
+ zero_mask = bins == 0
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
+ new_means = new_means / bins_min_clamped[..., None]
+
+ means = torch.where(zero_mask[..., None], means, new_means)
+
+ return means, bins
+
+
+class EuclideanCodebook(nn.Module):
+ """Codebook with Euclidean distance.
+ Args:
+ dim (int): Dimension.
+ codebook_size (int): Codebook size.
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
+ If set to true, run the k-means algorithm on the first training batch and use
+ the learned centroids as initialization.
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
+ decay (float): Decay for exponential moving average over the codebooks.
+ epsilon (float): Epsilon value for numerical stability.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ codebook_size: int,
+ kmeans_init: int = False,
+ kmeans_iters: int = 10,
+ decay: float = 0.99,
+ epsilon: float = 1e-5,
+ threshold_ema_dead_code: int = 2,
+ ):
+ super().__init__()
+ self.decay = decay
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
+ uniform_init if not kmeans_init else torch.zeros
+ )
+ embed = init_fn(codebook_size, dim)
+
+ self.codebook_size = codebook_size
+
+ self.kmeans_iters = kmeans_iters
+ self.epsilon = epsilon
+ self.threshold_ema_dead_code = threshold_ema_dead_code
+
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
+ self.register_buffer("embed", embed)
+ self.register_buffer("embed_avg", embed.clone())
+
+ @torch.jit.ignore
+ def init_embed_(self, data):
+ if self.inited:
+ return
+
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
+ self.embed.data.copy_(embed)
+ self.embed_avg.data.copy_(embed.clone())
+ self.cluster_size.data.copy_(cluster_size)
+ self.inited.data.copy_(torch.Tensor([True]))
+ # Make sure all buffers across workers are in sync after initialization
+ # broadcast_tensors(self.buffers())
+
+ def replace_(self, samples, mask):
+ modified_codebook = torch.where(
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
+ )
+ self.embed.data.copy_(modified_codebook)
+
+ def expire_codes_(self, batch_samples):
+ if self.threshold_ema_dead_code == 0:
+ return
+
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
+ if not torch.any(expired_codes):
+ return
+
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
+ self.replace_(batch_samples, mask=expired_codes)
+ # broadcast_tensors(self.buffers())
+
+ def preprocess(self, x):
+ x = rearrange(x, "... d -> (...) d")
+ return x
+
+ def quantize(self, x):
+ embed = self.embed.t()
+ dist = -(
+ x.pow(2).sum(1, keepdim=True)
+ - 2 * x @ embed
+ + embed.pow(2).sum(0, keepdim=True)
+ )
+ embed_ind = dist.max(dim=-1).indices
+ return embed_ind
+
+ def postprocess_emb(self, embed_ind, shape):
+ return embed_ind.view(*shape[:-1])
+
+ def dequantize(self, embed_ind):
+ quantize = F.embedding(embed_ind, self.embed)
+ return quantize
+
+ def encode(self, x):
+ shape = x.shape
+ # pre-process
+ x = self.preprocess(x)
+ # quantize
+ embed_ind = self.quantize(x)
+ # post-process
+ embed_ind = self.postprocess_emb(embed_ind, shape)
+ return embed_ind
+
+ def decode(self, embed_ind):
+ quantize = self.dequantize(embed_ind)
+ return quantize
+
+ def forward(self, x):
+ shape, dtype = x.shape, x.dtype
+ x = self.preprocess(x)
+
+ # self.init_embed_(x)
+
+ embed_ind = self.quantize(x)
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
+ embed_ind = self.postprocess_emb(embed_ind, shape)
+ quantize = self.dequantize(embed_ind)
+
+ # if self.training:
+ # # We do the expiry of code at that point as buffers are in sync
+ # # and all the workers will take the same decision.
+ # self.expire_codes_(x)
+ # ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
+ # embed_sum = x.t() @ embed_onehot
+ # ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
+ # cluster_size = (
+ # laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
+ # * self.cluster_size.sum()
+ # )
+ # embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
+ # self.embed.data.copy_(embed_normalized)
+
+ return quantize, embed_ind
+
+
+class VectorQuantization(nn.Module):
+ """Vector quantization implementation.
+ Currently supports only euclidean distance.
+ Args:
+ dim (int): Dimension
+ codebook_size (int): Codebook size
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
+ decay (float): Decay for exponential moving average over the codebooks.
+ epsilon (float): Epsilon value for numerical stability.
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ commitment_weight (float): Weight for commitment loss.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ codebook_size: int,
+ codebook_dim: tp.Optional[int] = None,
+ decay: float = 0.99,
+ epsilon: float = 1e-5,
+ kmeans_init: bool = True,
+ kmeans_iters: int = 50,
+ threshold_ema_dead_code: int = 2,
+ commitment_weight: float = 1.0,
+ ):
+ super().__init__()
+ _codebook_dim: int = default(codebook_dim, dim)
+
+ requires_projection = _codebook_dim != dim
+ self.project_in = (
+ nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
+ )
+ self.project_out = (
+ nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
+ )
+
+ self.epsilon = epsilon
+ self.commitment_weight = commitment_weight
+
+ self._codebook = EuclideanCodebook(
+ dim=_codebook_dim,
+ codebook_size=codebook_size,
+ kmeans_init=kmeans_init,
+ kmeans_iters=kmeans_iters,
+ decay=decay,
+ epsilon=epsilon,
+ threshold_ema_dead_code=threshold_ema_dead_code,
+ )
+ self.codebook_size = codebook_size
+
+ @property
+ def codebook(self):
+ return self._codebook.embed
+
+ def encode(self, x):
+ x = rearrange(x, "b d n -> b n d")
+ x = self.project_in(x)
+ embed_in = self._codebook.encode(x)
+ return embed_in
+
+ def decode(self, embed_ind):
+ quantize = self._codebook.decode(embed_ind)
+ quantize = self.project_out(quantize)
+ quantize = rearrange(quantize, "b n d -> b d n")
+ return quantize
+
+ def forward(self, x):
+ device = x.device
+ x = rearrange(x, "b d n -> b n d")
+ x = self.project_in(x)
+
+ quantize, embed_ind = self._codebook(x)
+
+ if self.training:
+ quantize = x + (quantize - x).detach()
+
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
+
+ if self.training:
+ if self.commitment_weight > 0:
+ commit_loss = F.mse_loss(quantize.detach(), x)
+ loss = loss + commit_loss * self.commitment_weight
+
+ quantize = self.project_out(quantize)
+ quantize = rearrange(quantize, "b n d -> b d n")
+ return quantize, embed_ind, loss
+
+
+class ResidualVectorQuantization(nn.Module):
+ """Residual vector quantization implementation.
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
+ """
+
+ def __init__(self, *, num_quantizers, **kwargs):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
+ )
+
+ def forward(
+ self,
+ x,
+ n_q: tp.Optional[int] = None,
+ layers: tp.Optional[list] = None,
+ st: tp.Optional[int] = None,
+ ):
+ quantized_out = 0.0
+ residual = x
+
+ all_losses = []
+ all_indices = []
+ out_quantized = []
+ n_q = n_q or len(self.layers)
+ st = st or 0
+ for i, layer in enumerate(self.layers[st:n_q]):
+ quantized, indices, loss = layer(residual)
+ residual = residual - quantized
+ quantized_out = quantized_out + quantized
+
+ all_indices.append(indices)
+ all_losses.append(loss)
+ if layers and i in layers:
+ out_quantized.append(quantized)
+
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
+ return quantized_out, out_indices, out_losses, out_quantized
+
+ def encode(
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
+ ) -> torch.Tensor:
+ residual = x
+ all_indices = []
+ n_q = n_q or len(self.layers)
+ st = st or 0
+ for layer in self.layers[st:n_q]:
+ indices = layer.encode(residual)
+ quantized = layer.decode(indices)
+ residual = residual - quantized
+ all_indices.append(indices)
+ out_indices = torch.stack(all_indices)
+ return out_indices
+
+ def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
+ for i, indices in enumerate(q_indices):
+ layer = self.layers[st + i]
+ quantized = layer.decode(indices)
+ quantized_out = quantized_out + quantized
+ return quantized_out
diff --git a/speechtokenizer/quantization/distrib.py b/speechtokenizer/quantization/distrib.py
new file mode 100644
index 0000000000000000000000000000000000000000..4edc88a532c9cf842bd2ee9a56f2c0b5dcb7c83d
--- /dev/null
+++ b/speechtokenizer/quantization/distrib.py
@@ -0,0 +1,126 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Torch distributed utilities."""
+
+import typing as tp
+
+import torch
+
+
+def rank():
+ if torch.distributed.is_initialized():
+ return torch.distributed.get_rank()
+ else:
+ return 0
+
+
+def world_size():
+ if torch.distributed.is_initialized():
+ return torch.distributed.get_world_size()
+ else:
+ return 1
+
+
+def is_distributed():
+ return world_size() > 1
+
+
+def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
+ if is_distributed():
+ return torch.distributed.all_reduce(tensor, op)
+
+
+def _is_complex_or_float(tensor):
+ return torch.is_floating_point(tensor) or torch.is_complex(tensor)
+
+
+def _check_number_of_params(params: tp.List[torch.Tensor]):
+ # utility function to check that the number of params in all workers is the same,
+ # and thus avoid a deadlock with distributed all reduce.
+ if not is_distributed() or not params:
+ return
+ #print('params[0].device ', params[0].device)
+ tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
+ all_reduce(tensor)
+ if tensor.item() != len(params) * world_size():
+ # If not all the workers have the same number, for at least one of them,
+ # this inequality will be verified.
+ raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
+ "at least one worker has a different one.")
+
+
+def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
+ """Broadcast the tensors from the given parameters to all workers.
+ This can be used to ensure that all workers have the same model to start with.
+ """
+ if not is_distributed():
+ return
+ tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
+ _check_number_of_params(tensors)
+ handles = []
+ for tensor in tensors:
+ # src = int(rank()) # added code
+ handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
+ handles.append(handle)
+ for handle in handles:
+ handle.wait()
+
+
+def sync_buffer(buffers, average=True):
+ """
+ Sync grad for buffers. If average is False, broadcast instead of averaging.
+ """
+ if not is_distributed():
+ return
+ handles = []
+ for buffer in buffers:
+ if torch.is_floating_point(buffer.data):
+ if average:
+ handle = torch.distributed.all_reduce(
+ buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
+ else:
+ handle = torch.distributed.broadcast(
+ buffer.data, src=0, async_op=True)
+ handles.append((buffer, handle))
+ for buffer, handle in handles:
+ handle.wait()
+ if average:
+ buffer.data /= world_size
+
+
+def sync_grad(params):
+ """
+ Simpler alternative to DistributedDataParallel, that doesn't rely
+ on any black magic. For simple models it can also be as fast.
+ Just call this on your model parameters after the call to backward!
+ """
+ if not is_distributed():
+ return
+ handles = []
+ for p in params:
+ if p.grad is not None:
+ handle = torch.distributed.all_reduce(
+ p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
+ handles.append((p, handle))
+ for p, handle in handles:
+ handle.wait()
+ p.grad.data /= world_size()
+
+
+def average_metrics(metrics: tp.Dict[str, float], count=1.):
+ """Average a dictionary of metrics across all workers, using the optional
+ `count` as unormalized weight.
+ """
+ if not is_distributed():
+ return metrics
+ keys, values = zip(*metrics.items())
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
+ tensor *= count
+ all_reduce(tensor)
+ averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
+ return dict(zip(keys, averaged))
diff --git a/speechtokenizer/quantization/vq.py b/speechtokenizer/quantization/vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f6e65f99c2743b5b1a1ca507ce57ad4e9132da1
--- /dev/null
+++ b/speechtokenizer/quantization/vq.py
@@ -0,0 +1,120 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Residual vector quantizer implementation."""
+
+from dataclasses import dataclass, field
+import math
+import typing as tp
+
+import torch
+from torch import nn
+
+from .core_vq import ResidualVectorQuantization
+
+
+@dataclass
+class QuantizedResult:
+ quantized: torch.Tensor
+ codes: torch.Tensor
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
+ penalty: tp.Optional[torch.Tensor] = None
+ metrics: dict = field(default_factory=dict)
+
+
+class ResidualVectorQuantizer(nn.Module):
+ """Residual Vector Quantizer.
+ Args:
+ dimension (int): Dimension of the codebooks.
+ n_q (int): Number of residual vector quantizers used.
+ bins (int): Codebook size.
+ decay (float): Decay for exponential moving average over the codebooks.
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ """
+
+ def __init__(
+ self,
+ dimension: int = 256,
+ n_q: int = 8,
+ bins: int = 1024,
+ decay: float = 0.99,
+ kmeans_init: bool = True,
+ kmeans_iters: int = 50,
+ threshold_ema_dead_code: int = 2,
+ ):
+ super().__init__()
+ self.n_q = n_q
+ self.dimension = dimension
+ self.bins = bins
+ self.decay = decay
+ self.kmeans_init = kmeans_init
+ self.kmeans_iters = kmeans_iters
+ self.threshold_ema_dead_code = threshold_ema_dead_code
+ self.vq = ResidualVectorQuantization(
+ dim=self.dimension,
+ codebook_size=self.bins,
+ num_quantizers=self.n_q,
+ decay=self.decay,
+ kmeans_init=self.kmeans_init,
+ kmeans_iters=self.kmeans_iters,
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ n_q: tp.Optional[int] = None,
+ layers: tp.Optional[list] = None,
+ st: tp.Optional[int] = None,
+ ) -> QuantizedResult:
+ """Residual vector quantization on the given input tensor.
+ Args:
+ x (torch.Tensor): Input tensor.
+ n_q (int): Number of quantizer used to quantize. Default: All quantizers.
+ layers (list): Layer that need to return quantized. Defalt: None.
+ Returns:
+ QuantizedResult:
+ The quantized (or approximately quantized) representation with
+ the associated numbert quantizers and layer quantized required to return.
+ """
+ n_q = n_q if n_q else self.n_q
+ if layers and max(layers) >= n_q:
+ raise ValueError(
+ f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
+ )
+ quantized, codes, commit_loss, quantized_list = self.vq(
+ x, n_q=n_q, layers=layers, st=st
+ )
+ return quantized, codes, torch.mean(commit_loss), quantized_list
+
+ def encode(
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
+ ) -> torch.Tensor:
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
+ The RVQ encode method sets the appropriate number of quantizer to use
+ and returns indices for each quantizer.
+ Args:
+ x (torch.Tensor): Input tensor.
+ n_q (int): Number of quantizer used to quantize. Default: All quantizers.
+ st (int): Start to encode input from which layers. Default: 0.
+ """
+ n_q = n_q if n_q else self.n_q
+ st = st or 0
+ codes = self.vq.encode(x, n_q=n_q, st=st)
+ return codes
+
+ def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
+ """Decode the given codes to the quantized representation.
+ Args:
+ codes (torch.Tensor): Input indices for each quantizer.
+ st (int): Start to decode input codes from which layers. Default: 0.
+ """
+ quantized = self.vq.decode(codes, st=st)
+ return quantized
diff --git a/voicemark.pth b/voicemark.pth
new file mode 100644
index 0000000000000000000000000000000000000000..4ab3d9eb9e512f435625c3de8b87762da9ccb11c
--- /dev/null
+++ b/voicemark.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ff9975b18ee175ad05ac0edac88583db468fd56867b3bcdd2efa6584de85bfe4
+size 86430822
diff --git a/voicemark_overview.png b/voicemark_overview.png
new file mode 100644
index 0000000000000000000000000000000000000000..e565824fa3f4c229dc3323ee47f0c5f72ab99233
--- /dev/null
+++ b/voicemark_overview.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ddb86c3c1b83ef3a14dbe3c8a1960718534733d53a406f8ea00222ffd5fd989d
+size 297901