NeuCoSVC / modules /wavlm_encoder.py
kevinwang676's picture
Upload folder using huggingface_hub
cfdc687
import resampy
from pathlib import Path
import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as T
from torch import Tensor
from torchaudio.sox_effects import apply_effects_tensor
from modules.wavlm.WavLM import WavLM, WavLMConfig
class WavLMEncoder(nn.Module):
def __init__(self,
ckpt_path,
device='cpu'
):
"""
Load the WavLM large checkpoint from the original paper. See https://github.com/microsoft/unilm/tree/master/wavlm for details.
Args:
ckpt_path : checkpoint path of WavLM.
"""
super().__init__()
wavlm_check_point = torch.load(ckpt_path)
cfg = WavLMConfig(wavlm_check_point['cfg'])
wavlm = WavLM(cfg)
wavlm.load_state_dict(wavlm_check_point['model'])
wavlm = wavlm.to(device)
# store wavlm
self.wavlm = wavlm.eval()
self.device = torch.device(device)
self.sr = 16000
@torch.inference_mode()
def get_features(self, path, output_layer=None, weights=None, vad_trigger_level=0):
"""
Returns the features of the waveform at `path` as a tensor of shape (seq_len, dim).
Optionally, performs Voice Activity Detection (VAD) trimming on the start and end of the waveform
using the `vad_trigger_level`.
If the `output_layer` is specified, the result of the corresponding layer is returned.
If the `weights` are specified, the weighted result of the corresponding layers is returned.
If neither `output_layer` nor `weights` are specified, the result of all layers is returned.
Args:
path (str or torch.Tensor): Path to the audio waveform file or a tensor representing the waveform.
output_layer (int, optional): Index of the layer to extract the features from. Defaults to None.
weights (torch.Tensor, optional): Weights to apply to the features of each layer. Defaults to None.
vad_trigger_level (float, optional): VAD trigger level for trimming silence. Defaults to 0.
Returns:
torch.Tensor: Extracted WavLM features of the waveform.
"""
# load audio
if type(path) in [str, Path]:
x, sr = torchaudio.load(path, normalize=True)
if sr != self.sr:
print(f'Original audio sr is {sr}, change it to {self.sr}.')
x = resampy.resample(x.numpy(), sr, self.sr, axis=1)
x = torch.from_numpy(x).to(dtype=torch.float)
sr = self.sr
else:
x: Tensor = path
sr = self.sr
if x.dim() == 1: x = x[None]
assert sr == self.sr, f"input audio sample rate must be 16kHz. Got {sr}"
# trim silence from front and back
if vad_trigger_level > 1e-3:
transform = T.Vad(sample_rate=sr, trigger_level=vad_trigger_level)
x_front_trim = transform(x)
waveform_reversed, sr = apply_effects_tensor(x_front_trim, sr, [["reverse"]])
waveform_reversed_front_trim = transform(waveform_reversed)
waveform_end_trim, sr = apply_effects_tensor(
waveform_reversed_front_trim, sr, [["reverse"]]
)
x = waveform_end_trim
# extract the representation of each layer
wav_input_16khz = x.to(self.device)
if output_layer is not None:
# use fastpath
features = self.wavlm.extract_features(wav_input_16khz, output_layer=output_layer, ret_layer_results=False)[0]
features = torch.squeeze(features)
else:
# use slower weighted
rep, layer_results = self.wavlm.extract_features(wav_input_16khz, output_layer=self.wavlm.cfg.encoder_layers, ret_layer_results=True)[0]
features = torch.cat([x.transpose(0, 1) for x, _ in layer_results], dim=0) # (n_layers, seq_len, dim)
# save full sequence
if weights is not None:
features = (features*weights[:, None] ).sum(dim=0) # (1, seq_len, dim)
return features