Spaces:
Sleeping
Sleeping
import torch | |
import argparse | |
import pathlib | |
import re | |
def modify_spk_embed(spk_embed): | |
num_spk, hidden_size = spk_embed.shape | |
all_ids = set(range(num_spk)) | |
if args.drop is not None: | |
drop_ids = set([int(i) for i in args.drop.split(',') if i != '']).intersection(all_ids) | |
else: | |
drop_ids = all_ids - set([int(i) for i in args.retain.split(',') if i != '']) | |
fill_list = None | |
if args.fill == 'zeros': | |
fill_list = [0. for _ in drop_ids] | |
elif args.fill == 'random': | |
fill_list = [torch.randn(1, hidden_size, dtype=torch.float32, device='cpu') for _ in drop_ids] | |
elif args.fill == 'mean': | |
mean = torch.mean(spk_embed, dim=0, keepdim=True) | |
fill_list = [mean for _ in drop_ids] | |
elif args.fill == 'cyclic': | |
retain_ids = sorted(all_ids - drop_ids) | |
num_retain = len(retain_ids) | |
fill_list = [spk_embed[retain_ids[i % num_retain], :] for i, _ in enumerate(drop_ids)] | |
for spk_id, fill in zip(sorted(drop_ids), fill_list): | |
spk_embed[spk_id, :] = fill | |
parser = argparse.ArgumentParser(description='Drop or edit spk_embed in a checkpoint.') | |
parser.add_argument('input', type=str, help='Path to the input file') | |
parser.add_argument('output', type=str, help='Path to the output file') | |
drop_retain_group = parser.add_mutually_exclusive_group() | |
drop_retain_group.add_argument('--drop', type=str, required=False, metavar='ID,ID,...', | |
help='Drop specific speaker IDs.') | |
drop_retain_group.add_argument('--retain', type=str, required=False, metavar='ID,ID,...', | |
help='Retain specific speaker IDs and drop all the others.') | |
parser.add_argument('--fill', type=str, required=False, default='zeros', metavar='METHOD', | |
choices=['zeros', 'random', 'mean', 'cyclic'], | |
help='Specify a filling method for the dropped embedding. ' | |
'Available methods: zeros, random, mean, cyclic') | |
parser.add_argument('--overwrite', required=False, default=False, | |
action='store_true', help='Overwrite if the output file exists.') | |
args = parser.parse_args() | |
assert args.drop is not None or args.retain is not None, 'Either --drop or --retain should be specified.' | |
if args.drop and not re.fullmatch(r'(\d+)?(,\d+)*,?', args.drop): | |
print(f'Invalid format for --drop: \'{args.drop}\'') | |
exit(-1) | |
if args.retain and not re.fullmatch(r'(\d+)?(,\d+)*,?', args.retain): | |
print(f'Invalid format for --retain: \'{args.retain}\'') | |
exit(-1) | |
import torch | |
input_ckpt = pathlib.Path(args.input).resolve() | |
output_ckpt = pathlib.Path(args.output).resolve() | |
assert input_ckpt.exists(), 'The input file does not exist.' | |
assert args.overwrite or not output_ckpt.exists(), \ | |
'The output file already exists or is the same as the input file.\n' \ | |
'This is not recommended because spk_embed dropping scripts may not be stable, ' \ | |
'and you may be at risk of losing your model.\n' \ | |
'If you are sure to OVERWRITE the existing file, please re-run this script with the \'--overwrite\' argument.' | |
ckpt_loaded = torch.load(input_ckpt, map_location='cpu') | |
state_dict = ckpt_loaded['state_dict'] | |
if 'model.fs2.spk_embed.weight' in state_dict: | |
modify_spk_embed(state_dict['model.fs2.spk_embed.weight']) | |
if 'model.spk_embed.weight' in state_dict: | |
modify_spk_embed(state_dict['model.spk_embed.weight']) | |
torch.save(ckpt_loaded, output_ckpt) | |