File size: 2,244 Bytes
a726cc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba04fa0
 
186b2ec
a726cc5
 
 
 
 
 
 
43e9fc0
a726cc5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
import joblib
from components.semantic_extractor.WavLM import WavLM, WavLMConfig

class ApplyKmeans(nn.Module):
    def __init__(self, km_path, device='cuda'):
        super(ApplyKmeans, self).__init__()
        print(f'Init k-means model from {km_path}')
        self.km_model = joblib.load(km_path)
        self.C_np = self.km_model.cluster_centers_.transpose()
        self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True)
        self.C = torch.from_numpy(self.C_np).to(device)
        self.Cnorm = torch.from_numpy(self.Cnorm_np).to(device)
        self.emb = nn.Embedding(num_embeddings=300, embedding_dim=1024)
        self.emb.weight.data = self.C.transpose(0, 1)
        self.emb.weight.require_grad = False

    def forward(self, x, b, t):
        if not hasattr(self, 'C'):
            self.C = torch.from_numpy(self.C_np).to(x.device)
        if not hasattr(self, 'Cnorm'):
            self.Cnorm = torch.from_numpy(self.Cnorm_np).to(x.device)
        dist = x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm
        tokens = dist.argmin(dim=-1).reshape(b, t)
        return tokens

def get_ssl_model(ckpt_path, km_path, device='cuda', type='xlsr'):
    if type == 'xlsr':
        print(f'Init xlsr model from {ckpt_path}')
        import fairseq
        import argparse
        task_arg = argparse.Namespace(task='audio_pretraining')
        task = fairseq.tasks.setup_task(task_arg)
        model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path], task=task)
        model = model[0]
        model.eval()
    elif type == 'wavlm':
        print(f'Init wavlm model from {ckpt_path}')
        from huggingface_hub import hf_hub_download
        cpt = hf_hub_download(repo_id="yaoxunji/gense", filename="WavLM-Large.pt")
        cpt = torch.load(cpt, map_location="cpu")
        cfg = WavLMConfig(cpt["cfg"])
        model = WavLM(cfg)
        model.load_state_dict(cpt["model"])
        model = model.eval()
        model = model.requires_grad_(False)
    else:
        raise NotImplementedError
    km_path = hf_hub_download(repo_id="yaoxunji/gense", filename="wavlm_km.mdl")
    km_model = ApplyKmeans(km_path, device)
    return model, km_model