File size: 3,466 Bytes
c42fe7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import argparse
import pathlib
import re


def modify_spk_embed(spk_embed):
    num_spk, hidden_size = spk_embed.shape
    all_ids = set(range(num_spk))
    if args.drop is not None:
        drop_ids = set([int(i) for i in args.drop.split(',') if i != '']).intersection(all_ids)
    else:
        drop_ids = all_ids - set([int(i) for i in args.retain.split(',') if i != ''])

    fill_list = None
    if args.fill == 'zeros':
        fill_list = [0. for _ in drop_ids]
    elif args.fill == 'random':
        fill_list = [torch.randn(1, hidden_size, dtype=torch.float32, device='cpu') for _ in drop_ids]
    elif args.fill == 'mean':
        mean = torch.mean(spk_embed, dim=0, keepdim=True)
        fill_list = [mean for _ in drop_ids]
    elif args.fill == 'cyclic':
        retain_ids = sorted(all_ids - drop_ids)
        num_retain = len(retain_ids)
        fill_list = [spk_embed[retain_ids[i % num_retain], :] for i, _ in enumerate(drop_ids)]

    for spk_id, fill in zip(sorted(drop_ids), fill_list):
        spk_embed[spk_id, :] = fill


parser = argparse.ArgumentParser(description='Drop or edit spk_embed in a checkpoint.')
parser.add_argument('input', type=str, help='Path to the input file')
parser.add_argument('output', type=str, help='Path to the output file')
drop_retain_group = parser.add_mutually_exclusive_group()
drop_retain_group.add_argument('--drop', type=str, required=False, metavar='ID,ID,...',
                               help='Drop specific speaker IDs.')
drop_retain_group.add_argument('--retain', type=str, required=False, metavar='ID,ID,...',
                               help='Retain specific speaker IDs and drop all the others.')
parser.add_argument('--fill', type=str, required=False, default='zeros', metavar='METHOD',
                    choices=['zeros', 'random', 'mean', 'cyclic'],
                    help='Specify a filling method for the dropped embedding. '
                         'Available methods: zeros, random, mean, cyclic')
parser.add_argument('--overwrite', required=False, default=False,
                    action='store_true', help='Overwrite if the output file exists.')
args = parser.parse_args()
assert args.drop is not None or args.retain is not None, 'Either --drop or --retain should be specified.'
if args.drop and not re.fullmatch(r'(\d+)?(,\d+)*,?', args.drop):
    print(f'Invalid format for --drop: \'{args.drop}\'')
    exit(-1)
if args.retain and not re.fullmatch(r'(\d+)?(,\d+)*,?', args.retain):
    print(f'Invalid format for --retain: \'{args.retain}\'')
    exit(-1)

import torch
input_ckpt = pathlib.Path(args.input).resolve()
output_ckpt = pathlib.Path(args.output).resolve()
assert input_ckpt.exists(), 'The input file does not exist.'
assert args.overwrite or not output_ckpt.exists(), \
    'The output file already exists or is the same as the input file.\n' \
    'This is not recommended because spk_embed dropping scripts may not be stable, ' \
    'and you may be at risk of losing your model.\n' \
    'If you are sure to OVERWRITE the existing file, please re-run this script with the \'--overwrite\' argument.'

ckpt_loaded = torch.load(input_ckpt, map_location='cpu')
state_dict = ckpt_loaded['state_dict']
if 'model.fs2.spk_embed.weight' in state_dict:
    modify_spk_embed(state_dict['model.fs2.spk_embed.weight'])
if 'model.spk_embed.weight' in state_dict:
    modify_spk_embed(state_dict['model.spk_embed.weight'])

torch.save(ckpt_loaded, output_ckpt)