import torch import argparse import pathlib import re def modify_spk_embed(spk_embed): num_spk, hidden_size = spk_embed.shape all_ids = set(range(num_spk)) if args.drop is not None: drop_ids = set([int(i) for i in args.drop.split(',') if i != '']).intersection(all_ids) else: drop_ids = all_ids - set([int(i) for i in args.retain.split(',') if i != '']) fill_list = None if args.fill == 'zeros': fill_list = [0. for _ in drop_ids] elif args.fill == 'random': fill_list = [torch.randn(1, hidden_size, dtype=torch.float32, device='cpu') for _ in drop_ids] elif args.fill == 'mean': mean = torch.mean(spk_embed, dim=0, keepdim=True) fill_list = [mean for _ in drop_ids] elif args.fill == 'cyclic': retain_ids = sorted(all_ids - drop_ids) num_retain = len(retain_ids) fill_list = [spk_embed[retain_ids[i % num_retain], :] for i, _ in enumerate(drop_ids)] for spk_id, fill in zip(sorted(drop_ids), fill_list): spk_embed[spk_id, :] = fill parser = argparse.ArgumentParser(description='Drop or edit spk_embed in a checkpoint.') parser.add_argument('input', type=str, help='Path to the input file') parser.add_argument('output', type=str, help='Path to the output file') drop_retain_group = parser.add_mutually_exclusive_group() drop_retain_group.add_argument('--drop', type=str, required=False, metavar='ID,ID,...', help='Drop specific speaker IDs.') drop_retain_group.add_argument('--retain', type=str, required=False, metavar='ID,ID,...', help='Retain specific speaker IDs and drop all the others.') parser.add_argument('--fill', type=str, required=False, default='zeros', metavar='METHOD', choices=['zeros', 'random', 'mean', 'cyclic'], help='Specify a filling method for the dropped embedding. ' 'Available methods: zeros, random, mean, cyclic') parser.add_argument('--overwrite', required=False, default=False, action='store_true', help='Overwrite if the output file exists.') args = parser.parse_args() assert args.drop is not None or args.retain is not None, 'Either --drop or --retain should be specified.' if args.drop and not re.fullmatch(r'(\d+)?(,\d+)*,?', args.drop): print(f'Invalid format for --drop: \'{args.drop}\'') exit(-1) if args.retain and not re.fullmatch(r'(\d+)?(,\d+)*,?', args.retain): print(f'Invalid format for --retain: \'{args.retain}\'') exit(-1) import torch input_ckpt = pathlib.Path(args.input).resolve() output_ckpt = pathlib.Path(args.output).resolve() assert input_ckpt.exists(), 'The input file does not exist.' assert args.overwrite or not output_ckpt.exists(), \ 'The output file already exists or is the same as the input file.\n' \ 'This is not recommended because spk_embed dropping scripts may not be stable, ' \ 'and you may be at risk of losing your model.\n' \ 'If you are sure to OVERWRITE the existing file, please re-run this script with the \'--overwrite\' argument.' ckpt_loaded = torch.load(input_ckpt, map_location='cpu') state_dict = ckpt_loaded['state_dict'] if 'model.fs2.spk_embed.weight' in state_dict: modify_spk_embed(state_dict['model.fs2.spk_embed.weight']) if 'model.spk_embed.weight' in state_dict: modify_spk_embed(state_dict['model.spk_embed.weight']) torch.save(ckpt_loaded, output_ckpt)