File size: 4,140 Bytes
cfdc687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import resampy
from pathlib import Path

import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as T
from torch import Tensor
from torchaudio.sox_effects import apply_effects_tensor

from modules.wavlm.WavLM import WavLM, WavLMConfig


class WavLMEncoder(nn.Module):

    def __init__(self,
        ckpt_path,
        device='cpu'
    ):
        """ 
        Load the WavLM large checkpoint from the original paper. See https://github.com/microsoft/unilm/tree/master/wavlm for details. 
        
        Args:
            ckpt_path : checkpoint path of WavLM.

        """
        super().__init__()
        wavlm_check_point = torch.load(ckpt_path)
        cfg = WavLMConfig(wavlm_check_point['cfg'])
        wavlm = WavLM(cfg)
        wavlm.load_state_dict(wavlm_check_point['model'])
        wavlm = wavlm.to(device)
        
        # store wavlm
        self.wavlm = wavlm.eval()
        self.device = torch.device(device)
        self.sr = 16000
    
    @torch.inference_mode()
    def get_features(self, path, output_layer=None, weights=None, vad_trigger_level=0):
        """
        Returns the features of the waveform at `path` as a tensor of shape (seq_len, dim).
        Optionally, performs Voice Activity Detection (VAD) trimming on the start and end of the waveform
        using the `vad_trigger_level`.

        If the `output_layer` is specified, the result of the corresponding layer is returned.
        If the `weights` are specified, the weighted result of the corresponding layers is returned.
        If neither `output_layer` nor `weights` are specified, the result of all layers is returned.

        Args:
            path (str or torch.Tensor): Path to the audio waveform file or a tensor representing the waveform.
            output_layer (int, optional): Index of the layer to extract the features from. Defaults to None.
            weights (torch.Tensor, optional): Weights to apply to the features of each layer. Defaults to None.
            vad_trigger_level (float, optional): VAD trigger level for trimming silence. Defaults to 0.

        Returns:
            torch.Tensor: Extracted WavLM features of the waveform.

        """
        # load audio
        if type(path) in [str, Path]:
            x, sr = torchaudio.load(path, normalize=True)
            if sr != self.sr:
                print(f'Original audio sr is {sr}, change it to {self.sr}.')
                x = resampy.resample(x.numpy(), sr, self.sr, axis=1)
                x = torch.from_numpy(x).to(dtype=torch.float)
                sr = self.sr
        else:
            x: Tensor = path
            sr = self.sr
            if x.dim() == 1: x = x[None]
        assert sr == self.sr, f"input audio sample rate must be 16kHz. Got {sr}"
        
        # trim silence from front and back
        if vad_trigger_level > 1e-3:
            transform = T.Vad(sample_rate=sr, trigger_level=vad_trigger_level)
            x_front_trim = transform(x)
            waveform_reversed, sr = apply_effects_tensor(x_front_trim, sr, [["reverse"]])
            waveform_reversed_front_trim = transform(waveform_reversed)
            waveform_end_trim, sr = apply_effects_tensor(
                waveform_reversed_front_trim, sr, [["reverse"]]
            )
            x = waveform_end_trim

        # extract the representation of each layer
        wav_input_16khz = x.to(self.device)
        if output_layer is not None:
            # use fastpath
            features = self.wavlm.extract_features(wav_input_16khz, output_layer=output_layer, ret_layer_results=False)[0]
            features = torch.squeeze(features)
        else:
            # use slower weighted
            rep, layer_results = self.wavlm.extract_features(wav_input_16khz, output_layer=self.wavlm.cfg.encoder_layers, ret_layer_results=True)[0]
            features = torch.cat([x.transpose(0, 1) for x, _ in layer_results], dim=0) # (n_layers, seq_len, dim)
            # save full sequence
            if weights is not None:
                features = (features*weights[:, None] ).sum(dim=0) # (1, seq_len, dim)
        
        return features