Spaces:
Configuration error
Configuration error
import resampy | |
from pathlib import Path | |
import torch | |
import torch.nn as nn | |
import torchaudio | |
import torchaudio.transforms as T | |
from torch import Tensor | |
from torchaudio.sox_effects import apply_effects_tensor | |
from modules.wavlm.WavLM import WavLM, WavLMConfig | |
class WavLMEncoder(nn.Module): | |
def __init__(self, | |
ckpt_path, | |
device='cpu' | |
): | |
""" | |
Load the WavLM large checkpoint from the original paper. See https://github.com/microsoft/unilm/tree/master/wavlm for details. | |
Args: | |
ckpt_path : checkpoint path of WavLM. | |
""" | |
super().__init__() | |
wavlm_check_point = torch.load(ckpt_path) | |
cfg = WavLMConfig(wavlm_check_point['cfg']) | |
wavlm = WavLM(cfg) | |
wavlm.load_state_dict(wavlm_check_point['model']) | |
wavlm = wavlm.to(device) | |
# store wavlm | |
self.wavlm = wavlm.eval() | |
self.device = torch.device(device) | |
self.sr = 16000 | |
def get_features(self, path, output_layer=None, weights=None, vad_trigger_level=0): | |
""" | |
Returns the features of the waveform at `path` as a tensor of shape (seq_len, dim). | |
Optionally, performs Voice Activity Detection (VAD) trimming on the start and end of the waveform | |
using the `vad_trigger_level`. | |
If the `output_layer` is specified, the result of the corresponding layer is returned. | |
If the `weights` are specified, the weighted result of the corresponding layers is returned. | |
If neither `output_layer` nor `weights` are specified, the result of all layers is returned. | |
Args: | |
path (str or torch.Tensor): Path to the audio waveform file or a tensor representing the waveform. | |
output_layer (int, optional): Index of the layer to extract the features from. Defaults to None. | |
weights (torch.Tensor, optional): Weights to apply to the features of each layer. Defaults to None. | |
vad_trigger_level (float, optional): VAD trigger level for trimming silence. Defaults to 0. | |
Returns: | |
torch.Tensor: Extracted WavLM features of the waveform. | |
""" | |
# load audio | |
if type(path) in [str, Path]: | |
x, sr = torchaudio.load(path, normalize=True) | |
if sr != self.sr: | |
print(f'Original audio sr is {sr}, change it to {self.sr}.') | |
x = resampy.resample(x.numpy(), sr, self.sr, axis=1) | |
x = torch.from_numpy(x).to(dtype=torch.float) | |
sr = self.sr | |
else: | |
x: Tensor = path | |
sr = self.sr | |
if x.dim() == 1: x = x[None] | |
assert sr == self.sr, f"input audio sample rate must be 16kHz. Got {sr}" | |
# trim silence from front and back | |
if vad_trigger_level > 1e-3: | |
transform = T.Vad(sample_rate=sr, trigger_level=vad_trigger_level) | |
x_front_trim = transform(x) | |
waveform_reversed, sr = apply_effects_tensor(x_front_trim, sr, [["reverse"]]) | |
waveform_reversed_front_trim = transform(waveform_reversed) | |
waveform_end_trim, sr = apply_effects_tensor( | |
waveform_reversed_front_trim, sr, [["reverse"]] | |
) | |
x = waveform_end_trim | |
# extract the representation of each layer | |
wav_input_16khz = x.to(self.device) | |
if output_layer is not None: | |
# use fastpath | |
features = self.wavlm.extract_features(wav_input_16khz, output_layer=output_layer, ret_layer_results=False)[0] | |
features = torch.squeeze(features) | |
else: | |
# use slower weighted | |
rep, layer_results = self.wavlm.extract_features(wav_input_16khz, output_layer=self.wavlm.cfg.encoder_layers, ret_layer_results=True)[0] | |
features = torch.cat([x.transpose(0, 1) for x, _ in layer_results], dim=0) # (n_layers, seq_len, dim) | |
# save full sequence | |
if weights is not None: | |
features = (features*weights[:, None] ).sum(dim=0) # (1, seq_len, dim) | |
return features |