File size: 5,515 Bytes
cfdc687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
import torch
import numpy as np
import soundfile as sf


def fast_cosine_dist(source_feats: torch.Tensor, matching_pool: torch.Tensor, device='cpu'):
    """
    Computes the cosine distance between source features and a matching pool of features.
    Like torch.cdist, but fixed dim=-1 and for cosine distance.

    Args:
        source_feats (torch.Tensor): Tensor of source features with shape (n_source_feats, feat_dim).
        matching_pool (torch.Tensor): Tensor of matching pool features with shape (n_matching_feats, feat_dim).
        device (str, optional): Device to perform the computation on. Defaults to 'cpu'.

    Returns:
        torch.Tensor: Tensor of cosine distances between the source features and the matching pool features.

    """
    source_feats = source_feats.to(device)
    matching_pool = matching_pool.to(device)
    source_norms = torch.norm(source_feats, p=2, dim=-1)
    matching_norms = torch.norm(matching_pool, p=2, dim=-1)
    dotprod = -torch.cdist(source_feats[None].to(device), matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2
    dotprod /= 2

    dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) )
    return dists


def load_wav(wav_path, sr=None):
    """
    Loads a waveform from a wav file.

    Args:
        wav_path (str): Path to the wav file.
        sr (int, optional): Target sample rate. 
            If `sr` is specified and the loaded audio has a different sample rate, an AssertionError is raised. 
            Defaults to None.

    Returns:
        Tuple[np.ndarray, int]: Tuple containing the loaded waveform as a NumPy array and the sample rate.

    """
    wav, fs = sf.read(wav_path)
    if wav.ndim != 1:
        print('The wav file %s has %d channels, select the first one to proceed.' %(wav_path, wav.ndim))
        wav = wav[:,0]
    assert sr is None or fs == sr, f'{sr} kHz audio is required. Got {fs}'
    peak = np.abs(wav).max()
    if peak > 1.0:
        wav /= peak
    return wav, fs


class ConfigWrapper(object):
    """
    Wrapper dict class to avoid annoying key dict indexing like:
    `config.sample_rate` instead of `config["sample_rate"]`.
    """
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            if type(v) == dict:
                v = ConfigWrapper(**v)
            self[k] = v
      
    def keys(self):
        return self.__dict__.keys()

    def items(self):
        return self.__dict__.items()

    def values(self):
        return self.__dict__.values()

    def to_dict_type(self):
        return {
            key: (value if not isinstance(value, ConfigWrapper) else value.to_dict_type())
            for key, value in dict(**self).items()
        }

    def __len__(self):
        return len(self.__dict__)

    def __getitem__(self, key):
        return getattr(self, key)

    def __setitem__(self, key, value):
        return setattr(self, key, value)

    def __contains__(self, key):
        return key in self.__dict__

    def __repr__(self):
        return self.__dict__.__repr__()


def save_checkpoint(steps, epochs, model, optimizer, scheduler, checkpoint_path, dst_train=False):
    """Save checkpoint.

    Args:
        checkpoint_path (str): Checkpoint path to be saved.

    """
    state_dict = {
        "optimizer": {
            "generator": optimizer["generator"].state_dict(),
            "discriminator": optimizer["discriminator"].state_dict(),
        },
        "scheduler": {
            "generator": scheduler["generator"].state_dict(),
            "discriminator": scheduler["discriminator"].state_dict(),
        },
        "steps": steps,
        "epochs": epochs,
    }
    if dst_train:
        state_dict["model"] = {
            "generator": model["generator"].module.state_dict(),
            "discriminator": model["discriminator"].module.state_dict(),
        }
    else:
        state_dict["model"] = {
            "generator": model["generator"].state_dict(),
            "discriminator": model["discriminator"].state_dict(),
        }

    if not os.path.exists(os.path.dirname(checkpoint_path)):
        os.makedirs(os.path.dirname(checkpoint_path))
    torch.save(state_dict, checkpoint_path)


def load_checkpoint(model, optimizer, scheduler, checkpoint_path, load_only_params=False, dst_train=False):
    """Load checkpoint.

    Args:
        checkpoint_path (str): Checkpoint path to be loaded.
        load_only_params (bool): Whether to load only model parameters.

    """
    state_dict = torch.load(checkpoint_path, map_location="cpu")
    if dst_train:
        model["generator"].module.load_state_dict(
            state_dict["model"]["generator"]
        )
        model["discriminator"].module.load_state_dict(
            state_dict["model"]["discriminator"]
        )
    else:
        model["generator"].load_state_dict(state_dict["model"]["generator"])
        model["discriminator"].load_state_dict(
            state_dict["model"]["discriminator"]
        )
    optimizer["generator"].load_state_dict(
        state_dict["optimizer"]["generator"]
    )
    optimizer["discriminator"].load_state_dict(
        state_dict["optimizer"]["discriminator"]
    )
    scheduler["generator"].load_state_dict(
        state_dict["scheduler"]["generator"]
    )
    scheduler["discriminator"].load_state_dict(
        state_dict["scheduler"]["discriminator"]
    )
    
    steps = state_dict["steps"]
    epochs = state_dict["epochs"]

    return steps, epochs