diff --git a/README.md b/README.md index 56715545d8657e9074cf9161fa8d984634e29b45..892883c05dde4e31fbd21349642716e77f442232 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ --- title: VoiceMark -emoji: 📚 -colorFrom: blue -colorTo: blue +emoji: 📊 +colorFrom: purple +colorTo: green sdk: gradio -sdk_version: 5.30.0 +sdk_version: 5.16.0 app_file: app.py pinned: false -license: apache-2.0 -short_description: Zero-Shot Voice Cloning-Resistant Watermarking Approach Leve +license: gpl +short_description: Zero-Shot Voice Cloning-Resistant Watermarking --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..db6a9252be66cbf9a963be863b107dcb2d227ac8 --- /dev/null +++ b/app.py @@ -0,0 +1,137 @@ +import gradio as gr +import json +import os +import torchaudio +from infer import ( + WatermarkSolver, + hamming_distance +) + +# Predefined watermarks (instead of loading from a JSON file) +watermarks = { + "VoiceMark": "1000010101010011", + "Voice Cloning": "1111111001000010", + "Speech Security": "1011101100001110", + "Audio Watermarking": "0110110011100010", + "Deep Learning": "0000100111111000", + "Artificial Intelligence": "0010000100011111", + "Hello World": "0001111101110001", + "Happy New Year": "1101011011011101", + "World Peace": "0011110010011110", + "Good Morning": "0000001011000010", +} + +# Initialize WatermarkSolver model +solver = WatermarkSolver() +solver.load_model(checkpoint_dir="./", checkpoint_name="voicemark.pth", strict=True) + +# Gradio interface +with gr.Blocks() as demo: + gr.Markdown( + "# VoiceMark: Zero-Shot Voice Cloning-Resistant Watermarking Approach Leveraging Speaker-Specific Latents" + ) + with gr.Column(): + gr.Image( + value="voicemark_overview.png", + width=925, + height=487, + elem_id="overview_image", + label="overview" + ) + # Step 1: Upload audio and select watermark + gr.HTML("

The overall architecture of our proposed VoiceMark

") + + # Step 1: Upload audio and select watermark + gr.Markdown( + """ + **Step 1**: Upload an audio file or select one from the provided samples, choose a watermark, and generate the watermarked audio. + """ + ) + + with gr.Row(): + with gr.Column(): + audio_input = gr.Audio(label="Upload Audio", type="filepath") + + gr.Examples( + examples=[ + ["audios/1.wav"], + ["audios/2.wav"], + ["audios/3.wav"], + ["audios/4.wav"], + ["audios/5.wav"], + ], + inputs=audio_input, + label="Sample Audios (Click to Use)" + ) + + with gr.Column(): + audio_output = gr.Audio(label="Watermarked Audio", type="filepath") + watermark_list = gr.Dropdown( + label="Select Watermark", choices=list(watermarks.keys()), interactive=True + ) + add_watermark_button = gr.Button("Add Watermark to Audio") + + # Step 2: TTS tools demo links + gr.Markdown( + """ + **Step 2**: Download the generated watermarked audio, then use Zero-Shot Voice Cloning tools to generate the cloned audio. Some available tools are: + - [CosyVoice2: Scalable Streaming Speech Synthesis with Large Language Models](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B) + - [F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching](https://huggingface.co/spaces/mrfakename/E2-F5-TTS) + - [MaskGCT: Zero-Shot Text-to-Speech with Masked Generative Codec Transformer](https://huggingface.co/spaces/amphion/maskgct) + """ + ) + + # Step 3: Upload cloned audio to decode watermark + gr.Markdown( + """ + **Step 3**: Upload the cloned audio and decode your watermark. + """ + ) + + with gr.Row(): + decode_audio_input = gr.Audio(label="Upload Cloned Audio", type="filepath") + with gr.Column(): + decoded_watermark_output = gr.Textbox(label="Decoded Watermark") + decode_button = gr.Button("Decode Watermark") + + def process_audio(audio_path, watermark_text): + if not audio_path: + return "No audio selected. Please upload or select a sample." + try: + watermarked_audio = solver.infer_for_ui( + audio_path, watermarks[watermark_text] + ) + return watermarked_audio + except ValueError as e: + return str(e) + + add_watermark_button.click( + process_audio, + inputs=[audio_input, watermark_list], + outputs=audio_output + ) + + def decode_watermark(audio_path): + try: + detect_prob, decoded_id = solver.decode_for_ui(audio_path) + if detect_prob < 1e-2: + return "No matching watermark found" + closest_match = None + min_distance = float("inf") + for text, id_bin in watermarks.items(): + distance = hamming_distance(decoded_id, id_bin, base=16) + if distance < min_distance: + closest_match = text + min_distance = distance + if min_distance < 10: + return closest_match + return "No matching watermark found" + except ValueError as e: + return str(e) + + decode_button.click( + decode_watermark, inputs=decode_audio_input, outputs=decoded_watermark_output + ) + +# Launch the Gradio app +demo.launch() diff --git a/audios/1.wav b/audios/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..8bc9491ca8b637dce252df5b56025b8e6c76b58f --- /dev/null +++ b/audios/1.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb97a4d6a89ce3cfff4d504de24ef4d79658cc68cc765da7c905b30ecb5e9f1d +size 160078 diff --git a/audios/2.wav b/audios/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..e584955749e912955badac8e7b95eecb7ee5d581 --- /dev/null +++ b/audios/2.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d5cf698739bad50c02ab96375bed49594a0b1a4421406b6ed10edc432f09dc94 +size 152078 diff --git a/audios/3.wav b/audios/3.wav new file mode 100644 index 0000000000000000000000000000000000000000..a2cb187543ddf7d7ebad1ed521576bfdb5c5d9e8 --- /dev/null +++ b/audios/3.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4673c2f94dec181e7603df7564af0778d6589a7d7f4878f30a9c42bdfc2324f7 +size 160078 diff --git a/audios/4.wav b/audios/4.wav new file mode 100644 index 0000000000000000000000000000000000000000..f9bdb6702975b272f60c98cc2c8f209c6a11f640 --- /dev/null +++ b/audios/4.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:10c63ef8e2866b7e4b98d15cbbc084da667b350ab4d084f338b656661b619ceb +size 98750 diff --git a/audios/5.wav b/audios/5.wav new file mode 100644 index 0000000000000000000000000000000000000000..650cf3edb5d1b6007b77a980945a6960bcd54498 --- /dev/null +++ b/audios/5.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:049b7061915306d541d2d8c059c6c9599502997562bc32120f16a873a55e499e +size 63072 diff --git a/infer.py b/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a5d0cfe9fea48e7e12659d77d5ca5a6673e758d --- /dev/null +++ b/infer.py @@ -0,0 +1,175 @@ +from models import SBW +import torchaudio +import torch +import os + +def hamming_distance(s1, s2, base=2): + """ + Calculate the absolute difference between two strings interpreted in chunks of a specified base. + + Args: + s1 (str): The first binary string. + s2 (str): The second binary string. + base (int): The base to interpret the binary strings (e.g., 2 for binary, 16 for hexadecimal). + + Returns: + int: The sum of absolute differences for corresponding chunks. + + Raises: + ValueError: If the strings are not of equal length or invalid for the given base. + """ + if len(s1) != len(s2): + raise ValueError("Both strings must be of equal length") + + # Determine the chunk size for the given base + import math + + chunk_size = int(math.log(base, 2)) + + if len(s1) % chunk_size != 0 or len(s2) % chunk_size != 0: + raise ValueError( + f"Binary strings must be a multiple of {chunk_size} bits for base {base}" + ) + + # Split binary strings into chunks + def to_chunks(binary_str): + return [ + binary_str[i : i + chunk_size] + for i in range(0, len(binary_str), chunk_size) + ] + + chunks_s1 = to_chunks(s1) + chunks_s2 = to_chunks(s2) + + # Convert chunks to integers and calculate absolute difference + def absolute_difference_chunks(chunk1, chunk2): + int1 = int(chunk1, 2) + int2 = int(chunk2, 2) + return abs(int1 - int2) + + return sum( + absolute_difference_chunks(c1, c2) for c1, c2 in zip(chunks_s1, chunks_s2) + ) + +def random_message(nbits: int, batch_size: int) -> torch.Tensor: + """Return random message as 0/1 tensor.""" + if nbits == 0: + return torch.tensor([]) + return torch.randint(0, 2, (batch_size, nbits)) + + +def string_to_message(message_str: str, batch_size: int) -> torch.Tensor: + """ + Convert a binary string to a message tensor. + + Args: + message_str (str): A string of '0's and '1's. + batch_size (int): The batch size for the output tensor. + + Returns: + torch.Tensor: A tensor of shape (batch_size, nbits) where nbits is the length of message_str. + """ + # Convert the string to a list of integers (0 and 1) + message_list = [int(bit) for bit in message_str] + nbits = len(message_list) + # Convert the list to a tensor and repeat it for the batch size + message_tensor = torch.tensor(message_list).unsqueeze(0).repeat(batch_size, 1) + return message_tensor + + +class WatermarkSolver: + def __init__(self): + self.device = 'cpu' + self.sample_rate = 16000 + self.nbits = 16 + self.model = SBW() + + def load_model( + self, checkpoint_dir, checkpoint_name, strict=False + ): + """ + Load the latest model weights from the checkpoint directory. + """ + checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")] + if not checkpoint_files: + print("No checkpoint files found in the directory.") + return + + checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) + print(f"Loading model weights from {checkpoint_path}...") + + # Load the checkpoint + checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'),weights_only=False) + + # Load model state dict with strict=False to allow missing keys (semantic_encoder) + model_state_dict = checkpoint["model_state_dict"] + + new_state_dict = {} + for k, v in model_state_dict.items(): + new_key = k.replace("module.", "") + new_state_dict[new_key] = v + if not strict: + new_state_dict = { + k: v + for k, v in new_state_dict.items() + if k in self.model.state_dict() + and self.model.state_dict()[k].shape == v.shape + } + self.model.load_state_dict(new_state_dict, strict=False) + + print("Model state dict loaded successfully.") + self.epoch = checkpoint["epoch"] + print(f"Checkpoint loaded from epoch {checkpoint['epoch']}.") + + def infer_for_ui(self, input_audio_path, watermark, output_audio_path="tmp"): + message_str = watermark + self.model.eval() + + waveform, sample_rate = torchaudio.load(input_audio_path) + + if sample_rate != self.sample_rate: + resampler = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=self.sample_rate + ) + waveform = resampler(waveform) + + waveform = waveform[:1].to(self.device).unsqueeze(1) + message = string_to_message(message_str=message_str, batch_size=1).to( + self.device + ) + + with torch.no_grad(): + output = self.model( + waveform, + message=message, + ) + y_wm = output["recon_wm"] + + os.makedirs(output_audio_path, exist_ok=True) + watermarked_audio_path = os.path.join( + output_audio_path, + f"{os.path.splitext(os.path.basename(input_audio_path))[0]}_watermarked.wav", + ) + torchaudio.save(watermarked_audio_path, y_wm.squeeze(1).cpu(), self.sample_rate) + return watermarked_audio_path + + def decode_for_ui(self, input_audio_path): + self.model.eval() + + waveform, sample_rate = torchaudio.load(input_audio_path) + + if sample_rate != self.sample_rate: + resampler = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=self.sample_rate + ) + waveform = resampler(waveform) + + waveform = waveform[:1].to(self.device).unsqueeze(1) + + with torch.no_grad(): + detect_prob, detected_message_tensor, _ = self.model.detect_watermark(waveform) + detected_id = "".join( + map(str, detected_message_tensor.squeeze(0).cpu().numpy().tolist()) + ) + + return detect_prob,detected_id diff --git a/models.py b/models.py new file mode 100644 index 0000000000000000000000000000000000000000..c197b7824e64a92834f53810468c5c67534b62d9 --- /dev/null +++ b/models.py @@ -0,0 +1,299 @@ +import torch +import torch.nn as nn +from typing import Dict, Optional, Tuple + +import torch.nn.functional as F +import torch +import torch.nn as nn + + +class WMDetector(nn.Module): + """ + Detect watermarks in an audio signal using a Transformer architecture, + where the watermark bits are split into bytes (8 bits each). + We assume nbits is a multiple of 8. + """ + + def __init__( + self, input_channels: int, nbits: int, nchunk_size: int, d_model: int = 512 + ): + """ + Args: + input_channels (int): Number of input channels in the audio feature (e.g., mel channels). + nbits (int): Total number of bits in the watermark, must be a multiple of 8. + d_model (int): Embedding dimension for the Transformer. + """ + super().__init__() + self.nchunk_size = nchunk_size + assert nbits % nchunk_size == 0, "nbits must be a multiple of 8!" + self.nbits = nbits + self.d_model = d_model + # Number of bytes + self.nchunks = nbits // nchunk_size + + # 1D convolution to map the input channels to d_model + self.embedding = nn.Conv1d(input_channels, d_model, kernel_size=1) + + # Transformer encoder block + self.transformer = nn.TransformerEncoder( + nn.TransformerEncoderLayer( + d_model=d_model, + nhead=1, + dim_feedforward=d_model * 2, + activation="gelu", + batch_first=True, + ), + num_layers=8, + ) + + # A linear head for watermark presence detection (binary) + self.watermark_head = nn.Linear(d_model, 1) + + # For each byte, we perform a 256-way classification + self.message_heads = nn.ModuleList( + nn.Linear(d_model, 2**nchunk_size) for _ in range(self.nchunks) + ) + + # Learnable embeddings for each byte chunk (instead of per bit) + # Shape: [nchunks, d_model] + self.nchunk_embeddings = nn.Parameter(torch.randn(self.nchunks, d_model)) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass of the detector. + + Returns: + logits (torch.Tensor): Watermark detection logits of shape [batch, seq_len]. + chunk_logits (torch.Tensor): Byte-level classification logits of shape [batch, nchunks, 256]. + """ + batch_size, input_channels, time_steps = x.shape + + # 1) Map [batch, in_channels, time_steps] → [batch, time_steps, d_model] + x = self.embedding(x).permute(0, 2, 1) # [batch, time_steps, d_model] + + # 2) Prepend chunk embeddings at the beginning of the sequence + # [nchunks, d_model] → [1, nchunks, d_model] → [batch, nchunks, d_model] + nchunk_embeds = self.nchunk_embeddings.unsqueeze(0).expand(batch_size, -1, -1) + # Concatenate along the time dimension: [batch, nchunks + time_steps, d_model] + x = torch.cat([nchunk_embeds, x], dim=1) + + # 3) Pass through the Transformer + x = self.transformer(x) + # x has shape [batch, nchunks + time_steps, d_model] + + # (a) Watermark presence detection: skip the first nchunks + detection_part = x[:, self.nchunks :] # [batch, time_steps, d_model] + logits = self.watermark_head(detection_part).squeeze(-1) # [batch, time_steps] + + # (b) Message decoding: use the first nchunks + message_part = x[:, : self.nchunks] # [batch, nchunks, d_model] + chunk_logits_list = [] + for i, head in enumerate(self.message_heads): + # message_part[:, i, :] has shape [batch, d_model] + # each head outputs [batch, 256] + chunk_vec = message_part[:, i, :] + chunk_logits_list.append(head(chunk_vec).unsqueeze(1)) # [batch, 1, 256] + + # Concatenate along the 'nchunks' dimension → [batch, nchunks, 256] + chunk_logits = torch.cat(chunk_logits_list, dim=1) + + return logits, chunk_logits + + def detect_watermark( + self, + x: torch.Tensor, + sample_rate: Optional[int] = None, + threshold: float = 0.5, + ) -> Tuple[float, torch.Tensor, torch.Tensor]: + """ + A convenience function for inference. + + Returns: + detect_prob (float): Probability that the audio is watermarked. + binary_message (torch.Tensor): The recovered message of shape [batch, nbits] (binary). + detected (torch.Tensor): The sigmoid values of the per-timestep watermark detection. + """ + logits, chunk_logits = self.forward(x) + # logits: [batch, seq_len] → raw logits for watermark presence detection + # chunk_logits: [batch, nchunks, 256] → classification logits for each byte + + # (1) Compute watermark detection probability + detected = torch.sigmoid(logits) # [batch, seq_len] + detect_prob = detected.mean(dim=-1).cpu().item() + + # (2) Decode the message: chunk_logits has shape [batch, nchunks, 256] + chunk_probs = F.softmax(chunk_logits, dim=-1) # [batch, nchunks, 256] + chunk_indices = torch.argmax( + chunk_probs, dim=-1 + ) # [batch, nchunks], each in [0..255] + # (3) Convert each byte back to 8 bits + # Finally, assemble into a [batch, nbits] binary tensor + binary_message = [] + for i in range(self.nchunks): + chunk_val = chunk_indices[:, i] # [batch] + # Extract 8 bits from the integer (0..255) + chunk_bits = [] + for b in range(self.nchunk_size): + bit_b = (chunk_val >> b) & 1 # get bit b + chunk_bits.append(bit_b.unsqueeze(-1)) + # Concatenate bits to shape [batch, 8] + chunk_bits = torch.cat(chunk_bits, dim=-1) + binary_message.append(chunk_bits) + + # Concatenate all bytes → [batch, nbits] + binary_message = torch.cat(binary_message, dim=-1) + + return detect_prob, binary_message, detected + + + +class WMEmbedder(nn.Module): + """ + A class that takes a secret message, processes it into chunk embeddings + (as a small sequence), and uses a TransformerDecoder to do cross-attention + between the original hidden (target) and the watermark tokens (memory). + """ + + def __init__( + self, + nbits: int, # total bits in the secret message + input_dim: int, # the input dimension (e.g. audio feature dimension) + nchunk_size: int, + hidden_dim: int = 256, + num_heads: int = 1, + num_layers: int = 4, + ): + super().__init__() + self.nchunk_size = nchunk_size + assert nbits % nchunk_size == 0, "nbits must be a multiple of nchunk_size!" + self.nbits = nbits + self.nchunks = nbits // nchunk_size # how many chunks + + # Each chunk (0..2^nchunk_size - 1) maps to an embedding of size [hidden_dim] + self.msg_embeddings = nn.ModuleList( + nn.Embedding(2**nchunk_size, hidden_dim) for _ in range(self.nchunks) + ) + + # Linear to project [input_dim] -> [hidden_dim] + self.input_projection = nn.Linear(input_dim, hidden_dim) + + # TransformerDecoder for cross-attention + # d_model=hidden_dim, so the decoder expects [b, seq_len, hidden_dim] as tgt + # and [b, memory_len, hidden_dim] as memory + decoder_layer = nn.TransformerDecoderLayer( + d_model=hidden_dim, + nhead=num_heads, + dim_feedforward=2 * hidden_dim, + activation="gelu", + batch_first=True, # so shape is [batch, seq, feature] + ) + self.transformer_decoder = nn.TransformerDecoder( + decoder_layer, num_layers=num_layers + ) + + # Project [hidden_dim] -> [input_dim] + # self.output_projection1 = nn.Linear(hidden_dim * 2, hidden_dim) + self.output_projection = nn.Linear(hidden_dim, input_dim) + + def forward(self, hidden: torch.Tensor, msg: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden: [batch, input_dim, seq_len] + msg: [batch, nbits] + Returns: + A tensor [batch, input_dim, seq_len] with watermark injected. + """ + b, in_dim, seq_len = hidden.shape + + # 1) Project input features to [b, seq_len, hidden_dim] + hidden_projected = self.input_projection( + hidden.permute(0, 2, 1) + ) # => [b, seq_len, hidden_dim] + + # 2) Convert the msg bits into a sequence of chunk embeddings + # We keep each chunk as one token => [b, nchunks, hidden_dim] + chunk_emb_list = [] + for i in range(self.nchunks): + # msg[:, i*nchunk_size : (i+1)*nchunk_size] => shape [b, nchunk_size] + chunk_bits = msg[:, i * self.nchunk_size : (i + 1) * self.nchunk_size] + chunk_val = torch.zeros_like(chunk_bits[:, 0]) # shape [b] + for bit_idx in range(self.nchunk_size): + # shift bits + chunk_val += chunk_bits[:, bit_idx] << bit_idx + + # embedding => [b, hidden_dim] + chunk_emb = self.msg_embeddings[i](chunk_val) + chunk_emb_list.append(chunk_emb.unsqueeze(1)) # => [b,1,hidden_dim] + + # Concat => [b, nchunks, hidden_dim] + chunk_emb_seq = torch.cat(chunk_emb_list, dim=1) # [b, nchunks, hidden_dim] + + # 3) Use chunk_emb_seq as memory, hidden_projected as target for TransformerDecoder + # + # TransformerDecoder forward signature: + # transformer_decoder(tgt, memory, ...) + # => [b, seq_len, hidden_dim] + x_decoded = self.transformer_decoder( + tgt=hidden_projected, # [b, seq_len, hidden_dim] + memory=chunk_emb_seq, # [b, nchunks, hidden_dim] + ) + + # 4) Project back to input_dim => [b, seq_len, input_dim] + x_output = self.output_projection(x_decoded) + + # 5) permute back to [b, input_dim, seq_len] + x_output = x_output.permute(0, 2, 1) # => [b, input_dim, seq_len] + + # 6) (Optional) Residual with original hidden + x_output = x_output + hidden + + return x_output + + +from speechtokenizer import SpeechTokenizer + + +class SBW(nn.Module): + def __init__(self): + super().__init__() + self.nbits = 16 + config_path = ( + "speechtokenizer/pretrained_model/speechtokenizer_hubert_avg_config.json" + ) + ckpt_path = "speechtokenizer/pretrained_model/SpeechTokenizer.pt" + self.st_model = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path) + self.msg_processor = WMEmbedder( + nbits=16, + input_dim=1024, + nchunk_size=4, + ) + self.detector = WMDetector( + 1024, + 16, + nchunk_size=4, + ) + + def detect_watermark( + self, x: torch.Tensor, return_logits=False + ) -> Tuple[float, torch.Tensor]: + embedding = self.st_model.forward_feature(x) + if return_logits: + return self.detector(embedding) + return self.detector.detect_watermark(embedding) + + def forward( + self, + speech_input: torch.Tensor, + message: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + recon, recon_wm, acoustic, acoustic_wm = self.st_model( + speech_input, msg_processor=self.msg_processor, message=message + ) + wav_length = min(speech_input.size(-1), recon_wm.size(-1)) + speech_input = speech_input[..., :wav_length] + recon = recon[..., :wav_length] + recon_wm = recon_wm[..., :wav_length] + return { + "recon": recon, + "recon_wm": recon_wm, + } diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1220e33b5d43e1e68b2ef3d599f7630434d02b95 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +torchaudio +beartype +einops +omegaconf \ No newline at end of file diff --git a/speechtokenizer/__init__.py b/speechtokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..390bd890dffbf95ce01d141eb5013177dcea468c --- /dev/null +++ b/speechtokenizer/__init__.py @@ -0,0 +1,3 @@ +from .model import SpeechTokenizer + +__version__ = '1.0.0' \ No newline at end of file diff --git a/speechtokenizer/model.py b/speechtokenizer/model.py new file mode 100644 index 0000000000000000000000000000000000000000..cec8a020639e1107b283268c70e5305d288da46c --- /dev/null +++ b/speechtokenizer/model.py @@ -0,0 +1,215 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Aug 30 15:47:55 2023 +@author: zhangxin +""" + +import random +from models import WMEmbedder +from .modules.seanet import SEANetEncoder, SEANetDecoder +from .quantization import ResidualVectorQuantizer +import torch.nn as nn +from einops import rearrange +import torch +import numpy as np + + +class SpeechTokenizer(nn.Module): + def __init__(self, config): + """ + + Parameters + ---------- + config : json + Model Config. + + """ + super().__init__() + self.encoder = SEANetEncoder( + n_filters=config.get("n_filters"), + dimension=config.get("dimension"), + ratios=config.get("strides"), + lstm=config.get("lstm_layers"), + bidirectional=config.get("bidirectional"), + dilation_base=config.get("dilation_base"), + residual_kernel_size=config.get("residual_kernel_size"), + n_residual_layers=config.get("n_residual_layers"), + activation=config.get("activation"), + ) + self.sample_rate = config.get("sample_rate") + self.n_q = config.get("n_q") + self.downsample_rate = np.prod(config.get("strides")) + if config.get("dimension") != config.get("semantic_dimension"): + self.transform = nn.Linear( + config.get("dimension"), config.get("semantic_dimension") + ) + else: + self.transform = nn.Identity() + self.quantizer = ResidualVectorQuantizer( + dimension=config.get("dimension"), + n_q=config.get("n_q"), + bins=config.get("codebook_size"), + ) + self.decoder = SEANetDecoder( + n_filters=config.get("n_filters"), + dimension=config.get("dimension"), + ratios=config.get("strides"), + lstm=config.get("lstm_layers"), + bidirectional=False, + dilation_base=config.get("dilation_base"), + residual_kernel_size=config.get("residual_kernel_size"), + n_residual_layers=config.get("n_residual_layers"), + activation=config.get("activation"), + ) + + @classmethod + def load_from_checkpoint(cls, config_path: str, ckpt_path: str): + """ + + Parameters + ---------- + config_path : str + Path of model configuration file. + ckpt_path : str + Path of model checkpoint. + + Returns + ------- + model : SpeechTokenizer + SpeechTokenizer model. + + """ + import json + + with open(config_path) as f: + cfg = json.load(f) + model = cls(cfg) + params = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(params) + return model + + def forward( + self, + x: torch.tensor, + n_q: int = None, + layers: list = [0], + msg_processor: WMEmbedder = None, + message: torch.Tensor = None, + ): + """ + + Parameters + ---------- + x : torch.tensor + Input wavs. Shape: (batch, channels, timesteps). + n_q : int, optional + Number of quantizers in RVQ used to encode. The default is all layers. + layers : list[int], optional + Layers of RVQ should return quantized result. The default is the first layer. + + Returns + ------- + o : torch.tensor + Output wavs. Shape: (batch, channels, timesteps). + commit_loss : torch.tensor + Commitment loss from residual vector quantizers. + feature : torch.tensor + Output of RVQ's first layer. Shape: (batch, timesteps, dimension) + + """ + with torch.no_grad(): + e = self.encoder(x) + quantized_full, _, _, quantized_list = self.quantizer( + e, n_q=n_q, layers=[0, 1, 2, 3, 4, 5, 6, 7], st=0 + ) + # semantic, _, _, _ = self.quantizer(e, n_q=1, st=0) + # acoustic = e - semantic + o = self.decoder(quantized_full) + + subset = quantized_list[1:] + + # half_len = len(subset) // 2 + # selected_for_processing = random.sample(subset, half_len) + + # selected_ids = set(id(x) for x in selected_for_processing) + # acoustic_wm = sum( + # msg_processor(x, message) if id(x) in selected_ids else x for x in subset + # ) + + acoustic_wm = sum(msg_processor(x, message) for x in subset) + + acoustic = e - quantized_list[0] + + # e_wm = acoustic_wm + e_wm = quantized_list[0] + acoustic_wm + + o_wm = self.decoder(e_wm) + + return (o, o_wm, acoustic, acoustic_wm) + + def forward_feature(self, x: torch.tensor, layers: list = None): + """ + + Parameters + ---------- + x : torch.tensor + Input wavs. Shape should be (batch, channels, timesteps). + layers : list[int], optional + Layers of RVQ should return quantized result. The default is all layers. + + Returns + ------- + quantized_list : list[torch.tensor] + Quantized of required layers. + + """ + e = self.encoder(x) + with torch.no_grad(): + semantic, _, _, _ = self.quantizer(e, st=0, n_q=1) + acoustic = e - semantic + return acoustic + + def encode(self, x: torch.tensor, n_q: int = None, st: int = None): + """ + + Parameters + ---------- + x : torch.tensor + Input wavs. Shape: (batch, channels, timesteps). + n_q : int, optional + Number of quantizers in RVQ used to encode. The default is all layers. + st : int, optional + Start quantizer index in RVQ. The default is 0. + + Returns + ------- + codes : torch.tensor + Output indices for each quantizer. Shape: (n_q, batch, timesteps) + + """ + e = self.encoder(x) + if st is None: + st = 0 + n_q = n_q if n_q else self.n_q + codes = self.quantizer.encode(e, n_q=n_q, st=st) + return codes + + def decode(self, codes: torch.tensor, st: int = 0): + """ + + Parameters + ---------- + codes : torch.tensor + Indices for each quantizer. Shape: (n_q, batch, timesteps). + st : int, optional + Start quantizer index in RVQ. The default is 0. + + Returns + ------- + o : torch.tensor + Reconstruct wavs from codes. Shape: (batch, channels, timesteps) + + """ + quantized = self.quantizer.decode(codes, st=st) + o = self.decoder(quantized) + return o diff --git a/speechtokenizer/modules/__init__.py b/speechtokenizer/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d61ab031045422c4000cf26388f24d072572089 --- /dev/null +++ b/speechtokenizer/modules/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch modules.""" + +# flake8: noqa +from .conv import ( + pad1d, + unpad1d, + NormConv1d, + NormConvTranspose1d, + NormConv2d, + NormConvTranspose2d, + SConv1d, + SConvTranspose1d, +) +from .lstm import SLSTM +from .seanet import SEANetEncoder, SEANetDecoder \ No newline at end of file diff --git a/speechtokenizer/modules/__pycache__/__init__.cpython-310.pyc b/speechtokenizer/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a32d6b7c2de6cbbc18cb783ad1943e6f17caee15 Binary files /dev/null and b/speechtokenizer/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/speechtokenizer/modules/__pycache__/__init__.cpython-311.pyc b/speechtokenizer/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..488e7cede536d6a98ece07eeb4e16d1f120eb081 Binary files /dev/null and b/speechtokenizer/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/speechtokenizer/modules/__pycache__/__init__.cpython-39.pyc b/speechtokenizer/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9416d9a323d941a0fce79863e2d0b82105666e20 Binary files /dev/null and b/speechtokenizer/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/speechtokenizer/modules/__pycache__/conv.cpython-310.pyc b/speechtokenizer/modules/__pycache__/conv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf35c42c62599415234e4815f176c9bdd83f87d3 Binary files /dev/null and b/speechtokenizer/modules/__pycache__/conv.cpython-310.pyc differ diff --git a/speechtokenizer/modules/__pycache__/conv.cpython-311.pyc b/speechtokenizer/modules/__pycache__/conv.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef84ffa2942d241c0de6898dab490d0ce0b35f57 Binary files /dev/null and b/speechtokenizer/modules/__pycache__/conv.cpython-311.pyc differ diff --git a/speechtokenizer/modules/__pycache__/conv.cpython-39.pyc b/speechtokenizer/modules/__pycache__/conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e51bc6d346bf5659edb9b16f0243f3f9c7d6c0e7 Binary files /dev/null and b/speechtokenizer/modules/__pycache__/conv.cpython-39.pyc differ diff --git a/speechtokenizer/modules/__pycache__/lstm.cpython-310.pyc b/speechtokenizer/modules/__pycache__/lstm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8824b269519aa9c34b331b6b63293c3a2a23d97d Binary files /dev/null and b/speechtokenizer/modules/__pycache__/lstm.cpython-310.pyc differ diff --git a/speechtokenizer/modules/__pycache__/lstm.cpython-311.pyc b/speechtokenizer/modules/__pycache__/lstm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..363de3b98fbc8a0a5a3b47f692cf264b364024eb Binary files /dev/null and b/speechtokenizer/modules/__pycache__/lstm.cpython-311.pyc differ diff --git a/speechtokenizer/modules/__pycache__/lstm.cpython-39.pyc b/speechtokenizer/modules/__pycache__/lstm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f361e0ff5749098e2d04d067fa85a2d47ba3f57 Binary files /dev/null and b/speechtokenizer/modules/__pycache__/lstm.cpython-39.pyc differ diff --git a/speechtokenizer/modules/__pycache__/norm.cpython-310.pyc b/speechtokenizer/modules/__pycache__/norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..731f81f0d438c2c70f1ad5ecedfb212f4544549c Binary files /dev/null and b/speechtokenizer/modules/__pycache__/norm.cpython-310.pyc differ diff --git a/speechtokenizer/modules/__pycache__/norm.cpython-311.pyc b/speechtokenizer/modules/__pycache__/norm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..defcb330a18291545a4885475612955e93ac39be Binary files /dev/null and b/speechtokenizer/modules/__pycache__/norm.cpython-311.pyc differ diff --git a/speechtokenizer/modules/__pycache__/norm.cpython-39.pyc b/speechtokenizer/modules/__pycache__/norm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5221b03d1254b82ea0f893b3d45e1e4e475807b8 Binary files /dev/null and b/speechtokenizer/modules/__pycache__/norm.cpython-39.pyc differ diff --git a/speechtokenizer/modules/__pycache__/seanet.cpython-310.pyc b/speechtokenizer/modules/__pycache__/seanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8cfdd3816c3dcc8856f66fd61878e964d391f5a Binary files /dev/null and b/speechtokenizer/modules/__pycache__/seanet.cpython-310.pyc differ diff --git a/speechtokenizer/modules/__pycache__/seanet.cpython-311.pyc b/speechtokenizer/modules/__pycache__/seanet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d11fac416dc5ed2647d2ea67ab0891e7c49492c Binary files /dev/null and b/speechtokenizer/modules/__pycache__/seanet.cpython-311.pyc differ diff --git a/speechtokenizer/modules/__pycache__/seanet.cpython-39.pyc b/speechtokenizer/modules/__pycache__/seanet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..670f9f718da0ac17d0488413d887e72ba8749892 Binary files /dev/null and b/speechtokenizer/modules/__pycache__/seanet.cpython-39.pyc differ diff --git a/speechtokenizer/modules/conv.py b/speechtokenizer/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..f37963901118e10bfe009e3c6d27e47be06be7cc --- /dev/null +++ b/speechtokenizer/modules/conv.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Convolutional layers wrappers and utilities.""" + +import math +import typing as tp +import warnings + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + +from .norm import ConvLayerNorm + + +CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', + 'time_layer_norm', 'layer_norm', 'time_group_norm']) + + +def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == 'weight_norm': + return weight_norm(module) + elif norm == 'spectral_norm': + return spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module + + +def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == 'layer_norm': + assert isinstance(module, nn.modules.conv._ConvNd) + return ConvLayerNorm(module.out_channels, **norm_kwargs) + elif norm == 'time_group_norm': + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + """See `pad_for_conv1d`. + """ + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left: end] + + +class NormConv1d(nn.Module): + """Wrapper around Conv1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConv2d(nn.Module): + """Wrapper around Conv2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConvTranspose1d(nn.Module): + """Wrapper around ConvTranspose1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class NormConvTranspose2d(nn.Module): + """Wrapper around ConvTranspose2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class SConv1d(nn.Module): + """Conv1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, dilation: int = 1, + groups: int = 1, bias: bool = True, causal: bool = False, + norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = 'reflect'): + super().__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1' + f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).') + self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, + dilation=dilation, groups=groups, bias=bias, causal=causal, + norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.pad_mode = pad_mode + + def forward(self, x): + B, C, T = x.shape + kernel_size = self.conv.conv.kernel_size[0] + stride = self.conv.conv.stride[0] + dilation = self.conv.conv.dilation[0] + padding_total = (kernel_size - 1) * dilation - (stride - 1) + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + if self.causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) + return self.conv(x) + + +class SConvTranspose1d(nn.Module): + """ConvTranspose1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, causal: bool = False, + norm: str = 'none', trim_right_ratio: float = 1., + norm_kwargs: tp.Dict[str, tp.Any] = {}): + super().__init__() + self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, + causal=causal, norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert self.causal or self.trim_right_ratio == 1., \ + "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. + + def forward(self, x): + kernel_size = self.convtr.convtr.kernel_size[0] + stride = self.convtr.convtr.stride[0] + padding_total = kernel_size - stride + + y = self.convtr(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y diff --git a/speechtokenizer/modules/lstm.py b/speechtokenizer/modules/lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..ba669aef7249f84f8944b2f2079bf29ff34235e0 --- /dev/null +++ b/speechtokenizer/modules/lstm.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""LSTM layers module.""" + +from torch import nn + + +class SLSTM(nn.Module): + """ + LSTM without worrying about the hidden state, nor the layout of the data. + Expects input as convolutional layout. + """ + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True, bidirectional: bool=False): + super().__init__() + self.bidirectional = bidirectional + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension, num_layers, bidirectional=bidirectional) + + def forward(self, x): + x = x.permute(2, 0, 1) + y, _ = self.lstm(x) + if self.bidirectional: + x = x.repeat(1, 1, 2) + if self.skip: + y = y + x + y = y.permute(1, 2, 0) + return y diff --git a/speechtokenizer/modules/norm.py b/speechtokenizer/modules/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..19970e0a21ea1c10461cb56d776619dd5f64ff36 --- /dev/null +++ b/speechtokenizer/modules/norm.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Normalization modules.""" + +import typing as tp + +import einops +import torch +from torch import nn + + +class ConvLayerNorm(nn.LayerNorm): + """ + Convolution-friendly LayerNorm that moves channels to last dimensions + before running the normalization and moves them back to original position right after. + """ + def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): + super().__init__(normalized_shape, **kwargs) + + def forward(self, x): + x = einops.rearrange(x, 'b ... t -> b t ...') + x = super().forward(x) + x = einops.rearrange(x, 'b t ... -> b ... t') + return diff --git a/speechtokenizer/modules/seanet.py b/speechtokenizer/modules/seanet.py new file mode 100644 index 0000000000000000000000000000000000000000..a2456a2c2fee8de888af6138ef121d3b4dde81c9 --- /dev/null +++ b/speechtokenizer/modules/seanet.py @@ -0,0 +1,408 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Encodec SEANet-based encoder and decoder implementation.""" + +import typing as tp + +import numpy as np +import torch.nn as nn +import torch + +from . import SConv1d, SConvTranspose1d, SLSTM + + +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) + + +class SEANetResnetBlock(nn.Module): + """Residual block from SEANet model. + Args: + dim (int): Dimension of the input/output + kernel_sizes (list): List of kernel sizes for the convolutions. + dilations (list): List of dilations for the convolutions. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + compress (int): Reduced dimensionality in residual branches (from Demucs v3) + true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection. + """ + + def __init__( + self, + dim: int, + kernel_sizes: tp.List[int] = [3, 1], + dilations: tp.List[int] = [1, 1], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + norm: str = "weight_norm", + norm_params: tp.Dict[str, tp.Any] = {}, + causal: bool = False, + pad_mode: str = "reflect", + compress: int = 2, + true_skip: bool = True, + ): + super().__init__() + assert len(kernel_sizes) == len( + dilations + ), "Number of kernel sizes should match number of dilations" + act = getattr(nn, activation) if activation != "Snake" else Snake1d + hidden = dim // compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [ + act(**activation_params) if activation != "Snake" else act(in_chs), + SConv1d( + in_chs, + out_chs, + kernel_size=kernel_size, + dilation=dilation, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + self.block = nn.Sequential(*block) + self.shortcut: nn.Module + if true_skip: + self.shortcut = nn.Identity() + else: + self.shortcut = SConv1d( + dim, + dim, + kernel_size=1, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + + def forward(self, x): + return self.shortcut(x) + self.block(x) + + +class SEANetEncoder(nn.Module): + """SEANet encoder. + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of + upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here + that must match the decoder order + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + """ + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 1, + ratios: tp.List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + norm: str = "weight_norm", + norm_params: tp.Dict[str, tp.Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = False, + compress: int = 2, + lstm: int = 2, + bidirectional: bool = False, + ): + super().__init__() + self.channels = channels + self.dimension = dimension + self.n_filters = n_filters + self.ratios = list(reversed(ratios)) + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) # čŽĄįŽ—äš˜į§¯ + + act = getattr(nn, activation) if activation != "Snake" else Snake1d + mult = 1 + model: tp.List[nn.Module] = [ + SConv1d( + channels, + mult * n_filters, + kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + ] + # Downsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock( + mult * n_filters, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + norm=norm, + norm_params=norm_params, + activation=activation, + activation_params=activation_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] + + # Add downsampling layers + model += [ + ( + act(**activation_params) + if activation != "Snake" + else act(mult * n_filters) + ), + SConv1d( + mult * n_filters, + mult * n_filters * 2, + kernel_size=ratio * 2, + stride=ratio, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + mult *= 2 + + if lstm: + model += [ + SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional) + ] + + mult = mult * 2 if bidirectional else mult + model += [ + ( + act(**activation_params) + if activation != "Snake" + else act(mult * n_filters) + ), + SConv1d( + mult * n_filters, + dimension, + last_kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + + self.model = nn.Sequential(*model) + + def forward(self, x): + return self.model(x) + + +class SEANetDecoder(nn.Module): + """SEANet decoder. + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + final_activation (str): Final activation function after all convolutions. + final_activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. + If equal to 1.0, it means that all the trimming is done at the right. + """ + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 1, + ratios: tp.List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + final_activation: tp.Optional[str] = None, + final_activation_params: tp.Optional[dict] = None, + norm: str = "weight_norm", + norm_params: tp.Dict[str, tp.Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = False, + compress: int = 2, + lstm: int = 2, + trim_right_ratio: float = 1.0, + bidirectional: bool = False, + ): + super().__init__() + self.dimension = dimension + self.channels = channels + self.n_filters = n_filters + self.ratios = ratios + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + + act = getattr(nn, activation) if activation != "Snake" else Snake1d + mult = int(2 ** len(self.ratios)) + model: tp.List[nn.Module] = [ + SConv1d( + dimension, + mult * n_filters, + kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + ] + + if lstm: + model += [ + SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional) + ] + + # Upsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add upsampling layers + model += [ + ( + act(**activation_params) + if activation != "Snake" + else act(mult * n_filters) + ), + SConvTranspose1d( + mult * n_filters, + mult * n_filters // 2, + kernel_size=ratio * 2, + stride=ratio, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + trim_right_ratio=trim_right_ratio, + ), + ] + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock( + mult * n_filters // 2, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + activation=activation, + activation_params=activation_params, + norm=norm, + norm_params=norm_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] + + mult //= 2 + + # Add final layers + model += [ + act(**activation_params) if activation != "Snake" else act(n_filters), + SConv1d( + n_filters, + channels, + last_kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + # Add optional final activation to decoder (eg. tanh) + if final_activation is not None: + final_act = getattr(nn, final_activation) + final_activation_params = final_activation_params or {} + model += [final_act(**final_activation_params)] + self.model = nn.Sequential(*model) + + def forward(self, z): + y = self.model(z) + return y + + +def test(): + import torch + + encoder = SEANetEncoder() + decoder = SEANetDecoder() + x = torch.randn(1, 1, 24000) + z = encoder(x) + print("z ", z.shape) + assert 1 == 2 + assert list(z.shape) == [1, 128, 75], z.shape + y = decoder(z) + assert y.shape == x.shape, (x.shape, y.shape) + + +if __name__ == "__main__": + test() diff --git a/speechtokenizer/pretrained_model/SpeechTokenizer.pt b/speechtokenizer/pretrained_model/SpeechTokenizer.pt new file mode 100644 index 0000000000000000000000000000000000000000..c81992fd5ac3e9726ff58f659f3115bf032aac32 --- /dev/null +++ b/speechtokenizer/pretrained_model/SpeechTokenizer.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d04593b6c9a4b475f91ca481141a6ef5b23e6ac112f347dd2b2717f193c1c728 +size 481906997 diff --git a/speechtokenizer/pretrained_model/speechtokenizer_hubert_avg_config.json b/speechtokenizer/pretrained_model/speechtokenizer_hubert_avg_config.json new file mode 100644 index 0000000000000000000000000000000000000000..fd35d63489d9b631f4bb419fcba5f74d1f76aa40 --- /dev/null +++ b/speechtokenizer/pretrained_model/speechtokenizer_hubert_avg_config.json @@ -0,0 +1,47 @@ +{ + "resblock": "1", + "num_gpus": 3, + "batch_size": 60, + "learning_rate": 0.0001, + "adam_b1": 0.5, + "adam_b2": 0.9, + "lr_decay": 0.98, + "seed": 1234, + "lambda_distill": 0.15, + + "n_filters": 64, + "strides": [8,5,4,2], + "dimension": 1024, + "semantic_dimension": 768, + "bidirectional": true, + "dilation_base": 2, + "residual_kernel_size": 3, + "n_residual_layers": 1, + "lstm_layers": 2, + "activation": "ELU", + + "segment_size": 48000, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 240, + "win_size": 1024, + + "sampling_rate": 16000, + "sample_rate": 16000, + + "codebook_size": 1024, + "n_q": 8, + + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": null, + + "num_workers": 12, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54322", + "world_size": 1 + } +} diff --git a/speechtokenizer/quantization/__init__.py b/speechtokenizer/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfabe52b8cb6f260cdda6137b34df2f4736bd02f --- /dev/null +++ b/speechtokenizer/quantization/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from .vq import QuantizedResult, ResidualVectorQuantizer diff --git a/speechtokenizer/quantization/__pycache__/__init__.cpython-310.pyc b/speechtokenizer/quantization/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a9aa18ebd773618b79afb090c24d6ed9ca9b716 Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/__init__.cpython-310.pyc differ diff --git a/speechtokenizer/quantization/__pycache__/__init__.cpython-311.pyc b/speechtokenizer/quantization/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65723079defbf8b70d4f8b568141253e2af6b02f Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/__init__.cpython-311.pyc differ diff --git a/speechtokenizer/quantization/__pycache__/__init__.cpython-39.pyc b/speechtokenizer/quantization/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76200d4a860c5b09a9f6f794af1ca96d0eaf7e58 Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/__init__.cpython-39.pyc differ diff --git a/speechtokenizer/quantization/__pycache__/core_vq.cpython-310.pyc b/speechtokenizer/quantization/__pycache__/core_vq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c62dc7ad30c7d70ae643cc3ae638980a01ec73d5 Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/core_vq.cpython-310.pyc differ diff --git a/speechtokenizer/quantization/__pycache__/core_vq.cpython-311.pyc b/speechtokenizer/quantization/__pycache__/core_vq.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4acb73bc60e0133ae2d3421300d0922e58eda002 Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/core_vq.cpython-311.pyc differ diff --git a/speechtokenizer/quantization/__pycache__/core_vq.cpython-39.pyc b/speechtokenizer/quantization/__pycache__/core_vq.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..359211656337bd3cb468b4005292dc9a36bbdc07 Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/core_vq.cpython-39.pyc differ diff --git a/speechtokenizer/quantization/__pycache__/distrib.cpython-310.pyc b/speechtokenizer/quantization/__pycache__/distrib.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35d5648fb781d6b57dacce130f3a4da5a4e96af7 Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/distrib.cpython-310.pyc differ diff --git a/speechtokenizer/quantization/__pycache__/distrib.cpython-311.pyc b/speechtokenizer/quantization/__pycache__/distrib.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cdf666fb73d46c7e3beaeb7f55c23583108dd16 Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/distrib.cpython-311.pyc differ diff --git a/speechtokenizer/quantization/__pycache__/distrib.cpython-39.pyc b/speechtokenizer/quantization/__pycache__/distrib.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..069fc2a8a74117c9bfeb955243c55c785cc2d16b Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/distrib.cpython-39.pyc differ diff --git a/speechtokenizer/quantization/__pycache__/vq.cpython-310.pyc b/speechtokenizer/quantization/__pycache__/vq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..537d9ea38286c2ef6ee329b43e421bd988c69388 Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/vq.cpython-310.pyc differ diff --git a/speechtokenizer/quantization/__pycache__/vq.cpython-311.pyc b/speechtokenizer/quantization/__pycache__/vq.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96dec9ab5693e6ffdc62cdfe7210c901c530e461 Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/vq.cpython-311.pyc differ diff --git a/speechtokenizer/quantization/__pycache__/vq.cpython-39.pyc b/speechtokenizer/quantization/__pycache__/vq.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9847026c05a9d3f900fb985c58a658d1ebe4a74a Binary files /dev/null and b/speechtokenizer/quantization/__pycache__/vq.cpython-39.pyc differ diff --git a/speechtokenizer/quantization/ac.py b/speechtokenizer/quantization/ac.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f3e5dcd385cd273a145effa3f53ce7ccfdc74c --- /dev/null +++ b/speechtokenizer/quantization/ac.py @@ -0,0 +1,292 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Arithmetic coder.""" + +import io +import math +import random +import typing as tp +import torch + +from ..binary import BitPacker, BitUnpacker + + +def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int, + roundoff: float = 1e-8, min_range: int = 2, + check: bool = True) -> torch.Tensor: + """Turn the given PDF into a quantized CDF that splits + [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional + to the PDF. + + Args: + pdf (torch.Tensor): probability distribution, shape should be `[N]`. + total_range_bits (int): see `ArithmeticCoder`, the typical range we expect + during the coding process is `[0, 2 ** total_range_bits - 1]`. + roundoff (float): will round the pdf up to that level to remove difference coming + from e.g. evaluating the Language Model on different architectures. + min_range (int): minimum range width. Should always be at least 2 for numerical + stability. Use this to avoid pathological behavior is a value + that is expected to be rare actually happens in real life. + check (bool): if True, checks that nothing bad happened, can be deactivated for speed. + """ + pdf = pdf.detach() + if roundoff: + pdf = (pdf / roundoff).floor() * roundoff + # interpolate with uniform distribution to achieve desired minimum probability. + total_range = 2 ** total_range_bits + cardinality = len(pdf) + alpha = min_range * cardinality / total_range + assert alpha <= 1, "you must reduce min_range" + ranges = (((1 - alpha) * total_range) * pdf).floor().long() + ranges += min_range + quantized_cdf = torch.cumsum(ranges, dim=-1) + if min_range < 2: + raise ValueError("min_range must be at least 2.") + if check: + assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1] + if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range: + raise ValueError("You must increase your total_range_bits.") + return quantized_cdf + + +class ArithmeticCoder: + """ArithmeticCoder, + Let us take a distribution `p` over `N` symbols, and assume we have a stream + of random variables `s_t` sampled from `p`. Let us assume that we have a budget + of `B` bits that we can afford to write on device. There are `2**B` possible numbers, + corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single + sequence `(s_t)` by doing the following: + + 1) Initialize the current range to` [0 ** 2 B - 1]`. + 2) For each time step t, split the current range into contiguous chunks, + one for each possible outcome, with size roughly proportional to `p`. + For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks + would be `{[0, 2], [3, 3]}`. + 3) Select the chunk corresponding to `s_t`, and replace the current range with this. + 4) When done encoding all the values, just select any value remaining in the range. + + You will notice that this procedure can fail: for instance if at any point in time + the range is smaller than `N`, then we can no longer assign a non-empty chunk to each + possible outcome. Intuitively, the more likely a value is, the less the range width + will reduce, and the longer we can go on encoding values. This makes sense: for any efficient + coding scheme, likely outcomes would take less bits, and more of them can be coded + with a fixed budget. + + In practice, we do not know `B` ahead of time, but we have a way to inject new bits + when the current range decreases below a given limit (given by `total_range_bits`), without + having to redo all the computations. If we encode mostly likely values, we will seldom + need to inject new bits, but a single rare value can deplete our stock of entropy! + + In this explanation, we assumed that the distribution `p` was constant. In fact, the present + code works for any sequence `(p_t)` possibly different for each timestep. + We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller + the KL between the true distribution and `p_t`, the most efficient the coding will be. + + Args: + fo (IO[bytes]): file-like object to which the bytes will be written to. + total_range_bits (int): the range `M` described above is `2 ** total_range_bits. + Any time the current range width fall under this limit, new bits will + be injected to rescale the initial range. + """ + + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + assert total_range_bits <= 30 + self.total_range_bits = total_range_bits + self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. + self.low: int = 0 + self.high: int = 0 + self.max_bit: int = -1 + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + + @property + def delta(self) -> int: + """Return the current range width.""" + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # If self.low and self.high start with the sames bits, + # those won't change anymore as we always just increase the range + # by powers of 2, and we can flush them out to the bit stream. + assert self.high >= self.low, (self.low, self.high) + assert self.high < 2 ** (self.max_bit + 1) + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= (b1 << self.max_bit) + self.high -= (b1 << self.max_bit) + assert self.high >= self.low, (self.high, self.low, self.max_bit) + assert self.low >= 0 + self.max_bit -= 1 + self.packer.push(b1) + else: + break + + def push(self, symbol: int, quantized_cdf: torch.Tensor): + """Push the given symbol on the stream, flushing out bits + if possible. + + Args: + symbol (int): symbol to encode with the AC. + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. + """ + while self.delta < 2 ** self.total_range_bits: + self.low *= 2 + self.high = self.high * 2 + 1 + self.max_bit += 1 + + range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() + range_high = quantized_cdf[symbol].item() - 1 + effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) + effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) + assert self.low <= self.high + self.high = self.low + effective_high + self.low = self.low + effective_low + assert self.low <= self.high, (effective_low, effective_high, range_low, range_high) + self._dbg.append((self.low, self.high)) + self._dbg2.append((self.low, self.high)) + outs = self._flush_common_prefix() + assert self.low <= self.high + assert self.max_bit >= -1 + assert self.max_bit <= 61, self.max_bit + return outs + + def flush(self): + """Flush the remaining information to the stream. + """ + while self.max_bit >= 0: + b1 = (self.low >> self.max_bit) & 1 + self.packer.push(b1) + self.max_bit -= 1 + self.packer.flush() + + +class ArithmeticDecoder: + """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. + + Note that this must be called with **exactly** the same parameters and sequence + of quantized cdf as the arithmetic encoder or the wrong values will be decoded. + + If the AC encoder current range is [L, H], with `L` and `H` having the some common + prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. + For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside + `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained + for a specific sequence of symbols and a binary-search allows us to decode those symbols. + At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, + and we will need to read new bits from the stream and repeat the process. + + """ + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + self.total_range_bits = total_range_bits + self.low: int = 0 + self.high: int = 0 + self.current: int = 0 + self.max_bit: int = -1 + self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. + # Following is for debugging + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + self._last: tp.Any = None + + @property + def delta(self) -> int: + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # Given the current range [L, H], if both have a common prefix, + # we know we can remove it from our representation to avoid handling large numbers. + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= (b1 << self.max_bit) + self.high -= (b1 << self.max_bit) + self.current -= (b1 << self.max_bit) + assert self.high >= self.low + assert self.low >= 0 + self.max_bit -= 1 + else: + break + + def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: + """Pull a symbol, reading as many bits from the stream as required. + This returns `None` when the stream has been exhausted. + + Args: + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. This must be **exatly** + the same cdf as the one used at encoding time. + """ + while self.delta < 2 ** self.total_range_bits: + bit = self.unpacker.pull() + if bit is None: + return None + self.low *= 2 + self.high = self.high * 2 + 1 + self.current = self.current * 2 + bit + self.max_bit += 1 + + def bin_search(low_idx: int, high_idx: int): + # Binary search is not just for coding interviews :) + if high_idx < low_idx: + raise RuntimeError("Binary search failed") + mid = (low_idx + high_idx) // 2 + range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 + range_high = quantized_cdf[mid].item() - 1 + effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) + effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) + low = effective_low + self.low + high = effective_high + self.low + if self.current >= low: + if self.current <= high: + return (mid, low, high, self.current) + else: + return bin_search(mid + 1, high_idx) + else: + return bin_search(low_idx, mid - 1) + + self._last = (self.low, self.high, self.current, self.max_bit) + sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) + self._dbg.append((self.low, self.high, self.current)) + self._flush_common_prefix() + self._dbg2.append((self.low, self.high, self.current)) + + return sym + + +def test(): + torch.manual_seed(1234) + random.seed(1234) + for _ in range(4): + pdfs = [] + cardinality = random.randrange(4000) + steps = random.randrange(100, 500) + fo = io.BytesIO() + encoder = ArithmeticCoder(fo) + symbols = [] + for step in range(steps): + pdf = torch.softmax(torch.randn(cardinality), dim=0) + pdfs.append(pdf) + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + symbol = torch.multinomial(pdf, 1).item() + symbols.append(symbol) + encoder.push(symbol, q_cdf) + encoder.flush() + + fo.seek(0) + decoder = ArithmeticDecoder(fo) + for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + decoded_symbol = decoder.pull(q_cdf) + assert decoded_symbol == symbol, idx + assert decoder.pull(torch.zeros(1)) is None + + +if __name__ == "__main__": + test() diff --git a/speechtokenizer/quantization/core_vq.py b/speechtokenizer/quantization/core_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..c8f32ab303e884a120863f5e360e2af95c98429f --- /dev/null +++ b/speechtokenizer/quantization/core_vq.py @@ -0,0 +1,385 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Core vector quantization implementation.""" +import typing as tp + +from einops import rearrange, repeat +import torch +from torch import nn +import torch.nn.functional as F + +from .distrib import broadcast_tensors, rank + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = ( + uniform_init if not kmeans_init else torch.zeros + ) + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + # broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], sample_vectors(samples, self.codebook_size), self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + # broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -( + x.pow(2).sum(1, keepdim=True) + - 2 * x @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + + # self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + # if self.training: + # # We do the expiry of code at that point as buffers are in sync + # # and all the workers will take the same decision. + # self.expire_codes_(x) + # ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + # embed_sum = x.t() @ embed_onehot + # ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + # cluster_size = ( + # laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) + # * self.cluster_size.sum() + # ) + # embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + # self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1.0, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = ( + nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + ) + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x): + device = x.device + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + + def forward( + self, + x, + n_q: tp.Optional[int] = None, + layers: tp.Optional[list] = None, + st: tp.Optional[int] = None, + ): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + out_quantized = [] + n_q = n_q or len(self.layers) + st = st or 0 + for i, layer in enumerate(self.layers[st:n_q]): + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + if layers and i in layers: + out_quantized.append(quantized) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses, out_quantized + + def encode( + self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None + ) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + st = st or 0 + for layer in self.layers[st:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[st + i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out diff --git a/speechtokenizer/quantization/distrib.py b/speechtokenizer/quantization/distrib.py new file mode 100644 index 0000000000000000000000000000000000000000..4edc88a532c9cf842bd2ee9a56f2c0b5dcb7c83d --- /dev/null +++ b/speechtokenizer/quantization/distrib.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch distributed utilities.""" + +import typing as tp + +import torch + + +def rank(): + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + + +def world_size(): + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + +def is_distributed(): + return world_size() > 1 + + +def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): + if is_distributed(): + return torch.distributed.all_reduce(tensor, op) + + +def _is_complex_or_float(tensor): + return torch.is_floating_point(tensor) or torch.is_complex(tensor) + + +def _check_number_of_params(params: tp.List[torch.Tensor]): + # utility function to check that the number of params in all workers is the same, + # and thus avoid a deadlock with distributed all reduce. + if not is_distributed() or not params: + return + #print('params[0].device ', params[0].device) + tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) + all_reduce(tensor) + if tensor.item() != len(params) * world_size(): + # If not all the workers have the same number, for at least one of them, + # this inequality will be verified. + raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, " + "at least one worker has a different one.") + + +def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): + """Broadcast the tensors from the given parameters to all workers. + This can be used to ensure that all workers have the same model to start with. + """ + if not is_distributed(): + return + tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] + _check_number_of_params(tensors) + handles = [] + for tensor in tensors: + # src = int(rank()) # added code + handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) + handles.append(handle) + for handle in handles: + handle.wait() + + +def sync_buffer(buffers, average=True): + """ + Sync grad for buffers. If average is False, broadcast instead of averaging. + """ + if not is_distributed(): + return + handles = [] + for buffer in buffers: + if torch.is_floating_point(buffer.data): + if average: + handle = torch.distributed.all_reduce( + buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + else: + handle = torch.distributed.broadcast( + buffer.data, src=0, async_op=True) + handles.append((buffer, handle)) + for buffer, handle in handles: + handle.wait() + if average: + buffer.data /= world_size + + +def sync_grad(params): + """ + Simpler alternative to DistributedDataParallel, that doesn't rely + on any black magic. For simple models it can also be as fast. + Just call this on your model parameters after the call to backward! + """ + if not is_distributed(): + return + handles = [] + for p in params: + if p.grad is not None: + handle = torch.distributed.all_reduce( + p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + handles.append((p, handle)) + for p, handle in handles: + handle.wait() + p.grad.data /= world_size() + + +def average_metrics(metrics: tp.Dict[str, float], count=1.): + """Average a dictionary of metrics across all workers, using the optional + `count` as unormalized weight. + """ + if not is_distributed(): + return metrics + keys, values = zip(*metrics.items()) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) + tensor *= count + all_reduce(tensor) + averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() + return dict(zip(keys, averaged)) diff --git a/speechtokenizer/quantization/vq.py b/speechtokenizer/quantization/vq.py new file mode 100644 index 0000000000000000000000000000000000000000..5f6e65f99c2743b5b1a1ca507ce57ad4e9132da1 --- /dev/null +++ b/speechtokenizer/quantization/vq.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Residual vector quantizer implementation.""" + +from dataclasses import dataclass, field +import math +import typing as tp + +import torch +from torch import nn + +from .core_vq import ResidualVectorQuantization + + +@dataclass +class QuantizedResult: + quantized: torch.Tensor + codes: torch.Tensor + bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. + penalty: tp.Optional[torch.Tensor] = None + metrics: dict = field(default_factory=dict) + + +class ResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer. + Args: + dimension (int): Dimension of the codebooks. + n_q (int): Number of residual vector quantizers used. + bins (int): Codebook size. + decay (float): Decay for exponential moving average over the codebooks. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dimension: int = 256, + n_q: int = 8, + bins: int = 1024, + decay: float = 0.99, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.n_q = n_q + self.dimension = dimension + self.bins = bins + self.decay = decay + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + self.vq = ResidualVectorQuantization( + dim=self.dimension, + codebook_size=self.bins, + num_quantizers=self.n_q, + decay=self.decay, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + threshold_ema_dead_code=self.threshold_ema_dead_code, + ) + + def forward( + self, + x: torch.Tensor, + n_q: tp.Optional[int] = None, + layers: tp.Optional[list] = None, + st: tp.Optional[int] = None, + ) -> QuantizedResult: + """Residual vector quantization on the given input tensor. + Args: + x (torch.Tensor): Input tensor. + n_q (int): Number of quantizer used to quantize. Default: All quantizers. + layers (list): Layer that need to return quantized. Defalt: None. + Returns: + QuantizedResult: + The quantized (or approximately quantized) representation with + the associated numbert quantizers and layer quantized required to return. + """ + n_q = n_q if n_q else self.n_q + if layers and max(layers) >= n_q: + raise ValueError( + f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B." + ) + quantized, codes, commit_loss, quantized_list = self.vq( + x, n_q=n_q, layers=layers, st=st + ) + return quantized, codes, torch.mean(commit_loss), quantized_list + + def encode( + self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None + ) -> torch.Tensor: + """Encode a given input tensor with the specified sample rate at the given bandwidth. + The RVQ encode method sets the appropriate number of quantizer to use + and returns indices for each quantizer. + Args: + x (torch.Tensor): Input tensor. + n_q (int): Number of quantizer used to quantize. Default: All quantizers. + st (int): Start to encode input from which layers. Default: 0. + """ + n_q = n_q if n_q else self.n_q + st = st or 0 + codes = self.vq.encode(x, n_q=n_q, st=st) + return codes + + def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor: + """Decode the given codes to the quantized representation. + Args: + codes (torch.Tensor): Input indices for each quantizer. + st (int): Start to decode input codes from which layers. Default: 0. + """ + quantized = self.vq.decode(codes, st=st) + return quantized diff --git a/voicemark.pth b/voicemark.pth new file mode 100644 index 0000000000000000000000000000000000000000..4ab3d9eb9e512f435625c3de8b87762da9ccb11c --- /dev/null +++ b/voicemark.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ff9975b18ee175ad05ac0edac88583db468fd56867b3bcdd2efa6584de85bfe4 +size 86430822 diff --git a/voicemark_overview.png b/voicemark_overview.png new file mode 100644 index 0000000000000000000000000000000000000000..e565824fa3f4c229dc3323ee47f0c5f72ab99233 --- /dev/null +++ b/voicemark_overview.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ddb86c3c1b83ef3a14dbe3c8a1960718534733d53a406f8ea00222ffd5fd989d +size 297901