ordinaryaccount's picture
Add file
26f400e
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 15:47:55 2023
@author: zhangxin
"""
import random
from models import WMEmbedder
from .modules.seanet import SEANetEncoder, SEANetDecoder
from .quantization import ResidualVectorQuantizer
import torch.nn as nn
from einops import rearrange
import torch
import numpy as np
class SpeechTokenizer(nn.Module):
def __init__(self, config):
"""
Parameters
----------
config : json
Model Config.
"""
super().__init__()
self.encoder = SEANetEncoder(
n_filters=config.get("n_filters"),
dimension=config.get("dimension"),
ratios=config.get("strides"),
lstm=config.get("lstm_layers"),
bidirectional=config.get("bidirectional"),
dilation_base=config.get("dilation_base"),
residual_kernel_size=config.get("residual_kernel_size"),
n_residual_layers=config.get("n_residual_layers"),
activation=config.get("activation"),
)
self.sample_rate = config.get("sample_rate")
self.n_q = config.get("n_q")
self.downsample_rate = np.prod(config.get("strides"))
if config.get("dimension") != config.get("semantic_dimension"):
self.transform = nn.Linear(
config.get("dimension"), config.get("semantic_dimension")
)
else:
self.transform = nn.Identity()
self.quantizer = ResidualVectorQuantizer(
dimension=config.get("dimension"),
n_q=config.get("n_q"),
bins=config.get("codebook_size"),
)
self.decoder = SEANetDecoder(
n_filters=config.get("n_filters"),
dimension=config.get("dimension"),
ratios=config.get("strides"),
lstm=config.get("lstm_layers"),
bidirectional=False,
dilation_base=config.get("dilation_base"),
residual_kernel_size=config.get("residual_kernel_size"),
n_residual_layers=config.get("n_residual_layers"),
activation=config.get("activation"),
)
@classmethod
def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
"""
Parameters
----------
config_path : str
Path of model configuration file.
ckpt_path : str
Path of model checkpoint.
Returns
-------
model : SpeechTokenizer
SpeechTokenizer model.
"""
import json
with open(config_path) as f:
cfg = json.load(f)
model = cls(cfg)
params = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(params)
return model
def forward(
self,
x: torch.tensor,
n_q: int = None,
layers: list = [0],
msg_processor: WMEmbedder = None,
message: torch.Tensor = None,
):
"""
Parameters
----------
x : torch.tensor
Input wavs. Shape: (batch, channels, timesteps).
n_q : int, optional
Number of quantizers in RVQ used to encode. The default is all layers.
layers : list[int], optional
Layers of RVQ should return quantized result. The default is the first layer.
Returns
-------
o : torch.tensor
Output wavs. Shape: (batch, channels, timesteps).
commit_loss : torch.tensor
Commitment loss from residual vector quantizers.
feature : torch.tensor
Output of RVQ's first layer. Shape: (batch, timesteps, dimension)
"""
with torch.no_grad():
e = self.encoder(x)
quantized_full, _, _, quantized_list = self.quantizer(
e, n_q=n_q, layers=[0, 1, 2, 3, 4, 5, 6, 7], st=0
)
# semantic, _, _, _ = self.quantizer(e, n_q=1, st=0)
# acoustic = e - semantic
o = self.decoder(quantized_full)
subset = quantized_list[1:]
# half_len = len(subset) // 2
# selected_for_processing = random.sample(subset, half_len)
# selected_ids = set(id(x) for x in selected_for_processing)
# acoustic_wm = sum(
# msg_processor(x, message) if id(x) in selected_ids else x for x in subset
# )
acoustic_wm = sum(msg_processor(x, message) for x in subset)
acoustic = e - quantized_list[0]
# e_wm = acoustic_wm
e_wm = quantized_list[0] + acoustic_wm
o_wm = self.decoder(e_wm)
return (o, o_wm, acoustic, acoustic_wm)
def forward_feature(self, x: torch.tensor, layers: list = None):
"""
Parameters
----------
x : torch.tensor
Input wavs. Shape should be (batch, channels, timesteps).
layers : list[int], optional
Layers of RVQ should return quantized result. The default is all layers.
Returns
-------
quantized_list : list[torch.tensor]
Quantized of required layers.
"""
e = self.encoder(x)
with torch.no_grad():
semantic, _, _, _ = self.quantizer(e, st=0, n_q=1)
acoustic = e - semantic
return acoustic
def encode(self, x: torch.tensor, n_q: int = None, st: int = None):
"""
Parameters
----------
x : torch.tensor
Input wavs. Shape: (batch, channels, timesteps).
n_q : int, optional
Number of quantizers in RVQ used to encode. The default is all layers.
st : int, optional
Start quantizer index in RVQ. The default is 0.
Returns
-------
codes : torch.tensor
Output indices for each quantizer. Shape: (n_q, batch, timesteps)
"""
e = self.encoder(x)
if st is None:
st = 0
n_q = n_q if n_q else self.n_q
codes = self.quantizer.encode(e, n_q=n_q, st=st)
return codes
def decode(self, codes: torch.tensor, st: int = 0):
"""
Parameters
----------
codes : torch.tensor
Indices for each quantizer. Shape: (n_q, batch, timesteps).
st : int, optional
Start quantizer index in RVQ. The default is 0.
Returns
-------
o : torch.tensor
Reconstruct wavs from codes. Shape: (batch, channels, timesteps)
"""
quantized = self.quantizer.decode(codes, st=st)
o = self.decoder(quantized)
return o