liampond commited on
Commit
c42fe7e
·
0 Parent(s):

Clean deploy snapshot

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .bashrc +9 -0
  2. .devcontainer/devcontainer.json +33 -0
  3. .gitignore +34 -0
  4. .streamlit/config.toml +3 -0
  5. .vscode/settings.json +6 -0
  6. LICENSE +21 -0
  7. Makefile +26 -0
  8. README.md +99 -0
  9. app.sh +2 -0
  10. augmentation/spec_stretch.py +92 -0
  11. basics/base_augmentation.py +28 -0
  12. basics/base_binarizer.py +347 -0
  13. basics/base_dataset.py +58 -0
  14. basics/base_exporter.py +59 -0
  15. basics/base_module.py +18 -0
  16. basics/base_pe.py +7 -0
  17. basics/base_svs_infer.py +131 -0
  18. basics/base_task.py +520 -0
  19. basics/base_vocoder.py +23 -0
  20. configs/CantusSVS_acoustic.yaml +149 -0
  21. configs/CantusSVS_variance.yaml +153 -0
  22. configs/base.yaml +94 -0
  23. configs/defaults/acoustic.yaml +138 -0
  24. configs/defaults/base.yaml +94 -0
  25. configs/defaults/variance.yaml +145 -0
  26. configs/templates/config_acoustic.yaml +105 -0
  27. configs/templates/config_variance.yaml +129 -0
  28. deployment/.gitignore +7 -0
  29. deployment/__init__.py +0 -0
  30. deployment/benchmarks/infer_acoustic.py +32 -0
  31. deployment/benchmarks/infer_nsf_hifigan.py +16 -0
  32. deployment/exporters/__init__.py +3 -0
  33. deployment/exporters/acoustic_exporter.py +405 -0
  34. deployment/exporters/nsf_hifigan_exporter.py +120 -0
  35. deployment/exporters/variance_exporter.py +781 -0
  36. deployment/modules/__init__.py +0 -0
  37. deployment/modules/diffusion.py +220 -0
  38. deployment/modules/fastspeech2.py +153 -0
  39. deployment/modules/nsf_hifigan.py +16 -0
  40. deployment/modules/rectified_flow.py +123 -0
  41. deployment/modules/toplevel.py +392 -0
  42. dictionaries/.gitignore +3 -0
  43. docs/BestPractices.md +618 -0
  44. docs/ConfigurationSchemas.md +2109 -0
  45. docs/GettingStarted.md +164 -0
  46. docs/resources/arch-acoustic.drawio +0 -0
  47. docs/resources/arch-overview.drawio +123 -0
  48. docs/resources/arch-overview.jpg +0 -0
  49. docs/resources/arch-variance.drawio +0 -0
  50. inference/dpm_solver_pytorch.py +1305 -0
.bashrc ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Warn if Arrow is not loaded
2
+ if ! module list 2>&1 | grep -q arrow; then
3
+ echo -e "\n\033[1;33m[WARNING]\033[0m Arrow module is not loaded! Run: module load arrow/19.0.1"
4
+ fi
5
+
6
+ # Warn if GCC is not loaded
7
+ if ! module list 2>&1 | grep -q gcc; then
8
+ echo -e "\n\033[1;33m[WARNING]\033[0m GCC module is not loaded! Run: module load gcc/12.3"
9
+ fi
.devcontainer/devcontainer.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "Python 3",
3
+ // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
4
+ "image": "mcr.microsoft.com/devcontainers/python:1-3.11-bullseye",
5
+ "customizations": {
6
+ "codespaces": {
7
+ "openFiles": [
8
+ "README.md",
9
+ "webapp/app.py"
10
+ ]
11
+ },
12
+ "vscode": {
13
+ "settings": {},
14
+ "extensions": [
15
+ "ms-python.python",
16
+ "ms-python.vscode-pylance"
17
+ ]
18
+ }
19
+ },
20
+ "updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y <packages.txt; [ -f requirements.txt ] && pip3 install --user -r requirements.txt; pip3 install --user streamlit; echo '✅ Packages installed and Requirements met'",
21
+ "postAttachCommand": {
22
+ "server": "streamlit run webapp/app.py --server.enableCORS false --server.enableXsrfProtection false"
23
+ },
24
+ "portsAttributes": {
25
+ "8501": {
26
+ "label": "Application",
27
+ "onAutoForward": "openPreview"
28
+ }
29
+ },
30
+ "forwardPorts": [
31
+ 8501
32
+ ]
33
+ }
.gitignore ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ artifacts/
2
+ checkpoints/
3
+ venv/
4
+ .venv/
5
+ svenv/
6
+ env-py311/
7
+ *.swp
8
+ data/
9
+ DiffSinger/
10
+ training_logs/
11
+ training_*.out
12
+ __pycache__/
13
+
14
+ *.onnx
15
+ *.ckpt
16
+ *.pt
17
+ *.zip
18
+ *.out
19
+ *.log
20
+ *.pyc
21
+ *.pyo
22
+ *.swp
23
+ artifacts/
24
+ data/NNSVS_training_data/
25
+ data/binary/
26
+ outputs/
27
+
28
+ commands.txt
29
+ 19Apr-Commands.sh
30
+
31
+ webapp/output/
32
+ webapp/tmp_ds/
33
+ webapp/uploaded_ds/
34
+ webapp/uploaded_mei/
.streamlit/config.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [server]
2
+ fileWatcherType = "none"
3
+ runOnSave = false
.vscode/settings.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "python.defaultInterpreterPath": "webapp/venv/bin/python",
3
+ "python.analysis.extraPaths": ["./inference"],
4
+ "python.analysis.indexing": true,
5
+ "python.analysis.typeCheckingMode": "off"
6
+ }
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Liam Pond
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Makefile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: help setup clean run
2
+
3
+ # Default command: show help
4
+ help:
5
+ @echo ""
6
+ @echo "Available commands:"
7
+ @echo " make setup - Set up virtual environment (use 'make setup reset=1' to force rebuild)"
8
+ @echo " make clean - Remove the virtual environment"
9
+ @echo " make run - Launch the Streamlit app"
10
+ @echo ""
11
+
12
+ # Set up the environment
13
+ setup:
14
+ ifeq ($(reset),1)
15
+ rm -rf venv
16
+ endif
17
+ bash scripts/setup_env.sh
18
+
19
+ # Remove the virtual environment
20
+ clean:
21
+ rm -rf venv
22
+
23
+ # Run the Streamlit app
24
+ run:
25
+ source venv/bin/activate && streamlit run webapp/app.py
26
+
README.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CantusSVS
2
+
3
+ ## Table of Contents
4
+ - [About CantusSVS](#about-cantussvs)
5
+ - [Quick Start](#quick-start)
6
+ - [Preparing Your Input](#preparing-your-input)
7
+ - [Running Locally](#running-locally)
8
+ - [FAQ](#faq)
9
+
10
+ ---
11
+
12
+ ## About CantusSVS
13
+
14
+ CantusSVS is a singing voice synthesis tool that automatically generates audio playback for the Latin chants in Cantus. You can access CantusSVS directly in the browser here [**https://cantussvs.streamlit.app**](https://cantussvs.streamlit.app). For training and inferencing, we use **DiffSinger**, a diffusion-based singing voice synthesis model described in the paper below:
15
+
16
+ **DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism**
17
+
18
+ Liu, Jinglin, Chengxi Li, Yi Ren, Feiyang Chen, and Zhou Zhao. 2022. "Diffsinger: Singing Voice Synthesis via Shallow Diffusion Mechanism." In *Proceedings of the AAAI Conference on Artificial Intelligence* 36 10: 11020–11028. [https://arxiv.org/abs/2105.02446](http://dx.doi.org/10.1609/aaai.v36i10.21350).
19
+
20
+ Training was done using Cedar, a cluster provided by the Digital Research Alliance of Canada. To set up training locally, follow [this tutorial](https://youtu.be/Sxt11TAflV0?feature=shared) by [tigermeat](https://www.youtube.com/@spicytigermeat).
21
+
22
+ For general help training and creating a dataset, [this tutorial](https://docs.google.com/document/d/1uMsepxbdUW65PfIWL1pt2OM6ZKa5ybTTJOpZ733Ht6s/view) by [PixPrucer](https://bsky.app/profile/pixprucer.bsky.social) is an excellent guide. For help, join the [DiffSinger Discord server](https://discord.gg/DZ6fhEUfnb).
23
+
24
+ The dataset used for this project was built using [*Adventus: Dominica prima adventus Domini*](https://youtu.be/ThnPySybDJs?feature=shared), the first track from [Psallentes](https://psallentes.com/)' album *Salzinnes Saints*. Psallentes is a Belgian women's chorus that specializes in Late Medieval and Renaissance music. *Salzinnes Saints* is an album of music from the [Salzinnes Antiphonal](https://www.smu.ca/academics/archives/the-salzinnes-antiphonal.html), a mid-sixteenth century choirbook with the music and text for the Liturgy of the Hours.
25
+
26
+ ---
27
+
28
+ ## Quick Start
29
+
30
+ 1. Clone the repository:
31
+
32
+ ```bash
33
+ git clone https://github.com/yourusername/CantusSVS.git
34
+ cd CantusSVS
35
+ ```
36
+
37
+ 2. Set up the environment:
38
+
39
+ ```bash
40
+ make setup
41
+ ```
42
+
43
+ 3. Run the web app locally:
44
+
45
+ ```bash
46
+ make run
47
+ ```
48
+
49
+ 4. Open your browser at:
50
+
51
+ ```
52
+ http://localhost:8501
53
+ ```
54
+
55
+ Or just use the hosted app here: [https://cantussvs.streamlit.app](https://cantussvs.streamlit.app)
56
+
57
+ ---
58
+
59
+ ## Preparing Your Input
60
+
61
+ - Most commercial music composition software can export `.mei` files. MuseScore 4 is free to use.
62
+ - Input format must be `.mei` (Music Encoding Initiative XML).
63
+ - Only **monophonic** scores are supported (one staff, one voice).
64
+ - Lyrics must be embedded in the MEI file and aligned with notes.
65
+
66
+ Validation tool:
67
+
68
+ ```bash
69
+ python scripts/validate_mei.py your_song.mei
70
+ ```
71
+
72
+ ---
73
+
74
+ ## Running Locally
75
+
76
+ 1. Drop your `.mei` file into the upload area of the web app.
77
+
78
+ 2. Choose settings:
79
+ - Tempo (BPM)
80
+ - Output file name (optional)
81
+
82
+ 3. Hit "Synthesize" and download the resulting `.wav` file.
83
+
84
+ Generated files:
85
+ - `.wav`: final audio output
86
+ - `.mel.npy`: intermediate mel-spectrogram
87
+ - `.info.json`: metadata (phoneme sequence, note mapping)
88
+
89
+ ---
90
+
91
+ ## FAQ
92
+
93
+ **Q: Can I synthesize polyphonic (multi-voice) chants?**
94
+ A: No, only monophonic scores are supported currently. However, in the future, polyphonic chants could be synthesized by layering multiple monophonic voices.
95
+
96
+ **Q: Can I change the voice timbre?**
97
+ A: In the webapp, only the provided pre-trained model is available. However, DiffSinger will learn the timbre of the input dataset so if you train your own model, you can control the timbre that way.
98
+
99
+ ---
app.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ streamlit run webapp/app.py
augmentation/spec_stretch.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ import librosa
4
+ import numpy as np
5
+ import torch
6
+
7
+ from basics.base_augmentation import BaseAugmentation, require_same_keys
8
+ from basics.base_pe import BasePE
9
+ from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST
10
+ from modules.fastspeech.tts_modules import LengthRegulator
11
+ from utils.binarizer_utils import get_mel_torch, get_mel2ph_torch
12
+ from utils.hparams import hparams
13
+ from utils.infer_utils import resample_align_curve
14
+
15
+
16
+ class SpectrogramStretchAugmentation(BaseAugmentation):
17
+ """
18
+ This class contains methods for frequency-domain and time-domain stretching augmentation.
19
+ """
20
+
21
+ def __init__(self, data_dirs: list, augmentation_args: dict, pe: BasePE = None):
22
+ super().__init__(data_dirs, augmentation_args)
23
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
+ self.lr = LengthRegulator().to(self.device)
25
+ self.pe = pe
26
+
27
+ @require_same_keys
28
+ def process_item(self, item: dict, key_shift=0., speed=1., replace_spk_id=None) -> dict:
29
+ aug_item = deepcopy(item)
30
+ waveform, _ = librosa.load(aug_item['wav_fn'], sr=hparams['audio_sample_rate'], mono=True)
31
+ mel = get_mel_torch(
32
+ waveform, hparams['audio_sample_rate'], num_mel_bins=hparams['audio_num_mel_bins'],
33
+ hop_size=hparams['hop_size'], win_size=hparams['win_size'], fft_size=hparams['fft_size'],
34
+ fmin=hparams['fmin'], fmax=hparams['fmax'],
35
+ keyshift=key_shift, speed=speed, device=self.device
36
+ )
37
+
38
+ aug_item['mel'] = mel
39
+
40
+ if speed != 1. or hparams['use_speed_embed']:
41
+ aug_item['length'] = mel.shape[0]
42
+ aug_item['speed'] = int(np.round(hparams['hop_size'] * speed)) / hparams['hop_size'] # real speed
43
+ aug_item['seconds'] /= aug_item['speed']
44
+ aug_item['ph_dur'] /= aug_item['speed']
45
+ aug_item['mel2ph'] = get_mel2ph_torch(
46
+ self.lr, torch.from_numpy(aug_item['ph_dur']), aug_item['length'], self.timestep, device=self.device
47
+ ).cpu().numpy()
48
+
49
+ f0, _ = self.pe.get_pitch(
50
+ waveform, samplerate=hparams['audio_sample_rate'], length=aug_item['length'],
51
+ hop_size=hparams['hop_size'], f0_min=hparams['f0_min'], f0_max=hparams['f0_max'],
52
+ speed=speed, interp_uv=True
53
+ )
54
+ aug_item['f0'] = f0.astype(np.float32)
55
+
56
+ # NOTE: variance curves are directly resampled according to speed,
57
+ # despite how frequency-domain features change after the augmentation.
58
+ # For acoustic models, this can bring more (but not much) difficulty
59
+ # to learn how variance curves affect the mel spectrograms, since
60
+ # they must realize how the augmentation causes the mismatch.
61
+ #
62
+ # This is a simple way to combine augmentation and variances. However,
63
+ # dealing variance curves like this will decrease the accuracy of
64
+ # variance controls. In most situations, not being ~100% accurate
65
+ # will not ruin the user experience. For example, it does not matter
66
+ # if the energy does not exactly equal the RMS; it is just fine
67
+ # as long as higher energy can bring higher loudness and strength.
68
+ # The neural networks itself cannot be 100% accurate, though.
69
+ #
70
+ # There are yet other choices to simulate variance curves:
71
+ # 1. Re-extract the features from resampled waveforms;
72
+ # 2. Re-extract the features from re-constructed waveforms using
73
+ # the transformed mel spectrograms through the vocoder.
74
+ # But there are actually no perfect ways to make them all accurate
75
+ # and stable.
76
+ for v_name in VARIANCE_CHECKLIST:
77
+ if v_name in item:
78
+ aug_item[v_name] = resample_align_curve(
79
+ aug_item[v_name],
80
+ original_timestep=self.timestep,
81
+ target_timestep=self.timestep * aug_item['speed'],
82
+ align_length=aug_item['length']
83
+ )
84
+
85
+ if key_shift != 0. or hparams['use_key_shift_embed']:
86
+ if replace_spk_id is None:
87
+ aug_item['key_shift'] = key_shift
88
+ else:
89
+ aug_item['spk_id'] = replace_spk_id
90
+ aug_item['f0'] *= 2 ** (key_shift / 12)
91
+
92
+ return aug_item
basics/base_augmentation.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.hparams import hparams
2
+
3
+
4
+ class BaseAugmentation:
5
+ """
6
+ Base class for data augmentation.
7
+ All methods of this class should be thread-safe.
8
+ 1. *process_item*:
9
+ Apply augmentation to one piece of data.
10
+ """
11
+ def __init__(self, data_dirs: list, augmentation_args: dict):
12
+ self.raw_data_dirs = data_dirs
13
+ self.augmentation_args = augmentation_args
14
+ self.timestep = hparams['hop_size'] / hparams['audio_sample_rate']
15
+
16
+ def process_item(self, item: dict, **kwargs) -> dict:
17
+ raise NotImplementedError()
18
+
19
+
20
+ def require_same_keys(func):
21
+ def run(*args, **kwargs):
22
+ item: dict = args[1]
23
+ res: dict = func(*args, **kwargs)
24
+ assert set(item.keys()) == set(res.keys()), 'Item keys mismatch after augmentation.\n' \
25
+ f'Before: {sorted(item.keys())}\n' \
26
+ f'After: {sorted(res.keys())}'
27
+ return res
28
+ return run
basics/base_binarizer.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pathlib
3
+ import pickle
4
+ import random
5
+ import shutil
6
+ import warnings
7
+ from copy import deepcopy
8
+
9
+ import numpy as np
10
+ import torch
11
+ from tqdm import tqdm
12
+
13
+ from utils.hparams import hparams
14
+ from utils.indexed_datasets import IndexedDatasetBuilder
15
+ from utils.multiprocess_utils import chunked_multiprocess_run
16
+ from utils.phoneme_utils import build_phoneme_list, locate_dictionary
17
+ from utils.plot import distribution_to_figure
18
+ from utils.text_encoder import TokenTextEncoder
19
+
20
+
21
+ class BinarizationError(Exception):
22
+ pass
23
+
24
+
25
+ class BaseBinarizer:
26
+ """
27
+ Base class for data processing.
28
+ 1. *process* and *process_data_split*:
29
+ process entire data, generate the train-test split (support parallel processing);
30
+ 2. *process_item*:
31
+ process singe piece of data;
32
+ 3. *get_pitch*:
33
+ infer the pitch using some algorithm;
34
+ 4. *get_align*:
35
+ get the alignment using 'mel2ph' format (see https://arxiv.org/abs/1905.09263).
36
+ 5. phoneme encoder, voice encoder, etc.
37
+
38
+ Subclasses should define:
39
+ 1. *load_metadata*:
40
+ how to read multiple datasets from files;
41
+ 2. *train_item_names*, *valid_item_names*, *test_item_names*:
42
+ how to split the dataset;
43
+ 3. load_ph_set:
44
+ the phoneme set.
45
+ """
46
+
47
+ def __init__(self, data_dir=None, data_attrs=None):
48
+ if data_dir is None:
49
+ data_dir = hparams['raw_data_dir']
50
+ if not isinstance(data_dir, list):
51
+ data_dir = [data_dir]
52
+
53
+ self.raw_data_dirs = [pathlib.Path(d) for d in data_dir]
54
+ self.binary_data_dir = pathlib.Path(hparams['binary_data_dir'])
55
+ self.data_attrs = [] if data_attrs is None else data_attrs
56
+
57
+ self.binarization_args = hparams['binarization_args']
58
+ self.augmentation_args = hparams.get('augmentation_args', {})
59
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
60
+
61
+ self.spk_map = None
62
+ self.spk_ids = hparams['spk_ids']
63
+ self.speakers = hparams['speakers']
64
+ self.build_spk_map()
65
+
66
+ self.items = {}
67
+ self.item_names: list = None
68
+ self._train_item_names: list = None
69
+ self._valid_item_names: list = None
70
+
71
+ self.phone_encoder = TokenTextEncoder(vocab_list=build_phoneme_list())
72
+ self.timestep = hparams['hop_size'] / hparams['audio_sample_rate']
73
+
74
+ def build_spk_map(self):
75
+ assert isinstance(self.speakers, list), 'Speakers must be a list'
76
+ assert len(self.speakers) == len(self.raw_data_dirs), \
77
+ 'Number of raw data dirs must equal number of speaker names!'
78
+ if len(self.spk_ids) == 0:
79
+ self.spk_ids = list(range(len(self.raw_data_dirs)))
80
+ else:
81
+ assert len(self.spk_ids) == len(self.raw_data_dirs), \
82
+ 'Length of explicitly given spk_ids must equal the number of raw datasets.'
83
+ assert max(self.spk_ids) < hparams['num_spk'], \
84
+ f'Index in spk_id sequence {self.spk_ids} is out of range. All values should be smaller than num_spk.'
85
+
86
+ self.spk_map = {}
87
+ for spk_name, spk_id in zip(self.speakers, self.spk_ids):
88
+ if spk_name in self.spk_map and self.spk_map[spk_name] != spk_id:
89
+ raise ValueError(f'Invalid speaker ID assignment. Name \'{spk_name}\' is assigned '
90
+ f'with different speaker IDs: {self.spk_map[spk_name]} and {spk_id}.')
91
+ self.spk_map[spk_name] = spk_id
92
+
93
+ print("| spk_map: ", self.spk_map)
94
+
95
+ def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id):
96
+ raise NotImplementedError()
97
+
98
+ def split_train_valid_set(self, item_names):
99
+ """
100
+ Split the dataset into training set and validation set.
101
+ :return: train_item_names, valid_item_names
102
+ """
103
+ prefixes = {str(pr): 1 for pr in hparams['test_prefixes']}
104
+ valid_item_names = {}
105
+ # Add prefixes that specified speaker index and matches exactly item name to test set
106
+ for prefix in deepcopy(prefixes):
107
+ if prefix in item_names:
108
+ valid_item_names[prefix] = 1
109
+ prefixes.pop(prefix)
110
+ # Add prefixes that exactly matches item name without speaker id to test set
111
+ for prefix in deepcopy(prefixes):
112
+ matched = False
113
+ for name in item_names:
114
+ if name.split(':')[-1] == prefix:
115
+ valid_item_names[name] = 1
116
+ matched = True
117
+ if matched:
118
+ prefixes.pop(prefix)
119
+ # Add names with one of the remaining prefixes to test set
120
+ for prefix in deepcopy(prefixes):
121
+ matched = False
122
+ for name in item_names:
123
+ if name.startswith(prefix):
124
+ valid_item_names[name] = 1
125
+ matched = True
126
+ if matched:
127
+ prefixes.pop(prefix)
128
+ for prefix in deepcopy(prefixes):
129
+ matched = False
130
+ for name in item_names:
131
+ if name.split(':')[-1].startswith(prefix):
132
+ valid_item_names[name] = 1
133
+ matched = True
134
+ if matched:
135
+ prefixes.pop(prefix)
136
+
137
+ if len(prefixes) != 0:
138
+ warnings.warn(
139
+ f'The following rules in test_prefixes have no matching names in the dataset: {", ".join(prefixes.keys())}',
140
+ category=UserWarning
141
+ )
142
+ warnings.filterwarnings('default')
143
+
144
+ valid_item_names = list(valid_item_names.keys())
145
+ assert len(valid_item_names) > 0, 'Validation set is empty!'
146
+ train_item_names = [x for x in item_names if x not in set(valid_item_names)]
147
+ assert len(train_item_names) > 0, 'Training set is empty!'
148
+
149
+ return train_item_names, valid_item_names
150
+
151
+ @property
152
+ def train_item_names(self):
153
+ return self._train_item_names
154
+
155
+ @property
156
+ def valid_item_names(self):
157
+ return self._valid_item_names
158
+
159
+ def meta_data_iterator(self, prefix):
160
+ if prefix == 'train':
161
+ item_names = self.train_item_names
162
+ else:
163
+ item_names = self.valid_item_names
164
+ for item_name in item_names:
165
+ meta_data = self.items[item_name]
166
+ yield item_name, meta_data
167
+
168
+ def process(self):
169
+ # load each dataset
170
+ for ds_id, spk_id, data_dir in zip(range(len(self.raw_data_dirs)), self.spk_ids, self.raw_data_dirs):
171
+ self.load_meta_data(pathlib.Path(data_dir), ds_id=ds_id, spk_id=spk_id)
172
+ self.item_names = sorted(list(self.items.keys()))
173
+ self._train_item_names, self._valid_item_names = self.split_train_valid_set(self.item_names)
174
+
175
+ if self.binarization_args['shuffle']:
176
+ random.shuffle(self.item_names)
177
+
178
+ self.binary_data_dir.mkdir(parents=True, exist_ok=True)
179
+
180
+ # Copy spk_map and dictionary to binary data dir
181
+ spk_map_fn = self.binary_data_dir / 'spk_map.json'
182
+ with open(spk_map_fn, 'w', encoding='utf-8') as f:
183
+ json.dump(self.spk_map, f)
184
+ shutil.copy(locate_dictionary(), self.binary_data_dir / 'dictionary.txt')
185
+ self.check_coverage()
186
+
187
+ # Process valid set and train set
188
+ try:
189
+ self.process_dataset('valid')
190
+ self.process_dataset(
191
+ 'train',
192
+ num_workers=int(self.binarization_args['num_workers']),
193
+ apply_augmentation=any(args['enabled'] for args in self.augmentation_args.values())
194
+ )
195
+ except KeyboardInterrupt:
196
+ exit(-1)
197
+
198
+ def check_coverage(self):
199
+ # Group by phonemes in the dictionary.
200
+ ph_required = set(build_phoneme_list())
201
+ phoneme_map = {}
202
+ for ph in ph_required:
203
+ phoneme_map[ph] = 0
204
+ ph_occurred = []
205
+
206
+ # Load and count those phones that appear in the actual data
207
+ for item_name in self.items:
208
+ ph_occurred += self.items[item_name]['ph_seq']
209
+ if len(ph_occurred) == 0:
210
+ raise BinarizationError(f'Empty tokens in {item_name}.')
211
+ for ph in ph_occurred:
212
+ if ph not in ph_required:
213
+ continue
214
+ phoneme_map[ph] += 1
215
+ ph_occurred = set(ph_occurred)
216
+
217
+ print('===== Phoneme Distribution Summary =====')
218
+ for i, key in enumerate(sorted(phoneme_map.keys())):
219
+ if i == len(ph_required) - 1:
220
+ end = '\n'
221
+ elif i % 10 == 9:
222
+ end = ',\n'
223
+ else:
224
+ end = ', '
225
+ print(f'\'{key}\': {phoneme_map[key]}', end=end)
226
+
227
+ # Draw graph.
228
+ x = sorted(phoneme_map.keys())
229
+ values = [phoneme_map[k] for k in x]
230
+ plt = distribution_to_figure(
231
+ title='Phoneme Distribution Summary',
232
+ x_label='Phoneme', y_label='Number of occurrences',
233
+ items=x, values=values
234
+ )
235
+ filename = self.binary_data_dir / 'phoneme_distribution.jpg'
236
+ plt.savefig(fname=filename,
237
+ bbox_inches='tight',
238
+ pad_inches=0.25)
239
+ print(f'| save summary to \'{filename}\'')
240
+
241
+ # Check unrecognizable or missing phonemes
242
+ if ph_occurred != ph_required:
243
+ unrecognizable_phones = ph_occurred.difference(ph_required)
244
+ missing_phones = ph_required.difference(ph_occurred)
245
+ raise BinarizationError('transcriptions and dictionary mismatch.\n'
246
+ f' (+) {sorted(unrecognizable_phones)}\n'
247
+ f' (-) {sorted(missing_phones)}')
248
+
249
+ def process_dataset(self, prefix, num_workers=0, apply_augmentation=False):
250
+ args = []
251
+ builder = IndexedDatasetBuilder(self.binary_data_dir, prefix=prefix, allowed_attr=self.data_attrs)
252
+ total_sec = {k: 0.0 for k in self.spk_map}
253
+ total_raw_sec = {k: 0.0 for k in self.spk_map}
254
+ extra_info = {'names': {}, 'spk_ids': {}, 'spk_names': {}, 'lengths': {}}
255
+ max_no = -1
256
+
257
+ for item_name, meta_data in self.meta_data_iterator(prefix):
258
+ args.append([item_name, meta_data, self.binarization_args])
259
+
260
+ aug_map = self.arrange_data_augmentation(self.meta_data_iterator(prefix)) if apply_augmentation else {}
261
+
262
+ def postprocess(_item):
263
+ nonlocal total_sec, total_raw_sec, extra_info, max_no
264
+ if _item is None:
265
+ return
266
+ item_no = builder.add_item(_item)
267
+ max_no = max(max_no, item_no)
268
+ for k, v in _item.items():
269
+ if isinstance(v, np.ndarray):
270
+ if k not in extra_info:
271
+ extra_info[k] = {}
272
+ extra_info[k][item_no] = v.shape[0]
273
+ extra_info['names'][item_no] = _item['name'].split(':', 1)[-1]
274
+ extra_info['spk_ids'][item_no] = _item['spk_id']
275
+ extra_info['spk_names'][item_no] = _item['spk_name']
276
+ extra_info['lengths'][item_no] = _item['length']
277
+ total_raw_sec[_item['spk_name']] += _item['seconds']
278
+ total_sec[_item['spk_name']] += _item['seconds']
279
+
280
+ for task in aug_map.get(_item['name'], []):
281
+ aug_item = task['func'](_item, **task['kwargs'])
282
+ aug_item_no = builder.add_item(aug_item)
283
+ max_no = max(max_no, aug_item_no)
284
+ for k, v in aug_item.items():
285
+ if isinstance(v, np.ndarray):
286
+ if k not in extra_info:
287
+ extra_info[k] = {}
288
+ extra_info[k][aug_item_no] = v.shape[0]
289
+ extra_info['names'][aug_item_no] = aug_item['name'].split(':', 1)[-1]
290
+ extra_info['spk_ids'][aug_item_no] = aug_item['spk_id']
291
+ extra_info['spk_names'][aug_item_no] = aug_item['spk_name']
292
+ extra_info['lengths'][aug_item_no] = aug_item['length']
293
+ total_sec[aug_item['spk_name']] += aug_item['seconds']
294
+
295
+ try:
296
+ if num_workers > 0:
297
+ # code for parallel processing
298
+ for item in tqdm(
299
+ chunked_multiprocess_run(self.process_item, args, num_workers=num_workers),
300
+ total=len(list(self.meta_data_iterator(prefix)))
301
+ ):
302
+ postprocess(item)
303
+ else:
304
+ # code for single cpu processing
305
+ for a in tqdm(args):
306
+ item = self.process_item(*a)
307
+ postprocess(item)
308
+ for k in extra_info:
309
+ assert set(extra_info[k]) == set(range(max_no + 1)), f'Item numbering is not consecutive.'
310
+ extra_info[k] = list(map(lambda x: x[1], sorted(extra_info[k].items(), key=lambda x: x[0])))
311
+ except KeyboardInterrupt:
312
+ builder.finalize()
313
+ raise
314
+
315
+ builder.finalize()
316
+ if prefix == "train":
317
+ extra_info.pop("names")
318
+ extra_info.pop("spk_names")
319
+ with open(self.binary_data_dir / f"{prefix}.meta", "wb") as f:
320
+ # noinspection PyTypeChecker
321
+ pickle.dump(extra_info, f)
322
+ if apply_augmentation:
323
+ print(f"| {prefix} total duration (before augmentation): {sum(total_raw_sec.values()):.2f}s")
324
+ print(
325
+ f"| {prefix} respective duration (before augmentation): "
326
+ + ', '.join(f'{k}={v:.2f}s' for k, v in total_raw_sec.items())
327
+ )
328
+ print(
329
+ f"| {prefix} total duration (after augmentation): "
330
+ f"{sum(total_sec.values()):.2f}s ({sum(total_sec.values()) / sum(total_raw_sec.values()):.2f}x)"
331
+ )
332
+ print(
333
+ f"| {prefix} respective duration (after augmentation): "
334
+ + ', '.join(f'{k}={v:.2f}s' for k, v in total_sec.items())
335
+ )
336
+ else:
337
+ print(f"| {prefix} total duration: {sum(total_raw_sec.values()):.2f}s")
338
+ print(f"| {prefix} respective duration: " + ', '.join(f'{k}={v:.2f}s' for k, v in total_raw_sec.items()))
339
+
340
+ def arrange_data_augmentation(self, data_iterator):
341
+ """
342
+ Code for all types of data augmentation should be added here.
343
+ """
344
+ raise NotImplementedError()
345
+
346
+ def process_item(self, item_name, meta_data, binarization_args):
347
+ raise NotImplementedError()
basics/base_dataset.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+
7
+ from utils.hparams import hparams
8
+ from utils.indexed_datasets import IndexedDataset
9
+
10
+
11
+ class BaseDataset(Dataset):
12
+ """
13
+ Base class for datasets.
14
+ 1. *sizes*:
15
+ clipped length if "max_frames" is set;
16
+ 2. *num_frames*:
17
+ unclipped length.
18
+
19
+ Subclasses should define:
20
+ 1. *collate*:
21
+ take the longest data, pad other data to the same length;
22
+ 2. *__getitem__*:
23
+ the index function.
24
+ """
25
+
26
+ def __init__(self, prefix, size_key='lengths', preload=False):
27
+ super().__init__()
28
+ self.prefix = prefix
29
+ self.data_dir = hparams['binary_data_dir']
30
+ with open(os.path.join(self.data_dir, f'{self.prefix}.meta'), 'rb') as f:
31
+ self.metadata = pickle.load(f)
32
+ self.sizes = self.metadata[size_key]
33
+ self._indexed_ds = IndexedDataset(self.data_dir, self.prefix)
34
+ if preload:
35
+ self.indexed_ds = [self._indexed_ds[i] for i in range(len(self._indexed_ds))]
36
+ del self._indexed_ds
37
+ else:
38
+ self.indexed_ds = self._indexed_ds
39
+
40
+ def __getitem__(self, index):
41
+ return {'_idx': index, **self.indexed_ds[index]}
42
+
43
+ def __len__(self):
44
+ return len(self.sizes)
45
+
46
+ def num_frames(self, index):
47
+ return self.sizes[index]
48
+
49
+ def size(self, index):
50
+ """Return an example's size as a float or tuple. This value is used when
51
+ filtering a dataset with ``--max-positions``."""
52
+ return self.sizes[index]
53
+
54
+ def collater(self, samples):
55
+ return {
56
+ 'size': len(samples),
57
+ 'indices': torch.LongTensor([s['_idx'] for s in samples])
58
+ }
basics/base_exporter.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from utils.hparams import hparams
9
+
10
+
11
+ class BaseExporter:
12
+ def __init__(
13
+ self,
14
+ device: Union[str, torch.device] = None,
15
+ cache_dir: Path = None,
16
+ **kwargs
17
+ ):
18
+ self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ self.cache_dir: Path = cache_dir.resolve() if cache_dir is not None \
20
+ else Path(__file__).parent.parent / 'deployment' / 'cache'
21
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
22
+
23
+ # noinspection PyMethodMayBeStatic
24
+ def build_spk_map(self) -> dict:
25
+ if hparams['use_spk_id']:
26
+ with open(Path(hparams['work_dir']) / 'spk_map.json', 'r', encoding='utf8') as f:
27
+ spk_map = json.load(f)
28
+ assert isinstance(spk_map, dict) and len(spk_map) > 0, 'Invalid or empty speaker map!'
29
+ assert len(spk_map) == len(set(spk_map.values())), 'Duplicate speaker id in speaker map!'
30
+ return spk_map
31
+ else:
32
+ return {}
33
+
34
+ def build_model(self) -> nn.Module:
35
+ """
36
+ Creates an instance of nn.Module and load its state dict on the target device.
37
+ """
38
+ raise NotImplementedError()
39
+
40
+ def export_model(self, path: Path):
41
+ """
42
+ Exports the model to ONNX format.
43
+ :param path: the target model path
44
+ """
45
+ raise NotImplementedError()
46
+
47
+ def export_attachments(self, path: Path):
48
+ """
49
+ Exports related files and configs (e.g. the dictionary) to the target directory.
50
+ :param path: the target directory
51
+ """
52
+ raise NotImplementedError()
53
+
54
+ def export(self, path: Path):
55
+ """
56
+ Exports all the artifacts to the target directory.
57
+ :param path: the target directory
58
+ """
59
+ raise NotImplementedError()
basics/base_module.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class CategorizedModule(nn.Module):
5
+ @property
6
+ def category(self):
7
+ raise NotImplementedError()
8
+
9
+ def check_category(self, category):
10
+ if category is None:
11
+ raise RuntimeError('Category is not specified in this checkpoint.\n'
12
+ 'If this is a checkpoint in the old format, please consider '
13
+ 'migrating it to the new format via the following command:\n'
14
+ 'python scripts/migrate.py ckpt <INPUT_CKPT> <OUTPUT_CKPT>')
15
+ elif category != self.category:
16
+ raise RuntimeError('Category mismatches!\n'
17
+ f'This checkpoint is of the category \'{category}\', '
18
+ f'but a checkpoint of category \'{self.category}\' is required.')
basics/base_pe.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ class BasePE:
2
+ def get_pitch(
3
+ self, waveform, samplerate, length,
4
+ *, hop_size, f0_min=65, f0_max=1100,
5
+ speed=1, interp_uv=False
6
+ ):
7
+ raise NotImplementedError()
basics/base_svs_infer.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf8
2
+ import numpy as np
3
+ import torch
4
+ from torch import Tensor
5
+ from typing import Tuple, Dict
6
+
7
+ from utils.hparams import hparams
8
+ from utils.infer_utils import resample_align_curve
9
+
10
+
11
+ class BaseSVSInfer:
12
+ """
13
+ Base class for SVS inference models.
14
+ Subclasses should define:
15
+ 1. *build_model*:
16
+ how to build the model;
17
+ 2. *run_model*:
18
+ how to run the model (typically, generate a mel-spectrogram and
19
+ pass it to the pre-built vocoder);
20
+ 3. *preprocess_input*:
21
+ how to preprocess user input.
22
+ 4. *infer_once*
23
+ infer from raw inputs to the final outputs
24
+ """
25
+
26
+ def __init__(self, device=None):
27
+ if device is None:
28
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
29
+ self.device = device
30
+ self.timestep = hparams['hop_size'] / hparams['audio_sample_rate']
31
+ self.spk_map = {}
32
+ self.model: torch.nn.Module = None
33
+
34
+ def build_model(self, ckpt_steps=None) -> torch.nn.Module:
35
+ raise NotImplementedError()
36
+
37
+ def load_speaker_mix(self, param_src: dict, summary_dst: dict,
38
+ mix_mode: str = 'frame', mix_length: int = None) -> Tuple[Tensor, Tensor]:
39
+ """
40
+
41
+ :param param_src: param dict
42
+ :param summary_dst: summary dict
43
+ :param mix_mode: 'token' or 'frame'
44
+ :param mix_length: total tokens or frames to mix
45
+ :return: spk_mix_id [B=1, 1, N], spk_mix_value [B=1, T, N]
46
+ """
47
+ assert mix_mode == 'token' or mix_mode == 'frame'
48
+ param_key = 'spk_mix' if mix_mode == 'frame' else 'ph_spk_mix'
49
+ summary_solo_key = 'spk' if mix_mode == 'frame' else 'ph_spk'
50
+ spk_mix_map = param_src.get(param_key) # { spk_name: value } or { spk_name: "value value value ..." }
51
+ dynamic = False
52
+ if spk_mix_map is None:
53
+ # Get the first speaker
54
+ for name in self.spk_map.keys():
55
+ spk_mix_map = {name: 1.0}
56
+ break
57
+ else:
58
+ for name in spk_mix_map:
59
+ assert name in self.spk_map, f'Speaker \'{name}\' not found.'
60
+ if len(spk_mix_map) == 1:
61
+ summary_dst[summary_solo_key] = list(spk_mix_map.keys())[0]
62
+ elif any([isinstance(val, str) for val in spk_mix_map.values()]):
63
+ print_mix = '|'.join(spk_mix_map.keys())
64
+ summary_dst[param_key] = f'dynamic({print_mix})'
65
+ dynamic = True
66
+ else:
67
+ print_mix = '|'.join([f'{n}:{"%.3f" % spk_mix_map[n]}' for n in spk_mix_map])
68
+ summary_dst[param_key] = f'static({print_mix})'
69
+ spk_mix_id_list = []
70
+ spk_mix_value_list = []
71
+ if dynamic:
72
+ for name, values in spk_mix_map.items():
73
+ spk_mix_id_list.append(self.spk_map[name])
74
+ if isinstance(values, str):
75
+ # this speaker has a variable proportion
76
+ if mix_mode == 'token':
77
+ cur_spk_mix_value = values.split()
78
+ assert len(cur_spk_mix_value) == mix_length, \
79
+ 'Speaker mix checks failed. In dynamic token-level mix, ' \
80
+ 'number of proportion values must equal number of tokens.'
81
+ cur_spk_mix_value = torch.from_numpy(
82
+ np.array(cur_spk_mix_value, 'float32')
83
+ ).to(self.device)[None] # => [B=1, T]
84
+ else:
85
+ cur_spk_mix_value = torch.from_numpy(resample_align_curve(
86
+ np.array(values.split(), 'float32'),
87
+ original_timestep=float(param_src['spk_mix_timestep']),
88
+ target_timestep=self.timestep,
89
+ align_length=mix_length
90
+ )).to(self.device)[None] # => [B=1, T]
91
+ assert torch.all(cur_spk_mix_value >= 0.), \
92
+ f'Speaker mix checks failed.\n' \
93
+ f'Proportions of speaker \'{name}\' on some {mix_mode}s are negative.'
94
+ else:
95
+ # this speaker has a constant proportion
96
+ assert values >= 0., f'Speaker mix checks failed.\n' \
97
+ f'Proportion of speaker \'{name}\' is negative.'
98
+ cur_spk_mix_value = torch.full(
99
+ (1, mix_length), fill_value=values,
100
+ dtype=torch.float32, device=self.device
101
+ )
102
+ spk_mix_value_list.append(cur_spk_mix_value)
103
+ spk_mix_id = torch.LongTensor(spk_mix_id_list).to(self.device)[None, None] # => [B=1, 1, N]
104
+ spk_mix_value = torch.stack(spk_mix_value_list, dim=2) # [B=1, T] => [B=1, T, N]
105
+ spk_mix_value_sum = torch.sum(spk_mix_value, dim=2, keepdim=True) # => [B=1, T, 1]
106
+ assert torch.all(spk_mix_value_sum > 0.), \
107
+ f'Speaker mix checks failed.\n' \
108
+ f'Proportions of speaker mix on some frames sum to zero.'
109
+ spk_mix_value /= spk_mix_value_sum # normalize
110
+ else:
111
+ for name, value in spk_mix_map.items():
112
+ spk_mix_id_list.append(self.spk_map[name])
113
+ assert value >= 0., f'Speaker mix checks failed.\n' \
114
+ f'Proportion of speaker \'{name}\' is negative.'
115
+ spk_mix_value_list.append(value)
116
+ spk_mix_id = torch.LongTensor(spk_mix_id_list).to(self.device)[None, None] # => [B=1, 1, N]
117
+ spk_mix_value = torch.FloatTensor(spk_mix_value_list).to(self.device)[None, None] # => [B=1, 1, N]
118
+ spk_mix_value_sum = spk_mix_value.sum()
119
+ assert spk_mix_value_sum > 0., f'Speaker mix checks failed.\n' \
120
+ f'Proportions of speaker mix sum to zero.'
121
+ spk_mix_value /= spk_mix_value_sum # normalize
122
+ return spk_mix_id, spk_mix_value
123
+
124
+ def preprocess_input(self, param: dict, idx=0) -> Dict[str, torch.Tensor]:
125
+ raise NotImplementedError()
126
+
127
+ def forward_model(self, sample: Dict[str, torch.Tensor]):
128
+ raise NotImplementedError()
129
+
130
+ def run_inference(self, params, **kwargs):
131
+ raise NotImplementedError()
basics/base_task.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import pathlib
4
+ import shutil
5
+ import sys
6
+ from typing import Dict
7
+
8
+ import matplotlib
9
+
10
+ import utils
11
+ from utils.text_encoder import TokenTextEncoder
12
+
13
+ matplotlib.use('Agg')
14
+
15
+ import torch.utils.data
16
+ from torchmetrics import Metric, MeanMetric
17
+ import lightning.pytorch as pl
18
+ from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_only
19
+
20
+ from basics.base_module import CategorizedModule
21
+ from utils.hparams import hparams
22
+ from utils.training_utils import (
23
+ DsModelCheckpoint, DsTQDMProgressBar,
24
+ DsBatchSampler, DsTensorBoardLogger,
25
+ get_latest_checkpoint_path, get_strategy
26
+ )
27
+ from utils.phoneme_utils import locate_dictionary, build_phoneme_list
28
+
29
+ torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system'))
30
+
31
+ log_format = '%(asctime)s %(message)s'
32
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO,
33
+ format=log_format, datefmt='%m/%d %I:%M:%S %p')
34
+
35
+
36
+ class BaseTask(pl.LightningModule):
37
+ """
38
+ Base class for training tasks.
39
+ 1. *load_ckpt*:
40
+ load checkpoint;
41
+ 2. *training_step*:
42
+ record and log the loss;
43
+ 3. *optimizer_step*:
44
+ run backwards step;
45
+ 4. *start*:
46
+ load training configs, backup code, log to tensorboard, start training;
47
+ 5. *configure_ddp* and *init_ddp_connection*:
48
+ start parallel training.
49
+
50
+ Subclasses should define:
51
+ 1. *build_model*, *build_optimizer*, *build_scheduler*:
52
+ how to build the model, the optimizer and the training scheduler;
53
+ 2. *_training_step*:
54
+ one training step of the model;
55
+ 3. *on_validation_end* and *_on_validation_end*:
56
+ postprocess the validation output.
57
+ """
58
+
59
+ def __init__(self, *args, **kwargs):
60
+ super().__init__(*args, **kwargs)
61
+ self.max_batch_frames = hparams['max_batch_frames']
62
+ self.max_batch_size = hparams['max_batch_size']
63
+ self.max_val_batch_frames = hparams['max_val_batch_frames']
64
+ if self.max_val_batch_frames == -1:
65
+ hparams['max_val_batch_frames'] = self.max_val_batch_frames = self.max_batch_frames
66
+ self.max_val_batch_size = hparams['max_val_batch_size']
67
+ if self.max_val_batch_size == -1:
68
+ hparams['max_val_batch_size'] = self.max_val_batch_size = self.max_batch_size
69
+
70
+ self.training_sampler = None
71
+ self.skip_immediate_validation = False
72
+ self.skip_immediate_ckpt_save = False
73
+
74
+ self.phone_encoder = self.build_phone_encoder()
75
+ self.build_model()
76
+
77
+ self.valid_losses: Dict[str, Metric] = {}
78
+ self.valid_metrics: Dict[str, Metric] = {}
79
+
80
+ def _finish_init(self):
81
+ self.register_validation_loss('total_loss')
82
+ self.build_losses_and_metrics()
83
+ assert len(self.valid_losses) > 0, "No validation loss registered. Please check your configuration file."
84
+
85
+ ###########
86
+ # Training, validation and testing
87
+ ###########
88
+ def setup(self, stage):
89
+ self.train_dataset = self.dataset_cls('train')
90
+ self.valid_dataset = self.dataset_cls('valid')
91
+ self.num_replicas = (self.trainer.distributed_sampler_kwargs or {}).get('num_replicas', 1)
92
+
93
+ def get_need_freeze_state_dict_key(self, model_state_dict) -> list:
94
+ key_list = []
95
+ for i in hparams['frozen_params']:
96
+ for j in model_state_dict:
97
+ if j.startswith(i):
98
+ key_list.append(j)
99
+ return list(set(key_list))
100
+
101
+ def freeze_params(self) -> None:
102
+ model_state_dict = self.state_dict().keys()
103
+ freeze_key = self.get_need_freeze_state_dict_key(model_state_dict=model_state_dict)
104
+
105
+ for i in freeze_key:
106
+ params=self.get_parameter(i)
107
+
108
+ params.requires_grad = False
109
+
110
+ def unfreeze_all_params(self) -> None:
111
+ for i in self.model.parameters():
112
+ i.requires_grad = True
113
+
114
+ def load_finetune_ckpt(
115
+ self, state_dict
116
+ ) -> None:
117
+ adapt_shapes = hparams['finetune_strict_shapes']
118
+ if not adapt_shapes:
119
+ cur_model_state_dict = self.state_dict()
120
+ unmatched_keys = []
121
+ for key, param in state_dict.items():
122
+ if key in cur_model_state_dict:
123
+ new_param = cur_model_state_dict[key]
124
+ if new_param.shape != param.shape:
125
+ unmatched_keys.append(key)
126
+ print('| Unmatched keys: ', key, new_param.shape, param.shape)
127
+ for key in unmatched_keys:
128
+ del state_dict[key]
129
+ self.load_state_dict(state_dict, strict=False)
130
+
131
+ def load_pre_train_model(self):
132
+ pre_train_ckpt_path = hparams['finetune_ckpt_path']
133
+ blacklist = hparams['finetune_ignored_params']
134
+ # whitelist=hparams['pre_train_whitelist']
135
+ if blacklist is None:
136
+ blacklist = []
137
+ # if whitelist is None:
138
+ # raise RuntimeError("")
139
+
140
+ if pre_train_ckpt_path is not None:
141
+ ckpt = torch.load(pre_train_ckpt_path)
142
+ # if ckpt.get('category') is None:
143
+ # raise RuntimeError("")
144
+
145
+ if isinstance(self.model, CategorizedModule):
146
+ self.model.check_category(ckpt.get('category'))
147
+
148
+ state_dict = {}
149
+ for i in ckpt['state_dict']:
150
+ # if 'diffusion' in i:
151
+ # if i in rrrr:
152
+ # continue
153
+ skip = False
154
+ for b in blacklist:
155
+ if i.startswith(b):
156
+ skip = True
157
+ break
158
+
159
+ if skip:
160
+ continue
161
+
162
+ state_dict[i] = ckpt['state_dict'][i]
163
+ print(i)
164
+ return state_dict
165
+ else:
166
+ raise RuntimeError("")
167
+
168
+ @staticmethod
169
+ def build_phone_encoder():
170
+ phone_list = build_phoneme_list()
171
+ return TokenTextEncoder(vocab_list=phone_list)
172
+
173
+ def _build_model(self):
174
+ raise NotImplementedError()
175
+
176
+ def build_model(self):
177
+ self.model = self._build_model()
178
+ # utils.load_warp(self)
179
+ self.unfreeze_all_params()
180
+ if hparams['freezing_enabled']:
181
+ self.freeze_params()
182
+ if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None:
183
+ self.load_finetune_ckpt(self.load_pre_train_model())
184
+ self.print_arch()
185
+
186
+ @rank_zero_only
187
+ def print_arch(self):
188
+ utils.print_arch(self.model)
189
+
190
+ def build_losses_and_metrics(self):
191
+ raise NotImplementedError()
192
+
193
+ def register_validation_metric(self, name: str, metric: Metric):
194
+ assert isinstance(metric, Metric)
195
+ self.valid_metrics[name] = metric
196
+
197
+ def register_validation_loss(self, name: str, Aggregator: Metric = MeanMetric):
198
+ assert issubclass(Aggregator, Metric)
199
+ self.valid_losses[name] = Aggregator()
200
+
201
+ def run_model(self, sample, infer=False):
202
+ """
203
+ steps:
204
+ 1. run the full model
205
+ 2. calculate losses if not infer
206
+ """
207
+ raise NotImplementedError()
208
+
209
+ def on_train_epoch_start(self):
210
+ if self.training_sampler is not None:
211
+ self.training_sampler.set_epoch(self.current_epoch)
212
+
213
+ def _training_step(self, sample):
214
+ """
215
+ :return: total loss: torch.Tensor, loss_log: dict, other_log: dict
216
+ """
217
+ losses = self.run_model(sample)
218
+ total_loss = sum(losses.values())
219
+ return total_loss, {**losses, 'batch_size': float(sample['size'])}
220
+
221
+ def training_step(self, sample, batch_idx):
222
+ total_loss, log_outputs = self._training_step(sample)
223
+
224
+ # logs to progress bar
225
+ self.log_dict(log_outputs, prog_bar=True, logger=False, on_step=True, on_epoch=False)
226
+ self.log('lr', self.lr_schedulers().get_last_lr()[0], prog_bar=True, logger=False, on_step=True, on_epoch=False)
227
+ # logs to tensorboard
228
+ if self.global_step % hparams['log_interval'] == 0:
229
+ tb_log = {f'training/{k}': v for k, v in log_outputs.items()}
230
+ tb_log['training/lr'] = self.lr_schedulers().get_last_lr()[0]
231
+ self.logger.log_metrics(tb_log, step=self.global_step)
232
+
233
+ return total_loss
234
+
235
+ # def on_before_optimizer_step(self, *args, **kwargs):
236
+ # self.log_dict(grad_norm(self, norm_type=2))
237
+
238
+ def _on_validation_start(self):
239
+ pass
240
+
241
+ def on_validation_start(self):
242
+ if self.skip_immediate_validation:
243
+ rank_zero_debug("Skip validation")
244
+ return
245
+ self._on_validation_start()
246
+ for metric in self.valid_losses.values():
247
+ metric.to(self.device)
248
+ metric.reset()
249
+ for metric in self.valid_metrics.values():
250
+ metric.to(self.device)
251
+ metric.reset()
252
+
253
+ def _validation_step(self, sample, batch_idx):
254
+ """
255
+
256
+ :param sample:
257
+ :param batch_idx:
258
+ :return: loss_log: dict, weight: int
259
+ """
260
+ raise NotImplementedError()
261
+
262
+ def validation_step(self, sample, batch_idx):
263
+ """
264
+
265
+ :param sample:
266
+ :param batch_idx:
267
+ """
268
+ if self.skip_immediate_validation:
269
+ rank_zero_debug("Skip validation")
270
+ return
271
+ if sample['size'] > 0:
272
+ with torch.autocast(self.device.type, enabled=False):
273
+ losses, weight = self._validation_step(sample, batch_idx)
274
+ losses = {
275
+ 'total_loss': sum(losses.values()),
276
+ **losses
277
+ }
278
+ for k, v in losses.items():
279
+ self.valid_losses[k].update(v, weight=weight)
280
+
281
+ def _on_validation_epoch_end(self):
282
+ pass
283
+
284
+ def on_validation_epoch_end(self):
285
+ if self.skip_immediate_validation:
286
+ self.skip_immediate_validation = False
287
+ self.skip_immediate_ckpt_save = True
288
+ return
289
+ self._on_validation_epoch_end()
290
+ loss_vals = {k: v.compute() for k, v in self.valid_losses.items()}
291
+ metric_vals = {k: v.compute() for k, v in self.valid_metrics.items()}
292
+ self.log('val_loss', loss_vals['total_loss'], on_epoch=True, prog_bar=True, logger=False, sync_dist=True)
293
+ self.logger.log_metrics({f'validation/{k}': v for k, v in loss_vals.items()}, step=self.global_step)
294
+ self.logger.log_metrics({f'metrics/{k}': v for k, v in metric_vals.items()}, step=self.global_step)
295
+
296
+ # noinspection PyMethodMayBeStatic
297
+ def build_scheduler(self, optimizer):
298
+ from utils import build_lr_scheduler_from_config
299
+
300
+ scheduler_args = hparams['lr_scheduler_args']
301
+ assert scheduler_args['scheduler_cls'] != ''
302
+ scheduler = build_lr_scheduler_from_config(optimizer, scheduler_args)
303
+ return scheduler
304
+
305
+ # noinspection PyMethodMayBeStatic
306
+ def build_optimizer(self, model):
307
+ from utils import build_object_from_class_name
308
+
309
+ optimizer_args = hparams['optimizer_args']
310
+ assert optimizer_args['optimizer_cls'] != ''
311
+ if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args:
312
+ optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2'])
313
+ optimizer = build_object_from_class_name(
314
+ optimizer_args['optimizer_cls'],
315
+ torch.optim.Optimizer,
316
+ model.parameters(),
317
+ **optimizer_args
318
+ )
319
+ return optimizer
320
+
321
+ def configure_optimizers(self):
322
+ optm = self.build_optimizer(self.model)
323
+ scheduler = self.build_scheduler(optm)
324
+ if scheduler is None:
325
+ return optm
326
+ return {
327
+ "optimizer": optm,
328
+ "lr_scheduler": {
329
+ "scheduler": scheduler,
330
+ "interval": "step",
331
+ "frequency": 1
332
+ }
333
+ }
334
+
335
+ def train_dataloader(self):
336
+ self.training_sampler = DsBatchSampler(
337
+ self.train_dataset,
338
+ max_batch_frames=self.max_batch_frames,
339
+ max_batch_size=self.max_batch_size,
340
+ num_replicas=self.num_replicas,
341
+ rank=self.global_rank,
342
+ sort_by_similar_size=hparams['sort_by_len'],
343
+ size_reversed=True,
344
+ required_batch_count_multiple=hparams['accumulate_grad_batches'],
345
+ shuffle_sample=True,
346
+ shuffle_batch=True
347
+ )
348
+ return torch.utils.data.DataLoader(
349
+ self.train_dataset,
350
+ collate_fn=self.train_dataset.collater,
351
+ batch_sampler=self.training_sampler,
352
+ num_workers=hparams['ds_workers'],
353
+ prefetch_factor=hparams['dataloader_prefetch_factor'],
354
+ pin_memory=True,
355
+ persistent_workers=True
356
+ )
357
+
358
+ def val_dataloader(self):
359
+ sampler = DsBatchSampler(
360
+ self.valid_dataset,
361
+ max_batch_frames=self.max_val_batch_frames,
362
+ max_batch_size=self.max_val_batch_size,
363
+ num_replicas=self.num_replicas,
364
+ rank=self.global_rank,
365
+ shuffle_sample=False,
366
+ shuffle_batch=False,
367
+ disallow_empty_batch=False,
368
+ pad_batch_assignment=False
369
+ )
370
+ return torch.utils.data.DataLoader(
371
+ self.valid_dataset,
372
+ collate_fn=self.valid_dataset.collater,
373
+ batch_sampler=sampler,
374
+ num_workers=hparams['ds_workers'],
375
+ prefetch_factor=hparams['dataloader_prefetch_factor'],
376
+ persistent_workers=True
377
+ )
378
+
379
+ def test_dataloader(self):
380
+ return self.val_dataloader()
381
+
382
+ def on_test_start(self):
383
+ self.on_validation_start()
384
+
385
+ def test_step(self, sample, batch_idx):
386
+ return self.validation_step(sample, batch_idx)
387
+
388
+ def on_test_end(self):
389
+ return self.on_validation_end()
390
+
391
+ ###########
392
+ # Running configuration
393
+ ###########
394
+
395
+ @classmethod
396
+ def start(cls):
397
+ task = cls()
398
+
399
+ # if pre_train is not None:
400
+ # task.load_state_dict(pre_train,strict=False)
401
+ # print("load success-------------------------------------------------------------------")
402
+
403
+ work_dir = pathlib.Path(hparams['work_dir'])
404
+ trainer = pl.Trainer(
405
+ accelerator=hparams['pl_trainer_accelerator'],
406
+ devices=hparams['pl_trainer_devices'],
407
+ num_nodes=hparams['pl_trainer_num_nodes'],
408
+ strategy=get_strategy(
409
+ hparams['pl_trainer_devices'],
410
+ hparams['pl_trainer_num_nodes'],
411
+ hparams['pl_trainer_accelerator'],
412
+ hparams['pl_trainer_strategy'],
413
+ hparams['pl_trainer_precision'],
414
+ ),
415
+ precision=hparams['pl_trainer_precision'],
416
+ callbacks=[
417
+ DsModelCheckpoint(
418
+ dirpath=work_dir,
419
+ filename='model_ckpt_steps_{step}',
420
+ auto_insert_metric_name=False,
421
+ monitor='step',
422
+ mode='max',
423
+ save_last=False,
424
+ # every_n_train_steps=hparams['val_check_interval'],
425
+ save_top_k=hparams['num_ckpt_keep'],
426
+ permanent_ckpt_start=hparams['permanent_ckpt_start'],
427
+ permanent_ckpt_interval=hparams['permanent_ckpt_interval'],
428
+ verbose=True
429
+ ),
430
+ # LearningRateMonitor(logging_interval='step'),
431
+ DsTQDMProgressBar(),
432
+ ],
433
+ logger=DsTensorBoardLogger(
434
+ save_dir=str(work_dir),
435
+ name='lightning_logs',
436
+ version='latest'
437
+ ),
438
+ gradient_clip_val=hparams['clip_grad_norm'],
439
+ val_check_interval=hparams['val_check_interval'] * hparams['accumulate_grad_batches'],
440
+ # so this is global_steps
441
+ check_val_every_n_epoch=None,
442
+ log_every_n_steps=1,
443
+ max_steps=hparams['max_updates'],
444
+ use_distributed_sampler=False,
445
+ num_sanity_val_steps=hparams['num_sanity_val_steps'],
446
+ accumulate_grad_batches=hparams['accumulate_grad_batches']
447
+ )
448
+ if not hparams['infer']: # train
449
+ @rank_zero_only
450
+ def train_payload_copy():
451
+ # Copy spk_map.json and dictionary.txt to work dir
452
+ binary_dir = pathlib.Path(hparams['binary_data_dir'])
453
+ spk_map = work_dir / 'spk_map.json'
454
+ spk_map_src = binary_dir / 'spk_map.json'
455
+ if not spk_map.exists() and spk_map_src.exists():
456
+ shutil.copy(spk_map_src, spk_map)
457
+ print(f'| Copied spk map to {spk_map}.')
458
+ dictionary = work_dir / 'dictionary.txt'
459
+ dict_src = binary_dir / 'dictionary.txt'
460
+ if not dictionary.exists():
461
+ if dict_src.exists():
462
+ shutil.copy(dict_src, dictionary)
463
+ else:
464
+ shutil.copy(locate_dictionary(), dictionary)
465
+ print(f'| Copied dictionary to {dictionary}.')
466
+
467
+ train_payload_copy()
468
+ trainer.fit(task, ckpt_path=get_latest_checkpoint_path(work_dir))
469
+ else:
470
+ trainer.test(task)
471
+
472
+ def on_save_checkpoint(self, checkpoint):
473
+ if isinstance(self.model, CategorizedModule):
474
+ checkpoint['category'] = self.model.category
475
+ checkpoint['trainer_stage'] = self.trainer.state.stage.value
476
+
477
+ def on_load_checkpoint(self, checkpoint):
478
+ from lightning.pytorch.trainer.states import RunningStage
479
+ from utils import simulate_lr_scheduler
480
+ if checkpoint.get('trainer_stage', '') == RunningStage.VALIDATING.value:
481
+ self.skip_immediate_validation = True
482
+
483
+ optimizer_args = hparams['optimizer_args']
484
+ scheduler_args = hparams['lr_scheduler_args']
485
+
486
+ if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args:
487
+ optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2'])
488
+
489
+ if checkpoint.get('optimizer_states', None):
490
+ opt_states = checkpoint['optimizer_states']
491
+ assert len(opt_states) == 1 # only support one optimizer
492
+ opt_state = opt_states[0]
493
+ for param_group in opt_state['param_groups']:
494
+ for k, v in optimizer_args.items():
495
+ if k in param_group and param_group[k] != v:
496
+ if 'lr_schedulers' in checkpoint and checkpoint['lr_schedulers'] and k == 'lr':
497
+ continue
498
+ rank_zero_info(f'| Overriding optimizer parameter {k} from checkpoint: {param_group[k]} -> {v}')
499
+ param_group[k] = v
500
+ if 'initial_lr' in param_group and param_group['initial_lr'] != optimizer_args['lr']:
501
+ rank_zero_info(
502
+ f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}'
503
+ )
504
+ param_group['initial_lr'] = optimizer_args['lr']
505
+
506
+ if checkpoint.get('lr_schedulers', None):
507
+ assert checkpoint.get('optimizer_states', False)
508
+ assert len(checkpoint['lr_schedulers']) == 1 # only support one scheduler
509
+ checkpoint['lr_schedulers'][0] = simulate_lr_scheduler(
510
+ optimizer_args, scheduler_args,
511
+ step_count=checkpoint['global_step'],
512
+ num_param_groups=len(checkpoint['optimizer_states'][0]['param_groups'])
513
+ )
514
+ for param_group, new_lr in zip(
515
+ checkpoint['optimizer_states'][0]['param_groups'],
516
+ checkpoint['lr_schedulers'][0]['_last_lr'],
517
+ ):
518
+ if param_group['lr'] != new_lr:
519
+ rank_zero_info(f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}')
520
+ param_group['lr'] = new_lr
basics/base_vocoder.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class BaseVocoder:
2
+ def to_device(self, device):
3
+ """
4
+
5
+ :param device: torch.device or str
6
+ """
7
+ raise NotImplementedError()
8
+
9
+ def get_device(self):
10
+ """
11
+
12
+ :return: device: torch.device or str
13
+ """
14
+ raise NotImplementedError()
15
+
16
+ def spec2wav(self, mel, **kwargs):
17
+ """
18
+
19
+ :param mel: [T, 80]
20
+ :return: wav: [T']
21
+ """
22
+
23
+ raise NotImplementedError()
configs/CantusSVS_acoustic.yaml ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - configs/base.yaml
3
+
4
+ pl_trainer_precision: '16-mixed'
5
+ pe: 'rmvpe'
6
+ pe_ckpt: checkpoints/dependency_checkpoints/rmvpe/model.pt
7
+
8
+ task_cls: training.acoustic_task.AcousticTask
9
+ spk_ids: []
10
+
11
+ num_spk: 1
12
+ speakers:
13
+ - regular # index 0
14
+ test_prefixes:
15
+ # - {index_speaker}:{name_of_wav}
16
+ # regular (0)
17
+ - 0:Adventus_seg012
18
+ - 0:Adventus_seg024
19
+ - 0:Adventus_seg036
20
+ - 0:Adventus_seg048
21
+ - 0:Adventus_seg060
22
+
23
+ raw_data_dir:
24
+ - data/NNSVS_training_data/regular/diffsinger_db #0
25
+
26
+ hnsep: vr
27
+ hnsep_ckpt: checkpoints/dependency_checkpoints/vr/model.pt
28
+
29
+ vocoder: NsfHifiGAN
30
+ vocoder_ckpt: checkpoints/dependency_checkpoints/nsf-hifigan/model.ckpt
31
+ audio_sample_rate: 44100
32
+ audio_num_mel_bins: 128
33
+ hop_size: 512 # Hop size.
34
+ fft_size: 2048 # FFT size.
35
+ win_size: 2048 # FFT size.
36
+ fmin: 40
37
+ fmax: 16000
38
+
39
+ binarization_args:
40
+ shuffle: true
41
+ num_workers: 0
42
+ augmentation_args:
43
+ random_pitch_shifting:
44
+ enabled: true
45
+ range: [-3., 3.]
46
+ scale: 0.75
47
+ fixed_pitch_shifting:
48
+ enabled: false
49
+ targets: [-3., 3.]
50
+ scale: 0.5
51
+ random_time_stretching:
52
+ enabled: true
53
+ range: [0.8, 1.2]
54
+ scale: 0.75
55
+
56
+ binary_data_dir: 'data/binary/regular_acoustic_v1'
57
+ binarizer_cls: preprocessing.acoustic_binarizer.AcousticBinarizer
58
+ dictionary: dictionaries/latin_dictionary.txt
59
+ spec_min: [-12]
60
+ spec_max: [0]
61
+ mel_vmin: -14.
62
+ mel_vmax: 4.
63
+ mel_base: 'e'
64
+ energy_smooth_width: 0.12
65
+ breathiness_smooth_width: 0.12
66
+ voicing_smooth_width: 0.12
67
+ tension_smooth_width: 0.12
68
+
69
+ use_spk_id: false
70
+ use_energy_embed: false
71
+ use_breathiness_embed: false
72
+ use_voicing_embed: false
73
+ use_tension_embed: false
74
+ use_key_shift_embed: true
75
+ use_speed_embed: true
76
+
77
+ diffusion_type: reflow
78
+ time_scale_factor: 1000
79
+ timesteps: 1000
80
+ max_beta: 0.02
81
+ enc_ffn_kernel_size: 3
82
+ use_rope: true
83
+ rel_pos: true
84
+ sampling_algorithm: euler
85
+ sampling_steps: 20
86
+ diff_accelerator: ddim
87
+ diff_speedup: 10
88
+ hidden_size: 256
89
+ backbone_type: 'lynxnet'
90
+ backbone_args:
91
+ num_channels: 1024
92
+ num_layers: 6
93
+ kernel_size: 31
94
+ dropout_rate: 0.0
95
+ strong_cond: true
96
+ main_loss_type: l2
97
+ main_loss_log_norm: false
98
+ schedule_type: 'linear'
99
+
100
+ # shallow diffusion
101
+ use_shallow_diffusion: true
102
+ T_start: 0.4
103
+ T_start_infer: 0.4
104
+ K_step: 400
105
+ K_step_infer: 400
106
+
107
+ shallow_diffusion_args:
108
+ train_aux_decoder: true
109
+ train_diffusion: true
110
+ val_gt_start: false
111
+ aux_decoder_arch: convnext
112
+ aux_decoder_args:
113
+ num_channels: 512
114
+ num_layers: 6
115
+ kernel_size: 7
116
+ dropout_rate: 0.1
117
+ aux_decoder_grad: 0.1
118
+
119
+ lambda_aux_mel_loss: 0.2
120
+
121
+ # train and eval
122
+ num_sanity_val_steps: 1
123
+ optimizer_args:
124
+ lr: 0.0006
125
+ lr_scheduler_args:
126
+ step_size: 10000
127
+ gamma: 0.75
128
+ max_batch_frames: 50000
129
+ max_batch_size: 16
130
+ dataset_size_key: 'lengths'
131
+ val_with_vocoder: true
132
+ val_check_interval: 1000
133
+ num_valid_plots: 10
134
+ max_updates: 1000000
135
+ num_ckpt_keep: 5
136
+ permanent_ckpt_start: 20000
137
+ permanent_ckpt_interval: 5000
138
+
139
+ finetune_enabled: false
140
+ finetune_ckpt_path: null
141
+
142
+ finetune_ignored_params:
143
+ - model.fs2.encoder.embed_tokens
144
+ - model.fs2.txt_embed
145
+ - model.fs2.spk_embed
146
+ finetune_strict_shapes: true
147
+
148
+ freezing_enabled: false
149
+ frozen_params: []
configs/CantusSVS_variance.yaml ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - configs/base.yaml
3
+
4
+ pl_trainer_precision: '16-mixed'
5
+ pe: 'rmvpe'
6
+ pe_ckpt: checkpoints/dependency_checkpoints/rmvpe/model.pt
7
+
8
+ task_cls: training.variance_task.VarianceTask
9
+ spk_ids: []
10
+
11
+ num_spk: 1
12
+ speakers:
13
+ - regular # index 0
14
+ test_prefixes:
15
+ # - {index_speaker}:{name_of_wav}
16
+ # regular (0)
17
+ - 0:Adventus_seg012
18
+ - 0:Adventus_seg024
19
+ - 0:Adventus_seg036
20
+ - 0:Adventus_seg048
21
+ - 0:Adventus_seg060
22
+
23
+ raw_data_dir:
24
+ - data/NNSVS_training_data/regular/diffsinger_db #0
25
+
26
+ audio_sample_rate: 44100
27
+ hop_size: 512 # Hop size.
28
+ fft_size: 2048 # FFT size.
29
+ win_size: 2048 # FFT size.
30
+ midi_smooth_width: 0.06 # in seconds
31
+
32
+ binarization_args:
33
+ shuffle: true
34
+ num_workers: 0
35
+ prefer_ds: false
36
+
37
+ binary_data_dir: 'data/binary/regular_variance_v1'
38
+ binarizer_cls: preprocessing.variance_binarizer.VarianceBinarizer
39
+ dictionary: dictionaries/latin_dictionary.txt
40
+
41
+ use_spk_id: false
42
+
43
+ enc_ffn_kernel_size: 3
44
+ use_rope: true
45
+ rel_pos: true
46
+ hidden_size: 256
47
+
48
+ predict_dur: true
49
+ predict_pitch: true
50
+ predict_energy: false
51
+ predict_breathiness: false
52
+ predict_voicing: false
53
+ predict_tension: false
54
+
55
+ dur_prediction_args:
56
+ arch: fs2
57
+ hidden_size: 512
58
+ dropout: 0.1
59
+ num_layers: 5
60
+ kernel_size: 3
61
+ log_offset: 1.0
62
+ loss_type: mse
63
+ lambda_pdur_loss: 0.3
64
+ lambda_wdur_loss: 1.0
65
+ lambda_sdur_loss: 3.0
66
+
67
+ use_melody_encoder: false
68
+ melody_encoder_args:
69
+ hidden_size: 128
70
+ enc_layers: 4
71
+ use_glide_embed: false
72
+ glide_types: [up, down]
73
+ glide_embed_scale: 11.313708498984760 # sqrt(128)
74
+
75
+ pitch_prediction_args:
76
+ pitd_norm_min: -8.0
77
+ pitd_norm_max: 8.0
78
+ pitd_clip_min: -12.0
79
+ pitd_clip_max: 12.0
80
+ repeat_bins: 64
81
+ backbone_type: 'wavenet'
82
+ backbone_args:
83
+ num_layers: 20
84
+ num_channels: 256
85
+ dilation_cycle_length: 5
86
+
87
+ energy_db_min: -96.0
88
+ energy_db_max: -12.0
89
+ energy_smooth_width: 0.12
90
+
91
+ breathiness_db_min: -96.0
92
+ breathiness_db_max: -20.0
93
+ breathiness_smooth_width: 0.12
94
+ voicing_db_min: -96.0
95
+ voicing_db_max: -12.0
96
+ voicing_smooth_width: 0.12
97
+
98
+ tension_logit_min: -10.0
99
+ tension_logit_max: 10.0
100
+ tension_smooth_width: 0.12
101
+
102
+ variances_prediction_args:
103
+ total_repeat_bins: 48
104
+ backbone_type: 'wavenet'
105
+ backbone_args:
106
+ num_layers: 10
107
+ num_channels: 192
108
+ dilation_cycle_length: 4
109
+
110
+ lambda_dur_loss: 1.0
111
+ lambda_pitch_loss: 1.0
112
+ lambda_var_loss: 1.0
113
+
114
+ diffusion_type: reflow # ddpm
115
+ time_scale_factor: 1000
116
+ schedule_type: 'linear'
117
+ K_step: 1000
118
+ timesteps: 1000
119
+ max_beta: 0.02
120
+ main_loss_type: l2
121
+ main_loss_log_norm: true
122
+ sampling_algorithm: euler
123
+ sampling_steps: 20
124
+ diff_accelerator: ddim
125
+ diff_speedup: 10
126
+
127
+ # train and eval
128
+ num_sanity_val_steps: 1
129
+ optimizer_args:
130
+ lr: 0.0006
131
+ lr_scheduler_args:
132
+ step_size: 10000
133
+ gamma: 0.75
134
+ max_batch_frames: 50000
135
+ max_batch_size: 16
136
+ dataset_size_key: 'lengths'
137
+ val_check_interval: 1000
138
+ num_valid_plots: 10
139
+ max_updates: 1000000
140
+ num_ckpt_keep: 5
141
+ permanent_ckpt_start: 20000
142
+ permanent_ckpt_interval: 5000
143
+
144
+ finetune_enabled: false
145
+ finetune_ckpt_path: null
146
+ finetune_ignored_params:
147
+ - model.spk_embed
148
+ - model.fs2.txt_embed
149
+ - model.fs2.encoder.embed_tokens
150
+ finetune_strict_shapes: true
151
+
152
+ freezing_enabled: false
153
+ frozen_params: []
configs/base.yaml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # task
2
+ task_cls: null
3
+
4
+ #############
5
+ # dataset
6
+ #############
7
+ sort_by_len: true
8
+ raw_data_dir: null
9
+ binary_data_dir: null
10
+ binarizer_cls: null
11
+ binarization_args:
12
+ shuffle: false
13
+ num_workers: 0
14
+
15
+ audio_sample_rate: 44100
16
+ hop_size: 512
17
+ win_size: 2048
18
+ fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
19
+ sampler_frame_count_grid: 6
20
+ ds_workers: 4
21
+ dataloader_prefetch_factor: 2
22
+
23
+ #########
24
+ # model
25
+ #########
26
+ hidden_size: 256
27
+ dropout: 0.1
28
+ use_pos_embed: true
29
+ enc_layers: 4
30
+ num_heads: 2
31
+ enc_ffn_kernel_size: 9
32
+ ffn_act: gelu
33
+ use_spk_id: false
34
+
35
+ ###########
36
+ # optimization
37
+ ###########
38
+ optimizer_args:
39
+ optimizer_cls: torch.optim.AdamW
40
+ lr: 0.0004
41
+ beta1: 0.9
42
+ beta2: 0.98
43
+ weight_decay: 0
44
+ lr_scheduler_args:
45
+ scheduler_cls: torch.optim.lr_scheduler.StepLR
46
+ step_size: 50000
47
+ gamma: 0.5
48
+ clip_grad_norm: 1
49
+
50
+ ###########
51
+ # train and eval
52
+ ###########
53
+ num_ckpt_keep: 5
54
+ accumulate_grad_batches: 1
55
+ log_interval: 100
56
+ num_sanity_val_steps: 1 # steps of validation at the beginning
57
+ val_check_interval: 2000
58
+ max_updates: 120000
59
+ max_batch_frames: 32000
60
+ max_batch_size: 100000
61
+ max_val_batch_frames: 60000
62
+ max_val_batch_size: 1
63
+ pe: parselmouth
64
+ pe_ckpt: 'checkpoints/rmvpe/model.pt'
65
+ hnsep: vr
66
+ hnsep_ckpt: 'checkpoints/vr/model.pt'
67
+ f0_min: 65
68
+ f0_max: 1100
69
+ num_valid_plots: 10
70
+
71
+ ###########
72
+ # pytorch lightning
73
+ # Read https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api for possible values
74
+ ###########
75
+ pl_trainer_accelerator: 'auto'
76
+ pl_trainer_devices: 'auto'
77
+ pl_trainer_precision: '16-mixed'
78
+ pl_trainer_num_nodes: 1
79
+ pl_trainer_strategy:
80
+ name: auto
81
+ process_group_backend: nccl
82
+ find_unused_parameters: false
83
+ nccl_p2p: true
84
+
85
+ ###########
86
+ # finetune
87
+ ###########
88
+ finetune_enabled: false
89
+ finetune_ckpt_path: null
90
+ finetune_ignored_params: []
91
+ finetune_strict_shapes: true
92
+
93
+ freezing_enabled: false
94
+ frozen_params: []
configs/defaults/acoustic.yaml ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - configs/base.yaml
3
+
4
+ task_cls: training.acoustic_task.AcousticTask
5
+ num_spk: 1
6
+ speakers:
7
+ - opencpop
8
+ spk_ids: []
9
+ test_prefixes: [
10
+ '2044',
11
+ '2086',
12
+ '2092',
13
+ '2093',
14
+ '2100',
15
+ ]
16
+
17
+ vocoder: NsfHifiGAN
18
+ vocoder_ckpt: checkpoints/nsf_hifigan_44.1k_hop512_128bin_2024.02/model.ckpt
19
+ audio_sample_rate: 44100
20
+ audio_num_mel_bins: 128
21
+ hop_size: 512 # Hop size.
22
+ fft_size: 2048 # FFT size.
23
+ win_size: 2048 # FFT size.
24
+ fmin: 40
25
+ fmax: 16000
26
+
27
+ binarization_args:
28
+ shuffle: true
29
+ num_workers: 0
30
+ augmentation_args:
31
+ random_pitch_shifting:
32
+ enabled: false
33
+ range: [-5., 5.]
34
+ scale: 0.75
35
+ fixed_pitch_shifting:
36
+ enabled: false
37
+ targets: [-5., 5.]
38
+ scale: 0.5
39
+ random_time_stretching:
40
+ enabled: false
41
+ range: [0.5, 2.]
42
+ scale: 0.75
43
+
44
+ raw_data_dir: 'data/opencpop/raw'
45
+ binary_data_dir: 'data/opencpop/binary'
46
+ binarizer_cls: preprocessing.acoustic_binarizer.AcousticBinarizer
47
+ dictionary: dictionaries/opencpop-extension.txt
48
+ spec_min: [-12]
49
+ spec_max: [0]
50
+ mel_vmin: -14.
51
+ mel_vmax: 4.
52
+ mel_base: 'e'
53
+ energy_smooth_width: 0.12
54
+ breathiness_smooth_width: 0.12
55
+ voicing_smooth_width: 0.12
56
+ tension_smooth_width: 0.12
57
+
58
+ use_spk_id: false
59
+ use_energy_embed: false
60
+ use_breathiness_embed: false
61
+ use_voicing_embed: false
62
+ use_tension_embed: false
63
+ use_key_shift_embed: false
64
+ use_speed_embed: false
65
+
66
+ diffusion_type: reflow
67
+ time_scale_factor: 1000
68
+ timesteps: 1000
69
+ max_beta: 0.02
70
+ enc_ffn_kernel_size: 3
71
+ use_rope: true
72
+ rel_pos: true
73
+ sampling_algorithm: euler
74
+ sampling_steps: 20
75
+ diff_accelerator: ddim
76
+ diff_speedup: 10
77
+ hidden_size: 256
78
+ backbone_type: 'lynxnet'
79
+ backbone_args:
80
+ num_channels: 1024
81
+ num_layers: 6
82
+ kernel_size: 31
83
+ dropout_rate: 0.0
84
+ strong_cond: true
85
+ main_loss_type: l2
86
+ main_loss_log_norm: false
87
+ schedule_type: 'linear'
88
+
89
+ # shallow diffusion
90
+ use_shallow_diffusion: true
91
+ T_start: 0.4
92
+ T_start_infer: 0.4
93
+ K_step: 400
94
+ K_step_infer: 400
95
+
96
+ shallow_diffusion_args:
97
+ train_aux_decoder: true
98
+ train_diffusion: true
99
+ val_gt_start: false
100
+ aux_decoder_arch: convnext
101
+ aux_decoder_args:
102
+ num_channels: 512
103
+ num_layers: 6
104
+ kernel_size: 7
105
+ dropout_rate: 0.1
106
+ aux_decoder_grad: 0.1
107
+
108
+ lambda_aux_mel_loss: 0.2
109
+
110
+ # train and eval
111
+ num_sanity_val_steps: 1
112
+ optimizer_args:
113
+ lr: 0.0006
114
+ lr_scheduler_args:
115
+ step_size: 10000
116
+ gamma: 0.75
117
+ max_batch_frames: 50000
118
+ max_batch_size: 64
119
+ dataset_size_key: 'lengths'
120
+ val_with_vocoder: true
121
+ val_check_interval: 2000
122
+ num_valid_plots: 10
123
+ max_updates: 160000
124
+ num_ckpt_keep: 5
125
+ permanent_ckpt_start: 80000
126
+ permanent_ckpt_interval: 20000
127
+
128
+ finetune_enabled: false
129
+ finetune_ckpt_path: null
130
+
131
+ finetune_ignored_params:
132
+ - model.fs2.encoder.embed_tokens
133
+ - model.fs2.txt_embed
134
+ - model.fs2.spk_embed
135
+ finetune_strict_shapes: true
136
+
137
+ freezing_enabled: false
138
+ frozen_params: []
configs/defaults/base.yaml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # task
2
+ task_cls: null
3
+
4
+ #############
5
+ # dataset
6
+ #############
7
+ sort_by_len: true
8
+ raw_data_dir: null
9
+ binary_data_dir: null
10
+ binarizer_cls: null
11
+ binarization_args:
12
+ shuffle: false
13
+ num_workers: 0
14
+
15
+ audio_sample_rate: 44100
16
+ hop_size: 512
17
+ win_size: 2048
18
+ fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
19
+ sampler_frame_count_grid: 6
20
+ ds_workers: 4
21
+ dataloader_prefetch_factor: 2
22
+
23
+ #########
24
+ # model
25
+ #########
26
+ hidden_size: 256
27
+ dropout: 0.1
28
+ use_pos_embed: true
29
+ enc_layers: 4
30
+ num_heads: 2
31
+ enc_ffn_kernel_size: 9
32
+ ffn_act: gelu
33
+ use_spk_id: false
34
+
35
+ ###########
36
+ # optimization
37
+ ###########
38
+ optimizer_args:
39
+ optimizer_cls: torch.optim.AdamW
40
+ lr: 0.0004
41
+ beta1: 0.9
42
+ beta2: 0.98
43
+ weight_decay: 0
44
+ lr_scheduler_args:
45
+ scheduler_cls: torch.optim.lr_scheduler.StepLR
46
+ step_size: 50000
47
+ gamma: 0.5
48
+ clip_grad_norm: 1
49
+
50
+ ###########
51
+ # train and eval
52
+ ###########
53
+ num_ckpt_keep: 5
54
+ accumulate_grad_batches: 1
55
+ log_interval: 100
56
+ num_sanity_val_steps: 1 # steps of validation at the beginning
57
+ val_check_interval: 2000
58
+ max_updates: 120000
59
+ max_batch_frames: 32000
60
+ max_batch_size: 100000
61
+ max_val_batch_frames: 60000
62
+ max_val_batch_size: 1
63
+ pe: parselmouth
64
+ pe_ckpt: 'checkpoints/rmvpe/model.pt'
65
+ hnsep: vr
66
+ hnsep_ckpt: 'checkpoints/vr/model.pt'
67
+ f0_min: 65
68
+ f0_max: 1100
69
+ num_valid_plots: 10
70
+
71
+ ###########
72
+ # pytorch lightning
73
+ # Read https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api for possible values
74
+ ###########
75
+ pl_trainer_accelerator: 'auto'
76
+ pl_trainer_devices: 'auto'
77
+ pl_trainer_precision: '16-mixed'
78
+ pl_trainer_num_nodes: 1
79
+ pl_trainer_strategy:
80
+ name: auto
81
+ process_group_backend: nccl
82
+ find_unused_parameters: false
83
+ nccl_p2p: true
84
+
85
+ ###########
86
+ # finetune
87
+ ###########
88
+ finetune_enabled: false
89
+ finetune_ckpt_path: null
90
+ finetune_ignored_params: []
91
+ finetune_strict_shapes: true
92
+
93
+ freezing_enabled: false
94
+ frozen_params: []
configs/defaults/variance.yaml ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - configs/base.yaml
3
+
4
+ task_cls: training.variance_task.VarianceTask
5
+ num_spk: 1
6
+ speakers:
7
+ - opencpop
8
+ spk_ids: []
9
+ test_prefixes: [
10
+ '2044',
11
+ '2086',
12
+ '2092',
13
+ '2093',
14
+ '2100',
15
+ ]
16
+
17
+ audio_sample_rate: 44100
18
+ hop_size: 512 # Hop size.
19
+ fft_size: 2048 # FFT size.
20
+ win_size: 2048 # FFT size.
21
+ midi_smooth_width: 0.06 # in seconds
22
+
23
+ binarization_args:
24
+ shuffle: true
25
+ num_workers: 0
26
+ prefer_ds: false
27
+
28
+ raw_data_dir: 'data/opencpop_variance/raw'
29
+ binary_data_dir: 'data/opencpop_variance/binary'
30
+ binarizer_cls: preprocessing.variance_binarizer.VarianceBinarizer
31
+ dictionary: dictionaries/opencpop-extension.txt
32
+
33
+ use_spk_id: false
34
+
35
+ enc_ffn_kernel_size: 3
36
+ use_rope: true
37
+ rel_pos: true
38
+ hidden_size: 256
39
+
40
+ predict_dur: true
41
+ predict_pitch: true
42
+ predict_energy: false
43
+ predict_breathiness: false
44
+ predict_voicing: false
45
+ predict_tension: false
46
+
47
+ dur_prediction_args:
48
+ arch: fs2
49
+ hidden_size: 512
50
+ dropout: 0.1
51
+ num_layers: 5
52
+ kernel_size: 3
53
+ log_offset: 1.0
54
+ loss_type: mse
55
+ lambda_pdur_loss: 0.3
56
+ lambda_wdur_loss: 1.0
57
+ lambda_sdur_loss: 3.0
58
+
59
+ use_melody_encoder: false
60
+ melody_encoder_args:
61
+ hidden_size: 128
62
+ enc_layers: 4
63
+ use_glide_embed: false
64
+ glide_types: [up, down]
65
+ glide_embed_scale: 11.313708498984760 # sqrt(128)
66
+
67
+ pitch_prediction_args:
68
+ pitd_norm_min: -8.0
69
+ pitd_norm_max: 8.0
70
+ pitd_clip_min: -12.0
71
+ pitd_clip_max: 12.0
72
+ repeat_bins: 64
73
+ backbone_type: 'wavenet'
74
+ backbone_args:
75
+ num_layers: 20
76
+ num_channels: 256
77
+ dilation_cycle_length: 5
78
+
79
+ energy_db_min: -96.0
80
+ energy_db_max: -12.0
81
+ energy_smooth_width: 0.12
82
+
83
+ breathiness_db_min: -96.0
84
+ breathiness_db_max: -20.0
85
+ breathiness_smooth_width: 0.12
86
+ voicing_db_min: -96.0
87
+ voicing_db_max: -12.0
88
+ voicing_smooth_width: 0.12
89
+
90
+ tension_logit_min: -10.0
91
+ tension_logit_max: 10.0
92
+ tension_smooth_width: 0.12
93
+
94
+ variances_prediction_args:
95
+ total_repeat_bins: 48
96
+ backbone_type: 'wavenet'
97
+ backbone_args:
98
+ num_layers: 10
99
+ num_channels: 192
100
+ dilation_cycle_length: 4
101
+
102
+ lambda_dur_loss: 1.0
103
+ lambda_pitch_loss: 1.0
104
+ lambda_var_loss: 1.0
105
+
106
+ diffusion_type: reflow # ddpm
107
+ time_scale_factor: 1000
108
+ schedule_type: 'linear'
109
+ K_step: 1000
110
+ timesteps: 1000
111
+ max_beta: 0.02
112
+ main_loss_type: l2
113
+ main_loss_log_norm: true
114
+ sampling_algorithm: euler
115
+ sampling_steps: 20
116
+ diff_accelerator: ddim
117
+ diff_speedup: 10
118
+
119
+ # train and eval
120
+ num_sanity_val_steps: 1
121
+ optimizer_args:
122
+ lr: 0.0006
123
+ lr_scheduler_args:
124
+ step_size: 10000
125
+ gamma: 0.75
126
+ max_batch_frames: 80000
127
+ max_batch_size: 48
128
+ dataset_size_key: 'lengths'
129
+ val_check_interval: 2000
130
+ num_valid_plots: 10
131
+ max_updates: 160000
132
+ num_ckpt_keep: 5
133
+ permanent_ckpt_start: 80000
134
+ permanent_ckpt_interval: 10000
135
+
136
+ finetune_enabled: false
137
+ finetune_ckpt_path: null
138
+ finetune_ignored_params:
139
+ - model.spk_embed
140
+ - model.fs2.txt_embed
141
+ - model.fs2.encoder.embed_tokens
142
+ finetune_strict_shapes: true
143
+
144
+ freezing_enabled: false
145
+ frozen_params: []
configs/templates/config_acoustic.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config: configs/acoustic.yaml
2
+
3
+ raw_data_dir:
4
+ - data/xxx1/raw
5
+ - data/xxx2/raw
6
+ speakers:
7
+ - speaker1
8
+ - speaker2
9
+ spk_ids: []
10
+ test_prefixes:
11
+ - wav1
12
+ - wav2
13
+ - wav3
14
+ - wav4
15
+ - wav5
16
+ dictionary: dictionaries/opencpop-extension.txt
17
+ binary_data_dir: data/xxx/binary
18
+ binarization_args:
19
+ num_workers: 0
20
+ pe: parselmouth
21
+ pe_ckpt: 'checkpoints/rmvpe/model.pt'
22
+ hnsep: vr
23
+ hnsep_ckpt: 'checkpoints/vr/model.pt'
24
+ vocoder: NsfHifiGAN
25
+ vocoder_ckpt: checkpoints/nsf_hifigan_44.1k_hop512_128bin_2024.02/model.ckpt
26
+
27
+ use_spk_id: false
28
+ num_spk: 1
29
+
30
+ # NOTICE: before enabling variance embeddings, please read the docs at
31
+ # https://github.com/openvpi/DiffSinger/tree/main/docs/BestPractices.md#choosing-variance-parameters
32
+ use_energy_embed: false
33
+ use_breathiness_embed: false
34
+ use_voicing_embed: false
35
+ use_tension_embed: false
36
+
37
+ use_key_shift_embed: true
38
+ use_speed_embed: true
39
+
40
+ augmentation_args:
41
+ random_pitch_shifting:
42
+ enabled: true
43
+ range: [-5., 5.]
44
+ scale: 0.75
45
+ fixed_pitch_shifting:
46
+ enabled: false
47
+ targets: [-5., 5.]
48
+ scale: 0.5
49
+ random_time_stretching:
50
+ enabled: true
51
+ range: [0.5, 2.]
52
+ scale: 0.75
53
+
54
+ # diffusion and shallow diffusion
55
+ diffusion_type: reflow
56
+ enc_ffn_kernel_size: 3
57
+ use_rope: true
58
+ use_shallow_diffusion: true
59
+ T_start: 0.4
60
+ T_start_infer: 0.4
61
+ K_step: 300
62
+ K_step_infer: 300
63
+ backbone_type: 'lynxnet'
64
+ backbone_args:
65
+ num_channels: 1024
66
+ num_layers: 6
67
+ kernel_size: 31
68
+ dropout_rate: 0.0
69
+ strong_cond: true
70
+ #backbone_type: 'wavenet'
71
+ #backbone_args:
72
+ # num_channels: 512
73
+ # num_layers: 20
74
+ # dilation_cycle_length: 4
75
+ shallow_diffusion_args:
76
+ train_aux_decoder: true
77
+ train_diffusion: true
78
+ val_gt_start: false
79
+ aux_decoder_arch: convnext
80
+ aux_decoder_args:
81
+ num_channels: 512
82
+ num_layers: 6
83
+ kernel_size: 7
84
+ dropout_rate: 0.1
85
+ aux_decoder_grad: 0.1
86
+ lambda_aux_mel_loss: 0.2
87
+
88
+ optimizer_args:
89
+ lr: 0.0006
90
+ lr_scheduler_args:
91
+ scheduler_cls: torch.optim.lr_scheduler.StepLR
92
+ step_size: 10000
93
+ gamma: 0.75
94
+ max_batch_frames: 50000
95
+ max_batch_size: 64
96
+ max_updates: 160000
97
+
98
+ num_valid_plots: 10
99
+ val_with_vocoder: true
100
+ val_check_interval: 2000
101
+ num_ckpt_keep: 5
102
+ permanent_ckpt_start: 120000
103
+ permanent_ckpt_interval: 20000
104
+ pl_trainer_devices: 'auto'
105
+ pl_trainer_precision: '16-mixed'
configs/templates/config_variance.yaml ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - configs/variance.yaml
3
+
4
+ raw_data_dir:
5
+ - data/xxx1/raw
6
+ - data/xxx2/raw
7
+ speakers:
8
+ - speaker1
9
+ - speaker2
10
+ spk_ids: []
11
+ test_prefixes:
12
+ - wav1
13
+ - wav2
14
+ - wav3
15
+ - wav4
16
+ - wav5
17
+ dictionary: dictionaries/opencpop-extension.txt
18
+ binary_data_dir: data/xxx/binary
19
+ binarization_args:
20
+ num_workers: 0
21
+
22
+ pe: parselmouth
23
+ pe_ckpt: 'checkpoints/rmvpe/model.pt'
24
+ hnsep: vr
25
+ hnsep_ckpt: 'checkpoints/vr/model.pt'
26
+
27
+ use_spk_id: false
28
+ num_spk: 1
29
+ # NOTICE: before enabling variance modules, please read the docs at
30
+ # https://github.com/openvpi/DiffSinger/tree/main/docs/BestPractices.md#mutual-influence-between-variance-modules
31
+ predict_dur: false
32
+ predict_pitch: false
33
+ # NOTICE: before enabling variance predictions, please read the docs at
34
+ # https://github.com/openvpi/DiffSinger/tree/main/docs/BestPractices.md#choosing-variance-parameters
35
+ predict_energy: false
36
+ predict_breathiness: false
37
+ predict_voicing: false
38
+ predict_tension: false
39
+
40
+ energy_db_min: -96.0
41
+ energy_db_max: -12.0
42
+
43
+ breathiness_db_min: -96.0
44
+ breathiness_db_max: -20.0
45
+
46
+ voicing_db_min: -96.0
47
+ voicing_db_max: -12.0
48
+
49
+ tension_logit_min: -10.0
50
+ tension_logit_max: 10.0
51
+
52
+ enc_ffn_kernel_size: 3
53
+ use_rope: true
54
+ hidden_size: 256
55
+ dur_prediction_args:
56
+ arch: fs2
57
+ hidden_size: 512
58
+ dropout: 0.1
59
+ num_layers: 5
60
+ kernel_size: 3
61
+ log_offset: 1.0
62
+ loss_type: mse
63
+ lambda_pdur_loss: 0.3
64
+ lambda_wdur_loss: 1.0
65
+ lambda_sdur_loss: 3.0
66
+
67
+ use_melody_encoder: false
68
+ melody_encoder_args:
69
+ hidden_size: 128
70
+ enc_layers: 4
71
+ use_glide_embed: false
72
+ glide_types: [up, down]
73
+ glide_embed_scale: 11.313708498984760 # sqrt(128)
74
+
75
+ diffusion_type: reflow
76
+
77
+ pitch_prediction_args:
78
+ pitd_norm_min: -8.0
79
+ pitd_norm_max: 8.0
80
+ pitd_clip_min: -12.0
81
+ pitd_clip_max: 12.0
82
+ repeat_bins: 64
83
+ backbone_type: 'wavenet'
84
+ backbone_args:
85
+ num_layers: 20
86
+ num_channels: 256
87
+ dilation_cycle_length: 5
88
+ # backbone_type: 'lynxnet'
89
+ # backbone_args:
90
+ # num_layers: 6
91
+ # num_channels: 512
92
+ # dropout_rate: 0.0
93
+ # strong_cond: true
94
+
95
+ variances_prediction_args:
96
+ total_repeat_bins: 48
97
+ backbone_type: 'wavenet'
98
+ backbone_args:
99
+ num_layers: 10
100
+ num_channels: 192
101
+ dilation_cycle_length: 4
102
+ # backbone_type: 'lynxnet'
103
+ # backbone_args:
104
+ # num_layers: 6
105
+ # num_channels: 384
106
+ # dropout_rate: 0.0
107
+ # strong_cond: true
108
+
109
+ lambda_dur_loss: 1.0
110
+ lambda_pitch_loss: 1.0
111
+ lambda_var_loss: 1.0
112
+
113
+ optimizer_args:
114
+ lr: 0.0006
115
+ lr_scheduler_args:
116
+ scheduler_cls: torch.optim.lr_scheduler.StepLR
117
+ step_size: 10000
118
+ gamma: 0.75
119
+ max_batch_frames: 80000
120
+ max_batch_size: 48
121
+ max_updates: 160000
122
+
123
+ num_valid_plots: 10
124
+ val_check_interval: 2000
125
+ num_ckpt_keep: 5
126
+ permanent_ckpt_start: 80000
127
+ permanent_ckpt_interval: 10000
128
+ pl_trainer_devices: 'auto'
129
+ pl_trainer_precision: '16-mixed'
deployment/.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ *.ds
2
+ *.onnx
3
+ *.npy
4
+ *.wav
5
+ temp/
6
+ cache/
7
+ assets/
deployment/__init__.py ADDED
File without changes
deployment/benchmarks/infer_acoustic.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import onnxruntime as ort
3
+ import tqdm
4
+
5
+ n_tokens = 10
6
+ n_frames = 100
7
+ n_runs = 20
8
+ speedup = 20
9
+ provider = 'DmlExecutionProvider'
10
+
11
+ tokens = np.array([[1] * n_tokens], dtype=np.int64)
12
+ durations = np.array([[n_frames // n_tokens] * n_tokens], dtype=np.int64)
13
+ f0 = np.array([[440.] * n_frames], dtype=np.float32)
14
+ speedup = np.array(speedup, dtype=np.int64)
15
+
16
+ session = ort.InferenceSession('model1.onnx', providers=[provider])
17
+ for _ in tqdm.tqdm(range(n_runs)):
18
+ session.run(['mel'], {
19
+ 'tokens': tokens,
20
+ 'durations': durations,
21
+ 'f0': f0,
22
+ 'speedup': speedup
23
+ })
24
+
25
+ session = ort.InferenceSession('model2.onnx', providers=[provider])
26
+ for _ in tqdm.tqdm(range(n_runs)):
27
+ session.run(['mel'], {
28
+ 'tokens': tokens,
29
+ 'durations': durations,
30
+ 'f0': f0,
31
+ 'speedup': speedup
32
+ })
deployment/benchmarks/infer_nsf_hifigan.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import onnxruntime as ort
3
+ import tqdm
4
+
5
+ n_frames = 1000
6
+ n_runs = 20
7
+ mel = np.random.randn(1, n_frames, 128).astype(np.float32)
8
+ f0 = np.random.randn(1, n_frames).astype(np.float32) + 440.
9
+ provider = 'DmlExecutionProvider'
10
+
11
+ session = ort.InferenceSession('nsf_hifigan.onnx', providers=[provider])
12
+ for _ in tqdm.tqdm(range(n_runs)):
13
+ session.run(['waveform'], {
14
+ 'mel': mel,
15
+ 'f0': f0
16
+ })
deployment/exporters/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .acoustic_exporter import DiffSingerAcousticExporter
2
+ from .variance_exporter import DiffSingerVarianceExporter
3
+ from .nsf_hifigan_exporter import NSFHiFiGANExporter
deployment/exporters/acoustic_exporter.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from pathlib import Path
3
+ from typing import List, Union, Tuple, Dict
4
+
5
+ import onnx
6
+ import onnxsim
7
+ import torch
8
+ import yaml
9
+
10
+ from basics.base_exporter import BaseExporter
11
+ from deployment.modules.toplevel import DiffSingerAcousticONNX
12
+ from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST
13
+ from utils import load_ckpt, onnx_helper, remove_suffix
14
+ from utils.hparams import hparams
15
+ from utils.phoneme_utils import locate_dictionary, build_phoneme_list
16
+ from utils.text_encoder import TokenTextEncoder
17
+
18
+
19
+ class DiffSingerAcousticExporter(BaseExporter):
20
+ def __init__(
21
+ self,
22
+ device: Union[str, torch.device] = 'cpu',
23
+ cache_dir: Path = None,
24
+ ckpt_steps: int = None,
25
+ freeze_gender: float = None,
26
+ freeze_velocity: bool = False,
27
+ export_spk: List[Tuple[str, Dict[str, float]]] = None,
28
+ freeze_spk: Tuple[str, Dict[str, float]] = None
29
+ ):
30
+ super().__init__(device=device, cache_dir=cache_dir)
31
+ # Basic attributes
32
+ self.model_name: str = hparams['exp_name']
33
+ self.ckpt_steps: int = ckpt_steps
34
+ self.spk_map: dict = self.build_spk_map()
35
+ self.vocab = TokenTextEncoder(vocab_list=build_phoneme_list())
36
+ self.model = self.build_model()
37
+ self.fs2_aux_cache_path = self.cache_dir / (
38
+ 'fs2_aux.onnx' if self.model.use_shallow_diffusion else 'fs2.onnx'
39
+ )
40
+ self.diffusion_cache_path = self.cache_dir / 'diffusion.onnx'
41
+
42
+ # Attributes for logging
43
+ self.model_class_name = remove_suffix(self.model.__class__.__name__, 'ONNX')
44
+ fs2_aux_cls_logging = [remove_suffix(self.model.fs2.__class__.__name__, 'ONNX')]
45
+ if self.model.use_shallow_diffusion:
46
+ fs2_aux_cls_logging.append(remove_suffix(
47
+ self.model.aux_decoder.decoder.__class__.__name__, 'ONNX'
48
+ ))
49
+ self.fs2_aux_class_name = ', '.join(fs2_aux_cls_logging)
50
+ self.aux_decoder_class_name = remove_suffix(
51
+ self.model.aux_decoder.decoder.__class__.__name__, 'ONNX'
52
+ ) if self.model.use_shallow_diffusion else None
53
+ self.backbone_class_name = remove_suffix(self.model.diffusion.backbone.__class__.__name__, 'ONNX')
54
+ self.diffusion_class_name = remove_suffix(self.model.diffusion.__class__.__name__, 'ONNX')
55
+
56
+ # Attributes for exporting
57
+ self.expose_gender = freeze_gender is None
58
+ self.expose_velocity = not freeze_velocity
59
+ self.freeze_spk: Tuple[str, Dict[str, float]] = freeze_spk \
60
+ if hparams['use_spk_id'] else None
61
+ self.export_spk: List[Tuple[str, Dict[str, float]]] = export_spk \
62
+ if hparams['use_spk_id'] and export_spk is not None else []
63
+ if hparams['use_key_shift_embed'] and not self.expose_gender:
64
+ shift_min, shift_max = hparams['augmentation_args']['random_pitch_shifting']['range']
65
+ key_shift = freeze_gender * shift_max if freeze_gender >= 0. else freeze_gender * abs(shift_min)
66
+ key_shift = max(min(key_shift, shift_max), shift_min) # clip key shift
67
+ self.model.fs2.register_buffer('frozen_key_shift', torch.FloatTensor([key_shift]).to(self.device))
68
+ if hparams['use_spk_id']:
69
+ if not self.export_spk and self.freeze_spk is None:
70
+ # In case the user did not specify any speaker settings:
71
+ if len(self.spk_map) == 1:
72
+ # If there is only one speaker, freeze him/her.
73
+ first_spk = next(iter(self.spk_map.keys()))
74
+ self.freeze_spk = (first_spk, {first_spk: 1.0})
75
+ else:
76
+ # If there are multiple speakers, export them all.
77
+ self.export_spk = [(name, {name: 1.0}) for name in self.spk_map.keys()]
78
+ if self.freeze_spk is not None:
79
+ self.model.fs2.register_buffer('frozen_spk_embed', self._perform_spk_mix(self.freeze_spk[1]))
80
+
81
+ def build_model(self) -> DiffSingerAcousticONNX:
82
+ model = DiffSingerAcousticONNX(
83
+ vocab_size=len(self.vocab),
84
+ out_dims=hparams['audio_num_mel_bins']
85
+ ).eval().to(self.device)
86
+ load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps,
87
+ prefix_in_ckpt='model', strict=True, device=self.device)
88
+ return model
89
+
90
+ def export(self, path: Path):
91
+ path.mkdir(parents=True, exist_ok=True)
92
+ model_name = self.model_name
93
+ if self.freeze_spk is not None:
94
+ model_name += '.' + self.freeze_spk[0]
95
+ self.export_model(path / f'{model_name}.onnx')
96
+ self.export_attachments(path)
97
+
98
+ def export_model(self, path: Path):
99
+ self._torch_export_model()
100
+ fs2_aux_onnx = self._optimize_fs2_aux_graph(onnx.load(self.fs2_aux_cache_path))
101
+ diffusion_onnx = self._optimize_diffusion_graph(onnx.load(self.diffusion_cache_path))
102
+ model_onnx = self._merge_fs2_aux_diffusion_graphs(fs2_aux_onnx, diffusion_onnx)
103
+ onnx.save(model_onnx, path)
104
+ self.fs2_aux_cache_path.unlink()
105
+ self.diffusion_cache_path.unlink()
106
+ print(f'| export model => {path}')
107
+
108
+ def export_attachments(self, path: Path):
109
+ for spk in self.export_spk:
110
+ self._export_spk_embed(
111
+ path / f'{self.model_name}.{spk[0]}.emb',
112
+ self._perform_spk_mix(spk[1])
113
+ )
114
+ self._export_dictionary(path / 'dictionary.txt')
115
+ self._export_phonemes(path / f'{self.model_name}.phonemes.txt')
116
+
117
+ model_name = self.model_name
118
+ if self.freeze_spk is not None:
119
+ model_name += '.' + self.freeze_spk[0]
120
+ dsconfig = {
121
+ # basic configs
122
+ 'phonemes': f'{self.model_name}.phonemes.txt',
123
+ 'acoustic': f'{model_name}.onnx',
124
+ 'hidden_size': hparams['hidden_size'],
125
+ 'vocoder': 'nsf_hifigan_44.1k_hop512_128bin_2024.02',
126
+ }
127
+ # multi-speaker
128
+ if len(self.export_spk) > 0:
129
+ dsconfig['speakers'] = [f'{self.model_name}.{spk[0]}' for spk in self.export_spk]
130
+ # parameters
131
+ if self.expose_gender:
132
+ dsconfig['augmentation_args'] = {
133
+ 'random_pitch_shifting': {
134
+ 'range': hparams['augmentation_args']['random_pitch_shifting']['range']
135
+ }
136
+ }
137
+ dsconfig['use_key_shift_embed'] = self.expose_gender
138
+ dsconfig['use_speed_embed'] = self.expose_velocity
139
+ for variance in VARIANCE_CHECKLIST:
140
+ dsconfig[f'use_{variance}_embed'] = (variance in self.model.fs2.variance_embed_list)
141
+ # sampling acceleration and shallow diffusion
142
+ dsconfig['use_continuous_acceleration'] = True
143
+ dsconfig['use_variable_depth'] = self.model.use_shallow_diffusion
144
+ dsconfig['max_depth'] = 1 - self.model.diffusion.t_start
145
+ # mel specification
146
+ dsconfig['sample_rate'] = hparams['audio_sample_rate']
147
+ dsconfig['hop_size'] = hparams['hop_size']
148
+ dsconfig['win_size'] = hparams['win_size']
149
+ dsconfig['fft_size'] = hparams['fft_size']
150
+ dsconfig['num_mel_bins'] = hparams['audio_num_mel_bins']
151
+ dsconfig['mel_fmin'] = hparams['fmin']
152
+ dsconfig['mel_fmax'] = hparams['fmax'] if hparams['fmax'] is not None else hparams['audio_sample_rate'] / 2
153
+ dsconfig['mel_base'] = 'e'
154
+ dsconfig['mel_scale'] = 'slaney'
155
+ config_path = path / 'dsconfig.yaml'
156
+ with open(config_path, 'w', encoding='utf8') as fw:
157
+ yaml.safe_dump(dsconfig, fw, sort_keys=False)
158
+ print(f'| export configs => {config_path} **PLEASE EDIT BEFORE USE**')
159
+
160
+ @torch.no_grad()
161
+ def _torch_export_model(self):
162
+ # Prepare inputs for FastSpeech2 and aux decoder tracing
163
+ n_frames = 10
164
+ tokens = torch.LongTensor([[1]]).to(self.device)
165
+ durations = torch.LongTensor([[n_frames]]).to(self.device)
166
+ f0 = torch.FloatTensor([[440.] * n_frames]).to(self.device)
167
+ variances = {
168
+ v_name: torch.zeros(1, n_frames, dtype=torch.float32, device=self.device)
169
+ for v_name in self.model.fs2.variance_embed_list
170
+ }
171
+ kwargs: Dict[str, torch.Tensor] = {}
172
+ arguments = (tokens, durations, f0, variances, kwargs)
173
+ input_names = ['tokens', 'durations', 'f0'] + self.model.fs2.variance_embed_list
174
+ dynamix_axes = {
175
+ 'tokens': {
176
+ 1: 'n_tokens'
177
+ },
178
+ 'durations': {
179
+ 1: 'n_tokens'
180
+ },
181
+ 'f0': {
182
+ 1: 'n_frames'
183
+ },
184
+ **{
185
+ v_name: {
186
+ 1: 'n_frames'
187
+ }
188
+ for v_name in self.model.fs2.variance_embed_list
189
+ }
190
+ }
191
+ if hparams['use_key_shift_embed']:
192
+ if self.expose_gender:
193
+ kwargs['gender'] = torch.rand((1, n_frames), dtype=torch.float32, device=self.device)
194
+ input_names.append('gender')
195
+ dynamix_axes['gender'] = {
196
+ 1: 'n_frames'
197
+ }
198
+ if hparams['use_speed_embed']:
199
+ if self.expose_velocity:
200
+ kwargs['velocity'] = torch.rand((1, n_frames), dtype=torch.float32, device=self.device)
201
+ input_names.append('velocity')
202
+ dynamix_axes['velocity'] = {
203
+ 1: 'n_frames'
204
+ }
205
+ if hparams['use_spk_id'] and not self.freeze_spk:
206
+ kwargs['spk_embed'] = torch.rand(
207
+ (1, n_frames, hparams['hidden_size']),
208
+ dtype=torch.float32, device=self.device
209
+ )
210
+ input_names.append('spk_embed')
211
+ dynamix_axes['spk_embed'] = {
212
+ 1: 'n_frames'
213
+ }
214
+ dynamix_axes['condition'] = {
215
+ 1: 'n_frames'
216
+ }
217
+
218
+ # PyTorch ONNX export for FastSpeech2 and aux decoder
219
+ output_names = ['condition']
220
+ if self.model.use_shallow_diffusion:
221
+ output_names.append('aux_mel')
222
+ dynamix_axes['aux_mel'] = {
223
+ 1: 'n_frames'
224
+ }
225
+ print(f'Exporting {self.fs2_aux_class_name}...')
226
+ torch.onnx.export(
227
+ self.model.view_as_fs2_aux(),
228
+ arguments,
229
+ self.fs2_aux_cache_path,
230
+ input_names=input_names,
231
+ output_names=output_names,
232
+ dynamic_axes=dynamix_axes,
233
+ opset_version=15
234
+ )
235
+
236
+ condition = torch.rand((1, n_frames, hparams['hidden_size']), device=self.device)
237
+
238
+ # Prepare inputs for backbone tracing and GaussianDiffusion scripting
239
+ shape = (1, 1, hparams['audio_num_mel_bins'], n_frames)
240
+ noise = torch.randn(shape, device=self.device)
241
+ x_aux = torch.randn((1, n_frames, hparams['audio_num_mel_bins']), device=self.device)
242
+ dummy_time = (torch.rand((1,), device=self.device) * self.model.diffusion.time_scale_factor).float()
243
+ dummy_depth = torch.tensor(0.1, device=self.device)
244
+ dummy_steps = 5
245
+
246
+ print(f'Tracing {self.backbone_class_name} backbone...')
247
+ if self.model.diffusion_type == 'ddpm':
248
+ major_mel_decoder = self.model.view_as_diffusion()
249
+ elif self.model.diffusion_type == 'reflow':
250
+ major_mel_decoder = self.model.view_as_reflow()
251
+ else:
252
+ raise ValueError(f'Invalid diffusion type: {self.model.diffusion_type}')
253
+ major_mel_decoder.diffusion.set_backbone(
254
+ torch.jit.trace(
255
+ major_mel_decoder.diffusion.backbone,
256
+ (
257
+ noise,
258
+ dummy_time,
259
+ condition.transpose(1, 2)
260
+ )
261
+ )
262
+ )
263
+
264
+ print(f'Scripting {self.diffusion_class_name}...')
265
+ diffusion_inputs = [
266
+ condition,
267
+ *([x_aux, dummy_depth] if self.model.use_shallow_diffusion else [])
268
+ ]
269
+ major_mel_decoder = torch.jit.script(
270
+ major_mel_decoder,
271
+ example_inputs=[
272
+ (
273
+ *diffusion_inputs,
274
+ 1 # p_sample branch
275
+ ),
276
+ (
277
+ *diffusion_inputs,
278
+ dummy_steps # p_sample_plms branch
279
+ )
280
+ ]
281
+ )
282
+
283
+ # PyTorch ONNX export for GaussianDiffusion
284
+ print(f'Exporting {self.diffusion_class_name}...')
285
+ torch.onnx.export(
286
+ major_mel_decoder,
287
+ (
288
+ *diffusion_inputs,
289
+ dummy_steps
290
+ ),
291
+ self.diffusion_cache_path,
292
+ input_names=[
293
+ 'condition',
294
+ *(['x_aux', 'depth'] if self.model.use_shallow_diffusion else []),
295
+ 'steps'
296
+ ],
297
+ output_names=[
298
+ 'mel'
299
+ ],
300
+ dynamic_axes={
301
+ 'condition': {
302
+ 1: 'n_frames'
303
+ },
304
+ **({'x_aux': {1: 'n_frames'}} if self.model.use_shallow_diffusion else {}),
305
+ 'mel': {
306
+ 1: 'n_frames'
307
+ }
308
+ },
309
+ opset_version=15
310
+ )
311
+
312
+ @torch.no_grad()
313
+ def _perform_spk_mix(self, spk_mix: Dict[str, float]):
314
+ spk_mix_ids = []
315
+ spk_mix_values = []
316
+ for name, value in spk_mix.items():
317
+ spk_mix_ids.append(self.spk_map[name])
318
+ assert value >= 0., f'Speaker mix checks failed.\n' \
319
+ f'Proportion of speaker \'{name}\' is negative.'
320
+ spk_mix_values.append(value)
321
+ spk_mix_id_N = torch.LongTensor(spk_mix_ids).to(self.device)[None] # => [1, N]
322
+ spk_mix_value_N = torch.FloatTensor(spk_mix_values).to(self.device)[None] # => [1, N]
323
+ spk_mix_value_sum = spk_mix_value_N.sum()
324
+ assert spk_mix_value_sum > 0., f'Speaker mix checks failed.\n' \
325
+ f'Proportions of speaker mix sum to zero.'
326
+ spk_mix_value_N /= spk_mix_value_sum # normalize
327
+ spk_mix_embed = torch.sum(
328
+ self.model.fs2.spk_embed(spk_mix_id_N) * spk_mix_value_N.unsqueeze(2), # => [1, N, H]
329
+ dim=1, keepdim=False
330
+ ) # => [1, H]
331
+ return spk_mix_embed
332
+
333
+ def _optimize_fs2_aux_graph(self, fs2: onnx.ModelProto) -> onnx.ModelProto:
334
+ print(f'Running ONNX Simplifier on {self.fs2_aux_class_name}...')
335
+ fs2, check = onnxsim.simplify(fs2, include_subgraph=True)
336
+ assert check, 'Simplified ONNX model could not be validated'
337
+ print(f'| optimize graph: {self.fs2_aux_class_name}')
338
+ return fs2
339
+
340
+ def _optimize_diffusion_graph(self, diffusion: onnx.ModelProto) -> onnx.ModelProto:
341
+ onnx_helper.model_override_io_shapes(diffusion, output_shapes={
342
+ 'mel': (1, 'n_frames', hparams['audio_num_mel_bins'])
343
+ })
344
+ print(f'Running ONNX Simplifier #1 on {self.diffusion_class_name}...')
345
+ diffusion, check = onnxsim.simplify(diffusion, include_subgraph=True)
346
+ assert check, 'Simplified ONNX model could not be validated'
347
+ onnx_helper.graph_fold_back_to_squeeze(diffusion.graph)
348
+ onnx_helper.graph_extract_conditioner_projections(
349
+ graph=diffusion.graph, op_type='Conv',
350
+ weight_pattern=r'diffusion\..*\.conditioner_projection\.weight',
351
+ alias_prefix='/diffusion/backbone/cache'
352
+ )
353
+ onnx_helper.graph_remove_unused_values(diffusion.graph)
354
+ print(f'Running ONNX Simplifier #2 on {self.diffusion_class_name}...')
355
+ diffusion, check = onnxsim.simplify(
356
+ diffusion,
357
+ include_subgraph=True
358
+ )
359
+ assert check, 'Simplified ONNX model could not be validated'
360
+ print(f'| optimize graph: {self.diffusion_class_name}')
361
+ return diffusion
362
+
363
+ def _merge_fs2_aux_diffusion_graphs(self, fs2: onnx.ModelProto, diffusion: onnx.ModelProto) -> onnx.ModelProto:
364
+ onnx_helper.model_add_prefixes(
365
+ fs2, dim_prefix=('fs2aux.' if self.model.use_shallow_diffusion else 'fs2.'),
366
+ ignored_pattern=r'(n_tokens)|(n_frames)'
367
+ )
368
+ onnx_helper.model_add_prefixes(diffusion, dim_prefix='diffusion.', ignored_pattern='n_frames')
369
+ print(f'Merging {self.fs2_aux_class_name} and {self.diffusion_class_name} '
370
+ f'back into {self.model_class_name}...')
371
+ merged = onnx.compose.merge_models(
372
+ fs2, diffusion, io_map=[
373
+ ('condition', 'condition'),
374
+ *([('aux_mel', 'x_aux')] if self.model.use_shallow_diffusion else []),
375
+ ],
376
+ prefix1='', prefix2='', doc_string='',
377
+ producer_name=fs2.producer_name, producer_version=fs2.producer_version,
378
+ domain=fs2.domain, model_version=fs2.model_version
379
+ )
380
+ merged.graph.name = fs2.graph.name
381
+
382
+ print(f'Running ONNX Simplifier on {self.model_class_name}...')
383
+ merged, check = onnxsim.simplify(
384
+ merged,
385
+ include_subgraph=True
386
+ )
387
+ assert check, 'Simplified ONNX model could not be validated'
388
+ print(f'| optimize graph: {self.model_class_name}')
389
+
390
+ return merged
391
+
392
+ # noinspection PyMethodMayBeStatic
393
+ def _export_spk_embed(self, path: Path, spk_embed: torch.Tensor):
394
+ with open(path, 'wb') as f:
395
+ f.write(spk_embed.cpu().numpy().tobytes())
396
+ print(f'| export spk embed => {path}')
397
+
398
+ # noinspection PyMethodMayBeStatic
399
+ def _export_dictionary(self, path: Path):
400
+ print(f'| export dictionary => {path}')
401
+ shutil.copy(locate_dictionary(), path)
402
+
403
+ def _export_phonemes(self, path: Path):
404
+ self.vocab.store_to_file(path)
405
+ print(f'| export phonemes => {path}')
deployment/exporters/nsf_hifigan_exporter.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+ import onnx
6
+ import onnxsim
7
+ import torch
8
+ import yaml
9
+ from torch import nn
10
+
11
+ from basics.base_exporter import BaseExporter
12
+ from deployment.modules.nsf_hifigan import NSFHiFiGANONNX
13
+ from utils import load_ckpt, remove_suffix
14
+ from utils.hparams import hparams
15
+
16
+
17
+ class NSFHiFiGANExporter(BaseExporter):
18
+ def __init__(
19
+ self,
20
+ device: Union[str, torch.device] = 'cpu',
21
+ cache_dir: Path = None,
22
+ model_path: Path = None,
23
+ model_name: str = 'nsf_hifigan'
24
+ ):
25
+ super().__init__(device=device, cache_dir=cache_dir)
26
+ self.model_path = model_path
27
+ self.model_name = model_name
28
+ self.model = self.build_model()
29
+ self.model_class_name = remove_suffix(self.model.__class__.__name__, 'ONNX')
30
+ self.model_cache_path = (self.cache_dir / self.model_name).with_suffix('.onnx')
31
+
32
+ def build_model(self) -> nn.Module:
33
+ config_path = self.model_path.with_name('config.json')
34
+ with open(config_path, 'r', encoding='utf8') as f:
35
+ config = json.load(f)
36
+ assert hparams.get('mel_base') == 'e', (
37
+ "Mel base must be set to \'e\' according to 2nd stage of the migration plan. "
38
+ "See https://github.com/openvpi/DiffSinger/releases/tag/v2.3.0 for more details."
39
+ )
40
+ model = NSFHiFiGANONNX(config).eval().to(self.device)
41
+ load_ckpt(model.generator, str(self.model_path),
42
+ prefix_in_ckpt=None, key_in_ckpt='generator',
43
+ strict=True, device=self.device)
44
+ model.generator.remove_weight_norm()
45
+ return model
46
+
47
+ def export(self, path: Path):
48
+ path.mkdir(parents=True, exist_ok=True)
49
+ self.export_model(path / self.model_cache_path.name)
50
+ self.export_attachments(path)
51
+
52
+ def export_model(self, path: Path):
53
+ self._torch_export_model()
54
+ model_onnx = self._optimize_model_graph(onnx.load(self.model_cache_path))
55
+ onnx.save(model_onnx, path)
56
+ self.model_cache_path.unlink()
57
+ print(f'| export model => {path}')
58
+
59
+ def export_attachments(self, path: Path):
60
+ config_path = path / 'vocoder.yaml'
61
+ with open(config_path, 'w', encoding='utf8') as fw:
62
+ yaml.safe_dump({
63
+ # basic configs
64
+ 'name': self.model_name,
65
+ 'model': self.model_cache_path.name,
66
+ # mel specifications
67
+ 'sample_rate': hparams['audio_sample_rate'],
68
+ 'hop_size': hparams['hop_size'],
69
+ 'win_size': hparams['win_size'],
70
+ 'fft_size': hparams['fft_size'],
71
+ 'num_mel_bins': hparams['audio_num_mel_bins'],
72
+ 'mel_fmin': hparams['fmin'],
73
+ 'mel_fmax': hparams['fmax'] if hparams['fmax'] is not None else hparams['audio_sample_rate'] / 2,
74
+ 'mel_base': 'e',
75
+ 'mel_scale': 'slaney',
76
+ }, fw, sort_keys=False)
77
+ print(f'| export configs => {config_path} **PLEASE EDIT BEFORE USE**')
78
+
79
+ @torch.no_grad()
80
+ def _torch_export_model(self):
81
+ # Prepare inputs for NSFHiFiGAN
82
+ n_frames = 10
83
+ mel = torch.randn((1, n_frames, hparams['audio_num_mel_bins']), dtype=torch.float32, device=self.device)
84
+ f0 = torch.randn((1, n_frames), dtype=torch.float32, device=self.device) + 440.
85
+
86
+ # PyTorch ONNX export for NSFHiFiGAN
87
+ print(f'Exporting {self.model_class_name}...')
88
+ torch.onnx.export(
89
+ self.model,
90
+ (
91
+ mel,
92
+ f0
93
+ ),
94
+ self.model_cache_path,
95
+ input_names=[
96
+ 'mel',
97
+ 'f0'
98
+ ],
99
+ output_names=[
100
+ 'waveform'
101
+ ],
102
+ dynamic_axes={
103
+ 'mel': {
104
+ 1: 'n_frames'
105
+ },
106
+ 'f0': {
107
+ 1: 'n_frames'
108
+ },
109
+ 'waveform': {
110
+ 1: 'n_samples'
111
+ }
112
+ },
113
+ opset_version=15
114
+ )
115
+
116
+ def _optimize_model_graph(self, model: onnx.ModelProto) -> onnx.ModelProto:
117
+ print(f'Running ONNX simplifier for {self.model_class_name}...')
118
+ model, check = onnxsim.simplify(model, include_subgraph=True)
119
+ assert check, 'Simplified ONNX model could not be validated'
120
+ return model
deployment/exporters/variance_exporter.py ADDED
@@ -0,0 +1,781 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from pathlib import Path
3
+ from typing import Union, List, Tuple, Dict
4
+
5
+ import onnx
6
+ import onnxsim
7
+ import torch
8
+ import yaml
9
+
10
+ from basics.base_exporter import BaseExporter
11
+ from deployment.modules.toplevel import DiffSingerVarianceONNX
12
+ from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST
13
+ from utils import load_ckpt, onnx_helper, remove_suffix
14
+ from utils.hparams import hparams
15
+ from utils.phoneme_utils import locate_dictionary, build_phoneme_list
16
+ from utils.text_encoder import TokenTextEncoder
17
+
18
+
19
+ class DiffSingerVarianceExporter(BaseExporter):
20
+ def __init__(
21
+ self,
22
+ device: Union[str, torch.device] = 'cpu',
23
+ cache_dir: Path = None,
24
+ ckpt_steps: int = None,
25
+ freeze_glide: bool = False,
26
+ freeze_expr: bool = False,
27
+ export_spk: List[Tuple[str, Dict[str, float]]] = None,
28
+ freeze_spk: Tuple[str, Dict[str, float]] = None
29
+ ):
30
+ super().__init__(device=device, cache_dir=cache_dir)
31
+ # Basic attributes
32
+ self.model_name: str = hparams['exp_name']
33
+ self.ckpt_steps: int = ckpt_steps
34
+ self.spk_map: dict = self.build_spk_map()
35
+ self.vocab = TokenTextEncoder(vocab_list=build_phoneme_list())
36
+ self.model = self.build_model()
37
+ self.linguistic_encoder_cache_path = self.cache_dir / 'linguistic.onnx'
38
+ self.dur_predictor_cache_path = self.cache_dir / 'dur.onnx'
39
+ self.pitch_preprocess_cache_path = self.cache_dir / 'pitch_pre.onnx'
40
+ self.pitch_predictor_cache_path = self.cache_dir / 'pitch.onnx'
41
+ self.pitch_postprocess_cache_path = self.cache_dir / 'pitch_post.onnx'
42
+ self.variance_preprocess_cache_path = self.cache_dir / 'variance_pre.onnx'
43
+ self.multi_var_predictor_cache_path = self.cache_dir / 'variance.onnx'
44
+ self.variance_postprocess_cache_path = self.cache_dir / 'variance_post.onnx'
45
+
46
+ # Attributes for logging
47
+ self.fs2_class_name = remove_suffix(self.model.fs2.__class__.__name__, 'ONNX')
48
+ self.dur_predictor_class_name = \
49
+ remove_suffix(self.model.fs2.dur_predictor.__class__.__name__, 'ONNX') \
50
+ if self.model.predict_dur else None
51
+ self.pitch_backbone_class_name = \
52
+ remove_suffix(self.model.pitch_predictor.backbone.__class__.__name__, 'ONNX') \
53
+ if self.model.predict_pitch else None
54
+ self.pitch_predictor_class_name = \
55
+ remove_suffix(self.model.pitch_predictor.__class__.__name__, 'ONNX') \
56
+ if self.model.predict_pitch else None
57
+ self.variance_backbone_class_name = \
58
+ remove_suffix(self.model.variance_predictor.backbone.__class__.__name__, 'ONNX') \
59
+ if self.model.predict_variances else None
60
+ self.multi_var_predictor_class_name = \
61
+ remove_suffix(self.model.variance_predictor.__class__.__name__, 'ONNX') \
62
+ if self.model.predict_variances else None
63
+
64
+ # Attributes for exporting
65
+ self.expose_expr = not freeze_expr
66
+ self.freeze_glide = freeze_glide
67
+ self.freeze_spk: Tuple[str, Dict[str, float]] = freeze_spk \
68
+ if hparams['use_spk_id'] else None
69
+ self.export_spk: List[Tuple[str, Dict[str, float]]] = export_spk \
70
+ if hparams['use_spk_id'] and export_spk is not None else []
71
+ if hparams['use_spk_id']:
72
+ if not self.export_spk and self.freeze_spk is None:
73
+ # In case the user did not specify any speaker settings:
74
+ if len(self.spk_map) == 1:
75
+ # If there is only one speaker, freeze him/her.
76
+ first_spk = next(iter(self.spk_map.keys()))
77
+ self.freeze_spk = (first_spk, {first_spk: 1.0})
78
+ else:
79
+ # If there are multiple speakers, export them all.
80
+ self.export_spk = [(name, {name: 1.0}) for name in self.spk_map.keys()]
81
+ if self.freeze_spk is not None:
82
+ self.model.register_buffer('frozen_spk_embed', self._perform_spk_mix(self.freeze_spk[1]))
83
+
84
+ def build_model(self) -> DiffSingerVarianceONNX:
85
+ model = DiffSingerVarianceONNX(
86
+ vocab_size=len(self.vocab)
87
+ ).eval().to(self.device)
88
+ load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps,
89
+ prefix_in_ckpt='model', strict=True, device=self.device)
90
+ model.build_smooth_op(self.device)
91
+ return model
92
+
93
+ def export(self, path: Path):
94
+ path.mkdir(parents=True, exist_ok=True)
95
+ model_name = self.model_name
96
+ if self.freeze_spk is not None:
97
+ model_name += '.' + self.freeze_spk[0]
98
+ self.export_model(path, model_name)
99
+ self.export_attachments(path)
100
+
101
+ def export_model(self, path: Path, model_name: str = None):
102
+ self._torch_export_model()
103
+ linguistic_onnx = self._optimize_linguistic_graph(onnx.load(self.linguistic_encoder_cache_path))
104
+ linguistic_path = path / f'{model_name}.linguistic.onnx'
105
+ onnx.save(linguistic_onnx, linguistic_path)
106
+ print(f'| export linguistic encoder => {linguistic_path}')
107
+ self.linguistic_encoder_cache_path.unlink()
108
+ if self.model.predict_dur:
109
+ dur_predictor_onnx = self._optimize_dur_predictor_graph(onnx.load(self.dur_predictor_cache_path))
110
+ dur_predictor_path = path / f'{model_name}.dur.onnx'
111
+ onnx.save(dur_predictor_onnx, dur_predictor_path)
112
+ self.dur_predictor_cache_path.unlink()
113
+ print(f'| export dur predictor => {dur_predictor_path}')
114
+ if self.model.predict_pitch:
115
+ pitch_predictor_onnx = self._optimize_merge_pitch_predictor_graph(
116
+ onnx.load(self.pitch_preprocess_cache_path),
117
+ onnx.load(self.pitch_predictor_cache_path),
118
+ onnx.load(self.pitch_postprocess_cache_path)
119
+ )
120
+ pitch_predictor_path = path / f'{model_name}.pitch.onnx'
121
+ onnx.save(pitch_predictor_onnx, pitch_predictor_path)
122
+ self.pitch_preprocess_cache_path.unlink()
123
+ self.pitch_predictor_cache_path.unlink()
124
+ self.pitch_postprocess_cache_path.unlink()
125
+ print(f'| export pitch predictor => {pitch_predictor_path}')
126
+ if self.model.predict_variances:
127
+ variance_predictor_onnx = self._optimize_merge_variance_predictor_graph(
128
+ onnx.load(self.variance_preprocess_cache_path),
129
+ onnx.load(self.multi_var_predictor_cache_path),
130
+ onnx.load(self.variance_postprocess_cache_path)
131
+ )
132
+ variance_predictor_path = path / f'{model_name}.variance.onnx'
133
+ onnx.save(variance_predictor_onnx, variance_predictor_path)
134
+ self.variance_preprocess_cache_path.unlink()
135
+ self.multi_var_predictor_cache_path.unlink()
136
+ self.variance_postprocess_cache_path.unlink()
137
+ print(f'| export variance predictor => {variance_predictor_path}')
138
+
139
+ def export_attachments(self, path: Path):
140
+ for spk in self.export_spk:
141
+ self._export_spk_embed(
142
+ path / f'{self.model_name}.{spk[0]}.emb',
143
+ self._perform_spk_mix(spk[1])
144
+ )
145
+ self._export_dictionary(path / 'dictionary.txt')
146
+ self._export_phonemes((path / f'{self.model_name}.phonemes.txt'))
147
+
148
+ model_name = self.model_name
149
+ if self.freeze_spk is not None:
150
+ model_name += '.' + self.freeze_spk[0]
151
+ dsconfig = {
152
+ # basic configs
153
+ 'phonemes': f'{self.model_name}.phonemes.txt',
154
+ 'linguistic': f'{model_name}.linguistic.onnx',
155
+ 'hidden_size': self.model.hidden_size,
156
+ 'predict_dur': self.model.predict_dur,
157
+ }
158
+ # multi-speaker
159
+ if len(self.export_spk) > 0:
160
+ dsconfig['speakers'] = [f'{self.model_name}.{spk[0]}' for spk in self.export_spk]
161
+ # functionalities
162
+ if self.model.predict_dur:
163
+ dsconfig['dur'] = f'{model_name}.dur.onnx'
164
+ if self.model.predict_pitch:
165
+ dsconfig['pitch'] = f'{model_name}.pitch.onnx'
166
+ dsconfig['use_expr'] = self.expose_expr
167
+ dsconfig['use_note_rest'] = self.model.use_melody_encoder
168
+ if self.model.predict_variances:
169
+ dsconfig['variance'] = f'{model_name}.variance.onnx'
170
+ for variance in VARIANCE_CHECKLIST:
171
+ dsconfig[f'predict_{variance}'] = (variance in self.model.variance_prediction_list)
172
+ # sampling acceleration
173
+ dsconfig['use_continuous_acceleration'] = True
174
+ # frame specifications
175
+ dsconfig['sample_rate'] = hparams['audio_sample_rate']
176
+ dsconfig['hop_size'] = hparams['hop_size']
177
+ config_path = path / 'dsconfig.yaml'
178
+ with open(config_path, 'w', encoding='utf8') as fw:
179
+ yaml.safe_dump(dsconfig, fw, sort_keys=False)
180
+ print(f'| export configs => {config_path} **PLEASE EDIT BEFORE USE**')
181
+
182
+ @torch.no_grad()
183
+ def _torch_export_model(self):
184
+ # Prepare inputs for FastSpeech2 and dur predictor tracing
185
+ tokens = torch.LongTensor([[1] * 5]).to(self.device)
186
+ ph_dur = torch.LongTensor([[3, 5, 2, 1, 4]]).to(self.device)
187
+ word_div = torch.LongTensor([[2, 2, 1]]).to(self.device)
188
+ word_dur = torch.LongTensor([[8, 3, 4]]).to(self.device)
189
+ encoder_out = torch.rand(1, 5, hparams['hidden_size'], dtype=torch.float32, device=self.device)
190
+ x_masks = tokens == 0
191
+ ph_midi = torch.LongTensor([[60] * 5]).to(self.device)
192
+ encoder_output_names = ['encoder_out', 'x_masks']
193
+ encoder_common_axes = {
194
+ 'encoder_out': {
195
+ 1: 'n_tokens'
196
+ },
197
+ 'x_masks': {
198
+ 1: 'n_tokens'
199
+ }
200
+ }
201
+ input_spk_embed = hparams['use_spk_id'] and not self.freeze_spk
202
+
203
+ print(f'Exporting {self.fs2_class_name}...')
204
+ if self.model.predict_dur:
205
+ torch.onnx.export(
206
+ self.model.view_as_linguistic_encoder(),
207
+ (
208
+ tokens,
209
+ word_div,
210
+ word_dur
211
+ ),
212
+ self.linguistic_encoder_cache_path,
213
+ input_names=[
214
+ 'tokens',
215
+ 'word_div',
216
+ 'word_dur'
217
+ ],
218
+ output_names=encoder_output_names,
219
+ dynamic_axes={
220
+ 'tokens': {
221
+ 1: 'n_tokens'
222
+ },
223
+ 'word_div': {
224
+ 1: 'n_words'
225
+ },
226
+ 'word_dur': {
227
+ 1: 'n_words'
228
+ },
229
+ **encoder_common_axes
230
+ },
231
+ opset_version=15
232
+ )
233
+
234
+ print(f'Exporting {self.dur_predictor_class_name}...')
235
+ torch.onnx.export(
236
+ self.model.view_as_dur_predictor(),
237
+ (
238
+ encoder_out,
239
+ x_masks,
240
+ ph_midi,
241
+ *([torch.rand(
242
+ 1, 5, hparams['hidden_size'],
243
+ dtype=torch.float32, device=self.device
244
+ )] if input_spk_embed else [])
245
+ ),
246
+ self.dur_predictor_cache_path,
247
+ input_names=[
248
+ 'encoder_out',
249
+ 'x_masks',
250
+ 'ph_midi',
251
+ *(['spk_embed'] if input_spk_embed else [])
252
+ ],
253
+ output_names=[
254
+ 'ph_dur_pred'
255
+ ],
256
+ dynamic_axes={
257
+ 'ph_midi': {
258
+ 1: 'n_tokens'
259
+ },
260
+ 'ph_dur_pred': {
261
+ 1: 'n_tokens'
262
+ },
263
+ **({'spk_embed': {1: 'n_tokens'}} if input_spk_embed else {}),
264
+ **encoder_common_axes
265
+ },
266
+ opset_version=15
267
+ )
268
+ else:
269
+ torch.onnx.export(
270
+ self.model.view_as_linguistic_encoder(),
271
+ (
272
+ tokens,
273
+ ph_dur
274
+ ),
275
+ self.linguistic_encoder_cache_path,
276
+ input_names=[
277
+ 'tokens',
278
+ 'ph_dur'
279
+ ],
280
+ output_names=encoder_output_names,
281
+ dynamic_axes={
282
+ 'tokens': {
283
+ 1: 'n_tokens'
284
+ },
285
+ 'ph_dur': {
286
+ 1: 'n_tokens'
287
+ },
288
+ **encoder_common_axes
289
+ },
290
+ opset_version=15
291
+ )
292
+
293
+ # Common dummy inputs
294
+ dummy_time = (torch.rand((1,), device=self.device) * hparams.get('time_scale_factor', 1.0)).float()
295
+ dummy_steps = 5
296
+
297
+ if self.model.predict_pitch:
298
+ use_melody_encoder = hparams.get('use_melody_encoder', False)
299
+ use_glide_embed = use_melody_encoder and hparams['use_glide_embed'] and not self.freeze_glide
300
+ # Prepare inputs for preprocessor of the pitch predictor
301
+ note_midi = torch.FloatTensor([[60.] * 4]).to(self.device)
302
+ note_dur = torch.LongTensor([[2, 6, 3, 4]]).to(self.device)
303
+ pitch = torch.FloatTensor([[60.] * 15]).to(self.device)
304
+ retake = torch.ones_like(pitch, dtype=torch.bool)
305
+ pitch_input_args = (
306
+ encoder_out,
307
+ ph_dur,
308
+ {
309
+ 'note_midi': note_midi,
310
+ **({'note_rest': note_midi >= 0} if use_melody_encoder else {}),
311
+ 'note_dur': note_dur,
312
+ **({'note_glide': torch.zeros_like(note_midi, dtype=torch.long)} if use_glide_embed else {}),
313
+ 'pitch': pitch,
314
+ **({'expr': torch.ones_like(pitch)} if self.expose_expr else {}),
315
+ 'retake': retake,
316
+ **({'spk_embed': torch.rand(
317
+ 1, 15, hparams['hidden_size'], dtype=torch.float32, device=self.device
318
+ )} if input_spk_embed else {})
319
+ }
320
+ )
321
+ torch.onnx.export(
322
+ self.model.view_as_pitch_preprocess(),
323
+ pitch_input_args,
324
+ self.pitch_preprocess_cache_path,
325
+ input_names=[
326
+ 'encoder_out', 'ph_dur', 'note_midi',
327
+ *(['note_rest'] if use_melody_encoder else []),
328
+ 'note_dur',
329
+ *(['note_glide'] if use_glide_embed else []),
330
+ 'pitch',
331
+ *(['expr'] if self.expose_expr else []),
332
+ 'retake',
333
+ *(['spk_embed'] if input_spk_embed else [])
334
+ ],
335
+ output_names=[
336
+ 'pitch_cond', 'base_pitch'
337
+ ],
338
+ dynamic_axes={
339
+ 'encoder_out': {
340
+ 1: 'n_tokens'
341
+ },
342
+ 'ph_dur': {
343
+ 1: 'n_tokens'
344
+ },
345
+ 'note_midi': {
346
+ 1: 'n_notes'
347
+ },
348
+ **({'note_rest': {1: 'n_notes'}} if use_melody_encoder else {}),
349
+ 'note_dur': {
350
+ 1: 'n_notes'
351
+ },
352
+ **({'note_glide': {1: 'n_notes'}} if use_glide_embed else {}),
353
+ 'pitch': {
354
+ 1: 'n_frames'
355
+ },
356
+ **({'expr': {1: 'n_frames'}} if self.expose_expr else {}),
357
+ 'retake': {
358
+ 1: 'n_frames'
359
+ },
360
+ 'pitch_cond': {
361
+ 1: 'n_frames'
362
+ },
363
+ 'base_pitch': {
364
+ 1: 'n_frames'
365
+ },
366
+ **({'spk_embed': {1: 'n_frames'}} if input_spk_embed else {})
367
+ },
368
+ opset_version=15
369
+ )
370
+
371
+ # Prepare inputs for backbone tracing and pitch predictor scripting
372
+ shape = (1, 1, hparams['pitch_prediction_args']['repeat_bins'], 15)
373
+ noise = torch.randn(shape, device=self.device)
374
+ condition = torch.rand((1, hparams['hidden_size'], 15), device=self.device)
375
+
376
+ print(f'Tracing {self.pitch_backbone_class_name} backbone...')
377
+ pitch_predictor = self.model.view_as_pitch_predictor()
378
+ pitch_predictor.pitch_predictor.set_backbone(
379
+ torch.jit.trace(
380
+ pitch_predictor.pitch_predictor.backbone,
381
+ (
382
+ noise,
383
+ dummy_time,
384
+ condition
385
+ )
386
+ )
387
+ )
388
+
389
+ print(f'Scripting {self.pitch_predictor_class_name}...')
390
+ pitch_predictor = torch.jit.script(
391
+ pitch_predictor,
392
+ example_inputs=[
393
+ (
394
+ condition.transpose(1, 2),
395
+ 1 # p_sample branch
396
+ ),
397
+ (
398
+ condition.transpose(1, 2),
399
+ dummy_steps # p_sample_plms branch
400
+ )
401
+ ]
402
+ )
403
+
404
+ print(f'Exporting {self.pitch_predictor_class_name}...')
405
+ torch.onnx.export(
406
+ pitch_predictor,
407
+ (
408
+ condition.transpose(1, 2),
409
+ dummy_steps
410
+ ),
411
+ self.pitch_predictor_cache_path,
412
+ input_names=[
413
+ 'pitch_cond',
414
+ 'steps'
415
+ ],
416
+ output_names=[
417
+ 'x_pred'
418
+ ],
419
+ dynamic_axes={
420
+ 'pitch_cond': {
421
+ 1: 'n_frames'
422
+ },
423
+ 'x_pred': {
424
+ 1: 'n_frames'
425
+ }
426
+ },
427
+ opset_version=15
428
+ )
429
+
430
+ # Prepare inputs for postprocessor of the multi-variance predictor
431
+ torch.onnx.export(
432
+ self.model.view_as_pitch_postprocess(),
433
+ (
434
+ pitch,
435
+ pitch
436
+ ),
437
+ self.pitch_postprocess_cache_path,
438
+ input_names=[
439
+ 'x_pred',
440
+ 'base_pitch'
441
+ ],
442
+ output_names=[
443
+ 'pitch_pred'
444
+ ],
445
+ dynamic_axes={
446
+ 'x_pred': {
447
+ 1: 'n_frames'
448
+ },
449
+ 'base_pitch': {
450
+ 1: 'n_frames'
451
+ },
452
+ 'pitch_pred': {
453
+ 1: 'n_frames'
454
+ }
455
+ },
456
+ opset_version=15
457
+ )
458
+
459
+ if self.model.predict_variances:
460
+ total_repeat_bins = hparams['variances_prediction_args']['total_repeat_bins']
461
+ repeat_bins = total_repeat_bins // len(self.model.variance_prediction_list)
462
+
463
+ # Prepare inputs for preprocessor of the multi-variance predictor
464
+ pitch = torch.FloatTensor([[60.] * 15]).to(self.device)
465
+ variances = {
466
+ v_name: torch.FloatTensor([[0.] * 15]).to(self.device)
467
+ for v_name in self.model.variance_prediction_list
468
+ }
469
+ retake = torch.ones_like(pitch, dtype=torch.bool)[..., None].tile(len(self.model.variance_prediction_list))
470
+ torch.onnx.export(
471
+ self.model.view_as_variance_preprocess(),
472
+ (
473
+ encoder_out,
474
+ ph_dur,
475
+ pitch,
476
+ variances,
477
+ retake,
478
+ *([torch.rand(
479
+ 1, 15, hparams['hidden_size'],
480
+ dtype=torch.float32, device=self.device
481
+ )] if input_spk_embed else [])
482
+ ),
483
+ self.variance_preprocess_cache_path,
484
+ input_names=[
485
+ 'encoder_out', 'ph_dur', 'pitch',
486
+ *self.model.variance_prediction_list,
487
+ 'retake',
488
+ *(['spk_embed'] if input_spk_embed else [])
489
+ ],
490
+ output_names=[
491
+ 'variance_cond'
492
+ ],
493
+ dynamic_axes={
494
+ 'encoder_out': {
495
+ 1: 'n_tokens'
496
+ },
497
+ 'ph_dur': {
498
+ 1: 'n_tokens'
499
+ },
500
+ 'pitch': {
501
+ 1: 'n_frames'
502
+ },
503
+ **{
504
+ v_name: {
505
+ 1: 'n_frames'
506
+ }
507
+ for v_name in self.model.variance_prediction_list
508
+ },
509
+ 'retake': {
510
+ 1: 'n_frames'
511
+ },
512
+ **({'spk_embed': {1: 'n_frames'}} if input_spk_embed else {})
513
+ },
514
+ opset_version=15
515
+ )
516
+
517
+ # Prepare inputs for backbone tracing and multi-variance predictor scripting
518
+ shape = (1, len(self.model.variance_prediction_list), repeat_bins, 15)
519
+ noise = torch.randn(shape, device=self.device)
520
+ condition = torch.rand((1, hparams['hidden_size'], 15), device=self.device)
521
+ step = (torch.rand((1,), device=self.device) * hparams['K_step']).long()
522
+
523
+ print(f'Tracing {self.variance_backbone_class_name} backbone...')
524
+ multi_var_predictor = self.model.view_as_variance_predictor()
525
+ multi_var_predictor.variance_predictor.set_backbone(
526
+ torch.jit.trace(
527
+ multi_var_predictor.variance_predictor.backbone,
528
+ (
529
+ noise,
530
+ step,
531
+ condition
532
+ )
533
+ )
534
+ )
535
+
536
+ print(f'Scripting {self.multi_var_predictor_class_name}...')
537
+ multi_var_predictor = torch.jit.script(
538
+ multi_var_predictor,
539
+ example_inputs=[
540
+ (
541
+ condition.transpose(1, 2),
542
+ 1 # p_sample branch
543
+ ),
544
+ (
545
+ condition.transpose(1, 2),
546
+ dummy_steps # p_sample_plms branch
547
+ )
548
+ ]
549
+ )
550
+
551
+ print(f'Exporting {self.multi_var_predictor_class_name}...')
552
+ torch.onnx.export(
553
+ multi_var_predictor,
554
+ (
555
+ condition.transpose(1, 2),
556
+ dummy_steps
557
+ ),
558
+ self.multi_var_predictor_cache_path,
559
+ input_names=[
560
+ 'variance_cond',
561
+ 'steps'
562
+ ],
563
+ output_names=[
564
+ 'xs_pred'
565
+ ],
566
+ dynamic_axes={
567
+ 'variance_cond': {
568
+ 1: 'n_frames'
569
+ },
570
+ 'xs_pred': {
571
+ (1 if len(self.model.variance_prediction_list) == 1 else 2): 'n_frames'
572
+ }
573
+ },
574
+ opset_version=15
575
+ )
576
+
577
+ # Prepare inputs for postprocessor of the multi-variance predictor
578
+ xs_shape = (1, 15) \
579
+ if len(self.model.variance_prediction_list) == 1 \
580
+ else (1, len(self.model.variance_prediction_list), 15)
581
+ xs_pred = torch.randn(xs_shape, dtype=torch.float32, device=self.device)
582
+ torch.onnx.export(
583
+ self.model.view_as_variance_postprocess(),
584
+ (
585
+ xs_pred
586
+ ),
587
+ self.variance_postprocess_cache_path,
588
+ input_names=[
589
+ 'xs_pred'
590
+ ],
591
+ output_names=[
592
+ f'{v_name}_pred'
593
+ for v_name in self.model.variance_prediction_list
594
+ ],
595
+ dynamic_axes={
596
+ 'xs_pred': {
597
+ (1 if len(self.model.variance_prediction_list) == 1 else 2): 'n_frames'
598
+ },
599
+ **{
600
+ f'{v_name}_pred': {
601
+ 1: 'n_frames'
602
+ }
603
+ for v_name in self.model.variance_prediction_list
604
+ }
605
+ },
606
+ opset_version=15
607
+ )
608
+
609
+ @torch.no_grad()
610
+ def _perform_spk_mix(self, spk_mix: Dict[str, float]):
611
+ spk_mix_ids = []
612
+ spk_mix_values = []
613
+ for name, value in spk_mix.items():
614
+ spk_mix_ids.append(self.spk_map[name])
615
+ assert value >= 0., f'Speaker mix checks failed.\n' \
616
+ f'Proportion of speaker \'{name}\' is negative.'
617
+ spk_mix_values.append(value)
618
+ spk_mix_id_N = torch.LongTensor(spk_mix_ids).to(self.device)[None] # => [1, N]
619
+ spk_mix_value_N = torch.FloatTensor(spk_mix_values).to(self.device)[None] # => [1, N]
620
+ spk_mix_value_sum = spk_mix_value_N.sum()
621
+ assert spk_mix_value_sum > 0., f'Speaker mix checks failed.\n' \
622
+ f'Proportions of speaker mix sum to zero.'
623
+ spk_mix_value_N /= spk_mix_value_sum # normalize
624
+ spk_mix_embed = torch.sum(
625
+ self.model.spk_embed(spk_mix_id_N) * spk_mix_value_N.unsqueeze(2), # => [1, N, H]
626
+ dim=1, keepdim=False
627
+ ) # => [1, H]
628
+ return spk_mix_embed
629
+
630
+ def _optimize_linguistic_graph(self, linguistic: onnx.ModelProto) -> onnx.ModelProto:
631
+ onnx_helper.model_override_io_shapes(
632
+ linguistic,
633
+ output_shapes={
634
+ 'encoder_out': (1, 'n_tokens', hparams['hidden_size'])
635
+ }
636
+ )
637
+ print(f'Running ONNX Simplifier on {self.fs2_class_name}...')
638
+ linguistic, check = onnxsim.simplify(linguistic, include_subgraph=True)
639
+ assert check, 'Simplified ONNX model could not be validated'
640
+ print(f'| optimize graph: {self.fs2_class_name}')
641
+ return linguistic
642
+
643
+ def _optimize_dur_predictor_graph(self, dur_predictor: onnx.ModelProto) -> onnx.ModelProto:
644
+ onnx_helper.model_override_io_shapes(
645
+ dur_predictor,
646
+ output_shapes={
647
+ 'ph_dur_pred': (1, 'n_tokens')
648
+ }
649
+ )
650
+ print(f'Running ONNX Simplifier on {self.dur_predictor_class_name}...')
651
+ dur_predictor, check = onnxsim.simplify(dur_predictor, include_subgraph=True)
652
+ assert check, 'Simplified ONNX model could not be validated'
653
+ print(f'| optimize graph: {self.dur_predictor_class_name}')
654
+ return dur_predictor
655
+
656
+ def _optimize_merge_pitch_predictor_graph(
657
+ self, pitch_pre: onnx.ModelProto, pitch_predictor: onnx.ModelProto, pitch_post: onnx.ModelProto
658
+ ) -> onnx.ModelProto:
659
+ onnx_helper.model_override_io_shapes(
660
+ pitch_pre, output_shapes={'pitch_cond': (1, 'n_frames', hparams['hidden_size'])}
661
+ )
662
+ pitch_pre, check = onnxsim.simplify(pitch_pre, include_subgraph=True)
663
+ assert check, 'Simplified ONNX model could not be validated'
664
+
665
+ onnx_helper.model_override_io_shapes(
666
+ pitch_predictor, output_shapes={'pitch_pred': (1, 'n_frames')}
667
+ )
668
+ print(f'Running ONNX Simplifier #1 on {self.pitch_predictor_class_name}...')
669
+ pitch_predictor, check = onnxsim.simplify(pitch_predictor, include_subgraph=True)
670
+ assert check, 'Simplified ONNX model could not be validated'
671
+ onnx_helper.graph_fold_back_to_squeeze(pitch_predictor.graph)
672
+ onnx_helper.graph_extract_conditioner_projections(
673
+ graph=pitch_predictor.graph, op_type='Conv',
674
+ weight_pattern=r'pitch_predictor\..*\.conditioner_projection\.weight',
675
+ alias_prefix='/pitch_predictor/backbone/cache'
676
+ )
677
+ onnx_helper.graph_remove_unused_values(pitch_predictor.graph)
678
+ print(f'Running ONNX Simplifier #2 on {self.pitch_predictor_class_name}...')
679
+ pitch_predictor, check = onnxsim.simplify(pitch_predictor, include_subgraph=True)
680
+ assert check, 'Simplified ONNX model could not be validated'
681
+
682
+ onnx_helper.model_add_prefixes(pitch_pre, node_prefix='/pre', ignored_pattern=r'.*embed.*')
683
+ onnx_helper.model_add_prefixes(pitch_pre, dim_prefix='pre.', ignored_pattern='(n_tokens)|(n_notes)|(n_frames)')
684
+ onnx_helper.model_add_prefixes(pitch_post, node_prefix='/post', ignored_pattern=None)
685
+ onnx_helper.model_add_prefixes(pitch_post, dim_prefix='post.', ignored_pattern='n_frames')
686
+ pitch_pre_diffusion = onnx.compose.merge_models(
687
+ pitch_pre, pitch_predictor, io_map=[('pitch_cond', 'pitch_cond')],
688
+ prefix1='', prefix2='', doc_string='',
689
+ producer_name=pitch_pre.producer_name, producer_version=pitch_pre.producer_version,
690
+ domain=pitch_pre.domain, model_version=pitch_pre.model_version
691
+ )
692
+ pitch_pre_diffusion.graph.name = pitch_pre.graph.name
693
+ pitch_predictor = onnx.compose.merge_models(
694
+ pitch_pre_diffusion, pitch_post, io_map=[
695
+ ('x_pred', 'x_pred'), ('base_pitch', 'base_pitch')
696
+ ], prefix1='', prefix2='', doc_string='',
697
+ producer_name=pitch_pre.producer_name, producer_version=pitch_pre.producer_version,
698
+ domain=pitch_pre.domain, model_version=pitch_pre.model_version
699
+ )
700
+ pitch_predictor.graph.name = pitch_pre.graph.name
701
+
702
+ print(f'| optimize graph: {self.pitch_predictor_class_name}')
703
+ return pitch_predictor
704
+
705
+ def _optimize_merge_variance_predictor_graph(
706
+ self, var_pre: onnx.ModelProto, var_diffusion: onnx.ModelProto, var_post: onnx.ModelProto
707
+ ):
708
+ onnx_helper.model_override_io_shapes(
709
+ var_pre, output_shapes={'variance_cond': (1, 'n_frames', hparams['hidden_size'])}
710
+ )
711
+ var_pre, check = onnxsim.simplify(var_pre, include_subgraph=True)
712
+ assert check, 'Simplified ONNX model could not be validated'
713
+
714
+ onnx_helper.model_override_io_shapes(
715
+ var_diffusion, output_shapes={
716
+ 'xs_pred': (1, 'n_frames')
717
+ if len(self.model.variance_prediction_list) == 1
718
+ else (1, len(self.model.variance_prediction_list), 'n_frames')
719
+ }
720
+ )
721
+ print(f'Running ONNX Simplifier #1 on'
722
+ f' {self.multi_var_predictor_class_name}...')
723
+ var_diffusion, check = onnxsim.simplify(var_diffusion, include_subgraph=True)
724
+ assert check, 'Simplified ONNX model could not be validated'
725
+ onnx_helper.graph_fold_back_to_squeeze(var_diffusion.graph)
726
+ onnx_helper.graph_extract_conditioner_projections(
727
+ graph=var_diffusion.graph, op_type='Conv',
728
+ weight_pattern=r'variance_predictor\..*\.conditioner_projection\.weight',
729
+ alias_prefix='/variance_predictor/backbone/cache'
730
+ )
731
+ onnx_helper.graph_remove_unused_values(var_diffusion.graph)
732
+ print(f'Running ONNX Simplifier #2 on {self.multi_var_predictor_class_name}...')
733
+ var_diffusion, check = onnxsim.simplify(var_diffusion, include_subgraph=True)
734
+ assert check, 'Simplified ONNX model could not be validated'
735
+
736
+ var_post, check = onnxsim.simplify(var_post, include_subgraph=True)
737
+ assert check, 'Simplified ONNX model could not be validated'
738
+
739
+ ignored_variance_names = '|'.join([f'({v_name})' for v_name in self.model.variance_prediction_list])
740
+ onnx_helper.model_add_prefixes(
741
+ var_pre, node_prefix='/pre', value_info_prefix='/pre', initializer_prefix='/pre',
742
+ ignored_pattern=fr'.*((embed)|{ignored_variance_names}).*'
743
+ )
744
+ onnx_helper.model_add_prefixes(var_pre, dim_prefix='pre.', ignored_pattern='(n_tokens)|(n_frames)')
745
+ onnx_helper.model_add_prefixes(
746
+ var_post, node_prefix='/post', value_info_prefix='/post', initializer_prefix='/post',
747
+ ignored_pattern=None
748
+ )
749
+ onnx_helper.model_add_prefixes(var_post, dim_prefix='post.', ignored_pattern='n_frames')
750
+
751
+ print(f'Merging {self.multi_var_predictor_class_name} subroutines...')
752
+ var_pre_diffusion = onnx.compose.merge_models(
753
+ var_pre, var_diffusion, io_map=[('variance_cond', 'variance_cond')],
754
+ prefix1='', prefix2='', doc_string='',
755
+ producer_name=var_pre.producer_name, producer_version=var_pre.producer_version,
756
+ domain=var_pre.domain, model_version=var_pre.model_version
757
+ )
758
+ var_pre_diffusion.graph.name = var_pre.graph.name
759
+ var_predictor = onnx.compose.merge_models(
760
+ var_pre_diffusion, var_post, io_map=[('xs_pred', 'xs_pred')],
761
+ prefix1='', prefix2='', doc_string='',
762
+ producer_name=var_pre.producer_name, producer_version=var_pre.producer_version,
763
+ domain=var_pre.domain, model_version=var_pre.model_version
764
+ )
765
+ var_predictor.graph.name = var_pre.graph.name
766
+ return var_predictor
767
+
768
+ # noinspection PyMethodMayBeStatic
769
+ def _export_spk_embed(self, path: Path, spk_embed: torch.Tensor):
770
+ with open(path, 'wb') as f:
771
+ f.write(spk_embed.cpu().numpy().tobytes())
772
+ print(f'| export spk embed => {path}')
773
+
774
+ # noinspection PyMethodMayBeStatic
775
+ def _export_dictionary(self, path: Path):
776
+ print(f'| export dictionary => {path}')
777
+ shutil.copy(locate_dictionary(), path)
778
+
779
+ def _export_phonemes(self, path: Path):
780
+ self.vocab.store_to_file(path)
781
+ print(f'| export phonemes => {path}')
deployment/modules/__init__.py ADDED
File without changes
deployment/modules/diffusion.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+ from modules.core import (
9
+ GaussianDiffusion, PitchDiffusion, MultiVarianceDiffusion
10
+ )
11
+
12
+
13
+ def extract(a, t):
14
+ return a[t].reshape((1, 1, 1, 1))
15
+
16
+
17
+ # noinspection PyMethodOverriding
18
+ class GaussianDiffusionONNX(GaussianDiffusion):
19
+ @property
20
+ def backbone(self):
21
+ return self.denoise_fn
22
+
23
+ # We give up the setter for the property `backbone` because this will cause TorchScript to fail
24
+ # @backbone.setter
25
+ @torch.jit.unused
26
+ def set_backbone(self, value):
27
+ self.denoise_fn = value
28
+
29
+ def q_sample(self, x_start, t, noise):
30
+ return (
31
+ extract(self.sqrt_alphas_cumprod, t) * x_start +
32
+ extract(self.sqrt_one_minus_alphas_cumprod, t) * noise
33
+ )
34
+
35
+ def p_sample(self, x, t, cond):
36
+ x_pred = self.denoise_fn(x, t, cond)
37
+ x_recon = (
38
+ extract(self.sqrt_recip_alphas_cumprod, t) * x -
39
+ extract(self.sqrt_recipm1_alphas_cumprod, t) * x_pred
40
+ )
41
+ # This is previously inherited from original DiffSinger repository
42
+ # and disabled due to some loudness issues when speedup = 1.
43
+ # x_recon = torch.clamp(x_recon, min=-1., max=1.)
44
+
45
+ model_mean = (
46
+ extract(self.posterior_mean_coef1, t) * x_recon +
47
+ extract(self.posterior_mean_coef2, t) * x
48
+ )
49
+ model_log_variance = extract(self.posterior_log_variance_clipped, t)
50
+ noise = torch.randn_like(x)
51
+ # no noise when t == 0
52
+ nonzero_mask = ((t > 0).float()).reshape(1, 1, 1, 1)
53
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
54
+
55
+ def p_sample_ddim(self, x, t, interval: int, cond):
56
+ a_t = extract(self.alphas_cumprod, t)
57
+ t_prev = t - interval
58
+ a_prev = extract(self.alphas_cumprod, t_prev * (t_prev > 0))
59
+
60
+ noise_pred = self.denoise_fn(x, t, cond=cond)
61
+ x_prev = a_prev.sqrt() * (
62
+ x / a_t.sqrt() + (((1 - a_prev) / a_prev).sqrt() - ((1 - a_t) / a_t).sqrt()) * noise_pred
63
+ )
64
+ return x_prev
65
+
66
+ def plms_get_x_pred(self, x, noise_t, t, t_prev):
67
+ a_t = extract(self.alphas_cumprod, t)
68
+ a_prev = extract(self.alphas_cumprod, t_prev)
69
+ a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
70
+
71
+ x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (
72
+ a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
73
+ x_pred = x + x_delta
74
+
75
+ return x_pred
76
+
77
+ def p_sample_plms(self, x_prev, t, interval: int, cond, noise_list: List[Tensor], stage: int):
78
+ noise_pred = self.denoise_fn(x_prev, t, cond)
79
+ t_prev = t - interval
80
+ t_prev = t_prev * (t_prev > 0)
81
+ if stage == 0:
82
+ x_pred = self.plms_get_x_pred(x_prev, noise_pred, t, t_prev)
83
+ noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond)
84
+ noise_pred_prime = (noise_pred + noise_pred_prev) / 2.
85
+ elif stage == 1:
86
+ noise_pred_prime = (3. * noise_pred - noise_list[-1]) / 2.
87
+ elif stage == 2:
88
+ noise_pred_prime = (23. * noise_pred - 16. * noise_list[-1] + 5. * noise_list[-2]) / 12.
89
+ else:
90
+ noise_pred_prime = (55. * noise_pred - 59. * noise_list[-1] + 37.
91
+ * noise_list[-2] - 9. * noise_list[-3]) / 24.
92
+ x_prev = self.plms_get_x_pred(x_prev, noise_pred_prime, t, t_prev)
93
+ return noise_pred, x_prev
94
+
95
+ def norm_spec(self, x):
96
+ k = (self.spec_max - self.spec_min) / 2.
97
+ b = (self.spec_max + self.spec_min) / 2.
98
+ return (x - b) / k
99
+
100
+ def denorm_spec(self, x):
101
+ k = (self.spec_max - self.spec_min) / 2.
102
+ b = (self.spec_max + self.spec_min) / 2.
103
+ return x * k + b
104
+
105
+ def forward(self, condition, x_start=None, depth=None, steps: int = 10):
106
+ condition = condition.transpose(1, 2) # [1, T, H] => [1, H, T]
107
+ device = condition.device
108
+ n_frames = condition.shape[2]
109
+
110
+ noise = torch.randn((1, self.num_feats, self.out_dims, n_frames), device=device)
111
+ if x_start is None:
112
+ speedup = max(1, self.timesteps // steps)
113
+ speedup = self.timestep_factors[torch.sum(self.timestep_factors <= speedup) - 1]
114
+ step_range = torch.arange(0, self.k_step, speedup, dtype=torch.long, device=device).flip(0)[:, None]
115
+ x = noise
116
+ else:
117
+ depth_int64 = min(torch.round(depth * self.timesteps).long(), self.k_step)
118
+ speedup = max(1, depth_int64 // steps)
119
+ depth_int64 = depth_int64 // speedup * speedup # make depth_int64 a multiple of speedup
120
+ step_range = torch.arange(0, depth_int64, speedup, dtype=torch.long, device=device).flip(0)[:, None]
121
+ x_start = self.norm_spec(x_start).transpose(-2, -1)
122
+ if self.num_feats == 1:
123
+ x_start = x_start[:, None, :, :]
124
+ if depth_int64 >= self.timesteps:
125
+ x = noise
126
+ elif depth_int64 > 0:
127
+ x = self.q_sample(
128
+ x_start, torch.full((1,), depth_int64 - 1, device=device, dtype=torch.long), noise
129
+ )
130
+ else:
131
+ x = x_start
132
+
133
+ if speedup > 1:
134
+ for t in step_range:
135
+ x = self.p_sample_ddim(x, t, interval=speedup, cond=condition)
136
+ # plms_noise_stage: int = 0
137
+ # noise_list: List[Tensor] = []
138
+ # for t in step_range:
139
+ # noise_pred, x = self.p_sample_plms(
140
+ # x, t, interval=speedup, cond=condition,
141
+ # noise_list=noise_list, stage=plms_noise_stage
142
+ # )
143
+ # if plms_noise_stage == 0:
144
+ # noise_list = [noise_pred]
145
+ # plms_noise_stage = plms_noise_stage + 1
146
+ # else:
147
+ # if plms_noise_stage >= 3:
148
+ # noise_list.pop(0)
149
+ # else:
150
+ # plms_noise_stage = plms_noise_stage + 1
151
+ # noise_list.append(noise_pred)
152
+ else:
153
+ for t in step_range:
154
+ x = self.p_sample(x, t, cond=condition)
155
+
156
+ if self.num_feats == 1:
157
+ x = x.squeeze(1).permute(0, 2, 1) # [B, 1, M, T] => [B, T, M]
158
+ else:
159
+ x = x.permute(0, 1, 3, 2) # [B, F, M, T] => [B, F, T, M]
160
+ x = self.denorm_spec(x)
161
+ return x
162
+
163
+
164
+ class PitchDiffusionONNX(GaussianDiffusionONNX, PitchDiffusion):
165
+ def __init__(self, vmin: float, vmax: float,
166
+ cmin: float, cmax: float, repeat_bins,
167
+ timesteps=1000, k_step=1000,
168
+ backbone_type=None, backbone_args=None,
169
+ betas=None):
170
+ self.vmin = vmin
171
+ self.vmax = vmax
172
+ self.cmin = cmin
173
+ self.cmax = cmax
174
+ super(PitchDiffusion, self).__init__(
175
+ vmin=vmin, vmax=vmax, repeat_bins=repeat_bins,
176
+ timesteps=timesteps, k_step=k_step,
177
+ backbone_type=backbone_type, backbone_args=backbone_args,
178
+ betas=betas
179
+ )
180
+
181
+ def clamp_spec(self, x):
182
+ return x.clamp(min=self.cmin, max=self.cmax)
183
+
184
+ def denorm_spec(self, x):
185
+ d = (self.spec_max - self.spec_min) / 2.
186
+ m = (self.spec_max + self.spec_min) / 2.
187
+ x = x * d + m
188
+ x = x.mean(dim=-1)
189
+ return x
190
+
191
+
192
+ class MultiVarianceDiffusionONNX(GaussianDiffusionONNX, MultiVarianceDiffusion):
193
+ def __init__(
194
+ self, ranges: List[Tuple[float, float]],
195
+ clamps: List[Tuple[float | None, float | None] | None],
196
+ repeat_bins, timesteps=1000, k_step=1000,
197
+ backbone_type=None, backbone_args=None,
198
+ betas=None
199
+ ):
200
+ assert len(ranges) == len(clamps)
201
+ self.clamps = clamps
202
+ vmin = [r[0] for r in ranges]
203
+ vmax = [r[1] for r in ranges]
204
+ if len(vmin) == 1:
205
+ vmin = vmin[0]
206
+ if len(vmax) == 1:
207
+ vmax = vmax[0]
208
+ super(MultiVarianceDiffusion, self).__init__(
209
+ vmin=vmin, vmax=vmax, repeat_bins=repeat_bins,
210
+ timesteps=timesteps, k_step=k_step,
211
+ backbone_type=backbone_type, backbone_args=backbone_args,
212
+ betas=betas
213
+ )
214
+
215
+ def denorm_spec(self, x):
216
+ d = (self.spec_max - self.spec_min) / 2.
217
+ m = (self.spec_max + self.spec_min) / 2.
218
+ x = x * d + m
219
+ x = x.mean(dim=-1)
220
+ return x
deployment/modules/fastspeech2.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from modules.commons.common_layers import NormalInitEmbedding as Embedding
9
+ from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic
10
+ from modules.fastspeech.variance_encoder import FastSpeech2Variance
11
+ from utils.hparams import hparams
12
+ from utils.text_encoder import PAD_INDEX
13
+
14
+ f0_bin = 256
15
+ f0_max = 1100.0
16
+ f0_min = 50.0
17
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
18
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
19
+
20
+
21
+ def f0_to_coarse(f0):
22
+ f0_mel = 1127 * (1 + f0 / 700).log()
23
+ a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
24
+ b = f0_mel_min * a - 1.
25
+ f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
26
+ torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
27
+ f0_coarse = torch.round(f0_mel).long()
28
+ return f0_coarse
29
+
30
+
31
+ class LengthRegulator(nn.Module):
32
+ # noinspection PyMethodMayBeStatic
33
+ def forward(self, dur):
34
+ token_idx = torch.arange(1, dur.shape[1] + 1, device=dur.device)[None, :, None]
35
+ dur_cumsum = torch.cumsum(dur, dim=1)
36
+ dur_cumsum_prev = F.pad(dur_cumsum, (1, -1), mode='constant', value=0)
37
+ pos_idx = torch.arange(dur.sum(dim=1).max(), device=dur.device)[None, None]
38
+ token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
39
+ mel2ph = (token_idx * token_mask).sum(dim=1)
40
+ return mel2ph
41
+
42
+
43
+ class FastSpeech2AcousticONNX(FastSpeech2Acoustic):
44
+ def __init__(self, vocab_size):
45
+ super().__init__(vocab_size=vocab_size)
46
+
47
+ # for temporary compatibility; will be completely removed in the future
48
+ self.f0_embed_type = hparams.get('f0_embed_type', 'continuous')
49
+ if self.f0_embed_type == 'discrete':
50
+ self.pitch_embed = Embedding(300, hparams['hidden_size'], PAD_INDEX)
51
+
52
+ self.lr = LengthRegulator()
53
+ if hparams['use_key_shift_embed']:
54
+ self.shift_min, self.shift_max = hparams['augmentation_args']['random_pitch_shifting']['range']
55
+ if hparams['use_speed_embed']:
56
+ self.speed_min, self.speed_max = hparams['augmentation_args']['random_time_stretching']['range']
57
+
58
+ # noinspection PyMethodOverriding
59
+ def forward(self, tokens, durations, f0, variances: dict, gender=None, velocity=None, spk_embed=None):
60
+ txt_embed = self.txt_embed(tokens)
61
+ durations = durations * (tokens > 0)
62
+ mel2ph = self.lr(durations)
63
+ f0 = f0 * (mel2ph > 0)
64
+ mel2ph = mel2ph[..., None].repeat((1, 1, hparams['hidden_size']))
65
+ dur_embed = self.dur_embed(durations.float()[:, :, None])
66
+ encoded = self.encoder(txt_embed, dur_embed, tokens == PAD_INDEX)
67
+ encoded = F.pad(encoded, (0, 0, 1, 0))
68
+ condition = torch.gather(encoded, 1, mel2ph)
69
+
70
+ if self.f0_embed_type == 'discrete':
71
+ pitch = f0_to_coarse(f0)
72
+ pitch_embed = self.pitch_embed(pitch)
73
+ else:
74
+ f0_mel = (1 + f0 / 700).log()
75
+ pitch_embed = self.pitch_embed(f0_mel[:, :, None])
76
+ condition += pitch_embed
77
+
78
+ if self.use_variance_embeds:
79
+ variance_embeds = torch.stack([
80
+ self.variance_embeds[v_name](variances[v_name][:, :, None])
81
+ for v_name in self.variance_embed_list
82
+ ], dim=-1).sum(-1)
83
+ condition += variance_embeds
84
+
85
+ if hparams['use_key_shift_embed']:
86
+ if hasattr(self, 'frozen_key_shift'):
87
+ key_shift_embed = self.key_shift_embed(self.frozen_key_shift[:, None, None])
88
+ else:
89
+ gender = torch.clip(gender, min=-1., max=1.)
90
+ gender_mask = (gender < 0.).float()
91
+ key_shift = gender * ((1. - gender_mask) * self.shift_max + gender_mask * abs(self.shift_min))
92
+ key_shift_embed = self.key_shift_embed(key_shift[:, :, None])
93
+ condition += key_shift_embed
94
+
95
+ if hparams['use_speed_embed']:
96
+ if velocity is not None:
97
+ velocity = torch.clip(velocity, min=self.speed_min, max=self.speed_max)
98
+ speed_embed = self.speed_embed(velocity[:, :, None])
99
+ else:
100
+ speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None])
101
+ condition += speed_embed
102
+
103
+ if hparams['use_spk_id']:
104
+ if hasattr(self, 'frozen_spk_embed'):
105
+ condition += self.frozen_spk_embed
106
+ else:
107
+ condition += spk_embed
108
+ return condition
109
+
110
+
111
+ class FastSpeech2VarianceONNX(FastSpeech2Variance):
112
+ def __init__(self, vocab_size):
113
+ super().__init__(vocab_size=vocab_size)
114
+ self.lr = LengthRegulator()
115
+
116
+ def forward_encoder_word(self, tokens, word_div, word_dur):
117
+ txt_embed = self.txt_embed(tokens)
118
+ ph2word = self.lr(word_div)
119
+ onset = ph2word > F.pad(ph2word, [1, -1])
120
+ onset_embed = self.onset_embed(onset.long())
121
+ ph_word_dur = torch.gather(F.pad(word_dur, [1, 0]), 1, ph2word)
122
+ word_dur_embed = self.word_dur_embed(ph_word_dur.float()[:, :, None])
123
+ x_masks = tokens == PAD_INDEX
124
+ return self.encoder(txt_embed, onset_embed + word_dur_embed, x_masks), x_masks
125
+
126
+ def forward_encoder_phoneme(self, tokens, ph_dur):
127
+ txt_embed = self.txt_embed(tokens)
128
+ ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None])
129
+ x_masks = tokens == PAD_INDEX
130
+ return self.encoder(txt_embed, ph_dur_embed, x_masks), x_masks
131
+
132
+ def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None):
133
+ midi_embed = self.midi_embed(ph_midi)
134
+ dur_cond = encoder_out + midi_embed
135
+ if hparams['use_spk_id'] and spk_embed is not None:
136
+ dur_cond += spk_embed
137
+ ph_dur = self.dur_predictor(dur_cond, x_masks=x_masks)
138
+ return ph_dur
139
+
140
+ def view_as_encoder(self):
141
+ model = copy.deepcopy(self)
142
+ if self.predict_dur:
143
+ del model.dur_predictor
144
+ model.forward = model.forward_encoder_word
145
+ else:
146
+ model.forward = model.forward_encoder_phoneme
147
+ return model
148
+
149
+ def view_as_dur_predictor(self):
150
+ model = copy.deepcopy(self)
151
+ del model.encoder
152
+ model.forward = model.forward_dur_predictor
153
+ return model
deployment/modules/nsf_hifigan.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from modules.nsf_hifigan.env import AttrDict
4
+ from modules.nsf_hifigan.models import Generator
5
+
6
+
7
+ # noinspection SpellCheckingInspection
8
+ class NSFHiFiGANONNX(torch.nn.Module):
9
+ def __init__(self, attrs: dict):
10
+ super().__init__()
11
+ self.generator = Generator(AttrDict(attrs))
12
+
13
+ def forward(self, mel: torch.Tensor, f0: torch.Tensor):
14
+ mel = mel.transpose(1, 2)
15
+ wav = self.generator(mel, f0)
16
+ return wav.squeeze(1)
deployment/modules/rectified_flow.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+
7
+ from modules.core import (
8
+ RectifiedFlow, PitchRectifiedFlow, MultiVarianceRectifiedFlow
9
+ )
10
+
11
+
12
+ class RectifiedFlowONNX(RectifiedFlow):
13
+ @property
14
+ def backbone(self):
15
+ return self.velocity_fn
16
+
17
+ # We give up the setter for the property `backbone` because this will cause TorchScript to fail
18
+ # @backbone.setter
19
+ @torch.jit.unused
20
+ def set_backbone(self, value):
21
+ self.velocity_fn = value
22
+
23
+ def sample_euler(self, x, t, dt: float, cond):
24
+ x += self.velocity_fn(x, t * self.time_scale_factor, cond) * dt
25
+ return x
26
+
27
+ def norm_spec(self, x):
28
+ k = (self.spec_max - self.spec_min) / 2.
29
+ b = (self.spec_max + self.spec_min) / 2.
30
+ return (x - b) / k
31
+
32
+ def denorm_spec(self, x):
33
+ k = (self.spec_max - self.spec_min) / 2.
34
+ b = (self.spec_max + self.spec_min) / 2.
35
+ return x * k + b
36
+
37
+ def forward(self, condition, x_end=None, depth=None, steps: int = 10):
38
+ condition = condition.transpose(1, 2) # [1, T, H] => [1, H, T]
39
+ device = condition.device
40
+ n_frames = condition.shape[2]
41
+ noise = torch.randn((1, self.num_feats, self.out_dims, n_frames), device=device)
42
+ if x_end is None:
43
+ t_start = 0.
44
+ x = noise
45
+ else:
46
+ t_start = torch.max(1 - depth, torch.tensor(self.t_start, dtype=torch.float32, device=device))
47
+ x_end = self.norm_spec(x_end).transpose(-2, -1)
48
+ if self.num_feats == 1:
49
+ x_end = x_end[:, None, :, :]
50
+ if t_start <= 0.:
51
+ x = noise
52
+ elif t_start >= 1.:
53
+ x = x_end
54
+ else:
55
+ x = t_start * x_end + (1 - t_start) * noise
56
+
57
+ t_width = 1. - t_start
58
+ if t_width >= 0.:
59
+ dt = t_width / max(1, steps)
60
+ for t in torch.arange(steps, dtype=torch.long, device=device)[:, None].float() * dt + t_start:
61
+ x = self.sample_euler(x, t, dt, condition)
62
+
63
+ if self.num_feats == 1:
64
+ x = x.squeeze(1).permute(0, 2, 1) # [B, 1, M, T] => [B, T, M]
65
+ else:
66
+ x = x.permute(0, 1, 3, 2) # [B, F, M, T] => [B, F, T, M]
67
+ x = self.denorm_spec(x)
68
+ return x
69
+
70
+
71
+ class PitchRectifiedFlowONNX(RectifiedFlowONNX, PitchRectifiedFlow):
72
+ def __init__(self, vmin: float, vmax: float,
73
+ cmin: float, cmax: float, repeat_bins,
74
+ time_scale_factor=1000,
75
+ backbone_type=None, backbone_args=None):
76
+ self.vmin = vmin
77
+ self.vmax = vmax
78
+ self.cmin = cmin
79
+ self.cmax = cmax
80
+ super(PitchRectifiedFlow, self).__init__(
81
+ vmin=vmin, vmax=vmax, repeat_bins=repeat_bins,
82
+ time_scale_factor=time_scale_factor,
83
+ backbone_type=backbone_type, backbone_args=backbone_args
84
+ )
85
+
86
+ def clamp_spec(self, x):
87
+ return x.clamp(min=self.cmin, max=self.cmax)
88
+
89
+ def denorm_spec(self, x):
90
+ d = (self.spec_max - self.spec_min) / 2.
91
+ m = (self.spec_max + self.spec_min) / 2.
92
+ x = x * d + m
93
+ x = x.mean(dim=-1)
94
+ return x
95
+
96
+
97
+ class MultiVarianceRectifiedFlowONNX(RectifiedFlowONNX, MultiVarianceRectifiedFlow):
98
+ def __init__(
99
+ self, ranges: List[Tuple[float, float]],
100
+ clamps: List[Tuple[float | None, float | None] | None],
101
+ repeat_bins, time_scale_factor=1000,
102
+ backbone_type=None, backbone_args=None
103
+ ):
104
+ assert len(ranges) == len(clamps)
105
+ self.clamps = clamps
106
+ vmin = [r[0] for r in ranges]
107
+ vmax = [r[1] for r in ranges]
108
+ if len(vmin) == 1:
109
+ vmin = vmin[0]
110
+ if len(vmax) == 1:
111
+ vmax = vmax[0]
112
+ super(MultiVarianceRectifiedFlow, self).__init__(
113
+ vmin=vmin, vmax=vmax, repeat_bins=repeat_bins,
114
+ time_scale_factor=time_scale_factor,
115
+ backbone_type=backbone_type, backbone_args=backbone_args
116
+ )
117
+
118
+ def denorm_spec(self, x):
119
+ d = (self.spec_max - self.spec_min) / 2.
120
+ m = (self.spec_max + self.spec_min) / 2.
121
+ x = x * d + m
122
+ x = x.mean(dim=-1)
123
+ return x
deployment/modules/toplevel.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+
9
+ from deployment.modules.diffusion import (
10
+ GaussianDiffusionONNX, PitchDiffusionONNX, MultiVarianceDiffusionONNX
11
+ )
12
+ from deployment.modules.rectified_flow import (
13
+ RectifiedFlowONNX, PitchRectifiedFlowONNX, MultiVarianceRectifiedFlowONNX
14
+ )
15
+ from deployment.modules.fastspeech2 import FastSpeech2AcousticONNX, FastSpeech2VarianceONNX
16
+ from modules.toplevel import DiffSingerAcoustic, DiffSingerVariance
17
+ from utils.hparams import hparams
18
+
19
+
20
+ class DiffSingerAcousticONNX(DiffSingerAcoustic):
21
+ def __init__(self, vocab_size, out_dims):
22
+ super().__init__(vocab_size, out_dims)
23
+ del self.fs2
24
+ del self.diffusion
25
+ self.fs2 = FastSpeech2AcousticONNX(
26
+ vocab_size=vocab_size
27
+ )
28
+ if self.diffusion_type == 'ddpm':
29
+ self.diffusion = GaussianDiffusionONNX(
30
+ out_dims=out_dims,
31
+ num_feats=1,
32
+ timesteps=hparams['timesteps'],
33
+ k_step=hparams['K_step'],
34
+ backbone_type=self.backbone_type,
35
+ backbone_args=self.backbone_args,
36
+ spec_min=hparams['spec_min'],
37
+ spec_max=hparams['spec_max']
38
+ )
39
+ elif self.diffusion_type == 'reflow':
40
+ self.diffusion = RectifiedFlowONNX(
41
+ out_dims=out_dims,
42
+ num_feats=1,
43
+ t_start=hparams['T_start'],
44
+ time_scale_factor=hparams['time_scale_factor'],
45
+ backbone_type=self.backbone_type,
46
+ backbone_args=self.backbone_args,
47
+ spec_min=hparams['spec_min'],
48
+ spec_max=hparams['spec_max']
49
+ )
50
+ else:
51
+ raise ValueError(f"Invalid diffusion type: {self.diffusion_type}")
52
+ self.mel_base = hparams.get('mel_base', '10')
53
+
54
+ def ensure_mel_base(self, mel):
55
+ if self.mel_base != 'e':
56
+ # log10 mel to log mel
57
+ mel = mel * 2.30259
58
+ return mel
59
+
60
+ def forward_fs2_aux(
61
+ self,
62
+ tokens: Tensor,
63
+ durations: Tensor,
64
+ f0: Tensor,
65
+ variances: dict,
66
+ gender: Tensor = None,
67
+ velocity: Tensor = None,
68
+ spk_embed: Tensor = None
69
+ ):
70
+ condition = self.fs2(
71
+ tokens, durations, f0, variances=variances,
72
+ gender=gender, velocity=velocity, spk_embed=spk_embed
73
+ )
74
+ if self.use_shallow_diffusion:
75
+ aux_mel_pred = self.aux_decoder(condition, infer=True)
76
+ return condition, aux_mel_pred
77
+ else:
78
+ return condition
79
+
80
+ def forward_shallow_diffusion(
81
+ self, condition: Tensor, x_start: Tensor,
82
+ depth, steps: int
83
+ ) -> Tensor:
84
+ mel_pred = self.diffusion(condition, x_start=x_start, depth=depth, steps=steps)
85
+ return self.ensure_mel_base(mel_pred)
86
+
87
+ def forward_diffusion(self, condition: Tensor, steps: int):
88
+ mel_pred = self.diffusion(condition, steps=steps)
89
+ return self.ensure_mel_base(mel_pred)
90
+
91
+ def forward_shallow_reflow(
92
+ self, condition: Tensor, x_end: Tensor,
93
+ depth, steps: int
94
+ ):
95
+ mel_pred = self.diffusion(condition, x_end=x_end, depth=depth, steps=steps)
96
+ return self.ensure_mel_base(mel_pred)
97
+
98
+ def forward_reflow(self, condition: Tensor, steps: int):
99
+ mel_pred = self.diffusion(condition, steps=steps)
100
+ return self.ensure_mel_base(mel_pred)
101
+
102
+ def view_as_fs2_aux(self) -> nn.Module:
103
+ model = copy.deepcopy(self)
104
+ del model.diffusion
105
+ model.forward = model.forward_fs2_aux
106
+ return model
107
+
108
+ def view_as_diffusion(self) -> nn.Module:
109
+ model = copy.deepcopy(self)
110
+ del model.fs2
111
+ if self.use_shallow_diffusion:
112
+ del model.aux_decoder
113
+ model.forward = model.forward_shallow_diffusion
114
+ else:
115
+ model.forward = model.forward_diffusion
116
+ return model
117
+
118
+ def view_as_reflow(self) -> nn.Module:
119
+ model = copy.deepcopy(self)
120
+ del model.fs2
121
+ if self.use_shallow_diffusion:
122
+ del model.aux_decoder
123
+ model.forward = model.forward_shallow_reflow
124
+ else:
125
+ model.forward = model.forward_reflow
126
+ return model
127
+
128
+
129
+ class DiffSingerVarianceONNX(DiffSingerVariance):
130
+ def __init__(self, vocab_size):
131
+ super().__init__(vocab_size=vocab_size)
132
+ del self.fs2
133
+ self.fs2 = FastSpeech2VarianceONNX(
134
+ vocab_size=vocab_size
135
+ )
136
+ self.hidden_size = hparams['hidden_size']
137
+ if self.predict_pitch:
138
+ del self.pitch_predictor
139
+ self.smooth: nn.Conv1d = None
140
+ pitch_hparams = hparams['pitch_prediction_args']
141
+ if self.diffusion_type == 'ddpm':
142
+ self.pitch_predictor = PitchDiffusionONNX(
143
+ vmin=pitch_hparams['pitd_norm_min'],
144
+ vmax=pitch_hparams['pitd_norm_max'],
145
+ cmin=pitch_hparams['pitd_clip_min'],
146
+ cmax=pitch_hparams['pitd_clip_max'],
147
+ repeat_bins=pitch_hparams['repeat_bins'],
148
+ timesteps=hparams['timesteps'],
149
+ k_step=hparams['K_step'],
150
+ backbone_type=self.pitch_backbone_type,
151
+ backbone_args=self.pitch_backbone_args
152
+ )
153
+ elif self.diffusion_type == 'reflow':
154
+ self.pitch_predictor = PitchRectifiedFlowONNX(
155
+ vmin=pitch_hparams['pitd_norm_min'],
156
+ vmax=pitch_hparams['pitd_norm_max'],
157
+ cmin=pitch_hparams['pitd_clip_min'],
158
+ cmax=pitch_hparams['pitd_clip_max'],
159
+ repeat_bins=pitch_hparams['repeat_bins'],
160
+ time_scale_factor=hparams['time_scale_factor'],
161
+ backbone_type=self.pitch_backbone_type,
162
+ backbone_args=self.pitch_backbone_args
163
+ )
164
+ else:
165
+ raise ValueError(f"Invalid diffusion type: {self.diffusion_type}")
166
+ if self.predict_variances:
167
+ del self.variance_predictor
168
+ if self.diffusion_type == 'ddpm':
169
+ self.variance_predictor = self.build_adaptor(cls=MultiVarianceDiffusionONNX)
170
+ elif self.diffusion_type == 'reflow':
171
+ self.variance_predictor = self.build_adaptor(cls=MultiVarianceRectifiedFlowONNX)
172
+ else:
173
+ raise NotImplementedError(self.diffusion_type)
174
+
175
+ def build_smooth_op(self, device):
176
+ smooth_kernel_size = round(hparams['midi_smooth_width'] * hparams['audio_sample_rate'] / hparams['hop_size'])
177
+ smooth = nn.Conv1d(
178
+ in_channels=1,
179
+ out_channels=1,
180
+ kernel_size=smooth_kernel_size,
181
+ bias=False,
182
+ padding='same',
183
+ padding_mode='replicate'
184
+ ).eval()
185
+ smooth_kernel = torch.sin(torch.from_numpy(
186
+ np.linspace(0, 1, smooth_kernel_size).astype(np.float32) * np.pi
187
+ ))
188
+ smooth_kernel /= smooth_kernel.sum()
189
+ smooth.weight.data = smooth_kernel[None, None]
190
+ self.smooth = smooth.to(device)
191
+
192
+ def embed_frozen_spk(self, encoder_out):
193
+ if hparams['use_spk_id'] and hasattr(self, 'frozen_spk_embed'):
194
+ encoder_out += self.frozen_spk_embed
195
+ return encoder_out
196
+
197
+ def forward_linguistic_encoder_word(self, tokens, word_div, word_dur):
198
+ encoder_out, x_masks = self.fs2.forward_encoder_word(tokens, word_div, word_dur)
199
+ encoder_out = self.embed_frozen_spk(encoder_out)
200
+ return encoder_out, x_masks
201
+
202
+ def forward_linguistic_encoder_phoneme(self, tokens, ph_dur):
203
+ encoder_out, x_masks = self.fs2.forward_encoder_phoneme(tokens, ph_dur)
204
+ encoder_out = self.embed_frozen_spk(encoder_out)
205
+ return encoder_out, x_masks
206
+
207
+ def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None):
208
+ return self.fs2.forward_dur_predictor(encoder_out, x_masks, ph_midi, spk_embed=spk_embed)
209
+
210
+ def forward_mel2x_gather(self, x_src, x_dur, x_dim=None):
211
+ mel2x = self.lr(x_dur)
212
+ if x_dim is not None:
213
+ x_src = F.pad(x_src, [0, 0, 1, 0])
214
+ mel2x = mel2x[..., None].repeat([1, 1, x_dim])
215
+ else:
216
+ x_src = F.pad(x_src, [1, 0])
217
+ x_cond = torch.gather(x_src, 1, mel2x)
218
+ return x_cond
219
+
220
+ def forward_pitch_preprocess(
221
+ self, encoder_out, ph_dur,
222
+ note_midi=None, note_rest=None, note_dur=None, note_glide=None,
223
+ pitch=None, expr=None, retake=None, spk_embed=None
224
+ ):
225
+ condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size)
226
+ if self.use_melody_encoder:
227
+ if self.melody_encoder.use_glide_embed and note_glide is None:
228
+ note_glide = torch.LongTensor([[0]]).to(encoder_out.device)
229
+ melody_encoder_out = self.melody_encoder(
230
+ note_midi, note_rest, note_dur,
231
+ glide=note_glide
232
+ )
233
+ melody_encoder_out = self.forward_mel2x_gather(melody_encoder_out, note_dur, x_dim=self.hidden_size)
234
+ condition += melody_encoder_out
235
+ if expr is None:
236
+ retake_embed = self.pitch_retake_embed(retake.long())
237
+ else:
238
+ retake_true_embed = self.pitch_retake_embed(
239
+ torch.ones(1, 1, dtype=torch.long, device=encoder_out.device)
240
+ ) # [B=1, T=1] => [B=1, T=1, H]
241
+ retake_false_embed = self.pitch_retake_embed(
242
+ torch.zeros(1, 1, dtype=torch.long, device=encoder_out.device)
243
+ ) # [B=1, T=1] => [B=1, T=1, H]
244
+ expr = (expr * retake)[:, :, None] # [B, T, 1]
245
+ retake_embed = expr * retake_true_embed + (1. - expr) * retake_false_embed
246
+ pitch_cond = condition + retake_embed
247
+ frame_midi_pitch = self.forward_mel2x_gather(note_midi, note_dur, x_dim=None)
248
+ base_pitch = self.smooth(frame_midi_pitch)
249
+ if self.use_melody_encoder:
250
+ delta_pitch = (pitch - base_pitch) * ~retake
251
+ pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None])
252
+ else:
253
+ base_pitch = base_pitch * retake + pitch * ~retake
254
+ pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
255
+ if hparams['use_spk_id'] and spk_embed is not None:
256
+ pitch_cond += spk_embed
257
+ return pitch_cond, base_pitch
258
+
259
+ def forward_pitch_reflow(
260
+ self, pitch_cond, steps: int = 10
261
+ ):
262
+ x_pred = self.pitch_predictor(pitch_cond, steps=steps)
263
+ return x_pred
264
+
265
+ def forward_pitch_postprocess(self, x_pred, base_pitch):
266
+ pitch_pred = self.pitch_predictor.clamp_spec(x_pred) + base_pitch
267
+ return pitch_pred
268
+
269
+ def forward_variance_preprocess(
270
+ self, encoder_out, ph_dur, pitch,
271
+ variances: dict = None, retake=None, spk_embed=None
272
+ ):
273
+ condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size)
274
+ variance_cond = condition + self.pitch_embed(pitch[:, :, None])
275
+ non_retake_masks = [
276
+ v_retake.float() # [B, T, 1]
277
+ for v_retake in (~retake).split(1, dim=2)
278
+ ]
279
+ variance_embeds = [
280
+ self.variance_embeds[v_name](variances[v_name][:, :, None]) * v_masks
281
+ for v_name, v_masks in zip(self.variance_prediction_list, non_retake_masks)
282
+ ]
283
+ variance_cond += torch.stack(variance_embeds, dim=-1).sum(-1)
284
+ if hparams['use_spk_id'] and spk_embed is not None:
285
+ variance_cond += spk_embed
286
+ return variance_cond
287
+
288
+ def forward_variance_reflow(self, variance_cond, steps: int = 10):
289
+ xs_pred = self.variance_predictor(variance_cond, steps=steps)
290
+ return xs_pred
291
+
292
+ def forward_variance_postprocess(self, xs_pred):
293
+ if self.variance_predictor.num_feats == 1:
294
+ xs_pred = [xs_pred]
295
+ else:
296
+ xs_pred = xs_pred.unbind(dim=1)
297
+ variance_pred = self.variance_predictor.clamp_spec(xs_pred)
298
+ return tuple(variance_pred)
299
+
300
+ def view_as_linguistic_encoder(self):
301
+ model = copy.deepcopy(self)
302
+ if self.predict_pitch:
303
+ del model.pitch_predictor
304
+ if self.use_melody_encoder:
305
+ del model.melody_encoder
306
+ if self.predict_variances:
307
+ del model.variance_predictor
308
+ model.fs2 = model.fs2.view_as_encoder()
309
+ if self.predict_dur:
310
+ model.forward = model.forward_linguistic_encoder_word
311
+ else:
312
+ model.forward = model.forward_linguistic_encoder_phoneme
313
+ return model
314
+
315
+ def view_as_dur_predictor(self):
316
+ assert self.predict_dur
317
+ model = copy.deepcopy(self)
318
+ if self.predict_pitch:
319
+ del model.pitch_predictor
320
+ if self.use_melody_encoder:
321
+ del model.melody_encoder
322
+ if self.predict_variances:
323
+ del model.variance_predictor
324
+ model.fs2 = model.fs2.view_as_dur_predictor()
325
+ model.forward = model.forward_dur_predictor
326
+ return model
327
+
328
+ def view_as_pitch_preprocess(self):
329
+ model = copy.deepcopy(self)
330
+ del model.fs2
331
+ if self.predict_pitch:
332
+ del model.pitch_predictor
333
+ if self.predict_variances:
334
+ del model.variance_predictor
335
+ model.forward = model.forward_pitch_preprocess
336
+ return model
337
+
338
+ def view_as_pitch_predictor(self):
339
+ assert self.predict_pitch
340
+ model = copy.deepcopy(self)
341
+ del model.fs2
342
+ del model.lr
343
+ if self.use_melody_encoder:
344
+ del model.melody_encoder
345
+ if self.predict_variances:
346
+ del model.variance_predictor
347
+ model.forward = model.forward_pitch_reflow
348
+ return model
349
+
350
+ def view_as_pitch_postprocess(self):
351
+ model = copy.deepcopy(self)
352
+ del model.fs2
353
+ if self.use_melody_encoder:
354
+ del model.melody_encoder
355
+ if self.predict_variances:
356
+ del model.variance_predictor
357
+ model.forward = model.forward_pitch_postprocess
358
+ return model
359
+
360
+ def view_as_variance_preprocess(self):
361
+ model = copy.deepcopy(self)
362
+ del model.fs2
363
+ if self.predict_pitch:
364
+ del model.pitch_predictor
365
+ if self.use_melody_encoder:
366
+ del model.melody_encoder
367
+ if self.predict_variances:
368
+ del model.variance_predictor
369
+ model.forward = model.forward_variance_preprocess
370
+ return model
371
+
372
+ def view_as_variance_predictor(self):
373
+ assert self.predict_variances
374
+ model = copy.deepcopy(self)
375
+ del model.fs2
376
+ del model.lr
377
+ if self.predict_pitch:
378
+ del model.pitch_predictor
379
+ if self.use_melody_encoder:
380
+ del model.melody_encoder
381
+ model.forward = model.forward_variance_reflow
382
+ return model
383
+
384
+ def view_as_variance_postprocess(self):
385
+ model = copy.deepcopy(self)
386
+ del model.fs2
387
+ if self.predict_pitch:
388
+ del model.pitch_predictor
389
+ if self.use_melody_encoder:
390
+ del model.melody_encoder
391
+ model.forward = model.forward_variance_postprocess
392
+ return model
dictionaries/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.py
2
+ *.txt
3
+ !opencpop*
docs/BestPractices.md ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Best Practices
2
+
3
+ ## Materials for training and using models
4
+
5
+ ### Datasets
6
+
7
+ A dataset mainly includes recordings and transcriptions, which is called a _raw dataset_. Raw datasets should be organized as the following folder structure:
8
+
9
+ - my_raw_data/
10
+ - wavs/
11
+ - 001.wav
12
+ - 002.wav
13
+ - ... (more recording files)
14
+ - transcriptions.csv
15
+
16
+ In the example above, the _my_raw_data_ folder is the root directory of a raw dataset.
17
+
18
+ The _transcriptions.csv_ file contains all labels of the recordings. The common column of the CSV file is `name`, which represents all recording items by their filenames **without extension**. Elements of sequence attributes should be split by `space`. Other required columns may vary according to the category of the model you are training, and will be introduced in the following sections.
19
+
20
+ ### Dictionaries
21
+
22
+ A dictionary is a .txt file, in which each line represents a mapping rule from one syllable to its phoneme sequence. The syllable and the phonemes are split by `tab`, and the phonemes are split by `space`:
23
+
24
+ ```
25
+ <syllable> <phoneme1> <phoneme2> ...
26
+ ```
27
+
28
+ Syllable names and phoneme names can be customized, but with the following limitations/suggestions:
29
+
30
+ - `SP` (rest), `AP` (breath) and `<PAD>` (padding) cannot be used because they are reserved.
31
+ - `-` and `+` cannot be used because they are defined as slur tags in most singing voice synthesis editors.
32
+ - Special characters including but not limited to `@`, `#`, `&`, `|`, `/`, `<`, `>`, etc. should be avoided because they may be used as special tags in the future format changes. Using them now is okay, and all modifications will be notified in advance.
33
+ - ASCII characters are preferred for the best encoding compatibility, but all UTF-8 characters are acceptable.
34
+
35
+ There are some preset dictionaries in the [dictionaries/](../dictionaries) folder. For the guidance of using a custom dictionary, see [Using custom dictionaries](#using-custom-dictionaries).
36
+
37
+ ### Configuration files
38
+
39
+ A configuration file is a YAML file that defines enabled features, model hyperparameters and controls the behavior of the binarizer, trainer and inference. For more information of the configuration system and configurable attributes, see [Configuration Schemas](ConfigurationSchemas.md).
40
+
41
+ ### DS files
42
+
43
+ DS files are JSON files with _.ds_ suffix that contains phoneme sequence, phoneme durations, music scores or curve parameters. They are mainly used to run inference on models for test and evaluation purposes, and they can be used as training data in some cases. There are some example DS files in the [samples/](../samples) folder.
44
+
45
+ The current recommended way of using a model for production purposes is to use [OpenUTAU for DiffSinger](https://github.com/xunmengshe/OpenUtau). It can export DS files as well.
46
+
47
+ ### Other fundamental assets
48
+
49
+ #### Vocoders
50
+
51
+ A vocoder is a model that can reconstruct the audio waveform given the low-dimensional mel-spectrogram. The vocoder is the essential dependency if you want to train an acoustic model and hear the voice on the TensorBoard.
52
+
53
+ The [DiffSinger Community Vocoders Project](https://openvpi.github.io/vocoders) provides a universal pre-trained NSF-HiFiGAN vocoder that can be used for starters of this repository. To use it, download the model (~50 MB size) from its releases and unzip it into the `checkpoints/` folder.
54
+
55
+ The pre-trained vocoder can be fine-tuned on your target dataset. It is highly recommended to do so because fine-tuned vocoder can generate much better results on specific (seen) datasets while does not need much computing resources. See the [vocoder training and fine-tuning repository](https://github.com/openvpi/SingingVocoders) for detailed instructions. After you get the fine-tuned vocoder checkpoint, you can configure it by `vocoder_ckpt` key in your configuration file. The fine-tuned NSF-HiFiGAN vocoder checkpoints can be exported to ONNX format like other DiffSinger user models for further production purposes.
56
+
57
+ Another unrecommended option: train a ultra-lightweight [DDSP vocoder](https://github.com/yxlllc/pc-ddsp) first by yourself, then configure it according to the relevant [instructions](https://github.com/yxlllc/pc-ddsp/blob/master/DiffSinger.md).
58
+
59
+ #### Feature extractors or auxiliary models
60
+
61
+ RMVPE is the recommended pitch extractor of this repository, which is an NN-based algorithm and requires a pre-trained model. For more information about pitch extractors and how to configure them, see [feature extraction](#pitch-extraction).
62
+
63
+ Vocal Remover (VR) is the recommended harmonic-noise separator of this repository, which is an NN-based algorithm and requires a pre-trained model. For more information about harmonic-noise separators and how to configure them, see [feature extraction](#harmonic-noise-separation).
64
+
65
+ ## Overview: training acoustic models
66
+
67
+ An acoustic model takes low-level singing information as input, including (but not limited to) phoneme sequence, phoneme durations and F0 sequence. The only output of an acoustic model is the mel-spectrogram, which can be converted to waveform (the final audio) through the vocoder. Briefly speaking, an acoustic model takes in all features that are explicitly given, and produces the singing voice.
68
+
69
+ ### Datasets
70
+
71
+ To train an acoustic model, you must have three columns in your transcriptions.csv: `name`, `ph_seq` and `ph_dur`, where `ph_seq` is the phoneme sequence and `ph_dur` is the phoneme duration sequence in seconds. You must have all corresponding recordings declared by the `name` column in mono, WAV format.
72
+
73
+ Training from multiple datasets in one model (so that the model is a multi-speaker model) is supported. See `speakers`, `spk_ids` and `use_spk_id` in the configuration schemas.
74
+
75
+ ### Functionalities
76
+
77
+ Functionalities of acoustic models are defined by their inputs. Acoustic models have three basic and fixed inputs: phoneme sequence, phoneme duration sequence and F0 (pitch) sequence. There are three categories of additional inputs (control parameters):
78
+
79
+ - speaker IDs: if your acoustic model is a multi-speaker model, you can use different speaker in the same model, or mix their timbre and style.
80
+ - variance parameters: these curve parameters are features extracted from the recordings, and can control the timbre and style of the singing voice. See `use_energy_embed` and `use_breathiness_embed` in the configuration schemas. Please note that variance parameters **do not have default values**, so they are usually obtained from the variance model at inference time.
81
+ - transition parameters: these values represent the transition of the mel-spectrogram, and are obtained by enabling data augmentation. They are scalars at training time and sequences at inference time. See `augmentation_args`, `use_key_shift_embed` and `use_speed_embed` in the configuration schemas.
82
+
83
+ ## Overview: training variance models
84
+
85
+ A variance model takes high-level music information as input, including phoneme sequence, word division, word durations and music scores. The outputs of a variance model may include phoneme durations, pitch curve and other control parameters that will be consumed by acoustic models. Briefly speaking, a variance model works as an auxiliary tool (so-called _automatic parameter generator_) for the acoustic models.
86
+
87
+ ### Datasets
88
+
89
+ To train a variance model, you must have all the required attributes listed in the following table in your transcriptions.csv according to the functionalities enabled.
90
+
91
+ | | name | ph_seq | ph_dur | ph_num | note_seq | note_dur |
92
+ |:------------------------------:|:----:|:------:|:------:|:------:|:--------:|:--------:|
93
+ | phoneme duration prediction | ✓ | ✓ | ✓ | ✓ | | |
94
+ | pitch prediction | ✓ | ✓ | ✓ | | ✓ | ✓ |
95
+ | variance parameters prediction | ✓ | ✓ | ✓ | | | |
96
+
97
+ The recommended way of building a variance dataset is to extend an acoustic dataset. You may have all the recordings prepared like the acoustic dataset as well, or [use DS files in your variance datasets](#build-variance-datasets-with-ds-files).
98
+
99
+ Variance models support multi-speaker settings like acoustic models do.
100
+
101
+ ### Functionalities
102
+
103
+ Functionalities of variance models are defined by their outputs. There are three main prediction modules that can be enabled/disable independently:
104
+
105
+ - Duration Predictor: predicts the phoneme durations. See `predict_dur` in the configuration schemas.
106
+ - Pitch Predictor: predicts the pitch curve. See `predict_pitch` in the configuration schemas.
107
+ - Multi-Variance Predictor: jointly predicts other variance parameters. See `predict_energy` and `predict_breathiness` in the configuration schemas.
108
+
109
+ There may be some mutual influence between the modules above when they are enabled together. See [mutual influence between variance modules](#mutual-influence-between-variance-modules) for more details.
110
+
111
+ ## Using custom dictionaries
112
+
113
+ This section is about using a custom grapheme-to-phoneme dictionary for any language(s).
114
+
115
+ ### Add a dictionary
116
+
117
+ Assume that you have made a dictionary file named `my_dict.txt`. Edit your configuration file:
118
+
119
+ ```yaml
120
+ dictionary: my_dict.txt
121
+ ```
122
+
123
+ Then you can binarize your data as normal. The phonemes in your dataset must cover, and must only cover the phonemes appeared in your dictionary. Otherwise, the binarizer will raise an error:
124
+
125
+ ```
126
+ AssertionError: transcriptions and dictionary mismatch.
127
+ (+) ['E', 'En', 'i0', 'ir']
128
+ (-) ['AP', 'SP']
129
+ ```
130
+
131
+ This means there are 4 unexpected symbols in the data labels (`ir`, `i0`, `E`, `En`) and 2 missing phonemes that are not covered by the data labels (`AP`, `SP`).
132
+
133
+ Once the coverage checks passed, a phoneme distribution summary will be saved into your binary data directory. Below is an example.
134
+
135
+ ![phoneme-distribution](resources/phoneme-distribution.jpg)
136
+
137
+ During the binarization process, each phoneme will be assigned with a unique phoneme ID according the order of their names. There are one padding index (marked as `<PAD`) before all real phonemes IDs.
138
+
139
+ The dictionary used to binarize the dataset will be copied to the binary data directory by the binarizer, and will be copied again to the experiment directory by the trainer. When exported to ONNX, the dictionary and the phoneme sequence ordered by IDs will be saved to the artifact directory. You do not need to carry the original dictionary file for training and inference.
140
+
141
+ ### Preset dictionaries
142
+
143
+ There are currently some preset dictionaries for you to use directly:
144
+
145
+ | dictionary | filename | description |
146
+ |:------------------:|:----------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
147
+ | Opencpop | opencpop.txt | The original dictionary used by the Opencpop mandarin singing dataset that is fully aligned with the pinyin writing system. We copied the dictionary from [here](http://wenet.org.cn/opencpop/resources/annotationformat/), removed 5 syllables that has no occurrence in the data labels (`hm`, `hng`, `m`, `n` and `ng`) and added some aliases for some syllables (e.g. `jv` for `ju`). Due to pronunciation issues, this dictionary is deprecated and remained only for backward compatibility. |
148
+ | Opencpop extension | opencpop-extension.txt | The modified version of the opencpop dictionary, with stricter phoneme division rules for some pinyin syllables. For example, `ci` is mapped to `c i0` and `chi` is mapped to `ch ir` to distinguish with `bi` (`b i`). This dictionary is now used as the default dictionary for mandarin Chinese. There are also many new syllables for more phoneme combinations. |
149
+
150
+ ### Submit or propose a new dictionary
151
+
152
+ You can submit or propose a new dictionary by raising a topic in [Discussions](https://github.com/openvpi/DiffSinger/discussions). Any dictionary to be formally supported in the main branch must match the following principles:
153
+
154
+ - Only monolingual dictionaries are accepted for now. Support for multilingual dictionaries will be designed in the future.
155
+ - All syllables and phonemes in the dictionary should have linguistic meanings. Style tags (vocal fry, falsetto, etc.) should not appear in the dictionary.
156
+ - Its syllables should be standard spelling or phonetic transcriptions (like pinyin in mandarin Chinese and romaji in Japanese) for easy integration with G2P modules.
157
+ - Its phonemes should cover all (or almost all) possible pronunciations in that language.
158
+ - Every syllable and every phoneme should have one, and only one certain pronunciation, in all or almost all situations in that language. Some slight context-based pronunciation differences are allowed as the networks can learn.
159
+ - Most native speakers/singers of that language should be able to easily cover all phonemes in the dictionary. This means the dictionary should not contain extremely rare or highly customized phonemes of some dialects or accents.
160
+ - It should not bring too much difficulty and complexity to the data labeling workflow, and it should be easy to use for end users of voicebanks.
161
+
162
+ ## Build variance datasets with DS files
163
+
164
+ By default, the variance binarizer loads attributes from transcriptions.csv and searches for recording files (*.wav) to extract features and parameters. These attributes and parameters also exist in DS files, which are normally used for inference. This section introduces the required settings and important notes to build a variance dataset from DS files.
165
+
166
+ First of all, you should edit your configuration file to enable loading from DS files:
167
+
168
+ ```yaml
169
+ binarization_args:
170
+ prefer_ds: true # prefer loading from DS files
171
+ ```
172
+
173
+ Then you should prepare some DS files which are properly segmented. If you export DS files with OpenUTAU for DiffSinger, the DS files are already segmented according to the spaces between notes. You should put these DS files in a folder named `ds` in your raw dataset directory (besides the `wavs` folder).
174
+
175
+ The DS files should also use the same dictionary as that of your target model. The attributes required vary from your target functionalities, as listed below:
176
+
177
+ | attribute name | required by duration prediction | required by pitch prediction | required by variance parameters prediction | previous source | current source |
178
+ |:----------------------------:|:-------------------------------:|:----------------------------:|:------------------------------------------:|:---------------:|:--------------:|
179
+ | `name` | ✓ | ✓ | ✓ | CSV | CSV |
180
+ | `ph_seq` | ✓ | ✓ | ✓ | CSV | DS/CSV |
181
+ | `ph_dur` | ✓ | ✓ | ✓ | CSV | DS/CSV |
182
+ | `ph_num` | ✓ | | | CSV | DS/CSV |
183
+ | `note_seq` | | ✓ | | CSV | DS/CSV |
184
+ | `note_dur` | | ✓ | | CSV | DS/CSV |
185
+ | `f0_seq` | ✓ | ✓ | ✓ | WAV | DS/WAV |
186
+ | `energy`, `breathiness`, ... | | | ✓ | WAV | DS/WAV |
187
+
188
+ This means you only need one column in transcriptions.csv, the `name` column, to declare all DS files included in the dataset. The name pattern can be:
189
+
190
+ - Full name: `some-name` will firstly match the first segment in `some-name.ds`.
191
+ - Name with index: `some-name#0` and `some-name#1` will match segment 0 and segment 1 in `some-name.ds` if there are no match with full name.
192
+
193
+ Though not recommended, the binarizer will still try to load attributes from transcriptions.csv or extract parameters from recordings if there are no matching DS files. In this case the full name matching logic is applied (the same as the normal binarization process).
194
+
195
+ ## Choosing variance parameters
196
+
197
+ Variance parameters are a type of parameters that are significantly related to singing styles and emotions, have no default values and need to be predicted by the variance models. Choosing the proper variance parameters can obtain more controllability and expressiveness for your singing models. In this section, we are only talking about **narrowly defined variance parameters**, which are variance parameters except the pitch.
198
+
199
+ ### Supported variance parameters
200
+
201
+ #### Energy
202
+
203
+ > WARNING
204
+ >
205
+ > This parameter is no longer recommended in favor of the new voicing parameter. The latter are less coupled with breathiness than energy.
206
+
207
+ Energy is defined as the RMS curve of the singing, in dB, which can control the strength of voice to a certain extent.
208
+
209
+ #### Breathiness
210
+
211
+ Breathiness is defined as the RMS curve of the aperiodic part of the singing, in dB, which can control the power of the air and unvoiced consonants in the voice.
212
+
213
+ #### Voicing
214
+
215
+ Voicing is defined as the RMS curve of the harmonic part of the singing, in dB, which can control the power of the harmonics in vowels and voiced consonants in the voice.
216
+
217
+ #### Tension
218
+
219
+ Tension is mostly related to the ratio of the base harmonic to the full harmonics, which can be used to control the strength and timbre of the voice. The ratio is calculated as
220
+ $$
221
+ r = \frac{\text{RMS}(H_{full}-H_{base})}{\text{RMS}(H_{full})}
222
+ $$
223
+ where $H_{full}$ is the full harmonics and $H_{base}$ is the base harmonic. The ratio is then mapped to the final domain via the inverse function of Sigmoid, that
224
+ $$
225
+ T = \log{\frac{r}{1-r}}
226
+ $$
227
+ where $T$ is the tension value.
228
+
229
+ ### Principles of choosing multiple parameters
230
+
231
+ #### Energy, breathiness and voicing
232
+
233
+ These three parameters should **NOT** be enabled together. Energy is the RMS of the full waveform, which is the composition of the harmonic part and the aperiodic part. Therefore, these three parameters are coupled with each other.
234
+
235
+ #### Energy, voicing and tension
236
+
237
+ When voicing (or energy) is enabled, it almost fixes the loudness. However, tension sometimes rely on the implicitly predicted loudness for more expressiveness, because when a person sings with higher tension, he/she always produces louder voice. For this reason, some people may find their models or datasets _less natural_ with tension control. To be specific, changing tension will change the timbre but keep the loudness, and changing voicing (or energy) will change the loudness but keep the timbre. This behavior can be suitable for some, but not all datasets and users. Therefore, it is highly recommended for everyone to conduct some experiments on the actual datasets used to train the model.
238
+
239
+ ## Mutual influence between variance modules
240
+
241
+ In some recent experiments and researches, some mutual influence between the modules of variance models has been found. In practice, being aware of the influence and making use of it can improve accuracy and avoid instability of the model.
242
+
243
+ ### Influence on the duration predictor
244
+
245
+ The duration predictor benefits from its downstream modules, like the pitch predictor and the variance predictor.
246
+
247
+ The experiments were conducted on both manually refined datasets and automatically labeled datasets, and with pitch predictors driven by both base pitch and melody encoder. All the results have shown that when either of the pitch predictor and the variance predictor is enabled together with the duration predictor, its rhythm correctness and duration accuracy significantly outperforms those of a solely trained duration predictor.
248
+
249
+ Possible reason for this difference can be the lack of information carried by pure phoneme duration sequences, which may not fully represent the phoneme features in the real world. With the help of frame-level feature predictors, the encoder learns more knowledge about the voice features related to the phoneme types and durations, thus making the duration predictor produce better results.
250
+
251
+ ### Influence on frame-level feature predictors
252
+
253
+ Frame-level feature predictors, including the pitch predictor and the variance predictor, have better performance when trained without enabling the duration predictor.
254
+
255
+ The experiments found that when the duration predictor is enabled, the pitch accuracy drops and the dynamics of variance parameters sometimes become unstable. And it has nothing to do with the gradients from the duration predictor, because applying a scale factor on the gradients does not make any difference even if the gradients are completely cut off.
256
+
257
+ Possible reason for this phenomenon can be the lack of direct phoneme duration input. When the duration predictor is enabled, the model takes in word durations instead of phoneme durations; when there is no duration predictor together, the phoneme duration sequence is directly taken in and passed through the attention-based linguistic encoder. With direct modeling on the phoneme duration, the frame-level predictors can have a better understanding of the context, thus producing better results.
258
+
259
+ Another set of experiments showed that there is no significant influence between the pitch predictor and the variance predictor. When they are enabled together without the duration predictor, both can converge well and produce satisfactory results. No conclusion can be drawn on this issue, and it can depend on the dataset.
260
+
261
+ ### Suggested procedures of training variance models
262
+
263
+ According to the experiment results and the analysis above, the suggested procedures of training a set of variance models are listed below:
264
+
265
+ 1. Train the duration predictor together with the variance predictor, and discard the variance predictor part.
266
+ 2. Train the pitch predictor and the variance predictor separately or together.
267
+ 3. If interested, compare across different combinations in step 2 and choose the best.
268
+
269
+ ## Feature extraction
270
+
271
+ Feature extraction is the process of extracting low-level features from the recordings, which are needed as inputs for the acoustic models, or as outputs for the variance models.
272
+
273
+ ### Pitch extraction
274
+
275
+ A pitch extractor estimates pitch (F0 sequence) from given recordings. F0 (fundamental frequency) is one of the most important components of singing voice that is needed by both acoustic models and variance models.
276
+
277
+ ```yaml
278
+ pe: parselmouth # pitch extractor type
279
+ pe_ckpt: checkpoints/xxx/model.pt # pitch extractor model path (if it requires any)
280
+ ```
281
+
282
+ #### Parselmouth
283
+
284
+ [Parselmouth](https://github.com/YannickJadoul/Parselmouth) is the default pitch extractor in this repository. It is based on DSP algorithms, runs fast on CPU and can get accurate F0 on clean and normal recordings.
285
+
286
+ To use parselmouth, simply include the following line in your configuration file:
287
+
288
+ ```yaml
289
+ pe: parselmouth
290
+ ```
291
+
292
+ #### RMVPE (recommended)
293
+
294
+ [RMVPE](https://github.com/Dream-High/RMVPE) (Robust Model for Vocal Pitch Estimation) is the state-of-the-art NN-based pitch estimation model for singing voice. It runs slower than parselmouth, consumes more memory, however uses CUDA to accelerate computation (if available) and produce better results on noisy recordings and edge cases.
295
+
296
+ To enable RMVPE, download its pre-trained checkpoint from [here](https://github.com/yxlllc/RMVPE/releases), extract it into the `checkpoints/` folder and edit the configuration file:
297
+
298
+ ```yaml
299
+ pe: rmvpe
300
+ pe_ckpt: checkpoints/rmvpe/model.pt
301
+ ```
302
+
303
+ #### Harvest
304
+
305
+ Harvest (Harvest: A high-performance fundamental frequency estimator from speech signals) is the recommended pitch extractor from Masanori Morise's [WORLD](https://github.com/mmorise/World), a free software for high-quality speech analysis, manipulation and synthesis. It is a state-of-the-art algorithmic pitch estimator designed for speech, but has seen use in singing voice synthesis. It runs the slowest compared to the others, but provides very accurate F0 on clean and normal recordings compared to parselmouth.
306
+
307
+ To use Harvest, simply include the following line in your configuration file:
308
+
309
+ ```yaml
310
+ pe: harvest
311
+ ```
312
+
313
+ **Note:** It is also recommended to change the F0 detection range for Harvest with accordance to your dataset, as they are hard boundaries for this algorithm and the defaults might not suffice for most use cases. To change the F0 detection range, you may include or edit this part in the configuration file:
314
+
315
+ ```yaml
316
+ f0_min: 65 # Minimum F0 to detect
317
+ f0_max: 800 # Maximum F0 to detect
318
+ ```
319
+
320
+ ### Harmonic-noise separation
321
+
322
+ Harmonic-noise separation is the process of separating the harmonic part and the aperiodic part of the singing voice. These parts are the fundamental components for variance parameters including breathiness, voicing and tension to be calculated from.
323
+
324
+ #### WORLD
325
+
326
+ This algorithm uses Masanori Morise's [WORLD](https://github.com/mmorise/World), a free software for high-quality speech analysis, manipulation and synthesis. It uses CPU (no CUDA required) but runs relatively slow.
327
+
328
+ To use WORLD, simply include the following line in your configuration file:
329
+
330
+ ```yaml
331
+ hnsep: world
332
+ ```
333
+
334
+ #### Vocal Remover (recommended)
335
+
336
+ Vocal Remover (VR) is originally a popular NN-based algorithm for music source separation that removes the vocal part from the music. This repository uses a specially trained model for harmonic-noise separation. VR extracts much cleaner harmonic parts, utilizes CUDA to accelerate computation (if available) and runs much faster than WORLD. However, it consumes more memory and should not be used with too many parallel workers.
337
+
338
+ To enable VR, download its pre-trained checkpoint from [here](https://github.com/yxlllc/vocal-remover/releases), extract it into the `checkpoints/` folder and edit the configuration file:
339
+
340
+ ```yaml
341
+ hnsep: vr
342
+ hnsep_ckpt: checkpoints/vr/model.pt
343
+ ```
344
+
345
+ ## Shallow diffusion
346
+
347
+ Shallow diffusion is a mechanism that can improve quality and save inference time for diffusion models that was first introduced in the original DiffSinger [paper](https://arxiv.org/abs/2105.02446). Instead of starting the diffusion process from purely gaussian noise as classic diffusion does, shallow diffusion adds a shallow gaussian noise on a low-quality results generated by a simple network (which is called the auxiliary decoder) to skip many unnecessary steps from the beginning. With the combination of shallow diffusion and sampling acceleration algorithms, we can get better results under the same inference speed as before, or achieve higher inference speed without quality deterioration.
348
+
349
+ Currently, acoustic models in this repository support shallow diffusion. The main switch of shallow diffusion is `use_shallow_diffusion` in the configuration file, and most arguments of shallow diffusion can be adjusted under `shallow_diffusion_args`. See [Configuration Schemas](ConfigurationSchemas.md) for more details.
350
+
351
+ ### Train full shallow diffusion models from scratch
352
+
353
+ To train a full shallow diffusion model from scratch, simply introduce the following settings in your configuration file:
354
+
355
+ ```yaml
356
+ use_shallow_diffusion: true
357
+ K_step: 400 # adjust according to your needs
358
+ K_step_infer: 400 # should be <= K_step
359
+ ```
360
+
361
+ Please note that when shallow diffusion is enabled, only the last $K$ diffusion steps will be trained. Unlike classic diffusion models which are trained on full steps, the limit of `K_step` can make the training more efficient. However, `K_step` should not be set too small because without enough diffusion depth (steps), the low-quality auxiliary decoder results cannot be well refined. 200 ~ 400 should be the proper range of `K_step`.
362
+
363
+ The auxiliary decoder and the diffusion decoder shares the same linguistic encoder, which receives gradients from both the decoders. In some experiments, it was found that gradients from the auxiliary decoder will cause mismatching between the encoder and the diffusion decoder, resulting in the latter being unable to produce reasonable results. To prevent this case, a configuration item called `aux_decoder_grad` is introduced to apply a scale factor on the gradients from the auxiliary decoder during training. To adjust this factor, introduce the following in the configuration file:
364
+
365
+ ```yaml
366
+ shallow_diffusion_args:
367
+ aux_decoder_grad: 0.1 # should not be too high
368
+ ```
369
+
370
+ ### Train auxiliary decoder and diffusion decoder separately
371
+
372
+ Training a full shallow diffusion model can consume more memory because the auxiliary decoder is also in the training graph. In limited situations, the two decoders can be trained separately, i.e. train one decoder after another.
373
+
374
+ **STEP 1: train the diffusion decoder**
375
+
376
+ In the first stage, the linguistic encoder and the diffusion decoder is trained together, while the auxiliary decoder is left unchanged. Edit your configuration file like this:
377
+
378
+ ```yaml
379
+ use_shallow_diffusion: true # make sure the main option is turned on
380
+ shallow_diffusion_args:
381
+ train_aux_decoder: false # exclude the auxiliary decoder from the training graph
382
+ train_diffusion: true # train diffusion decoder as normal
383
+ val_gt_start: true # should be true because the auxiliary decoder is not trained yet
384
+ ```
385
+
386
+ Start training until `max_updates` is reached, or until you get satisfactory results on the TensorBoard.
387
+
388
+ **STEP 2: train the auxiliary decoder**
389
+
390
+ In the second stage, the auxiliary decoder is trained besides the linguistic encoder and the diffusion decoder. Edit your configuration file like this:
391
+
392
+ ```yaml
393
+ shallow_diffusion_args:
394
+ train_aux_decoder: true
395
+ train_diffusion: false # exclude the diffusion decoder from the training graph
396
+ lambda_aux_mel_loss: 1.0 # no more need to limit the auxiliary loss
397
+ ```
398
+
399
+ Then you should freeze the encoder to prevent it from getting updates. This is because if the encoder changes, it no longer matches with the diffusion decoder, thus making the latter unable to produce correct results again. Edit your configuration file:
400
+
401
+ ```yaml
402
+ freezing_enabled: true
403
+ frozen_params:
404
+ - model.fs2 # the linguistic encoder
405
+ ```
406
+
407
+ You should also manually reset your learning rate scheduler because this is a new training process for the auxiliary decoder. Possible ways are:
408
+
409
+ 1. Rename the latest checkpoint to `model_ckpt_steps_0.ckpt` and remove the other checkpoints from the directory.
410
+ 2. Increase the initial learning rate (if you use a scheduler that decreases the LR over training steps) so that the auxiliary decoder gets proper learning rate.
411
+
412
+ Additionally, `max_updates` should be adjusted to ensure enough training steps for the auxiliary decoder.
413
+
414
+ Once you finished the configurations above, you can resume the training. The auxiliary decoder normally does not need many steps to train, and you can stop training when you get stable results on the TensorBoard. Because this step is much more complicated than the previous step, it is recommended to run some inference to verify if the model is trained properly after everything is finished.
415
+
416
+ ### Add shallow diffusion to classic diffusion models
417
+
418
+ Actually, all classic DDPMs have the ability to be "shallow". If you want to add shallow diffusion functionality to a former classic diffusion model, the only thing you need to do is to train an auxiliary decoder for it.
419
+
420
+ Before you start, you should edit the configuration file to ensure that you use the same datasets, and that you do not remove or add any of the functionalities of the old model. Then you can configure the old checkpoint in your configuration file:
421
+
422
+ ```yaml
423
+ finetune_enabled: true
424
+ finetune_ckpt_path: xxx.ckpt # path to your old checkpoint
425
+ finetune_ignored_params: [] # do not ignore any parameters
426
+ ```
427
+
428
+ Then you can follow the instructions in STEP 2 of the [previous section](#add-shallow-diffusion-to-classic-diffusion-models) to finish your training.
429
+
430
+ ## Performance tuning
431
+
432
+ This section is about accelerating training and utilizing hardware.
433
+
434
+ ### Data loader and batch sampler
435
+
436
+ The data loader loads data pieces from the binary dataset, and the batch sampler forms batches according to data lengths.
437
+
438
+ To configure the data loader, edit your configuration file:
439
+
440
+ ```yaml
441
+ ds_workers: 4 # number of DataLoader workers
442
+ dataloader_prefetch_factor: 2 # load data in advance
443
+ ```
444
+
445
+ To configure the batch sampler, edit your configuration file:
446
+
447
+ ```yaml
448
+ sampler_frame_count_grid: 6 # lower value means higher speed but less randomness
449
+ ```
450
+
451
+ For more details of the batch sampler algorithm and this configuration key, see [sampler_frame_count_grid](ConfigurationSchemas.md#sampler_frame_count_grid).
452
+
453
+ ### Automatic mixed precision
454
+
455
+ Enabling automatic mixed precision (AMP) can accelerate training and save GPU memory. DiffSinger have adapted the latest version of PyTorch Lightning for AMP functionalities.
456
+
457
+ By default, the training runs in FP32 precision. To enable AMP, edit your configuration file:
458
+
459
+ ```yaml
460
+ pl_trainer_precision: 16-mixed # FP16 precision
461
+ ```
462
+
463
+ or
464
+
465
+ ```yaml
466
+ pl_trainer_precision: bf16-mixed # BF16 precision
467
+ ```
468
+
469
+ For more precision options, please check out the [official documentation](https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision).
470
+
471
+ ### Training on multiple GPUs
472
+
473
+ Using distributed data parallel (DDP) can divide training tasks to multiple GPUs and synchronize gradients and weights between them. DiffSinger have adapted the latest version of PyTorch Lightning for DDP functionalities.
474
+
475
+ By default, the trainer will utilize all CUDA devices defined in the `CUDA_VISIBLE_DEVICES` environment variable (empty means using all available devices). If you want to specify which GPUs to use, edit your configuration file:
476
+
477
+ ```yaml
478
+ pl_trainer_devices: [0, 1, 2, 3] # use the first 4 GPUs defined in CUDA_VISIBLE_DEVICES
479
+ ```
480
+
481
+ Please note that `max_batch_size` and `max_batch_frames` are values for **each** GPU.
482
+
483
+ By default, the trainer uses NCCL as the DDP backend. If this gets stuck on your machine, try disabling P2P first via
484
+
485
+ ```yaml
486
+ nccl_p2p: false # disable P2P in NCCL
487
+ ```
488
+
489
+ Or if your machine does not support NCCL, you can switch to Gloo instead:
490
+
491
+ ```yaml
492
+ pl_trainer_strategy:
493
+ name: ddp # must manually choose a strategy instead of 'auto'
494
+ process_group_backend: gloo # however, it has a lower performance than NCCL
495
+ ```
496
+
497
+ ### Gradient accumulation
498
+
499
+ Gradient accumulation means accumulating losses for several batches before each time the weights are updated. This can simulate a larger batch size with a lower GPU memory cost.
500
+
501
+ By default, the trainer calls `backward()` each time the losses are calculated through one batch of data. To enable gradient accumulation, edit your configuration file:
502
+
503
+ ```yaml
504
+ accumulate_grad_batches: 4 # the actual batch size will be 4x.
505
+ ```
506
+
507
+ Please note that enabling gradient accumulation will slow down training because the losses must be calculated for several times before the weights are updated (1 update to the weights = 1 actual training step).
508
+
509
+ ## Optimizers and learning rate schedulers
510
+
511
+ The optimizer and the learning rate scheduler can take an important role in the training process. DiffSinger uses a flexible configuration logic for these two modules.
512
+
513
+ ### Basic configurations
514
+
515
+ The optimizer and learning rate scheduler used during training can be configured by their full class name and keyword arguments in the configuration file. Take the following as an example for the optimizer:
516
+
517
+ ```yaml
518
+ optimizer_args:
519
+ optimizer_cls: torch.optim.AdamW # class name of optimizer
520
+ lr: 0.0004
521
+ beta1: 0.9
522
+ beta2: 0.98
523
+ weight_decay: 0
524
+ ```
525
+
526
+ and for the learning rate scheduler:
527
+
528
+ ```yaml
529
+ lr_scheduler_args:
530
+ scheduler_cls: torch.optim.lr_scheduler.StepLR # class name of learning rate schedule
531
+ warmup_steps: 2000
532
+ step_size: 50000
533
+ gamma: 0.5
534
+ ```
535
+
536
+ Note that `optimizer_args` and `lr_scheduler_args` will be filtered by needed parameters and passed to `__init__` as keyword arguments (`kwargs`) when constructing the optimizer and scheduler. Therefore, you could specify all arguments according to your need in the configuration file to directly control the behavior of optimization and LR scheduling. It will also tolerate parameters existing in the configuration but not needed in `__init__`.
537
+
538
+ Also, note that the LR scheduler performs scheduling on the granularity of steps, not epochs.
539
+
540
+ The special case applies when a tuple is needed in `__init__`: `beta1` and `beta2` are treated separately and form a tuple in the code. You could try to pass in an array instead. (And as an experiment, AdamW does accept `[beta1, beta2]`). If there is another special treatment required, please submit an issue.
541
+
542
+ For PyTorch built-in optimizers and LR schedulers, see official [documentation](https://pytorch.org/docs/stable/optim.html) of the `torch.optim` package. If you found other optimizer and learning rate scheduler useful, you can raise a topic in [Discussions](https://github.com/openvpi/DiffSinger/discussions), raise [Issues](https://github.com/openvpi/DiffSinger/issues) or submit [PRs](https://github.com/openvpi/DiffSinger/pulls) if it introduces new codes or dependencies.
543
+
544
+ ### Composite LR schedulers
545
+
546
+ Some LR schedulers like `SequentialLR` and `ChainedScheduler` may use other schedulers as arguments. Besides built-in types, there is a special design to configure these scheduler objects. See the following example.
547
+
548
+ ```yaml
549
+ lr_scheduler_args:
550
+ scheduler_cls: torch.optim.lr_scheduler.SequentialLR
551
+ schedulers:
552
+ - cls: torch.optim.lr_scheduler.ExponentialLR
553
+ gamma: 0.5
554
+ - cls: torch.optim.lr_scheduler.LinearLR
555
+ - cls: torch.optim.lr_scheduler.MultiStepLR
556
+ milestones:
557
+ - 10
558
+ - 20
559
+ milestones:
560
+ - 10
561
+ - 20
562
+ ```
563
+
564
+ The LR scheduler objects will be recursively construct objects if `cls` is present in sub-arguments. Please note that `cls` must be a scheduler class because this is a special design.
565
+
566
+ **WARNING:** Nested `SequentialLR` and `ChainedScheduler` have unexpected behavior. **DO NOT** nest them. Also, make sure the scheduler is _chainable_ before using it in `ChainedScheduler`.
567
+
568
+ ## Fine-tuning and parameter freezing
569
+
570
+ ### Fine-tuning from existing checkpoints
571
+
572
+ By default, the training starts from a model from scratch with randomly initialized parameters. However, if you already have some pre-trained checkpoints, and you need to adapt them to other datasets with their functionalities unchanged, fine-tuning may save training steps and time. In general, you need to add the following structure into the configuration file:
573
+
574
+ ```yaml
575
+ # take acoustic models as an example
576
+ finetune_enabled: true # the main switch to enable fine-tuning
577
+ finetune_ckpt_path: checkpoints/pretrained/model_ckpt_steps_320000.ckpt # path to your pre-trained checkpoint
578
+ finetune_ignored_params: # prefix rules to exclude specific parameters when loading the checkpoints
579
+ - model.fs2.encoder.embed_tokens # in case when the phoneme set is changed
580
+ - model.fs2.txt_embed # same as above
581
+ - model.fs2.spk_embed # in case when the speaker set is changed
582
+ finetune_strict_shapes: true # whether to raise an error when parameter shapes mismatch
583
+ ```
584
+
585
+ For the pre-trained checkpoint, it must be a file saved with `torch.save`, containing a `dict` object and a `state_dict` key, like the following example:
586
+
587
+ ```json5
588
+ {
589
+ "state_dict": {
590
+ "model.fs2.txt_embed": null, // torch.Tensor
591
+ "model.fs2.pitch_embed.weight": null, // torch.Tensor
592
+ "model.fs2.pitch_embed.bias": null, // torch.Tensor
593
+ // ... (other parameters)
594
+ }
595
+ // ... (other possible keys
596
+ }
597
+ ```
598
+
599
+ **IMPORTANT NOTES**:
600
+
601
+ - The pre-trained checkpoint is **loaded only once** at the beginning of the training experiment. You may interrupt the training at any time, but after this new experiment has saved its own checkpoint, the pre-trained checkpoint will not be loaded again when the training is resumed.
602
+ - Only the state dict of the checkpoint will be loaded. The optimizer state in the pre-trained checkpoint will be ignored.
603
+ - The parameter name matching is **not strict** when loading the pre-trained checkpoint. This means that missing parameters in the state dict will still be left as randomly initialized, and redundant parameters will be ignored without any warnings and errors. There are cases where the tensor shapes mismatch between the pre-trained state dict and the model - edit `finetune_strict_shapes` to change the behavior when dealing with this.
604
+ - Be careful if you want to change the functionalities when fine-tuning. Starting from a checkpoint trained under different functionalities may be even slower than training from scratch.
605
+
606
+ ### Freezing model parameters
607
+
608
+ Sometimes you want to freeze part of the model during training or fine-tuning to save GPU memory, accelerate the training process or avoid catastrophic forgetting. Parameter freezing may also be useful if you want to add/remove functionalities from pre-trained checkpoints. In general, you need to add the following structure into the configuration file:
609
+
610
+ ```yaml
611
+ # take acoustic models as an example
612
+ freezing_enabled: true # main switch to enable parameter freezing
613
+ frozen_params: # prefix rules to freeze specific parameters during training
614
+ - model.fs2.encoder
615
+ - model.fs2.pitch_embed
616
+ ```
617
+
618
+ You may interrupt the training and change the settings above at any time. Sometimes this will cause mismatching optimizer state - and it will be discarded silently.
docs/ConfigurationSchemas.md ADDED
@@ -0,0 +1,2109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration Schemas
2
+
3
+ ## The configuration system
4
+
5
+ DiffSinger uses a cascading configuration system based on YAML files. All configuration files originally inherit and override [configs/base.yaml](../configs/base.yaml), and each file directly override another file by setting the `base_config` attribute. The overriding rules are:
6
+
7
+ - Configuration keys with the same path and the same name will be replaced. Other paths and names will be merged.
8
+ - All configurations in the inheritance chain will be squashed (via the rule above) as the final configuration.
9
+ - The trainer will save the final configuration in the experiment directory, which is detached from the chain and made independent from other configuration files.
10
+
11
+ ## Configurable parameters
12
+
13
+ This following are the meaning and usages of all editable keys in a configuration file.
14
+
15
+ Each configuration key (including nested keys) are described with a brief explanation and several attributes listed as follows:
16
+
17
+ | Attribute | Explanation |
18
+ |:---------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
19
+ | visibility | Represents what kind(s) of models and tasks this configuration belongs to. |
20
+ | scope | The scope of effects of the configuration, indicating what it can influence within the whole pipeline. Possible values are:<br>**nn** - This configuration is related to how the neural networks are formed and initialized. Modifying it will result in failure when loading or resuming from checkpoints.<br>**preprocessing** - This configuration controls how raw data pieces or inference inputs are converted to inputs of neural networks. Binarizers should be re-run if this configuration is modified.<br>**training** - This configuration describes the training procedures. Most training configurations can affect training performance, memory consumption, device utilization and loss calculation. Modifying training-only configurations will not cause severe inconsistency or errors in most situations.<br>**inference** - This configuration describes the calculation logic through the model graph. Changing it can lead to inconsistent or wrong outputs of inference or validation.<br>**others** - Other configurations not discussed above. Will have different effects according to the descriptions. |
21
+ | customizability | The level of customizability of the configuration. Possible values are:<br>**required** - This configuration **must** be set or modified according to the actual situation or condition, otherwise errors can be raised.<br>**recommended** - It is recommended to adjust this configuration according to the dataset, requirements, environment and hardware. Most functionality-related and feature-related configurations are at this level, and all configurations in this level are widely tested with different values. However, leaving it unchanged will not cause problems.<br>**normal** - There is no need to modify it as the default value is carefully tuned and widely validated. However, one can still use another value if there are some special requirements or situations.<br>**not recommended** - No other values except the default one of this configuration are tested. Modifying it will not cause errors, but may cause unpredictable or significant impacts to the pipelines.<br>**reserved** - This configuration **must not** be modified. It appears in the configuration file only for future scalability, and currently changing it will result in errors. |
22
+ | type | Value type of the configuration. Follows the syntax of Python type hints. |
23
+ | constraints | Value constraints of the configuration. |
24
+ | default | Default value of the configuration. Uses YAML value syntax. |
25
+
26
+ ### accumulate_grad_batches
27
+
28
+ Indicates that gradients of how many training steps are accumulated before each `optimizer.step()` call. 1 means no gradient accumulation.
29
+
30
+ <table><tbody>
31
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
32
+ <tr><td align="center"><b>scope</b></td><td>training</td>
33
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
34
+ <tr><td align="center"><b>type</b></td><td>int</td>
35
+ <tr><td align="center"><b>default</b></td><td>1</td>
36
+ </tbody></table>
37
+
38
+ ### audio_num_mel_bins
39
+
40
+ Number of mel channels for the mel-spectrogram.
41
+
42
+ <table><tbody>
43
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
44
+ <tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
45
+ <tr><td align="center"><b>customizability</b></td><td>reserved</td>
46
+ <tr><td align="center"><b>type</b></td><td>int</td>
47
+ <tr><td align="center"><b>default</b></td><td>128</td>
48
+ </tbody></table>
49
+
50
+ ### audio_sample_rate
51
+
52
+ Sampling rate of waveforms.
53
+
54
+ <table><tbody>
55
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
56
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
57
+ <tr><td align="center"><b>customizability</b></td><td>reserved</td>
58
+ <tr><td align="center"><b>type</b></td><td>int</td>
59
+ <tr><td align="center"><b>default</b></td><td>44100</td>
60
+ </tbody></table>
61
+
62
+ ### augmentation_args
63
+
64
+ Arguments for data augmentation.
65
+
66
+ <table><tbody>
67
+ <tr><td align="center"><b>type</b></td><td>dict</td>
68
+ </tbody></table>
69
+
70
+ ### augmentation_args.fixed_pitch_shifting
71
+
72
+ Arguments for fixed pitch shifting augmentation.
73
+
74
+ <table><tbody>
75
+ <tr><td align="center"><b>type</b></td><td>dict</td>
76
+ </tbody></table>
77
+
78
+ ### augmentation_args.fixed_pitch_shifting.enabled
79
+
80
+ Whether to apply fixed pitch shifting augmentation.
81
+
82
+ <table><tbody>
83
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
84
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
85
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
86
+ <tr><td align="center"><b>type</b></td><td>bool</td>
87
+ <tr><td align="center"><b>default</b></td><td>false</td>
88
+ <tr><td align="center"><b>constraints</b></td><td>Must be false if <a href="#augmentation_argsrandom_pitch_shiftingenabled">augmentation_args.random_pitch_shifting.enabled</a> is set to true.</td>
89
+ </tbody></table>
90
+
91
+ ### augmentation_args.fixed_pitch_shifting.scale
92
+
93
+ Scale ratio of each target in fixed pitch shifting augmentation.
94
+
95
+ <table><tbody>
96
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
97
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
98
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
99
+ <tr><td align="center"><b>type</b></td><td>tuple</td>
100
+ <tr><td align="center"><b>default</b></td><td>0.5</td>
101
+ </tbody></table>
102
+
103
+ ### augmentation_args.fixed_pitch_shifting.targets
104
+
105
+ Targets (in semitones) of fixed pitch shifting augmentation.
106
+
107
+ <table><tbody>
108
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
109
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
110
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
111
+ <tr><td align="center"><b>type</b></td><td>tuple</td>
112
+ <tr><td align="center"><b>default</b></td><td>[-5.0, 5.0]</td>
113
+ </tbody></table>
114
+
115
+ ### augmentation_args.random_pitch_shifting
116
+
117
+ Arguments for random pitch shifting augmentation.
118
+
119
+ <table><tbody>
120
+ <tr><td align="center"><b>type</b></td><td>dict</td>
121
+ </tbody></table>
122
+
123
+ ### augmentation_args.random_pitch_shifting.enabled
124
+
125
+ Whether to apply random pitch shifting augmentation.
126
+
127
+ <table><tbody>
128
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
129
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
130
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
131
+ <tr><td align="center"><b>type</b></td><td>bool</td>
132
+ <tr><td align="center"><b>default</b></td><td>true</td>
133
+ <tr><td align="center"><b>constraints</b></td><td>Must be false if <a href="#augmentation_argsfixed_pitch_shiftingenabled">augmentation_args.fixed_pitch_shifting.enabled</a> is set to true.</td>
134
+ </tbody></table>
135
+
136
+ ### augmentation_args.random_pitch_shifting.range
137
+
138
+ Range of the random pitch shifting ( in semitones).
139
+
140
+ <table><tbody>
141
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
142
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
143
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
144
+ <tr><td align="center"><b>type</b></td><td>tuple</td>
145
+ <tr><td align="center"><b>default</b></td><td>[-5.0, 5.0]</td>
146
+ </tbody></table>
147
+
148
+ ### augmentation_args.random_pitch_shifting.scale
149
+
150
+ Scale ratio of the random pitch shifting augmentation.
151
+
152
+ <table><tbody>
153
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
154
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
155
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
156
+ <tr><td align="center"><b>type</b></td><td>float</td>
157
+ <tr><td align="center"><b>default</b></td><td>0.75</td>
158
+ </tbody></table>
159
+
160
+ ### augmentation_args.random_time_stretching
161
+
162
+ Arguments for random time stretching augmentation.
163
+
164
+ <table><tbody>
165
+ <tr><td align="center"><b>type</b></td><td>dict</td>
166
+ </tbody></table>
167
+
168
+ ### augmentation_args.random_time_stretching.enabled
169
+
170
+ Whether to apply random time stretching augmentation.
171
+
172
+ <table><tbody>
173
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
174
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
175
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
176
+ <tr><td align="center"><b>type</b></td><td>bool</td>
177
+ <tr><td align="center"><b>default</b></td><td>true</td>
178
+ </tbody></table>
179
+
180
+ ### augmentation_args.random_time_stretching.range
181
+
182
+ Range of random time stretching factors.
183
+
184
+ <table><tbody>
185
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
186
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
187
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
188
+ <tr><td align="center"><b>type</b></td><td>tuple</td>
189
+ <tr><td align="center"><b>default</b></td><td>[0.5, 2]</td>
190
+ </tbody></table>
191
+
192
+ ### augmentation_args.random_time_stretching.scale
193
+
194
+ Scale ratio of random time stretching augmentation.
195
+
196
+ <table><tbody>
197
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
198
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
199
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
200
+ <tr><td align="center"><b>type</b></td><td>float</td>
201
+ <tr><td align="center"><b>default</b></td><td>0.75</td>
202
+ </tbody></table>
203
+
204
+ ### backbone_args
205
+
206
+ Keyword arguments for the backbone of main decoder module.
207
+
208
+ <table><tbody>
209
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
210
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
211
+ <tr><td align="center"><b>type</b></td><td>dict</td>
212
+ </tbody></table>
213
+
214
+ Some available arguments are listed below.
215
+
216
+ | argument name | for backbone type | description |
217
+ |:---------------------:|:-----------------:|:-----------------------------------------------------------------------------------------------------------:|
218
+ | num_layers | wavenet/lynxnet | Number of layer blocks, or depth of the network |
219
+ | num_channels | wavenet/lynxnet | Number of channels, or width of the network |
220
+ | dilation_cycle_length | wavenet | Length k of the cycle $2^0, 2^1 ...., 2^k$ of convolution dilation factors through WaveNet residual blocks. |
221
+
222
+ ### backbone_type
223
+
224
+ Backbone type of the main decoder/predictor module.
225
+
226
+ <table><tbody>
227
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
228
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
229
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
230
+ <tr><td align="center"><b>type</b></td><td>str</td>
231
+ <tr><td align="center"><b>default</b></td><td>lynxnet</td>
232
+ <tr><td align="center"><b>constraints</b></td><td>Choose from 'wavenet', 'lynxnet'.</td>
233
+ </tbody></table>
234
+
235
+ ### base_config
236
+
237
+ Path(s) of other config files that the current config is based on and will override.
238
+
239
+ <table><tbody>
240
+ <tr><td align="center"><b>scope</b></td><td>others</td>
241
+ <tr><td align="center"><b>type</b></td><td>Union[str, list]</td>
242
+ </tbody></table>
243
+
244
+ ### binarization_args
245
+
246
+ Arguments for binarizers.
247
+
248
+ <table><tbody>
249
+ <tr><td align="center"><b>type</b></td><td>dict</td>
250
+ </tbody></table>
251
+
252
+ ### binarization_args.num_workers
253
+
254
+ Number of worker subprocesses when running binarizers. More workers can speed up the preprocessing but will consume more memory. 0 means the main processing doing everything.
255
+
256
+ <table><tbody>
257
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
258
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
259
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
260
+ <tr><td align="center"><b>type</b></td><td>int</td>
261
+ <tr><td align="center"><b>default</b></td><td>1</td>
262
+ </tbody></table>
263
+
264
+ ### binarization_args.prefer_ds
265
+
266
+ Whether to prefer loading attributes and parameters from DS files.
267
+
268
+ <table><tbody>
269
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
270
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
271
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
272
+ <tr><td align="center"><b>type</b></td><td>bool</td>
273
+ <tr><td align="center"><b>default</b></td><td>False</td>
274
+ </tbody></table>
275
+
276
+ ### binarization_args.shuffle
277
+
278
+ Whether binarized dataset will be shuffled or not.
279
+
280
+ <table><tbody>
281
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
282
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
283
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
284
+ <tr><td align="center"><b>type</b></td><td>bool</td>
285
+ <tr><td align="center"><b>default</b></td><td>true</td>
286
+ </tbody></table>
287
+
288
+ ### binarizer_cls
289
+
290
+ Binarizer class name.
291
+
292
+ <table><tbody>
293
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
294
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
295
+ <tr><td align="center"><b>customizability</b></td><td>reserved</td>
296
+ <tr><td align="center"><b>type</b></td><td>str</td>
297
+ </tbody></table>
298
+
299
+ ### binary_data_dir
300
+
301
+ Path to the binarized dataset.
302
+
303
+ <table><tbody>
304
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
305
+ <tr><td align="center"><b>scope</b></td><td>preprocessing, training</td>
306
+ <tr><td align="center"><b>customizability</b></td><td>required</td>
307
+ <tr><td align="center"><b>type</b></td><td>str</td>
308
+ </tbody></table>
309
+
310
+ ### breathiness_db_max
311
+
312
+ Maximum breathiness value in dB used for normalization to [-1, 1].
313
+
314
+ <table><tbody>
315
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
316
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
317
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
318
+ <tr><td align="center"><b>type</b></td><td>float</td>
319
+ <tr><td align="center"><b>default</b></td><td>-20.0</td>
320
+ </tbody></table>
321
+
322
+ ### breathiness_db_min
323
+
324
+ Minimum breathiness value in dB used for normalization to [-1, 1].
325
+
326
+ <table><tbody>
327
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
328
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
329
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
330
+ <tr><td align="center"><b>type</b></td><td>float</td>
331
+ <tr><td align="center"><b>default</b></td><td>-96.0</td>
332
+ </tbody></table>
333
+
334
+ ### breathiness_smooth_width
335
+
336
+ Length of sinusoidal smoothing convolution kernel (in seconds) on extracted breathiness curve.
337
+
338
+ <table><tbody>
339
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
340
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
341
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
342
+ <tr><td align="center"><b>type</b></td><td>float</td>
343
+ <tr><td align="center"><b>default</b></td><td>0.12</td>
344
+ </tbody></table>
345
+
346
+ ### clip_grad_norm
347
+
348
+ The value at which to clip gradients. Equivalent to `gradient_clip_val` in `lightning.pytorch.Trainer`.
349
+
350
+ <table><tbody>
351
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
352
+ <tr><td align="center"><b>scope</b></td><td>training</td>
353
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
354
+ <tr><td align="center"><b>type</b></td><td>float</td>
355
+ <tr><td align="center"><b>default</b></td><td>1</td>
356
+ </tbody></table>
357
+
358
+ ### dataloader_prefetch_factor
359
+
360
+ Number of batches loaded in advance by each `torch.utils.data.DataLoader` worker.
361
+
362
+ <table><tbody>
363
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
364
+ <tr><td align="center"><b>scope</b></td><td>training</td>
365
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
366
+ <tr><td align="center"><b>type</b></td><td>bool</td>
367
+ <tr><td align="center"><b>default</b></td><td>true</td>
368
+ </tbody></table>
369
+
370
+ ### dataset_size_key
371
+
372
+ The key that indexes the binarized metadata to be used as the `sizes` when batching by size
373
+
374
+ <table><tbody>
375
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
376
+ <tr><td align="center"><b>scope</b></td><td>training</td>
377
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
378
+ <tr><td align="center"><b>type</b></td><td>str</td>
379
+ <tr><td align="center"><b>default</b></td><td>lengths</td>
380
+ </tbody></table>
381
+
382
+ ### dictionary
383
+
384
+ Path to the word-phoneme mapping dictionary file. Training data must fully cover phonemes in the dictionary.
385
+
386
+ <table><tbody>
387
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
388
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
389
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
390
+ <tr><td align="center"><b>type</b></td><td>str</td>
391
+ </tbody></table>
392
+
393
+ ### diff_accelerator
394
+
395
+ DDPM sampling acceleration method. The following methods are currently available:
396
+
397
+ - DDIM: the DDIM method from [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)
398
+ - PNDM: the PLMS method from [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778)
399
+ - DPM-Solver++ adapted from [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://github.com/LuChengTHU/dpm-solver)
400
+ - UniPC adapted from [UniPC: A Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models](https://github.com/wl-zhao/UniPC)
401
+
402
+ <table><tbody>
403
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
404
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
405
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
406
+ <tr><td align="center"><b>type</b></td><td>str</td>
407
+ <tr><td align="center"><b>default</b></td><td>dpm-solver</td>
408
+ <tr><td align="center"><b>constraints</b></td><td>Choose from 'ddim', 'pndm', 'dpm-solver', 'unipc'.</td>
409
+ </tbody></table>
410
+
411
+ ### diff_speedup
412
+
413
+ DDPM sampling speed-up ratio. 1 means no speeding up.
414
+
415
+ <table><tbody>
416
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
417
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
418
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
419
+ <tr><td align="center"><b>type</b></td><td>int</td>
420
+ <tr><td align="center"><b>default</b></td><td>10</td>
421
+ <tr><td align="center"><b>constraints</b></td><td>Must be a factor of <a href="#K_step">K_step</a>.</td>
422
+ </tbody></table>
423
+
424
+ ### diffusion_type
425
+
426
+ The type of ODE-based generative model algorithm. The following models are currently available:
427
+
428
+ - Denoising Diffusion Probabilistic Models (DDPM) from [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
429
+ - Rectified Flow from [Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow](https://arxiv.org/abs/2209.03003)
430
+
431
+ <table><tbody>
432
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
433
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
434
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
435
+ <tr><td align="center"><b>type</b></td><td>str</td>
436
+ <tr><td align="center"><b>default</b></td><td>reflow</td>
437
+ <tr><td align="center"><b>constraints</b></td><td>Choose from 'ddpm', 'reflow'.</td>
438
+ </tbody></table>
439
+
440
+ ### dropout
441
+
442
+ Dropout rate in some FastSpeech2 modules.
443
+
444
+ <table><tbody>
445
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
446
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
447
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
448
+ <tr><td align="center"><b>type</b></td><td>float</td>
449
+ <tr><td align="center"><b>default</b></td><td>0.1</td>
450
+ </tbody></table>
451
+
452
+ ### ds_workers
453
+
454
+ Number of workers of `torch.utils.data.DataLoader`.
455
+
456
+ <table><tbody>
457
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
458
+ <tr><td align="center"><b>scope</b></td><td>training</td>
459
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
460
+ <tr><td align="center"><b>type</b></td><td>int</td>
461
+ <tr><td align="center"><b>default</b></td><td>4</td>
462
+ </tbody></table>
463
+
464
+ ### dur_prediction_args
465
+
466
+ Arguments for phoneme duration prediction.
467
+
468
+ <table><tbody>
469
+ <tr><td align="center"><b>type</b></td><td>dict</td>
470
+ </tbody></table>
471
+
472
+ ### dur_prediction_args.arch
473
+
474
+ Architecture of duration predictor.
475
+
476
+ <table><tbody>
477
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
478
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
479
+ <tr><td align="center"><b>customizability</b></td><td>reserved</td>
480
+ <tr><td align="center"><b>type</b></td><td>str</td>
481
+ <tr><td align="center"><b>default</b></td><td>fs2</td>
482
+ <tr><td align="center"><b>constraints</b></td><td>Choose from 'fs2'.</td>
483
+ </tbody></table>
484
+
485
+ ### dur_prediction_args.dropout
486
+
487
+ Dropout rate in duration predictor of FastSpeech2.
488
+
489
+ <table><tbody>
490
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
491
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
492
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
493
+ <tr><td align="center"><b>type</b></td><td>float</td>
494
+ <tr><td align="center"><b>default</b></td><td>0.1</td>
495
+ </tbody></table>
496
+
497
+ ### dur_prediction_args.hidden_size
498
+
499
+ Dimensions of hidden layers in duration predictor of FastSpeech2.
500
+
501
+ <table><tbody>
502
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
503
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
504
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
505
+ <tr><td align="center"><b>type</b></td><td>int</td>
506
+ <tr><td align="center"><b>default</b></td><td>512</td>
507
+ </tbody></table>
508
+
509
+ ### dur_prediction_args.kernel_size
510
+
511
+ Kernel size of convolution layers of duration predictor of FastSpeech2.
512
+
513
+ <table><tbody>
514
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
515
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
516
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
517
+ <tr><td align="center"><b>type</b></td><td>int</td>
518
+ <tr><td align="center"><b>default</b></td><td>3</td>
519
+ </tbody></table>
520
+
521
+ ### dur_prediction_args.lambda_pdur_loss
522
+
523
+ Coefficient of single phone duration loss when calculating joint duration loss.
524
+
525
+ <table><tbody>
526
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
527
+ <tr><td align="center"><b>scope</b></td><td>training</td>
528
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
529
+ <tr><td align="center"><b>type</b></td><td>float</td>
530
+ <tr><td align="center"><b>default</b></td><td>0.3</td>
531
+ </tbody></table>
532
+
533
+ ### dur_prediction_args.lambda_sdur_loss
534
+
535
+ Coefficient of sentence duration loss when calculating joint duration loss.
536
+
537
+ <table><tbody>
538
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
539
+ <tr><td align="center"><b>scope</b></td><td>training</td>
540
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
541
+ <tr><td align="center"><b>type</b></td><td>float</td>
542
+ <tr><td align="center"><b>default</b></td><td>3.0</td>
543
+ </tbody></table>
544
+
545
+ ### dur_prediction_args.lambda_wdur_loss
546
+
547
+ Coefficient of word duration loss when calculating joint duration loss.
548
+
549
+ <table><tbody>
550
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
551
+ <tr><td align="center"><b>scope</b></td><td>training</td>
552
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
553
+ <tr><td align="center"><b>type</b></td><td>float</td>
554
+ <tr><td align="center"><b>default</b></td><td>1.0</td>
555
+ </tbody></table>
556
+
557
+ ### dur_prediction_args.log_offset
558
+
559
+ Offset for log domain duration loss calculation, where the following transformation is applied:
560
+ $$
561
+ D' = \ln{(D+d)}
562
+ $$
563
+ with the offset value $d$.
564
+
565
+ <table><tbody>
566
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
567
+ <tr><td align="center"><b>scope</b></td><td>training</td>
568
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
569
+ <tr><td align="center"><b>type</b></td><td>float</td>
570
+ <tr><td align="center"><b>default</b></td><td>1.0</td>
571
+ </tbody></table>
572
+
573
+ ### dur_prediction_args.loss_type
574
+
575
+ Underlying loss type of duration loss.
576
+
577
+ <table><tbody>
578
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
579
+ <tr><td align="center"><b>scope</b></td><td>training</td>
580
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
581
+ <tr><td align="center"><b>type</b></td><td>str</td>
582
+ <tr><td align="center"><b>default</b></td><td>mse</td>
583
+ <tr><td align="center"><b>constraints</b></td><td>Choose from 'mse', 'huber'.</td>
584
+ </tbody></table>
585
+
586
+ ### dur_prediction_args.num_layers
587
+
588
+ Number of duration predictor layers.
589
+
590
+ <table><tbody>
591
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
592
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
593
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
594
+ <tr><td align="center"><b>type</b></td><td>int</td>
595
+ <tr><td align="center"><b>default</b></td><td>5</td>
596
+ </tbody></table>
597
+
598
+ ### enc_ffn_kernel_size
599
+
600
+ Size of TransformerFFNLayer convolution kernel size in FastSpeech2 encoder.
601
+
602
+ <table><tbody>
603
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
604
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
605
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
606
+ <tr><td align="center"><b>type</b></td><td>int</td>
607
+ <tr><td align="center"><b>default</b></td><td>9</td>
608
+ </tbody></table>
609
+
610
+ ### enc_layers
611
+
612
+ Number of FastSpeech2 encoder layers.
613
+
614
+ <table><tbody>
615
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
616
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
617
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
618
+ <tr><td align="center"><b>type</b></td><td>int</td>
619
+ <tr><td align="center"><b>default</b></td><td>4</td>
620
+ </tbody></table>
621
+
622
+ ### energy_db_max
623
+
624
+ Maximum energy value in dB used for normalization to [-1, 1].
625
+
626
+ <table><tbody>
627
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
628
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
629
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
630
+ <tr><td align="center"><b>type</b></td><td>float</td>
631
+ <tr><td align="center"><b>default</b></td><td>-12.0</td>
632
+ </tbody></table>
633
+
634
+ ### energy_db_min
635
+
636
+ Minimum energy value in dB used for normalization to [-1, 1].
637
+
638
+ <table><tbody>
639
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
640
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
641
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
642
+ <tr><td align="center"><b>type</b></td><td>float</td>
643
+ <tr><td align="center"><b>default</b></td><td>-96.0</td>
644
+ </tbody></table>
645
+
646
+ ### energy_smooth_width
647
+
648
+ Length of sinusoidal smoothing convolution kernel (in seconds) on extracted energy curve.
649
+
650
+ <table><tbody>
651
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
652
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
653
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
654
+ <tr><td align="center"><b>type</b></td><td>float</td>
655
+ <tr><td align="center"><b>default</b></td><td>0.12</td>
656
+ </tbody></table>
657
+
658
+ ### f0_max
659
+
660
+ Maximum base frequency (F0) in Hz for pitch extraction.
661
+
662
+ <table><tbody>
663
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
664
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
665
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
666
+ <tr><td align="center"><b>type</b></td><td>int</td>
667
+ <tr><td align="center"><b>default</b></td><td>1100</td>
668
+ </tbody></table>
669
+
670
+ ### f0_min
671
+
672
+ Minimum base frequency (F0) in Hz for pitch extraction.
673
+
674
+ <table><tbody>
675
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
676
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
677
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
678
+ <tr><td align="center"><b>type</b></td><td>int</td>
679
+ <tr><td align="center"><b>default</b></td><td>65</td>
680
+ </tbody></table>
681
+
682
+ ### ffn_act
683
+
684
+ Activation function of TransformerFFNLayer in FastSpeech2 encoder:
685
+
686
+ - `torch.nn.ReLU` if 'relu'
687
+ - `torch.nn.GELU` if 'gelu'
688
+ - `torch.nn.SiLU` if 'swish'
689
+
690
+ <table><tbody>
691
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
692
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
693
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
694
+ <tr><td align="center"><b>type</b></td><td>str</td>
695
+ <tr><td align="center"><b>default</b></td><td>gelu</td>
696
+ <tr><td align="center"><b>constraints</b></td><td>Choose from 'relu', 'gelu', 'swish'.</td>
697
+ </tbody></table>
698
+
699
+ ### fft_size
700
+
701
+ Fast Fourier Transforms parameter for mel extraction.
702
+
703
+ <table><tbody>
704
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
705
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
706
+ <tr><td align="center"><b>customizability</b></td><td>reserved</td>
707
+ <tr><td align="center"><b>type</b></td><td>int</td>
708
+ <tr><td align="center"><b>default</b></td><td>2048</td>
709
+ </tbody></table>
710
+
711
+ ### finetune_enabled
712
+
713
+ Whether to finetune from a pretrained model.
714
+
715
+ <table><tbody>
716
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
717
+ <tr><td align="center"><b>scope</b></td><td>training</td>
718
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
719
+ <tr><td align="center"><b>type</b></td><td>bool</td>
720
+ <tr><td align="center"><b>default</b></td><td>False</td>
721
+ </tbody></table>
722
+
723
+ ### finetune_ckpt_path
724
+
725
+ Path to the pretrained model for finetuning.
726
+
727
+ <table><tbody>
728
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
729
+ <tr><td align="center"><b>scope</b></td><td>training</td>
730
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
731
+ <tr><td align="center"><b>type</b></td><td>str</td>
732
+ <tr><td align="center"><b>default</b></td><td>null</td>
733
+ </tbody></table>
734
+
735
+ ### finetune_ignored_params
736
+
737
+ Prefixes of parameter key names in the state dict of the pretrained model that need to be dropped before finetuning.
738
+
739
+ <table><tbody>
740
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
741
+ <tr><td align="center"><b>scope</b></td><td>training</td>
742
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
743
+ <tr><td align="center"><b>type</b></td><td>list</td>
744
+ </tbody></table>
745
+
746
+ ### finetune_strict_shapes
747
+
748
+ Whether to raise error if the tensor shapes of any parameter of the pretrained model and the target model mismatch. If set to `False`, parameters with mismatching shapes will be skipped.
749
+
750
+ <table><tbody>
751
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
752
+ <tr><td align="center"><b>scope</b></td><td>training</td>
753
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
754
+ <tr><td align="center"><b>type</b></td><td>bool</td>
755
+ <tr><td align="center"><b>default</b></td><td>True</td>
756
+ </tbody></table>
757
+
758
+ ### fmax
759
+
760
+ Maximum frequency of mel extraction.
761
+
762
+ <table><tbody>
763
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
764
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
765
+ <tr><td align="center"><b>customizability</b></td><td>reserved</td>
766
+ <tr><td align="center"><b>type</b></td><td>int</td>
767
+ <tr><td align="center"><b>default</b></td><td>16000</td>
768
+ </tbody></table>
769
+
770
+ ### fmin
771
+
772
+ Minimum frequency of mel extraction.
773
+
774
+ <table><tbody>
775
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
776
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
777
+ <tr><td align="center"><b>customizability</b></td><td>reserved</td>
778
+ <tr><td align="center"><b>type</b></td><td>int</td>
779
+ <tr><td align="center"><b>default</b></td><td>40</td>
780
+ </tbody></table>
781
+
782
+ ### freezing_enabled
783
+
784
+ Whether enabling parameter freezing during training.
785
+
786
+ <table><tbody>
787
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
788
+ <tr><td align="center"><b>scope</b></td><td>training</td>
789
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
790
+ <tr><td align="center"><b>type</b></td><td>bool</td>
791
+ <tr><td align="center"><b>default</b></td><td>False</td>
792
+ </tbody></table>
793
+
794
+ ### frozen_params
795
+
796
+ Parameter name prefixes to freeze during training.
797
+
798
+ <table><tbody>
799
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
800
+ <tr><td align="center"><b>scope</b></td><td>training</td>
801
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
802
+ <tr><td align="center"><b>type</b></td><td>list</td>
803
+ <tr><td align="center"><b>default</b></td><td>[]</td>
804
+ </tbody></table>
805
+
806
+ ### glide_embed_scale
807
+
808
+ The scale factor to be multiplied on the glide embedding values for melody encoder.
809
+
810
+ <table><tbody>
811
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
812
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
813
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
814
+ <tr><td align="center"><b>type</b></td><td>float</td>
815
+ <tr><td align="center"><b>default</b></td><td>11.313708498984760</td>
816
+ </tbody></table>
817
+
818
+ ### glide_types
819
+
820
+ Type names of glide notes.
821
+
822
+ <table><tbody>
823
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
824
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
825
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
826
+ <tr><td align="center"><b>type</b></td><td>list</td>
827
+ <tr><td align="center"><b>default</b></td><td>[up, down]</td>
828
+ </tbody></table>
829
+
830
+ ### hidden_size
831
+
832
+ Dimension of hidden layers of FastSpeech2, token and parameter embeddings, and diffusion condition.
833
+
834
+ <table><tbody>
835
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
836
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
837
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
838
+ <tr><td align="center"><b>type</b></td><td>int</td>
839
+ <tr><td align="center"><b>default</b></td><td>256</td>
840
+ </tbody></table>
841
+
842
+ ### hnsep
843
+
844
+ Harmonic-noise separation algorithm type.
845
+
846
+ <table><tbody>
847
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
848
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
849
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
850
+ <tr><td align="center"><b>type</b></td><td>str</td>
851
+ <tr><td align="center"><b>default</b></td><td>world</td>
852
+ <tr><td align="center"><b>constraints</b></td><td>Choose from 'world', 'vr'.</td>
853
+ </tbody></table>
854
+
855
+ ### hnsep_ckpt
856
+
857
+ Checkpoint or model path of NN-based harmonic-noise separator.
858
+
859
+ <table><tbody>
860
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
861
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
862
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
863
+ <tr><td align="center"><b>type</b></td><td>str</td>
864
+ </tbody></table>
865
+
866
+ ### hop_size
867
+
868
+ Hop size or step length (in number of waveform samples) of mel and feature extraction.
869
+
870
+ <table><tbody>
871
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
872
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
873
+ <tr><td align="center"><b>customizability</b></td><td>reserved</td>
874
+ <tr><td align="center"><b>type</b></td><td>int</td>
875
+ <tr><td align="center"><b>default</b></td><td>512</td>
876
+ </tbody></table>
877
+
878
+ ### lambda_aux_mel_loss
879
+
880
+ Coefficient of aux mel loss when calculating total loss of acoustic model with shallow diffusion.
881
+
882
+ <table><tbody>
883
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
884
+ <tr><td align="center"><b>scope</b></td><td>training</td>
885
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
886
+ <tr><td align="center"><b>type</b></td><td>float</td>
887
+ <tr><td align="center"><b>default</b></td><td>0.2</td>
888
+ </tbody></table>
889
+
890
+ ### lambda_dur_loss
891
+
892
+ Coefficient of duration loss when calculating total loss of variance model.
893
+
894
+ <table><tbody>
895
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
896
+ <tr><td align="center"><b>scope</b></td><td>training</td>
897
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
898
+ <tr><td align="center"><b>type</b></td><td>float</td>
899
+ <tr><td align="center"><b>default</b></td><td>1.0</td>
900
+ </tbody></table>
901
+
902
+ ### lambda_pitch_loss
903
+
904
+ Coefficient of pitch loss when calculating total loss of variance model.
905
+
906
+ <table><tbody>
907
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
908
+ <tr><td align="center"><b>scope</b></td><td>training</td>
909
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
910
+ <tr><td align="center"><b>type</b></td><td>float</td>
911
+ <tr><td align="center"><b>default</b></td><td>1.0</td>
912
+ </tbody></table>
913
+
914
+ ### lambda_var_loss
915
+
916
+ Coefficient of variance loss (all variance parameters other than pitch, like energy, breathiness, etc.) when calculating total loss of variance model.
917
+
918
+ <table><tbody>
919
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
920
+ <tr><td align="center"><b>scope</b></td><td>training</td>
921
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
922
+ <tr><td align="center"><b>type</b></td><td>float</td>
923
+ <tr><td align="center"><b>default</b></td><td>1.0</td>
924
+ </tbody></table>
925
+
926
+ ### K_step
927
+
928
+ Maximum number of DDPM steps used by shallow diffusion.
929
+
930
+ <table><tbody>
931
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
932
+ <tr><td align="center"><b>scope</b></td><td>training</td>
933
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
934
+ <tr><td align="center"><b>type</b></td><td>int</td>
935
+ <tr><td align="center"><b>default</b></td><td>400</td>
936
+ </tbody></table>
937
+
938
+ ### K_step_infer
939
+
940
+ Number of DDPM steps used during shallow diffusion inference. Normally set as same as [K_step](#K_step).
941
+
942
+ <table><tbody>
943
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
944
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
945
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
946
+ <tr><td align="center"><b>type</b></td><td>int</td>
947
+ <tr><td align="center"><b>default</b></td><td>400</td>
948
+ <tr><td align="center"><b>constraints</b></td><td>Should be no larger than K_step.</td>
949
+ </tbody></table>
950
+
951
+ ### log_interval
952
+
953
+ Controls how often to log within training steps. Equivalent to `log_every_n_steps` in `lightning.pytorch.Trainer`.
954
+
955
+ <table><tbody>
956
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
957
+ <tr><td align="center"><b>scope</b></td><td>training</td>
958
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
959
+ <tr><td align="center"><b>type</b></td><td>int</td>
960
+ <tr><td align="center"><b>default</b></td><td>100</td>
961
+ </tbody></table>
962
+
963
+ ### lr_scheduler_args
964
+
965
+ Arguments of learning rate scheduler. Keys will be used as keyword arguments of the `__init__()` method of [lr_scheduler_args.scheduler_cls](#lr_scheduler_argsscheduler_cls).
966
+
967
+ <table><tbody>
968
+ <tr><td align="center"><b>type</b></td><td>dict</td>
969
+ </tbody></table>
970
+
971
+ ### lr_scheduler_args.scheduler_cls
972
+
973
+ Learning rate scheduler class name.
974
+
975
+ <table><tbody>
976
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
977
+ <tr><td align="center"><b>scope</b></td><td>training</td>
978
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
979
+ <tr><td align="center"><b>type</b></td><td>str</td>
980
+ <tr><td align="center"><b>default</b></td><td>torch.optim.lr_scheduler.StepLR</td>
981
+ </tbody></table>
982
+
983
+ ### main_loss_log_norm
984
+
985
+ Whether to use log-normalized weight for the main loss. This is similar to the method in the Stable Diffusion 3 paper [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206).
986
+
987
+ <table><tbody>
988
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
989
+ <tr><td align="center"><b>scope</b></td><td>training</td>
990
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
991
+ <tr><td align="center"><b>type</b></td><td>bool</td>
992
+ </tbody></table>
993
+
994
+ ### main_loss_type
995
+
996
+ Loss type of the main decoder/predictor.
997
+
998
+ <table><tbody>
999
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1000
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1001
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
1002
+ <tr><td align="center"><b>type</b></td><td>str</td>
1003
+ <tr><td align="center"><b>default</b></td><td>l2</td>
1004
+ <tr><td align="center"><b>constraints</b></td><td>Choose from 'l1', 'l2'.</td>
1005
+ </tbody></table>
1006
+
1007
+ ### max_batch_frames
1008
+
1009
+ Maximum number of data frames in each training batch. Used to dynamically control the batch size.
1010
+
1011
+ <table><tbody>
1012
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1013
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1014
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1015
+ <tr><td align="center"><b>type</b></td><td>int</td>
1016
+ <tr><td align="center"><b>default</b></td><td>80000</td>
1017
+ </tbody></table>
1018
+
1019
+ ### max_batch_size
1020
+
1021
+ The maximum training batch size.
1022
+
1023
+ <table><tbody>
1024
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1025
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1026
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1027
+ <tr><td align="center"><b>type</b></td><td>int</td>
1028
+ <tr><td align="center"><b>default</b></td><td>48</td>
1029
+ </tbody></table>
1030
+
1031
+ ### max_beta
1032
+
1033
+ Max beta of the DDPM noise schedule.
1034
+
1035
+ <table><tbody>
1036
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1037
+ <tr><td align="center"><b>scope</b></td><td>nn, inference</td>
1038
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1039
+ <tr><td align="center"><b>type</b></td><td>float</td>
1040
+ <tr><td align="center"><b>default</b></td><td>0.02</td>
1041
+ </tbody></table>
1042
+
1043
+ ### max_updates
1044
+
1045
+ Stop training after this number of steps. Equivalent to `max_steps` in `lightning.pytorch.Trainer`.
1046
+
1047
+ <table><tbody>
1048
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1049
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1050
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1051
+ <tr><td align="center"><b>type</b></td><td>int</td>
1052
+ <tr><td align="center"><b>default</b></td><td>320000</td>
1053
+ </tbody></table>
1054
+
1055
+ ### max_val_batch_frames
1056
+
1057
+ Maximum number of data frames in each validation batch.
1058
+
1059
+ <table><tbody>
1060
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1061
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1062
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1063
+ <tr><td align="center"><b>type</b></td><td>int</td>
1064
+ <tr><td align="center"><b>default</b></td><td>60000</td>
1065
+ </tbody></table>
1066
+
1067
+ ### max_val_batch_size
1068
+
1069
+ The maximum validation batch size.
1070
+
1071
+ <table><tbody>
1072
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1073
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1074
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1075
+ <tr><td align="center"><b>type</b></td><td>int</td>
1076
+ <tr><td align="center"><b>default</b></td><td>1</td>
1077
+ </tbody></table>
1078
+
1079
+ ### mel_base
1080
+
1081
+ The logarithmic base of mel spectrogram calculation.
1082
+
1083
+ **WARNING: Since v2.4.0 release, this value is no longer configurable for preprocessing new datasets.**
1084
+
1085
+ <table><tbody>
1086
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1087
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
1088
+ <tr><td align="center"><b>customizability</b></td><td>reserved</td>
1089
+ <tr><td align="center"><b>type</b></td><td>str</td>
1090
+ <tr><td align="center"><b>default</b></td><td>e</td>
1091
+ </tbody></table>
1092
+
1093
+ ### mel_vmax
1094
+
1095
+ Maximum mel spectrogram heatmap value for TensorBoard plotting.
1096
+
1097
+ <table><tbody>
1098
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1099
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1100
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
1101
+ <tr><td align="center"><b>type</b></td><td>float</td>
1102
+ <tr><td align="center"><b>default</b></td><td>1.5</td>
1103
+ </tbody></table>
1104
+
1105
+ ### mel_vmin
1106
+
1107
+ Minimum mel spectrogram heatmap value for TensorBoard plotting.
1108
+
1109
+ <table><tbody>
1110
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1111
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1112
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
1113
+ <tr><td align="center"><b>type</b></td><td>float</td>
1114
+ <tr><td align="center"><b>default</b></td><td>-6.0</td>
1115
+ </tbody></table>
1116
+
1117
+ ### melody_encoder_args
1118
+
1119
+ Arguments for melody encoder. Available sub-keys: `hidden_size`, `enc_layers`, `enc_ffn_kernel_size`, `ffn_act`, `dropout`, `num_heads`, `use_pos_embed`, `rel_pos`. If either of the parameter does not exist in this configuration key, it inherits from the linguistic encoder.
1120
+
1121
+ <table><tbody>
1122
+ <tr><td align="center"><b>type</b></td><td>dict</td>
1123
+ </tbody></table>
1124
+
1125
+ ### midi_smooth_width
1126
+
1127
+ Length of sinusoidal smoothing convolution kernel (in seconds) on the step function representing MIDI sequence for base pitch calculation.
1128
+
1129
+ <table><tbody>
1130
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1131
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
1132
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1133
+ <tr><td align="center"><b>type</b></td><td>float</td>
1134
+ <tr><td align="center"><b>default</b></td><td>0.06</td>
1135
+ </tbody></table>
1136
+
1137
+ ### nccl_p2p
1138
+
1139
+ Whether to enable P2P when using NCCL as the backend. Turn it to `false` if the training process is stuck upon beginning.
1140
+
1141
+ <table><tbody>
1142
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1143
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1144
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1145
+ <tr><td align="center"><b>type</b></td><td>bool</td>
1146
+ <tr><td align="center"><b>default</b></td><td>true</td>
1147
+ </tbody></table>
1148
+
1149
+ ### num_ckpt_keep
1150
+
1151
+ Number of newest checkpoints kept during training.
1152
+
1153
+ <table><tbody>
1154
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1155
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1156
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1157
+ <tr><td align="center"><b>type</b></td><td>int</td>
1158
+ <tr><td align="center"><b>default</b></td><td>5</td>
1159
+ </tbody></table>
1160
+
1161
+ ### num_heads
1162
+
1163
+ The number of attention heads of `torch.nn.MultiheadAttention` in FastSpeech2 encoder.
1164
+
1165
+ <table><tbody>
1166
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1167
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
1168
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
1169
+ <tr><td align="center"><b>type</b></td><td>int</td>
1170
+ <tr><td align="center"><b>default</b></td><td>2</td>
1171
+ </tbody></table>
1172
+
1173
+ ### num_sanity_val_steps
1174
+
1175
+ Number of sanity validation steps at the beginning.
1176
+
1177
+ <table><tbody>
1178
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1179
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1180
+ <tr><td align="center"><b>customizability</b></td><td>reserved</td>
1181
+ <tr><td align="center"><b>type</b></td><td>int</td>
1182
+ <tr><td align="center"><b>default</b></td><td>1</td>
1183
+ </tbody></table>
1184
+
1185
+ ### num_spk
1186
+
1187
+ Maximum number of speakers in multi-speaker models.
1188
+
1189
+ <table><tbody>
1190
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1191
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
1192
+ <tr><td align="center"><b>customizability</b></td><td>required</td>
1193
+ <tr><td align="center"><b>type</b></td><td>int</td>
1194
+ <tr><td align="center"><b>default</b></td><td>1</td>
1195
+ </tbody></table>
1196
+
1197
+ ### num_valid_plots
1198
+
1199
+ Number of validation plots in each validation. Plots will be chosen from the start of the validation set.
1200
+
1201
+ <table><tbody>
1202
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1203
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1204
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1205
+ <tr><td align="center"><b>type</b></td><td>int</td>
1206
+ <tr><td align="center"><b>default</b></td><td>10</td>
1207
+ </tbody></table>
1208
+
1209
+ ### optimizer_args
1210
+
1211
+ Arguments of optimizer. Keys will be used as keyword arguments of the `__init__()` method of [optimizer_args.optimizer_cls](#optimizer_argsoptimizer_cls).
1212
+
1213
+ <table><tbody>
1214
+ <tr><td align="center"><b>type</b></td><td>dict</td>
1215
+ </tbody></table>
1216
+
1217
+ ### optimizer_args.optimizer_cls
1218
+
1219
+ Optimizer class name
1220
+
1221
+ <table><tbody>
1222
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1223
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1224
+ <tr><td align="center"><b>customizability</b></td><td>reserved</td>
1225
+ <tr><td align="center"><b>type</b></td><td>str</td>
1226
+ <tr><td align="center"><b>default</b></td><td>torch.optim.AdamW</td>
1227
+ </tbody></table>
1228
+
1229
+ ### pe
1230
+
1231
+ Pitch extraction algorithm type.
1232
+
1233
+ <table><tbody>
1234
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1235
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
1236
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1237
+ <tr><td align="center"><b>type</b></td><td>str</td>
1238
+ <tr><td align="center"><b>default</b></td><td>parselmouth</td>
1239
+ <tr><td align="center"><b>constraints</b></td><td>Choose from 'parselmouth', 'rmvpe', 'harvest'.</td>
1240
+ </tbody></table>
1241
+
1242
+ ### pe_ckpt
1243
+
1244
+ Checkpoint or model path of NN-based pitch extractor.
1245
+
1246
+ <table><tbody>
1247
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1248
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
1249
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1250
+ <tr><td align="center"><b>type</b></td><td>str</td>
1251
+ </tbody></table>
1252
+
1253
+ ### permanent_ckpt_interval
1254
+
1255
+ The interval (in number of training steps) of permanent checkpoints. Permanent checkpoints will not be removed even if they are not the newest ones.
1256
+
1257
+ <table><tbody>
1258
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1259
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1260
+ <tr><td align="center"><b>type</b></td><td>int</td>
1261
+ <tr><td align="center"><b>default</b></td><td>40000</td>
1262
+ </tbody></table>
1263
+
1264
+ ### permanent_ckpt_start
1265
+
1266
+ Checkpoints will be marked as permanent every [permanent_ckpt_interval](#permanent_ckpt_interval) training steps after this number of training steps.
1267
+
1268
+ <table><tbody>
1269
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1270
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1271
+ <tr><td align="center"><b>type</b></td><td>int</td>
1272
+ <tr><td align="center"><b>default</b></td><td>120000</td>
1273
+ </tbody></table>
1274
+
1275
+ ### pitch_prediction_args
1276
+
1277
+ Arguments for pitch prediction.
1278
+
1279
+ <table><tbody>
1280
+ <tr><td align="center"><b>type</b></td><td>dict</td>
1281
+ </tbody></table>
1282
+
1283
+ ### pitch_prediction_args.backbone_args
1284
+
1285
+ Equivalent to [backbone_args](#backbone_args) but only for the pitch predictor model. If not set, use the root backbone type.
1286
+
1287
+ <table><tbody>
1288
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1289
+ </tbody></table>
1290
+
1291
+ ### pitch_prediction_args.backbone_type
1292
+
1293
+ Equivalent to [backbone_type](#backbone_type) but only for the pitch predictor model.
1294
+
1295
+ <table><tbody>
1296
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1297
+ <tr><td align="center"><b>default</b></td><td>wavenet</td>
1298
+ </tbody></table>
1299
+
1300
+ ### pitch_prediction_args.pitd_clip_max
1301
+
1302
+ Maximum clipping value (in semitones) of pitch delta between actual pitch and base pitch.
1303
+
1304
+ <table><tbody>
1305
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1306
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
1307
+ <tr><td align="center"><b>type</b></td><td>float</td>
1308
+ <tr><td align="center"><b>default</b></td><td>12.0</td>
1309
+ </tbody></table>
1310
+
1311
+ ### pitch_prediction_args.pitd_clip_min
1312
+
1313
+ Minimum clipping value (in semitones) of pitch delta between actual pitch and base pitch.
1314
+
1315
+ <table><tbody>
1316
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1317
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
1318
+ <tr><td align="center"><b>type</b></td><td>float</td>
1319
+ <tr><td align="center"><b>default</b></td><td>-12.0</td>
1320
+ </tbody></table>
1321
+
1322
+ ### pitch_prediction_args.pitd_norm_max
1323
+
1324
+ Maximum pitch delta value in semitones used for normalization to [-1, 1].
1325
+
1326
+ <table><tbody>
1327
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1328
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
1329
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1330
+ <tr><td align="center"><b>type</b></td><td>float</td>
1331
+ <tr><td align="center"><b>default</b></td><td>8.0</td>
1332
+ </tbody></table>
1333
+
1334
+ ### pitch_prediction_args.pitd_norm_min
1335
+
1336
+ Minimum pitch delta value in semitones used for normalization to [-1, 1].
1337
+
1338
+ <table><tbody>
1339
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1340
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
1341
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1342
+ <tr><td align="center"><b>type</b></td><td>float</td>
1343
+ <tr><td align="center"><b>default</b></td><td>-8.0</td>
1344
+ </tbody></table>
1345
+
1346
+ ### pitch_prediction_args.repeat_bins
1347
+
1348
+ Number of repeating bins in the pitch predictor.
1349
+
1350
+ <table><tbody>
1351
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1352
+ <tr><td align="center"><b>scope</b></td><td>nn, inference</td>
1353
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1354
+ <tr><td align="center"><b>type</b></td><td>int</td>
1355
+ <tr><td align="center"><b>default</b></td><td>64</td>
1356
+ </tbody></table>
1357
+
1358
+ ### pl_trainer_accelerator
1359
+
1360
+ Type of Lightning trainer hardware accelerator.
1361
+
1362
+ <table><tbody>
1363
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1364
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1365
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
1366
+ <tr><td align="center"><b>type</b></td><td>str</td>
1367
+ <tr><td align="center"><b>default</b></td><td>auto</td>
1368
+ <tr><td align="center"><b>constraints</b></td><td>See <a href="https://lightning.ai/docs/pytorch/stable/extensions/accelerator.html?highlight=accelerator">Accelerator — PyTorch Lightning 2.X.X documentation</a> for available values.</td>
1369
+ </tbody></table>
1370
+
1371
+ ### pl_trainer_devices
1372
+
1373
+ To determine on which device(s) model should be trained.
1374
+
1375
+ 'auto' will utilize all visible devices defined with the `CUDA_VISIBLE_DEVICES` environment variable, or utilize all available devices if that variable is not set. Otherwise, it behaves like `CUDA_VISIBLE_DEVICES` which can filter out visible devices.
1376
+
1377
+ <table><tbody>
1378
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1379
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1380
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
1381
+ <tr><td align="center"><b>type</b></td><td>str</td>
1382
+ <tr><td align="center"><b>default</b></td><td>auto</td>
1383
+ </tbody></table>
1384
+
1385
+ ### pl_trainer_precision
1386
+
1387
+ The computation precision of training.
1388
+
1389
+ <table><tbody>
1390
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1391
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1392
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1393
+ <tr><td align="center"><b>type</b></td><td>str</td>
1394
+ <tr><td align="center"><b>default</b></td><td>16-mixed</td>
1395
+ <tr><td align="center"><b>constraints</b></td><td>Choose from '32-true', 'bf16-mixed', '16-mixed'. See more possible values at <a href="https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api">Trainer — PyTorch Lightning 2.X.X documentation</a>.</td>
1396
+ </tbody></table>
1397
+
1398
+ ### pl_trainer_num_nodes
1399
+
1400
+ Number of nodes in the training cluster of Lightning trainer.
1401
+
1402
+ <table><tbody>
1403
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1404
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1405
+ <tr><td align="center"><b>customizability</b></td><td>reserved</td>
1406
+ <tr><td align="center"><b>type</b></td><td>int</td>
1407
+ <tr><td align="center"><b>default</b></td><td>1</td>
1408
+ </tbody></table>
1409
+
1410
+ ### pl_trainer_strategy
1411
+
1412
+ Arguments of Lightning Strategy. Values will be used as keyword arguments when constructing the Strategy object.
1413
+
1414
+ <table><tbody>
1415
+ <tr><td align="center"><b>type</b></td><td>dict</td>
1416
+ </tbody></table>
1417
+
1418
+ ### pl_trainer_strategy.name
1419
+
1420
+ Strategy name for the Lightning trainer.
1421
+
1422
+ <table><tbody>
1423
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1424
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1425
+ <tr><td align="center"><b>customizability</b></td><td>reserved</td>
1426
+ <tr><td align="center"><b>type</b></td><td>str</td>
1427
+ <tr><td align="center"><b>default</b></td><td>auto</td>
1428
+ </tbody></table>
1429
+
1430
+ ### predict_breathiness
1431
+
1432
+ Whether to enable breathiness prediction.
1433
+
1434
+ <table><tbody>
1435
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1436
+ <tr><td align="center"><b>scope</b></td><td>nn, preprocessing, training, inference</td>
1437
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1438
+ <tr><td align="center"><b>type</b></td><td>bool</td>
1439
+ <tr><td align="center"><b>default</b></td><td>false</td>
1440
+ </tbody></table>
1441
+
1442
+ ### predict_dur
1443
+
1444
+ Whether to enable phoneme duration prediction.
1445
+
1446
+ <table><tbody>
1447
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1448
+ <tr><td align="center"><b>scope</b></td><td>nn, preprocessing, training, inference</td>
1449
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1450
+ <tr><td align="center"><b>type</b></td><td>bool</td>
1451
+ <tr><td align="center"><b>default</b></td><td>true</td>
1452
+ </tbody></table>
1453
+
1454
+ ### predict_energy
1455
+
1456
+ Whether to enable energy prediction.
1457
+
1458
+ <table><tbody>
1459
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1460
+ <tr><td align="center"><b>scope</b></td><td>nn, preprocessing, training, inference</td>
1461
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1462
+ <tr><td align="center"><b>type</b></td><td>bool</td>
1463
+ <tr><td align="center"><b>default</b></td><td>false</td>
1464
+ </tbody></table>
1465
+
1466
+ ### predict_pitch
1467
+
1468
+ Whether to enable pitch prediction.
1469
+
1470
+ <table><tbody>
1471
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1472
+ <tr><td align="center"><b>scope</b></td><td>nn, preprocessing, training, inference</td>
1473
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1474
+ <tr><td align="center"><b>type</b></td><td>bool</td>
1475
+ <tr><td align="center"><b>default</b></td><td>true</td>
1476
+ </tbody></table>
1477
+
1478
+ ### predict_tension
1479
+
1480
+ Whether to enable tension prediction.
1481
+
1482
+ <table><tbody>
1483
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1484
+ <tr><td align="center"><b>scope</b></td><td>nn, preprocessing, training, inference</td>
1485
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1486
+ <tr><td align="center"><b>type</b></td><td>bool</td>
1487
+ <tr><td align="center"><b>default</b></td><td>true</td>
1488
+ </tbody></table>
1489
+
1490
+ ### predict_voicing
1491
+
1492
+ Whether to enable voicing prediction.
1493
+
1494
+ <table><tbody>
1495
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1496
+ <tr><td align="center"><b>scope</b></td><td>nn, preprocessing, training, inference</td>
1497
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1498
+ <tr><td align="center"><b>type</b></td><td>bool</td>
1499
+ <tr><td align="center"><b>default</b></td><td>true</td>
1500
+ </tbody></table>
1501
+
1502
+ ### raw_data_dir
1503
+
1504
+ Path(s) to the raw dataset including wave files, transcriptions, etc.
1505
+
1506
+ <table><tbody>
1507
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1508
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
1509
+ <tr><td align="center"><b>customizability</b></td><td>required</td>
1510
+ <tr><td align="center"><b>type</b></td><td>str, List[str]</td>
1511
+ </tbody></table>
1512
+
1513
+ ### rel_pos
1514
+
1515
+ Whether to use relative positional encoding in FastSpeech2 module.
1516
+
1517
+ <table><tbody>
1518
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1519
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
1520
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
1521
+ <tr><td align="center"><b>type</b></td><td>boolean</td>
1522
+ <tr><td align="center"><b>default</b></td><td>true</td>
1523
+ </tbody></table>
1524
+
1525
+ ### sampler_frame_count_grid
1526
+
1527
+ The batch sampler applies an algorithm called _sorting by similar length_ when collecting batches. Data samples are first grouped by their approximate lengths before they get shuffled within each group. Assume this value is set to $L_{grid}$, the approximate length of a data sample with length $L_{real}$ can be calculated through the following expression:
1528
+
1529
+ $$
1530
+ L_{approx} = \lfloor\frac{L_{real}}{L_{grid}}\rfloor\cdot L_{grid}
1531
+ $$
1532
+
1533
+ Training performance on some datasets may be very sensitive to this value. Change it to 1 (completely sorted by length without shuffling) to get the best performance in theory.
1534
+
1535
+ <table><tbody>
1536
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1537
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1538
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1539
+ <tr><td align="center"><b>type</b></td><td>int</td>
1540
+ <tr><td align="center"><b>default</b></td><td>6</td>
1541
+ </tbody></table>
1542
+
1543
+ ### sampling_algorithm
1544
+
1545
+ The algorithm to solve the ODE of Rectified Flow. The following methods are currently available:
1546
+
1547
+ - Euler: The Euler method.
1548
+ - Runge-Kutta (order 2): The 2nd-order Runge-Kutta method.
1549
+ - Runge-Kutta (order 4): The 4th-order Runge-Kutta method.
1550
+ - Runge-Kutta (order 5): The 5th-order Runge-Kutta method.
1551
+
1552
+ <table><tbody>
1553
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1554
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
1555
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1556
+ <tr><td align="center"><b>type</b></td><td>str</td>
1557
+ <tr><td align="center"><b>default</b></td><td>euler</td>
1558
+ <tr><td align="center"><b>constraints</b></td><td>Choose from 'euler', 'rk2', 'rk4', 'rk5'.</td>
1559
+ </tbody></table>
1560
+
1561
+ ### sampling_steps
1562
+
1563
+ The total sampling steps to solve the ODE of Rectified Flow. Note that this value may not equal to NFE (Number of Function Evaluations) because some methods may require more than one function evaluation per step.
1564
+
1565
+ <table><tbody>
1566
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1567
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
1568
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1569
+ <tr><td align="center"><b>type</b></td><td>int</td>
1570
+ <tr><td align="center"><b>default</b></td><td>20</td>
1571
+ </tbody></table>
1572
+
1573
+ ### schedule_type
1574
+
1575
+ The DDPM schedule type.
1576
+
1577
+ <table><tbody>
1578
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1579
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
1580
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
1581
+ <tr><td align="center"><b>type</b></td><td>str</td>
1582
+ <tr><td align="center"><b>default</b></td><td>linear</td>
1583
+ <tr><td align="center"><b>constraints</b></td><td>Choose from 'linear', 'cosine'.</td>
1584
+ </tbody></table>
1585
+
1586
+ ### shallow_diffusion_args
1587
+
1588
+ Arguments for shallow diffusion.
1589
+
1590
+ <table><tbody>
1591
+ <tr><td align="center"><b>type</b></td><td>dict</td>
1592
+ </tbody></table>
1593
+
1594
+ ### shallow_diffusion_args.aux_decoder_arch
1595
+
1596
+ Architecture type of the auxiliary decoder.
1597
+
1598
+ <table><tbody>
1599
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1600
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
1601
+ <tr><td align="center"><b>customizability</b></td><td>reserved</td>
1602
+ <tr><td align="center"><b>type</b></td><td>str</td>
1603
+ <tr><td align="center"><b>default</b></td><td>convnext</td>
1604
+ <tr><td align="center"><b>constraints</b></td><td>Choose from 'convnext'.</td>
1605
+ </tbody></table>
1606
+
1607
+ ### shallow_diffusion_args.aux_decoder_args
1608
+
1609
+ Keyword arguments for dynamically constructing the auxiliary decoder.
1610
+
1611
+ <table><tbody>
1612
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1613
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
1614
+ <tr><td align="center"><b>type</b></td><td>dict</td>
1615
+ </tbody></table>
1616
+
1617
+ ### shallow_diffusion_args.aux_decoder_grad
1618
+
1619
+ Scale factor of the gradients from the auxiliary decoder to the encoder.
1620
+
1621
+ <table><tbody>
1622
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1623
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1624
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1625
+ <tr><td align="center"><b>type</b></td><td>float</td>
1626
+ <tr><td align="center"><b>default</b></td><td>0.1</td>
1627
+ </tbody></table>
1628
+
1629
+ ### shallow_diffusion_args.train_aux_decoder
1630
+
1631
+ Whether to forward and backward the auxiliary decoder during training. If set to `false`, the auxiliary decoder hangs in the memory and does not get any updates.
1632
+
1633
+ <table><tbody>
1634
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1635
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1636
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1637
+ <tr><td align="center"><b>type</b></td><td>bool</td>
1638
+ <tr><td align="center"><b>default</b></td><td>true</td>
1639
+ </tbody></table>
1640
+
1641
+ ### shallow_diffusion_args.train_diffusion
1642
+
1643
+ Whether to forward and backward the diffusion (main) decoder during training. If set to `false`, the diffusion decoder hangs in the memory and does not get any updates.
1644
+
1645
+ <table><tbody>
1646
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1647
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1648
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1649
+ <tr><td align="center"><b>type</b></td><td>bool</td>
1650
+ <tr><td align="center"><b>default</b></td><td>true</td>
1651
+ </tbody></table>
1652
+
1653
+ ### shallow_diffusion_args.val_gt_start
1654
+
1655
+ Whether to use the ground truth as `x_start` in the shallow diffusion validation process. If set to `true`, gaussian noise is added to the ground truth before shallow diffusion is performed; otherwise the noise is added to the output of the auxiliary decoder. This option is useful when the auxiliary decoder has not been trained yet.
1656
+
1657
+ <table><tbody>
1658
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1659
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1660
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1661
+ <tr><td align="center"><b>type</b></td><td>bool</td>
1662
+ <tr><td align="center"><b>default</b></td><td>false</td>
1663
+ </tbody></table>
1664
+
1665
+ ### sort_by_len
1666
+
1667
+ Whether to apply the _sorting by similar length_ algorithm described in [sampler_frame_count_grid](#sampler_frame_count_grid). Turning off this option may slow down training because sorting by length can better utilize the computing resources.
1668
+
1669
+ <table><tbody>
1670
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1671
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1672
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
1673
+ <tr><td align="center"><b>type</b></td><td>bool</td>
1674
+ <tr><td align="center"><b>default</b></td><td>true</td>
1675
+ </tbody></table>
1676
+
1677
+ ### speakers
1678
+
1679
+ The names of speakers in a multi-speaker model. Speaker names are mapped to speaker indexes and stored into spk_map.json when preprocessing.
1680
+
1681
+ <table><tbody>
1682
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1683
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
1684
+ <tr><td align="center"><b>customizability</b></td><td>required</td>
1685
+ <tr><td align="center"><b>type</b></td><td>list</td>
1686
+ </tbody></table>
1687
+
1688
+ ### spk_ids
1689
+
1690
+ The IDs of speakers in a multi-speaker model. If an empty list is given, speaker IDs will be automatically generated as $0,1,2,...,N_{spk}-1$. IDs can be duplicate or discontinuous.
1691
+
1692
+ <table><tbody>
1693
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1694
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
1695
+ <tr><td align="center"><b>customizability</b></td><td>required</td>
1696
+ <tr><td align="center"><b>type</b></td><td>List[int]</td>
1697
+ <tr><td align="center"><b>default</b></td><td>[]</td>
1698
+ </tbody></table>
1699
+
1700
+ ### spec_min
1701
+
1702
+ Minimum mel spectrogram value used for normalization to [-1, 1]. Different mel bins can have different minimum values.
1703
+
1704
+ <table><tbody>
1705
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1706
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
1707
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
1708
+ <tr><td align="center"><b>type</b></td><td>List[float]</td>
1709
+ <tr><td align="center"><b>default</b></td><td>[-5.0]</td>
1710
+ </tbody></table>
1711
+
1712
+ ### spec_max
1713
+
1714
+ Maximum mel spectrogram value used for normalization to [-1, 1]. Different mel bins can have different maximum values.
1715
+
1716
+ <table><tbody>
1717
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1718
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
1719
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
1720
+ <tr><td align="center"><b>type</b></td><td>List[float]</td>
1721
+ <tr><td align="center"><b>default</b></td><td>[0.0]</td>
1722
+ </tbody></table>
1723
+
1724
+ ### T_start
1725
+
1726
+ The starting value of time $t$ in the Rectified Flow ODE which applies on $t \in (T_{start}, 1)$.
1727
+
1728
+ <table><tbody>
1729
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1730
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1731
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1732
+ <tr><td align="center"><b>type</b></td><td>float</td>
1733
+ <tr><td align="center"><b>default</b></td><td>0.4</td>
1734
+ </tbody></table>
1735
+
1736
+ ### T_start_infer
1737
+
1738
+ The starting value of time $t$ in the ODE during shallow Rectified Flow inference. Normally set as same as [T_start](#T_start).
1739
+
1740
+ <table><tbody>
1741
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1742
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
1743
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1744
+ <tr><td align="center"><b>type</b></td><td>float</td>
1745
+ <tr><td align="center"><b>default</b></td><td>0.4</td>
1746
+ <tr><td align="center"><b>constraints</b></td><td>Should be no less than T_start.</td>
1747
+ </tbody></table>
1748
+
1749
+ ### task_cls
1750
+
1751
+ Task trainer class name.
1752
+
1753
+ <table><tbody>
1754
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1755
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1756
+ <tr><td align="center"><b>customizability</b></td><td>reserved</td>
1757
+ <tr><td align="center"><b>type</b></td><td>str</td>
1758
+ </tbody></table>
1759
+
1760
+ ### tension_logit_max
1761
+
1762
+ Maximum tension logit value used for normalization to [-1, 1]. Logit is the reverse function of Sigmoid:
1763
+
1764
+ $$
1765
+ f(x) = \ln\frac{x}{1-x}
1766
+ $$
1767
+
1768
+ <table><tbody>
1769
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1770
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
1771
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1772
+ <tr><td align="center"><b>type</b></td><td>float</td>
1773
+ <tr><td align="center"><b>default</b></td><td>10.0</td>
1774
+ </tbody></table>
1775
+
1776
+ ### tension_logit_min
1777
+
1778
+ Minimum tension logit value used for normalization to [-1, 1]. Logit is the reverse function of Sigmoid:
1779
+
1780
+ $$
1781
+ f(x) = \ln\frac{x}{1-x}
1782
+ $$
1783
+
1784
+ <table><tbody>
1785
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1786
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
1787
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1788
+ <tr><td align="center"><b>type</b></td><td>float</td>
1789
+ <tr><td align="center"><b>default</b></td><td>-10.0</td>
1790
+ </tbody></table>
1791
+
1792
+ ### tension_smooth_width
1793
+
1794
+ Length of sinusoidal smoothing convolution kernel (in seconds) on extracted tension curve.
1795
+
1796
+ <table><tbody>
1797
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1798
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
1799
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1800
+ <tr><td align="center"><b>type</b></td><td>float</td>
1801
+ <tr><td align="center"><b>default</b></td><td>0.12</td>
1802
+ </tbody></table>
1803
+
1804
+ ### test_prefixes
1805
+
1806
+ List of data item names or name prefixes for the validation set. For each string `s` in the list:
1807
+
1808
+ - If `s` equals to an actual item name, add that item to validation set.
1809
+ - If `s` does not equal to any item names, add all items whose names start with `s` to validation set.
1810
+
1811
+ For multi-speaker combined datasets, "ds_id:name_prefix" can be used to apply the rules above within data from a specific sub-dataset, where ds_id represents the dataset index.
1812
+
1813
+ <table><tbody>
1814
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1815
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
1816
+ <tr><td align="center"><b>customizability</b></td><td>required</td>
1817
+ <tr><td align="center"><b>type</b></td><td>list</td>
1818
+ </tbody></table>
1819
+
1820
+ ### time_scale_factor
1821
+
1822
+ The scale factor that will be multiplied on the time $t$ of Rectified Flow before embedding into the model.
1823
+
1824
+ <table><tbody>
1825
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1826
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
1827
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
1828
+ <tr><td align="center"><b>type</b></td><td>float</td>
1829
+ <tr><td align="center"><b>default</b></td><td>1000</td>
1830
+ </tbody></table>
1831
+
1832
+ ### timesteps
1833
+
1834
+ Total number of DDPM steps.
1835
+
1836
+ <table><tbody>
1837
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1838
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
1839
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
1840
+ <tr><td align="center"><b>type</b></td><td>int</td>
1841
+ <tr><td align="center"><b>default</b></td><td>1000</td>
1842
+ </tbody></table>
1843
+
1844
+ ### use_breathiness_embed
1845
+
1846
+ Whether to accept and embed breathiness values into the model.
1847
+
1848
+ <table><tbody>
1849
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1850
+ <tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
1851
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1852
+ <tr><td align="center"><b>type</b></td><td>boolean</td>
1853
+ <tr><td align="center"><b>default</b></td><td>false</td>
1854
+ </tbody></table>
1855
+
1856
+ ### use_energy_embed
1857
+
1858
+ Whether to accept and embed energy values into the model.
1859
+
1860
+ <table><tbody>
1861
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1862
+ <tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
1863
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1864
+ <tr><td align="center"><b>type</b></td><td>boolean</td>
1865
+ <tr><td align="center"><b>default</b></td><td>false</td>
1866
+ </tbody></table>
1867
+
1868
+ ### use_glide_embed
1869
+
1870
+ Whether to accept and embed glide types in melody encoder.
1871
+
1872
+ <table><tbody>
1873
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1874
+ <tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
1875
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1876
+ <tr><td align="center"><b>type</b></td><td>boolean</td>
1877
+ <tr><td align="center"><b>default</b></td><td>false</td>
1878
+ <tr><td align="center"><b>constraints</b></td><td>Only take affects when melody encoder is enabled.</td>
1879
+ </tbody></table>
1880
+
1881
+ ### use_key_shift_embed
1882
+
1883
+ Whether to embed key shifting values introduced by random pitch shifting augmentation.
1884
+
1885
+ <table><tbody>
1886
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1887
+ <tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
1888
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1889
+ <tr><td align="center"><b>type</b></td><td>boolean</td>
1890
+ <tr><td align="center"><b>default</b></td><td>false</td>
1891
+ <tr><td align="center"><b>constraints</b></td><td>Must be true if random pitch shifting is enabled.</td>
1892
+ </tbody></table>
1893
+
1894
+ ### use_melody_encoder
1895
+
1896
+ Whether to enable melody encoder for the pitch predictor.
1897
+
1898
+ <table><tbody>
1899
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
1900
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
1901
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1902
+ <tr><td align="center"><b>type</b></td><td>boolean</td>
1903
+ <tr><td align="center"><b>default</b></td><td>false</td>
1904
+ </tbody></table>
1905
+
1906
+ ### use_pos_embed
1907
+
1908
+ Whether to use SinusoidalPositionalEmbedding in FastSpeech2 encoder.
1909
+
1910
+ <table><tbody>
1911
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1912
+ <tr><td align="center"><b>scope</b></td><td>nn</td>
1913
+ <tr><td align="center"><b>customizability</b></td><td>not recommended</td>
1914
+ <tr><td align="center"><b>type</b></td><td>boolean</td>
1915
+ <tr><td align="center"><b>default</b></td><td>true</td>
1916
+ </tbody></table>
1917
+
1918
+ ### use_shallow_diffusion
1919
+
1920
+ Whether to use shallow diffusion.
1921
+
1922
+ <table><tbody>
1923
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1924
+ <tr><td align="center"><b>scope</b></td><td>nn, inference</td>
1925
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1926
+ <tr><td align="center"><b>type</b></td><td>boolean</td>
1927
+ <tr><td align="center"><b>default</b></td><td>false</td>
1928
+ </tbody></table>
1929
+
1930
+ ### use_speed_embed
1931
+
1932
+ Whether to embed speed values introduced by random time stretching augmentation.
1933
+
1934
+ <table><tbody>
1935
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1936
+ <tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
1937
+ <tr><td align="center"><b>type</b></td><td>boolean</td>
1938
+ <tr><td align="center"><b>default</b></td><td>false</td>
1939
+ <tr><td align="center"><b>constraints</b></td><td>Must be true if random time stretching is enabled.</td>
1940
+ </tbody></table>
1941
+
1942
+ ### use_spk_id
1943
+
1944
+ Whether embed the speaker id from a multi-speaker dataset.
1945
+
1946
+ <table><tbody>
1947
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
1948
+ <tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
1949
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1950
+ <tr><td align="center"><b>type</b></td><td>bool</td>
1951
+ <tr><td align="center"><b>default</b></td><td>false</td>
1952
+ </tbody></table>
1953
+
1954
+ ### use_tension_embed
1955
+
1956
+ Whether to accept and embed tension values into the model.
1957
+
1958
+ <table><tbody>
1959
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1960
+ <tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
1961
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1962
+ <tr><td align="center"><b>type</b></td><td>boolean</td>
1963
+ <tr><td align="center"><b>default</b></td><td>false</td>
1964
+ </tbody></table>
1965
+
1966
+ ### use_voicing_embed
1967
+
1968
+ Whether to accept and embed voicing values into the model.
1969
+
1970
+ <table><tbody>
1971
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1972
+ <tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
1973
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1974
+ <tr><td align="center"><b>type</b></td><td>boolean</td>
1975
+ <tr><td align="center"><b>default</b></td><td>false</td>
1976
+ </tbody></table>
1977
+
1978
+ ### val_check_interval
1979
+
1980
+ Interval (in number of training steps) between validation checks.
1981
+
1982
+ <table><tbody>
1983
+ <tr><td align="center"><b>visibility</b></td><td>all</td>
1984
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1985
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
1986
+ <tr><td align="center"><b>type</b></td><td>int</td>
1987
+ <tr><td align="center"><b>default</b></td><td>2000</td>
1988
+ </tbody></table>
1989
+
1990
+ ### val_with_vocoder
1991
+
1992
+ Whether to load and use the vocoder to generate audio during validation. Validation audio will not be available if this option is disabled.
1993
+
1994
+ <table><tbody>
1995
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
1996
+ <tr><td align="center"><b>scope</b></td><td>training</td>
1997
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
1998
+ <tr><td align="center"><b>type</b></td><td>bool</td>
1999
+ <tr><td align="center"><b>default</b></td><td>true</td>
2000
+ </tbody></table>
2001
+
2002
+ ### variances_prediction_args
2003
+
2004
+ Arguments for prediction of variance parameters other than pitch, like energy, breathiness, etc.
2005
+
2006
+ <table><tbody>
2007
+ <tr><td align="center"><b>type</b></td><td>dict</td>
2008
+ </tbody></table>
2009
+
2010
+ ### variances_prediction_args.backbone_args
2011
+
2012
+ Equivalent to [backbone_args](#backbone_args) but only for the multi-variance predictor.
2013
+
2014
+ <table><tbody>
2015
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
2016
+ </tbody></table>
2017
+
2018
+ ### variances_prediction_args.backbone_type
2019
+
2020
+ Equivalent to [backbone_type](#backbone_type) but only for the multi-variance predictor model. If not set, use the root backbone type.
2021
+
2022
+ <table><tbody>
2023
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
2024
+ <tr><td align="center"><b>default</b></td><td>wavenet</td>
2025
+ </tbody></table>
2026
+
2027
+ ### variances_prediction_args.total_repeat_bins
2028
+
2029
+ Total number of repeating bins in the multi-variance predictor. Repeating bins are distributed evenly to each variance parameter.
2030
+
2031
+ <table><tbody>
2032
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
2033
+ <tr><td align="center"><b>scope</b></td><td>nn, inference</td>
2034
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
2035
+ <tr><td align="center"><b>type</b></td><td>int</td>
2036
+ <tr><td align="center"><b>default</b></td><td>48</td>
2037
+ </tbody></table>
2038
+
2039
+ ### vocoder
2040
+
2041
+ The vocoder class name.
2042
+
2043
+ <table><tbody>
2044
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
2045
+ <tr><td align="center"><b>scope</b></td><td>preprocessing, training, inference</td>
2046
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
2047
+ <tr><td align="center"><b>type</b></td><td>str</td>
2048
+ <tr><td align="center"><b>default</b></td><td>NsfHifiGAN</td>
2049
+ </tbody></table>
2050
+
2051
+ ### vocoder_ckpt
2052
+
2053
+ Path of the vocoder model.
2054
+
2055
+ <table><tbody>
2056
+ <tr><td align="center"><b>visibility</b></td><td>acoustic</td>
2057
+ <tr><td align="center"><b>scope</b></td><td>preprocessing, training, inference</td>
2058
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
2059
+ <tr><td align="center"><b>type</b></td><td>str</td>
2060
+ <tr><td align="center"><b>default</b></td><td>checkpoints/nsf_hifigan/model</td>
2061
+ </tbody></table>
2062
+
2063
+ ### voicing_db_max
2064
+
2065
+ Maximum voicing value in dB used for normalization to [-1, 1].
2066
+
2067
+ <table><tbody>
2068
+ <tr><td align="center"><b>visibility</b></td><td>variance</td>
2069
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
2070
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
2071
+ <tr><td align="center"><b>type</b></td><td>float</td>
2072
+ <tr><td align="center"><b>default</b></td><td>-20.0</td>
2073
+ </tbody></table>
2074
+
2075
+ ### voicing_db_min
2076
+
2077
+ Minimum voicing value in dB used for normalization to [-1, 1].
2078
+
2079
+ <table><tbody>
2080
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
2081
+ <tr><td align="center"><b>scope</b></td><td>inference</td>
2082
+ <tr><td align="center"><b>customizability</b></td><td>recommended</td>
2083
+ <tr><td align="center"><b>type</b></td><td>float</td>
2084
+ <tr><td align="center"><b>default</b></td><td>-96.0</td>
2085
+ </tbody></table>
2086
+
2087
+ ### voicing_smooth_width
2088
+
2089
+ Length of sinusoidal smoothing convolution kernel (in seconds) on extracted voicing curve.
2090
+
2091
+ <table><tbody>
2092
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
2093
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
2094
+ <tr><td align="center"><b>customizability</b></td><td>normal</td>
2095
+ <tr><td align="center"><b>type</b></td><td>float</td>
2096
+ <tr><td align="center"><b>default</b></td><td>0.12</td>
2097
+ </tbody></table>
2098
+
2099
+ ### win_size
2100
+
2101
+ Window size for mel or feature extraction.
2102
+
2103
+ <table><tbody>
2104
+ <tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
2105
+ <tr><td align="center"><b>scope</b></td><td>preprocessing</td>
2106
+ <tr><td align="center"><b>customizability</b></td><td>reserved</td>
2107
+ <tr><td align="center"><b>type</b></td><td>int</td>
2108
+ <tr><td align="center"><b>default</b></td><td>2048</td>
2109
+ </tbody></table>
docs/GettingStarted.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Getting Started
2
+
3
+ ## Installation
4
+
5
+ ### Environments and dependencies
6
+
7
+ DiffSinger requires Python 3.8 or later. We strongly recommend you create a virtual environment via Conda or venv before installing dependencies.
8
+
9
+ 1. Install The latest PyTorch following the [official instructions](https://pytorch.org/get-started/locally/) according to your OS and hardware.
10
+
11
+ 2. Install other dependencies via the following command:
12
+
13
+ ```bash
14
+ pip install -r requirements.txt
15
+ ```
16
+
17
+ ### Materials and assets
18
+
19
+ Some essential materials and assets are needed before continuing with this repository. See [materials for training and using models](BestPractices.md#materials-for-training-and-using-models) for detailed instructions.
20
+
21
+ ## Configuration
22
+
23
+ Every model needs a configuration file to run preprocessing, training, inference and deployment. Templates of configurations files are in [configs/templates](../configs/templates). Please **copy** the templates to your own data directory before you edit them.
24
+
25
+ Before you continue, it is highly recommended to read through [Best Practices](BestPractices.md), which is a more detailed tutorial on how to configure your experiments.
26
+
27
+ For more details about configurable parameters, see [Configuration Schemas](ConfigurationSchemas.md).
28
+
29
+ > Tips: to see which parameters are required or recommended to be edited, you can search by _customizability_ in the configuration schemas.
30
+
31
+ ## Preprocessing
32
+
33
+ Raw data pieces and transcriptions should be binarized into dataset files before training. Before doing this step, please ensure all required configurations like `raw_data_dir` and `binary_data_dir` are set properly, and all your desired functionalities and features are enabled and configured.
34
+
35
+ Assume that you have a configuration file called `my_config.yaml`. Run:
36
+
37
+ ```bash
38
+ python scripts/binarize.py --config my_config.yaml
39
+ ```
40
+
41
+ Preprocessing can be accelerated through multiprocessing. See [binarization_args.num_workers](ConfigurationSchemas.md#binarization_args.num_workers) for more explanations.
42
+
43
+ ## Training
44
+
45
+ Assume that you have a configuration file called `my_config.yaml` and the name of your model is `my_experiment`. Run:
46
+
47
+ ```bash
48
+ python scripts/train.py --config my_config.yaml --exp_name my_experiment --reset
49
+ ```
50
+
51
+ Checkpoints will be saved at the `checkpoints/my_experiment/` directory. When interrupting the program and running the above command again, the training resumes automatically from the latest checkpoint.
52
+
53
+ For more suggestions related to training performance, see [performance tuning](BestPractices.md#performance-tuning).
54
+
55
+ ### TensorBoard
56
+
57
+ Run the following command to start the TensorBoard:
58
+
59
+ ```bash
60
+ tensorboard --logdir checkpoints/
61
+ ```
62
+
63
+ > NOTICE
64
+ >
65
+ > If you are training a model with multiple GPUs (DDP), please add `--reload_multifile=true` option when launching TensorBoard, otherwise it may not update properly.
66
+
67
+ ## Inference
68
+
69
+ Inference of DiffSinger is based on DS files. Assume that you have a DS file named `my_song.ds` and your model is named `my_experiment`.
70
+
71
+ If your model is a variance model, run:
72
+
73
+ ```bash
74
+ python scripts/infer.py variance my_song.ds --exp my_experiment
75
+ ```
76
+
77
+ or run
78
+
79
+ ```bash
80
+ python scripts/infer.py variance --help
81
+ ```
82
+
83
+ for more configurable options.
84
+
85
+ If your model is an acoustic model, run:
86
+
87
+ ```bash
88
+ python scripts/infer.py acoustic my_song.ds --exp my_experiment
89
+ ```
90
+
91
+ or run
92
+
93
+ ```bash
94
+ python scripts/infer.py acoustic --help
95
+ ```
96
+
97
+ for more configurable options.
98
+
99
+ ## Deployment
100
+
101
+ DiffSinger uses [ONNX](https://onnx.ai/) as the deployment format.
102
+
103
+ Due to TorchScript issues, exporting to ONNX now requires PyTorch **1.13**. Please ensure the correct dependencies through following steps:
104
+
105
+ 1. Create a new separate environment for exporting ONNX.
106
+
107
+ 2. Install PyTorch 1.13 following the [official instructions](https://pytorch.org/get-started/previous-versions/). A CPU-only version is enough.
108
+
109
+ 3. Install other dependencies via the following command:
110
+
111
+ ```bash
112
+ pip install -r requirements-onnx.txt
113
+ ```
114
+
115
+ Assume that you have a model named `my_experiment`.
116
+
117
+ If your model is a variance model, run:
118
+
119
+ ```bash
120
+ python scripts/export.py variance --exp my_experiment
121
+ ```
122
+
123
+ or run
124
+
125
+ ```bash
126
+ python scripts/export.py variance --help
127
+ ```
128
+
129
+ for more configurable options.
130
+
131
+ If your model is an acoustic model, run:
132
+
133
+ ```bash
134
+ python scripts/export.py acoustic --exp my_experiment
135
+ ```
136
+
137
+ or run
138
+
139
+ ```bash
140
+ python scripts/export.py acoustic --help
141
+ ```
142
+
143
+ for more configurable options.
144
+
145
+ To export an NSF-HiFiGAN vocoder checkpoint, run:
146
+
147
+ ```bash
148
+ python scripts/export.py nsf-hifigan --config CONFIG --ckpt CKPT
149
+ ```
150
+
151
+ where `CONFIG` is a configuration file that has configured the same mel parameters as the vocoder (can be configs/acoustic.yaml for most cases) and `CKPT` is the path of the checkpoint to be exported.
152
+
153
+ For more configurable options, run
154
+
155
+ ```bash
156
+ python scripts/export.py nsf-hifigan --help
157
+ ```
158
+
159
+ ## Other utilities
160
+
161
+ There are other useful CLI tools in the [scripts/](../scripts) directory not mentioned above:
162
+
163
+ - drop_spk.py - delete speaker embeddings from checkpoints (for data security reasons when distributing models)
164
+ - vocoder.py - bypass the acoustic model and only run the vocoder on given mel-spectrograms
docs/resources/arch-acoustic.drawio ADDED
The diff for this file is too large to render. See raw diff
 
docs/resources/arch-overview.drawio ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <mxfile host="Electron" modified="2023-07-06T16:23:00.268Z" agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/21.5.1 Chrome/112.0.5615.204 Electron/24.6.0 Safari/537.36" etag="Jj5RvJnN6RNPHxq1C0K-" version="21.5.1" type="device">
2
+ <diagram name="第 1 页" id="YZcYkKBV4jQxwUdL2UfG">
3
+ <mxGraphModel dx="1270" dy="914" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
4
+ <root>
5
+ <mxCell id="0" />
6
+ <mxCell id="1" parent="0" />
7
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-8" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.25;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" parent="1" source="gMzeuS7BrVqeg4yIgOCW-1" target="FjPlvFc2R7vJgaK0KZA3-1" edge="1">
8
+ <mxGeometry relative="1" as="geometry" />
9
+ </mxCell>
10
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-9" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;" parent="1" source="gMzeuS7BrVqeg4yIgOCW-1" target="FjPlvFc2R7vJgaK0KZA3-3" edge="1">
11
+ <mxGeometry relative="1" as="geometry" />
12
+ </mxCell>
13
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-10" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.75;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" parent="1" source="gMzeuS7BrVqeg4yIgOCW-1" target="FjPlvFc2R7vJgaK0KZA3-6" edge="1">
14
+ <mxGeometry relative="1" as="geometry" />
15
+ </mxCell>
16
+ <mxCell id="gMzeuS7BrVqeg4yIgOCW-1" value="&lt;font face=&quot;Times New Roman&quot; style=&quot;font-size: 20px;&quot;&gt;Variance Model&lt;/font&gt;" style="rounded=1;whiteSpace=wrap;html=1;strokeWidth=1.5;fillColor=#dae8fc;strokeColor=#6c8ebf;" parent="1" vertex="1">
17
+ <mxGeometry x="232.97" y="720" width="300" height="50" as="geometry" />
18
+ </mxCell>
19
+ <mxCell id="gMzeuS7BrVqeg4yIgOCW-3" value="&lt;font face=&quot;Times New Roman&quot; style=&quot;font-size: 16px;&quot;&gt;Phoneme&lt;/font&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
20
+ <mxGeometry x="267.97" y="800" width="80" height="30" as="geometry" />
21
+ </mxCell>
22
+ <mxCell id="gMzeuS7BrVqeg4yIgOCW-4" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.25;entryY=1;entryDx=0;entryDy=0;strokeWidth=1;" parent="1" source="gMzeuS7BrVqeg4yIgOCW-3" target="gMzeuS7BrVqeg4yIgOCW-1" edge="1">
23
+ <mxGeometry width="50" height="50" relative="1" as="geometry">
24
+ <mxPoint x="562.97" y="620" as="sourcePoint" />
25
+ <mxPoint x="612.97" y="570" as="targetPoint" />
26
+ </mxGeometry>
27
+ </mxCell>
28
+ <mxCell id="gMzeuS7BrVqeg4yIgOCW-6" value="&lt;font face=&quot;Times New Roman&quot; style=&quot;font-size: 16px;&quot;&gt;Word&lt;/font&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
29
+ <mxGeometry x="342.97" y="800" width="80" height="30" as="geometry" />
30
+ </mxCell>
31
+ <mxCell id="gMzeuS7BrVqeg4yIgOCW-7" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeWidth=1;" parent="1" source="gMzeuS7BrVqeg4yIgOCW-6" target="gMzeuS7BrVqeg4yIgOCW-1" edge="1">
32
+ <mxGeometry width="50" height="50" relative="1" as="geometry">
33
+ <mxPoint x="657.97" y="635" as="sourcePoint" />
34
+ <mxPoint x="402.97" y="785" as="targetPoint" />
35
+ </mxGeometry>
36
+ </mxCell>
37
+ <mxCell id="gMzeuS7BrVqeg4yIgOCW-8" value="&lt;font face=&quot;Times New Roman&quot; style=&quot;font-size: 16px;&quot;&gt;MIDI&lt;/font&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
38
+ <mxGeometry x="417.97" y="800" width="80" height="30" as="geometry" />
39
+ </mxCell>
40
+ <mxCell id="gMzeuS7BrVqeg4yIgOCW-9" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.75;entryY=1;entryDx=0;entryDy=0;strokeWidth=1;" parent="1" source="gMzeuS7BrVqeg4yIgOCW-8" target="gMzeuS7BrVqeg4yIgOCW-1" edge="1">
41
+ <mxGeometry width="50" height="50" relative="1" as="geometry">
42
+ <mxPoint x="717.97" y="635" as="sourcePoint" />
43
+ <mxPoint x="462.97" y="785" as="targetPoint" />
44
+ </mxGeometry>
45
+ </mxCell>
46
+ <mxCell id="gMzeuS7BrVqeg4yIgOCW-10" value="&lt;font face=&quot;Times New Roman&quot; style=&quot;font-size: 20px;&quot;&gt;Acoustic Model&lt;/font&gt;" style="rounded=1;whiteSpace=wrap;html=1;strokeWidth=1.5;fillColor=#dae8fc;strokeColor=#6c8ebf;" parent="1" vertex="1">
47
+ <mxGeometry x="232.97" y="580" width="300" height="50" as="geometry" />
48
+ </mxCell>
49
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-1" value="&lt;font face=&quot;Times New Roman&quot; style=&quot;font-size: 16px;&quot;&gt;Duration&lt;/font&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
50
+ <mxGeometry x="268.97" y="660" width="78" height="30" as="geometry" />
51
+ </mxCell>
52
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-2" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;strokeWidth=1;entryX=0.25;entryY=1;entryDx=0;entryDy=0;" parent="1" source="FjPlvFc2R7vJgaK0KZA3-1" target="gMzeuS7BrVqeg4yIgOCW-10" edge="1">
53
+ <mxGeometry width="50" height="50" relative="1" as="geometry">
54
+ <mxPoint x="442.97" y="670" as="sourcePoint" />
55
+ <mxPoint x="312.97" y="630" as="targetPoint" />
56
+ </mxGeometry>
57
+ </mxCell>
58
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-3" value="&lt;font face=&quot;Times New Roman&quot; style=&quot;font-size: 16px;&quot;&gt;Pitch&lt;/font&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
59
+ <mxGeometry x="343.97" y="660" width="78" height="30" as="geometry" />
60
+ </mxCell>
61
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-4" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;strokeWidth=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" parent="1" source="FjPlvFc2R7vJgaK0KZA3-3" target="gMzeuS7BrVqeg4yIgOCW-10" edge="1">
62
+ <mxGeometry width="50" height="50" relative="1" as="geometry">
63
+ <mxPoint x="526.97" y="680" as="sourcePoint" />
64
+ <mxPoint x="391.97" y="630" as="targetPoint" />
65
+ </mxGeometry>
66
+ </mxCell>
67
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-5" value="" style="endArrow=classic;html=1;rounded=0;exitX=0;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.75;entryDx=0;entryDy=0;edgeStyle=orthogonalEdgeStyle;" parent="1" source="gMzeuS7BrVqeg4yIgOCW-3" target="gMzeuS7BrVqeg4yIgOCW-10" edge="1">
68
+ <mxGeometry width="50" height="50" relative="1" as="geometry">
69
+ <mxPoint x="432.97" y="710" as="sourcePoint" />
70
+ <mxPoint x="482.97" y="660" as="targetPoint" />
71
+ <Array as="points">
72
+ <mxPoint x="213" y="815" />
73
+ <mxPoint x="213" y="618" />
74
+ <mxPoint x="233" y="618" />
75
+ </Array>
76
+ </mxGeometry>
77
+ </mxCell>
78
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-6" value="&lt;font face=&quot;Times New Roman&quot; style=&quot;&quot;&gt;Variance Parameters&lt;br&gt;&lt;font style=&quot;font-size: 10px;&quot;&gt;(energy, breathiness, etc.)&lt;/font&gt;&lt;br&gt;&lt;/font&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
79
+ <mxGeometry x="392.97" y="660" width="130" height="30" as="geometry" />
80
+ </mxCell>
81
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-7" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;strokeWidth=1;entryX=0.75;entryY=1;entryDx=0;entryDy=0;" parent="1" source="FjPlvFc2R7vJgaK0KZA3-6" target="gMzeuS7BrVqeg4yIgOCW-10" edge="1">
82
+ <mxGeometry width="50" height="50" relative="1" as="geometry">
83
+ <mxPoint x="625.97" y="700" as="sourcePoint" />
84
+ <mxPoint x="481.97" y="640" as="targetPoint" />
85
+ </mxGeometry>
86
+ </mxCell>
87
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-11" value="&lt;font face=&quot;Times New Roman&quot; style=&quot;&quot;&gt;Transformation Parameters&lt;br&gt;&lt;font style=&quot;font-size: 10px;&quot;&gt;(gender &amp;amp; velocity)&lt;/font&gt;&lt;br&gt;&lt;/font&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
88
+ <mxGeometry x="521.94" y="650" width="92.06" height="50" as="geometry" />
89
+ </mxCell>
90
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-13" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=1;entryY=0.75;entryDx=0;entryDy=0;edgeStyle=orthogonalEdgeStyle;" parent="1" source="FjPlvFc2R7vJgaK0KZA3-11" target="gMzeuS7BrVqeg4yIgOCW-10" edge="1">
91
+ <mxGeometry width="50" height="50" relative="1" as="geometry">
92
+ <mxPoint x="412.97" y="690" as="sourcePoint" />
93
+ <mxPoint x="462.97" y="640" as="targetPoint" />
94
+ </mxGeometry>
95
+ </mxCell>
96
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-22" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" parent="1" source="FjPlvFc2R7vJgaK0KZA3-18" target="FjPlvFc2R7vJgaK0KZA3-21" edge="1">
97
+ <mxGeometry relative="1" as="geometry" />
98
+ </mxCell>
99
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-18" value="&lt;font face=&quot;Times New Roman&quot; style=&quot;font-size: 16px;&quot;&gt;Mel-spectrogram&lt;br&gt;&lt;/font&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
100
+ <mxGeometry x="315.97" y="520" width="134" height="30" as="geometry" />
101
+ </mxCell>
102
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-19" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" parent="1" source="gMzeuS7BrVqeg4yIgOCW-10" target="FjPlvFc2R7vJgaK0KZA3-18" edge="1">
103
+ <mxGeometry width="50" height="50" relative="1" as="geometry">
104
+ <mxPoint x="532.97" y="570" as="sourcePoint" />
105
+ <mxPoint x="582.97" y="520" as="targetPoint" />
106
+ </mxGeometry>
107
+ </mxCell>
108
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-33" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" parent="1" source="FjPlvFc2R7vJgaK0KZA3-21" target="FjPlvFc2R7vJgaK0KZA3-32" edge="1">
109
+ <mxGeometry relative="1" as="geometry" />
110
+ </mxCell>
111
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-21" value="&lt;font face=&quot;Times New Roman&quot; style=&quot;font-size: 20px;&quot;&gt;Vocoder&lt;/font&gt;" style="rounded=1;whiteSpace=wrap;html=1;strokeWidth=1.5;fillColor=#e1d5e7;strokeColor=#9673a6;" parent="1" vertex="1">
112
+ <mxGeometry x="292.94" y="440.13" width="180.06" height="50" as="geometry" />
113
+ </mxCell>
114
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-32" value="" style="shape=image;verticalLabelPosition=bottom;labelBackgroundColor=default;verticalAlign=top;aspect=fixed;imageAspect=0;image=data:image/png,iVBORw0KGgoAAAANSUhEUgAABJIAAACqCAYAAAD2i2XMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABcMklEQVR4nO3dd3hTdRcH8G+694IOaCmlUChllL03ZSsoIgguVNwiuJAtsp0oLkBEX1QQFBBlyp5llNEWyuhkFVq69877R2lI2uwmuRnfz/O8z9skN/eeSpPcnHt+54jEYrEYREREREREREREKlgJHQAREREREREREZkGJpKIiIiIiIiIiEgtTCQREREREREREZFamEgiIiIiIiIiIiK1MJFERERERERERERqYSKJiIiIiMjAissqEXM7BxygTEREpkYk5qcXEREREZHBiMViNJu1S3I7ZfkoAaMhIiLSDCuSiIiIiIgMKC2vVOb2H2duChQJERGR5phIIiIiIiIyoN2X7srcnrk1VqBIiIiINMdEEhERERGRAX38b5zQIRAREWmNiSQiIiIiIiIiIlILE0lERERERAZy7V6+0CEQERHVCxNJREREREQG8vpv5+TefyuryMCREBERaYeJJCIiIiIiA0nKKJR7/+XUPANHQkREpB0mkoiIiIiI9Gx9ZAqCZu5U+PhrCiqViIiIjA0TSUREREREejZ/+2WV25xMyDBAJERERPXDRBIRERERkR4dvJqm1nYXbuXoNxAiIiIdEInFYrHQQRARERERmStlS9pqe7JzAHKKy7Hqmc6wthLpMSoi85NbXI4dMakY0bYRvJzthA6HyGyxIomIiIiIyEj8ee429sWl4eDVdKFDITI57266iDnbLuHl9VFCh0Jk1phIMlKnkzLxxX/XUF5ZJXQoREREZuXTPVfx6q9RqOBnrMGk5hRj5pYYXLuXL3QoBpdfUq7V815ez79RIk0deJCAPXcjW+BIiMwbE0lGasKaU/jmYAI2nrkpdChERGRCuGJdsfLKKoxfHYnvDydi7+U07Ii5K3RIFuON38/jj7O3MOLro0KHYnBxqXlaP7ftgr06jITI8ojFYuyISUVyRqFa2xeUVqCqSoxX1kfhxV/OoqpKjNvZRXqOksj02AgdACmn7pseERGZr3+jU1FcXonR4Y0RdzcPIT4uuHovH50DPVElFiO7qBxpeSVYdzwZMXdysWNqHzjYWqOySoyMglIs3BGHl/sGo0MTD6F/FUHN+/sSziRnSW7nl1YIGI1luXK3OplSZYF5ztnbYrV+bkm55hVJi3fEwdbGCh8OD9X6uETmoKKyCvuvpOGtDRcAACnLRynd/tDVdLzwy1mZ+4Jn75L8fOC9/vB2tYebg63ugyUyMUwkGYnC0gocvJqOAa284Sr15sQLy0RElq24rBJTN1afBM/4K0bmsR7BXjiVlFXnOWuPJaGFjyumbjyP8srqD5KdMXdxYd4QZBaWopG7I5ztLe8U4I+zt2Run03OwrM9mgoUjXkrraiEvY210GEYhcT79bsomFlQigYu9gofv5dbgq/2X8dT3QJxI7MQa48nAwCmDmoBJzvLe50T1Wg5d7dM8vqZtaex5rnOcLCxRpVYjG8PJeBuTglKKioxdVALLNoRp3R/g784AnsbK1xbPELPkRMZP366GIkP/orGrth7GBTqg3WTuwodDhERGYlOi/YpfExeEgkAPv/vutz7Oz7YVyN3B0TOGlz/4ExIlZxSmH+iU7FyYkcBojFv3x1KwGd7r+G3l7qjT0hDnE7KRGkFe/1o67l1Z7Dz7b5yH/v99A3M2XYJQN1Eac2f/L3cEny65yoe6+iPfi299RorkZDi02R7sNV+2z+ekIGw+XvR3NsZbRq745/oVMlj2y+mQh2lFVU4m5IFTyc7tPBxqXfMRKaKiSQjsSv2HgBwQgcREckoLq/U+T7v5pbgbEoWugZ56XzfxurI9ftCh2AxPtt7DQAw5+9YHPlgICasOSVwRLpV04dMJBKhpLwS9jZWEIlEcrfNLdau0ba0y6l5uJ9fCm/Xh1VJ0bdy8PYfF3AjU3HvlpFfH8PNrIePb71wR+XSHl24nV0EDyc7uFhg1SMJ505OMcavjlRr28T7hfWqFHxyVfVxDPF6IjJWfIcnIiIyUvqc3PnkqkhsfaMXOgV66u0YxiRPweSsgtIKfuElhcoqqvDJnqvILCiFm6MtPnq0DTp8/B/ySyvwz1u9MfrbExjWxhern+0i9/mX7+TqJI7kjEKZRNIza0+r7PElnUQylJuZRej32SE42VkjbuFwgx+fLM/ZlCxJYoeIDIdT24iIiIxUyJzdet3/v9HqlfKbs2ErLG+KmKEoq5YxFV/su4afjifj74upWB95A2uPJUkSOKO/PQEA2Hs5DSUKKgfn/n1JJ3HUNCuvoW2j+FVHEmVul6lYchiXmoeDV9PU3v+JxAwAQFGZ7ispiWrsjLmLxTvi8G90qlEkkUorKvHWhvPYHHVL9cZEZoKJJCPHMc5ERJap0gDjrX4+kYLUnGK9H8cY/Bl1W+79dyzk9xfKLQGqYnRFLBZj9ZEkmfv+PCf/7yh03h659yfpaPruR/9c1sl+lu++KqnOi0rJQsu5u7HyQLzC7UeuPIYXf4nC1Xuyiaz0vBIs+Ocypv9xASlKfsct527jfydTdBI7UX5JOZbuuoI3N5zH2uPJkkEUQtscdRs7Yu7WGYhBZM5Yy20ELqfWr+y5poGolZX89flERGR6Ir48YpDjnEnOwmMd/Q1yLCEdT8hQ+Fh6Xgl83BwMGI3l6PvpIZ3tq7SiEuduZKNLUy/Y2ej3WujaY0lYVSuJBAAJ6QV6Pa4yYrFYYS8mTbRf8B+8nO2QVVgGAPhy33X0b+mNdv7uyC+tgLujLXbG3MWf5x5WVySmFyLUz01y+5mfTuN6WvV/i8ikTJyeHSH3WO/9GQ0AGBLmi8YejvWOnSzXrawifLLnKnbE3BU6FIlj8fcRczsXJ5R8vhCZKyaSBJZ0vwCjVh7X+vlVVWIM+6q6LH/v9H5MJhERmYlkHVUyqDJ900W0C3BHc2/znD6TmlOM67Um+dR2M6uIiSQDKimvhIOttcbPm7U1FlvP3wEARLT2wWfjwuHpbKfr8JBTVIbFO6/Uez+Klrtp67O91zBjeKhO9lWTRKox5rsTkp8butgjo6BU5vGj1+9jRFs/HLyajnUnkiVJJABIy5Pdtob0e1iBiqV45ZVV2H3pHno089L4tSgWi3EtLR8tvF1gY615gnHLudto5OGAXs0bavxcMox7uSU6TUrryrM/nRE6BCLBcGmbwC7eylH6uKqFDVlFZYhPL0B8egFydDAZhIiMX05RGdYdT8aSnXGY/scFuWPNyXSVV1ZhR4xhexcN/sIw1U9C6LX8ICb/fFbpNudvZhsoGgK0781Vk0QCgP1X0jF900UdRfSQWCxGh4X7dLKvw9d0Oynw+8OJqjfSgdpJJADYFHUL41dHYsr6KJxMzKzzeNL9AiTeL4D05czHpJJTqqw5moS3N17AiK+Pydx/I7MQvZYdwLrjyXKftzv2LkLn7cHwr47hAy2WFcWl5uG9P6Mx6cfTMvfnFpWr7B8lT35JOd7ZdFGjvlKWRJuWHWKxGG9uOK+HaIioPphIMhFisRhpeSVy7hcgGCIS1DubLmLhjjj8eKy6AezSXfW/ck7GYfWRRITM2Y23NhhH3wdLsXTXVaFDsCjH4jOQkK68SkwdR67rNlEDAD8c0T5ZU7tVQUWV/qYuCiHqhuKE66AvjmDwF0dkqrBypS5wqjpf3X+lOvGSWViGvy88TBjO234ZqbklWLgjDkB1kquorAJHrt/HiK+P4fXfz6P0QcJnm9Tz8kvK8dPxZBy+lo7SisoHMYjx1obz6PfpIUmyXrpHXM1FmfT8EoQv/A+DvjisPGg5vt4fj20X7uDFX6I0fq6xupyai5+OJ+O1X8+hqEz9Ju+pOcV4a8N5nHzQgH3+9kvo88kh5CuYoKnIwavpOKfkb4+IhMGlbUau5oN30Y4rWHciGcvGtsPEboFyt+WiNiLLcKjWVe61x5Px+oDmaOBir+AZZCqW7RYuoaHtciPSv8yCUpy/mYMTCRlo2sAJIT6uKK+qwobTN+HlZIdPxrUXOkSN/BOdin+iU5GyfFS995VXUg43B9t67aO4rBKzt8ViWBtffLVfceNpVfbFpaFNY3fJbX0khG9nFyElw3gbmCtqXp9TVKa0x9OFmzmSn6dvuojR4Y1RKRbjqFSyMD2/BN2WHFB6/B5LD2DVs50xc0sMrt6rTlb2CPbCH6/0xIVbOZL+Om9tuICewQ3w+X/XJM8Nnb8H0waHwNWh+uvR7exi9PnkIKZHtIS/hyNSMgsl5+DJGYXYFXsXz/cKgov9w69Td+Vc9K2hqx5XhnA8PgPvbr6IeY+EyTS0zvipFD+/0BU/HktGr+YN0CO4AYDq/x47olMxuXcQXB1scTIhA5PWVld57Yi5i4vzh2B95A0A1YMPXuzTTO1YTK1h/6ytsWjl64Ldl+7h8yfD0cTLSeiQiPSCiSQlTiZmYNv5O5g7KgzuTvU7SVFE3c+TdSeqS3qX7rwik0iKTKpbXkxE5qvmymptnRfvR/ySEbASiWDNXmkmSfoqvBDu5BSbbZ8kY5RTVIYLt3LQL8Rb6Wt2X1waXl6vvLqhXYA7PJ3sMKp9I12HafQ2n72FKX2D67WPNUeTsO3CHZmKFm1IV90cj9dP892P/42DvxE3rf7xmPwlaBPWnAIA9A1piLXPd4FYDKWJ6wlrIuv8nj8p2Le0e3kldZbUnUrKQs9lB+r0rOq8eL/M7bKKKny295rMfbezi/H+g4bhALA/Lg0pmYVIvF/d/2n7xTv4753+kselK5wqKqtw4VYO2ge448+o2/hq/3X8+lJ3tG70sGm5salJdj3zU3USqPZUtKgb2Wi34D8AwMoD8Zg5IhRlFVX4ct91ANWvJW9X+zrTCqWXi269oFki6a/z8iclGquNZ25Kfu776SEc/3AgAjydUFRWgbKKKng4yfZ1K6+sQnF5Zb0T4kSGxkSSEjXrpW1trLD08XYCRyPf21Jv8IqSUreyivDOpouY0jcYw9v6GSgyItKHjadvKnwsZM5uAMCOqX3Q1t9d4XZkfMRisV76vWgWg6CHtzijVh7HnZxidGnqib9e7yXz2M3MIhSUViDY21llEgkA5v59CQAwrM0IrZoNC+XK3Ty1v1Qr6q2iiwqP9HzFVSSauHTn4dK2ayoavGurpLwSv5xM0cu+DeFYfAZaz9uDKjEQPX8oCsoq8Nqv5+psdzYlG2chu5xp9dG6U/TUdTdXN//GB66my9yuaTq+IyYV644ny1RWtXjwmSxtxl8x+HdqH5SUV2Lv5Xto3cgNIT4uEIlEyCkqw5bzdzA6vDG8XQ1fYXzuRjae+OGkRs9ZXquKNr+0AvkqGqtfupOH+LR8BDV0hm2t96vKKjGOXE/HmeRs7L50FzcyTasaSZ7JP5/Fxpd7oOuS6sRl7IKhcJVKGg376iiS7hdi9bOd0cLHhRd0yGSYztmGgDYo+eKmD9INuMUq220rVl5ZhbKKKszeFouoG9l47be6H9REZFoW/BuncptHvtF+EiQJY1+c8I1ZaypfLZW8PoT6EpeaJ1kCJK/vTL/PDmHkymMInbdHo/22mLNbZimQsavdWFmZZrN2yb1f28bdYrEYNzILIRbX50xLlnSSYdEO1e/V2jhlBpXoNfMhwhf+h97LDyL2Tq7yJxi562n5eGvDBZyXSiIpUlhaXZUy468YTPvjIoauOIpms3Yh6X4Bpm+6iEU74jD55+pJYBWVVbiZWYToWznYfPYWcovq9ha6npaPTDnN0ZWJT8vH+NWRiHzQNL28snqZrKZJpPoYsuIoXpAzBOHHY0l48ZcorDqSaBZJJABISC/A9E0PL/xfT8vHjL+i8eWDZZVJD6rbXv31nNLBFxWVVbifr9m/NZE+sSJJgWPxsidiWYVl8NLDiFmRnM5G46TeyFNzZE9slZ3sSO9LLBaj7yeHUFhWYdQl0ESkvjVH1W8C+9upG3i6e6DJ9GOwdK/IuSJvaBtO38RHj4bB3sYy+yR1X3pAJz171DFypfwEilgsrlePHgB4bt0Zg/0eupBdWAZPFedXCekFCh9TNf1WkR+PJZlkk/XySpYOGpuhK46qvW1SRiFazq1bqTRIKoFwOTUPlVXiOhVNM7bEYMOU7sgsLMOK/dfRKdATf52rXvY1KNQH9jZW+OGZzpLt7+YWw8nOBu6OskumXvn1HJIzCjHxx1MI9naWJDIM7XhCBoJm7sTQMF94u9pj8WNtsfnsLUFi0bcTCQ8TwLtj72FzVPW/m7wlfjXVl7XP355cHYkLN3NYdU5Gg4kkBZbVOrnILtJPIkmeCqlR3gevpmN9ZIrc7V6vVWGUXVSGWdtiML5LE/Rs3gD3HlxdTVXQ+JCITMf1tHyNvvTM/fsS8krK8Xr/5kwmkdoqKsWw55mBwX1/OAFPdQ1E4v0CfH2gfokkU5NdpDqRpKg3XI20vBL4ujlodFxTTCKR5Wg+W34FXk0DawAyCaCDD6rhvt4fj2kRIcgoKEXPZQcBQJJYFovF+PlECpKl+hcJlUSS9t+DitzfDbwCRChrjz+s/p21NbbO4+NWRaJKLMbPk7tK+indzy+VLJvccv42E0lkFLi0TYGqWmvxFa3NN4T52y/XiaOgtAK7L92T2W7Z7ivYFXsPk2uViqr6Enkzs0jjUZxEZFi3szUv8f50zzU8teaUoO9fpFrtBrBC0naZENXPp3uuodOifRovUVHk4FXhl0qqS9lI+Ro5cpb0SFuop2Vk2krJEP7LOVmmFfuvIz4tH5dT8yT3PbUmEi3n7sbey2lG91qxdLW/ywHVvaou3MxBh4X7MOTLI5j04ylJfyUA+PlECp796TQS0vXTh41IXUwkKVBZVTuRpJ/jlFVWabR9YVklpm68gHu5dauMtFlLnHS/AP0+OyTzBkVE2iurqMKeS3eRU1RW57G1x5Kw7cJttRI76XklkpOE8soqXFSj94I8p5OzcDo5S6vnkv5VVokl/TCMwUw5V0ctSVmFZp/Juvbab+d1sp8Xf4lChYbnF0KZ8VeM0v/u2y/ewdNSVRjyJKQV4HKq8fTZGfD5YaFDIAs2ZMVRHLzyMJl8KikLZRVV7JVqguLTC3AysW5ftGPxGYj48qjGF3+Ky4znwhWZPiaSFKiTSNLTceZuu6Txc/6NTsWu2LoZ7Kv35GemlRUkHU+oHk9bUm4aJ5xkvi6n5mL57qsoUDHtw9h9ue86XvvtPCb+KPvFJ+l+ARbvvIJ3NkXjqTWnFH7JE4vFEIvF6Lb0ACK+PIpzN7LQau5urDyYoHVMT605haoqViUZm5fXR6H57F04lcREn7EwRHWurqqOVNl6vn6j7A2pXMH7YXllFab9cVHl86+l5WPUyuP4/fQNtY53NoWvOTJv/4tU77VApk3d9zwA2Bx1C63n7zH4ECkyX0wkybHxzE0k1SpL3nhG9y86sViscUVSDVVl3tq6m1uMlQfiUWjiX+bJ9IxaeRyrjiRi+e4rQoeitZOJGVh1pLoh9pW7eTKP5RY/fM2eTs5C2Py9da7CV1WJ8fj3J/HMTw+TUE/8EAld5IBe+KXudBQSTlRKllFMaiNZKQaYEqSo0bauzdgSY5Dj6JO8ZR/KzFHz4tyqw+oPLiAiMlbKqjlrV7/P+Kv6M2H2NtnKY15oJG0xkSSHvMZnP59IwR9nbuJWVhH2x6XppOfIjpi7Wj9X1bDaPx9MA1C5n1q76bnsIL7cdx1tPtrLEZMkiN9O3cT5m6p7ZhijST/KX35xJjkL6bVeT2WVVVh3Ihkvr4/C+NWR6L38II7G38fFWzky0z105cj1+4hLzVO9IeldYWkFxq2KFDoMhSy5p9a7my/qdf+5ReVIy+Nna20VCr7IvL3xgtz7tZGcUYjdsXchFotx4EFjYn14fp3xLFUlIvN2/mYONp65iTnbYnEr6+GFkENX09Ft6QEcvX5fybOrp2Z2W3oAc7ZZ9rJ20g4TSRqYuTUWfT89hCnro3BYxQtTHfLWvKrr5xMpSh+f+/fDq3Lqzmuq/eVh6kbd9Gog0tTY708KHYLORKVkYfzqSLwqZ7z78t1XsS8uDWeSs3Anp7hOo3xde3KV+fx3NWX/UzCJ01j8cjJF6BAEo02vQU2EL/xPr/uvzVT6JPX79JDC5W31lVlQigGfHcLAzw/j9d/P48AV/SWRgOqkPRGRoczaGovfT9/E8+vO4NKdXJSUV+KFX87ifn4pnlt3BlmFZRj5tWwlbM1nw/8iU5BRUGoxE/NIt5hI0tK7my4KHYLapKe21fSf2Xz2Fp5ee0qmH0Ttkyv27SBNVVWJUV5ZhdScYmw8c1PtaVTmWlabXVhmVJUnhWWVOHRNv1+iSLVzKcZdcfeDhS/7OafGFDFT8dspw/ZJuXYvH5ujbmn8vNziclxT0OdRU5fuyDbdfu/PaJkli1PWR+nkOERExiQpoxCPfHMcg784InN/p0X7EFer3ULr+XswfnUkvtofb8gQyczYCB2AqcrWU48ifZCuSLp2Lx+dm3pKeidIL6Ex1eVEZDzGfHcCNzILkVdSnbCctTUWw9v4IcTXBVvP38E3kzqiU6Bnnef9fdF0msJq4mUj/MLyws9n8d87/dDS11XoUCzG3dxifL73Oib3CkK7AHe9LqvRBVNveF9fcXfz0Llp3fep+pI3yVHfjidkYnLvZgY73rCvjhrsWIo88s1xzBoRiv/i0vDpuPY4fI0VQkRkOe7k1J3sXVt5pRhnak30/TUyBc/2DNJTVGSOWJFkATILDX/ySpYp9k6uJIlUY8/le/jmYALu5BRj7Pcnse54cp3nnZVToZGWV1KvWPbHpeF0ku57DWkiykgrG+b9rfm0SNLetI0XseX8bTz67XEUlRl/kqbIwscD38lWfRKuDSEqvfZfscyG7st2X8W5G9l1rswTEZF887ZfximBz5vJtDCRZGG2nL+NhPQCuY9VWnCDVaq/vy+oV1W0cEcc/rt8D2KxGCcTM5CeXyJ36kSklj3EqiefncCU9VGYsOYUcg1QPSgWi/XepFeXTifrd9mqWCw22+WK2rh672FJ+aIdpjuVEKi+0mnu/7arjiTqZXJpam79kuPmjqcgRETCemrNKUTfykGlmX/Ok24wkWRhNpy+iYgv5V+hW30kycDRkDmZrkHfsFd+PYfn1p3BpB9Po+eyg9hyvu6UQVWTCRU5eDUdF27mSG6HL/wP6/Xc3Djmdi62njet5Xna9DFR1yPfHEfw7F2IT9NNzxNTJ/0FeeMZ02hoKW8Z1j/Rqei9/KBGr3VT1X3pAZ3v89/oVJ3vUx1CLKnThdq9joiISP/GfHcCbT/aiwtseUIqMJFUD81n70JeifbVDvKqMIgsxbH4DABQeNVjzjbNl19lFJTKbaQ6f/tljfeljt9P38CTq04avKGtLsz4K0btZujqunYvH8/+dBqXU6srcIasOIp/BPrybEyqTLDUYuGOuDr3fXcwAQAs4t+0dp+okwkZmLjmlEkmZaQT66bkjd85OZaISAjF5ZWY8VeM0GGQkWMiqRZN+rJUVokx+IsjeGvDea3G1sqrwiDjUFJeie0X7yBbR/2l9DXW2Jxp06fl6R9PK3yspLwSu2Lvqlzqlp5Xgjd+P4eTiRkQi8XYcPqmwilOc7ZdwtmUbPx5zjRfyyv2XdfZvu7nl2LYV0clCcIab2+8AADIKynHO5su4rAFTo0rNMGeQ6ZWYacPMbdzAFRX9ExaexqRSZnosHCfVvsSMgFlqCUK9T1O7SrUm1lFCrYkIiJ9K2XBA6nARFItr/56TqPt7+eXYkfMXeyIMc8rtDO3xEBsglfT62Pv5XsInbcH0/64iI6L9tX79//uUAJC5uzGoWvpSMko1FGUupNRUIoKM0h0bb94B9eULKVavDMOb/x+Hs//fEbpfmZvi8Wu2HuY9ONpjFsVidnbYvHEDydx9Pr9OhMuTN3qo/VbzlpUVoGt528jaOZOdF2yX+F2lVVivLI+Ctsu3MHkn88ysWpixGIx0vNLcDvbtL7Yi8XievVIG/3tCQDAK+s1Oy+QZ56eqiLVsWy3YfpyJdfz883CTjWIiIwak/mkChNJtVy8laPV85Lva3YC9fX+eK2OY2h/nL0lWaZiCcavjqyTTOy6pH69Mj7bew1A9dj1AZ8fNqovY7+euoEui/ejxZzd2Hv5ntDh1PHxv5cRn5avVjJv2h8XlT7+26nq3jTyXuNpeSU4cv0+Kqtkx6FKVyI9t+4Mxq+OxJYH1UfavlcYG22//N3KKkLY/L14d3O0ym2bz96FU0kP/7u2mrvbKJOqJKukvBJf/ncNzWbtQrclB0yusmre9ksIX/hfvfZRVSXGmRTZBLI2S0IvpwrX7ydRw/MToSRnFOLNDeex4J/LFncBi4jIGLFXHSnDRJKOrDyYoHbz2mv38rFiv+6WlOibqtLGqioxjl6/j8yCUgNFpB8t5+yWW3GSUVCKkV8fQ2mFbr5Exdw2jjflP87clBkDr2k1niH8fCIFQ1YcRbNZu7Dh9E20nLMbUVJf6sRiMa7ey9O4oqqsogo/HU9GfFo+Lt3JRfelB/D8ujMY+Plh5JUon9b03p/ViZN1x5M1/4WM0MDPD2PTWc0aQOeVlKPvp4e0PmaVuLpCjH3ijFvovD1Y+aAvkimqSR7Xx4Q1kXXuy9RwyXNVlRhJJpLMqZ/6JX+mb7qInTF38cvJFDSbtUtHMRERkbbq0wuYzB8TSTo0468YXLmrunrH1F6Uyq6+pueVYOhXR/HcujPovHi/5At9VmEZ1kemmERyqaKyCkEzd6JMSTIi7m4eWs3dg3QNemgBkJt8MpYLrTO3xta5b8y3x7XalyEq7GZvi0VZZRXGrYrErK0xiE/Lx+qjSRj+1TG8omES7OX1UVi0Iw5DVhzFI988/J3VLeM9Hp9hVg2HVx7QLFnQfkH9qjwAYP+VdAz8/HC990OkT2dT6vZH07Qi6cBVy+gLpioJT0REpkUEkdAhkBFjIknHRnx9DHkl5ZiwOhKf7LkqdxtjSSSo6+m18hsYv7c5Gt2WHkBCeoHkvhZzdmPUymPotGgf5m+/jM6L9+t8MpQu3ckpRos5u9XevtvSAxqNcG41d0+d+97ccB4J6fkoFnCZyBYFzaGjtayWMnSF3cYztzBkxVEs3139Gjuo4Re1I9fv1+v4z/ykuKm3KbqTU6z2tqkabKvL45JxMeb3dX3TNHG+6ax61cqGUl5ZBbFYjNScYtzJKUZucTluZBbWe9n1YjmT/oiIyHSZ4tRZMhwboQMwRzVX608nZ2FAS280a+gMHzcHgaOqH7FYDJHoYVY6KiVL4dS52j2VQuftQdLSkbCyUpzVvpFZiFVHkpCQno83B7ZA56aeuJNTjKyCMvRq0VA3v0Qt529mY+z3JzV+3tSNF/C/kymIupENXzd7/PtWH43/fSO+PAoAuPTxMLjYP3wZisViRN/ORbOGznBzsJH5b64rs7bGYuOZ+i/5IPNy9V4eQv3cVG732HcnDBANGbtvDybg/WGthA5DIX1OKvsnOhUrJ3ZUe/v9V9L0FoumVuy7jq8PKE6EPd+zKWaNbA0HW2uN933+Zk49IiMiImPz9NrT+GZiRzwa3ljoUMgIMZGkZxPWnAIArH2uC1o3doO/h6NJNpHcGXsXj7R/+Cby90XNRkMHz96FlOWj5D6WeL8Ag784Irk9+eezMo8f+WAAmjZw1uh4qsTeztUqiVQj6kET5rS8UnRbekDu73Y/X/WyvphbOZJEWV5JOV74+aykwXPvFg3w+5QeWseoiKok0rrjyTgWfx8TugZieFs/nR+fjNOYb0/g2uIRKrdLV+Pvmszft4eMO5FkDIrKKvSa0NJUVEqW0iQSAPwv8gb+F3lDcrtrkCfGd2mCcZ0D9HJhg4iIjNvUjReYSCK5mEgykCnrowAAe6f3w9H4+i2rEcLBK+kyiSRt1szezi5CgKeT5HZFZRWibmTjqQfJNkX6f3YY0fOHwt3JVu1j3coqgkgEmeMBwP64NMm/hS7dzy+Ft6u9zH2ztsZo9Pza49NPJGTWqQQzhIUPliccunYfH49ug+d7BSnd/kamJTSRNX+lFVUoLK2As73ijwV9fClOSC9ACx8Xne+XLNcHf0bjTwXLd3UlIT0fLXxcJbeLyipw+Np9HLl2H5vUHLxhSHdyijFuVd3G4aqcTcnG2ZRsfPTPZcQuGAZrJZXFRERknvJLyuHqoP73MLIMTCQZ2NcHrmNXrPGNWVdJ6tzxlxPJ+PXUDcXbKtDnk0M4M2cwfou8gcikTMSl5qk9Tjp84X8KK5qk5ZeU44v/ruOXkyky94/p0Bg2VlYKl+PVV2RSJo7H38dbA0MQ2KA6ebX/iuq+PZPWnsazPZoq7KHx47EkvNKvudZx1fSvcra3hreLPWysNWuL9tE/l3E7uwjO9jaYNjgEV+7mIy2/BLlF5fhi3zXYWVuZzGhpUu3dzRex+tkuAIDIxEwEezvDV2rZZu3XlS68vfECdk3ri4LSChSUVMDP3fiXAadkFCK7qAwdAz2FDoXk0HcSCahenrzl9V7o3NQTUzde0Kh3nhB6Lz9Yr+cXlVWi+exdSF42kpVJREQWZuOZm/X6PkLmiYkkAzPJJBIeViBlFpRiwb/aN9TstuSA1s8NmrkT4QHu+PvN3vjfyRQs2XUFQ9v4YUqfZmjs4YiY27l4WUG10faL+j3Jf3vjBQDAhZs52Pduf43G0StLyn2y5xq6BHnh6/3xSLxfgIpKMdoHuGPVM52RWViG/JJyBHvXreaorBJjzdGkOg3f2/m7qx1XjR+PVY+5338lDZfuqJ5KSKZr7+XqXi4nEjIkTfalE7ibzuq+t1bc3TzczCxCv88OAQBOzRps9MmkAVLT5tg7wHI98cNJJCwZYfRJJF2KupGNrkFeQodBREQGVF5pPMu0yXgwkSRFky//lqammr2gVNjxvtG3c9Fs1i7J7Z0xd7Ez5q6AEcmKf1AB9Npvmo2jV6SySlynl9O9uBIcvJouWaJ3evZgmaoRsViMTov2Ibe4vM7+Yu9oN5UNAJNIFiI9vwR/SVV0lFVUwc6mupLtelqBoqfVS00SCQDWnUjG7JGt9XIcfbDk3gE5RWXwcLITOgxBaTL10xw8uSqyTnVwYWmFwim1REREZJ40W+di5nbGGk9CwthYsZRdbb+euqHWsrb6kO7z1H3pAXy29yoqq8RY8M9l9Fx2UG4SiUgdyfcLse3Cw2b6LefuNuio9zVHk/CpGX0pTbyvn+SbMYi+rX1iuraDV9OQZMb/rczZygPxWB+p+XJ3IiIyDaYwKIrffQyPFUlS/r6g2SQyS7Ip6hbEEGNyr2ZCh2L05v19yeDH/O5QIr47lGjw45L5mSCn+f2j3xxHjgE/oL8/nIhpESGwt9F8BLkQLt3JRVgjN1jJaUT87qaLhg/IhEQmZuL9P6NxJ6cYANTqhUfGobC0Av/F3cPqo0lCh0JERHokVB6pqkqMczez0aaxG5zsFKctfjiciE/2XMXSx9thUvdAA0Zo2ViRJOXirRyhQzBqm6Nu47l1p4UOg4gMLD69APfzSw16zJ+OJ+tkP2KxGLeyipB4v0BvV9Qe+eY4vth3Te5jeSXCLgc2RquPJOLXUzdw7V4+Jv54SpJEItOQW1SdVJ69LRbvbIoWOBoiItI3MaqXsxt6UvPsbbF4clUkwubvVbiN9PLq2dtiUVphuCp6S8eKJCnZRSyJUyWjoEzoEIjIAny65xreGNBC4+clpOcj4sujam/f0tcFPz3fFU28qqctFpRWwMZKBAdbzaqhvjuUiA+Ghda535zHpZ9JzkT/lt5qb19RWYXVR5Pw2V75STcyDdsu3Mbk3s30PsSCiIiMwzcH4/HlvusAgBMzB8Hfw9Egx/1DzlTr7MIybDl/G+6OthjdoTEmrImUefzSnVx0bip/KERhaQV+OZmC4W390FzOsCLSDBNJRERklCoqq2Bjrbpwdvnuq8goKJVpEq6u62kF6PvpoTr3b3m9p8ITEU3YmHEiSVHyTJHfTt1gEskMLPg3DpN7c5k7EZGlkJ7aFn0rx2CJpNpOJmRg0tqHq2M++CumzjZf7Y+Hn5sDPh3XHqJaPX4/2XMV6yOrz0VSlo+CWCyusw2pj0vbiIjIKH1zMEHhY9G3cnDlbh6CZu7EqiOJWiWRlHnih0jsuXSvzv05RZpVZfIE5aGTiZlCh0BERET1INRZTXpeiUwSSZFj8Rn489xtnLuRXeexEwkZkp+X7bqCbksPGLx1gzlhRdIDhQKPtSciIllfH4jH9IgQSTKmorIKH26JxZbzuk0aKfLab+fqNH9esvOKRvswhUknhqJOTq2kvFLjZYVERERkGLezi9Fr2QG81DcYL/XRbXXq9ot30MjdEe0D3BE6b4/MY5fv5mm0r/9F3oC1lQgbz9zE8fgMNPJwROL9hz2eagZFrD6SiOd7BaGsskqy3I2VSuphRdIDd3NLhA6BiIhqkZ4iN2/7ZYMlkWqcSpKtormRVaRw260Gjs0YaJIoq6xSve3GMzeZfDMB6vxbEhGR+Vmy6wpSc0uwaEecTvd79V4epv1xEeNXR9ZJIgHACz+f1Wh//0an4vHvT2Jz1G2k5pbIrVACgLXHk9H300MY/MUR5BaX43h8Brou2Y99cWla/R6WhImkB55fd0boEIiIqJYzyVloM38PgmbuxMYzNw1+/KfWnEJmwcOyZ2VJjnc3R6PKwr5gr9gfr/a2B66mq9zm43/jcFCN7UhYTPYREVF9RaVk4dFvjuPcjWzczhJ+gmtaXgme+ek0MgrK8PL6KKHDMXoWmUg6dyMLUSlZkts/HE7k+GEiIiNVWCbsKNeBnx+W/KwqT1ReVYWswjJUVYlRUl4Jc/++vfJAvNpJBXX/W7z0vyjczFRc+UXCqzT3P2wiItK7casiEXsnF0/8cBILdVzhpA0uZtOMRSWSqqrEKC6rxBM/RGLcqkhcuZsHsViMT/ZcFTo0IiIyUnklFdh2oXrZWlah8mbbrebuQadF+xA8exdC5+2xiIsU/0Trfgz8i//TrISdDKvzov1Ch0BERGbkppLWAYYyc2uszG3p85u41DxsOsvl99Isptl2ckYhxnx7HI919JfcN+LrYwJGREREpuKdTdEI8HRCckah6o2lFFjAIId/LqZiTAd/1RtqICG9QKf7I92yhL9rIiJS7sv/ruGNgS3w1f54PNMjEAGeTmo/1xgTMrX7KL298QJa+7li8c4rOHL9PgDA3dEWw9s2QnllFWytLaompw6LSSQt23UFeSUVWB95Q+hQiIjIBD25KlLoEIySOoNNcovK9R8IERERGczKgwlYeTABALDqSGKdSbfK/Hba8H0vtTFkxVGZ25fu5MHe1hov/HwWL/QOwvxHwix2wpvZJpJqj+27lW3+ywuIiIgMbf8V1c2xVx9NNEAkREREJJQbmYVo2sBZrW0/M9HWMt8eSpD8/POJFDjZWeODYaECRiQcs6zHenrtKTSbtQu5xdVXQAtKK3Dlbp7AUREREZknZSXqd3OL8f1hzRNJJxIy6hMSERERGVBkYqZa2527kYW8EvNYIv3docQH/5+AGX9FG+WSPX0xy0TSiYTqP+JXHozta/vRXiHDISIiMmvfHkyQe39ZRRV6Ljuo1T6fXnvaok7IiIiITNnR+PtqbffED+bVKiBo5k58tvcaNkfdxsVbOUKHYzBml0iqqKyS/Hw6OQtBM3cKGA0REZH5+2Lf9Tr3lZRX4pmfTtdrv98dkp+gklZeWYX3Nkdj+h8XkGoBU/KIiIiM0a7Yeyq32X7xjgEiEc6ha+ol08yBSGxml/te/TUKey+nCR0GERGRRUleNlLSm/DdzRex9bxuThZVNe988/fz2Bl7VyfHIiIiIu1N7NYEy8a2V/i4JRR5PN7RHysmdBA6DL0zu4okJpGIiIgMr9msXThwJQ1isVhnSSQAOJWUiaCZOyX/k5ZbVM4kEhERkZHYeOZWnc/qLeduI6OgVKCIDG/bhTsImrkTPZcdQGWVWdXsyDC7iiRLyHISERERERERGautb/TC2O9PCh2G4OIWDoOTnY3QYegcE0lERERERERERHqgapm+KTK7pW1ERERERERERMZg0Y44oUPQOSaSiIiISKcmdgvEgkfDdLrPq4uGo62/G5aNbafT/RIREZF+vNovGFcXDRc6DEGMat9I8vNPx5MFjEQ/zG5pW3llFVJzitH/s8NCh0JERGRRoj8aCndHW8ltXS03/+HpThjRrlGd+7su2Y/7+ZbTwJOIiMjYbX+zN8KbeMjcV15ZhZA5u4UJSAC7p/VFqJ8rms3aJbnP3HolmV1Fkq21FZo2cBY6DCIiIoty8L3+MkkkXWnoYi83iQQAZ+dE4MPhoTo/JhEREWlu48s96iSRgOrv6O8PbWn4gAQwc0QoWjdyg0gkwrTBIZL7zSmJBJhhRVKNe7kl6LHsgNBhEBERWQR5jSTrW5F0ffEI2Nkov+YlFoux9fwdvPdndL2ORURERNob2c4P3z/dWeHjYrFYpkLHnAR6OaFToAce6+iP/i29IRKJhA5J78yuIqmGn7sDvn6qg9BhEBERWazPnwyv1/NVJZEAQCQS4YnOAUhZPsosp6IQERGZgvYBHkofF4lE2PRKD8MEY0A/Pd8FR2cMxFdPdcSAVj4WkUQCzDiRBABjOvjj0PsDhA6DiIjIrCm6cDOucwACvZwMGwwREREZXERrX5XbBDU0rxY0nZt6YrAav7c5MutEEgA0a+iM1o3chA6DiIjIbI3p4K/wsaMzBhowEiIiIjK0Xs0boIWPi8rtfN0csGNqHwNEpHstfV3QxMtR5r61z3URKBrhmVfHJwV+fakbtp2/g8T7BWjp64qFO+KEDomIiEyMp5MtsovK1d4+edlIs+0FIM3FXj+nEn+YYfk7ERGRuWnk7oANL6v/md3W312P0eiHi70N/nmrD25lFWHm1lj0C/FGj2AveDrbCR2aYMy+Igmonvjycr9gLH+iPcZ1CRA6HCIiMkE73u6r1tU2AAh7MK3DEozrrPpz9dV+wRrts2uQJ3oEN9A2JNKz/i29hQ6BiIiMwIV5QxA5a7DQYejF6wOaS36e0LUJHGytEeLrii2v98K0iBB0t/DzFItIJEmzqnVi/8sLXdEp0EOYYIiIyCQsfbwd/D0c4aHGePuvJnTAplctp5pm1shQldsMaOWj0T5trS3u9MSkDGzFRBIREUFnFTmLxrTRyX506cPhofjtpe6Y0KUJpkeECB2O0bGIpW3SXOxtML5LADZH3QZQPapvw8s98OGWGGy/mCpwdEREZGxGtvPDhK5NAABWVsqrjC59PExvS72MkZOdNextrFVu17O5ZlftmjYwr2ac5sZSqu2IiEixxY+11fq5f77WEycTMuHqYINgb2d4OhnnErE+IQ3RJ6Sh0GEYJcs525Xy6bhwPNczCPcLShHsXb1M4eunOjKRREREdXz/dGfJz6qSRJaURAIAGxWJNWmdAj1w/maO0m1cHWwwJMwXM4errnIi4WiaGCQiIvNyffEI2NloXz3cNcgLXYO8JLfjUvN0EZZOdGjigRnDWgkdhtGzrDNeKabY5IuIiAxrUKjskixleZMtr/fUczTGR6zBtlMHheCFX84q3aZHcAN8Ob5DvWIi/Wvp6yp0CEREJKD6JJHkad3IcJ8rf7zSA0+tOSVz31cTOqBbMy80cndg1a2a2ISAiIiM3v53+wly3A9rVcbMGRWmcNvOTb0UPma2NMgkDQz1wYmZg5Ru4+agugcVGQc3B4u9FklERDomEonQSslFilfUHNrh62Zf574zcwYjYckILH6sLfa/2x89ghvgj1d6YNnYdpJtqsRiNPZwZBJJA0wkSVn/YjehQyAiolreHtQCLXxccWHeEIMeN9jbGa38ZE9qmjVk7x5pa5/votH2/h6O+PvN3pg7qrXcxyNaa9aUmwyPDUeJiEgfvhgfjgYKmndXVql35Uq6UuqdiJbY8nov+Lg6wMbaCs/0aCqZvtsjuAEmdguUbKvm7kkKE0lS+nGcLRGR0XntwfhVT2c7pCwfhTcHVt9+tV8wzsyp38jZv9/sjQvzhiBhyQjsntZXcv/cUa1x8L0B9dq3JdBm9G2HJh6Y0jcY5w2cGCTdeK1/9etv9bOaJRGJiMi0BTVw0uv+2/q7I2puhNzHKqvE+O2l7vhmYkdM6h4odxsAmNKnunJpYCtvTIsIQeemnmodu4qZJI2xLpmIiIyak53sR9UHw0Lx/tBWkvLjyFmD0HPZQY33O3tkKDo08ZDcbt3IDcnLRqK0ogoOtqonkVH9eMm56sjzOONnZ119DbJn8wZwsLVCSXmVwBEREZG+TY8IwdRBIZi1NQZd9LiUX3pp2ROdArDlfPWkdbFYLJmeNiTMF1YiYHCob53ei0PCfDG4tQ8auTuqdbzxXQJwMjETo9o30tFvYDmYSCIiIqP1xys95N4vfaLRyN0RKctH4VZWEWZvi8Wx+Ayl+5zUPRCzRoTCVU4/HpFIxCSSgBp5OAgdAqlgJdVx/tzcIWjz0V4BoyEiIn3a/mZveDjZItDLCSKRCJ+OCzfYsTsGekgSSf6eDxNDDrbWWPxYO7nPEYmAAA/1K6c+HRcOsVjM3khaYCKJ1GZtJVJ7fSoRkS700GDpVBMvJ/z6UncA1SXQhWUVKC6rxI3MInRrprurZx89GoaP/42Tua+pgnJvR1trFJdX6uzY5ubnF7riRHwG+rb0xq2sInQKVK8EnYyDsz1PI4mIzFWXpp4Il6rcNjRvV3v89lJ3HLyajud7Bcnd5sgHA9D/s8OS2/IuEqrCJJJ2eAZARERGadpg7Zv6WluJ4OZgCzcHW/i66bbK5YXezXDpTp7kKtmj4Y3x7pCWcrf9eEwbzPgrRqfHNycDW/lgYCs22DYVg0L5b0VEZCn+fK2nIMdd/WxnXLyVgyGtfWFlJZIsaZOnaQNnRM8fit9O30DXIC+48AKHwfC/NBERGaXJCq4+GZtvJnZU+BivcZE5+eJJwy1pICIi4eyd3k+wSp1hbfwwrI2f2tu7O9nizYEt9BgRycOpbbW82j9Y6BCMlgjA8rHy16MSEemap4IRsMZADC7znTOytdAhkAH1DG4g9zV5bMZAAaIhIiJ9auXnKnQIZOSYSKqluw77aJijCV2bCB0CEVmAPi0UlzGTcZjcO0joEMiANipofN/Eywk/v9DVwNEQEZG+dAr0EDoEMgFMJJHa3hnSks3IiMgg1r/YTegQlGNBEj3wUp9mQoeglt3T+upv53w9EBGZjVXPdhY6BDIB7JFEatn/bn8093YWOgwishDSI8ZNmTkn3431N1v7XBdMWR9lsOMFejkhZfkouY/9GpmCedsvGywWZUL9XBGzYCiW7bqCjWduafTcg+/111NURERkTDo08YCPq26HlJB5YkVSLWJeVZOrhY+LWX8hIiLj8dPzXYQOQaXpES3h6mCDNwY0V7qd2Iw/VGysjfMUIiLM16DHU7bke0xHfwNGopqbgy2WjW2P+CUjMLyNHxxtrVU+Z/ubvRHs7aJ0m5bspUFEZHRe7ttMo2mbjrbWWPMcq5FIPaxIIiIiozGirR8GtzZsIkAbgQ2ccHH+UFibSeUUac9BSTLGzcEWvm72SMsrNWBE8klfDLK1tpJZurA/Lg2RSZloH+COhi722BGTijPJWdg9rR/sbFQnDP09HPUSMxERae+9oa1gb2OF4vJKhM3fW+fxfe/0w5AVRzFzRChe66/8whhRbUwkCeTtQS2w8mCC0GFoZcWEcGw4fRNnU7KFDoWIzMznJjRenEkkUoerg61RJJKUiQjzlank6s1m90REJs3exkpyocPJzgbRHw1F1yX78fHoNkjNKUbP5g0Q4uuqcGk2kSpMJNViqFUIbfzdDXMgPXi8YwAe7xiAoJk7hQ6FiMyMs715fSxxSTC18nVFQnqBoDEsGtNG0OMTEZFh7JjaB9G3c9AvxFvmfndHW1xfPEKgqMgcmdcZu5Fr3cgNTnbWGNHWj72YiIgsgDn3SDIFLXxc9JrEWfxYW5XbfDQ6DDtj7+otBnWM66y4jxMREZmPtv7uaGvCBQtkOoyzU6aAPJ1t9bZvF3trbHm9F6b0DQZn5RIRyerS1FPoEMjMqJPo0dbIdn54pkdTldt5OdnpLQZ1OdqpbqpNRESm7ZkegUKHQBaEiaRaOgXq74uMq8PDJJU5XKS+snC40CEQkRn5dlInoUPQOS5tE8ah9wfg9ynd0aaxm873/d87/bD08Xb4dqL5/b3Wx8wRoUKHQERk0T4YxvdhMhwmkmrR5Ul/c29nAMAj7RuhQxMPLJTqUeDmqL/KJ117VsEVV0c7azxuZKONyXxsfrUn+rRoiP3v9sOqZziK1BL4uTsIHQKpqXeLBkKHoFSzhs56aRgdHuCOlr6umNQ9EFYm0my9rb/uk2nydGvmZZDjEBGRfO4m9P2STB97JOnRgfcGoKS8Uu5o4F7NjfskvMbrA5rjw+GKs9sDWnlj24U7Wu9/cKgPDlxN1/r5ZL66NfPCb1O6AwBa+Lji8PsDMODzw8IGZYGc7axRWFYpdBgmy9baNJINmlo+tr3QIQhi1sjWGj9H6Ol+TxqoP5KrmTXKJyIiIsVYkaQnX46vHmEtL4kEmM9yB183zSoIImcNkrn90+SuSF42Er1bNMCgUB9seb2XLsMjE+PmUP1FZOnj7eo8FtTQWWcjSuOXjMDKiR11si9zF7tgGOKXcMqHtoa39dP5Pn1c7XHo/QFY8GgYJnYTph+Cpfbc0SYxKBKJ4Okk3FViRVXFuhbi64ppg0MMciwiIpLF81oyNCaSdMzO2gprn+uCxzqYx5KvR9o3Uvp4dw1K2Ru62KORuyPOzxsCoHqaDlB9kv37lB5YN7krOjf15BuhBYv+aCguzh+CSd0Vfzk+8F5/jO8SUK/j2FpbYXR4Yzzf0zBfsEyZlZUIttZW2PxqT70eR59NkYVkb6P7hMuxDweiWUNnTO7dDMvGtoO/h6POj6GKOfT504aVlheB3h/WSseRqM+QS/DeGdLSYMciIqJqUwe1wOjwxkKHQRaGiSQdm9w7CBFhvoL3TpCeflSfpRVtGisfH6lJZdXe6X0BAF7Odrj08TDsmdZX7naavBFGtPbFgFbeKreb2K0JkpeNxLXFw3FmzmC190/6tevtvjgzZzAWPdYW0fOHQiQSwUPFhKPm3i74dFy4Vscb3yVApiquPpWBx2YM1Pq5pkL6faRbMy/0b6n6taYpXzd7/PNWbzytJHlIDx35YECd5JR0FciaZw3TT8ze1jJPH8IDPLR63oi2yi/K6MumV3oIclwiIjKcR9oziUSGZ5lngnoytqM/3jWSq3HSPRlEqPtlOfhBI/BHlSRtfN3sdRpTA5eH+3Oxt4GNteI/v+1v9la6r1A/VxybMRBrn++CX17ohtOzB+OXF7oq3P6NAS0gEolgb2MNH1cH7J3eT/NfgHRqaJgvwhq7wcfVAc/2aAp3DZd+bHldswqZRu4O+HRcOBq5P6zeeKVfsEb7kObqYINBoT5aP19Ij7RvhL4hqhsR156i9vPkrlg+th0auujuveHrpzqifYCH2Sz31aeI1j5o2sC5zv3jOgdg/YvdcH7eEAxt83Apna21SK1Ee23LxtZdWirt1X7BcHMwjYaeuv670vYikZeznU5fN+rqHmz8/Rgbuii/eEBERIqdmT0YrfxchQ6DLBATSTrS1t8Nn4xrr7AnkqF9JlWx4eZYtwHmzql9sf/dfuijg8k7f6tI+gCanyiGN/HAhpe7Kz1mEy8nyW1fNwcMaOWDmAVD5W5fu5dTKz9XpCwfhR1T+2gUF+lO16D6Tfjp3NQLPz7XRe3tPx1XtzlwYw9HJC4diRnDNV92IoIIrw9orvHzhDRjeCskLh2Jbyd1wq8vdUfS0pFYOKaNwmbAtaeoWVmJ8FS3QETNjUDXIE+5z9HEZ+Pao4cJfNE1FrUTezWsrETo19IbXs7V77OXPh6G8/OG4PriEfjlhW6IXzICzRo+TED1U1FZ9lRX5c2Z3xjQQsPICVC9VFzX/nlL9WezMXhr4MO/pxUTwtG5af3fW4iILIWPhv1qiXSFiSQ5mnvXveKrzE/Pd8GOqX1hq6TCxtCaeD2suqh9Rfbxjv5wtLNGCx/Z7PX6F7vJ3Fa3B0aHJh4KHxvYyhvjOgdgkxb9VXo1b4glj7fFqmc64Ysnw3Fm9mBcWTgclz4epjBh5+Zgi+RlI7HplR7oFOiBRu4O2P9uP9jZyP+3aevvjt0KltiRciPa+sFPyw+vJY+3xeTeQfWOYUiYL5zUaPp7ffEI9A2R/+XZ2kqENwa00Kg6aVT7RnB3skXXIC9Ezx+K5GUjMT1C8yazKctHYcUE7ZbpqcPfwxFPdg7Ay32bIXnZSLwxoIVM0sjKSoTnegYhcelI7J7WF/MeCZM8tmhMG6X73vxqT5mlb+oa3yUATbwc8em49niyi2GmSZmD3dP6qn2hwsXeBl7OdpL3fltrK0nStYGzHda/2E1h43pbaxFEIhF6K7jI8PbgEI2rB82FdDJOG7NGKp6Aqg8tfYW5Qq3pOVR7qXMIDyc7/PVa/fqxhQcoX5JPRERE9cdZrXJoMqr30fDGGNzaV4/RKLbl9Z544odIldtJ/zZXFw2HvVRSRboqRNVVamU+Ht0GucXleHtwCPp8chC3s4sBAKM7NMbjHbVvjPx0d82bIYtEInQPboCtb6h3NbZ1IzdcXTQcH/wVg3+jUzU+nqWJ/mgonOysYWtthaoqMVIyCzHoiyMa7UObf1dF/p3aB4NVHF9RIlHa7JGtMbytH1ztbTBkxVGl234nVRlS86Va04bHNb1sHu8YgHc2RWv0XFV+eLoTvF3tEd7EQ+0Ed+tGbmjdyA1NvZxQXF6pdNkrUP06++XFbmj70V614xrbyV/r/laWbMfUPmjdyK1e+2jh44LjHw5EA+eHy6s2vNwdk348XetY1Yn1tc91Rev5e+rsZ4hAn3fastbh0rYD7/av1/Ptbaxx6P0B2HD6Bg5cSUdSRqGOIntoaJgvbmYV4ZkeTQWrkP7ztV7otGif2ttLv3f6uNrXeznitjd6Y+ofF7Az5m699kNEZEht/d3Qv6U3vjuUKHP/OxEt8eupFGQUlAEAQnxcMGtkKGZtjcWX4zsIEClRNeMpoTEi6k5l6d2iAT6Ts1zGEPa90w+dm6q3NEj613GwtZY5SQv2dsH+d/vhwoNJatLkjWBX5PleQXj7QcPXdZMf9ipytTeNK9cOttb4ZmJHhX2WWvm6qtVTxpw917MpkpeNhLujrSQ5YWUlQrC3i6BxNfd2QSslV96vLR6u9r46BXoiRMur+N6u8vuffP1UB7nTDaV72dT0I9swpTt6Na/fUq+5o1pjRLtG6BLkpVWVZESYr8okUg0XextMULOqKLyJB094tPDWwBZo66+bCosATyc4SlXw9WreUOb1sf/d/pI+C45qVPqZAl3+HroYotGsoTPmjArDS32b6SAiYMvrvSQ/r5gQjjXPdcGe6f3wTA/hJlLWLLFUlwjV/dcWP9ZW5YAPdVhZifDdpE5qVasSERma9ICM6I+G4n8vdkP/lt74dmKnOhdaj80YiGkRITg7J0Lm/kGhvjg9OwK9W1j2dxMSFiuS5FD3Kt6UvsGCXfHT5MuuvGbb0movcasREabdleeWvq5YNrYdYm7nmFwz4gGtfJC8bCT+jLqNqBtZeGNACzjZWUvWH/8TnYq3N14QOEphLByjm/HsazXoa1RfP7/QVefj1397SX7vrv4tvdHCxwUJ6QWS+87PGwIvZzuM6eCPoJk7Fe7zwHsDcCurCK0buaFrMy+M/f4kYu/kahybq70NpvTVvoG4NmaOCMW1tHxcvJWjdDtOj9LO8LZ+qjeqB3sbayQuHQkR6iZKujXzwpnkLJn7mmm4bMlcrHpGfn8qbbX399DJfjo39UTyspHIKSqHp4YJHKMhAgbWOlcI9HLCzawijXclXX36Qu+gOlf2iYgMbdfbfTHtjwuIf3B+OD0iBC72NgjxdYG7oy36t/SWmYp7ffEIycTtmov/IpEIYzo0xvaLqXhrEPsUknFgRZIcnz+puMrIz80BHwxrha8mdMAAPYzCVoeiyhht+9Xow8RugVg2tr1OruAamkgkwviuTfDpuHAENXSWaWI3OrwxouZG6HyinTGb/0gYfn2pm9JtrixUr+onZfkorROUyvQIrlvxs/PtPhjYSrtEZvySEZKfpRMgfUMaoo+C159IJFI57QoA3B1t6yRVXOxtJEuXbK2tZKoMNPGvAM3jPZ3t8PebvfFqf8UJrM+fDDeaQQRCqM90QF1VIyljbSWS+179+5TueG9ISzRt4IS+IQ0RNTcCLvaWd/1p0Zg2GN5Wt42y2wW4Y7ka7xfK1PQ0E4lEpptEUkDecAR1nJdTXU1EJIQVE8KRsnwUwhq7YdSDYQvBDZ0hEonwcr9gDFBwjmpnYwWRSFRnme+K8R1wYuYgjOngr/fYidRheWeEalBUofP1Ux2M4sWrqEFts4bOuJdXAqD6xLJvSENkF5XBzcFWcj/VX0MXezzboyk+/++60KHo3bKx7TCxW6DK7dRZPlKfL9OqfDgiFE28nDA0zA+lFZV1lvBoytbaCidnDkJWYZnMF3llS+gAKK39++u1nth24Q5mDAtV2azYzsYKr/YPxoZTN5FfWqFWzKPDGyOons2A6+OtgS2w+khSnfsVNXW2JDOHh2LN0br/bVSRlyA1JFtrK0wdHIKpgzVvJG8u3hrYAs/2DNLLvrVdknDp42FmldCTVzWt6fI4oDqJJP3fxUVPS+vb+btjYKgPVh6I18v+icg8SPeIfXNgC4T6ucltdaAuKyuRxv04ifSJFUkaaNO4fs1O1RHc0Bmj2unmyuf6F7vh37f6aNQ8nNTzWn/jGfs+WGpJwJgOjZG8bKTO9q1OEkkd307qiNkjW+tkX/I42VUv6Qps4IQQX1ed9EVp7OEoSSL981ZvvD6gOd4Z0lLpc9oHeMjclv5S0yXIC0seb6f2xKtZI1oj9uNhmDtKvf9uUwUudXZ1sEXSUt397ZkTbSszI0yssbW5eSeiJd4f1kpv+9e2isickkiAbB9Hbc0d1bpO8un5Xk3RN6QhFj32cFl2cENnJC8bqdVra/nYdvB1s8dnT7bHOxEh2Du9X73jJiLjFdbIDZN7BWHOyNZ4pH0jDGzlrfY59pQ+sn3wbK2tMLytn9lVj5JlYyJJA4oqlXTt8yeVTzXqo+AqZnMf2WqEmrLIRu7Gs+TNXNhYWwleLQAACx4Nw0+Tu0p6UT3fK6jeE29qDGylm6Wbjd0d8Eh79Zo3G6v2AR74cHgonFV8gbOzsUL8khH49aVuODlzkFrT4lSZ0jcYRz8YqHI7bZuE65KVlQjfTuoouT2mg2n/u+uLOg3guwZ5CtowmYBpEfqtxNImIVTf6X3GyMOxbnLd11Wz8xZ5veGc7Gzw60vd8WyPppKr+KPaN4JIJMLQNponkp7qFohTswYj1M8NIpEIrfxcsevtvhrvR2gNXewxJMwXYzsJX2FPZMx2TeuLBaPb4OV+wfh2Uif8/EI3mXPsl/ooHprQJUj47whE+mZel7V06I0BzfH9YWGaNCqrpgj1c1VY8j1jeChsrKwwutaXt1kjWqOwrBLj1ZyuROrp3NQTp5KyVG+oJ9JLhtY+1wU5xeVaLQeo7efJXVFaUYWI1rpplK7PJW3GyNbaCn1DdNs/LbCBEx4Nb4x/o1PlPq7LKrT6eqR9Y4T6uWJnzD2dTaYyN/Y21jj8/gBkFZUhNacYb22QbeCfvGykzhLCpJ3Nr/YUOgQZ/Vp6Y+bwUDQTcPmqPsQtHAYbOdMl1a3cBICD7/VXuc32t3rjZGImhj+YljmuUwBm/BWj1v6/fqqDpAdl7ddlQxfTqi44/P4A+Hs6wtbaCvdyS7D1/B30DG6AyKRMnR5n2xu90DHQU3I7Pa8E3ZYe0OkxiIQW3sQDk3sFobGHA1wdbOHv4YjtF1MhFosxRA/9QImMDSuSFHhXxRIWoSgbQe/mYIsFo9ugk9SHN1BdPv/dpE4yEwGo/gZrURofNTcCSUtHopuWa6RrJp5JV30A1ZUgypJIVxep1wwbqJ6eM7ytn9yTe01FtPbF872C6r0fAlY+1QGxC4YiZfkoHJD64rTvnX5Gl3Ro4eOKaQ+mklC1miWK6yZXv4aDGjqjU6AnHmlfvRz1uZ5N8ViHxjg3N8Lo/j0tkbbv0ZpydVDvNTI41Adhjd10smzXmDjZKf79Q/3Uq7IM9nZRuU1DF3uMDm8sqRK1shIpXKJesxQuxMcFyctGYkwHf3QPbiB3W00SXkJ7b0hLBDV0hu2Dz3Y/dwfELRyG36fIn0Ra4+unOsgMoADqLqUeHOqDQC8nye2Otc5DfYxoGAyZL3nJm0Cv6mERp2YNxqIxbfDRo2E6O56NlQgLRrfBK/2aY2K3QPRr6Y0vxofjywkd2FaELALP8hWwsbbCuM4B+OvcbaFDISNVO2GnjoYu1dPeNr/aE6uPJGLZ7qtqP/ePV3qgR3ADtZoXr32uC6asjwIArHm2MxxsrdHO313lSPkAT9028evc1JNfinVEJBLB1aH6S0tzbxfELRwGEURm98XSXE3pGyx3+Q1Q/W+7cExbuY+R4alaXq5LRz8YiI6L9il8PMTHBR8Ma6XVhQshdW7qiXM3suu1jxnDW+HFX6Lg4WSL0eGNsT7yho6iq/bh8FZo3cgV7fzdkVNcjrHfnwQAPNujKSZ2baLWxRR7G9N5/5XXNF9eIu/Xl7rh4NV0/HIyBV9N6IDR4Y0hEomw6ZUemLDmFL5+qvq+sZ0C0NTLCbnF5fB0toNYLMb/TqYonDR5bMZARN3Iwuhwf1hbibBi33V8rYOG5S/1aQYbKxFWH03Cxpd7YOKPp+q9TzI9r/YPxszhobh6Lx+PfnMcFVViBHo54dD7A2Alqv6cfbZnEHbH3pU857meTeu8r3Rv5oXTyVkIb+Kh8Fgv9WmG8zez2ceQLB4TSUo42Oq3YGtMh8bYflH+UhWhDAr1wcGr6RjZzk/oUEzCojFtMG/7Za2e+0LvZrh2Lx9bL9xRut3F+UNQVlGl0RW9iDBffDepE9LySjD0QSn/a/2b480N55U+TyxW+xBqeb4Xe7zoi7Ir+USkPekBBvqmrPGqKU9nW/9iN/wvMgUJ6QXYel75Z5wiA1v5YMvrvdDc2xmuDrZyE0nHZqjuH6eISCSSmcQ7uVcQPB5UGOmiIteYSFcLyRPg6Yjb2cX45Il26Bvijb4h3pg3KkxmUED3WheyapZZ1vwNi0QiTO6teDlzEy8nNJGKY3pECEa1b4QW3i5IySzEkBVHAQCXPx6GDadvYuGOOADAB8Na4bO91+Tu84enO2HEgwE1sx4M9DgxcxDiUvPw8oOLaWQZZo2o/vdv3cgNe9/phw2nb+LV/sF1KoMGt/ZFeBMPdAhwx8dj2mLhmLa4l1uC7KIyNHJ3gI21FXbGpGJImOLvQfMe0V1VE5EpM80zFAOZ0icYv526CaD6A0/Xlo1th5yichy5fl/n+9bWyokdcfhaOga2MtyJtCmrXe5++P0BGPD5YbWea2djhS8ndEDP5g1w6Fo6dsXek7udh5N2PRhGtZed/hcR5iM5WVRE3aUE8qyc2BFvb3zY6+XnyV2Z7CAik2MMU3W6BnmabBIJAJztbfDGgOrlT1+O74CgmTs13odIJELnpsorf5uoSJBoYsHoNlo9b0iYL/bFpcl9zM3BBnklFfUJSydULTHfObUv4u7myYwm13bapLpEIhFaPhgSEeztgkSpqZ8v9A6Cn7sDbmUV4YXezTCxWyB2xKRidHhjlFZUwcfVHnklFXCX06jd38MR/h6OmDuqNRbvvKLX34GMQ+3kf3NvF4XJHjsbK2x/s7fMfX7uDvCTGkw0oatuJhYTmTvTPUsxgKCGzri6aDhuZxehuRpr8DXlZGeDUe0bySSSGmjYuHHH1D74ZM9VfDg8VCcxudjbmPyELSEFNXTWuHHlk12aYFznADSbtQsA0KWpJ6LquSRAHnsbaxz/cBBSc4pRVFaJiC+P1Nlm+RPttd7/6PDGGNbGF2Ix4GBrOuX+REQ19HHRSBufjjPc8jpTMaKtH3ZfenjB5e9aXwaFYmstP+GycmJHjA5vrFUS7fLHw9Dmo731DQ0AsOqZziqHZ7g72aJnc/l9oIQgEokwst3Di2FeNnZ4rmeQzDbykkjSpvQNxqojicgoKMOLvZvhyt08nTcVJ+Pw3dOdhA6ByCIxkaSCg601WvgYZqx2r+YNsHys/C/yr/YPxrbzd/BKP9nmkG393fHrS8obJZL+NPZ42FPosQfT8hS1BFK2llokEuHEzEGISsnCqHaNYGNthR0xqfDVQ4PKmpg3vNwdk348LfOYt6t9vfZtSv0iiIhqU9SAWZ9WPdMJr/0mu+zY3Kaz6cLyse1lEkkdlPQwMaRZI1rXqSie3CsIo8OrzwmCvZ2RdL9Q5vHaSbHanO1tELNgKNov+K9esZ2ePVgv5xGmYsfUvjh8LR2PdfSHg601isoqkFdcgVd/O4foWzm49PEwVFaJcT0tH0+uinzwnD4IauiM307dwHIN+liSMFwdbHjxkkggTCQZkQ0v95B7/9Y3eqFToCdmDg9l42Ij42Jvg+MfDkR8egH6tKieqDewlQ9OJta96tU+QH4Dyhr+Ho7wl+rXoO/KsF7NG2JyryD8cjJFr8chIjIFXz/VQZAvJKF+bpKfQ3xcsGxsO4PHYArcnWxhZ2OFsooqvDnQ8Ak/ReQtr5NeVrPtjd4I/1g2IbTosbYKE0nDH/Q1dHPQfiJcz+AGWDmxY70vDpk6P3cHPNXt4TIlJzsbONnZYNvrvVApFksm2HUNerikz8fNHi72Nnitf3P0DG6A6Zsuwt7GClfv5Rs8fnNgYyXCksfb4sMtsXrZf+yCYXrZLxGpxkSSkVr7XBfsi0vDx2PaSE5smUQyTgGeTgjwfHgi+ULvIBxLyMDRWr2vnIxwutaC0W2wYHQbnLuRjSY6nthGRGRKpBsvC+XfqX14dV2JA+/2x5Hr9zGuc4DQoSgl3eDX3dEWix9ri7l/XwIA/PVaTzR0sccnT7ST++X6h2ceLtPpG9IQx+IzND7+/17sBjsb82oYrktWViJYQfac+o9XeqCwtAI+rg8ruMKbeODQ+wMAAJfu5CL6dg6KyyoN0nvpw+Gh+GSP4oqoL8eH493N0XC0tcbmV3vi0W+P6z0mVR4Nb4ygBk54d0hLlFZU4dyNbIQ1coOnsx2GtfFD3N28OpXw9VG7FygRGRYTSUYqIswXEWEcK2mKbKytMG9UawyplUh6urvxTjBT1dCUiIj0I9DLCe383eFkZw17fvlXqomXE57pYbyfpYo806NpnbjDGsmvUpa+aPjtxE4IX6jZ8ra1z3VhEkkLPYKV94hq6++Otv7V/2Yv9WmGG5lFOJmYiehbOdgUdUunsXw1oQPGdGiM3i0aICWzCH+duy25ONm5qSe2vN4LADCibSPY2VjB2kqElOWjkF1Yhr2X72FU+0ZwsbfB2ZRsjF9dvWRvVPtG+HZiR4hEIry98QL+iX44NbqlrwuupxWojCvUzxUFpRUyQ1uc7azRPbgBZo4IlTRPB6pbg/R+UKkPVA+O6dW8IVKWj9Kqb5i0ox8MxLGE+5Llo0QkDCaSBGbpZcfmKsS3bl8tRyOsSCIiImDzqz0FO7aVlQj/vFXdOJqVx6bppT7N8NPxZADVVRnqaCdnuXu/lt4yt92dbLHosbaY96CaSR28CKl/IpEIQQ2dEdTQGUPCfCWJpG1v9MKZ5CwsU6O30qv9gjFzRChyisrxb0wq5m+/jEVj2mBspwA4P5jY2D7AA+0DPBQmTGqfV3o628ks5evWzAtJS0fWmcC3cmJHLH68LaxFIlRUimFva4XOi/ahsKxSzu8KfDK2PSLCfOElNdHylxPJCPF1lUkWqUu6Qq+dvztuZxchu6hc7ecHNnDC0w1ML6FMZG5EYrFYLHQQlkwsFuOL/66jTWM3jGjHEk1zUvuKS8ryUQJFQkRkudS5+s33Z/2q/W/w8wtdEdTA2WyaimcXlqHjon0AgE2v9EB3FdUtNZ5bd0ZmGfze6f3Qyq/uhSh1Kzg2vtzDqKavWYrE+wVwsbeRaWz+5u/nsTP2LoDqxIm/hyOOXL8Pe1sr9AvxrpOAyS8ph2s9+mLpQl5JOXbH3kXHQE/E3M7F/O2XsPb5LujVXPNkkSq3s4uwI+YuJnYNhKuDDV5eH4XCsgp0C/LC7ZxibD1/R2Z7TydbPNEpAJO6ByJYD5O0iUhzrEgSmEgkwvvDWgkdBumBdCPr53vyygkRkTHq1sxL9UakM52bemJgK+Xj6E2Np7Md1jzbGYn3CzX6e1rwaBgGfXFEctvXTfsq9eRlI1nRJpDmchIby55oh27NvDCirR98HiSYBoYq/rsXOokEVDd4n9C1uqKppa8rxnb0r1PNpCsBnk4yUzJ/mtxV8vOqI4mSn2uS/GKxmH/fREaGiSQiPfno0TBM6dsMTnY28HQS/gSBiIjqMrekhrEz16+CQx9MW9NE7YosDyc7BVuqxi/ZxsXNwRbP9woSOox60VcSSZUXegchPa8Ug6QSb/z7JjI+7MZHpCcikQgBnk7wcrbjByARkZGa1D1Q9UZEeiASifDVhA4AqvvrKPKsigbjrKojc2JvY435j4ahT4jul9QRke4wkUREREQWy92RFaMknMc6+iNl+Sh0DFQ8PfXj0W2U7uOLJ8N1HRYREZFSTCQRERERkUHUZ/mWpbKyEqFL0+pEU+emdRNOTbycDB0SERFZOPZIIiIiIovUxMtR6BAszqLHlFfXkHx/vtYTBaUVsLOxQqu5eyT3dwz0EC4oIiKyWKxIIiIiIov06RNcEmRI0yNC0MidyTttiEQiuDrYwt7GGiPaVjf2Htc5AFteU9xbiYiISF9YkUREREQWydOZ/ZHI9PzwTGehQyAiIgvHiiQiIiKySH5uDkKHYFFsrXnaSUREZA74iU5ERERmK9jbWe794zoHsPGzgUyPCEGonyue7al8jD0RERGZBpFYLBYLHQQRERGRPmQVlqHTon117k9YMgI2rJAhIiIi0hjPoIiIiMhseTnLrzpiEomIiIhIOzyLIiIiIrP27aSOQodAREREZDaYSCIiIiKz1sLHReb2mwObCxQJERERkeljIomIiIgsyvtDWwkdAhEREZHJYiKJiIiILIpIJBI6BCIiIiKTxUQSERERWYxHwxsLHQIRERGRSWMiiYiIiCzG/EfChA6BiIiIyKQxkURERERmrbGHo+Rnb1d7ASMhIiIiMn0isVgsFjoIIiIiIn26mVkEOxsr+Lk7CB0KERERkUljIomIiIiIiIiIiNTCpW1ERERERERERKQWJpKIiIiIiIiIiEgtTCQREREREREREZFamEgiIiIiIiIiIiK1MJFERERERERERERq+T/1GpFoBMsRkAAAAABJRU5ErkJggg==;" parent="1" vertex="1">
115
+ <mxGeometry x="236.79" y="370" width="292.36" height="42.48" as="geometry" />
116
+ </mxCell>
117
+ <mxCell id="FjPlvFc2R7vJgaK0KZA3-34" value="&lt;font face=&quot;Times New Roman&quot; style=&quot;font-size: 16px;&quot;&gt;Waveform&lt;br&gt;&lt;/font&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
118
+ <mxGeometry x="315.97" y="340" width="134" height="30" as="geometry" />
119
+ </mxCell>
120
+ </root>
121
+ </mxGraphModel>
122
+ </diagram>
123
+ </mxfile>
docs/resources/arch-overview.jpg ADDED
docs/resources/arch-variance.drawio ADDED
The diff for this file is too large to render. See raw diff
 
inference/dpm_solver_pytorch.py ADDED
@@ -0,0 +1,1305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+
5
+
6
+ class NoiseScheduleVP:
7
+ def __init__(
8
+ self,
9
+ schedule='discrete',
10
+ betas=None,
11
+ alphas_cumprod=None,
12
+ continuous_beta_0=0.1,
13
+ continuous_beta_1=20.,
14
+ dtype=torch.float32,
15
+ ):
16
+ """Create a wrapper class for the forward SDE (VP type).
17
+
18
+ ***
19
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
20
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
21
+ ***
22
+
23
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
24
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
25
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
26
+
27
+ log_alpha_t = self.marginal_log_mean_coeff(t)
28
+ sigma_t = self.marginal_std(t)
29
+ lambda_t = self.marginal_lambda(t)
30
+
31
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
32
+
33
+ t = self.inverse_lambda(lambda_t)
34
+
35
+ ===============================================================
36
+
37
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
38
+
39
+ 1. For discrete-time DPMs:
40
+
41
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
42
+ t_i = (i + 1) / N
43
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
44
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
45
+
46
+ Args:
47
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
48
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
49
+
50
+ Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
51
+
52
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
53
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
54
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
55
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
56
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
57
+ and
58
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
59
+
60
+
61
+ 2. For continuous-time DPMs:
62
+
63
+ We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise
64
+ schedule are the default settings in Yang Song's ScoreSDE:
65
+
66
+ Args:
67
+ beta_min: A `float` number. The smallest beta for the linear schedule.
68
+ beta_max: A `float` number. The largest beta for the linear schedule.
69
+ T: A `float` number. The ending time of the forward process.
70
+
71
+ ===============================================================
72
+
73
+ Args:
74
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
75
+ 'linear' for continuous-time DPMs.
76
+ Returns:
77
+ A wrapper object of the forward SDE (VP type).
78
+
79
+ ===============================================================
80
+
81
+ Example:
82
+
83
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
84
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
85
+
86
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
87
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
88
+
89
+ # For continuous-time DPMs (VPSDE), linear schedule:
90
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
91
+
92
+ """
93
+
94
+ if schedule not in ['discrete', 'linear']:
95
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear'".format(schedule))
96
+
97
+ self.schedule = schedule
98
+ if schedule == 'discrete':
99
+ if betas is not None:
100
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
101
+ else:
102
+ assert alphas_cumprod is not None
103
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
104
+ self.T = 1.
105
+ self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype)
106
+ self.total_N = self.log_alpha_array.shape[1]
107
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
108
+ else:
109
+ self.T = 1.
110
+ self.total_N = 1000
111
+ self.beta_0 = continuous_beta_0
112
+ self.beta_1 = continuous_beta_1
113
+
114
+ def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1):
115
+ """
116
+ For some beta schedules such as cosine schedule, the log-SNR has numerical isssues.
117
+ We clip the log-SNR near t=T within -5.1 to ensure the stability.
118
+ Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE.
119
+ """
120
+ log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas))
121
+ lambs = log_alphas - log_sigmas
122
+ idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda)
123
+ if idx > 0:
124
+ log_alphas = log_alphas[:-idx]
125
+ return log_alphas
126
+
127
+ def marginal_log_mean_coeff(self, t):
128
+ """
129
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
130
+ """
131
+ if self.schedule == 'discrete':
132
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
133
+ elif self.schedule == 'linear':
134
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
135
+
136
+ def marginal_alpha(self, t):
137
+ """
138
+ Compute alpha_t of a given continuous-time label t in [0, T].
139
+ """
140
+ return torch.exp(self.marginal_log_mean_coeff(t))
141
+
142
+ def marginal_std(self, t):
143
+ """
144
+ Compute sigma_t of a given continuous-time label t in [0, T].
145
+ """
146
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
147
+
148
+ def marginal_lambda(self, t):
149
+ """
150
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
151
+ """
152
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
153
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
154
+ return log_mean_coeff - log_std
155
+
156
+ def inverse_lambda(self, lamb):
157
+ """
158
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
159
+ """
160
+ if self.schedule == 'linear':
161
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
162
+ Delta = self.beta_0**2 + tmp
163
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
164
+ elif self.schedule == 'discrete':
165
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
166
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
167
+ return t.reshape((-1,))
168
+
169
+
170
+ def model_wrapper(
171
+ model,
172
+ noise_schedule,
173
+ model_type="noise",
174
+ model_kwargs={},
175
+ guidance_type="uncond",
176
+ condition=None,
177
+ unconditional_condition=None,
178
+ guidance_scale=1.,
179
+ classifier_fn=None,
180
+ classifier_kwargs={},
181
+ ):
182
+ """Create a wrapper function for the noise prediction model.
183
+
184
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
185
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
186
+
187
+ We support four types of the diffusion model by setting `model_type`:
188
+
189
+ 1. "noise": noise prediction model. (Trained by predicting noise).
190
+
191
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
192
+
193
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
194
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
195
+
196
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
197
+ arXiv preprint arXiv:2202.00512 (2022).
198
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
199
+ arXiv preprint arXiv:2210.02303 (2022).
200
+
201
+ 4. "score": marginal score function. (Trained by denoising score matching).
202
+ Note that the score function and the noise prediction model follows a simple relationship:
203
+ ```
204
+ noise(x_t, t) = -sigma_t * score(x_t, t)
205
+ ```
206
+
207
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
208
+ 1. "uncond": unconditional sampling by DPMs.
209
+ The input `model` has the following format:
210
+ ``
211
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
212
+ ``
213
+
214
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
215
+ The input `model` has the following format:
216
+ ``
217
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
218
+ ``
219
+
220
+ The input `classifier_fn` has the following format:
221
+ ``
222
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
223
+ ``
224
+
225
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
226
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
227
+
228
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
229
+ The input `model` has the following format:
230
+ ``
231
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
232
+ ``
233
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
234
+
235
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
236
+ arXiv preprint arXiv:2207.12598 (2022).
237
+
238
+
239
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
240
+ or continuous-time labels (i.e. epsilon to T).
241
+
242
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
243
+ ``
244
+ def model_fn(x, t_continuous) -> noise:
245
+ t_input = get_model_input_time(t_continuous)
246
+ return noise_pred(model, x, t_input, **model_kwargs)
247
+ ``
248
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
249
+
250
+ ===============================================================
251
+
252
+ Args:
253
+ model: A diffusion model with the corresponding format described above.
254
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
255
+ model_type: A `str`. The parameterization type of the diffusion model.
256
+ "noise" or "x_start" or "v" or "score".
257
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
258
+ guidance_type: A `str`. The type of the guidance for sampling.
259
+ "uncond" or "classifier" or "classifier-free".
260
+ condition: A pytorch tensor. The condition for the guided sampling.
261
+ Only used for "classifier" or "classifier-free" guidance type.
262
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
263
+ Only used for "classifier-free" guidance type.
264
+ guidance_scale: A `float`. The scale for the guided sampling.
265
+ classifier_fn: A classifier function. Only used for the classifier guidance.
266
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
267
+ Returns:
268
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
269
+ """
270
+
271
+ def get_model_input_time(t_continuous):
272
+ """
273
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
274
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
275
+ For continuous-time DPMs, we just use `t_continuous`.
276
+ """
277
+ if noise_schedule.schedule == 'discrete':
278
+ return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N
279
+ else:
280
+ return t_continuous
281
+
282
+ def noise_pred_fn(x, t_continuous, cond=None):
283
+ t_input = get_model_input_time(t_continuous)
284
+ if cond is None:
285
+ output = model(x, t_input, **model_kwargs)
286
+ else:
287
+ output = model(x, t_input, cond, **model_kwargs)
288
+ if model_type == "noise":
289
+ return output
290
+ elif model_type == "x_start":
291
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
292
+ return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
293
+ elif model_type == "v":
294
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
295
+ return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
296
+ elif model_type == "score":
297
+ sigma_t = noise_schedule.marginal_std(t_continuous)
298
+ return -expand_dims(sigma_t, x.dim()) * output
299
+
300
+ def cond_grad_fn(x, t_input):
301
+ """
302
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
303
+ """
304
+ with torch.enable_grad():
305
+ x_in = x.detach().requires_grad_(True)
306
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
307
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
308
+
309
+ def model_fn(x, t_continuous):
310
+ """
311
+ The noise predicition model function that is used for DPM-Solver.
312
+ """
313
+ if guidance_type == "uncond":
314
+ return noise_pred_fn(x, t_continuous)
315
+ elif guidance_type == "classifier":
316
+ assert classifier_fn is not None
317
+ t_input = get_model_input_time(t_continuous)
318
+ cond_grad = cond_grad_fn(x, t_input)
319
+ sigma_t = noise_schedule.marginal_std(t_continuous)
320
+ noise = noise_pred_fn(x, t_continuous)
321
+ return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
322
+ elif guidance_type == "classifier-free":
323
+ if guidance_scale == 1. or unconditional_condition is None:
324
+ return noise_pred_fn(x, t_continuous, cond=condition)
325
+ else:
326
+ x_in = torch.cat([x] * 2)
327
+ t_in = torch.cat([t_continuous] * 2)
328
+ c_in = torch.cat([unconditional_condition, condition])
329
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
330
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
331
+
332
+ assert model_type in ["noise", "x_start", "v", "score"]
333
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
334
+ return model_fn
335
+
336
+
337
+ class DPM_Solver:
338
+ def __init__(
339
+ self,
340
+ model_fn,
341
+ noise_schedule,
342
+ algorithm_type="dpmsolver++",
343
+ correcting_x0_fn=None,
344
+ correcting_xt_fn=None,
345
+ thresholding_max_val=1.,
346
+ dynamic_thresholding_ratio=0.995,
347
+ ):
348
+ """Construct a DPM-Solver.
349
+
350
+ We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
351
+
352
+ We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
353
+ can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
354
+ dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
355
+ DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
356
+ DPMs (such as stable-diffusion).
357
+
358
+ To support advanced algorithms in image-to-image applications, we also support corrector functions for
359
+ both x0 and xt.
360
+
361
+ Args:
362
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
363
+ ``
364
+ def model_fn(x, t_continuous):
365
+ return noise
366
+ ``
367
+ The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
368
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
369
+ algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
370
+ correcting_x0_fn: A `str` or a function with the following format:
371
+ ```
372
+ def correcting_x0_fn(x0, t):
373
+ x0_new = ...
374
+ return x0_new
375
+ ```
376
+ This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
377
+ ```
378
+ x0_pred = data_pred_model(xt, t)
379
+ if correcting_x0_fn is not None:
380
+ x0_pred = correcting_x0_fn(x0_pred, t)
381
+ xt_1 = update(x0_pred, xt, t)
382
+ ```
383
+ If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
384
+ correcting_xt_fn: A function with the following format:
385
+ ```
386
+ def correcting_xt_fn(xt, t, step):
387
+ x_new = ...
388
+ return x_new
389
+ ```
390
+ This function is to correct the intermediate samples xt at each sampling step. e.g.,
391
+ ```
392
+ xt = ...
393
+ xt = correcting_xt_fn(xt, t, step)
394
+ ```
395
+ thresholding_max_val: A `float`. The max value for thresholding.
396
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
397
+ dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
398
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
399
+
400
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
401
+ Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
402
+ with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
403
+ """
404
+ self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
405
+ self.noise_schedule = noise_schedule
406
+ assert algorithm_type in ["dpmsolver", "dpmsolver++"]
407
+ self.algorithm_type = algorithm_type
408
+ if correcting_x0_fn == "dynamic_thresholding":
409
+ self.correcting_x0_fn = self.dynamic_thresholding_fn
410
+ else:
411
+ self.correcting_x0_fn = correcting_x0_fn
412
+ self.correcting_xt_fn = correcting_xt_fn
413
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
414
+ self.thresholding_max_val = thresholding_max_val
415
+
416
+ def dynamic_thresholding_fn(self, x0, t):
417
+ """
418
+ The dynamic thresholding method.
419
+ """
420
+ dims = x0.dim()
421
+ p = self.dynamic_thresholding_ratio
422
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
423
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
424
+ x0 = torch.clamp(x0, -s, s) / s
425
+ return x0
426
+
427
+ def noise_prediction_fn(self, x, t):
428
+ """
429
+ Return the noise prediction model.
430
+ """
431
+ return self.model(x, t)
432
+
433
+ def data_prediction_fn(self, x, t):
434
+ """
435
+ Return the data prediction model (with corrector).
436
+ """
437
+ noise = self.noise_prediction_fn(x, t)
438
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
439
+ x0 = (x - sigma_t * noise) / alpha_t
440
+ if self.correcting_x0_fn is not None:
441
+ x0 = self.correcting_x0_fn(x0, t)
442
+ return x0
443
+
444
+ def model_fn(self, x, t):
445
+ """
446
+ Convert the model to the noise prediction model or the data prediction model.
447
+ """
448
+ if self.algorithm_type == "dpmsolver++":
449
+ return self.data_prediction_fn(x, t)
450
+ else:
451
+ return self.noise_prediction_fn(x, t)
452
+
453
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
454
+ """Compute the intermediate time steps for sampling.
455
+
456
+ Args:
457
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
458
+ - 'logSNR': uniform logSNR for the time steps.
459
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
460
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
461
+ t_T: A `float`. The starting time of the sampling (default is T).
462
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
463
+ N: A `int`. The total number of the spacing of the time steps.
464
+ device: A torch device.
465
+ Returns:
466
+ A pytorch tensor of the time steps, with the shape (N + 1,).
467
+ """
468
+ if skip_type == 'logSNR':
469
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
470
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
471
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
472
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
473
+ elif skip_type == 'time_uniform':
474
+ return torch.linspace(t_T, t_0, N + 1).to(device)
475
+ elif skip_type == 'time_quadratic':
476
+ t_order = 2
477
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
478
+ return t
479
+ else:
480
+ raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
481
+
482
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
483
+ """
484
+ Get the order of each step for sampling by the singlestep DPM-Solver.
485
+
486
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
487
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
488
+ - If order == 1:
489
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
490
+ - If order == 2:
491
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
492
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
493
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
494
+ - If order == 3:
495
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
496
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
497
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
498
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
499
+
500
+ ============================================
501
+ Args:
502
+ order: A `int`. The max order for the solver (2 or 3).
503
+ steps: A `int`. The total number of function evaluations (NFE).
504
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
505
+ - 'logSNR': uniform logSNR for the time steps.
506
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
507
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
508
+ t_T: A `float`. The starting time of the sampling (default is T).
509
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
510
+ device: A torch device.
511
+ Returns:
512
+ orders: A list of the solver order of each step.
513
+ """
514
+ if order == 3:
515
+ K = steps // 3 + 1
516
+ if steps % 3 == 0:
517
+ orders = [3,] * (K - 2) + [2, 1]
518
+ elif steps % 3 == 1:
519
+ orders = [3,] * (K - 1) + [1]
520
+ else:
521
+ orders = [3,] * (K - 1) + [2]
522
+ elif order == 2:
523
+ if steps % 2 == 0:
524
+ K = steps // 2
525
+ orders = [2,] * K
526
+ else:
527
+ K = steps // 2 + 1
528
+ orders = [2,] * (K - 1) + [1]
529
+ elif order == 1:
530
+ K = 1
531
+ orders = [1,] * steps
532
+ else:
533
+ raise ValueError("'order' must be '1' or '2' or '3'.")
534
+ if skip_type == 'logSNR':
535
+ # To reproduce the results in DPM-Solver paper
536
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
537
+ else:
538
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
539
+ return timesteps_outer, orders
540
+
541
+ def denoise_to_zero_fn(self, x, s):
542
+ """
543
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
544
+ """
545
+ return self.data_prediction_fn(x, s)
546
+
547
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
548
+ """
549
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
550
+
551
+ Args:
552
+ x: A pytorch tensor. The initial value at time `s`.
553
+ s: A pytorch tensor. The starting time, with the shape (1,).
554
+ t: A pytorch tensor. The ending time, with the shape (1,).
555
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
556
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
557
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
558
+ Returns:
559
+ x_t: A pytorch tensor. The approximated solution at time `t`.
560
+ """
561
+ ns = self.noise_schedule
562
+ dims = x.dim()
563
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
564
+ h = lambda_t - lambda_s
565
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
566
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
567
+ alpha_t = torch.exp(log_alpha_t)
568
+
569
+ if self.algorithm_type == "dpmsolver++":
570
+ phi_1 = torch.expm1(-h)
571
+ if model_s is None:
572
+ model_s = self.model_fn(x, s)
573
+ x_t = (
574
+ sigma_t / sigma_s * x
575
+ - alpha_t * phi_1 * model_s
576
+ )
577
+ if return_intermediate:
578
+ return x_t, {'model_s': model_s}
579
+ else:
580
+ return x_t
581
+ else:
582
+ phi_1 = torch.expm1(h)
583
+ if model_s is None:
584
+ model_s = self.model_fn(x, s)
585
+ x_t = (
586
+ torch.exp(log_alpha_t - log_alpha_s) * x
587
+ - (sigma_t * phi_1) * model_s
588
+ )
589
+ if return_intermediate:
590
+ return x_t, {'model_s': model_s}
591
+ else:
592
+ return x_t
593
+
594
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpmsolver'):
595
+ """
596
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
597
+
598
+ Args:
599
+ x: A pytorch tensor. The initial value at time `s`.
600
+ s: A pytorch tensor. The starting time, with the shape (1,).
601
+ t: A pytorch tensor. The ending time, with the shape (1,).
602
+ r1: A `float`. The hyperparameter of the second-order solver.
603
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
604
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
605
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
606
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
607
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
608
+ Returns:
609
+ x_t: A pytorch tensor. The approximated solution at time `t`.
610
+ """
611
+ if solver_type not in ['dpmsolver', 'taylor']:
612
+ raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
613
+ if r1 is None:
614
+ r1 = 0.5
615
+ ns = self.noise_schedule
616
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
617
+ h = lambda_t - lambda_s
618
+ lambda_s1 = lambda_s + r1 * h
619
+ s1 = ns.inverse_lambda(lambda_s1)
620
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t)
621
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
622
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
623
+
624
+ if self.algorithm_type == "dpmsolver++":
625
+ phi_11 = torch.expm1(-r1 * h)
626
+ phi_1 = torch.expm1(-h)
627
+
628
+ if model_s is None:
629
+ model_s = self.model_fn(x, s)
630
+ x_s1 = (
631
+ (sigma_s1 / sigma_s) * x
632
+ - (alpha_s1 * phi_11) * model_s
633
+ )
634
+ model_s1 = self.model_fn(x_s1, s1)
635
+ if solver_type == 'dpmsolver':
636
+ x_t = (
637
+ (sigma_t / sigma_s) * x
638
+ - (alpha_t * phi_1) * model_s
639
+ - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
640
+ )
641
+ elif solver_type == 'taylor':
642
+ x_t = (
643
+ (sigma_t / sigma_s) * x
644
+ - (alpha_t * phi_1) * model_s
645
+ + (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s)
646
+ )
647
+ else:
648
+ phi_11 = torch.expm1(r1 * h)
649
+ phi_1 = torch.expm1(h)
650
+
651
+ if model_s is None:
652
+ model_s = self.model_fn(x, s)
653
+ x_s1 = (
654
+ torch.exp(log_alpha_s1 - log_alpha_s) * x
655
+ - (sigma_s1 * phi_11) * model_s
656
+ )
657
+ model_s1 = self.model_fn(x_s1, s1)
658
+ if solver_type == 'dpmsolver':
659
+ x_t = (
660
+ torch.exp(log_alpha_t - log_alpha_s) * x
661
+ - (sigma_t * phi_1) * model_s
662
+ - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
663
+ )
664
+ elif solver_type == 'taylor':
665
+ x_t = (
666
+ torch.exp(log_alpha_t - log_alpha_s) * x
667
+ - (sigma_t * phi_1) * model_s
668
+ - (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s)
669
+ )
670
+ if return_intermediate:
671
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
672
+ else:
673
+ return x_t
674
+
675
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpmsolver'):
676
+ """
677
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
678
+
679
+ Args:
680
+ x: A pytorch tensor. The initial value at time `s`.
681
+ s: A pytorch tensor. The starting time, with the shape (1,).
682
+ t: A pytorch tensor. The ending time, with the shape (1,).
683
+ r1: A `float`. The hyperparameter of the third-order solver.
684
+ r2: A `float`. The hyperparameter of the third-order solver.
685
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
686
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
687
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
688
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
689
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
690
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
691
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
692
+ Returns:
693
+ x_t: A pytorch tensor. The approximated solution at time `t`.
694
+ """
695
+ if solver_type not in ['dpmsolver', 'taylor']:
696
+ raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
697
+ if r1 is None:
698
+ r1 = 1. / 3.
699
+ if r2 is None:
700
+ r2 = 2. / 3.
701
+ ns = self.noise_schedule
702
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
703
+ h = lambda_t - lambda_s
704
+ lambda_s1 = lambda_s + r1 * h
705
+ lambda_s2 = lambda_s + r2 * h
706
+ s1 = ns.inverse_lambda(lambda_s1)
707
+ s2 = ns.inverse_lambda(lambda_s2)
708
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
709
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t)
710
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
711
+
712
+ if self.algorithm_type == "dpmsolver++":
713
+ phi_11 = torch.expm1(-r1 * h)
714
+ phi_12 = torch.expm1(-r2 * h)
715
+ phi_1 = torch.expm1(-h)
716
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
717
+ phi_2 = phi_1 / h + 1.
718
+ phi_3 = phi_2 / h - 0.5
719
+
720
+ if model_s is None:
721
+ model_s = self.model_fn(x, s)
722
+ if model_s1 is None:
723
+ x_s1 = (
724
+ (sigma_s1 / sigma_s) * x
725
+ - (alpha_s1 * phi_11) * model_s
726
+ )
727
+ model_s1 = self.model_fn(x_s1, s1)
728
+ x_s2 = (
729
+ (sigma_s2 / sigma_s) * x
730
+ - (alpha_s2 * phi_12) * model_s
731
+ + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
732
+ )
733
+ model_s2 = self.model_fn(x_s2, s2)
734
+ if solver_type == 'dpmsolver':
735
+ x_t = (
736
+ (sigma_t / sigma_s) * x
737
+ - (alpha_t * phi_1) * model_s
738
+ + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
739
+ )
740
+ elif solver_type == 'taylor':
741
+ D1_0 = (1. / r1) * (model_s1 - model_s)
742
+ D1_1 = (1. / r2) * (model_s2 - model_s)
743
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
744
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
745
+ x_t = (
746
+ (sigma_t / sigma_s) * x
747
+ - (alpha_t * phi_1) * model_s
748
+ + (alpha_t * phi_2) * D1
749
+ - (alpha_t * phi_3) * D2
750
+ )
751
+ else:
752
+ phi_11 = torch.expm1(r1 * h)
753
+ phi_12 = torch.expm1(r2 * h)
754
+ phi_1 = torch.expm1(h)
755
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
756
+ phi_2 = phi_1 / h - 1.
757
+ phi_3 = phi_2 / h - 0.5
758
+
759
+ if model_s is None:
760
+ model_s = self.model_fn(x, s)
761
+ if model_s1 is None:
762
+ x_s1 = (
763
+ (torch.exp(log_alpha_s1 - log_alpha_s)) * x
764
+ - (sigma_s1 * phi_11) * model_s
765
+ )
766
+ model_s1 = self.model_fn(x_s1, s1)
767
+ x_s2 = (
768
+ (torch.exp(log_alpha_s2 - log_alpha_s)) * x
769
+ - (sigma_s2 * phi_12) * model_s
770
+ - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
771
+ )
772
+ model_s2 = self.model_fn(x_s2, s2)
773
+ if solver_type == 'dpmsolver':
774
+ x_t = (
775
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
776
+ - (sigma_t * phi_1) * model_s
777
+ - (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
778
+ )
779
+ elif solver_type == 'taylor':
780
+ D1_0 = (1. / r1) * (model_s1 - model_s)
781
+ D1_1 = (1. / r2) * (model_s2 - model_s)
782
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
783
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
784
+ x_t = (
785
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
786
+ - (sigma_t * phi_1) * model_s
787
+ - (sigma_t * phi_2) * D1
788
+ - (sigma_t * phi_3) * D2
789
+ )
790
+
791
+ if return_intermediate:
792
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
793
+ else:
794
+ return x_t
795
+
796
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
797
+ """
798
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
799
+
800
+ Args:
801
+ x: A pytorch tensor. The initial value at time `s`.
802
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
803
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
804
+ t: A pytorch tensor. The ending time, with the shape (1,).
805
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
806
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
807
+ Returns:
808
+ x_t: A pytorch tensor. The approximated solution at time `t`.
809
+ """
810
+ if solver_type not in ['dpmsolver', 'taylor']:
811
+ raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
812
+ ns = self.noise_schedule
813
+ model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
814
+ t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
815
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
816
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
817
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
818
+ alpha_t = torch.exp(log_alpha_t)
819
+
820
+ h_0 = lambda_prev_0 - lambda_prev_1
821
+ h = lambda_t - lambda_prev_0
822
+ r0 = h_0 / h
823
+ D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
824
+ if self.algorithm_type == "dpmsolver++":
825
+ phi_1 = torch.expm1(-h)
826
+ if solver_type == 'dpmsolver':
827
+ x_t = (
828
+ (sigma_t / sigma_prev_0) * x
829
+ - (alpha_t * phi_1) * model_prev_0
830
+ - 0.5 * (alpha_t * phi_1) * D1_0
831
+ )
832
+ elif solver_type == 'taylor':
833
+ x_t = (
834
+ (sigma_t / sigma_prev_0) * x
835
+ - (alpha_t * phi_1) * model_prev_0
836
+ + (alpha_t * (phi_1 / h + 1.)) * D1_0
837
+ )
838
+ else:
839
+ phi_1 = torch.expm1(h)
840
+ if solver_type == 'dpmsolver':
841
+ x_t = (
842
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
843
+ - (sigma_t * phi_1) * model_prev_0
844
+ - 0.5 * (sigma_t * phi_1) * D1_0
845
+ )
846
+ elif solver_type == 'taylor':
847
+ x_t = (
848
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
849
+ - (sigma_t * phi_1) * model_prev_0
850
+ - (sigma_t * (phi_1 / h - 1.)) * D1_0
851
+ )
852
+ return x_t
853
+
854
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'):
855
+ """
856
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
857
+
858
+ Args:
859
+ x: A pytorch tensor. The initial value at time `s`.
860
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
861
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
862
+ t: A pytorch tensor. The ending time, with the shape (1,).
863
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
864
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
865
+ Returns:
866
+ x_t: A pytorch tensor. The approximated solution at time `t`.
867
+ """
868
+ ns = self.noise_schedule
869
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
870
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
871
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
872
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
873
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
874
+ alpha_t = torch.exp(log_alpha_t)
875
+
876
+ h_1 = lambda_prev_1 - lambda_prev_2
877
+ h_0 = lambda_prev_0 - lambda_prev_1
878
+ h = lambda_t - lambda_prev_0
879
+ r0, r1 = h_0 / h, h_1 / h
880
+ D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
881
+ D1_1 = (1. / r1) * (model_prev_1 - model_prev_2)
882
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
883
+ D2 = (1. / (r0 + r1)) * (D1_0 - D1_1)
884
+ if self.algorithm_type == "dpmsolver++":
885
+ phi_1 = torch.expm1(-h)
886
+ phi_2 = phi_1 / h + 1.
887
+ phi_3 = phi_2 / h - 0.5
888
+ x_t = (
889
+ (sigma_t / sigma_prev_0) * x
890
+ - (alpha_t * phi_1) * model_prev_0
891
+ + (alpha_t * phi_2) * D1
892
+ - (alpha_t * phi_3) * D2
893
+ )
894
+ else:
895
+ phi_1 = torch.expm1(h)
896
+ phi_2 = phi_1 / h - 1.
897
+ phi_3 = phi_2 / h - 0.5
898
+ x_t = (
899
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
900
+ - (sigma_t * phi_1) * model_prev_0
901
+ - (sigma_t * phi_2) * D1
902
+ - (sigma_t * phi_3) * D2
903
+ )
904
+ return x_t
905
+
906
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None, r2=None):
907
+ """
908
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
909
+
910
+ Args:
911
+ x: A pytorch tensor. The initial value at time `s`.
912
+ s: A pytorch tensor. The starting time, with the shape (1,).
913
+ t: A pytorch tensor. The ending time, with the shape (1,).
914
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
915
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
916
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
917
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
918
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
919
+ r2: A `float`. The hyperparameter of the third-order solver.
920
+ Returns:
921
+ x_t: A pytorch tensor. The approximated solution at time `t`.
922
+ """
923
+ if order == 1:
924
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
925
+ elif order == 2:
926
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1)
927
+ elif order == 3:
928
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2)
929
+ else:
930
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
931
+
932
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'):
933
+ """
934
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
935
+
936
+ Args:
937
+ x: A pytorch tensor. The initial value at time `s`.
938
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
939
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
940
+ t: A pytorch tensor. The ending time, with the shape (1,).
941
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
942
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
943
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
944
+ Returns:
945
+ x_t: A pytorch tensor. The approximated solution at time `t`.
946
+ """
947
+ if order == 1:
948
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
949
+ elif order == 2:
950
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
951
+ elif order == 3:
952
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
953
+ else:
954
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
955
+
956
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpmsolver'):
957
+ """
958
+ The adaptive step size solver based on singlestep DPM-Solver.
959
+
960
+ Args:
961
+ x: A pytorch tensor. The initial value at time `t_T`.
962
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
963
+ t_T: A `float`. The starting time of the sampling (default is T).
964
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
965
+ h_init: A `float`. The initial step size (for logSNR).
966
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
967
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
968
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
969
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
970
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
971
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
972
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
973
+ Returns:
974
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
975
+
976
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
977
+ """
978
+ ns = self.noise_schedule
979
+ s = t_T * torch.ones((1,)).to(x)
980
+ lambda_s = ns.marginal_lambda(s)
981
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
982
+ h = h_init * torch.ones_like(s).to(x)
983
+ x_prev = x
984
+ nfe = 0
985
+ if order == 2:
986
+ r1 = 0.5
987
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
988
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
989
+ elif order == 3:
990
+ r1, r2 = 1. / 3., 2. / 3.
991
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type)
992
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
993
+ else:
994
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
995
+ while torch.abs((s - t_0)).mean() > t_err:
996
+ t = ns.inverse_lambda(lambda_s + h)
997
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
998
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
999
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
1000
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
1001
+ E = norm_fn((x_higher - x_lower) / delta).max()
1002
+ if torch.all(E <= 1.):
1003
+ x = x_higher
1004
+ s = t
1005
+ x_prev = x_lower
1006
+ lambda_s = ns.marginal_lambda(s)
1007
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
1008
+ nfe += order
1009
+ print('adaptive solver nfe', nfe)
1010
+ return x
1011
+
1012
+ def add_noise(self, x, t, noise=None):
1013
+ """
1014
+ Compute the noised input xt = alpha_t * x + sigma_t * noise.
1015
+
1016
+ Args:
1017
+ x: A `torch.Tensor` with shape `(batch_size, *shape)`.
1018
+ t: A `torch.Tensor` with shape `(t_size,)`.
1019
+ Returns:
1020
+ xt with shape `(t_size, batch_size, *shape)`.
1021
+ """
1022
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
1023
+ if noise is None:
1024
+ noise = torch.randn((t.shape[0], *x.shape), device=x.device)
1025
+ x = x.reshape((-1, *x.shape))
1026
+ xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
1027
+ if t.shape[0] == 1:
1028
+ return xt.squeeze(0)
1029
+ else:
1030
+ return xt
1031
+
1032
+ def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
1033
+ method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
1034
+ atol=0.0078, rtol=0.05, return_intermediate=False,
1035
+ ):
1036
+ """
1037
+ Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
1038
+ For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
1039
+ """
1040
+ t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start
1041
+ t_T = self.noise_schedule.T if t_end is None else t_end
1042
+ assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1043
+ return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type,
1044
+ method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero, solver_type=solver_type,
1045
+ atol=atol, rtol=rtol, return_intermediate=return_intermediate)
1046
+
1047
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
1048
+ method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
1049
+ atol=0.0078, rtol=0.05, return_intermediate=False,
1050
+ ):
1051
+ """
1052
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
1053
+
1054
+ =====================================================
1055
+
1056
+ We support the following algorithms for both noise prediction model and data prediction model:
1057
+ - 'singlestep':
1058
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
1059
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
1060
+ The total number of function evaluations (NFE) == `steps`.
1061
+ Given a fixed NFE == `steps`, the sampling procedure is:
1062
+ - If `order` == 1:
1063
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
1064
+ - If `order` == 2:
1065
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
1066
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
1067
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1068
+ - If `order` == 3:
1069
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
1070
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1071
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
1072
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
1073
+ - 'multistep':
1074
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
1075
+ We initialize the first `order` values by lower order multistep solvers.
1076
+ Given a fixed NFE == `steps`, the sampling procedure is:
1077
+ Denote K = steps.
1078
+ - If `order` == 1:
1079
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
1080
+ - If `order` == 2:
1081
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
1082
+ - If `order` == 3:
1083
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
1084
+ - 'singlestep_fixed':
1085
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
1086
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
1087
+ - 'adaptive':
1088
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
1089
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
1090
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
1091
+ (NFE) and the sample quality.
1092
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
1093
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
1094
+
1095
+ =====================================================
1096
+
1097
+ Some advices for choosing the algorithm:
1098
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1099
+ Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
1100
+ e.g., DPM-Solver:
1101
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
1102
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1103
+ skip_type='time_uniform', method='singlestep')
1104
+ e.g., DPM-Solver++:
1105
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1106
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1107
+ skip_type='time_uniform', method='singlestep')
1108
+ - For **guided sampling with large guidance scale** by DPMs:
1109
+ Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
1110
+ e.g.
1111
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1112
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1113
+ skip_type='time_uniform', method='multistep')
1114
+
1115
+ We support three types of `skip_type`:
1116
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1117
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1118
+ - 'time_quadratic': quadratic time for the time steps.
1119
+
1120
+ =====================================================
1121
+ Args:
1122
+ x: A pytorch tensor. The initial value at time `t_start`
1123
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1124
+ steps: A `int`. The total number of function evaluations (NFE).
1125
+ t_start: A `float`. The starting time of the sampling.
1126
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1127
+ t_end: A `float`. The ending time of the sampling.
1128
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1129
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1130
+ For discrete-time DPMs:
1131
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1132
+ For continuous-time DPMs:
1133
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1134
+ order: A `int`. The order of DPM-Solver.
1135
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1136
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1137
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1138
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1139
+
1140
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1141
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1142
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1143
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1144
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1145
+ it for high-resolutional images.
1146
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1147
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1148
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1149
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1150
+ solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
1151
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1152
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1153
+ return_intermediate: A `bool`. Whether to save the xt at each step.
1154
+ When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
1155
+ Returns:
1156
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1157
+
1158
+ """
1159
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1160
+ t_T = self.noise_schedule.T if t_start is None else t_start
1161
+ assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1162
+ if return_intermediate:
1163
+ assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values"
1164
+ if self.correcting_xt_fn is not None:
1165
+ assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None"
1166
+ device = x.device
1167
+ intermediates = []
1168
+ with torch.no_grad():
1169
+ if method == 'adaptive':
1170
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type)
1171
+ elif method == 'multistep':
1172
+ assert steps >= order
1173
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1174
+ assert timesteps.shape[0] - 1 == steps
1175
+ # Init the initial values.
1176
+ step = 0
1177
+ t = timesteps[step]
1178
+ t_prev_list = [t]
1179
+ model_prev_list = [self.model_fn(x, t)]
1180
+ if self.correcting_xt_fn is not None:
1181
+ x = self.correcting_xt_fn(x, t, step)
1182
+ if return_intermediate:
1183
+ intermediates.append(x)
1184
+ # Init the first `order` values by lower order multistep DPM-Solver.
1185
+ for step in range(1, order):
1186
+ t = timesteps[step]
1187
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step, solver_type=solver_type)
1188
+ if self.correcting_xt_fn is not None:
1189
+ x = self.correcting_xt_fn(x, t, step)
1190
+ if return_intermediate:
1191
+ intermediates.append(x)
1192
+ t_prev_list.append(t)
1193
+ model_prev_list.append(self.model_fn(x, t))
1194
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1195
+ for step in range(order, steps + 1):
1196
+ t = timesteps[step]
1197
+ # We only use lower order for steps < 10
1198
+ if lower_order_final and steps < 10:
1199
+ step_order = min(order, steps + 1 - step)
1200
+ else:
1201
+ step_order = order
1202
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type)
1203
+ if self.correcting_xt_fn is not None:
1204
+ x = self.correcting_xt_fn(x, t, step)
1205
+ if return_intermediate:
1206
+ intermediates.append(x)
1207
+ for i in range(order - 1):
1208
+ t_prev_list[i] = t_prev_list[i + 1]
1209
+ model_prev_list[i] = model_prev_list[i + 1]
1210
+ t_prev_list[-1] = t
1211
+ # We do not need to evaluate the final model value.
1212
+ if step < steps:
1213
+ model_prev_list[-1] = self.model_fn(x, t)
1214
+ elif method in ['singlestep', 'singlestep_fixed']:
1215
+ if method == 'singlestep':
1216
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device)
1217
+ elif method == 'singlestep_fixed':
1218
+ K = steps // order
1219
+ orders = [order,] * K
1220
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1221
+ for step, order in enumerate(orders):
1222
+ s, t = timesteps_outer[step], timesteps_outer[step + 1]
1223
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device)
1224
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1225
+ h = lambda_inner[-1] - lambda_inner[0]
1226
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1227
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1228
+ x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
1229
+ if self.correcting_xt_fn is not None:
1230
+ x = self.correcting_xt_fn(x, t, step)
1231
+ if return_intermediate:
1232
+ intermediates.append(x)
1233
+ else:
1234
+ raise ValueError("Got wrong method {}".format(method))
1235
+ if denoise_to_zero:
1236
+ t = torch.ones((1,)).to(device) * t_0
1237
+ x = self.denoise_to_zero_fn(x, t)
1238
+ if self.correcting_xt_fn is not None:
1239
+ x = self.correcting_xt_fn(x, t, step + 1)
1240
+ if return_intermediate:
1241
+ intermediates.append(x)
1242
+ if return_intermediate:
1243
+ return x, intermediates
1244
+ else:
1245
+ return x
1246
+
1247
+
1248
+
1249
+ #############################################################
1250
+ # other utility functions
1251
+ #############################################################
1252
+
1253
+ def interpolate_fn(x, xp, yp):
1254
+ """
1255
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1256
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1257
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1258
+
1259
+ Args:
1260
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1261
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1262
+ yp: PyTorch tensor with shape [C, K].
1263
+ Returns:
1264
+ The function values f(x), with shape [N, C].
1265
+ """
1266
+ N, K = x.shape[0], xp.shape[1]
1267
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1268
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1269
+ x_idx = torch.argmin(x_indices, dim=2)
1270
+ cand_start_idx = x_idx - 1
1271
+ start_idx = torch.where(
1272
+ torch.eq(x_idx, 0),
1273
+ torch.tensor(1, device=x.device),
1274
+ torch.where(
1275
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1276
+ ),
1277
+ )
1278
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1279
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1280
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1281
+ start_idx2 = torch.where(
1282
+ torch.eq(x_idx, 0),
1283
+ torch.tensor(0, device=x.device),
1284
+ torch.where(
1285
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1286
+ ),
1287
+ )
1288
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1289
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1290
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1291
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1292
+ return cand
1293
+
1294
+
1295
+ def expand_dims(v, dims):
1296
+ """
1297
+ Expand the tensor `v` to the dim `dims`.
1298
+
1299
+ Args:
1300
+ `v`: a PyTorch tensor with shape [N].
1301
+ `dim`: a `int`.
1302
+ Returns:
1303
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1304
+ """
1305
+ return v[(...,) + (None,)*(dims - 1)]