Spaces:
Sleeping
Sleeping
liampond
commited on
Commit
·
c42fe7e
0
Parent(s):
Clean deploy snapshot
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .bashrc +9 -0
- .devcontainer/devcontainer.json +33 -0
- .gitignore +34 -0
- .streamlit/config.toml +3 -0
- .vscode/settings.json +6 -0
- LICENSE +21 -0
- Makefile +26 -0
- README.md +99 -0
- app.sh +2 -0
- augmentation/spec_stretch.py +92 -0
- basics/base_augmentation.py +28 -0
- basics/base_binarizer.py +347 -0
- basics/base_dataset.py +58 -0
- basics/base_exporter.py +59 -0
- basics/base_module.py +18 -0
- basics/base_pe.py +7 -0
- basics/base_svs_infer.py +131 -0
- basics/base_task.py +520 -0
- basics/base_vocoder.py +23 -0
- configs/CantusSVS_acoustic.yaml +149 -0
- configs/CantusSVS_variance.yaml +153 -0
- configs/base.yaml +94 -0
- configs/defaults/acoustic.yaml +138 -0
- configs/defaults/base.yaml +94 -0
- configs/defaults/variance.yaml +145 -0
- configs/templates/config_acoustic.yaml +105 -0
- configs/templates/config_variance.yaml +129 -0
- deployment/.gitignore +7 -0
- deployment/__init__.py +0 -0
- deployment/benchmarks/infer_acoustic.py +32 -0
- deployment/benchmarks/infer_nsf_hifigan.py +16 -0
- deployment/exporters/__init__.py +3 -0
- deployment/exporters/acoustic_exporter.py +405 -0
- deployment/exporters/nsf_hifigan_exporter.py +120 -0
- deployment/exporters/variance_exporter.py +781 -0
- deployment/modules/__init__.py +0 -0
- deployment/modules/diffusion.py +220 -0
- deployment/modules/fastspeech2.py +153 -0
- deployment/modules/nsf_hifigan.py +16 -0
- deployment/modules/rectified_flow.py +123 -0
- deployment/modules/toplevel.py +392 -0
- dictionaries/.gitignore +3 -0
- docs/BestPractices.md +618 -0
- docs/ConfigurationSchemas.md +2109 -0
- docs/GettingStarted.md +164 -0
- docs/resources/arch-acoustic.drawio +0 -0
- docs/resources/arch-overview.drawio +123 -0
- docs/resources/arch-overview.jpg +0 -0
- docs/resources/arch-variance.drawio +0 -0
- inference/dpm_solver_pytorch.py +1305 -0
.bashrc
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Warn if Arrow is not loaded
|
2 |
+
if ! module list 2>&1 | grep -q arrow; then
|
3 |
+
echo -e "\n\033[1;33m[WARNING]\033[0m Arrow module is not loaded! Run: module load arrow/19.0.1"
|
4 |
+
fi
|
5 |
+
|
6 |
+
# Warn if GCC is not loaded
|
7 |
+
if ! module list 2>&1 | grep -q gcc; then
|
8 |
+
echo -e "\n\033[1;33m[WARNING]\033[0m GCC module is not loaded! Run: module load gcc/12.3"
|
9 |
+
fi
|
.devcontainer/devcontainer.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "Python 3",
|
3 |
+
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
|
4 |
+
"image": "mcr.microsoft.com/devcontainers/python:1-3.11-bullseye",
|
5 |
+
"customizations": {
|
6 |
+
"codespaces": {
|
7 |
+
"openFiles": [
|
8 |
+
"README.md",
|
9 |
+
"webapp/app.py"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
"vscode": {
|
13 |
+
"settings": {},
|
14 |
+
"extensions": [
|
15 |
+
"ms-python.python",
|
16 |
+
"ms-python.vscode-pylance"
|
17 |
+
]
|
18 |
+
}
|
19 |
+
},
|
20 |
+
"updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y <packages.txt; [ -f requirements.txt ] && pip3 install --user -r requirements.txt; pip3 install --user streamlit; echo '✅ Packages installed and Requirements met'",
|
21 |
+
"postAttachCommand": {
|
22 |
+
"server": "streamlit run webapp/app.py --server.enableCORS false --server.enableXsrfProtection false"
|
23 |
+
},
|
24 |
+
"portsAttributes": {
|
25 |
+
"8501": {
|
26 |
+
"label": "Application",
|
27 |
+
"onAutoForward": "openPreview"
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"forwardPorts": [
|
31 |
+
8501
|
32 |
+
]
|
33 |
+
}
|
.gitignore
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
artifacts/
|
2 |
+
checkpoints/
|
3 |
+
venv/
|
4 |
+
.venv/
|
5 |
+
svenv/
|
6 |
+
env-py311/
|
7 |
+
*.swp
|
8 |
+
data/
|
9 |
+
DiffSinger/
|
10 |
+
training_logs/
|
11 |
+
training_*.out
|
12 |
+
__pycache__/
|
13 |
+
|
14 |
+
*.onnx
|
15 |
+
*.ckpt
|
16 |
+
*.pt
|
17 |
+
*.zip
|
18 |
+
*.out
|
19 |
+
*.log
|
20 |
+
*.pyc
|
21 |
+
*.pyo
|
22 |
+
*.swp
|
23 |
+
artifacts/
|
24 |
+
data/NNSVS_training_data/
|
25 |
+
data/binary/
|
26 |
+
outputs/
|
27 |
+
|
28 |
+
commands.txt
|
29 |
+
19Apr-Commands.sh
|
30 |
+
|
31 |
+
webapp/output/
|
32 |
+
webapp/tmp_ds/
|
33 |
+
webapp/uploaded_ds/
|
34 |
+
webapp/uploaded_mei/
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[server]
|
2 |
+
fileWatcherType = "none"
|
3 |
+
runOnSave = false
|
.vscode/settings.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"python.defaultInterpreterPath": "webapp/venv/bin/python",
|
3 |
+
"python.analysis.extraPaths": ["./inference"],
|
4 |
+
"python.analysis.indexing": true,
|
5 |
+
"python.analysis.typeCheckingMode": "off"
|
6 |
+
}
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 Liam Pond
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Makefile
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.PHONY: help setup clean run
|
2 |
+
|
3 |
+
# Default command: show help
|
4 |
+
help:
|
5 |
+
@echo ""
|
6 |
+
@echo "Available commands:"
|
7 |
+
@echo " make setup - Set up virtual environment (use 'make setup reset=1' to force rebuild)"
|
8 |
+
@echo " make clean - Remove the virtual environment"
|
9 |
+
@echo " make run - Launch the Streamlit app"
|
10 |
+
@echo ""
|
11 |
+
|
12 |
+
# Set up the environment
|
13 |
+
setup:
|
14 |
+
ifeq ($(reset),1)
|
15 |
+
rm -rf venv
|
16 |
+
endif
|
17 |
+
bash scripts/setup_env.sh
|
18 |
+
|
19 |
+
# Remove the virtual environment
|
20 |
+
clean:
|
21 |
+
rm -rf venv
|
22 |
+
|
23 |
+
# Run the Streamlit app
|
24 |
+
run:
|
25 |
+
source venv/bin/activate && streamlit run webapp/app.py
|
26 |
+
|
README.md
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CantusSVS
|
2 |
+
|
3 |
+
## Table of Contents
|
4 |
+
- [About CantusSVS](#about-cantussvs)
|
5 |
+
- [Quick Start](#quick-start)
|
6 |
+
- [Preparing Your Input](#preparing-your-input)
|
7 |
+
- [Running Locally](#running-locally)
|
8 |
+
- [FAQ](#faq)
|
9 |
+
|
10 |
+
---
|
11 |
+
|
12 |
+
## About CantusSVS
|
13 |
+
|
14 |
+
CantusSVS is a singing voice synthesis tool that automatically generates audio playback for the Latin chants in Cantus. You can access CantusSVS directly in the browser here [**https://cantussvs.streamlit.app**](https://cantussvs.streamlit.app). For training and inferencing, we use **DiffSinger**, a diffusion-based singing voice synthesis model described in the paper below:
|
15 |
+
|
16 |
+
**DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism**
|
17 |
+
|
18 |
+
Liu, Jinglin, Chengxi Li, Yi Ren, Feiyang Chen, and Zhou Zhao. 2022. "Diffsinger: Singing Voice Synthesis via Shallow Diffusion Mechanism." In *Proceedings of the AAAI Conference on Artificial Intelligence* 36 10: 11020–11028. [https://arxiv.org/abs/2105.02446](http://dx.doi.org/10.1609/aaai.v36i10.21350).
|
19 |
+
|
20 |
+
Training was done using Cedar, a cluster provided by the Digital Research Alliance of Canada. To set up training locally, follow [this tutorial](https://youtu.be/Sxt11TAflV0?feature=shared) by [tigermeat](https://www.youtube.com/@spicytigermeat).
|
21 |
+
|
22 |
+
For general help training and creating a dataset, [this tutorial](https://docs.google.com/document/d/1uMsepxbdUW65PfIWL1pt2OM6ZKa5ybTTJOpZ733Ht6s/view) by [PixPrucer](https://bsky.app/profile/pixprucer.bsky.social) is an excellent guide. For help, join the [DiffSinger Discord server](https://discord.gg/DZ6fhEUfnb).
|
23 |
+
|
24 |
+
The dataset used for this project was built using [*Adventus: Dominica prima adventus Domini*](https://youtu.be/ThnPySybDJs?feature=shared), the first track from [Psallentes](https://psallentes.com/)' album *Salzinnes Saints*. Psallentes is a Belgian women's chorus that specializes in Late Medieval and Renaissance music. *Salzinnes Saints* is an album of music from the [Salzinnes Antiphonal](https://www.smu.ca/academics/archives/the-salzinnes-antiphonal.html), a mid-sixteenth century choirbook with the music and text for the Liturgy of the Hours.
|
25 |
+
|
26 |
+
---
|
27 |
+
|
28 |
+
## Quick Start
|
29 |
+
|
30 |
+
1. Clone the repository:
|
31 |
+
|
32 |
+
```bash
|
33 |
+
git clone https://github.com/yourusername/CantusSVS.git
|
34 |
+
cd CantusSVS
|
35 |
+
```
|
36 |
+
|
37 |
+
2. Set up the environment:
|
38 |
+
|
39 |
+
```bash
|
40 |
+
make setup
|
41 |
+
```
|
42 |
+
|
43 |
+
3. Run the web app locally:
|
44 |
+
|
45 |
+
```bash
|
46 |
+
make run
|
47 |
+
```
|
48 |
+
|
49 |
+
4. Open your browser at:
|
50 |
+
|
51 |
+
```
|
52 |
+
http://localhost:8501
|
53 |
+
```
|
54 |
+
|
55 |
+
Or just use the hosted app here: [https://cantussvs.streamlit.app](https://cantussvs.streamlit.app)
|
56 |
+
|
57 |
+
---
|
58 |
+
|
59 |
+
## Preparing Your Input
|
60 |
+
|
61 |
+
- Most commercial music composition software can export `.mei` files. MuseScore 4 is free to use.
|
62 |
+
- Input format must be `.mei` (Music Encoding Initiative XML).
|
63 |
+
- Only **monophonic** scores are supported (one staff, one voice).
|
64 |
+
- Lyrics must be embedded in the MEI file and aligned with notes.
|
65 |
+
|
66 |
+
Validation tool:
|
67 |
+
|
68 |
+
```bash
|
69 |
+
python scripts/validate_mei.py your_song.mei
|
70 |
+
```
|
71 |
+
|
72 |
+
---
|
73 |
+
|
74 |
+
## Running Locally
|
75 |
+
|
76 |
+
1. Drop your `.mei` file into the upload area of the web app.
|
77 |
+
|
78 |
+
2. Choose settings:
|
79 |
+
- Tempo (BPM)
|
80 |
+
- Output file name (optional)
|
81 |
+
|
82 |
+
3. Hit "Synthesize" and download the resulting `.wav` file.
|
83 |
+
|
84 |
+
Generated files:
|
85 |
+
- `.wav`: final audio output
|
86 |
+
- `.mel.npy`: intermediate mel-spectrogram
|
87 |
+
- `.info.json`: metadata (phoneme sequence, note mapping)
|
88 |
+
|
89 |
+
---
|
90 |
+
|
91 |
+
## FAQ
|
92 |
+
|
93 |
+
**Q: Can I synthesize polyphonic (multi-voice) chants?**
|
94 |
+
A: No, only monophonic scores are supported currently. However, in the future, polyphonic chants could be synthesized by layering multiple monophonic voices.
|
95 |
+
|
96 |
+
**Q: Can I change the voice timbre?**
|
97 |
+
A: In the webapp, only the provided pre-trained model is available. However, DiffSinger will learn the timbre of the input dataset so if you train your own model, you can control the timbre that way.
|
98 |
+
|
99 |
+
---
|
app.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
streamlit run webapp/app.py
|
augmentation/spec_stretch.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
|
3 |
+
import librosa
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from basics.base_augmentation import BaseAugmentation, require_same_keys
|
8 |
+
from basics.base_pe import BasePE
|
9 |
+
from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST
|
10 |
+
from modules.fastspeech.tts_modules import LengthRegulator
|
11 |
+
from utils.binarizer_utils import get_mel_torch, get_mel2ph_torch
|
12 |
+
from utils.hparams import hparams
|
13 |
+
from utils.infer_utils import resample_align_curve
|
14 |
+
|
15 |
+
|
16 |
+
class SpectrogramStretchAugmentation(BaseAugmentation):
|
17 |
+
"""
|
18 |
+
This class contains methods for frequency-domain and time-domain stretching augmentation.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, data_dirs: list, augmentation_args: dict, pe: BasePE = None):
|
22 |
+
super().__init__(data_dirs, augmentation_args)
|
23 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
24 |
+
self.lr = LengthRegulator().to(self.device)
|
25 |
+
self.pe = pe
|
26 |
+
|
27 |
+
@require_same_keys
|
28 |
+
def process_item(self, item: dict, key_shift=0., speed=1., replace_spk_id=None) -> dict:
|
29 |
+
aug_item = deepcopy(item)
|
30 |
+
waveform, _ = librosa.load(aug_item['wav_fn'], sr=hparams['audio_sample_rate'], mono=True)
|
31 |
+
mel = get_mel_torch(
|
32 |
+
waveform, hparams['audio_sample_rate'], num_mel_bins=hparams['audio_num_mel_bins'],
|
33 |
+
hop_size=hparams['hop_size'], win_size=hparams['win_size'], fft_size=hparams['fft_size'],
|
34 |
+
fmin=hparams['fmin'], fmax=hparams['fmax'],
|
35 |
+
keyshift=key_shift, speed=speed, device=self.device
|
36 |
+
)
|
37 |
+
|
38 |
+
aug_item['mel'] = mel
|
39 |
+
|
40 |
+
if speed != 1. or hparams['use_speed_embed']:
|
41 |
+
aug_item['length'] = mel.shape[0]
|
42 |
+
aug_item['speed'] = int(np.round(hparams['hop_size'] * speed)) / hparams['hop_size'] # real speed
|
43 |
+
aug_item['seconds'] /= aug_item['speed']
|
44 |
+
aug_item['ph_dur'] /= aug_item['speed']
|
45 |
+
aug_item['mel2ph'] = get_mel2ph_torch(
|
46 |
+
self.lr, torch.from_numpy(aug_item['ph_dur']), aug_item['length'], self.timestep, device=self.device
|
47 |
+
).cpu().numpy()
|
48 |
+
|
49 |
+
f0, _ = self.pe.get_pitch(
|
50 |
+
waveform, samplerate=hparams['audio_sample_rate'], length=aug_item['length'],
|
51 |
+
hop_size=hparams['hop_size'], f0_min=hparams['f0_min'], f0_max=hparams['f0_max'],
|
52 |
+
speed=speed, interp_uv=True
|
53 |
+
)
|
54 |
+
aug_item['f0'] = f0.astype(np.float32)
|
55 |
+
|
56 |
+
# NOTE: variance curves are directly resampled according to speed,
|
57 |
+
# despite how frequency-domain features change after the augmentation.
|
58 |
+
# For acoustic models, this can bring more (but not much) difficulty
|
59 |
+
# to learn how variance curves affect the mel spectrograms, since
|
60 |
+
# they must realize how the augmentation causes the mismatch.
|
61 |
+
#
|
62 |
+
# This is a simple way to combine augmentation and variances. However,
|
63 |
+
# dealing variance curves like this will decrease the accuracy of
|
64 |
+
# variance controls. In most situations, not being ~100% accurate
|
65 |
+
# will not ruin the user experience. For example, it does not matter
|
66 |
+
# if the energy does not exactly equal the RMS; it is just fine
|
67 |
+
# as long as higher energy can bring higher loudness and strength.
|
68 |
+
# The neural networks itself cannot be 100% accurate, though.
|
69 |
+
#
|
70 |
+
# There are yet other choices to simulate variance curves:
|
71 |
+
# 1. Re-extract the features from resampled waveforms;
|
72 |
+
# 2. Re-extract the features from re-constructed waveforms using
|
73 |
+
# the transformed mel spectrograms through the vocoder.
|
74 |
+
# But there are actually no perfect ways to make them all accurate
|
75 |
+
# and stable.
|
76 |
+
for v_name in VARIANCE_CHECKLIST:
|
77 |
+
if v_name in item:
|
78 |
+
aug_item[v_name] = resample_align_curve(
|
79 |
+
aug_item[v_name],
|
80 |
+
original_timestep=self.timestep,
|
81 |
+
target_timestep=self.timestep * aug_item['speed'],
|
82 |
+
align_length=aug_item['length']
|
83 |
+
)
|
84 |
+
|
85 |
+
if key_shift != 0. or hparams['use_key_shift_embed']:
|
86 |
+
if replace_spk_id is None:
|
87 |
+
aug_item['key_shift'] = key_shift
|
88 |
+
else:
|
89 |
+
aug_item['spk_id'] = replace_spk_id
|
90 |
+
aug_item['f0'] *= 2 ** (key_shift / 12)
|
91 |
+
|
92 |
+
return aug_item
|
basics/base_augmentation.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.hparams import hparams
|
2 |
+
|
3 |
+
|
4 |
+
class BaseAugmentation:
|
5 |
+
"""
|
6 |
+
Base class for data augmentation.
|
7 |
+
All methods of this class should be thread-safe.
|
8 |
+
1. *process_item*:
|
9 |
+
Apply augmentation to one piece of data.
|
10 |
+
"""
|
11 |
+
def __init__(self, data_dirs: list, augmentation_args: dict):
|
12 |
+
self.raw_data_dirs = data_dirs
|
13 |
+
self.augmentation_args = augmentation_args
|
14 |
+
self.timestep = hparams['hop_size'] / hparams['audio_sample_rate']
|
15 |
+
|
16 |
+
def process_item(self, item: dict, **kwargs) -> dict:
|
17 |
+
raise NotImplementedError()
|
18 |
+
|
19 |
+
|
20 |
+
def require_same_keys(func):
|
21 |
+
def run(*args, **kwargs):
|
22 |
+
item: dict = args[1]
|
23 |
+
res: dict = func(*args, **kwargs)
|
24 |
+
assert set(item.keys()) == set(res.keys()), 'Item keys mismatch after augmentation.\n' \
|
25 |
+
f'Before: {sorted(item.keys())}\n' \
|
26 |
+
f'After: {sorted(res.keys())}'
|
27 |
+
return res
|
28 |
+
return run
|
basics/base_binarizer.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import pathlib
|
3 |
+
import pickle
|
4 |
+
import random
|
5 |
+
import shutil
|
6 |
+
import warnings
|
7 |
+
from copy import deepcopy
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from utils.hparams import hparams
|
14 |
+
from utils.indexed_datasets import IndexedDatasetBuilder
|
15 |
+
from utils.multiprocess_utils import chunked_multiprocess_run
|
16 |
+
from utils.phoneme_utils import build_phoneme_list, locate_dictionary
|
17 |
+
from utils.plot import distribution_to_figure
|
18 |
+
from utils.text_encoder import TokenTextEncoder
|
19 |
+
|
20 |
+
|
21 |
+
class BinarizationError(Exception):
|
22 |
+
pass
|
23 |
+
|
24 |
+
|
25 |
+
class BaseBinarizer:
|
26 |
+
"""
|
27 |
+
Base class for data processing.
|
28 |
+
1. *process* and *process_data_split*:
|
29 |
+
process entire data, generate the train-test split (support parallel processing);
|
30 |
+
2. *process_item*:
|
31 |
+
process singe piece of data;
|
32 |
+
3. *get_pitch*:
|
33 |
+
infer the pitch using some algorithm;
|
34 |
+
4. *get_align*:
|
35 |
+
get the alignment using 'mel2ph' format (see https://arxiv.org/abs/1905.09263).
|
36 |
+
5. phoneme encoder, voice encoder, etc.
|
37 |
+
|
38 |
+
Subclasses should define:
|
39 |
+
1. *load_metadata*:
|
40 |
+
how to read multiple datasets from files;
|
41 |
+
2. *train_item_names*, *valid_item_names*, *test_item_names*:
|
42 |
+
how to split the dataset;
|
43 |
+
3. load_ph_set:
|
44 |
+
the phoneme set.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self, data_dir=None, data_attrs=None):
|
48 |
+
if data_dir is None:
|
49 |
+
data_dir = hparams['raw_data_dir']
|
50 |
+
if not isinstance(data_dir, list):
|
51 |
+
data_dir = [data_dir]
|
52 |
+
|
53 |
+
self.raw_data_dirs = [pathlib.Path(d) for d in data_dir]
|
54 |
+
self.binary_data_dir = pathlib.Path(hparams['binary_data_dir'])
|
55 |
+
self.data_attrs = [] if data_attrs is None else data_attrs
|
56 |
+
|
57 |
+
self.binarization_args = hparams['binarization_args']
|
58 |
+
self.augmentation_args = hparams.get('augmentation_args', {})
|
59 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
60 |
+
|
61 |
+
self.spk_map = None
|
62 |
+
self.spk_ids = hparams['spk_ids']
|
63 |
+
self.speakers = hparams['speakers']
|
64 |
+
self.build_spk_map()
|
65 |
+
|
66 |
+
self.items = {}
|
67 |
+
self.item_names: list = None
|
68 |
+
self._train_item_names: list = None
|
69 |
+
self._valid_item_names: list = None
|
70 |
+
|
71 |
+
self.phone_encoder = TokenTextEncoder(vocab_list=build_phoneme_list())
|
72 |
+
self.timestep = hparams['hop_size'] / hparams['audio_sample_rate']
|
73 |
+
|
74 |
+
def build_spk_map(self):
|
75 |
+
assert isinstance(self.speakers, list), 'Speakers must be a list'
|
76 |
+
assert len(self.speakers) == len(self.raw_data_dirs), \
|
77 |
+
'Number of raw data dirs must equal number of speaker names!'
|
78 |
+
if len(self.spk_ids) == 0:
|
79 |
+
self.spk_ids = list(range(len(self.raw_data_dirs)))
|
80 |
+
else:
|
81 |
+
assert len(self.spk_ids) == len(self.raw_data_dirs), \
|
82 |
+
'Length of explicitly given spk_ids must equal the number of raw datasets.'
|
83 |
+
assert max(self.spk_ids) < hparams['num_spk'], \
|
84 |
+
f'Index in spk_id sequence {self.spk_ids} is out of range. All values should be smaller than num_spk.'
|
85 |
+
|
86 |
+
self.spk_map = {}
|
87 |
+
for spk_name, spk_id in zip(self.speakers, self.spk_ids):
|
88 |
+
if spk_name in self.spk_map and self.spk_map[spk_name] != spk_id:
|
89 |
+
raise ValueError(f'Invalid speaker ID assignment. Name \'{spk_name}\' is assigned '
|
90 |
+
f'with different speaker IDs: {self.spk_map[spk_name]} and {spk_id}.')
|
91 |
+
self.spk_map[spk_name] = spk_id
|
92 |
+
|
93 |
+
print("| spk_map: ", self.spk_map)
|
94 |
+
|
95 |
+
def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id):
|
96 |
+
raise NotImplementedError()
|
97 |
+
|
98 |
+
def split_train_valid_set(self, item_names):
|
99 |
+
"""
|
100 |
+
Split the dataset into training set and validation set.
|
101 |
+
:return: train_item_names, valid_item_names
|
102 |
+
"""
|
103 |
+
prefixes = {str(pr): 1 for pr in hparams['test_prefixes']}
|
104 |
+
valid_item_names = {}
|
105 |
+
# Add prefixes that specified speaker index and matches exactly item name to test set
|
106 |
+
for prefix in deepcopy(prefixes):
|
107 |
+
if prefix in item_names:
|
108 |
+
valid_item_names[prefix] = 1
|
109 |
+
prefixes.pop(prefix)
|
110 |
+
# Add prefixes that exactly matches item name without speaker id to test set
|
111 |
+
for prefix in deepcopy(prefixes):
|
112 |
+
matched = False
|
113 |
+
for name in item_names:
|
114 |
+
if name.split(':')[-1] == prefix:
|
115 |
+
valid_item_names[name] = 1
|
116 |
+
matched = True
|
117 |
+
if matched:
|
118 |
+
prefixes.pop(prefix)
|
119 |
+
# Add names with one of the remaining prefixes to test set
|
120 |
+
for prefix in deepcopy(prefixes):
|
121 |
+
matched = False
|
122 |
+
for name in item_names:
|
123 |
+
if name.startswith(prefix):
|
124 |
+
valid_item_names[name] = 1
|
125 |
+
matched = True
|
126 |
+
if matched:
|
127 |
+
prefixes.pop(prefix)
|
128 |
+
for prefix in deepcopy(prefixes):
|
129 |
+
matched = False
|
130 |
+
for name in item_names:
|
131 |
+
if name.split(':')[-1].startswith(prefix):
|
132 |
+
valid_item_names[name] = 1
|
133 |
+
matched = True
|
134 |
+
if matched:
|
135 |
+
prefixes.pop(prefix)
|
136 |
+
|
137 |
+
if len(prefixes) != 0:
|
138 |
+
warnings.warn(
|
139 |
+
f'The following rules in test_prefixes have no matching names in the dataset: {", ".join(prefixes.keys())}',
|
140 |
+
category=UserWarning
|
141 |
+
)
|
142 |
+
warnings.filterwarnings('default')
|
143 |
+
|
144 |
+
valid_item_names = list(valid_item_names.keys())
|
145 |
+
assert len(valid_item_names) > 0, 'Validation set is empty!'
|
146 |
+
train_item_names = [x for x in item_names if x not in set(valid_item_names)]
|
147 |
+
assert len(train_item_names) > 0, 'Training set is empty!'
|
148 |
+
|
149 |
+
return train_item_names, valid_item_names
|
150 |
+
|
151 |
+
@property
|
152 |
+
def train_item_names(self):
|
153 |
+
return self._train_item_names
|
154 |
+
|
155 |
+
@property
|
156 |
+
def valid_item_names(self):
|
157 |
+
return self._valid_item_names
|
158 |
+
|
159 |
+
def meta_data_iterator(self, prefix):
|
160 |
+
if prefix == 'train':
|
161 |
+
item_names = self.train_item_names
|
162 |
+
else:
|
163 |
+
item_names = self.valid_item_names
|
164 |
+
for item_name in item_names:
|
165 |
+
meta_data = self.items[item_name]
|
166 |
+
yield item_name, meta_data
|
167 |
+
|
168 |
+
def process(self):
|
169 |
+
# load each dataset
|
170 |
+
for ds_id, spk_id, data_dir in zip(range(len(self.raw_data_dirs)), self.spk_ids, self.raw_data_dirs):
|
171 |
+
self.load_meta_data(pathlib.Path(data_dir), ds_id=ds_id, spk_id=spk_id)
|
172 |
+
self.item_names = sorted(list(self.items.keys()))
|
173 |
+
self._train_item_names, self._valid_item_names = self.split_train_valid_set(self.item_names)
|
174 |
+
|
175 |
+
if self.binarization_args['shuffle']:
|
176 |
+
random.shuffle(self.item_names)
|
177 |
+
|
178 |
+
self.binary_data_dir.mkdir(parents=True, exist_ok=True)
|
179 |
+
|
180 |
+
# Copy spk_map and dictionary to binary data dir
|
181 |
+
spk_map_fn = self.binary_data_dir / 'spk_map.json'
|
182 |
+
with open(spk_map_fn, 'w', encoding='utf-8') as f:
|
183 |
+
json.dump(self.spk_map, f)
|
184 |
+
shutil.copy(locate_dictionary(), self.binary_data_dir / 'dictionary.txt')
|
185 |
+
self.check_coverage()
|
186 |
+
|
187 |
+
# Process valid set and train set
|
188 |
+
try:
|
189 |
+
self.process_dataset('valid')
|
190 |
+
self.process_dataset(
|
191 |
+
'train',
|
192 |
+
num_workers=int(self.binarization_args['num_workers']),
|
193 |
+
apply_augmentation=any(args['enabled'] for args in self.augmentation_args.values())
|
194 |
+
)
|
195 |
+
except KeyboardInterrupt:
|
196 |
+
exit(-1)
|
197 |
+
|
198 |
+
def check_coverage(self):
|
199 |
+
# Group by phonemes in the dictionary.
|
200 |
+
ph_required = set(build_phoneme_list())
|
201 |
+
phoneme_map = {}
|
202 |
+
for ph in ph_required:
|
203 |
+
phoneme_map[ph] = 0
|
204 |
+
ph_occurred = []
|
205 |
+
|
206 |
+
# Load and count those phones that appear in the actual data
|
207 |
+
for item_name in self.items:
|
208 |
+
ph_occurred += self.items[item_name]['ph_seq']
|
209 |
+
if len(ph_occurred) == 0:
|
210 |
+
raise BinarizationError(f'Empty tokens in {item_name}.')
|
211 |
+
for ph in ph_occurred:
|
212 |
+
if ph not in ph_required:
|
213 |
+
continue
|
214 |
+
phoneme_map[ph] += 1
|
215 |
+
ph_occurred = set(ph_occurred)
|
216 |
+
|
217 |
+
print('===== Phoneme Distribution Summary =====')
|
218 |
+
for i, key in enumerate(sorted(phoneme_map.keys())):
|
219 |
+
if i == len(ph_required) - 1:
|
220 |
+
end = '\n'
|
221 |
+
elif i % 10 == 9:
|
222 |
+
end = ',\n'
|
223 |
+
else:
|
224 |
+
end = ', '
|
225 |
+
print(f'\'{key}\': {phoneme_map[key]}', end=end)
|
226 |
+
|
227 |
+
# Draw graph.
|
228 |
+
x = sorted(phoneme_map.keys())
|
229 |
+
values = [phoneme_map[k] for k in x]
|
230 |
+
plt = distribution_to_figure(
|
231 |
+
title='Phoneme Distribution Summary',
|
232 |
+
x_label='Phoneme', y_label='Number of occurrences',
|
233 |
+
items=x, values=values
|
234 |
+
)
|
235 |
+
filename = self.binary_data_dir / 'phoneme_distribution.jpg'
|
236 |
+
plt.savefig(fname=filename,
|
237 |
+
bbox_inches='tight',
|
238 |
+
pad_inches=0.25)
|
239 |
+
print(f'| save summary to \'{filename}\'')
|
240 |
+
|
241 |
+
# Check unrecognizable or missing phonemes
|
242 |
+
if ph_occurred != ph_required:
|
243 |
+
unrecognizable_phones = ph_occurred.difference(ph_required)
|
244 |
+
missing_phones = ph_required.difference(ph_occurred)
|
245 |
+
raise BinarizationError('transcriptions and dictionary mismatch.\n'
|
246 |
+
f' (+) {sorted(unrecognizable_phones)}\n'
|
247 |
+
f' (-) {sorted(missing_phones)}')
|
248 |
+
|
249 |
+
def process_dataset(self, prefix, num_workers=0, apply_augmentation=False):
|
250 |
+
args = []
|
251 |
+
builder = IndexedDatasetBuilder(self.binary_data_dir, prefix=prefix, allowed_attr=self.data_attrs)
|
252 |
+
total_sec = {k: 0.0 for k in self.spk_map}
|
253 |
+
total_raw_sec = {k: 0.0 for k in self.spk_map}
|
254 |
+
extra_info = {'names': {}, 'spk_ids': {}, 'spk_names': {}, 'lengths': {}}
|
255 |
+
max_no = -1
|
256 |
+
|
257 |
+
for item_name, meta_data in self.meta_data_iterator(prefix):
|
258 |
+
args.append([item_name, meta_data, self.binarization_args])
|
259 |
+
|
260 |
+
aug_map = self.arrange_data_augmentation(self.meta_data_iterator(prefix)) if apply_augmentation else {}
|
261 |
+
|
262 |
+
def postprocess(_item):
|
263 |
+
nonlocal total_sec, total_raw_sec, extra_info, max_no
|
264 |
+
if _item is None:
|
265 |
+
return
|
266 |
+
item_no = builder.add_item(_item)
|
267 |
+
max_no = max(max_no, item_no)
|
268 |
+
for k, v in _item.items():
|
269 |
+
if isinstance(v, np.ndarray):
|
270 |
+
if k not in extra_info:
|
271 |
+
extra_info[k] = {}
|
272 |
+
extra_info[k][item_no] = v.shape[0]
|
273 |
+
extra_info['names'][item_no] = _item['name'].split(':', 1)[-1]
|
274 |
+
extra_info['spk_ids'][item_no] = _item['spk_id']
|
275 |
+
extra_info['spk_names'][item_no] = _item['spk_name']
|
276 |
+
extra_info['lengths'][item_no] = _item['length']
|
277 |
+
total_raw_sec[_item['spk_name']] += _item['seconds']
|
278 |
+
total_sec[_item['spk_name']] += _item['seconds']
|
279 |
+
|
280 |
+
for task in aug_map.get(_item['name'], []):
|
281 |
+
aug_item = task['func'](_item, **task['kwargs'])
|
282 |
+
aug_item_no = builder.add_item(aug_item)
|
283 |
+
max_no = max(max_no, aug_item_no)
|
284 |
+
for k, v in aug_item.items():
|
285 |
+
if isinstance(v, np.ndarray):
|
286 |
+
if k not in extra_info:
|
287 |
+
extra_info[k] = {}
|
288 |
+
extra_info[k][aug_item_no] = v.shape[0]
|
289 |
+
extra_info['names'][aug_item_no] = aug_item['name'].split(':', 1)[-1]
|
290 |
+
extra_info['spk_ids'][aug_item_no] = aug_item['spk_id']
|
291 |
+
extra_info['spk_names'][aug_item_no] = aug_item['spk_name']
|
292 |
+
extra_info['lengths'][aug_item_no] = aug_item['length']
|
293 |
+
total_sec[aug_item['spk_name']] += aug_item['seconds']
|
294 |
+
|
295 |
+
try:
|
296 |
+
if num_workers > 0:
|
297 |
+
# code for parallel processing
|
298 |
+
for item in tqdm(
|
299 |
+
chunked_multiprocess_run(self.process_item, args, num_workers=num_workers),
|
300 |
+
total=len(list(self.meta_data_iterator(prefix)))
|
301 |
+
):
|
302 |
+
postprocess(item)
|
303 |
+
else:
|
304 |
+
# code for single cpu processing
|
305 |
+
for a in tqdm(args):
|
306 |
+
item = self.process_item(*a)
|
307 |
+
postprocess(item)
|
308 |
+
for k in extra_info:
|
309 |
+
assert set(extra_info[k]) == set(range(max_no + 1)), f'Item numbering is not consecutive.'
|
310 |
+
extra_info[k] = list(map(lambda x: x[1], sorted(extra_info[k].items(), key=lambda x: x[0])))
|
311 |
+
except KeyboardInterrupt:
|
312 |
+
builder.finalize()
|
313 |
+
raise
|
314 |
+
|
315 |
+
builder.finalize()
|
316 |
+
if prefix == "train":
|
317 |
+
extra_info.pop("names")
|
318 |
+
extra_info.pop("spk_names")
|
319 |
+
with open(self.binary_data_dir / f"{prefix}.meta", "wb") as f:
|
320 |
+
# noinspection PyTypeChecker
|
321 |
+
pickle.dump(extra_info, f)
|
322 |
+
if apply_augmentation:
|
323 |
+
print(f"| {prefix} total duration (before augmentation): {sum(total_raw_sec.values()):.2f}s")
|
324 |
+
print(
|
325 |
+
f"| {prefix} respective duration (before augmentation): "
|
326 |
+
+ ', '.join(f'{k}={v:.2f}s' for k, v in total_raw_sec.items())
|
327 |
+
)
|
328 |
+
print(
|
329 |
+
f"| {prefix} total duration (after augmentation): "
|
330 |
+
f"{sum(total_sec.values()):.2f}s ({sum(total_sec.values()) / sum(total_raw_sec.values()):.2f}x)"
|
331 |
+
)
|
332 |
+
print(
|
333 |
+
f"| {prefix} respective duration (after augmentation): "
|
334 |
+
+ ', '.join(f'{k}={v:.2f}s' for k, v in total_sec.items())
|
335 |
+
)
|
336 |
+
else:
|
337 |
+
print(f"| {prefix} total duration: {sum(total_raw_sec.values()):.2f}s")
|
338 |
+
print(f"| {prefix} respective duration: " + ', '.join(f'{k}={v:.2f}s' for k, v in total_raw_sec.items()))
|
339 |
+
|
340 |
+
def arrange_data_augmentation(self, data_iterator):
|
341 |
+
"""
|
342 |
+
Code for all types of data augmentation should be added here.
|
343 |
+
"""
|
344 |
+
raise NotImplementedError()
|
345 |
+
|
346 |
+
def process_item(self, item_name, meta_data, binarization_args):
|
347 |
+
raise NotImplementedError()
|
basics/base_dataset.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
from utils.hparams import hparams
|
8 |
+
from utils.indexed_datasets import IndexedDataset
|
9 |
+
|
10 |
+
|
11 |
+
class BaseDataset(Dataset):
|
12 |
+
"""
|
13 |
+
Base class for datasets.
|
14 |
+
1. *sizes*:
|
15 |
+
clipped length if "max_frames" is set;
|
16 |
+
2. *num_frames*:
|
17 |
+
unclipped length.
|
18 |
+
|
19 |
+
Subclasses should define:
|
20 |
+
1. *collate*:
|
21 |
+
take the longest data, pad other data to the same length;
|
22 |
+
2. *__getitem__*:
|
23 |
+
the index function.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, prefix, size_key='lengths', preload=False):
|
27 |
+
super().__init__()
|
28 |
+
self.prefix = prefix
|
29 |
+
self.data_dir = hparams['binary_data_dir']
|
30 |
+
with open(os.path.join(self.data_dir, f'{self.prefix}.meta'), 'rb') as f:
|
31 |
+
self.metadata = pickle.load(f)
|
32 |
+
self.sizes = self.metadata[size_key]
|
33 |
+
self._indexed_ds = IndexedDataset(self.data_dir, self.prefix)
|
34 |
+
if preload:
|
35 |
+
self.indexed_ds = [self._indexed_ds[i] for i in range(len(self._indexed_ds))]
|
36 |
+
del self._indexed_ds
|
37 |
+
else:
|
38 |
+
self.indexed_ds = self._indexed_ds
|
39 |
+
|
40 |
+
def __getitem__(self, index):
|
41 |
+
return {'_idx': index, **self.indexed_ds[index]}
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.sizes)
|
45 |
+
|
46 |
+
def num_frames(self, index):
|
47 |
+
return self.sizes[index]
|
48 |
+
|
49 |
+
def size(self, index):
|
50 |
+
"""Return an example's size as a float or tuple. This value is used when
|
51 |
+
filtering a dataset with ``--max-positions``."""
|
52 |
+
return self.sizes[index]
|
53 |
+
|
54 |
+
def collater(self, samples):
|
55 |
+
return {
|
56 |
+
'size': len(samples),
|
57 |
+
'indices': torch.LongTensor([s['_idx'] for s in samples])
|
58 |
+
}
|
basics/base_exporter.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from utils.hparams import hparams
|
9 |
+
|
10 |
+
|
11 |
+
class BaseExporter:
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
device: Union[str, torch.device] = None,
|
15 |
+
cache_dir: Path = None,
|
16 |
+
**kwargs
|
17 |
+
):
|
18 |
+
self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
19 |
+
self.cache_dir: Path = cache_dir.resolve() if cache_dir is not None \
|
20 |
+
else Path(__file__).parent.parent / 'deployment' / 'cache'
|
21 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
22 |
+
|
23 |
+
# noinspection PyMethodMayBeStatic
|
24 |
+
def build_spk_map(self) -> dict:
|
25 |
+
if hparams['use_spk_id']:
|
26 |
+
with open(Path(hparams['work_dir']) / 'spk_map.json', 'r', encoding='utf8') as f:
|
27 |
+
spk_map = json.load(f)
|
28 |
+
assert isinstance(spk_map, dict) and len(spk_map) > 0, 'Invalid or empty speaker map!'
|
29 |
+
assert len(spk_map) == len(set(spk_map.values())), 'Duplicate speaker id in speaker map!'
|
30 |
+
return spk_map
|
31 |
+
else:
|
32 |
+
return {}
|
33 |
+
|
34 |
+
def build_model(self) -> nn.Module:
|
35 |
+
"""
|
36 |
+
Creates an instance of nn.Module and load its state dict on the target device.
|
37 |
+
"""
|
38 |
+
raise NotImplementedError()
|
39 |
+
|
40 |
+
def export_model(self, path: Path):
|
41 |
+
"""
|
42 |
+
Exports the model to ONNX format.
|
43 |
+
:param path: the target model path
|
44 |
+
"""
|
45 |
+
raise NotImplementedError()
|
46 |
+
|
47 |
+
def export_attachments(self, path: Path):
|
48 |
+
"""
|
49 |
+
Exports related files and configs (e.g. the dictionary) to the target directory.
|
50 |
+
:param path: the target directory
|
51 |
+
"""
|
52 |
+
raise NotImplementedError()
|
53 |
+
|
54 |
+
def export(self, path: Path):
|
55 |
+
"""
|
56 |
+
Exports all the artifacts to the target directory.
|
57 |
+
:param path: the target directory
|
58 |
+
"""
|
59 |
+
raise NotImplementedError()
|
basics/base_module.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
class CategorizedModule(nn.Module):
|
5 |
+
@property
|
6 |
+
def category(self):
|
7 |
+
raise NotImplementedError()
|
8 |
+
|
9 |
+
def check_category(self, category):
|
10 |
+
if category is None:
|
11 |
+
raise RuntimeError('Category is not specified in this checkpoint.\n'
|
12 |
+
'If this is a checkpoint in the old format, please consider '
|
13 |
+
'migrating it to the new format via the following command:\n'
|
14 |
+
'python scripts/migrate.py ckpt <INPUT_CKPT> <OUTPUT_CKPT>')
|
15 |
+
elif category != self.category:
|
16 |
+
raise RuntimeError('Category mismatches!\n'
|
17 |
+
f'This checkpoint is of the category \'{category}\', '
|
18 |
+
f'but a checkpoint of category \'{self.category}\' is required.')
|
basics/base_pe.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class BasePE:
|
2 |
+
def get_pitch(
|
3 |
+
self, waveform, samplerate, length,
|
4 |
+
*, hop_size, f0_min=65, f0_max=1100,
|
5 |
+
speed=1, interp_uv=False
|
6 |
+
):
|
7 |
+
raise NotImplementedError()
|
basics/base_svs_infer.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf8
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import Tensor
|
5 |
+
from typing import Tuple, Dict
|
6 |
+
|
7 |
+
from utils.hparams import hparams
|
8 |
+
from utils.infer_utils import resample_align_curve
|
9 |
+
|
10 |
+
|
11 |
+
class BaseSVSInfer:
|
12 |
+
"""
|
13 |
+
Base class for SVS inference models.
|
14 |
+
Subclasses should define:
|
15 |
+
1. *build_model*:
|
16 |
+
how to build the model;
|
17 |
+
2. *run_model*:
|
18 |
+
how to run the model (typically, generate a mel-spectrogram and
|
19 |
+
pass it to the pre-built vocoder);
|
20 |
+
3. *preprocess_input*:
|
21 |
+
how to preprocess user input.
|
22 |
+
4. *infer_once*
|
23 |
+
infer from raw inputs to the final outputs
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, device=None):
|
27 |
+
if device is None:
|
28 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
29 |
+
self.device = device
|
30 |
+
self.timestep = hparams['hop_size'] / hparams['audio_sample_rate']
|
31 |
+
self.spk_map = {}
|
32 |
+
self.model: torch.nn.Module = None
|
33 |
+
|
34 |
+
def build_model(self, ckpt_steps=None) -> torch.nn.Module:
|
35 |
+
raise NotImplementedError()
|
36 |
+
|
37 |
+
def load_speaker_mix(self, param_src: dict, summary_dst: dict,
|
38 |
+
mix_mode: str = 'frame', mix_length: int = None) -> Tuple[Tensor, Tensor]:
|
39 |
+
"""
|
40 |
+
|
41 |
+
:param param_src: param dict
|
42 |
+
:param summary_dst: summary dict
|
43 |
+
:param mix_mode: 'token' or 'frame'
|
44 |
+
:param mix_length: total tokens or frames to mix
|
45 |
+
:return: spk_mix_id [B=1, 1, N], spk_mix_value [B=1, T, N]
|
46 |
+
"""
|
47 |
+
assert mix_mode == 'token' or mix_mode == 'frame'
|
48 |
+
param_key = 'spk_mix' if mix_mode == 'frame' else 'ph_spk_mix'
|
49 |
+
summary_solo_key = 'spk' if mix_mode == 'frame' else 'ph_spk'
|
50 |
+
spk_mix_map = param_src.get(param_key) # { spk_name: value } or { spk_name: "value value value ..." }
|
51 |
+
dynamic = False
|
52 |
+
if spk_mix_map is None:
|
53 |
+
# Get the first speaker
|
54 |
+
for name in self.spk_map.keys():
|
55 |
+
spk_mix_map = {name: 1.0}
|
56 |
+
break
|
57 |
+
else:
|
58 |
+
for name in spk_mix_map:
|
59 |
+
assert name in self.spk_map, f'Speaker \'{name}\' not found.'
|
60 |
+
if len(spk_mix_map) == 1:
|
61 |
+
summary_dst[summary_solo_key] = list(spk_mix_map.keys())[0]
|
62 |
+
elif any([isinstance(val, str) for val in spk_mix_map.values()]):
|
63 |
+
print_mix = '|'.join(spk_mix_map.keys())
|
64 |
+
summary_dst[param_key] = f'dynamic({print_mix})'
|
65 |
+
dynamic = True
|
66 |
+
else:
|
67 |
+
print_mix = '|'.join([f'{n}:{"%.3f" % spk_mix_map[n]}' for n in spk_mix_map])
|
68 |
+
summary_dst[param_key] = f'static({print_mix})'
|
69 |
+
spk_mix_id_list = []
|
70 |
+
spk_mix_value_list = []
|
71 |
+
if dynamic:
|
72 |
+
for name, values in spk_mix_map.items():
|
73 |
+
spk_mix_id_list.append(self.spk_map[name])
|
74 |
+
if isinstance(values, str):
|
75 |
+
# this speaker has a variable proportion
|
76 |
+
if mix_mode == 'token':
|
77 |
+
cur_spk_mix_value = values.split()
|
78 |
+
assert len(cur_spk_mix_value) == mix_length, \
|
79 |
+
'Speaker mix checks failed. In dynamic token-level mix, ' \
|
80 |
+
'number of proportion values must equal number of tokens.'
|
81 |
+
cur_spk_mix_value = torch.from_numpy(
|
82 |
+
np.array(cur_spk_mix_value, 'float32')
|
83 |
+
).to(self.device)[None] # => [B=1, T]
|
84 |
+
else:
|
85 |
+
cur_spk_mix_value = torch.from_numpy(resample_align_curve(
|
86 |
+
np.array(values.split(), 'float32'),
|
87 |
+
original_timestep=float(param_src['spk_mix_timestep']),
|
88 |
+
target_timestep=self.timestep,
|
89 |
+
align_length=mix_length
|
90 |
+
)).to(self.device)[None] # => [B=1, T]
|
91 |
+
assert torch.all(cur_spk_mix_value >= 0.), \
|
92 |
+
f'Speaker mix checks failed.\n' \
|
93 |
+
f'Proportions of speaker \'{name}\' on some {mix_mode}s are negative.'
|
94 |
+
else:
|
95 |
+
# this speaker has a constant proportion
|
96 |
+
assert values >= 0., f'Speaker mix checks failed.\n' \
|
97 |
+
f'Proportion of speaker \'{name}\' is negative.'
|
98 |
+
cur_spk_mix_value = torch.full(
|
99 |
+
(1, mix_length), fill_value=values,
|
100 |
+
dtype=torch.float32, device=self.device
|
101 |
+
)
|
102 |
+
spk_mix_value_list.append(cur_spk_mix_value)
|
103 |
+
spk_mix_id = torch.LongTensor(spk_mix_id_list).to(self.device)[None, None] # => [B=1, 1, N]
|
104 |
+
spk_mix_value = torch.stack(spk_mix_value_list, dim=2) # [B=1, T] => [B=1, T, N]
|
105 |
+
spk_mix_value_sum = torch.sum(spk_mix_value, dim=2, keepdim=True) # => [B=1, T, 1]
|
106 |
+
assert torch.all(spk_mix_value_sum > 0.), \
|
107 |
+
f'Speaker mix checks failed.\n' \
|
108 |
+
f'Proportions of speaker mix on some frames sum to zero.'
|
109 |
+
spk_mix_value /= spk_mix_value_sum # normalize
|
110 |
+
else:
|
111 |
+
for name, value in spk_mix_map.items():
|
112 |
+
spk_mix_id_list.append(self.spk_map[name])
|
113 |
+
assert value >= 0., f'Speaker mix checks failed.\n' \
|
114 |
+
f'Proportion of speaker \'{name}\' is negative.'
|
115 |
+
spk_mix_value_list.append(value)
|
116 |
+
spk_mix_id = torch.LongTensor(spk_mix_id_list).to(self.device)[None, None] # => [B=1, 1, N]
|
117 |
+
spk_mix_value = torch.FloatTensor(spk_mix_value_list).to(self.device)[None, None] # => [B=1, 1, N]
|
118 |
+
spk_mix_value_sum = spk_mix_value.sum()
|
119 |
+
assert spk_mix_value_sum > 0., f'Speaker mix checks failed.\n' \
|
120 |
+
f'Proportions of speaker mix sum to zero.'
|
121 |
+
spk_mix_value /= spk_mix_value_sum # normalize
|
122 |
+
return spk_mix_id, spk_mix_value
|
123 |
+
|
124 |
+
def preprocess_input(self, param: dict, idx=0) -> Dict[str, torch.Tensor]:
|
125 |
+
raise NotImplementedError()
|
126 |
+
|
127 |
+
def forward_model(self, sample: Dict[str, torch.Tensor]):
|
128 |
+
raise NotImplementedError()
|
129 |
+
|
130 |
+
def run_inference(self, params, **kwargs):
|
131 |
+
raise NotImplementedError()
|
basics/base_task.py
ADDED
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
import shutil
|
5 |
+
import sys
|
6 |
+
from typing import Dict
|
7 |
+
|
8 |
+
import matplotlib
|
9 |
+
|
10 |
+
import utils
|
11 |
+
from utils.text_encoder import TokenTextEncoder
|
12 |
+
|
13 |
+
matplotlib.use('Agg')
|
14 |
+
|
15 |
+
import torch.utils.data
|
16 |
+
from torchmetrics import Metric, MeanMetric
|
17 |
+
import lightning.pytorch as pl
|
18 |
+
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_only
|
19 |
+
|
20 |
+
from basics.base_module import CategorizedModule
|
21 |
+
from utils.hparams import hparams
|
22 |
+
from utils.training_utils import (
|
23 |
+
DsModelCheckpoint, DsTQDMProgressBar,
|
24 |
+
DsBatchSampler, DsTensorBoardLogger,
|
25 |
+
get_latest_checkpoint_path, get_strategy
|
26 |
+
)
|
27 |
+
from utils.phoneme_utils import locate_dictionary, build_phoneme_list
|
28 |
+
|
29 |
+
torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system'))
|
30 |
+
|
31 |
+
log_format = '%(asctime)s %(message)s'
|
32 |
+
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
33 |
+
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
34 |
+
|
35 |
+
|
36 |
+
class BaseTask(pl.LightningModule):
|
37 |
+
"""
|
38 |
+
Base class for training tasks.
|
39 |
+
1. *load_ckpt*:
|
40 |
+
load checkpoint;
|
41 |
+
2. *training_step*:
|
42 |
+
record and log the loss;
|
43 |
+
3. *optimizer_step*:
|
44 |
+
run backwards step;
|
45 |
+
4. *start*:
|
46 |
+
load training configs, backup code, log to tensorboard, start training;
|
47 |
+
5. *configure_ddp* and *init_ddp_connection*:
|
48 |
+
start parallel training.
|
49 |
+
|
50 |
+
Subclasses should define:
|
51 |
+
1. *build_model*, *build_optimizer*, *build_scheduler*:
|
52 |
+
how to build the model, the optimizer and the training scheduler;
|
53 |
+
2. *_training_step*:
|
54 |
+
one training step of the model;
|
55 |
+
3. *on_validation_end* and *_on_validation_end*:
|
56 |
+
postprocess the validation output.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, *args, **kwargs):
|
60 |
+
super().__init__(*args, **kwargs)
|
61 |
+
self.max_batch_frames = hparams['max_batch_frames']
|
62 |
+
self.max_batch_size = hparams['max_batch_size']
|
63 |
+
self.max_val_batch_frames = hparams['max_val_batch_frames']
|
64 |
+
if self.max_val_batch_frames == -1:
|
65 |
+
hparams['max_val_batch_frames'] = self.max_val_batch_frames = self.max_batch_frames
|
66 |
+
self.max_val_batch_size = hparams['max_val_batch_size']
|
67 |
+
if self.max_val_batch_size == -1:
|
68 |
+
hparams['max_val_batch_size'] = self.max_val_batch_size = self.max_batch_size
|
69 |
+
|
70 |
+
self.training_sampler = None
|
71 |
+
self.skip_immediate_validation = False
|
72 |
+
self.skip_immediate_ckpt_save = False
|
73 |
+
|
74 |
+
self.phone_encoder = self.build_phone_encoder()
|
75 |
+
self.build_model()
|
76 |
+
|
77 |
+
self.valid_losses: Dict[str, Metric] = {}
|
78 |
+
self.valid_metrics: Dict[str, Metric] = {}
|
79 |
+
|
80 |
+
def _finish_init(self):
|
81 |
+
self.register_validation_loss('total_loss')
|
82 |
+
self.build_losses_and_metrics()
|
83 |
+
assert len(self.valid_losses) > 0, "No validation loss registered. Please check your configuration file."
|
84 |
+
|
85 |
+
###########
|
86 |
+
# Training, validation and testing
|
87 |
+
###########
|
88 |
+
def setup(self, stage):
|
89 |
+
self.train_dataset = self.dataset_cls('train')
|
90 |
+
self.valid_dataset = self.dataset_cls('valid')
|
91 |
+
self.num_replicas = (self.trainer.distributed_sampler_kwargs or {}).get('num_replicas', 1)
|
92 |
+
|
93 |
+
def get_need_freeze_state_dict_key(self, model_state_dict) -> list:
|
94 |
+
key_list = []
|
95 |
+
for i in hparams['frozen_params']:
|
96 |
+
for j in model_state_dict:
|
97 |
+
if j.startswith(i):
|
98 |
+
key_list.append(j)
|
99 |
+
return list(set(key_list))
|
100 |
+
|
101 |
+
def freeze_params(self) -> None:
|
102 |
+
model_state_dict = self.state_dict().keys()
|
103 |
+
freeze_key = self.get_need_freeze_state_dict_key(model_state_dict=model_state_dict)
|
104 |
+
|
105 |
+
for i in freeze_key:
|
106 |
+
params=self.get_parameter(i)
|
107 |
+
|
108 |
+
params.requires_grad = False
|
109 |
+
|
110 |
+
def unfreeze_all_params(self) -> None:
|
111 |
+
for i in self.model.parameters():
|
112 |
+
i.requires_grad = True
|
113 |
+
|
114 |
+
def load_finetune_ckpt(
|
115 |
+
self, state_dict
|
116 |
+
) -> None:
|
117 |
+
adapt_shapes = hparams['finetune_strict_shapes']
|
118 |
+
if not adapt_shapes:
|
119 |
+
cur_model_state_dict = self.state_dict()
|
120 |
+
unmatched_keys = []
|
121 |
+
for key, param in state_dict.items():
|
122 |
+
if key in cur_model_state_dict:
|
123 |
+
new_param = cur_model_state_dict[key]
|
124 |
+
if new_param.shape != param.shape:
|
125 |
+
unmatched_keys.append(key)
|
126 |
+
print('| Unmatched keys: ', key, new_param.shape, param.shape)
|
127 |
+
for key in unmatched_keys:
|
128 |
+
del state_dict[key]
|
129 |
+
self.load_state_dict(state_dict, strict=False)
|
130 |
+
|
131 |
+
def load_pre_train_model(self):
|
132 |
+
pre_train_ckpt_path = hparams['finetune_ckpt_path']
|
133 |
+
blacklist = hparams['finetune_ignored_params']
|
134 |
+
# whitelist=hparams['pre_train_whitelist']
|
135 |
+
if blacklist is None:
|
136 |
+
blacklist = []
|
137 |
+
# if whitelist is None:
|
138 |
+
# raise RuntimeError("")
|
139 |
+
|
140 |
+
if pre_train_ckpt_path is not None:
|
141 |
+
ckpt = torch.load(pre_train_ckpt_path)
|
142 |
+
# if ckpt.get('category') is None:
|
143 |
+
# raise RuntimeError("")
|
144 |
+
|
145 |
+
if isinstance(self.model, CategorizedModule):
|
146 |
+
self.model.check_category(ckpt.get('category'))
|
147 |
+
|
148 |
+
state_dict = {}
|
149 |
+
for i in ckpt['state_dict']:
|
150 |
+
# if 'diffusion' in i:
|
151 |
+
# if i in rrrr:
|
152 |
+
# continue
|
153 |
+
skip = False
|
154 |
+
for b in blacklist:
|
155 |
+
if i.startswith(b):
|
156 |
+
skip = True
|
157 |
+
break
|
158 |
+
|
159 |
+
if skip:
|
160 |
+
continue
|
161 |
+
|
162 |
+
state_dict[i] = ckpt['state_dict'][i]
|
163 |
+
print(i)
|
164 |
+
return state_dict
|
165 |
+
else:
|
166 |
+
raise RuntimeError("")
|
167 |
+
|
168 |
+
@staticmethod
|
169 |
+
def build_phone_encoder():
|
170 |
+
phone_list = build_phoneme_list()
|
171 |
+
return TokenTextEncoder(vocab_list=phone_list)
|
172 |
+
|
173 |
+
def _build_model(self):
|
174 |
+
raise NotImplementedError()
|
175 |
+
|
176 |
+
def build_model(self):
|
177 |
+
self.model = self._build_model()
|
178 |
+
# utils.load_warp(self)
|
179 |
+
self.unfreeze_all_params()
|
180 |
+
if hparams['freezing_enabled']:
|
181 |
+
self.freeze_params()
|
182 |
+
if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None:
|
183 |
+
self.load_finetune_ckpt(self.load_pre_train_model())
|
184 |
+
self.print_arch()
|
185 |
+
|
186 |
+
@rank_zero_only
|
187 |
+
def print_arch(self):
|
188 |
+
utils.print_arch(self.model)
|
189 |
+
|
190 |
+
def build_losses_and_metrics(self):
|
191 |
+
raise NotImplementedError()
|
192 |
+
|
193 |
+
def register_validation_metric(self, name: str, metric: Metric):
|
194 |
+
assert isinstance(metric, Metric)
|
195 |
+
self.valid_metrics[name] = metric
|
196 |
+
|
197 |
+
def register_validation_loss(self, name: str, Aggregator: Metric = MeanMetric):
|
198 |
+
assert issubclass(Aggregator, Metric)
|
199 |
+
self.valid_losses[name] = Aggregator()
|
200 |
+
|
201 |
+
def run_model(self, sample, infer=False):
|
202 |
+
"""
|
203 |
+
steps:
|
204 |
+
1. run the full model
|
205 |
+
2. calculate losses if not infer
|
206 |
+
"""
|
207 |
+
raise NotImplementedError()
|
208 |
+
|
209 |
+
def on_train_epoch_start(self):
|
210 |
+
if self.training_sampler is not None:
|
211 |
+
self.training_sampler.set_epoch(self.current_epoch)
|
212 |
+
|
213 |
+
def _training_step(self, sample):
|
214 |
+
"""
|
215 |
+
:return: total loss: torch.Tensor, loss_log: dict, other_log: dict
|
216 |
+
"""
|
217 |
+
losses = self.run_model(sample)
|
218 |
+
total_loss = sum(losses.values())
|
219 |
+
return total_loss, {**losses, 'batch_size': float(sample['size'])}
|
220 |
+
|
221 |
+
def training_step(self, sample, batch_idx):
|
222 |
+
total_loss, log_outputs = self._training_step(sample)
|
223 |
+
|
224 |
+
# logs to progress bar
|
225 |
+
self.log_dict(log_outputs, prog_bar=True, logger=False, on_step=True, on_epoch=False)
|
226 |
+
self.log('lr', self.lr_schedulers().get_last_lr()[0], prog_bar=True, logger=False, on_step=True, on_epoch=False)
|
227 |
+
# logs to tensorboard
|
228 |
+
if self.global_step % hparams['log_interval'] == 0:
|
229 |
+
tb_log = {f'training/{k}': v for k, v in log_outputs.items()}
|
230 |
+
tb_log['training/lr'] = self.lr_schedulers().get_last_lr()[0]
|
231 |
+
self.logger.log_metrics(tb_log, step=self.global_step)
|
232 |
+
|
233 |
+
return total_loss
|
234 |
+
|
235 |
+
# def on_before_optimizer_step(self, *args, **kwargs):
|
236 |
+
# self.log_dict(grad_norm(self, norm_type=2))
|
237 |
+
|
238 |
+
def _on_validation_start(self):
|
239 |
+
pass
|
240 |
+
|
241 |
+
def on_validation_start(self):
|
242 |
+
if self.skip_immediate_validation:
|
243 |
+
rank_zero_debug("Skip validation")
|
244 |
+
return
|
245 |
+
self._on_validation_start()
|
246 |
+
for metric in self.valid_losses.values():
|
247 |
+
metric.to(self.device)
|
248 |
+
metric.reset()
|
249 |
+
for metric in self.valid_metrics.values():
|
250 |
+
metric.to(self.device)
|
251 |
+
metric.reset()
|
252 |
+
|
253 |
+
def _validation_step(self, sample, batch_idx):
|
254 |
+
"""
|
255 |
+
|
256 |
+
:param sample:
|
257 |
+
:param batch_idx:
|
258 |
+
:return: loss_log: dict, weight: int
|
259 |
+
"""
|
260 |
+
raise NotImplementedError()
|
261 |
+
|
262 |
+
def validation_step(self, sample, batch_idx):
|
263 |
+
"""
|
264 |
+
|
265 |
+
:param sample:
|
266 |
+
:param batch_idx:
|
267 |
+
"""
|
268 |
+
if self.skip_immediate_validation:
|
269 |
+
rank_zero_debug("Skip validation")
|
270 |
+
return
|
271 |
+
if sample['size'] > 0:
|
272 |
+
with torch.autocast(self.device.type, enabled=False):
|
273 |
+
losses, weight = self._validation_step(sample, batch_idx)
|
274 |
+
losses = {
|
275 |
+
'total_loss': sum(losses.values()),
|
276 |
+
**losses
|
277 |
+
}
|
278 |
+
for k, v in losses.items():
|
279 |
+
self.valid_losses[k].update(v, weight=weight)
|
280 |
+
|
281 |
+
def _on_validation_epoch_end(self):
|
282 |
+
pass
|
283 |
+
|
284 |
+
def on_validation_epoch_end(self):
|
285 |
+
if self.skip_immediate_validation:
|
286 |
+
self.skip_immediate_validation = False
|
287 |
+
self.skip_immediate_ckpt_save = True
|
288 |
+
return
|
289 |
+
self._on_validation_epoch_end()
|
290 |
+
loss_vals = {k: v.compute() for k, v in self.valid_losses.items()}
|
291 |
+
metric_vals = {k: v.compute() for k, v in self.valid_metrics.items()}
|
292 |
+
self.log('val_loss', loss_vals['total_loss'], on_epoch=True, prog_bar=True, logger=False, sync_dist=True)
|
293 |
+
self.logger.log_metrics({f'validation/{k}': v for k, v in loss_vals.items()}, step=self.global_step)
|
294 |
+
self.logger.log_metrics({f'metrics/{k}': v for k, v in metric_vals.items()}, step=self.global_step)
|
295 |
+
|
296 |
+
# noinspection PyMethodMayBeStatic
|
297 |
+
def build_scheduler(self, optimizer):
|
298 |
+
from utils import build_lr_scheduler_from_config
|
299 |
+
|
300 |
+
scheduler_args = hparams['lr_scheduler_args']
|
301 |
+
assert scheduler_args['scheduler_cls'] != ''
|
302 |
+
scheduler = build_lr_scheduler_from_config(optimizer, scheduler_args)
|
303 |
+
return scheduler
|
304 |
+
|
305 |
+
# noinspection PyMethodMayBeStatic
|
306 |
+
def build_optimizer(self, model):
|
307 |
+
from utils import build_object_from_class_name
|
308 |
+
|
309 |
+
optimizer_args = hparams['optimizer_args']
|
310 |
+
assert optimizer_args['optimizer_cls'] != ''
|
311 |
+
if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args:
|
312 |
+
optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2'])
|
313 |
+
optimizer = build_object_from_class_name(
|
314 |
+
optimizer_args['optimizer_cls'],
|
315 |
+
torch.optim.Optimizer,
|
316 |
+
model.parameters(),
|
317 |
+
**optimizer_args
|
318 |
+
)
|
319 |
+
return optimizer
|
320 |
+
|
321 |
+
def configure_optimizers(self):
|
322 |
+
optm = self.build_optimizer(self.model)
|
323 |
+
scheduler = self.build_scheduler(optm)
|
324 |
+
if scheduler is None:
|
325 |
+
return optm
|
326 |
+
return {
|
327 |
+
"optimizer": optm,
|
328 |
+
"lr_scheduler": {
|
329 |
+
"scheduler": scheduler,
|
330 |
+
"interval": "step",
|
331 |
+
"frequency": 1
|
332 |
+
}
|
333 |
+
}
|
334 |
+
|
335 |
+
def train_dataloader(self):
|
336 |
+
self.training_sampler = DsBatchSampler(
|
337 |
+
self.train_dataset,
|
338 |
+
max_batch_frames=self.max_batch_frames,
|
339 |
+
max_batch_size=self.max_batch_size,
|
340 |
+
num_replicas=self.num_replicas,
|
341 |
+
rank=self.global_rank,
|
342 |
+
sort_by_similar_size=hparams['sort_by_len'],
|
343 |
+
size_reversed=True,
|
344 |
+
required_batch_count_multiple=hparams['accumulate_grad_batches'],
|
345 |
+
shuffle_sample=True,
|
346 |
+
shuffle_batch=True
|
347 |
+
)
|
348 |
+
return torch.utils.data.DataLoader(
|
349 |
+
self.train_dataset,
|
350 |
+
collate_fn=self.train_dataset.collater,
|
351 |
+
batch_sampler=self.training_sampler,
|
352 |
+
num_workers=hparams['ds_workers'],
|
353 |
+
prefetch_factor=hparams['dataloader_prefetch_factor'],
|
354 |
+
pin_memory=True,
|
355 |
+
persistent_workers=True
|
356 |
+
)
|
357 |
+
|
358 |
+
def val_dataloader(self):
|
359 |
+
sampler = DsBatchSampler(
|
360 |
+
self.valid_dataset,
|
361 |
+
max_batch_frames=self.max_val_batch_frames,
|
362 |
+
max_batch_size=self.max_val_batch_size,
|
363 |
+
num_replicas=self.num_replicas,
|
364 |
+
rank=self.global_rank,
|
365 |
+
shuffle_sample=False,
|
366 |
+
shuffle_batch=False,
|
367 |
+
disallow_empty_batch=False,
|
368 |
+
pad_batch_assignment=False
|
369 |
+
)
|
370 |
+
return torch.utils.data.DataLoader(
|
371 |
+
self.valid_dataset,
|
372 |
+
collate_fn=self.valid_dataset.collater,
|
373 |
+
batch_sampler=sampler,
|
374 |
+
num_workers=hparams['ds_workers'],
|
375 |
+
prefetch_factor=hparams['dataloader_prefetch_factor'],
|
376 |
+
persistent_workers=True
|
377 |
+
)
|
378 |
+
|
379 |
+
def test_dataloader(self):
|
380 |
+
return self.val_dataloader()
|
381 |
+
|
382 |
+
def on_test_start(self):
|
383 |
+
self.on_validation_start()
|
384 |
+
|
385 |
+
def test_step(self, sample, batch_idx):
|
386 |
+
return self.validation_step(sample, batch_idx)
|
387 |
+
|
388 |
+
def on_test_end(self):
|
389 |
+
return self.on_validation_end()
|
390 |
+
|
391 |
+
###########
|
392 |
+
# Running configuration
|
393 |
+
###########
|
394 |
+
|
395 |
+
@classmethod
|
396 |
+
def start(cls):
|
397 |
+
task = cls()
|
398 |
+
|
399 |
+
# if pre_train is not None:
|
400 |
+
# task.load_state_dict(pre_train,strict=False)
|
401 |
+
# print("load success-------------------------------------------------------------------")
|
402 |
+
|
403 |
+
work_dir = pathlib.Path(hparams['work_dir'])
|
404 |
+
trainer = pl.Trainer(
|
405 |
+
accelerator=hparams['pl_trainer_accelerator'],
|
406 |
+
devices=hparams['pl_trainer_devices'],
|
407 |
+
num_nodes=hparams['pl_trainer_num_nodes'],
|
408 |
+
strategy=get_strategy(
|
409 |
+
hparams['pl_trainer_devices'],
|
410 |
+
hparams['pl_trainer_num_nodes'],
|
411 |
+
hparams['pl_trainer_accelerator'],
|
412 |
+
hparams['pl_trainer_strategy'],
|
413 |
+
hparams['pl_trainer_precision'],
|
414 |
+
),
|
415 |
+
precision=hparams['pl_trainer_precision'],
|
416 |
+
callbacks=[
|
417 |
+
DsModelCheckpoint(
|
418 |
+
dirpath=work_dir,
|
419 |
+
filename='model_ckpt_steps_{step}',
|
420 |
+
auto_insert_metric_name=False,
|
421 |
+
monitor='step',
|
422 |
+
mode='max',
|
423 |
+
save_last=False,
|
424 |
+
# every_n_train_steps=hparams['val_check_interval'],
|
425 |
+
save_top_k=hparams['num_ckpt_keep'],
|
426 |
+
permanent_ckpt_start=hparams['permanent_ckpt_start'],
|
427 |
+
permanent_ckpt_interval=hparams['permanent_ckpt_interval'],
|
428 |
+
verbose=True
|
429 |
+
),
|
430 |
+
# LearningRateMonitor(logging_interval='step'),
|
431 |
+
DsTQDMProgressBar(),
|
432 |
+
],
|
433 |
+
logger=DsTensorBoardLogger(
|
434 |
+
save_dir=str(work_dir),
|
435 |
+
name='lightning_logs',
|
436 |
+
version='latest'
|
437 |
+
),
|
438 |
+
gradient_clip_val=hparams['clip_grad_norm'],
|
439 |
+
val_check_interval=hparams['val_check_interval'] * hparams['accumulate_grad_batches'],
|
440 |
+
# so this is global_steps
|
441 |
+
check_val_every_n_epoch=None,
|
442 |
+
log_every_n_steps=1,
|
443 |
+
max_steps=hparams['max_updates'],
|
444 |
+
use_distributed_sampler=False,
|
445 |
+
num_sanity_val_steps=hparams['num_sanity_val_steps'],
|
446 |
+
accumulate_grad_batches=hparams['accumulate_grad_batches']
|
447 |
+
)
|
448 |
+
if not hparams['infer']: # train
|
449 |
+
@rank_zero_only
|
450 |
+
def train_payload_copy():
|
451 |
+
# Copy spk_map.json and dictionary.txt to work dir
|
452 |
+
binary_dir = pathlib.Path(hparams['binary_data_dir'])
|
453 |
+
spk_map = work_dir / 'spk_map.json'
|
454 |
+
spk_map_src = binary_dir / 'spk_map.json'
|
455 |
+
if not spk_map.exists() and spk_map_src.exists():
|
456 |
+
shutil.copy(spk_map_src, spk_map)
|
457 |
+
print(f'| Copied spk map to {spk_map}.')
|
458 |
+
dictionary = work_dir / 'dictionary.txt'
|
459 |
+
dict_src = binary_dir / 'dictionary.txt'
|
460 |
+
if not dictionary.exists():
|
461 |
+
if dict_src.exists():
|
462 |
+
shutil.copy(dict_src, dictionary)
|
463 |
+
else:
|
464 |
+
shutil.copy(locate_dictionary(), dictionary)
|
465 |
+
print(f'| Copied dictionary to {dictionary}.')
|
466 |
+
|
467 |
+
train_payload_copy()
|
468 |
+
trainer.fit(task, ckpt_path=get_latest_checkpoint_path(work_dir))
|
469 |
+
else:
|
470 |
+
trainer.test(task)
|
471 |
+
|
472 |
+
def on_save_checkpoint(self, checkpoint):
|
473 |
+
if isinstance(self.model, CategorizedModule):
|
474 |
+
checkpoint['category'] = self.model.category
|
475 |
+
checkpoint['trainer_stage'] = self.trainer.state.stage.value
|
476 |
+
|
477 |
+
def on_load_checkpoint(self, checkpoint):
|
478 |
+
from lightning.pytorch.trainer.states import RunningStage
|
479 |
+
from utils import simulate_lr_scheduler
|
480 |
+
if checkpoint.get('trainer_stage', '') == RunningStage.VALIDATING.value:
|
481 |
+
self.skip_immediate_validation = True
|
482 |
+
|
483 |
+
optimizer_args = hparams['optimizer_args']
|
484 |
+
scheduler_args = hparams['lr_scheduler_args']
|
485 |
+
|
486 |
+
if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args:
|
487 |
+
optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2'])
|
488 |
+
|
489 |
+
if checkpoint.get('optimizer_states', None):
|
490 |
+
opt_states = checkpoint['optimizer_states']
|
491 |
+
assert len(opt_states) == 1 # only support one optimizer
|
492 |
+
opt_state = opt_states[0]
|
493 |
+
for param_group in opt_state['param_groups']:
|
494 |
+
for k, v in optimizer_args.items():
|
495 |
+
if k in param_group and param_group[k] != v:
|
496 |
+
if 'lr_schedulers' in checkpoint and checkpoint['lr_schedulers'] and k == 'lr':
|
497 |
+
continue
|
498 |
+
rank_zero_info(f'| Overriding optimizer parameter {k} from checkpoint: {param_group[k]} -> {v}')
|
499 |
+
param_group[k] = v
|
500 |
+
if 'initial_lr' in param_group and param_group['initial_lr'] != optimizer_args['lr']:
|
501 |
+
rank_zero_info(
|
502 |
+
f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}'
|
503 |
+
)
|
504 |
+
param_group['initial_lr'] = optimizer_args['lr']
|
505 |
+
|
506 |
+
if checkpoint.get('lr_schedulers', None):
|
507 |
+
assert checkpoint.get('optimizer_states', False)
|
508 |
+
assert len(checkpoint['lr_schedulers']) == 1 # only support one scheduler
|
509 |
+
checkpoint['lr_schedulers'][0] = simulate_lr_scheduler(
|
510 |
+
optimizer_args, scheduler_args,
|
511 |
+
step_count=checkpoint['global_step'],
|
512 |
+
num_param_groups=len(checkpoint['optimizer_states'][0]['param_groups'])
|
513 |
+
)
|
514 |
+
for param_group, new_lr in zip(
|
515 |
+
checkpoint['optimizer_states'][0]['param_groups'],
|
516 |
+
checkpoint['lr_schedulers'][0]['_last_lr'],
|
517 |
+
):
|
518 |
+
if param_group['lr'] != new_lr:
|
519 |
+
rank_zero_info(f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}')
|
520 |
+
param_group['lr'] = new_lr
|
basics/base_vocoder.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class BaseVocoder:
|
2 |
+
def to_device(self, device):
|
3 |
+
"""
|
4 |
+
|
5 |
+
:param device: torch.device or str
|
6 |
+
"""
|
7 |
+
raise NotImplementedError()
|
8 |
+
|
9 |
+
def get_device(self):
|
10 |
+
"""
|
11 |
+
|
12 |
+
:return: device: torch.device or str
|
13 |
+
"""
|
14 |
+
raise NotImplementedError()
|
15 |
+
|
16 |
+
def spec2wav(self, mel, **kwargs):
|
17 |
+
"""
|
18 |
+
|
19 |
+
:param mel: [T, 80]
|
20 |
+
:return: wav: [T']
|
21 |
+
"""
|
22 |
+
|
23 |
+
raise NotImplementedError()
|
configs/CantusSVS_acoustic.yaml
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- configs/base.yaml
|
3 |
+
|
4 |
+
pl_trainer_precision: '16-mixed'
|
5 |
+
pe: 'rmvpe'
|
6 |
+
pe_ckpt: checkpoints/dependency_checkpoints/rmvpe/model.pt
|
7 |
+
|
8 |
+
task_cls: training.acoustic_task.AcousticTask
|
9 |
+
spk_ids: []
|
10 |
+
|
11 |
+
num_spk: 1
|
12 |
+
speakers:
|
13 |
+
- regular # index 0
|
14 |
+
test_prefixes:
|
15 |
+
# - {index_speaker}:{name_of_wav}
|
16 |
+
# regular (0)
|
17 |
+
- 0:Adventus_seg012
|
18 |
+
- 0:Adventus_seg024
|
19 |
+
- 0:Adventus_seg036
|
20 |
+
- 0:Adventus_seg048
|
21 |
+
- 0:Adventus_seg060
|
22 |
+
|
23 |
+
raw_data_dir:
|
24 |
+
- data/NNSVS_training_data/regular/diffsinger_db #0
|
25 |
+
|
26 |
+
hnsep: vr
|
27 |
+
hnsep_ckpt: checkpoints/dependency_checkpoints/vr/model.pt
|
28 |
+
|
29 |
+
vocoder: NsfHifiGAN
|
30 |
+
vocoder_ckpt: checkpoints/dependency_checkpoints/nsf-hifigan/model.ckpt
|
31 |
+
audio_sample_rate: 44100
|
32 |
+
audio_num_mel_bins: 128
|
33 |
+
hop_size: 512 # Hop size.
|
34 |
+
fft_size: 2048 # FFT size.
|
35 |
+
win_size: 2048 # FFT size.
|
36 |
+
fmin: 40
|
37 |
+
fmax: 16000
|
38 |
+
|
39 |
+
binarization_args:
|
40 |
+
shuffle: true
|
41 |
+
num_workers: 0
|
42 |
+
augmentation_args:
|
43 |
+
random_pitch_shifting:
|
44 |
+
enabled: true
|
45 |
+
range: [-3., 3.]
|
46 |
+
scale: 0.75
|
47 |
+
fixed_pitch_shifting:
|
48 |
+
enabled: false
|
49 |
+
targets: [-3., 3.]
|
50 |
+
scale: 0.5
|
51 |
+
random_time_stretching:
|
52 |
+
enabled: true
|
53 |
+
range: [0.8, 1.2]
|
54 |
+
scale: 0.75
|
55 |
+
|
56 |
+
binary_data_dir: 'data/binary/regular_acoustic_v1'
|
57 |
+
binarizer_cls: preprocessing.acoustic_binarizer.AcousticBinarizer
|
58 |
+
dictionary: dictionaries/latin_dictionary.txt
|
59 |
+
spec_min: [-12]
|
60 |
+
spec_max: [0]
|
61 |
+
mel_vmin: -14.
|
62 |
+
mel_vmax: 4.
|
63 |
+
mel_base: 'e'
|
64 |
+
energy_smooth_width: 0.12
|
65 |
+
breathiness_smooth_width: 0.12
|
66 |
+
voicing_smooth_width: 0.12
|
67 |
+
tension_smooth_width: 0.12
|
68 |
+
|
69 |
+
use_spk_id: false
|
70 |
+
use_energy_embed: false
|
71 |
+
use_breathiness_embed: false
|
72 |
+
use_voicing_embed: false
|
73 |
+
use_tension_embed: false
|
74 |
+
use_key_shift_embed: true
|
75 |
+
use_speed_embed: true
|
76 |
+
|
77 |
+
diffusion_type: reflow
|
78 |
+
time_scale_factor: 1000
|
79 |
+
timesteps: 1000
|
80 |
+
max_beta: 0.02
|
81 |
+
enc_ffn_kernel_size: 3
|
82 |
+
use_rope: true
|
83 |
+
rel_pos: true
|
84 |
+
sampling_algorithm: euler
|
85 |
+
sampling_steps: 20
|
86 |
+
diff_accelerator: ddim
|
87 |
+
diff_speedup: 10
|
88 |
+
hidden_size: 256
|
89 |
+
backbone_type: 'lynxnet'
|
90 |
+
backbone_args:
|
91 |
+
num_channels: 1024
|
92 |
+
num_layers: 6
|
93 |
+
kernel_size: 31
|
94 |
+
dropout_rate: 0.0
|
95 |
+
strong_cond: true
|
96 |
+
main_loss_type: l2
|
97 |
+
main_loss_log_norm: false
|
98 |
+
schedule_type: 'linear'
|
99 |
+
|
100 |
+
# shallow diffusion
|
101 |
+
use_shallow_diffusion: true
|
102 |
+
T_start: 0.4
|
103 |
+
T_start_infer: 0.4
|
104 |
+
K_step: 400
|
105 |
+
K_step_infer: 400
|
106 |
+
|
107 |
+
shallow_diffusion_args:
|
108 |
+
train_aux_decoder: true
|
109 |
+
train_diffusion: true
|
110 |
+
val_gt_start: false
|
111 |
+
aux_decoder_arch: convnext
|
112 |
+
aux_decoder_args:
|
113 |
+
num_channels: 512
|
114 |
+
num_layers: 6
|
115 |
+
kernel_size: 7
|
116 |
+
dropout_rate: 0.1
|
117 |
+
aux_decoder_grad: 0.1
|
118 |
+
|
119 |
+
lambda_aux_mel_loss: 0.2
|
120 |
+
|
121 |
+
# train and eval
|
122 |
+
num_sanity_val_steps: 1
|
123 |
+
optimizer_args:
|
124 |
+
lr: 0.0006
|
125 |
+
lr_scheduler_args:
|
126 |
+
step_size: 10000
|
127 |
+
gamma: 0.75
|
128 |
+
max_batch_frames: 50000
|
129 |
+
max_batch_size: 16
|
130 |
+
dataset_size_key: 'lengths'
|
131 |
+
val_with_vocoder: true
|
132 |
+
val_check_interval: 1000
|
133 |
+
num_valid_plots: 10
|
134 |
+
max_updates: 1000000
|
135 |
+
num_ckpt_keep: 5
|
136 |
+
permanent_ckpt_start: 20000
|
137 |
+
permanent_ckpt_interval: 5000
|
138 |
+
|
139 |
+
finetune_enabled: false
|
140 |
+
finetune_ckpt_path: null
|
141 |
+
|
142 |
+
finetune_ignored_params:
|
143 |
+
- model.fs2.encoder.embed_tokens
|
144 |
+
- model.fs2.txt_embed
|
145 |
+
- model.fs2.spk_embed
|
146 |
+
finetune_strict_shapes: true
|
147 |
+
|
148 |
+
freezing_enabled: false
|
149 |
+
frozen_params: []
|
configs/CantusSVS_variance.yaml
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- configs/base.yaml
|
3 |
+
|
4 |
+
pl_trainer_precision: '16-mixed'
|
5 |
+
pe: 'rmvpe'
|
6 |
+
pe_ckpt: checkpoints/dependency_checkpoints/rmvpe/model.pt
|
7 |
+
|
8 |
+
task_cls: training.variance_task.VarianceTask
|
9 |
+
spk_ids: []
|
10 |
+
|
11 |
+
num_spk: 1
|
12 |
+
speakers:
|
13 |
+
- regular # index 0
|
14 |
+
test_prefixes:
|
15 |
+
# - {index_speaker}:{name_of_wav}
|
16 |
+
# regular (0)
|
17 |
+
- 0:Adventus_seg012
|
18 |
+
- 0:Adventus_seg024
|
19 |
+
- 0:Adventus_seg036
|
20 |
+
- 0:Adventus_seg048
|
21 |
+
- 0:Adventus_seg060
|
22 |
+
|
23 |
+
raw_data_dir:
|
24 |
+
- data/NNSVS_training_data/regular/diffsinger_db #0
|
25 |
+
|
26 |
+
audio_sample_rate: 44100
|
27 |
+
hop_size: 512 # Hop size.
|
28 |
+
fft_size: 2048 # FFT size.
|
29 |
+
win_size: 2048 # FFT size.
|
30 |
+
midi_smooth_width: 0.06 # in seconds
|
31 |
+
|
32 |
+
binarization_args:
|
33 |
+
shuffle: true
|
34 |
+
num_workers: 0
|
35 |
+
prefer_ds: false
|
36 |
+
|
37 |
+
binary_data_dir: 'data/binary/regular_variance_v1'
|
38 |
+
binarizer_cls: preprocessing.variance_binarizer.VarianceBinarizer
|
39 |
+
dictionary: dictionaries/latin_dictionary.txt
|
40 |
+
|
41 |
+
use_spk_id: false
|
42 |
+
|
43 |
+
enc_ffn_kernel_size: 3
|
44 |
+
use_rope: true
|
45 |
+
rel_pos: true
|
46 |
+
hidden_size: 256
|
47 |
+
|
48 |
+
predict_dur: true
|
49 |
+
predict_pitch: true
|
50 |
+
predict_energy: false
|
51 |
+
predict_breathiness: false
|
52 |
+
predict_voicing: false
|
53 |
+
predict_tension: false
|
54 |
+
|
55 |
+
dur_prediction_args:
|
56 |
+
arch: fs2
|
57 |
+
hidden_size: 512
|
58 |
+
dropout: 0.1
|
59 |
+
num_layers: 5
|
60 |
+
kernel_size: 3
|
61 |
+
log_offset: 1.0
|
62 |
+
loss_type: mse
|
63 |
+
lambda_pdur_loss: 0.3
|
64 |
+
lambda_wdur_loss: 1.0
|
65 |
+
lambda_sdur_loss: 3.0
|
66 |
+
|
67 |
+
use_melody_encoder: false
|
68 |
+
melody_encoder_args:
|
69 |
+
hidden_size: 128
|
70 |
+
enc_layers: 4
|
71 |
+
use_glide_embed: false
|
72 |
+
glide_types: [up, down]
|
73 |
+
glide_embed_scale: 11.313708498984760 # sqrt(128)
|
74 |
+
|
75 |
+
pitch_prediction_args:
|
76 |
+
pitd_norm_min: -8.0
|
77 |
+
pitd_norm_max: 8.0
|
78 |
+
pitd_clip_min: -12.0
|
79 |
+
pitd_clip_max: 12.0
|
80 |
+
repeat_bins: 64
|
81 |
+
backbone_type: 'wavenet'
|
82 |
+
backbone_args:
|
83 |
+
num_layers: 20
|
84 |
+
num_channels: 256
|
85 |
+
dilation_cycle_length: 5
|
86 |
+
|
87 |
+
energy_db_min: -96.0
|
88 |
+
energy_db_max: -12.0
|
89 |
+
energy_smooth_width: 0.12
|
90 |
+
|
91 |
+
breathiness_db_min: -96.0
|
92 |
+
breathiness_db_max: -20.0
|
93 |
+
breathiness_smooth_width: 0.12
|
94 |
+
voicing_db_min: -96.0
|
95 |
+
voicing_db_max: -12.0
|
96 |
+
voicing_smooth_width: 0.12
|
97 |
+
|
98 |
+
tension_logit_min: -10.0
|
99 |
+
tension_logit_max: 10.0
|
100 |
+
tension_smooth_width: 0.12
|
101 |
+
|
102 |
+
variances_prediction_args:
|
103 |
+
total_repeat_bins: 48
|
104 |
+
backbone_type: 'wavenet'
|
105 |
+
backbone_args:
|
106 |
+
num_layers: 10
|
107 |
+
num_channels: 192
|
108 |
+
dilation_cycle_length: 4
|
109 |
+
|
110 |
+
lambda_dur_loss: 1.0
|
111 |
+
lambda_pitch_loss: 1.0
|
112 |
+
lambda_var_loss: 1.0
|
113 |
+
|
114 |
+
diffusion_type: reflow # ddpm
|
115 |
+
time_scale_factor: 1000
|
116 |
+
schedule_type: 'linear'
|
117 |
+
K_step: 1000
|
118 |
+
timesteps: 1000
|
119 |
+
max_beta: 0.02
|
120 |
+
main_loss_type: l2
|
121 |
+
main_loss_log_norm: true
|
122 |
+
sampling_algorithm: euler
|
123 |
+
sampling_steps: 20
|
124 |
+
diff_accelerator: ddim
|
125 |
+
diff_speedup: 10
|
126 |
+
|
127 |
+
# train and eval
|
128 |
+
num_sanity_val_steps: 1
|
129 |
+
optimizer_args:
|
130 |
+
lr: 0.0006
|
131 |
+
lr_scheduler_args:
|
132 |
+
step_size: 10000
|
133 |
+
gamma: 0.75
|
134 |
+
max_batch_frames: 50000
|
135 |
+
max_batch_size: 16
|
136 |
+
dataset_size_key: 'lengths'
|
137 |
+
val_check_interval: 1000
|
138 |
+
num_valid_plots: 10
|
139 |
+
max_updates: 1000000
|
140 |
+
num_ckpt_keep: 5
|
141 |
+
permanent_ckpt_start: 20000
|
142 |
+
permanent_ckpt_interval: 5000
|
143 |
+
|
144 |
+
finetune_enabled: false
|
145 |
+
finetune_ckpt_path: null
|
146 |
+
finetune_ignored_params:
|
147 |
+
- model.spk_embed
|
148 |
+
- model.fs2.txt_embed
|
149 |
+
- model.fs2.encoder.embed_tokens
|
150 |
+
finetune_strict_shapes: true
|
151 |
+
|
152 |
+
freezing_enabled: false
|
153 |
+
frozen_params: []
|
configs/base.yaml
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# task
|
2 |
+
task_cls: null
|
3 |
+
|
4 |
+
#############
|
5 |
+
# dataset
|
6 |
+
#############
|
7 |
+
sort_by_len: true
|
8 |
+
raw_data_dir: null
|
9 |
+
binary_data_dir: null
|
10 |
+
binarizer_cls: null
|
11 |
+
binarization_args:
|
12 |
+
shuffle: false
|
13 |
+
num_workers: 0
|
14 |
+
|
15 |
+
audio_sample_rate: 44100
|
16 |
+
hop_size: 512
|
17 |
+
win_size: 2048
|
18 |
+
fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
|
19 |
+
sampler_frame_count_grid: 6
|
20 |
+
ds_workers: 4
|
21 |
+
dataloader_prefetch_factor: 2
|
22 |
+
|
23 |
+
#########
|
24 |
+
# model
|
25 |
+
#########
|
26 |
+
hidden_size: 256
|
27 |
+
dropout: 0.1
|
28 |
+
use_pos_embed: true
|
29 |
+
enc_layers: 4
|
30 |
+
num_heads: 2
|
31 |
+
enc_ffn_kernel_size: 9
|
32 |
+
ffn_act: gelu
|
33 |
+
use_spk_id: false
|
34 |
+
|
35 |
+
###########
|
36 |
+
# optimization
|
37 |
+
###########
|
38 |
+
optimizer_args:
|
39 |
+
optimizer_cls: torch.optim.AdamW
|
40 |
+
lr: 0.0004
|
41 |
+
beta1: 0.9
|
42 |
+
beta2: 0.98
|
43 |
+
weight_decay: 0
|
44 |
+
lr_scheduler_args:
|
45 |
+
scheduler_cls: torch.optim.lr_scheduler.StepLR
|
46 |
+
step_size: 50000
|
47 |
+
gamma: 0.5
|
48 |
+
clip_grad_norm: 1
|
49 |
+
|
50 |
+
###########
|
51 |
+
# train and eval
|
52 |
+
###########
|
53 |
+
num_ckpt_keep: 5
|
54 |
+
accumulate_grad_batches: 1
|
55 |
+
log_interval: 100
|
56 |
+
num_sanity_val_steps: 1 # steps of validation at the beginning
|
57 |
+
val_check_interval: 2000
|
58 |
+
max_updates: 120000
|
59 |
+
max_batch_frames: 32000
|
60 |
+
max_batch_size: 100000
|
61 |
+
max_val_batch_frames: 60000
|
62 |
+
max_val_batch_size: 1
|
63 |
+
pe: parselmouth
|
64 |
+
pe_ckpt: 'checkpoints/rmvpe/model.pt'
|
65 |
+
hnsep: vr
|
66 |
+
hnsep_ckpt: 'checkpoints/vr/model.pt'
|
67 |
+
f0_min: 65
|
68 |
+
f0_max: 1100
|
69 |
+
num_valid_plots: 10
|
70 |
+
|
71 |
+
###########
|
72 |
+
# pytorch lightning
|
73 |
+
# Read https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api for possible values
|
74 |
+
###########
|
75 |
+
pl_trainer_accelerator: 'auto'
|
76 |
+
pl_trainer_devices: 'auto'
|
77 |
+
pl_trainer_precision: '16-mixed'
|
78 |
+
pl_trainer_num_nodes: 1
|
79 |
+
pl_trainer_strategy:
|
80 |
+
name: auto
|
81 |
+
process_group_backend: nccl
|
82 |
+
find_unused_parameters: false
|
83 |
+
nccl_p2p: true
|
84 |
+
|
85 |
+
###########
|
86 |
+
# finetune
|
87 |
+
###########
|
88 |
+
finetune_enabled: false
|
89 |
+
finetune_ckpt_path: null
|
90 |
+
finetune_ignored_params: []
|
91 |
+
finetune_strict_shapes: true
|
92 |
+
|
93 |
+
freezing_enabled: false
|
94 |
+
frozen_params: []
|
configs/defaults/acoustic.yaml
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- configs/base.yaml
|
3 |
+
|
4 |
+
task_cls: training.acoustic_task.AcousticTask
|
5 |
+
num_spk: 1
|
6 |
+
speakers:
|
7 |
+
- opencpop
|
8 |
+
spk_ids: []
|
9 |
+
test_prefixes: [
|
10 |
+
'2044',
|
11 |
+
'2086',
|
12 |
+
'2092',
|
13 |
+
'2093',
|
14 |
+
'2100',
|
15 |
+
]
|
16 |
+
|
17 |
+
vocoder: NsfHifiGAN
|
18 |
+
vocoder_ckpt: checkpoints/nsf_hifigan_44.1k_hop512_128bin_2024.02/model.ckpt
|
19 |
+
audio_sample_rate: 44100
|
20 |
+
audio_num_mel_bins: 128
|
21 |
+
hop_size: 512 # Hop size.
|
22 |
+
fft_size: 2048 # FFT size.
|
23 |
+
win_size: 2048 # FFT size.
|
24 |
+
fmin: 40
|
25 |
+
fmax: 16000
|
26 |
+
|
27 |
+
binarization_args:
|
28 |
+
shuffle: true
|
29 |
+
num_workers: 0
|
30 |
+
augmentation_args:
|
31 |
+
random_pitch_shifting:
|
32 |
+
enabled: false
|
33 |
+
range: [-5., 5.]
|
34 |
+
scale: 0.75
|
35 |
+
fixed_pitch_shifting:
|
36 |
+
enabled: false
|
37 |
+
targets: [-5., 5.]
|
38 |
+
scale: 0.5
|
39 |
+
random_time_stretching:
|
40 |
+
enabled: false
|
41 |
+
range: [0.5, 2.]
|
42 |
+
scale: 0.75
|
43 |
+
|
44 |
+
raw_data_dir: 'data/opencpop/raw'
|
45 |
+
binary_data_dir: 'data/opencpop/binary'
|
46 |
+
binarizer_cls: preprocessing.acoustic_binarizer.AcousticBinarizer
|
47 |
+
dictionary: dictionaries/opencpop-extension.txt
|
48 |
+
spec_min: [-12]
|
49 |
+
spec_max: [0]
|
50 |
+
mel_vmin: -14.
|
51 |
+
mel_vmax: 4.
|
52 |
+
mel_base: 'e'
|
53 |
+
energy_smooth_width: 0.12
|
54 |
+
breathiness_smooth_width: 0.12
|
55 |
+
voicing_smooth_width: 0.12
|
56 |
+
tension_smooth_width: 0.12
|
57 |
+
|
58 |
+
use_spk_id: false
|
59 |
+
use_energy_embed: false
|
60 |
+
use_breathiness_embed: false
|
61 |
+
use_voicing_embed: false
|
62 |
+
use_tension_embed: false
|
63 |
+
use_key_shift_embed: false
|
64 |
+
use_speed_embed: false
|
65 |
+
|
66 |
+
diffusion_type: reflow
|
67 |
+
time_scale_factor: 1000
|
68 |
+
timesteps: 1000
|
69 |
+
max_beta: 0.02
|
70 |
+
enc_ffn_kernel_size: 3
|
71 |
+
use_rope: true
|
72 |
+
rel_pos: true
|
73 |
+
sampling_algorithm: euler
|
74 |
+
sampling_steps: 20
|
75 |
+
diff_accelerator: ddim
|
76 |
+
diff_speedup: 10
|
77 |
+
hidden_size: 256
|
78 |
+
backbone_type: 'lynxnet'
|
79 |
+
backbone_args:
|
80 |
+
num_channels: 1024
|
81 |
+
num_layers: 6
|
82 |
+
kernel_size: 31
|
83 |
+
dropout_rate: 0.0
|
84 |
+
strong_cond: true
|
85 |
+
main_loss_type: l2
|
86 |
+
main_loss_log_norm: false
|
87 |
+
schedule_type: 'linear'
|
88 |
+
|
89 |
+
# shallow diffusion
|
90 |
+
use_shallow_diffusion: true
|
91 |
+
T_start: 0.4
|
92 |
+
T_start_infer: 0.4
|
93 |
+
K_step: 400
|
94 |
+
K_step_infer: 400
|
95 |
+
|
96 |
+
shallow_diffusion_args:
|
97 |
+
train_aux_decoder: true
|
98 |
+
train_diffusion: true
|
99 |
+
val_gt_start: false
|
100 |
+
aux_decoder_arch: convnext
|
101 |
+
aux_decoder_args:
|
102 |
+
num_channels: 512
|
103 |
+
num_layers: 6
|
104 |
+
kernel_size: 7
|
105 |
+
dropout_rate: 0.1
|
106 |
+
aux_decoder_grad: 0.1
|
107 |
+
|
108 |
+
lambda_aux_mel_loss: 0.2
|
109 |
+
|
110 |
+
# train and eval
|
111 |
+
num_sanity_val_steps: 1
|
112 |
+
optimizer_args:
|
113 |
+
lr: 0.0006
|
114 |
+
lr_scheduler_args:
|
115 |
+
step_size: 10000
|
116 |
+
gamma: 0.75
|
117 |
+
max_batch_frames: 50000
|
118 |
+
max_batch_size: 64
|
119 |
+
dataset_size_key: 'lengths'
|
120 |
+
val_with_vocoder: true
|
121 |
+
val_check_interval: 2000
|
122 |
+
num_valid_plots: 10
|
123 |
+
max_updates: 160000
|
124 |
+
num_ckpt_keep: 5
|
125 |
+
permanent_ckpt_start: 80000
|
126 |
+
permanent_ckpt_interval: 20000
|
127 |
+
|
128 |
+
finetune_enabled: false
|
129 |
+
finetune_ckpt_path: null
|
130 |
+
|
131 |
+
finetune_ignored_params:
|
132 |
+
- model.fs2.encoder.embed_tokens
|
133 |
+
- model.fs2.txt_embed
|
134 |
+
- model.fs2.spk_embed
|
135 |
+
finetune_strict_shapes: true
|
136 |
+
|
137 |
+
freezing_enabled: false
|
138 |
+
frozen_params: []
|
configs/defaults/base.yaml
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# task
|
2 |
+
task_cls: null
|
3 |
+
|
4 |
+
#############
|
5 |
+
# dataset
|
6 |
+
#############
|
7 |
+
sort_by_len: true
|
8 |
+
raw_data_dir: null
|
9 |
+
binary_data_dir: null
|
10 |
+
binarizer_cls: null
|
11 |
+
binarization_args:
|
12 |
+
shuffle: false
|
13 |
+
num_workers: 0
|
14 |
+
|
15 |
+
audio_sample_rate: 44100
|
16 |
+
hop_size: 512
|
17 |
+
win_size: 2048
|
18 |
+
fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
|
19 |
+
sampler_frame_count_grid: 6
|
20 |
+
ds_workers: 4
|
21 |
+
dataloader_prefetch_factor: 2
|
22 |
+
|
23 |
+
#########
|
24 |
+
# model
|
25 |
+
#########
|
26 |
+
hidden_size: 256
|
27 |
+
dropout: 0.1
|
28 |
+
use_pos_embed: true
|
29 |
+
enc_layers: 4
|
30 |
+
num_heads: 2
|
31 |
+
enc_ffn_kernel_size: 9
|
32 |
+
ffn_act: gelu
|
33 |
+
use_spk_id: false
|
34 |
+
|
35 |
+
###########
|
36 |
+
# optimization
|
37 |
+
###########
|
38 |
+
optimizer_args:
|
39 |
+
optimizer_cls: torch.optim.AdamW
|
40 |
+
lr: 0.0004
|
41 |
+
beta1: 0.9
|
42 |
+
beta2: 0.98
|
43 |
+
weight_decay: 0
|
44 |
+
lr_scheduler_args:
|
45 |
+
scheduler_cls: torch.optim.lr_scheduler.StepLR
|
46 |
+
step_size: 50000
|
47 |
+
gamma: 0.5
|
48 |
+
clip_grad_norm: 1
|
49 |
+
|
50 |
+
###########
|
51 |
+
# train and eval
|
52 |
+
###########
|
53 |
+
num_ckpt_keep: 5
|
54 |
+
accumulate_grad_batches: 1
|
55 |
+
log_interval: 100
|
56 |
+
num_sanity_val_steps: 1 # steps of validation at the beginning
|
57 |
+
val_check_interval: 2000
|
58 |
+
max_updates: 120000
|
59 |
+
max_batch_frames: 32000
|
60 |
+
max_batch_size: 100000
|
61 |
+
max_val_batch_frames: 60000
|
62 |
+
max_val_batch_size: 1
|
63 |
+
pe: parselmouth
|
64 |
+
pe_ckpt: 'checkpoints/rmvpe/model.pt'
|
65 |
+
hnsep: vr
|
66 |
+
hnsep_ckpt: 'checkpoints/vr/model.pt'
|
67 |
+
f0_min: 65
|
68 |
+
f0_max: 1100
|
69 |
+
num_valid_plots: 10
|
70 |
+
|
71 |
+
###########
|
72 |
+
# pytorch lightning
|
73 |
+
# Read https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api for possible values
|
74 |
+
###########
|
75 |
+
pl_trainer_accelerator: 'auto'
|
76 |
+
pl_trainer_devices: 'auto'
|
77 |
+
pl_trainer_precision: '16-mixed'
|
78 |
+
pl_trainer_num_nodes: 1
|
79 |
+
pl_trainer_strategy:
|
80 |
+
name: auto
|
81 |
+
process_group_backend: nccl
|
82 |
+
find_unused_parameters: false
|
83 |
+
nccl_p2p: true
|
84 |
+
|
85 |
+
###########
|
86 |
+
# finetune
|
87 |
+
###########
|
88 |
+
finetune_enabled: false
|
89 |
+
finetune_ckpt_path: null
|
90 |
+
finetune_ignored_params: []
|
91 |
+
finetune_strict_shapes: true
|
92 |
+
|
93 |
+
freezing_enabled: false
|
94 |
+
frozen_params: []
|
configs/defaults/variance.yaml
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- configs/base.yaml
|
3 |
+
|
4 |
+
task_cls: training.variance_task.VarianceTask
|
5 |
+
num_spk: 1
|
6 |
+
speakers:
|
7 |
+
- opencpop
|
8 |
+
spk_ids: []
|
9 |
+
test_prefixes: [
|
10 |
+
'2044',
|
11 |
+
'2086',
|
12 |
+
'2092',
|
13 |
+
'2093',
|
14 |
+
'2100',
|
15 |
+
]
|
16 |
+
|
17 |
+
audio_sample_rate: 44100
|
18 |
+
hop_size: 512 # Hop size.
|
19 |
+
fft_size: 2048 # FFT size.
|
20 |
+
win_size: 2048 # FFT size.
|
21 |
+
midi_smooth_width: 0.06 # in seconds
|
22 |
+
|
23 |
+
binarization_args:
|
24 |
+
shuffle: true
|
25 |
+
num_workers: 0
|
26 |
+
prefer_ds: false
|
27 |
+
|
28 |
+
raw_data_dir: 'data/opencpop_variance/raw'
|
29 |
+
binary_data_dir: 'data/opencpop_variance/binary'
|
30 |
+
binarizer_cls: preprocessing.variance_binarizer.VarianceBinarizer
|
31 |
+
dictionary: dictionaries/opencpop-extension.txt
|
32 |
+
|
33 |
+
use_spk_id: false
|
34 |
+
|
35 |
+
enc_ffn_kernel_size: 3
|
36 |
+
use_rope: true
|
37 |
+
rel_pos: true
|
38 |
+
hidden_size: 256
|
39 |
+
|
40 |
+
predict_dur: true
|
41 |
+
predict_pitch: true
|
42 |
+
predict_energy: false
|
43 |
+
predict_breathiness: false
|
44 |
+
predict_voicing: false
|
45 |
+
predict_tension: false
|
46 |
+
|
47 |
+
dur_prediction_args:
|
48 |
+
arch: fs2
|
49 |
+
hidden_size: 512
|
50 |
+
dropout: 0.1
|
51 |
+
num_layers: 5
|
52 |
+
kernel_size: 3
|
53 |
+
log_offset: 1.0
|
54 |
+
loss_type: mse
|
55 |
+
lambda_pdur_loss: 0.3
|
56 |
+
lambda_wdur_loss: 1.0
|
57 |
+
lambda_sdur_loss: 3.0
|
58 |
+
|
59 |
+
use_melody_encoder: false
|
60 |
+
melody_encoder_args:
|
61 |
+
hidden_size: 128
|
62 |
+
enc_layers: 4
|
63 |
+
use_glide_embed: false
|
64 |
+
glide_types: [up, down]
|
65 |
+
glide_embed_scale: 11.313708498984760 # sqrt(128)
|
66 |
+
|
67 |
+
pitch_prediction_args:
|
68 |
+
pitd_norm_min: -8.0
|
69 |
+
pitd_norm_max: 8.0
|
70 |
+
pitd_clip_min: -12.0
|
71 |
+
pitd_clip_max: 12.0
|
72 |
+
repeat_bins: 64
|
73 |
+
backbone_type: 'wavenet'
|
74 |
+
backbone_args:
|
75 |
+
num_layers: 20
|
76 |
+
num_channels: 256
|
77 |
+
dilation_cycle_length: 5
|
78 |
+
|
79 |
+
energy_db_min: -96.0
|
80 |
+
energy_db_max: -12.0
|
81 |
+
energy_smooth_width: 0.12
|
82 |
+
|
83 |
+
breathiness_db_min: -96.0
|
84 |
+
breathiness_db_max: -20.0
|
85 |
+
breathiness_smooth_width: 0.12
|
86 |
+
voicing_db_min: -96.0
|
87 |
+
voicing_db_max: -12.0
|
88 |
+
voicing_smooth_width: 0.12
|
89 |
+
|
90 |
+
tension_logit_min: -10.0
|
91 |
+
tension_logit_max: 10.0
|
92 |
+
tension_smooth_width: 0.12
|
93 |
+
|
94 |
+
variances_prediction_args:
|
95 |
+
total_repeat_bins: 48
|
96 |
+
backbone_type: 'wavenet'
|
97 |
+
backbone_args:
|
98 |
+
num_layers: 10
|
99 |
+
num_channels: 192
|
100 |
+
dilation_cycle_length: 4
|
101 |
+
|
102 |
+
lambda_dur_loss: 1.0
|
103 |
+
lambda_pitch_loss: 1.0
|
104 |
+
lambda_var_loss: 1.0
|
105 |
+
|
106 |
+
diffusion_type: reflow # ddpm
|
107 |
+
time_scale_factor: 1000
|
108 |
+
schedule_type: 'linear'
|
109 |
+
K_step: 1000
|
110 |
+
timesteps: 1000
|
111 |
+
max_beta: 0.02
|
112 |
+
main_loss_type: l2
|
113 |
+
main_loss_log_norm: true
|
114 |
+
sampling_algorithm: euler
|
115 |
+
sampling_steps: 20
|
116 |
+
diff_accelerator: ddim
|
117 |
+
diff_speedup: 10
|
118 |
+
|
119 |
+
# train and eval
|
120 |
+
num_sanity_val_steps: 1
|
121 |
+
optimizer_args:
|
122 |
+
lr: 0.0006
|
123 |
+
lr_scheduler_args:
|
124 |
+
step_size: 10000
|
125 |
+
gamma: 0.75
|
126 |
+
max_batch_frames: 80000
|
127 |
+
max_batch_size: 48
|
128 |
+
dataset_size_key: 'lengths'
|
129 |
+
val_check_interval: 2000
|
130 |
+
num_valid_plots: 10
|
131 |
+
max_updates: 160000
|
132 |
+
num_ckpt_keep: 5
|
133 |
+
permanent_ckpt_start: 80000
|
134 |
+
permanent_ckpt_interval: 10000
|
135 |
+
|
136 |
+
finetune_enabled: false
|
137 |
+
finetune_ckpt_path: null
|
138 |
+
finetune_ignored_params:
|
139 |
+
- model.spk_embed
|
140 |
+
- model.fs2.txt_embed
|
141 |
+
- model.fs2.encoder.embed_tokens
|
142 |
+
finetune_strict_shapes: true
|
143 |
+
|
144 |
+
freezing_enabled: false
|
145 |
+
frozen_params: []
|
configs/templates/config_acoustic.yaml
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config: configs/acoustic.yaml
|
2 |
+
|
3 |
+
raw_data_dir:
|
4 |
+
- data/xxx1/raw
|
5 |
+
- data/xxx2/raw
|
6 |
+
speakers:
|
7 |
+
- speaker1
|
8 |
+
- speaker2
|
9 |
+
spk_ids: []
|
10 |
+
test_prefixes:
|
11 |
+
- wav1
|
12 |
+
- wav2
|
13 |
+
- wav3
|
14 |
+
- wav4
|
15 |
+
- wav5
|
16 |
+
dictionary: dictionaries/opencpop-extension.txt
|
17 |
+
binary_data_dir: data/xxx/binary
|
18 |
+
binarization_args:
|
19 |
+
num_workers: 0
|
20 |
+
pe: parselmouth
|
21 |
+
pe_ckpt: 'checkpoints/rmvpe/model.pt'
|
22 |
+
hnsep: vr
|
23 |
+
hnsep_ckpt: 'checkpoints/vr/model.pt'
|
24 |
+
vocoder: NsfHifiGAN
|
25 |
+
vocoder_ckpt: checkpoints/nsf_hifigan_44.1k_hop512_128bin_2024.02/model.ckpt
|
26 |
+
|
27 |
+
use_spk_id: false
|
28 |
+
num_spk: 1
|
29 |
+
|
30 |
+
# NOTICE: before enabling variance embeddings, please read the docs at
|
31 |
+
# https://github.com/openvpi/DiffSinger/tree/main/docs/BestPractices.md#choosing-variance-parameters
|
32 |
+
use_energy_embed: false
|
33 |
+
use_breathiness_embed: false
|
34 |
+
use_voicing_embed: false
|
35 |
+
use_tension_embed: false
|
36 |
+
|
37 |
+
use_key_shift_embed: true
|
38 |
+
use_speed_embed: true
|
39 |
+
|
40 |
+
augmentation_args:
|
41 |
+
random_pitch_shifting:
|
42 |
+
enabled: true
|
43 |
+
range: [-5., 5.]
|
44 |
+
scale: 0.75
|
45 |
+
fixed_pitch_shifting:
|
46 |
+
enabled: false
|
47 |
+
targets: [-5., 5.]
|
48 |
+
scale: 0.5
|
49 |
+
random_time_stretching:
|
50 |
+
enabled: true
|
51 |
+
range: [0.5, 2.]
|
52 |
+
scale: 0.75
|
53 |
+
|
54 |
+
# diffusion and shallow diffusion
|
55 |
+
diffusion_type: reflow
|
56 |
+
enc_ffn_kernel_size: 3
|
57 |
+
use_rope: true
|
58 |
+
use_shallow_diffusion: true
|
59 |
+
T_start: 0.4
|
60 |
+
T_start_infer: 0.4
|
61 |
+
K_step: 300
|
62 |
+
K_step_infer: 300
|
63 |
+
backbone_type: 'lynxnet'
|
64 |
+
backbone_args:
|
65 |
+
num_channels: 1024
|
66 |
+
num_layers: 6
|
67 |
+
kernel_size: 31
|
68 |
+
dropout_rate: 0.0
|
69 |
+
strong_cond: true
|
70 |
+
#backbone_type: 'wavenet'
|
71 |
+
#backbone_args:
|
72 |
+
# num_channels: 512
|
73 |
+
# num_layers: 20
|
74 |
+
# dilation_cycle_length: 4
|
75 |
+
shallow_diffusion_args:
|
76 |
+
train_aux_decoder: true
|
77 |
+
train_diffusion: true
|
78 |
+
val_gt_start: false
|
79 |
+
aux_decoder_arch: convnext
|
80 |
+
aux_decoder_args:
|
81 |
+
num_channels: 512
|
82 |
+
num_layers: 6
|
83 |
+
kernel_size: 7
|
84 |
+
dropout_rate: 0.1
|
85 |
+
aux_decoder_grad: 0.1
|
86 |
+
lambda_aux_mel_loss: 0.2
|
87 |
+
|
88 |
+
optimizer_args:
|
89 |
+
lr: 0.0006
|
90 |
+
lr_scheduler_args:
|
91 |
+
scheduler_cls: torch.optim.lr_scheduler.StepLR
|
92 |
+
step_size: 10000
|
93 |
+
gamma: 0.75
|
94 |
+
max_batch_frames: 50000
|
95 |
+
max_batch_size: 64
|
96 |
+
max_updates: 160000
|
97 |
+
|
98 |
+
num_valid_plots: 10
|
99 |
+
val_with_vocoder: true
|
100 |
+
val_check_interval: 2000
|
101 |
+
num_ckpt_keep: 5
|
102 |
+
permanent_ckpt_start: 120000
|
103 |
+
permanent_ckpt_interval: 20000
|
104 |
+
pl_trainer_devices: 'auto'
|
105 |
+
pl_trainer_precision: '16-mixed'
|
configs/templates/config_variance.yaml
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- configs/variance.yaml
|
3 |
+
|
4 |
+
raw_data_dir:
|
5 |
+
- data/xxx1/raw
|
6 |
+
- data/xxx2/raw
|
7 |
+
speakers:
|
8 |
+
- speaker1
|
9 |
+
- speaker2
|
10 |
+
spk_ids: []
|
11 |
+
test_prefixes:
|
12 |
+
- wav1
|
13 |
+
- wav2
|
14 |
+
- wav3
|
15 |
+
- wav4
|
16 |
+
- wav5
|
17 |
+
dictionary: dictionaries/opencpop-extension.txt
|
18 |
+
binary_data_dir: data/xxx/binary
|
19 |
+
binarization_args:
|
20 |
+
num_workers: 0
|
21 |
+
|
22 |
+
pe: parselmouth
|
23 |
+
pe_ckpt: 'checkpoints/rmvpe/model.pt'
|
24 |
+
hnsep: vr
|
25 |
+
hnsep_ckpt: 'checkpoints/vr/model.pt'
|
26 |
+
|
27 |
+
use_spk_id: false
|
28 |
+
num_spk: 1
|
29 |
+
# NOTICE: before enabling variance modules, please read the docs at
|
30 |
+
# https://github.com/openvpi/DiffSinger/tree/main/docs/BestPractices.md#mutual-influence-between-variance-modules
|
31 |
+
predict_dur: false
|
32 |
+
predict_pitch: false
|
33 |
+
# NOTICE: before enabling variance predictions, please read the docs at
|
34 |
+
# https://github.com/openvpi/DiffSinger/tree/main/docs/BestPractices.md#choosing-variance-parameters
|
35 |
+
predict_energy: false
|
36 |
+
predict_breathiness: false
|
37 |
+
predict_voicing: false
|
38 |
+
predict_tension: false
|
39 |
+
|
40 |
+
energy_db_min: -96.0
|
41 |
+
energy_db_max: -12.0
|
42 |
+
|
43 |
+
breathiness_db_min: -96.0
|
44 |
+
breathiness_db_max: -20.0
|
45 |
+
|
46 |
+
voicing_db_min: -96.0
|
47 |
+
voicing_db_max: -12.0
|
48 |
+
|
49 |
+
tension_logit_min: -10.0
|
50 |
+
tension_logit_max: 10.0
|
51 |
+
|
52 |
+
enc_ffn_kernel_size: 3
|
53 |
+
use_rope: true
|
54 |
+
hidden_size: 256
|
55 |
+
dur_prediction_args:
|
56 |
+
arch: fs2
|
57 |
+
hidden_size: 512
|
58 |
+
dropout: 0.1
|
59 |
+
num_layers: 5
|
60 |
+
kernel_size: 3
|
61 |
+
log_offset: 1.0
|
62 |
+
loss_type: mse
|
63 |
+
lambda_pdur_loss: 0.3
|
64 |
+
lambda_wdur_loss: 1.0
|
65 |
+
lambda_sdur_loss: 3.0
|
66 |
+
|
67 |
+
use_melody_encoder: false
|
68 |
+
melody_encoder_args:
|
69 |
+
hidden_size: 128
|
70 |
+
enc_layers: 4
|
71 |
+
use_glide_embed: false
|
72 |
+
glide_types: [up, down]
|
73 |
+
glide_embed_scale: 11.313708498984760 # sqrt(128)
|
74 |
+
|
75 |
+
diffusion_type: reflow
|
76 |
+
|
77 |
+
pitch_prediction_args:
|
78 |
+
pitd_norm_min: -8.0
|
79 |
+
pitd_norm_max: 8.0
|
80 |
+
pitd_clip_min: -12.0
|
81 |
+
pitd_clip_max: 12.0
|
82 |
+
repeat_bins: 64
|
83 |
+
backbone_type: 'wavenet'
|
84 |
+
backbone_args:
|
85 |
+
num_layers: 20
|
86 |
+
num_channels: 256
|
87 |
+
dilation_cycle_length: 5
|
88 |
+
# backbone_type: 'lynxnet'
|
89 |
+
# backbone_args:
|
90 |
+
# num_layers: 6
|
91 |
+
# num_channels: 512
|
92 |
+
# dropout_rate: 0.0
|
93 |
+
# strong_cond: true
|
94 |
+
|
95 |
+
variances_prediction_args:
|
96 |
+
total_repeat_bins: 48
|
97 |
+
backbone_type: 'wavenet'
|
98 |
+
backbone_args:
|
99 |
+
num_layers: 10
|
100 |
+
num_channels: 192
|
101 |
+
dilation_cycle_length: 4
|
102 |
+
# backbone_type: 'lynxnet'
|
103 |
+
# backbone_args:
|
104 |
+
# num_layers: 6
|
105 |
+
# num_channels: 384
|
106 |
+
# dropout_rate: 0.0
|
107 |
+
# strong_cond: true
|
108 |
+
|
109 |
+
lambda_dur_loss: 1.0
|
110 |
+
lambda_pitch_loss: 1.0
|
111 |
+
lambda_var_loss: 1.0
|
112 |
+
|
113 |
+
optimizer_args:
|
114 |
+
lr: 0.0006
|
115 |
+
lr_scheduler_args:
|
116 |
+
scheduler_cls: torch.optim.lr_scheduler.StepLR
|
117 |
+
step_size: 10000
|
118 |
+
gamma: 0.75
|
119 |
+
max_batch_frames: 80000
|
120 |
+
max_batch_size: 48
|
121 |
+
max_updates: 160000
|
122 |
+
|
123 |
+
num_valid_plots: 10
|
124 |
+
val_check_interval: 2000
|
125 |
+
num_ckpt_keep: 5
|
126 |
+
permanent_ckpt_start: 80000
|
127 |
+
permanent_ckpt_interval: 10000
|
128 |
+
pl_trainer_devices: 'auto'
|
129 |
+
pl_trainer_precision: '16-mixed'
|
deployment/.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.ds
|
2 |
+
*.onnx
|
3 |
+
*.npy
|
4 |
+
*.wav
|
5 |
+
temp/
|
6 |
+
cache/
|
7 |
+
assets/
|
deployment/__init__.py
ADDED
File without changes
|
deployment/benchmarks/infer_acoustic.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import onnxruntime as ort
|
3 |
+
import tqdm
|
4 |
+
|
5 |
+
n_tokens = 10
|
6 |
+
n_frames = 100
|
7 |
+
n_runs = 20
|
8 |
+
speedup = 20
|
9 |
+
provider = 'DmlExecutionProvider'
|
10 |
+
|
11 |
+
tokens = np.array([[1] * n_tokens], dtype=np.int64)
|
12 |
+
durations = np.array([[n_frames // n_tokens] * n_tokens], dtype=np.int64)
|
13 |
+
f0 = np.array([[440.] * n_frames], dtype=np.float32)
|
14 |
+
speedup = np.array(speedup, dtype=np.int64)
|
15 |
+
|
16 |
+
session = ort.InferenceSession('model1.onnx', providers=[provider])
|
17 |
+
for _ in tqdm.tqdm(range(n_runs)):
|
18 |
+
session.run(['mel'], {
|
19 |
+
'tokens': tokens,
|
20 |
+
'durations': durations,
|
21 |
+
'f0': f0,
|
22 |
+
'speedup': speedup
|
23 |
+
})
|
24 |
+
|
25 |
+
session = ort.InferenceSession('model2.onnx', providers=[provider])
|
26 |
+
for _ in tqdm.tqdm(range(n_runs)):
|
27 |
+
session.run(['mel'], {
|
28 |
+
'tokens': tokens,
|
29 |
+
'durations': durations,
|
30 |
+
'f0': f0,
|
31 |
+
'speedup': speedup
|
32 |
+
})
|
deployment/benchmarks/infer_nsf_hifigan.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import onnxruntime as ort
|
3 |
+
import tqdm
|
4 |
+
|
5 |
+
n_frames = 1000
|
6 |
+
n_runs = 20
|
7 |
+
mel = np.random.randn(1, n_frames, 128).astype(np.float32)
|
8 |
+
f0 = np.random.randn(1, n_frames).astype(np.float32) + 440.
|
9 |
+
provider = 'DmlExecutionProvider'
|
10 |
+
|
11 |
+
session = ort.InferenceSession('nsf_hifigan.onnx', providers=[provider])
|
12 |
+
for _ in tqdm.tqdm(range(n_runs)):
|
13 |
+
session.run(['waveform'], {
|
14 |
+
'mel': mel,
|
15 |
+
'f0': f0
|
16 |
+
})
|
deployment/exporters/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .acoustic_exporter import DiffSingerAcousticExporter
|
2 |
+
from .variance_exporter import DiffSingerVarianceExporter
|
3 |
+
from .nsf_hifigan_exporter import NSFHiFiGANExporter
|
deployment/exporters/acoustic_exporter.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import List, Union, Tuple, Dict
|
4 |
+
|
5 |
+
import onnx
|
6 |
+
import onnxsim
|
7 |
+
import torch
|
8 |
+
import yaml
|
9 |
+
|
10 |
+
from basics.base_exporter import BaseExporter
|
11 |
+
from deployment.modules.toplevel import DiffSingerAcousticONNX
|
12 |
+
from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST
|
13 |
+
from utils import load_ckpt, onnx_helper, remove_suffix
|
14 |
+
from utils.hparams import hparams
|
15 |
+
from utils.phoneme_utils import locate_dictionary, build_phoneme_list
|
16 |
+
from utils.text_encoder import TokenTextEncoder
|
17 |
+
|
18 |
+
|
19 |
+
class DiffSingerAcousticExporter(BaseExporter):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
device: Union[str, torch.device] = 'cpu',
|
23 |
+
cache_dir: Path = None,
|
24 |
+
ckpt_steps: int = None,
|
25 |
+
freeze_gender: float = None,
|
26 |
+
freeze_velocity: bool = False,
|
27 |
+
export_spk: List[Tuple[str, Dict[str, float]]] = None,
|
28 |
+
freeze_spk: Tuple[str, Dict[str, float]] = None
|
29 |
+
):
|
30 |
+
super().__init__(device=device, cache_dir=cache_dir)
|
31 |
+
# Basic attributes
|
32 |
+
self.model_name: str = hparams['exp_name']
|
33 |
+
self.ckpt_steps: int = ckpt_steps
|
34 |
+
self.spk_map: dict = self.build_spk_map()
|
35 |
+
self.vocab = TokenTextEncoder(vocab_list=build_phoneme_list())
|
36 |
+
self.model = self.build_model()
|
37 |
+
self.fs2_aux_cache_path = self.cache_dir / (
|
38 |
+
'fs2_aux.onnx' if self.model.use_shallow_diffusion else 'fs2.onnx'
|
39 |
+
)
|
40 |
+
self.diffusion_cache_path = self.cache_dir / 'diffusion.onnx'
|
41 |
+
|
42 |
+
# Attributes for logging
|
43 |
+
self.model_class_name = remove_suffix(self.model.__class__.__name__, 'ONNX')
|
44 |
+
fs2_aux_cls_logging = [remove_suffix(self.model.fs2.__class__.__name__, 'ONNX')]
|
45 |
+
if self.model.use_shallow_diffusion:
|
46 |
+
fs2_aux_cls_logging.append(remove_suffix(
|
47 |
+
self.model.aux_decoder.decoder.__class__.__name__, 'ONNX'
|
48 |
+
))
|
49 |
+
self.fs2_aux_class_name = ', '.join(fs2_aux_cls_logging)
|
50 |
+
self.aux_decoder_class_name = remove_suffix(
|
51 |
+
self.model.aux_decoder.decoder.__class__.__name__, 'ONNX'
|
52 |
+
) if self.model.use_shallow_diffusion else None
|
53 |
+
self.backbone_class_name = remove_suffix(self.model.diffusion.backbone.__class__.__name__, 'ONNX')
|
54 |
+
self.diffusion_class_name = remove_suffix(self.model.diffusion.__class__.__name__, 'ONNX')
|
55 |
+
|
56 |
+
# Attributes for exporting
|
57 |
+
self.expose_gender = freeze_gender is None
|
58 |
+
self.expose_velocity = not freeze_velocity
|
59 |
+
self.freeze_spk: Tuple[str, Dict[str, float]] = freeze_spk \
|
60 |
+
if hparams['use_spk_id'] else None
|
61 |
+
self.export_spk: List[Tuple[str, Dict[str, float]]] = export_spk \
|
62 |
+
if hparams['use_spk_id'] and export_spk is not None else []
|
63 |
+
if hparams['use_key_shift_embed'] and not self.expose_gender:
|
64 |
+
shift_min, shift_max = hparams['augmentation_args']['random_pitch_shifting']['range']
|
65 |
+
key_shift = freeze_gender * shift_max if freeze_gender >= 0. else freeze_gender * abs(shift_min)
|
66 |
+
key_shift = max(min(key_shift, shift_max), shift_min) # clip key shift
|
67 |
+
self.model.fs2.register_buffer('frozen_key_shift', torch.FloatTensor([key_shift]).to(self.device))
|
68 |
+
if hparams['use_spk_id']:
|
69 |
+
if not self.export_spk and self.freeze_spk is None:
|
70 |
+
# In case the user did not specify any speaker settings:
|
71 |
+
if len(self.spk_map) == 1:
|
72 |
+
# If there is only one speaker, freeze him/her.
|
73 |
+
first_spk = next(iter(self.spk_map.keys()))
|
74 |
+
self.freeze_spk = (first_spk, {first_spk: 1.0})
|
75 |
+
else:
|
76 |
+
# If there are multiple speakers, export them all.
|
77 |
+
self.export_spk = [(name, {name: 1.0}) for name in self.spk_map.keys()]
|
78 |
+
if self.freeze_spk is not None:
|
79 |
+
self.model.fs2.register_buffer('frozen_spk_embed', self._perform_spk_mix(self.freeze_spk[1]))
|
80 |
+
|
81 |
+
def build_model(self) -> DiffSingerAcousticONNX:
|
82 |
+
model = DiffSingerAcousticONNX(
|
83 |
+
vocab_size=len(self.vocab),
|
84 |
+
out_dims=hparams['audio_num_mel_bins']
|
85 |
+
).eval().to(self.device)
|
86 |
+
load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps,
|
87 |
+
prefix_in_ckpt='model', strict=True, device=self.device)
|
88 |
+
return model
|
89 |
+
|
90 |
+
def export(self, path: Path):
|
91 |
+
path.mkdir(parents=True, exist_ok=True)
|
92 |
+
model_name = self.model_name
|
93 |
+
if self.freeze_spk is not None:
|
94 |
+
model_name += '.' + self.freeze_spk[0]
|
95 |
+
self.export_model(path / f'{model_name}.onnx')
|
96 |
+
self.export_attachments(path)
|
97 |
+
|
98 |
+
def export_model(self, path: Path):
|
99 |
+
self._torch_export_model()
|
100 |
+
fs2_aux_onnx = self._optimize_fs2_aux_graph(onnx.load(self.fs2_aux_cache_path))
|
101 |
+
diffusion_onnx = self._optimize_diffusion_graph(onnx.load(self.diffusion_cache_path))
|
102 |
+
model_onnx = self._merge_fs2_aux_diffusion_graphs(fs2_aux_onnx, diffusion_onnx)
|
103 |
+
onnx.save(model_onnx, path)
|
104 |
+
self.fs2_aux_cache_path.unlink()
|
105 |
+
self.diffusion_cache_path.unlink()
|
106 |
+
print(f'| export model => {path}')
|
107 |
+
|
108 |
+
def export_attachments(self, path: Path):
|
109 |
+
for spk in self.export_spk:
|
110 |
+
self._export_spk_embed(
|
111 |
+
path / f'{self.model_name}.{spk[0]}.emb',
|
112 |
+
self._perform_spk_mix(spk[1])
|
113 |
+
)
|
114 |
+
self._export_dictionary(path / 'dictionary.txt')
|
115 |
+
self._export_phonemes(path / f'{self.model_name}.phonemes.txt')
|
116 |
+
|
117 |
+
model_name = self.model_name
|
118 |
+
if self.freeze_spk is not None:
|
119 |
+
model_name += '.' + self.freeze_spk[0]
|
120 |
+
dsconfig = {
|
121 |
+
# basic configs
|
122 |
+
'phonemes': f'{self.model_name}.phonemes.txt',
|
123 |
+
'acoustic': f'{model_name}.onnx',
|
124 |
+
'hidden_size': hparams['hidden_size'],
|
125 |
+
'vocoder': 'nsf_hifigan_44.1k_hop512_128bin_2024.02',
|
126 |
+
}
|
127 |
+
# multi-speaker
|
128 |
+
if len(self.export_spk) > 0:
|
129 |
+
dsconfig['speakers'] = [f'{self.model_name}.{spk[0]}' for spk in self.export_spk]
|
130 |
+
# parameters
|
131 |
+
if self.expose_gender:
|
132 |
+
dsconfig['augmentation_args'] = {
|
133 |
+
'random_pitch_shifting': {
|
134 |
+
'range': hparams['augmentation_args']['random_pitch_shifting']['range']
|
135 |
+
}
|
136 |
+
}
|
137 |
+
dsconfig['use_key_shift_embed'] = self.expose_gender
|
138 |
+
dsconfig['use_speed_embed'] = self.expose_velocity
|
139 |
+
for variance in VARIANCE_CHECKLIST:
|
140 |
+
dsconfig[f'use_{variance}_embed'] = (variance in self.model.fs2.variance_embed_list)
|
141 |
+
# sampling acceleration and shallow diffusion
|
142 |
+
dsconfig['use_continuous_acceleration'] = True
|
143 |
+
dsconfig['use_variable_depth'] = self.model.use_shallow_diffusion
|
144 |
+
dsconfig['max_depth'] = 1 - self.model.diffusion.t_start
|
145 |
+
# mel specification
|
146 |
+
dsconfig['sample_rate'] = hparams['audio_sample_rate']
|
147 |
+
dsconfig['hop_size'] = hparams['hop_size']
|
148 |
+
dsconfig['win_size'] = hparams['win_size']
|
149 |
+
dsconfig['fft_size'] = hparams['fft_size']
|
150 |
+
dsconfig['num_mel_bins'] = hparams['audio_num_mel_bins']
|
151 |
+
dsconfig['mel_fmin'] = hparams['fmin']
|
152 |
+
dsconfig['mel_fmax'] = hparams['fmax'] if hparams['fmax'] is not None else hparams['audio_sample_rate'] / 2
|
153 |
+
dsconfig['mel_base'] = 'e'
|
154 |
+
dsconfig['mel_scale'] = 'slaney'
|
155 |
+
config_path = path / 'dsconfig.yaml'
|
156 |
+
with open(config_path, 'w', encoding='utf8') as fw:
|
157 |
+
yaml.safe_dump(dsconfig, fw, sort_keys=False)
|
158 |
+
print(f'| export configs => {config_path} **PLEASE EDIT BEFORE USE**')
|
159 |
+
|
160 |
+
@torch.no_grad()
|
161 |
+
def _torch_export_model(self):
|
162 |
+
# Prepare inputs for FastSpeech2 and aux decoder tracing
|
163 |
+
n_frames = 10
|
164 |
+
tokens = torch.LongTensor([[1]]).to(self.device)
|
165 |
+
durations = torch.LongTensor([[n_frames]]).to(self.device)
|
166 |
+
f0 = torch.FloatTensor([[440.] * n_frames]).to(self.device)
|
167 |
+
variances = {
|
168 |
+
v_name: torch.zeros(1, n_frames, dtype=torch.float32, device=self.device)
|
169 |
+
for v_name in self.model.fs2.variance_embed_list
|
170 |
+
}
|
171 |
+
kwargs: Dict[str, torch.Tensor] = {}
|
172 |
+
arguments = (tokens, durations, f0, variances, kwargs)
|
173 |
+
input_names = ['tokens', 'durations', 'f0'] + self.model.fs2.variance_embed_list
|
174 |
+
dynamix_axes = {
|
175 |
+
'tokens': {
|
176 |
+
1: 'n_tokens'
|
177 |
+
},
|
178 |
+
'durations': {
|
179 |
+
1: 'n_tokens'
|
180 |
+
},
|
181 |
+
'f0': {
|
182 |
+
1: 'n_frames'
|
183 |
+
},
|
184 |
+
**{
|
185 |
+
v_name: {
|
186 |
+
1: 'n_frames'
|
187 |
+
}
|
188 |
+
for v_name in self.model.fs2.variance_embed_list
|
189 |
+
}
|
190 |
+
}
|
191 |
+
if hparams['use_key_shift_embed']:
|
192 |
+
if self.expose_gender:
|
193 |
+
kwargs['gender'] = torch.rand((1, n_frames), dtype=torch.float32, device=self.device)
|
194 |
+
input_names.append('gender')
|
195 |
+
dynamix_axes['gender'] = {
|
196 |
+
1: 'n_frames'
|
197 |
+
}
|
198 |
+
if hparams['use_speed_embed']:
|
199 |
+
if self.expose_velocity:
|
200 |
+
kwargs['velocity'] = torch.rand((1, n_frames), dtype=torch.float32, device=self.device)
|
201 |
+
input_names.append('velocity')
|
202 |
+
dynamix_axes['velocity'] = {
|
203 |
+
1: 'n_frames'
|
204 |
+
}
|
205 |
+
if hparams['use_spk_id'] and not self.freeze_spk:
|
206 |
+
kwargs['spk_embed'] = torch.rand(
|
207 |
+
(1, n_frames, hparams['hidden_size']),
|
208 |
+
dtype=torch.float32, device=self.device
|
209 |
+
)
|
210 |
+
input_names.append('spk_embed')
|
211 |
+
dynamix_axes['spk_embed'] = {
|
212 |
+
1: 'n_frames'
|
213 |
+
}
|
214 |
+
dynamix_axes['condition'] = {
|
215 |
+
1: 'n_frames'
|
216 |
+
}
|
217 |
+
|
218 |
+
# PyTorch ONNX export for FastSpeech2 and aux decoder
|
219 |
+
output_names = ['condition']
|
220 |
+
if self.model.use_shallow_diffusion:
|
221 |
+
output_names.append('aux_mel')
|
222 |
+
dynamix_axes['aux_mel'] = {
|
223 |
+
1: 'n_frames'
|
224 |
+
}
|
225 |
+
print(f'Exporting {self.fs2_aux_class_name}...')
|
226 |
+
torch.onnx.export(
|
227 |
+
self.model.view_as_fs2_aux(),
|
228 |
+
arguments,
|
229 |
+
self.fs2_aux_cache_path,
|
230 |
+
input_names=input_names,
|
231 |
+
output_names=output_names,
|
232 |
+
dynamic_axes=dynamix_axes,
|
233 |
+
opset_version=15
|
234 |
+
)
|
235 |
+
|
236 |
+
condition = torch.rand((1, n_frames, hparams['hidden_size']), device=self.device)
|
237 |
+
|
238 |
+
# Prepare inputs for backbone tracing and GaussianDiffusion scripting
|
239 |
+
shape = (1, 1, hparams['audio_num_mel_bins'], n_frames)
|
240 |
+
noise = torch.randn(shape, device=self.device)
|
241 |
+
x_aux = torch.randn((1, n_frames, hparams['audio_num_mel_bins']), device=self.device)
|
242 |
+
dummy_time = (torch.rand((1,), device=self.device) * self.model.diffusion.time_scale_factor).float()
|
243 |
+
dummy_depth = torch.tensor(0.1, device=self.device)
|
244 |
+
dummy_steps = 5
|
245 |
+
|
246 |
+
print(f'Tracing {self.backbone_class_name} backbone...')
|
247 |
+
if self.model.diffusion_type == 'ddpm':
|
248 |
+
major_mel_decoder = self.model.view_as_diffusion()
|
249 |
+
elif self.model.diffusion_type == 'reflow':
|
250 |
+
major_mel_decoder = self.model.view_as_reflow()
|
251 |
+
else:
|
252 |
+
raise ValueError(f'Invalid diffusion type: {self.model.diffusion_type}')
|
253 |
+
major_mel_decoder.diffusion.set_backbone(
|
254 |
+
torch.jit.trace(
|
255 |
+
major_mel_decoder.diffusion.backbone,
|
256 |
+
(
|
257 |
+
noise,
|
258 |
+
dummy_time,
|
259 |
+
condition.transpose(1, 2)
|
260 |
+
)
|
261 |
+
)
|
262 |
+
)
|
263 |
+
|
264 |
+
print(f'Scripting {self.diffusion_class_name}...')
|
265 |
+
diffusion_inputs = [
|
266 |
+
condition,
|
267 |
+
*([x_aux, dummy_depth] if self.model.use_shallow_diffusion else [])
|
268 |
+
]
|
269 |
+
major_mel_decoder = torch.jit.script(
|
270 |
+
major_mel_decoder,
|
271 |
+
example_inputs=[
|
272 |
+
(
|
273 |
+
*diffusion_inputs,
|
274 |
+
1 # p_sample branch
|
275 |
+
),
|
276 |
+
(
|
277 |
+
*diffusion_inputs,
|
278 |
+
dummy_steps # p_sample_plms branch
|
279 |
+
)
|
280 |
+
]
|
281 |
+
)
|
282 |
+
|
283 |
+
# PyTorch ONNX export for GaussianDiffusion
|
284 |
+
print(f'Exporting {self.diffusion_class_name}...')
|
285 |
+
torch.onnx.export(
|
286 |
+
major_mel_decoder,
|
287 |
+
(
|
288 |
+
*diffusion_inputs,
|
289 |
+
dummy_steps
|
290 |
+
),
|
291 |
+
self.diffusion_cache_path,
|
292 |
+
input_names=[
|
293 |
+
'condition',
|
294 |
+
*(['x_aux', 'depth'] if self.model.use_shallow_diffusion else []),
|
295 |
+
'steps'
|
296 |
+
],
|
297 |
+
output_names=[
|
298 |
+
'mel'
|
299 |
+
],
|
300 |
+
dynamic_axes={
|
301 |
+
'condition': {
|
302 |
+
1: 'n_frames'
|
303 |
+
},
|
304 |
+
**({'x_aux': {1: 'n_frames'}} if self.model.use_shallow_diffusion else {}),
|
305 |
+
'mel': {
|
306 |
+
1: 'n_frames'
|
307 |
+
}
|
308 |
+
},
|
309 |
+
opset_version=15
|
310 |
+
)
|
311 |
+
|
312 |
+
@torch.no_grad()
|
313 |
+
def _perform_spk_mix(self, spk_mix: Dict[str, float]):
|
314 |
+
spk_mix_ids = []
|
315 |
+
spk_mix_values = []
|
316 |
+
for name, value in spk_mix.items():
|
317 |
+
spk_mix_ids.append(self.spk_map[name])
|
318 |
+
assert value >= 0., f'Speaker mix checks failed.\n' \
|
319 |
+
f'Proportion of speaker \'{name}\' is negative.'
|
320 |
+
spk_mix_values.append(value)
|
321 |
+
spk_mix_id_N = torch.LongTensor(spk_mix_ids).to(self.device)[None] # => [1, N]
|
322 |
+
spk_mix_value_N = torch.FloatTensor(spk_mix_values).to(self.device)[None] # => [1, N]
|
323 |
+
spk_mix_value_sum = spk_mix_value_N.sum()
|
324 |
+
assert spk_mix_value_sum > 0., f'Speaker mix checks failed.\n' \
|
325 |
+
f'Proportions of speaker mix sum to zero.'
|
326 |
+
spk_mix_value_N /= spk_mix_value_sum # normalize
|
327 |
+
spk_mix_embed = torch.sum(
|
328 |
+
self.model.fs2.spk_embed(spk_mix_id_N) * spk_mix_value_N.unsqueeze(2), # => [1, N, H]
|
329 |
+
dim=1, keepdim=False
|
330 |
+
) # => [1, H]
|
331 |
+
return spk_mix_embed
|
332 |
+
|
333 |
+
def _optimize_fs2_aux_graph(self, fs2: onnx.ModelProto) -> onnx.ModelProto:
|
334 |
+
print(f'Running ONNX Simplifier on {self.fs2_aux_class_name}...')
|
335 |
+
fs2, check = onnxsim.simplify(fs2, include_subgraph=True)
|
336 |
+
assert check, 'Simplified ONNX model could not be validated'
|
337 |
+
print(f'| optimize graph: {self.fs2_aux_class_name}')
|
338 |
+
return fs2
|
339 |
+
|
340 |
+
def _optimize_diffusion_graph(self, diffusion: onnx.ModelProto) -> onnx.ModelProto:
|
341 |
+
onnx_helper.model_override_io_shapes(diffusion, output_shapes={
|
342 |
+
'mel': (1, 'n_frames', hparams['audio_num_mel_bins'])
|
343 |
+
})
|
344 |
+
print(f'Running ONNX Simplifier #1 on {self.diffusion_class_name}...')
|
345 |
+
diffusion, check = onnxsim.simplify(diffusion, include_subgraph=True)
|
346 |
+
assert check, 'Simplified ONNX model could not be validated'
|
347 |
+
onnx_helper.graph_fold_back_to_squeeze(diffusion.graph)
|
348 |
+
onnx_helper.graph_extract_conditioner_projections(
|
349 |
+
graph=diffusion.graph, op_type='Conv',
|
350 |
+
weight_pattern=r'diffusion\..*\.conditioner_projection\.weight',
|
351 |
+
alias_prefix='/diffusion/backbone/cache'
|
352 |
+
)
|
353 |
+
onnx_helper.graph_remove_unused_values(diffusion.graph)
|
354 |
+
print(f'Running ONNX Simplifier #2 on {self.diffusion_class_name}...')
|
355 |
+
diffusion, check = onnxsim.simplify(
|
356 |
+
diffusion,
|
357 |
+
include_subgraph=True
|
358 |
+
)
|
359 |
+
assert check, 'Simplified ONNX model could not be validated'
|
360 |
+
print(f'| optimize graph: {self.diffusion_class_name}')
|
361 |
+
return diffusion
|
362 |
+
|
363 |
+
def _merge_fs2_aux_diffusion_graphs(self, fs2: onnx.ModelProto, diffusion: onnx.ModelProto) -> onnx.ModelProto:
|
364 |
+
onnx_helper.model_add_prefixes(
|
365 |
+
fs2, dim_prefix=('fs2aux.' if self.model.use_shallow_diffusion else 'fs2.'),
|
366 |
+
ignored_pattern=r'(n_tokens)|(n_frames)'
|
367 |
+
)
|
368 |
+
onnx_helper.model_add_prefixes(diffusion, dim_prefix='diffusion.', ignored_pattern='n_frames')
|
369 |
+
print(f'Merging {self.fs2_aux_class_name} and {self.diffusion_class_name} '
|
370 |
+
f'back into {self.model_class_name}...')
|
371 |
+
merged = onnx.compose.merge_models(
|
372 |
+
fs2, diffusion, io_map=[
|
373 |
+
('condition', 'condition'),
|
374 |
+
*([('aux_mel', 'x_aux')] if self.model.use_shallow_diffusion else []),
|
375 |
+
],
|
376 |
+
prefix1='', prefix2='', doc_string='',
|
377 |
+
producer_name=fs2.producer_name, producer_version=fs2.producer_version,
|
378 |
+
domain=fs2.domain, model_version=fs2.model_version
|
379 |
+
)
|
380 |
+
merged.graph.name = fs2.graph.name
|
381 |
+
|
382 |
+
print(f'Running ONNX Simplifier on {self.model_class_name}...')
|
383 |
+
merged, check = onnxsim.simplify(
|
384 |
+
merged,
|
385 |
+
include_subgraph=True
|
386 |
+
)
|
387 |
+
assert check, 'Simplified ONNX model could not be validated'
|
388 |
+
print(f'| optimize graph: {self.model_class_name}')
|
389 |
+
|
390 |
+
return merged
|
391 |
+
|
392 |
+
# noinspection PyMethodMayBeStatic
|
393 |
+
def _export_spk_embed(self, path: Path, spk_embed: torch.Tensor):
|
394 |
+
with open(path, 'wb') as f:
|
395 |
+
f.write(spk_embed.cpu().numpy().tobytes())
|
396 |
+
print(f'| export spk embed => {path}')
|
397 |
+
|
398 |
+
# noinspection PyMethodMayBeStatic
|
399 |
+
def _export_dictionary(self, path: Path):
|
400 |
+
print(f'| export dictionary => {path}')
|
401 |
+
shutil.copy(locate_dictionary(), path)
|
402 |
+
|
403 |
+
def _export_phonemes(self, path: Path):
|
404 |
+
self.vocab.store_to_file(path)
|
405 |
+
print(f'| export phonemes => {path}')
|
deployment/exporters/nsf_hifigan_exporter.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import onnx
|
6 |
+
import onnxsim
|
7 |
+
import torch
|
8 |
+
import yaml
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from basics.base_exporter import BaseExporter
|
12 |
+
from deployment.modules.nsf_hifigan import NSFHiFiGANONNX
|
13 |
+
from utils import load_ckpt, remove_suffix
|
14 |
+
from utils.hparams import hparams
|
15 |
+
|
16 |
+
|
17 |
+
class NSFHiFiGANExporter(BaseExporter):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
device: Union[str, torch.device] = 'cpu',
|
21 |
+
cache_dir: Path = None,
|
22 |
+
model_path: Path = None,
|
23 |
+
model_name: str = 'nsf_hifigan'
|
24 |
+
):
|
25 |
+
super().__init__(device=device, cache_dir=cache_dir)
|
26 |
+
self.model_path = model_path
|
27 |
+
self.model_name = model_name
|
28 |
+
self.model = self.build_model()
|
29 |
+
self.model_class_name = remove_suffix(self.model.__class__.__name__, 'ONNX')
|
30 |
+
self.model_cache_path = (self.cache_dir / self.model_name).with_suffix('.onnx')
|
31 |
+
|
32 |
+
def build_model(self) -> nn.Module:
|
33 |
+
config_path = self.model_path.with_name('config.json')
|
34 |
+
with open(config_path, 'r', encoding='utf8') as f:
|
35 |
+
config = json.load(f)
|
36 |
+
assert hparams.get('mel_base') == 'e', (
|
37 |
+
"Mel base must be set to \'e\' according to 2nd stage of the migration plan. "
|
38 |
+
"See https://github.com/openvpi/DiffSinger/releases/tag/v2.3.0 for more details."
|
39 |
+
)
|
40 |
+
model = NSFHiFiGANONNX(config).eval().to(self.device)
|
41 |
+
load_ckpt(model.generator, str(self.model_path),
|
42 |
+
prefix_in_ckpt=None, key_in_ckpt='generator',
|
43 |
+
strict=True, device=self.device)
|
44 |
+
model.generator.remove_weight_norm()
|
45 |
+
return model
|
46 |
+
|
47 |
+
def export(self, path: Path):
|
48 |
+
path.mkdir(parents=True, exist_ok=True)
|
49 |
+
self.export_model(path / self.model_cache_path.name)
|
50 |
+
self.export_attachments(path)
|
51 |
+
|
52 |
+
def export_model(self, path: Path):
|
53 |
+
self._torch_export_model()
|
54 |
+
model_onnx = self._optimize_model_graph(onnx.load(self.model_cache_path))
|
55 |
+
onnx.save(model_onnx, path)
|
56 |
+
self.model_cache_path.unlink()
|
57 |
+
print(f'| export model => {path}')
|
58 |
+
|
59 |
+
def export_attachments(self, path: Path):
|
60 |
+
config_path = path / 'vocoder.yaml'
|
61 |
+
with open(config_path, 'w', encoding='utf8') as fw:
|
62 |
+
yaml.safe_dump({
|
63 |
+
# basic configs
|
64 |
+
'name': self.model_name,
|
65 |
+
'model': self.model_cache_path.name,
|
66 |
+
# mel specifications
|
67 |
+
'sample_rate': hparams['audio_sample_rate'],
|
68 |
+
'hop_size': hparams['hop_size'],
|
69 |
+
'win_size': hparams['win_size'],
|
70 |
+
'fft_size': hparams['fft_size'],
|
71 |
+
'num_mel_bins': hparams['audio_num_mel_bins'],
|
72 |
+
'mel_fmin': hparams['fmin'],
|
73 |
+
'mel_fmax': hparams['fmax'] if hparams['fmax'] is not None else hparams['audio_sample_rate'] / 2,
|
74 |
+
'mel_base': 'e',
|
75 |
+
'mel_scale': 'slaney',
|
76 |
+
}, fw, sort_keys=False)
|
77 |
+
print(f'| export configs => {config_path} **PLEASE EDIT BEFORE USE**')
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
def _torch_export_model(self):
|
81 |
+
# Prepare inputs for NSFHiFiGAN
|
82 |
+
n_frames = 10
|
83 |
+
mel = torch.randn((1, n_frames, hparams['audio_num_mel_bins']), dtype=torch.float32, device=self.device)
|
84 |
+
f0 = torch.randn((1, n_frames), dtype=torch.float32, device=self.device) + 440.
|
85 |
+
|
86 |
+
# PyTorch ONNX export for NSFHiFiGAN
|
87 |
+
print(f'Exporting {self.model_class_name}...')
|
88 |
+
torch.onnx.export(
|
89 |
+
self.model,
|
90 |
+
(
|
91 |
+
mel,
|
92 |
+
f0
|
93 |
+
),
|
94 |
+
self.model_cache_path,
|
95 |
+
input_names=[
|
96 |
+
'mel',
|
97 |
+
'f0'
|
98 |
+
],
|
99 |
+
output_names=[
|
100 |
+
'waveform'
|
101 |
+
],
|
102 |
+
dynamic_axes={
|
103 |
+
'mel': {
|
104 |
+
1: 'n_frames'
|
105 |
+
},
|
106 |
+
'f0': {
|
107 |
+
1: 'n_frames'
|
108 |
+
},
|
109 |
+
'waveform': {
|
110 |
+
1: 'n_samples'
|
111 |
+
}
|
112 |
+
},
|
113 |
+
opset_version=15
|
114 |
+
)
|
115 |
+
|
116 |
+
def _optimize_model_graph(self, model: onnx.ModelProto) -> onnx.ModelProto:
|
117 |
+
print(f'Running ONNX simplifier for {self.model_class_name}...')
|
118 |
+
model, check = onnxsim.simplify(model, include_subgraph=True)
|
119 |
+
assert check, 'Simplified ONNX model could not be validated'
|
120 |
+
return model
|
deployment/exporters/variance_exporter.py
ADDED
@@ -0,0 +1,781 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Union, List, Tuple, Dict
|
4 |
+
|
5 |
+
import onnx
|
6 |
+
import onnxsim
|
7 |
+
import torch
|
8 |
+
import yaml
|
9 |
+
|
10 |
+
from basics.base_exporter import BaseExporter
|
11 |
+
from deployment.modules.toplevel import DiffSingerVarianceONNX
|
12 |
+
from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST
|
13 |
+
from utils import load_ckpt, onnx_helper, remove_suffix
|
14 |
+
from utils.hparams import hparams
|
15 |
+
from utils.phoneme_utils import locate_dictionary, build_phoneme_list
|
16 |
+
from utils.text_encoder import TokenTextEncoder
|
17 |
+
|
18 |
+
|
19 |
+
class DiffSingerVarianceExporter(BaseExporter):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
device: Union[str, torch.device] = 'cpu',
|
23 |
+
cache_dir: Path = None,
|
24 |
+
ckpt_steps: int = None,
|
25 |
+
freeze_glide: bool = False,
|
26 |
+
freeze_expr: bool = False,
|
27 |
+
export_spk: List[Tuple[str, Dict[str, float]]] = None,
|
28 |
+
freeze_spk: Tuple[str, Dict[str, float]] = None
|
29 |
+
):
|
30 |
+
super().__init__(device=device, cache_dir=cache_dir)
|
31 |
+
# Basic attributes
|
32 |
+
self.model_name: str = hparams['exp_name']
|
33 |
+
self.ckpt_steps: int = ckpt_steps
|
34 |
+
self.spk_map: dict = self.build_spk_map()
|
35 |
+
self.vocab = TokenTextEncoder(vocab_list=build_phoneme_list())
|
36 |
+
self.model = self.build_model()
|
37 |
+
self.linguistic_encoder_cache_path = self.cache_dir / 'linguistic.onnx'
|
38 |
+
self.dur_predictor_cache_path = self.cache_dir / 'dur.onnx'
|
39 |
+
self.pitch_preprocess_cache_path = self.cache_dir / 'pitch_pre.onnx'
|
40 |
+
self.pitch_predictor_cache_path = self.cache_dir / 'pitch.onnx'
|
41 |
+
self.pitch_postprocess_cache_path = self.cache_dir / 'pitch_post.onnx'
|
42 |
+
self.variance_preprocess_cache_path = self.cache_dir / 'variance_pre.onnx'
|
43 |
+
self.multi_var_predictor_cache_path = self.cache_dir / 'variance.onnx'
|
44 |
+
self.variance_postprocess_cache_path = self.cache_dir / 'variance_post.onnx'
|
45 |
+
|
46 |
+
# Attributes for logging
|
47 |
+
self.fs2_class_name = remove_suffix(self.model.fs2.__class__.__name__, 'ONNX')
|
48 |
+
self.dur_predictor_class_name = \
|
49 |
+
remove_suffix(self.model.fs2.dur_predictor.__class__.__name__, 'ONNX') \
|
50 |
+
if self.model.predict_dur else None
|
51 |
+
self.pitch_backbone_class_name = \
|
52 |
+
remove_suffix(self.model.pitch_predictor.backbone.__class__.__name__, 'ONNX') \
|
53 |
+
if self.model.predict_pitch else None
|
54 |
+
self.pitch_predictor_class_name = \
|
55 |
+
remove_suffix(self.model.pitch_predictor.__class__.__name__, 'ONNX') \
|
56 |
+
if self.model.predict_pitch else None
|
57 |
+
self.variance_backbone_class_name = \
|
58 |
+
remove_suffix(self.model.variance_predictor.backbone.__class__.__name__, 'ONNX') \
|
59 |
+
if self.model.predict_variances else None
|
60 |
+
self.multi_var_predictor_class_name = \
|
61 |
+
remove_suffix(self.model.variance_predictor.__class__.__name__, 'ONNX') \
|
62 |
+
if self.model.predict_variances else None
|
63 |
+
|
64 |
+
# Attributes for exporting
|
65 |
+
self.expose_expr = not freeze_expr
|
66 |
+
self.freeze_glide = freeze_glide
|
67 |
+
self.freeze_spk: Tuple[str, Dict[str, float]] = freeze_spk \
|
68 |
+
if hparams['use_spk_id'] else None
|
69 |
+
self.export_spk: List[Tuple[str, Dict[str, float]]] = export_spk \
|
70 |
+
if hparams['use_spk_id'] and export_spk is not None else []
|
71 |
+
if hparams['use_spk_id']:
|
72 |
+
if not self.export_spk and self.freeze_spk is None:
|
73 |
+
# In case the user did not specify any speaker settings:
|
74 |
+
if len(self.spk_map) == 1:
|
75 |
+
# If there is only one speaker, freeze him/her.
|
76 |
+
first_spk = next(iter(self.spk_map.keys()))
|
77 |
+
self.freeze_spk = (first_spk, {first_spk: 1.0})
|
78 |
+
else:
|
79 |
+
# If there are multiple speakers, export them all.
|
80 |
+
self.export_spk = [(name, {name: 1.0}) for name in self.spk_map.keys()]
|
81 |
+
if self.freeze_spk is not None:
|
82 |
+
self.model.register_buffer('frozen_spk_embed', self._perform_spk_mix(self.freeze_spk[1]))
|
83 |
+
|
84 |
+
def build_model(self) -> DiffSingerVarianceONNX:
|
85 |
+
model = DiffSingerVarianceONNX(
|
86 |
+
vocab_size=len(self.vocab)
|
87 |
+
).eval().to(self.device)
|
88 |
+
load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps,
|
89 |
+
prefix_in_ckpt='model', strict=True, device=self.device)
|
90 |
+
model.build_smooth_op(self.device)
|
91 |
+
return model
|
92 |
+
|
93 |
+
def export(self, path: Path):
|
94 |
+
path.mkdir(parents=True, exist_ok=True)
|
95 |
+
model_name = self.model_name
|
96 |
+
if self.freeze_spk is not None:
|
97 |
+
model_name += '.' + self.freeze_spk[0]
|
98 |
+
self.export_model(path, model_name)
|
99 |
+
self.export_attachments(path)
|
100 |
+
|
101 |
+
def export_model(self, path: Path, model_name: str = None):
|
102 |
+
self._torch_export_model()
|
103 |
+
linguistic_onnx = self._optimize_linguistic_graph(onnx.load(self.linguistic_encoder_cache_path))
|
104 |
+
linguistic_path = path / f'{model_name}.linguistic.onnx'
|
105 |
+
onnx.save(linguistic_onnx, linguistic_path)
|
106 |
+
print(f'| export linguistic encoder => {linguistic_path}')
|
107 |
+
self.linguistic_encoder_cache_path.unlink()
|
108 |
+
if self.model.predict_dur:
|
109 |
+
dur_predictor_onnx = self._optimize_dur_predictor_graph(onnx.load(self.dur_predictor_cache_path))
|
110 |
+
dur_predictor_path = path / f'{model_name}.dur.onnx'
|
111 |
+
onnx.save(dur_predictor_onnx, dur_predictor_path)
|
112 |
+
self.dur_predictor_cache_path.unlink()
|
113 |
+
print(f'| export dur predictor => {dur_predictor_path}')
|
114 |
+
if self.model.predict_pitch:
|
115 |
+
pitch_predictor_onnx = self._optimize_merge_pitch_predictor_graph(
|
116 |
+
onnx.load(self.pitch_preprocess_cache_path),
|
117 |
+
onnx.load(self.pitch_predictor_cache_path),
|
118 |
+
onnx.load(self.pitch_postprocess_cache_path)
|
119 |
+
)
|
120 |
+
pitch_predictor_path = path / f'{model_name}.pitch.onnx'
|
121 |
+
onnx.save(pitch_predictor_onnx, pitch_predictor_path)
|
122 |
+
self.pitch_preprocess_cache_path.unlink()
|
123 |
+
self.pitch_predictor_cache_path.unlink()
|
124 |
+
self.pitch_postprocess_cache_path.unlink()
|
125 |
+
print(f'| export pitch predictor => {pitch_predictor_path}')
|
126 |
+
if self.model.predict_variances:
|
127 |
+
variance_predictor_onnx = self._optimize_merge_variance_predictor_graph(
|
128 |
+
onnx.load(self.variance_preprocess_cache_path),
|
129 |
+
onnx.load(self.multi_var_predictor_cache_path),
|
130 |
+
onnx.load(self.variance_postprocess_cache_path)
|
131 |
+
)
|
132 |
+
variance_predictor_path = path / f'{model_name}.variance.onnx'
|
133 |
+
onnx.save(variance_predictor_onnx, variance_predictor_path)
|
134 |
+
self.variance_preprocess_cache_path.unlink()
|
135 |
+
self.multi_var_predictor_cache_path.unlink()
|
136 |
+
self.variance_postprocess_cache_path.unlink()
|
137 |
+
print(f'| export variance predictor => {variance_predictor_path}')
|
138 |
+
|
139 |
+
def export_attachments(self, path: Path):
|
140 |
+
for spk in self.export_spk:
|
141 |
+
self._export_spk_embed(
|
142 |
+
path / f'{self.model_name}.{spk[0]}.emb',
|
143 |
+
self._perform_spk_mix(spk[1])
|
144 |
+
)
|
145 |
+
self._export_dictionary(path / 'dictionary.txt')
|
146 |
+
self._export_phonemes((path / f'{self.model_name}.phonemes.txt'))
|
147 |
+
|
148 |
+
model_name = self.model_name
|
149 |
+
if self.freeze_spk is not None:
|
150 |
+
model_name += '.' + self.freeze_spk[0]
|
151 |
+
dsconfig = {
|
152 |
+
# basic configs
|
153 |
+
'phonemes': f'{self.model_name}.phonemes.txt',
|
154 |
+
'linguistic': f'{model_name}.linguistic.onnx',
|
155 |
+
'hidden_size': self.model.hidden_size,
|
156 |
+
'predict_dur': self.model.predict_dur,
|
157 |
+
}
|
158 |
+
# multi-speaker
|
159 |
+
if len(self.export_spk) > 0:
|
160 |
+
dsconfig['speakers'] = [f'{self.model_name}.{spk[0]}' for spk in self.export_spk]
|
161 |
+
# functionalities
|
162 |
+
if self.model.predict_dur:
|
163 |
+
dsconfig['dur'] = f'{model_name}.dur.onnx'
|
164 |
+
if self.model.predict_pitch:
|
165 |
+
dsconfig['pitch'] = f'{model_name}.pitch.onnx'
|
166 |
+
dsconfig['use_expr'] = self.expose_expr
|
167 |
+
dsconfig['use_note_rest'] = self.model.use_melody_encoder
|
168 |
+
if self.model.predict_variances:
|
169 |
+
dsconfig['variance'] = f'{model_name}.variance.onnx'
|
170 |
+
for variance in VARIANCE_CHECKLIST:
|
171 |
+
dsconfig[f'predict_{variance}'] = (variance in self.model.variance_prediction_list)
|
172 |
+
# sampling acceleration
|
173 |
+
dsconfig['use_continuous_acceleration'] = True
|
174 |
+
# frame specifications
|
175 |
+
dsconfig['sample_rate'] = hparams['audio_sample_rate']
|
176 |
+
dsconfig['hop_size'] = hparams['hop_size']
|
177 |
+
config_path = path / 'dsconfig.yaml'
|
178 |
+
with open(config_path, 'w', encoding='utf8') as fw:
|
179 |
+
yaml.safe_dump(dsconfig, fw, sort_keys=False)
|
180 |
+
print(f'| export configs => {config_path} **PLEASE EDIT BEFORE USE**')
|
181 |
+
|
182 |
+
@torch.no_grad()
|
183 |
+
def _torch_export_model(self):
|
184 |
+
# Prepare inputs for FastSpeech2 and dur predictor tracing
|
185 |
+
tokens = torch.LongTensor([[1] * 5]).to(self.device)
|
186 |
+
ph_dur = torch.LongTensor([[3, 5, 2, 1, 4]]).to(self.device)
|
187 |
+
word_div = torch.LongTensor([[2, 2, 1]]).to(self.device)
|
188 |
+
word_dur = torch.LongTensor([[8, 3, 4]]).to(self.device)
|
189 |
+
encoder_out = torch.rand(1, 5, hparams['hidden_size'], dtype=torch.float32, device=self.device)
|
190 |
+
x_masks = tokens == 0
|
191 |
+
ph_midi = torch.LongTensor([[60] * 5]).to(self.device)
|
192 |
+
encoder_output_names = ['encoder_out', 'x_masks']
|
193 |
+
encoder_common_axes = {
|
194 |
+
'encoder_out': {
|
195 |
+
1: 'n_tokens'
|
196 |
+
},
|
197 |
+
'x_masks': {
|
198 |
+
1: 'n_tokens'
|
199 |
+
}
|
200 |
+
}
|
201 |
+
input_spk_embed = hparams['use_spk_id'] and not self.freeze_spk
|
202 |
+
|
203 |
+
print(f'Exporting {self.fs2_class_name}...')
|
204 |
+
if self.model.predict_dur:
|
205 |
+
torch.onnx.export(
|
206 |
+
self.model.view_as_linguistic_encoder(),
|
207 |
+
(
|
208 |
+
tokens,
|
209 |
+
word_div,
|
210 |
+
word_dur
|
211 |
+
),
|
212 |
+
self.linguistic_encoder_cache_path,
|
213 |
+
input_names=[
|
214 |
+
'tokens',
|
215 |
+
'word_div',
|
216 |
+
'word_dur'
|
217 |
+
],
|
218 |
+
output_names=encoder_output_names,
|
219 |
+
dynamic_axes={
|
220 |
+
'tokens': {
|
221 |
+
1: 'n_tokens'
|
222 |
+
},
|
223 |
+
'word_div': {
|
224 |
+
1: 'n_words'
|
225 |
+
},
|
226 |
+
'word_dur': {
|
227 |
+
1: 'n_words'
|
228 |
+
},
|
229 |
+
**encoder_common_axes
|
230 |
+
},
|
231 |
+
opset_version=15
|
232 |
+
)
|
233 |
+
|
234 |
+
print(f'Exporting {self.dur_predictor_class_name}...')
|
235 |
+
torch.onnx.export(
|
236 |
+
self.model.view_as_dur_predictor(),
|
237 |
+
(
|
238 |
+
encoder_out,
|
239 |
+
x_masks,
|
240 |
+
ph_midi,
|
241 |
+
*([torch.rand(
|
242 |
+
1, 5, hparams['hidden_size'],
|
243 |
+
dtype=torch.float32, device=self.device
|
244 |
+
)] if input_spk_embed else [])
|
245 |
+
),
|
246 |
+
self.dur_predictor_cache_path,
|
247 |
+
input_names=[
|
248 |
+
'encoder_out',
|
249 |
+
'x_masks',
|
250 |
+
'ph_midi',
|
251 |
+
*(['spk_embed'] if input_spk_embed else [])
|
252 |
+
],
|
253 |
+
output_names=[
|
254 |
+
'ph_dur_pred'
|
255 |
+
],
|
256 |
+
dynamic_axes={
|
257 |
+
'ph_midi': {
|
258 |
+
1: 'n_tokens'
|
259 |
+
},
|
260 |
+
'ph_dur_pred': {
|
261 |
+
1: 'n_tokens'
|
262 |
+
},
|
263 |
+
**({'spk_embed': {1: 'n_tokens'}} if input_spk_embed else {}),
|
264 |
+
**encoder_common_axes
|
265 |
+
},
|
266 |
+
opset_version=15
|
267 |
+
)
|
268 |
+
else:
|
269 |
+
torch.onnx.export(
|
270 |
+
self.model.view_as_linguistic_encoder(),
|
271 |
+
(
|
272 |
+
tokens,
|
273 |
+
ph_dur
|
274 |
+
),
|
275 |
+
self.linguistic_encoder_cache_path,
|
276 |
+
input_names=[
|
277 |
+
'tokens',
|
278 |
+
'ph_dur'
|
279 |
+
],
|
280 |
+
output_names=encoder_output_names,
|
281 |
+
dynamic_axes={
|
282 |
+
'tokens': {
|
283 |
+
1: 'n_tokens'
|
284 |
+
},
|
285 |
+
'ph_dur': {
|
286 |
+
1: 'n_tokens'
|
287 |
+
},
|
288 |
+
**encoder_common_axes
|
289 |
+
},
|
290 |
+
opset_version=15
|
291 |
+
)
|
292 |
+
|
293 |
+
# Common dummy inputs
|
294 |
+
dummy_time = (torch.rand((1,), device=self.device) * hparams.get('time_scale_factor', 1.0)).float()
|
295 |
+
dummy_steps = 5
|
296 |
+
|
297 |
+
if self.model.predict_pitch:
|
298 |
+
use_melody_encoder = hparams.get('use_melody_encoder', False)
|
299 |
+
use_glide_embed = use_melody_encoder and hparams['use_glide_embed'] and not self.freeze_glide
|
300 |
+
# Prepare inputs for preprocessor of the pitch predictor
|
301 |
+
note_midi = torch.FloatTensor([[60.] * 4]).to(self.device)
|
302 |
+
note_dur = torch.LongTensor([[2, 6, 3, 4]]).to(self.device)
|
303 |
+
pitch = torch.FloatTensor([[60.] * 15]).to(self.device)
|
304 |
+
retake = torch.ones_like(pitch, dtype=torch.bool)
|
305 |
+
pitch_input_args = (
|
306 |
+
encoder_out,
|
307 |
+
ph_dur,
|
308 |
+
{
|
309 |
+
'note_midi': note_midi,
|
310 |
+
**({'note_rest': note_midi >= 0} if use_melody_encoder else {}),
|
311 |
+
'note_dur': note_dur,
|
312 |
+
**({'note_glide': torch.zeros_like(note_midi, dtype=torch.long)} if use_glide_embed else {}),
|
313 |
+
'pitch': pitch,
|
314 |
+
**({'expr': torch.ones_like(pitch)} if self.expose_expr else {}),
|
315 |
+
'retake': retake,
|
316 |
+
**({'spk_embed': torch.rand(
|
317 |
+
1, 15, hparams['hidden_size'], dtype=torch.float32, device=self.device
|
318 |
+
)} if input_spk_embed else {})
|
319 |
+
}
|
320 |
+
)
|
321 |
+
torch.onnx.export(
|
322 |
+
self.model.view_as_pitch_preprocess(),
|
323 |
+
pitch_input_args,
|
324 |
+
self.pitch_preprocess_cache_path,
|
325 |
+
input_names=[
|
326 |
+
'encoder_out', 'ph_dur', 'note_midi',
|
327 |
+
*(['note_rest'] if use_melody_encoder else []),
|
328 |
+
'note_dur',
|
329 |
+
*(['note_glide'] if use_glide_embed else []),
|
330 |
+
'pitch',
|
331 |
+
*(['expr'] if self.expose_expr else []),
|
332 |
+
'retake',
|
333 |
+
*(['spk_embed'] if input_spk_embed else [])
|
334 |
+
],
|
335 |
+
output_names=[
|
336 |
+
'pitch_cond', 'base_pitch'
|
337 |
+
],
|
338 |
+
dynamic_axes={
|
339 |
+
'encoder_out': {
|
340 |
+
1: 'n_tokens'
|
341 |
+
},
|
342 |
+
'ph_dur': {
|
343 |
+
1: 'n_tokens'
|
344 |
+
},
|
345 |
+
'note_midi': {
|
346 |
+
1: 'n_notes'
|
347 |
+
},
|
348 |
+
**({'note_rest': {1: 'n_notes'}} if use_melody_encoder else {}),
|
349 |
+
'note_dur': {
|
350 |
+
1: 'n_notes'
|
351 |
+
},
|
352 |
+
**({'note_glide': {1: 'n_notes'}} if use_glide_embed else {}),
|
353 |
+
'pitch': {
|
354 |
+
1: 'n_frames'
|
355 |
+
},
|
356 |
+
**({'expr': {1: 'n_frames'}} if self.expose_expr else {}),
|
357 |
+
'retake': {
|
358 |
+
1: 'n_frames'
|
359 |
+
},
|
360 |
+
'pitch_cond': {
|
361 |
+
1: 'n_frames'
|
362 |
+
},
|
363 |
+
'base_pitch': {
|
364 |
+
1: 'n_frames'
|
365 |
+
},
|
366 |
+
**({'spk_embed': {1: 'n_frames'}} if input_spk_embed else {})
|
367 |
+
},
|
368 |
+
opset_version=15
|
369 |
+
)
|
370 |
+
|
371 |
+
# Prepare inputs for backbone tracing and pitch predictor scripting
|
372 |
+
shape = (1, 1, hparams['pitch_prediction_args']['repeat_bins'], 15)
|
373 |
+
noise = torch.randn(shape, device=self.device)
|
374 |
+
condition = torch.rand((1, hparams['hidden_size'], 15), device=self.device)
|
375 |
+
|
376 |
+
print(f'Tracing {self.pitch_backbone_class_name} backbone...')
|
377 |
+
pitch_predictor = self.model.view_as_pitch_predictor()
|
378 |
+
pitch_predictor.pitch_predictor.set_backbone(
|
379 |
+
torch.jit.trace(
|
380 |
+
pitch_predictor.pitch_predictor.backbone,
|
381 |
+
(
|
382 |
+
noise,
|
383 |
+
dummy_time,
|
384 |
+
condition
|
385 |
+
)
|
386 |
+
)
|
387 |
+
)
|
388 |
+
|
389 |
+
print(f'Scripting {self.pitch_predictor_class_name}...')
|
390 |
+
pitch_predictor = torch.jit.script(
|
391 |
+
pitch_predictor,
|
392 |
+
example_inputs=[
|
393 |
+
(
|
394 |
+
condition.transpose(1, 2),
|
395 |
+
1 # p_sample branch
|
396 |
+
),
|
397 |
+
(
|
398 |
+
condition.transpose(1, 2),
|
399 |
+
dummy_steps # p_sample_plms branch
|
400 |
+
)
|
401 |
+
]
|
402 |
+
)
|
403 |
+
|
404 |
+
print(f'Exporting {self.pitch_predictor_class_name}...')
|
405 |
+
torch.onnx.export(
|
406 |
+
pitch_predictor,
|
407 |
+
(
|
408 |
+
condition.transpose(1, 2),
|
409 |
+
dummy_steps
|
410 |
+
),
|
411 |
+
self.pitch_predictor_cache_path,
|
412 |
+
input_names=[
|
413 |
+
'pitch_cond',
|
414 |
+
'steps'
|
415 |
+
],
|
416 |
+
output_names=[
|
417 |
+
'x_pred'
|
418 |
+
],
|
419 |
+
dynamic_axes={
|
420 |
+
'pitch_cond': {
|
421 |
+
1: 'n_frames'
|
422 |
+
},
|
423 |
+
'x_pred': {
|
424 |
+
1: 'n_frames'
|
425 |
+
}
|
426 |
+
},
|
427 |
+
opset_version=15
|
428 |
+
)
|
429 |
+
|
430 |
+
# Prepare inputs for postprocessor of the multi-variance predictor
|
431 |
+
torch.onnx.export(
|
432 |
+
self.model.view_as_pitch_postprocess(),
|
433 |
+
(
|
434 |
+
pitch,
|
435 |
+
pitch
|
436 |
+
),
|
437 |
+
self.pitch_postprocess_cache_path,
|
438 |
+
input_names=[
|
439 |
+
'x_pred',
|
440 |
+
'base_pitch'
|
441 |
+
],
|
442 |
+
output_names=[
|
443 |
+
'pitch_pred'
|
444 |
+
],
|
445 |
+
dynamic_axes={
|
446 |
+
'x_pred': {
|
447 |
+
1: 'n_frames'
|
448 |
+
},
|
449 |
+
'base_pitch': {
|
450 |
+
1: 'n_frames'
|
451 |
+
},
|
452 |
+
'pitch_pred': {
|
453 |
+
1: 'n_frames'
|
454 |
+
}
|
455 |
+
},
|
456 |
+
opset_version=15
|
457 |
+
)
|
458 |
+
|
459 |
+
if self.model.predict_variances:
|
460 |
+
total_repeat_bins = hparams['variances_prediction_args']['total_repeat_bins']
|
461 |
+
repeat_bins = total_repeat_bins // len(self.model.variance_prediction_list)
|
462 |
+
|
463 |
+
# Prepare inputs for preprocessor of the multi-variance predictor
|
464 |
+
pitch = torch.FloatTensor([[60.] * 15]).to(self.device)
|
465 |
+
variances = {
|
466 |
+
v_name: torch.FloatTensor([[0.] * 15]).to(self.device)
|
467 |
+
for v_name in self.model.variance_prediction_list
|
468 |
+
}
|
469 |
+
retake = torch.ones_like(pitch, dtype=torch.bool)[..., None].tile(len(self.model.variance_prediction_list))
|
470 |
+
torch.onnx.export(
|
471 |
+
self.model.view_as_variance_preprocess(),
|
472 |
+
(
|
473 |
+
encoder_out,
|
474 |
+
ph_dur,
|
475 |
+
pitch,
|
476 |
+
variances,
|
477 |
+
retake,
|
478 |
+
*([torch.rand(
|
479 |
+
1, 15, hparams['hidden_size'],
|
480 |
+
dtype=torch.float32, device=self.device
|
481 |
+
)] if input_spk_embed else [])
|
482 |
+
),
|
483 |
+
self.variance_preprocess_cache_path,
|
484 |
+
input_names=[
|
485 |
+
'encoder_out', 'ph_dur', 'pitch',
|
486 |
+
*self.model.variance_prediction_list,
|
487 |
+
'retake',
|
488 |
+
*(['spk_embed'] if input_spk_embed else [])
|
489 |
+
],
|
490 |
+
output_names=[
|
491 |
+
'variance_cond'
|
492 |
+
],
|
493 |
+
dynamic_axes={
|
494 |
+
'encoder_out': {
|
495 |
+
1: 'n_tokens'
|
496 |
+
},
|
497 |
+
'ph_dur': {
|
498 |
+
1: 'n_tokens'
|
499 |
+
},
|
500 |
+
'pitch': {
|
501 |
+
1: 'n_frames'
|
502 |
+
},
|
503 |
+
**{
|
504 |
+
v_name: {
|
505 |
+
1: 'n_frames'
|
506 |
+
}
|
507 |
+
for v_name in self.model.variance_prediction_list
|
508 |
+
},
|
509 |
+
'retake': {
|
510 |
+
1: 'n_frames'
|
511 |
+
},
|
512 |
+
**({'spk_embed': {1: 'n_frames'}} if input_spk_embed else {})
|
513 |
+
},
|
514 |
+
opset_version=15
|
515 |
+
)
|
516 |
+
|
517 |
+
# Prepare inputs for backbone tracing and multi-variance predictor scripting
|
518 |
+
shape = (1, len(self.model.variance_prediction_list), repeat_bins, 15)
|
519 |
+
noise = torch.randn(shape, device=self.device)
|
520 |
+
condition = torch.rand((1, hparams['hidden_size'], 15), device=self.device)
|
521 |
+
step = (torch.rand((1,), device=self.device) * hparams['K_step']).long()
|
522 |
+
|
523 |
+
print(f'Tracing {self.variance_backbone_class_name} backbone...')
|
524 |
+
multi_var_predictor = self.model.view_as_variance_predictor()
|
525 |
+
multi_var_predictor.variance_predictor.set_backbone(
|
526 |
+
torch.jit.trace(
|
527 |
+
multi_var_predictor.variance_predictor.backbone,
|
528 |
+
(
|
529 |
+
noise,
|
530 |
+
step,
|
531 |
+
condition
|
532 |
+
)
|
533 |
+
)
|
534 |
+
)
|
535 |
+
|
536 |
+
print(f'Scripting {self.multi_var_predictor_class_name}...')
|
537 |
+
multi_var_predictor = torch.jit.script(
|
538 |
+
multi_var_predictor,
|
539 |
+
example_inputs=[
|
540 |
+
(
|
541 |
+
condition.transpose(1, 2),
|
542 |
+
1 # p_sample branch
|
543 |
+
),
|
544 |
+
(
|
545 |
+
condition.transpose(1, 2),
|
546 |
+
dummy_steps # p_sample_plms branch
|
547 |
+
)
|
548 |
+
]
|
549 |
+
)
|
550 |
+
|
551 |
+
print(f'Exporting {self.multi_var_predictor_class_name}...')
|
552 |
+
torch.onnx.export(
|
553 |
+
multi_var_predictor,
|
554 |
+
(
|
555 |
+
condition.transpose(1, 2),
|
556 |
+
dummy_steps
|
557 |
+
),
|
558 |
+
self.multi_var_predictor_cache_path,
|
559 |
+
input_names=[
|
560 |
+
'variance_cond',
|
561 |
+
'steps'
|
562 |
+
],
|
563 |
+
output_names=[
|
564 |
+
'xs_pred'
|
565 |
+
],
|
566 |
+
dynamic_axes={
|
567 |
+
'variance_cond': {
|
568 |
+
1: 'n_frames'
|
569 |
+
},
|
570 |
+
'xs_pred': {
|
571 |
+
(1 if len(self.model.variance_prediction_list) == 1 else 2): 'n_frames'
|
572 |
+
}
|
573 |
+
},
|
574 |
+
opset_version=15
|
575 |
+
)
|
576 |
+
|
577 |
+
# Prepare inputs for postprocessor of the multi-variance predictor
|
578 |
+
xs_shape = (1, 15) \
|
579 |
+
if len(self.model.variance_prediction_list) == 1 \
|
580 |
+
else (1, len(self.model.variance_prediction_list), 15)
|
581 |
+
xs_pred = torch.randn(xs_shape, dtype=torch.float32, device=self.device)
|
582 |
+
torch.onnx.export(
|
583 |
+
self.model.view_as_variance_postprocess(),
|
584 |
+
(
|
585 |
+
xs_pred
|
586 |
+
),
|
587 |
+
self.variance_postprocess_cache_path,
|
588 |
+
input_names=[
|
589 |
+
'xs_pred'
|
590 |
+
],
|
591 |
+
output_names=[
|
592 |
+
f'{v_name}_pred'
|
593 |
+
for v_name in self.model.variance_prediction_list
|
594 |
+
],
|
595 |
+
dynamic_axes={
|
596 |
+
'xs_pred': {
|
597 |
+
(1 if len(self.model.variance_prediction_list) == 1 else 2): 'n_frames'
|
598 |
+
},
|
599 |
+
**{
|
600 |
+
f'{v_name}_pred': {
|
601 |
+
1: 'n_frames'
|
602 |
+
}
|
603 |
+
for v_name in self.model.variance_prediction_list
|
604 |
+
}
|
605 |
+
},
|
606 |
+
opset_version=15
|
607 |
+
)
|
608 |
+
|
609 |
+
@torch.no_grad()
|
610 |
+
def _perform_spk_mix(self, spk_mix: Dict[str, float]):
|
611 |
+
spk_mix_ids = []
|
612 |
+
spk_mix_values = []
|
613 |
+
for name, value in spk_mix.items():
|
614 |
+
spk_mix_ids.append(self.spk_map[name])
|
615 |
+
assert value >= 0., f'Speaker mix checks failed.\n' \
|
616 |
+
f'Proportion of speaker \'{name}\' is negative.'
|
617 |
+
spk_mix_values.append(value)
|
618 |
+
spk_mix_id_N = torch.LongTensor(spk_mix_ids).to(self.device)[None] # => [1, N]
|
619 |
+
spk_mix_value_N = torch.FloatTensor(spk_mix_values).to(self.device)[None] # => [1, N]
|
620 |
+
spk_mix_value_sum = spk_mix_value_N.sum()
|
621 |
+
assert spk_mix_value_sum > 0., f'Speaker mix checks failed.\n' \
|
622 |
+
f'Proportions of speaker mix sum to zero.'
|
623 |
+
spk_mix_value_N /= spk_mix_value_sum # normalize
|
624 |
+
spk_mix_embed = torch.sum(
|
625 |
+
self.model.spk_embed(spk_mix_id_N) * spk_mix_value_N.unsqueeze(2), # => [1, N, H]
|
626 |
+
dim=1, keepdim=False
|
627 |
+
) # => [1, H]
|
628 |
+
return spk_mix_embed
|
629 |
+
|
630 |
+
def _optimize_linguistic_graph(self, linguistic: onnx.ModelProto) -> onnx.ModelProto:
|
631 |
+
onnx_helper.model_override_io_shapes(
|
632 |
+
linguistic,
|
633 |
+
output_shapes={
|
634 |
+
'encoder_out': (1, 'n_tokens', hparams['hidden_size'])
|
635 |
+
}
|
636 |
+
)
|
637 |
+
print(f'Running ONNX Simplifier on {self.fs2_class_name}...')
|
638 |
+
linguistic, check = onnxsim.simplify(linguistic, include_subgraph=True)
|
639 |
+
assert check, 'Simplified ONNX model could not be validated'
|
640 |
+
print(f'| optimize graph: {self.fs2_class_name}')
|
641 |
+
return linguistic
|
642 |
+
|
643 |
+
def _optimize_dur_predictor_graph(self, dur_predictor: onnx.ModelProto) -> onnx.ModelProto:
|
644 |
+
onnx_helper.model_override_io_shapes(
|
645 |
+
dur_predictor,
|
646 |
+
output_shapes={
|
647 |
+
'ph_dur_pred': (1, 'n_tokens')
|
648 |
+
}
|
649 |
+
)
|
650 |
+
print(f'Running ONNX Simplifier on {self.dur_predictor_class_name}...')
|
651 |
+
dur_predictor, check = onnxsim.simplify(dur_predictor, include_subgraph=True)
|
652 |
+
assert check, 'Simplified ONNX model could not be validated'
|
653 |
+
print(f'| optimize graph: {self.dur_predictor_class_name}')
|
654 |
+
return dur_predictor
|
655 |
+
|
656 |
+
def _optimize_merge_pitch_predictor_graph(
|
657 |
+
self, pitch_pre: onnx.ModelProto, pitch_predictor: onnx.ModelProto, pitch_post: onnx.ModelProto
|
658 |
+
) -> onnx.ModelProto:
|
659 |
+
onnx_helper.model_override_io_shapes(
|
660 |
+
pitch_pre, output_shapes={'pitch_cond': (1, 'n_frames', hparams['hidden_size'])}
|
661 |
+
)
|
662 |
+
pitch_pre, check = onnxsim.simplify(pitch_pre, include_subgraph=True)
|
663 |
+
assert check, 'Simplified ONNX model could not be validated'
|
664 |
+
|
665 |
+
onnx_helper.model_override_io_shapes(
|
666 |
+
pitch_predictor, output_shapes={'pitch_pred': (1, 'n_frames')}
|
667 |
+
)
|
668 |
+
print(f'Running ONNX Simplifier #1 on {self.pitch_predictor_class_name}...')
|
669 |
+
pitch_predictor, check = onnxsim.simplify(pitch_predictor, include_subgraph=True)
|
670 |
+
assert check, 'Simplified ONNX model could not be validated'
|
671 |
+
onnx_helper.graph_fold_back_to_squeeze(pitch_predictor.graph)
|
672 |
+
onnx_helper.graph_extract_conditioner_projections(
|
673 |
+
graph=pitch_predictor.graph, op_type='Conv',
|
674 |
+
weight_pattern=r'pitch_predictor\..*\.conditioner_projection\.weight',
|
675 |
+
alias_prefix='/pitch_predictor/backbone/cache'
|
676 |
+
)
|
677 |
+
onnx_helper.graph_remove_unused_values(pitch_predictor.graph)
|
678 |
+
print(f'Running ONNX Simplifier #2 on {self.pitch_predictor_class_name}...')
|
679 |
+
pitch_predictor, check = onnxsim.simplify(pitch_predictor, include_subgraph=True)
|
680 |
+
assert check, 'Simplified ONNX model could not be validated'
|
681 |
+
|
682 |
+
onnx_helper.model_add_prefixes(pitch_pre, node_prefix='/pre', ignored_pattern=r'.*embed.*')
|
683 |
+
onnx_helper.model_add_prefixes(pitch_pre, dim_prefix='pre.', ignored_pattern='(n_tokens)|(n_notes)|(n_frames)')
|
684 |
+
onnx_helper.model_add_prefixes(pitch_post, node_prefix='/post', ignored_pattern=None)
|
685 |
+
onnx_helper.model_add_prefixes(pitch_post, dim_prefix='post.', ignored_pattern='n_frames')
|
686 |
+
pitch_pre_diffusion = onnx.compose.merge_models(
|
687 |
+
pitch_pre, pitch_predictor, io_map=[('pitch_cond', 'pitch_cond')],
|
688 |
+
prefix1='', prefix2='', doc_string='',
|
689 |
+
producer_name=pitch_pre.producer_name, producer_version=pitch_pre.producer_version,
|
690 |
+
domain=pitch_pre.domain, model_version=pitch_pre.model_version
|
691 |
+
)
|
692 |
+
pitch_pre_diffusion.graph.name = pitch_pre.graph.name
|
693 |
+
pitch_predictor = onnx.compose.merge_models(
|
694 |
+
pitch_pre_diffusion, pitch_post, io_map=[
|
695 |
+
('x_pred', 'x_pred'), ('base_pitch', 'base_pitch')
|
696 |
+
], prefix1='', prefix2='', doc_string='',
|
697 |
+
producer_name=pitch_pre.producer_name, producer_version=pitch_pre.producer_version,
|
698 |
+
domain=pitch_pre.domain, model_version=pitch_pre.model_version
|
699 |
+
)
|
700 |
+
pitch_predictor.graph.name = pitch_pre.graph.name
|
701 |
+
|
702 |
+
print(f'| optimize graph: {self.pitch_predictor_class_name}')
|
703 |
+
return pitch_predictor
|
704 |
+
|
705 |
+
def _optimize_merge_variance_predictor_graph(
|
706 |
+
self, var_pre: onnx.ModelProto, var_diffusion: onnx.ModelProto, var_post: onnx.ModelProto
|
707 |
+
):
|
708 |
+
onnx_helper.model_override_io_shapes(
|
709 |
+
var_pre, output_shapes={'variance_cond': (1, 'n_frames', hparams['hidden_size'])}
|
710 |
+
)
|
711 |
+
var_pre, check = onnxsim.simplify(var_pre, include_subgraph=True)
|
712 |
+
assert check, 'Simplified ONNX model could not be validated'
|
713 |
+
|
714 |
+
onnx_helper.model_override_io_shapes(
|
715 |
+
var_diffusion, output_shapes={
|
716 |
+
'xs_pred': (1, 'n_frames')
|
717 |
+
if len(self.model.variance_prediction_list) == 1
|
718 |
+
else (1, len(self.model.variance_prediction_list), 'n_frames')
|
719 |
+
}
|
720 |
+
)
|
721 |
+
print(f'Running ONNX Simplifier #1 on'
|
722 |
+
f' {self.multi_var_predictor_class_name}...')
|
723 |
+
var_diffusion, check = onnxsim.simplify(var_diffusion, include_subgraph=True)
|
724 |
+
assert check, 'Simplified ONNX model could not be validated'
|
725 |
+
onnx_helper.graph_fold_back_to_squeeze(var_diffusion.graph)
|
726 |
+
onnx_helper.graph_extract_conditioner_projections(
|
727 |
+
graph=var_diffusion.graph, op_type='Conv',
|
728 |
+
weight_pattern=r'variance_predictor\..*\.conditioner_projection\.weight',
|
729 |
+
alias_prefix='/variance_predictor/backbone/cache'
|
730 |
+
)
|
731 |
+
onnx_helper.graph_remove_unused_values(var_diffusion.graph)
|
732 |
+
print(f'Running ONNX Simplifier #2 on {self.multi_var_predictor_class_name}...')
|
733 |
+
var_diffusion, check = onnxsim.simplify(var_diffusion, include_subgraph=True)
|
734 |
+
assert check, 'Simplified ONNX model could not be validated'
|
735 |
+
|
736 |
+
var_post, check = onnxsim.simplify(var_post, include_subgraph=True)
|
737 |
+
assert check, 'Simplified ONNX model could not be validated'
|
738 |
+
|
739 |
+
ignored_variance_names = '|'.join([f'({v_name})' for v_name in self.model.variance_prediction_list])
|
740 |
+
onnx_helper.model_add_prefixes(
|
741 |
+
var_pre, node_prefix='/pre', value_info_prefix='/pre', initializer_prefix='/pre',
|
742 |
+
ignored_pattern=fr'.*((embed)|{ignored_variance_names}).*'
|
743 |
+
)
|
744 |
+
onnx_helper.model_add_prefixes(var_pre, dim_prefix='pre.', ignored_pattern='(n_tokens)|(n_frames)')
|
745 |
+
onnx_helper.model_add_prefixes(
|
746 |
+
var_post, node_prefix='/post', value_info_prefix='/post', initializer_prefix='/post',
|
747 |
+
ignored_pattern=None
|
748 |
+
)
|
749 |
+
onnx_helper.model_add_prefixes(var_post, dim_prefix='post.', ignored_pattern='n_frames')
|
750 |
+
|
751 |
+
print(f'Merging {self.multi_var_predictor_class_name} subroutines...')
|
752 |
+
var_pre_diffusion = onnx.compose.merge_models(
|
753 |
+
var_pre, var_diffusion, io_map=[('variance_cond', 'variance_cond')],
|
754 |
+
prefix1='', prefix2='', doc_string='',
|
755 |
+
producer_name=var_pre.producer_name, producer_version=var_pre.producer_version,
|
756 |
+
domain=var_pre.domain, model_version=var_pre.model_version
|
757 |
+
)
|
758 |
+
var_pre_diffusion.graph.name = var_pre.graph.name
|
759 |
+
var_predictor = onnx.compose.merge_models(
|
760 |
+
var_pre_diffusion, var_post, io_map=[('xs_pred', 'xs_pred')],
|
761 |
+
prefix1='', prefix2='', doc_string='',
|
762 |
+
producer_name=var_pre.producer_name, producer_version=var_pre.producer_version,
|
763 |
+
domain=var_pre.domain, model_version=var_pre.model_version
|
764 |
+
)
|
765 |
+
var_predictor.graph.name = var_pre.graph.name
|
766 |
+
return var_predictor
|
767 |
+
|
768 |
+
# noinspection PyMethodMayBeStatic
|
769 |
+
def _export_spk_embed(self, path: Path, spk_embed: torch.Tensor):
|
770 |
+
with open(path, 'wb') as f:
|
771 |
+
f.write(spk_embed.cpu().numpy().tobytes())
|
772 |
+
print(f'| export spk embed => {path}')
|
773 |
+
|
774 |
+
# noinspection PyMethodMayBeStatic
|
775 |
+
def _export_dictionary(self, path: Path):
|
776 |
+
print(f'| export dictionary => {path}')
|
777 |
+
shutil.copy(locate_dictionary(), path)
|
778 |
+
|
779 |
+
def _export_phonemes(self, path: Path):
|
780 |
+
self.vocab.store_to_file(path)
|
781 |
+
print(f'| export phonemes => {path}')
|
deployment/modules/__init__.py
ADDED
File without changes
|
deployment/modules/diffusion.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
from modules.core import (
|
9 |
+
GaussianDiffusion, PitchDiffusion, MultiVarianceDiffusion
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
def extract(a, t):
|
14 |
+
return a[t].reshape((1, 1, 1, 1))
|
15 |
+
|
16 |
+
|
17 |
+
# noinspection PyMethodOverriding
|
18 |
+
class GaussianDiffusionONNX(GaussianDiffusion):
|
19 |
+
@property
|
20 |
+
def backbone(self):
|
21 |
+
return self.denoise_fn
|
22 |
+
|
23 |
+
# We give up the setter for the property `backbone` because this will cause TorchScript to fail
|
24 |
+
# @backbone.setter
|
25 |
+
@torch.jit.unused
|
26 |
+
def set_backbone(self, value):
|
27 |
+
self.denoise_fn = value
|
28 |
+
|
29 |
+
def q_sample(self, x_start, t, noise):
|
30 |
+
return (
|
31 |
+
extract(self.sqrt_alphas_cumprod, t) * x_start +
|
32 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t) * noise
|
33 |
+
)
|
34 |
+
|
35 |
+
def p_sample(self, x, t, cond):
|
36 |
+
x_pred = self.denoise_fn(x, t, cond)
|
37 |
+
x_recon = (
|
38 |
+
extract(self.sqrt_recip_alphas_cumprod, t) * x -
|
39 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t) * x_pred
|
40 |
+
)
|
41 |
+
# This is previously inherited from original DiffSinger repository
|
42 |
+
# and disabled due to some loudness issues when speedup = 1.
|
43 |
+
# x_recon = torch.clamp(x_recon, min=-1., max=1.)
|
44 |
+
|
45 |
+
model_mean = (
|
46 |
+
extract(self.posterior_mean_coef1, t) * x_recon +
|
47 |
+
extract(self.posterior_mean_coef2, t) * x
|
48 |
+
)
|
49 |
+
model_log_variance = extract(self.posterior_log_variance_clipped, t)
|
50 |
+
noise = torch.randn_like(x)
|
51 |
+
# no noise when t == 0
|
52 |
+
nonzero_mask = ((t > 0).float()).reshape(1, 1, 1, 1)
|
53 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
54 |
+
|
55 |
+
def p_sample_ddim(self, x, t, interval: int, cond):
|
56 |
+
a_t = extract(self.alphas_cumprod, t)
|
57 |
+
t_prev = t - interval
|
58 |
+
a_prev = extract(self.alphas_cumprod, t_prev * (t_prev > 0))
|
59 |
+
|
60 |
+
noise_pred = self.denoise_fn(x, t, cond=cond)
|
61 |
+
x_prev = a_prev.sqrt() * (
|
62 |
+
x / a_t.sqrt() + (((1 - a_prev) / a_prev).sqrt() - ((1 - a_t) / a_t).sqrt()) * noise_pred
|
63 |
+
)
|
64 |
+
return x_prev
|
65 |
+
|
66 |
+
def plms_get_x_pred(self, x, noise_t, t, t_prev):
|
67 |
+
a_t = extract(self.alphas_cumprod, t)
|
68 |
+
a_prev = extract(self.alphas_cumprod, t_prev)
|
69 |
+
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
|
70 |
+
|
71 |
+
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (
|
72 |
+
a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
73 |
+
x_pred = x + x_delta
|
74 |
+
|
75 |
+
return x_pred
|
76 |
+
|
77 |
+
def p_sample_plms(self, x_prev, t, interval: int, cond, noise_list: List[Tensor], stage: int):
|
78 |
+
noise_pred = self.denoise_fn(x_prev, t, cond)
|
79 |
+
t_prev = t - interval
|
80 |
+
t_prev = t_prev * (t_prev > 0)
|
81 |
+
if stage == 0:
|
82 |
+
x_pred = self.plms_get_x_pred(x_prev, noise_pred, t, t_prev)
|
83 |
+
noise_pred_prev = self.denoise_fn(x_pred, t_prev, cond)
|
84 |
+
noise_pred_prime = (noise_pred + noise_pred_prev) / 2.
|
85 |
+
elif stage == 1:
|
86 |
+
noise_pred_prime = (3. * noise_pred - noise_list[-1]) / 2.
|
87 |
+
elif stage == 2:
|
88 |
+
noise_pred_prime = (23. * noise_pred - 16. * noise_list[-1] + 5. * noise_list[-2]) / 12.
|
89 |
+
else:
|
90 |
+
noise_pred_prime = (55. * noise_pred - 59. * noise_list[-1] + 37.
|
91 |
+
* noise_list[-2] - 9. * noise_list[-3]) / 24.
|
92 |
+
x_prev = self.plms_get_x_pred(x_prev, noise_pred_prime, t, t_prev)
|
93 |
+
return noise_pred, x_prev
|
94 |
+
|
95 |
+
def norm_spec(self, x):
|
96 |
+
k = (self.spec_max - self.spec_min) / 2.
|
97 |
+
b = (self.spec_max + self.spec_min) / 2.
|
98 |
+
return (x - b) / k
|
99 |
+
|
100 |
+
def denorm_spec(self, x):
|
101 |
+
k = (self.spec_max - self.spec_min) / 2.
|
102 |
+
b = (self.spec_max + self.spec_min) / 2.
|
103 |
+
return x * k + b
|
104 |
+
|
105 |
+
def forward(self, condition, x_start=None, depth=None, steps: int = 10):
|
106 |
+
condition = condition.transpose(1, 2) # [1, T, H] => [1, H, T]
|
107 |
+
device = condition.device
|
108 |
+
n_frames = condition.shape[2]
|
109 |
+
|
110 |
+
noise = torch.randn((1, self.num_feats, self.out_dims, n_frames), device=device)
|
111 |
+
if x_start is None:
|
112 |
+
speedup = max(1, self.timesteps // steps)
|
113 |
+
speedup = self.timestep_factors[torch.sum(self.timestep_factors <= speedup) - 1]
|
114 |
+
step_range = torch.arange(0, self.k_step, speedup, dtype=torch.long, device=device).flip(0)[:, None]
|
115 |
+
x = noise
|
116 |
+
else:
|
117 |
+
depth_int64 = min(torch.round(depth * self.timesteps).long(), self.k_step)
|
118 |
+
speedup = max(1, depth_int64 // steps)
|
119 |
+
depth_int64 = depth_int64 // speedup * speedup # make depth_int64 a multiple of speedup
|
120 |
+
step_range = torch.arange(0, depth_int64, speedup, dtype=torch.long, device=device).flip(0)[:, None]
|
121 |
+
x_start = self.norm_spec(x_start).transpose(-2, -1)
|
122 |
+
if self.num_feats == 1:
|
123 |
+
x_start = x_start[:, None, :, :]
|
124 |
+
if depth_int64 >= self.timesteps:
|
125 |
+
x = noise
|
126 |
+
elif depth_int64 > 0:
|
127 |
+
x = self.q_sample(
|
128 |
+
x_start, torch.full((1,), depth_int64 - 1, device=device, dtype=torch.long), noise
|
129 |
+
)
|
130 |
+
else:
|
131 |
+
x = x_start
|
132 |
+
|
133 |
+
if speedup > 1:
|
134 |
+
for t in step_range:
|
135 |
+
x = self.p_sample_ddim(x, t, interval=speedup, cond=condition)
|
136 |
+
# plms_noise_stage: int = 0
|
137 |
+
# noise_list: List[Tensor] = []
|
138 |
+
# for t in step_range:
|
139 |
+
# noise_pred, x = self.p_sample_plms(
|
140 |
+
# x, t, interval=speedup, cond=condition,
|
141 |
+
# noise_list=noise_list, stage=plms_noise_stage
|
142 |
+
# )
|
143 |
+
# if plms_noise_stage == 0:
|
144 |
+
# noise_list = [noise_pred]
|
145 |
+
# plms_noise_stage = plms_noise_stage + 1
|
146 |
+
# else:
|
147 |
+
# if plms_noise_stage >= 3:
|
148 |
+
# noise_list.pop(0)
|
149 |
+
# else:
|
150 |
+
# plms_noise_stage = plms_noise_stage + 1
|
151 |
+
# noise_list.append(noise_pred)
|
152 |
+
else:
|
153 |
+
for t in step_range:
|
154 |
+
x = self.p_sample(x, t, cond=condition)
|
155 |
+
|
156 |
+
if self.num_feats == 1:
|
157 |
+
x = x.squeeze(1).permute(0, 2, 1) # [B, 1, M, T] => [B, T, M]
|
158 |
+
else:
|
159 |
+
x = x.permute(0, 1, 3, 2) # [B, F, M, T] => [B, F, T, M]
|
160 |
+
x = self.denorm_spec(x)
|
161 |
+
return x
|
162 |
+
|
163 |
+
|
164 |
+
class PitchDiffusionONNX(GaussianDiffusionONNX, PitchDiffusion):
|
165 |
+
def __init__(self, vmin: float, vmax: float,
|
166 |
+
cmin: float, cmax: float, repeat_bins,
|
167 |
+
timesteps=1000, k_step=1000,
|
168 |
+
backbone_type=None, backbone_args=None,
|
169 |
+
betas=None):
|
170 |
+
self.vmin = vmin
|
171 |
+
self.vmax = vmax
|
172 |
+
self.cmin = cmin
|
173 |
+
self.cmax = cmax
|
174 |
+
super(PitchDiffusion, self).__init__(
|
175 |
+
vmin=vmin, vmax=vmax, repeat_bins=repeat_bins,
|
176 |
+
timesteps=timesteps, k_step=k_step,
|
177 |
+
backbone_type=backbone_type, backbone_args=backbone_args,
|
178 |
+
betas=betas
|
179 |
+
)
|
180 |
+
|
181 |
+
def clamp_spec(self, x):
|
182 |
+
return x.clamp(min=self.cmin, max=self.cmax)
|
183 |
+
|
184 |
+
def denorm_spec(self, x):
|
185 |
+
d = (self.spec_max - self.spec_min) / 2.
|
186 |
+
m = (self.spec_max + self.spec_min) / 2.
|
187 |
+
x = x * d + m
|
188 |
+
x = x.mean(dim=-1)
|
189 |
+
return x
|
190 |
+
|
191 |
+
|
192 |
+
class MultiVarianceDiffusionONNX(GaussianDiffusionONNX, MultiVarianceDiffusion):
|
193 |
+
def __init__(
|
194 |
+
self, ranges: List[Tuple[float, float]],
|
195 |
+
clamps: List[Tuple[float | None, float | None] | None],
|
196 |
+
repeat_bins, timesteps=1000, k_step=1000,
|
197 |
+
backbone_type=None, backbone_args=None,
|
198 |
+
betas=None
|
199 |
+
):
|
200 |
+
assert len(ranges) == len(clamps)
|
201 |
+
self.clamps = clamps
|
202 |
+
vmin = [r[0] for r in ranges]
|
203 |
+
vmax = [r[1] for r in ranges]
|
204 |
+
if len(vmin) == 1:
|
205 |
+
vmin = vmin[0]
|
206 |
+
if len(vmax) == 1:
|
207 |
+
vmax = vmax[0]
|
208 |
+
super(MultiVarianceDiffusion, self).__init__(
|
209 |
+
vmin=vmin, vmax=vmax, repeat_bins=repeat_bins,
|
210 |
+
timesteps=timesteps, k_step=k_step,
|
211 |
+
backbone_type=backbone_type, backbone_args=backbone_args,
|
212 |
+
betas=betas
|
213 |
+
)
|
214 |
+
|
215 |
+
def denorm_spec(self, x):
|
216 |
+
d = (self.spec_max - self.spec_min) / 2.
|
217 |
+
m = (self.spec_max + self.spec_min) / 2.
|
218 |
+
x = x * d + m
|
219 |
+
x = x.mean(dim=-1)
|
220 |
+
return x
|
deployment/modules/fastspeech2.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from modules.commons.common_layers import NormalInitEmbedding as Embedding
|
9 |
+
from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic
|
10 |
+
from modules.fastspeech.variance_encoder import FastSpeech2Variance
|
11 |
+
from utils.hparams import hparams
|
12 |
+
from utils.text_encoder import PAD_INDEX
|
13 |
+
|
14 |
+
f0_bin = 256
|
15 |
+
f0_max = 1100.0
|
16 |
+
f0_min = 50.0
|
17 |
+
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
|
18 |
+
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
|
19 |
+
|
20 |
+
|
21 |
+
def f0_to_coarse(f0):
|
22 |
+
f0_mel = 1127 * (1 + f0 / 700).log()
|
23 |
+
a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
|
24 |
+
b = f0_mel_min * a - 1.
|
25 |
+
f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
|
26 |
+
torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
|
27 |
+
f0_coarse = torch.round(f0_mel).long()
|
28 |
+
return f0_coarse
|
29 |
+
|
30 |
+
|
31 |
+
class LengthRegulator(nn.Module):
|
32 |
+
# noinspection PyMethodMayBeStatic
|
33 |
+
def forward(self, dur):
|
34 |
+
token_idx = torch.arange(1, dur.shape[1] + 1, device=dur.device)[None, :, None]
|
35 |
+
dur_cumsum = torch.cumsum(dur, dim=1)
|
36 |
+
dur_cumsum_prev = F.pad(dur_cumsum, (1, -1), mode='constant', value=0)
|
37 |
+
pos_idx = torch.arange(dur.sum(dim=1).max(), device=dur.device)[None, None]
|
38 |
+
token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
|
39 |
+
mel2ph = (token_idx * token_mask).sum(dim=1)
|
40 |
+
return mel2ph
|
41 |
+
|
42 |
+
|
43 |
+
class FastSpeech2AcousticONNX(FastSpeech2Acoustic):
|
44 |
+
def __init__(self, vocab_size):
|
45 |
+
super().__init__(vocab_size=vocab_size)
|
46 |
+
|
47 |
+
# for temporary compatibility; will be completely removed in the future
|
48 |
+
self.f0_embed_type = hparams.get('f0_embed_type', 'continuous')
|
49 |
+
if self.f0_embed_type == 'discrete':
|
50 |
+
self.pitch_embed = Embedding(300, hparams['hidden_size'], PAD_INDEX)
|
51 |
+
|
52 |
+
self.lr = LengthRegulator()
|
53 |
+
if hparams['use_key_shift_embed']:
|
54 |
+
self.shift_min, self.shift_max = hparams['augmentation_args']['random_pitch_shifting']['range']
|
55 |
+
if hparams['use_speed_embed']:
|
56 |
+
self.speed_min, self.speed_max = hparams['augmentation_args']['random_time_stretching']['range']
|
57 |
+
|
58 |
+
# noinspection PyMethodOverriding
|
59 |
+
def forward(self, tokens, durations, f0, variances: dict, gender=None, velocity=None, spk_embed=None):
|
60 |
+
txt_embed = self.txt_embed(tokens)
|
61 |
+
durations = durations * (tokens > 0)
|
62 |
+
mel2ph = self.lr(durations)
|
63 |
+
f0 = f0 * (mel2ph > 0)
|
64 |
+
mel2ph = mel2ph[..., None].repeat((1, 1, hparams['hidden_size']))
|
65 |
+
dur_embed = self.dur_embed(durations.float()[:, :, None])
|
66 |
+
encoded = self.encoder(txt_embed, dur_embed, tokens == PAD_INDEX)
|
67 |
+
encoded = F.pad(encoded, (0, 0, 1, 0))
|
68 |
+
condition = torch.gather(encoded, 1, mel2ph)
|
69 |
+
|
70 |
+
if self.f0_embed_type == 'discrete':
|
71 |
+
pitch = f0_to_coarse(f0)
|
72 |
+
pitch_embed = self.pitch_embed(pitch)
|
73 |
+
else:
|
74 |
+
f0_mel = (1 + f0 / 700).log()
|
75 |
+
pitch_embed = self.pitch_embed(f0_mel[:, :, None])
|
76 |
+
condition += pitch_embed
|
77 |
+
|
78 |
+
if self.use_variance_embeds:
|
79 |
+
variance_embeds = torch.stack([
|
80 |
+
self.variance_embeds[v_name](variances[v_name][:, :, None])
|
81 |
+
for v_name in self.variance_embed_list
|
82 |
+
], dim=-1).sum(-1)
|
83 |
+
condition += variance_embeds
|
84 |
+
|
85 |
+
if hparams['use_key_shift_embed']:
|
86 |
+
if hasattr(self, 'frozen_key_shift'):
|
87 |
+
key_shift_embed = self.key_shift_embed(self.frozen_key_shift[:, None, None])
|
88 |
+
else:
|
89 |
+
gender = torch.clip(gender, min=-1., max=1.)
|
90 |
+
gender_mask = (gender < 0.).float()
|
91 |
+
key_shift = gender * ((1. - gender_mask) * self.shift_max + gender_mask * abs(self.shift_min))
|
92 |
+
key_shift_embed = self.key_shift_embed(key_shift[:, :, None])
|
93 |
+
condition += key_shift_embed
|
94 |
+
|
95 |
+
if hparams['use_speed_embed']:
|
96 |
+
if velocity is not None:
|
97 |
+
velocity = torch.clip(velocity, min=self.speed_min, max=self.speed_max)
|
98 |
+
speed_embed = self.speed_embed(velocity[:, :, None])
|
99 |
+
else:
|
100 |
+
speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None])
|
101 |
+
condition += speed_embed
|
102 |
+
|
103 |
+
if hparams['use_spk_id']:
|
104 |
+
if hasattr(self, 'frozen_spk_embed'):
|
105 |
+
condition += self.frozen_spk_embed
|
106 |
+
else:
|
107 |
+
condition += spk_embed
|
108 |
+
return condition
|
109 |
+
|
110 |
+
|
111 |
+
class FastSpeech2VarianceONNX(FastSpeech2Variance):
|
112 |
+
def __init__(self, vocab_size):
|
113 |
+
super().__init__(vocab_size=vocab_size)
|
114 |
+
self.lr = LengthRegulator()
|
115 |
+
|
116 |
+
def forward_encoder_word(self, tokens, word_div, word_dur):
|
117 |
+
txt_embed = self.txt_embed(tokens)
|
118 |
+
ph2word = self.lr(word_div)
|
119 |
+
onset = ph2word > F.pad(ph2word, [1, -1])
|
120 |
+
onset_embed = self.onset_embed(onset.long())
|
121 |
+
ph_word_dur = torch.gather(F.pad(word_dur, [1, 0]), 1, ph2word)
|
122 |
+
word_dur_embed = self.word_dur_embed(ph_word_dur.float()[:, :, None])
|
123 |
+
x_masks = tokens == PAD_INDEX
|
124 |
+
return self.encoder(txt_embed, onset_embed + word_dur_embed, x_masks), x_masks
|
125 |
+
|
126 |
+
def forward_encoder_phoneme(self, tokens, ph_dur):
|
127 |
+
txt_embed = self.txt_embed(tokens)
|
128 |
+
ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None])
|
129 |
+
x_masks = tokens == PAD_INDEX
|
130 |
+
return self.encoder(txt_embed, ph_dur_embed, x_masks), x_masks
|
131 |
+
|
132 |
+
def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None):
|
133 |
+
midi_embed = self.midi_embed(ph_midi)
|
134 |
+
dur_cond = encoder_out + midi_embed
|
135 |
+
if hparams['use_spk_id'] and spk_embed is not None:
|
136 |
+
dur_cond += spk_embed
|
137 |
+
ph_dur = self.dur_predictor(dur_cond, x_masks=x_masks)
|
138 |
+
return ph_dur
|
139 |
+
|
140 |
+
def view_as_encoder(self):
|
141 |
+
model = copy.deepcopy(self)
|
142 |
+
if self.predict_dur:
|
143 |
+
del model.dur_predictor
|
144 |
+
model.forward = model.forward_encoder_word
|
145 |
+
else:
|
146 |
+
model.forward = model.forward_encoder_phoneme
|
147 |
+
return model
|
148 |
+
|
149 |
+
def view_as_dur_predictor(self):
|
150 |
+
model = copy.deepcopy(self)
|
151 |
+
del model.encoder
|
152 |
+
model.forward = model.forward_dur_predictor
|
153 |
+
return model
|
deployment/modules/nsf_hifigan.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from modules.nsf_hifigan.env import AttrDict
|
4 |
+
from modules.nsf_hifigan.models import Generator
|
5 |
+
|
6 |
+
|
7 |
+
# noinspection SpellCheckingInspection
|
8 |
+
class NSFHiFiGANONNX(torch.nn.Module):
|
9 |
+
def __init__(self, attrs: dict):
|
10 |
+
super().__init__()
|
11 |
+
self.generator = Generator(AttrDict(attrs))
|
12 |
+
|
13 |
+
def forward(self, mel: torch.Tensor, f0: torch.Tensor):
|
14 |
+
mel = mel.transpose(1, 2)
|
15 |
+
wav = self.generator(mel, f0)
|
16 |
+
return wav.squeeze(1)
|
deployment/modules/rectified_flow.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from modules.core import (
|
8 |
+
RectifiedFlow, PitchRectifiedFlow, MultiVarianceRectifiedFlow
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
class RectifiedFlowONNX(RectifiedFlow):
|
13 |
+
@property
|
14 |
+
def backbone(self):
|
15 |
+
return self.velocity_fn
|
16 |
+
|
17 |
+
# We give up the setter for the property `backbone` because this will cause TorchScript to fail
|
18 |
+
# @backbone.setter
|
19 |
+
@torch.jit.unused
|
20 |
+
def set_backbone(self, value):
|
21 |
+
self.velocity_fn = value
|
22 |
+
|
23 |
+
def sample_euler(self, x, t, dt: float, cond):
|
24 |
+
x += self.velocity_fn(x, t * self.time_scale_factor, cond) * dt
|
25 |
+
return x
|
26 |
+
|
27 |
+
def norm_spec(self, x):
|
28 |
+
k = (self.spec_max - self.spec_min) / 2.
|
29 |
+
b = (self.spec_max + self.spec_min) / 2.
|
30 |
+
return (x - b) / k
|
31 |
+
|
32 |
+
def denorm_spec(self, x):
|
33 |
+
k = (self.spec_max - self.spec_min) / 2.
|
34 |
+
b = (self.spec_max + self.spec_min) / 2.
|
35 |
+
return x * k + b
|
36 |
+
|
37 |
+
def forward(self, condition, x_end=None, depth=None, steps: int = 10):
|
38 |
+
condition = condition.transpose(1, 2) # [1, T, H] => [1, H, T]
|
39 |
+
device = condition.device
|
40 |
+
n_frames = condition.shape[2]
|
41 |
+
noise = torch.randn((1, self.num_feats, self.out_dims, n_frames), device=device)
|
42 |
+
if x_end is None:
|
43 |
+
t_start = 0.
|
44 |
+
x = noise
|
45 |
+
else:
|
46 |
+
t_start = torch.max(1 - depth, torch.tensor(self.t_start, dtype=torch.float32, device=device))
|
47 |
+
x_end = self.norm_spec(x_end).transpose(-2, -1)
|
48 |
+
if self.num_feats == 1:
|
49 |
+
x_end = x_end[:, None, :, :]
|
50 |
+
if t_start <= 0.:
|
51 |
+
x = noise
|
52 |
+
elif t_start >= 1.:
|
53 |
+
x = x_end
|
54 |
+
else:
|
55 |
+
x = t_start * x_end + (1 - t_start) * noise
|
56 |
+
|
57 |
+
t_width = 1. - t_start
|
58 |
+
if t_width >= 0.:
|
59 |
+
dt = t_width / max(1, steps)
|
60 |
+
for t in torch.arange(steps, dtype=torch.long, device=device)[:, None].float() * dt + t_start:
|
61 |
+
x = self.sample_euler(x, t, dt, condition)
|
62 |
+
|
63 |
+
if self.num_feats == 1:
|
64 |
+
x = x.squeeze(1).permute(0, 2, 1) # [B, 1, M, T] => [B, T, M]
|
65 |
+
else:
|
66 |
+
x = x.permute(0, 1, 3, 2) # [B, F, M, T] => [B, F, T, M]
|
67 |
+
x = self.denorm_spec(x)
|
68 |
+
return x
|
69 |
+
|
70 |
+
|
71 |
+
class PitchRectifiedFlowONNX(RectifiedFlowONNX, PitchRectifiedFlow):
|
72 |
+
def __init__(self, vmin: float, vmax: float,
|
73 |
+
cmin: float, cmax: float, repeat_bins,
|
74 |
+
time_scale_factor=1000,
|
75 |
+
backbone_type=None, backbone_args=None):
|
76 |
+
self.vmin = vmin
|
77 |
+
self.vmax = vmax
|
78 |
+
self.cmin = cmin
|
79 |
+
self.cmax = cmax
|
80 |
+
super(PitchRectifiedFlow, self).__init__(
|
81 |
+
vmin=vmin, vmax=vmax, repeat_bins=repeat_bins,
|
82 |
+
time_scale_factor=time_scale_factor,
|
83 |
+
backbone_type=backbone_type, backbone_args=backbone_args
|
84 |
+
)
|
85 |
+
|
86 |
+
def clamp_spec(self, x):
|
87 |
+
return x.clamp(min=self.cmin, max=self.cmax)
|
88 |
+
|
89 |
+
def denorm_spec(self, x):
|
90 |
+
d = (self.spec_max - self.spec_min) / 2.
|
91 |
+
m = (self.spec_max + self.spec_min) / 2.
|
92 |
+
x = x * d + m
|
93 |
+
x = x.mean(dim=-1)
|
94 |
+
return x
|
95 |
+
|
96 |
+
|
97 |
+
class MultiVarianceRectifiedFlowONNX(RectifiedFlowONNX, MultiVarianceRectifiedFlow):
|
98 |
+
def __init__(
|
99 |
+
self, ranges: List[Tuple[float, float]],
|
100 |
+
clamps: List[Tuple[float | None, float | None] | None],
|
101 |
+
repeat_bins, time_scale_factor=1000,
|
102 |
+
backbone_type=None, backbone_args=None
|
103 |
+
):
|
104 |
+
assert len(ranges) == len(clamps)
|
105 |
+
self.clamps = clamps
|
106 |
+
vmin = [r[0] for r in ranges]
|
107 |
+
vmax = [r[1] for r in ranges]
|
108 |
+
if len(vmin) == 1:
|
109 |
+
vmin = vmin[0]
|
110 |
+
if len(vmax) == 1:
|
111 |
+
vmax = vmax[0]
|
112 |
+
super(MultiVarianceRectifiedFlow, self).__init__(
|
113 |
+
vmin=vmin, vmax=vmax, repeat_bins=repeat_bins,
|
114 |
+
time_scale_factor=time_scale_factor,
|
115 |
+
backbone_type=backbone_type, backbone_args=backbone_args
|
116 |
+
)
|
117 |
+
|
118 |
+
def denorm_spec(self, x):
|
119 |
+
d = (self.spec_max - self.spec_min) / 2.
|
120 |
+
m = (self.spec_max + self.spec_min) / 2.
|
121 |
+
x = x * d + m
|
122 |
+
x = x.mean(dim=-1)
|
123 |
+
return x
|
deployment/modules/toplevel.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import Tensor
|
8 |
+
|
9 |
+
from deployment.modules.diffusion import (
|
10 |
+
GaussianDiffusionONNX, PitchDiffusionONNX, MultiVarianceDiffusionONNX
|
11 |
+
)
|
12 |
+
from deployment.modules.rectified_flow import (
|
13 |
+
RectifiedFlowONNX, PitchRectifiedFlowONNX, MultiVarianceRectifiedFlowONNX
|
14 |
+
)
|
15 |
+
from deployment.modules.fastspeech2 import FastSpeech2AcousticONNX, FastSpeech2VarianceONNX
|
16 |
+
from modules.toplevel import DiffSingerAcoustic, DiffSingerVariance
|
17 |
+
from utils.hparams import hparams
|
18 |
+
|
19 |
+
|
20 |
+
class DiffSingerAcousticONNX(DiffSingerAcoustic):
|
21 |
+
def __init__(self, vocab_size, out_dims):
|
22 |
+
super().__init__(vocab_size, out_dims)
|
23 |
+
del self.fs2
|
24 |
+
del self.diffusion
|
25 |
+
self.fs2 = FastSpeech2AcousticONNX(
|
26 |
+
vocab_size=vocab_size
|
27 |
+
)
|
28 |
+
if self.diffusion_type == 'ddpm':
|
29 |
+
self.diffusion = GaussianDiffusionONNX(
|
30 |
+
out_dims=out_dims,
|
31 |
+
num_feats=1,
|
32 |
+
timesteps=hparams['timesteps'],
|
33 |
+
k_step=hparams['K_step'],
|
34 |
+
backbone_type=self.backbone_type,
|
35 |
+
backbone_args=self.backbone_args,
|
36 |
+
spec_min=hparams['spec_min'],
|
37 |
+
spec_max=hparams['spec_max']
|
38 |
+
)
|
39 |
+
elif self.diffusion_type == 'reflow':
|
40 |
+
self.diffusion = RectifiedFlowONNX(
|
41 |
+
out_dims=out_dims,
|
42 |
+
num_feats=1,
|
43 |
+
t_start=hparams['T_start'],
|
44 |
+
time_scale_factor=hparams['time_scale_factor'],
|
45 |
+
backbone_type=self.backbone_type,
|
46 |
+
backbone_args=self.backbone_args,
|
47 |
+
spec_min=hparams['spec_min'],
|
48 |
+
spec_max=hparams['spec_max']
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
raise ValueError(f"Invalid diffusion type: {self.diffusion_type}")
|
52 |
+
self.mel_base = hparams.get('mel_base', '10')
|
53 |
+
|
54 |
+
def ensure_mel_base(self, mel):
|
55 |
+
if self.mel_base != 'e':
|
56 |
+
# log10 mel to log mel
|
57 |
+
mel = mel * 2.30259
|
58 |
+
return mel
|
59 |
+
|
60 |
+
def forward_fs2_aux(
|
61 |
+
self,
|
62 |
+
tokens: Tensor,
|
63 |
+
durations: Tensor,
|
64 |
+
f0: Tensor,
|
65 |
+
variances: dict,
|
66 |
+
gender: Tensor = None,
|
67 |
+
velocity: Tensor = None,
|
68 |
+
spk_embed: Tensor = None
|
69 |
+
):
|
70 |
+
condition = self.fs2(
|
71 |
+
tokens, durations, f0, variances=variances,
|
72 |
+
gender=gender, velocity=velocity, spk_embed=spk_embed
|
73 |
+
)
|
74 |
+
if self.use_shallow_diffusion:
|
75 |
+
aux_mel_pred = self.aux_decoder(condition, infer=True)
|
76 |
+
return condition, aux_mel_pred
|
77 |
+
else:
|
78 |
+
return condition
|
79 |
+
|
80 |
+
def forward_shallow_diffusion(
|
81 |
+
self, condition: Tensor, x_start: Tensor,
|
82 |
+
depth, steps: int
|
83 |
+
) -> Tensor:
|
84 |
+
mel_pred = self.diffusion(condition, x_start=x_start, depth=depth, steps=steps)
|
85 |
+
return self.ensure_mel_base(mel_pred)
|
86 |
+
|
87 |
+
def forward_diffusion(self, condition: Tensor, steps: int):
|
88 |
+
mel_pred = self.diffusion(condition, steps=steps)
|
89 |
+
return self.ensure_mel_base(mel_pred)
|
90 |
+
|
91 |
+
def forward_shallow_reflow(
|
92 |
+
self, condition: Tensor, x_end: Tensor,
|
93 |
+
depth, steps: int
|
94 |
+
):
|
95 |
+
mel_pred = self.diffusion(condition, x_end=x_end, depth=depth, steps=steps)
|
96 |
+
return self.ensure_mel_base(mel_pred)
|
97 |
+
|
98 |
+
def forward_reflow(self, condition: Tensor, steps: int):
|
99 |
+
mel_pred = self.diffusion(condition, steps=steps)
|
100 |
+
return self.ensure_mel_base(mel_pred)
|
101 |
+
|
102 |
+
def view_as_fs2_aux(self) -> nn.Module:
|
103 |
+
model = copy.deepcopy(self)
|
104 |
+
del model.diffusion
|
105 |
+
model.forward = model.forward_fs2_aux
|
106 |
+
return model
|
107 |
+
|
108 |
+
def view_as_diffusion(self) -> nn.Module:
|
109 |
+
model = copy.deepcopy(self)
|
110 |
+
del model.fs2
|
111 |
+
if self.use_shallow_diffusion:
|
112 |
+
del model.aux_decoder
|
113 |
+
model.forward = model.forward_shallow_diffusion
|
114 |
+
else:
|
115 |
+
model.forward = model.forward_diffusion
|
116 |
+
return model
|
117 |
+
|
118 |
+
def view_as_reflow(self) -> nn.Module:
|
119 |
+
model = copy.deepcopy(self)
|
120 |
+
del model.fs2
|
121 |
+
if self.use_shallow_diffusion:
|
122 |
+
del model.aux_decoder
|
123 |
+
model.forward = model.forward_shallow_reflow
|
124 |
+
else:
|
125 |
+
model.forward = model.forward_reflow
|
126 |
+
return model
|
127 |
+
|
128 |
+
|
129 |
+
class DiffSingerVarianceONNX(DiffSingerVariance):
|
130 |
+
def __init__(self, vocab_size):
|
131 |
+
super().__init__(vocab_size=vocab_size)
|
132 |
+
del self.fs2
|
133 |
+
self.fs2 = FastSpeech2VarianceONNX(
|
134 |
+
vocab_size=vocab_size
|
135 |
+
)
|
136 |
+
self.hidden_size = hparams['hidden_size']
|
137 |
+
if self.predict_pitch:
|
138 |
+
del self.pitch_predictor
|
139 |
+
self.smooth: nn.Conv1d = None
|
140 |
+
pitch_hparams = hparams['pitch_prediction_args']
|
141 |
+
if self.diffusion_type == 'ddpm':
|
142 |
+
self.pitch_predictor = PitchDiffusionONNX(
|
143 |
+
vmin=pitch_hparams['pitd_norm_min'],
|
144 |
+
vmax=pitch_hparams['pitd_norm_max'],
|
145 |
+
cmin=pitch_hparams['pitd_clip_min'],
|
146 |
+
cmax=pitch_hparams['pitd_clip_max'],
|
147 |
+
repeat_bins=pitch_hparams['repeat_bins'],
|
148 |
+
timesteps=hparams['timesteps'],
|
149 |
+
k_step=hparams['K_step'],
|
150 |
+
backbone_type=self.pitch_backbone_type,
|
151 |
+
backbone_args=self.pitch_backbone_args
|
152 |
+
)
|
153 |
+
elif self.diffusion_type == 'reflow':
|
154 |
+
self.pitch_predictor = PitchRectifiedFlowONNX(
|
155 |
+
vmin=pitch_hparams['pitd_norm_min'],
|
156 |
+
vmax=pitch_hparams['pitd_norm_max'],
|
157 |
+
cmin=pitch_hparams['pitd_clip_min'],
|
158 |
+
cmax=pitch_hparams['pitd_clip_max'],
|
159 |
+
repeat_bins=pitch_hparams['repeat_bins'],
|
160 |
+
time_scale_factor=hparams['time_scale_factor'],
|
161 |
+
backbone_type=self.pitch_backbone_type,
|
162 |
+
backbone_args=self.pitch_backbone_args
|
163 |
+
)
|
164 |
+
else:
|
165 |
+
raise ValueError(f"Invalid diffusion type: {self.diffusion_type}")
|
166 |
+
if self.predict_variances:
|
167 |
+
del self.variance_predictor
|
168 |
+
if self.diffusion_type == 'ddpm':
|
169 |
+
self.variance_predictor = self.build_adaptor(cls=MultiVarianceDiffusionONNX)
|
170 |
+
elif self.diffusion_type == 'reflow':
|
171 |
+
self.variance_predictor = self.build_adaptor(cls=MultiVarianceRectifiedFlowONNX)
|
172 |
+
else:
|
173 |
+
raise NotImplementedError(self.diffusion_type)
|
174 |
+
|
175 |
+
def build_smooth_op(self, device):
|
176 |
+
smooth_kernel_size = round(hparams['midi_smooth_width'] * hparams['audio_sample_rate'] / hparams['hop_size'])
|
177 |
+
smooth = nn.Conv1d(
|
178 |
+
in_channels=1,
|
179 |
+
out_channels=1,
|
180 |
+
kernel_size=smooth_kernel_size,
|
181 |
+
bias=False,
|
182 |
+
padding='same',
|
183 |
+
padding_mode='replicate'
|
184 |
+
).eval()
|
185 |
+
smooth_kernel = torch.sin(torch.from_numpy(
|
186 |
+
np.linspace(0, 1, smooth_kernel_size).astype(np.float32) * np.pi
|
187 |
+
))
|
188 |
+
smooth_kernel /= smooth_kernel.sum()
|
189 |
+
smooth.weight.data = smooth_kernel[None, None]
|
190 |
+
self.smooth = smooth.to(device)
|
191 |
+
|
192 |
+
def embed_frozen_spk(self, encoder_out):
|
193 |
+
if hparams['use_spk_id'] and hasattr(self, 'frozen_spk_embed'):
|
194 |
+
encoder_out += self.frozen_spk_embed
|
195 |
+
return encoder_out
|
196 |
+
|
197 |
+
def forward_linguistic_encoder_word(self, tokens, word_div, word_dur):
|
198 |
+
encoder_out, x_masks = self.fs2.forward_encoder_word(tokens, word_div, word_dur)
|
199 |
+
encoder_out = self.embed_frozen_spk(encoder_out)
|
200 |
+
return encoder_out, x_masks
|
201 |
+
|
202 |
+
def forward_linguistic_encoder_phoneme(self, tokens, ph_dur):
|
203 |
+
encoder_out, x_masks = self.fs2.forward_encoder_phoneme(tokens, ph_dur)
|
204 |
+
encoder_out = self.embed_frozen_spk(encoder_out)
|
205 |
+
return encoder_out, x_masks
|
206 |
+
|
207 |
+
def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None):
|
208 |
+
return self.fs2.forward_dur_predictor(encoder_out, x_masks, ph_midi, spk_embed=spk_embed)
|
209 |
+
|
210 |
+
def forward_mel2x_gather(self, x_src, x_dur, x_dim=None):
|
211 |
+
mel2x = self.lr(x_dur)
|
212 |
+
if x_dim is not None:
|
213 |
+
x_src = F.pad(x_src, [0, 0, 1, 0])
|
214 |
+
mel2x = mel2x[..., None].repeat([1, 1, x_dim])
|
215 |
+
else:
|
216 |
+
x_src = F.pad(x_src, [1, 0])
|
217 |
+
x_cond = torch.gather(x_src, 1, mel2x)
|
218 |
+
return x_cond
|
219 |
+
|
220 |
+
def forward_pitch_preprocess(
|
221 |
+
self, encoder_out, ph_dur,
|
222 |
+
note_midi=None, note_rest=None, note_dur=None, note_glide=None,
|
223 |
+
pitch=None, expr=None, retake=None, spk_embed=None
|
224 |
+
):
|
225 |
+
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size)
|
226 |
+
if self.use_melody_encoder:
|
227 |
+
if self.melody_encoder.use_glide_embed and note_glide is None:
|
228 |
+
note_glide = torch.LongTensor([[0]]).to(encoder_out.device)
|
229 |
+
melody_encoder_out = self.melody_encoder(
|
230 |
+
note_midi, note_rest, note_dur,
|
231 |
+
glide=note_glide
|
232 |
+
)
|
233 |
+
melody_encoder_out = self.forward_mel2x_gather(melody_encoder_out, note_dur, x_dim=self.hidden_size)
|
234 |
+
condition += melody_encoder_out
|
235 |
+
if expr is None:
|
236 |
+
retake_embed = self.pitch_retake_embed(retake.long())
|
237 |
+
else:
|
238 |
+
retake_true_embed = self.pitch_retake_embed(
|
239 |
+
torch.ones(1, 1, dtype=torch.long, device=encoder_out.device)
|
240 |
+
) # [B=1, T=1] => [B=1, T=1, H]
|
241 |
+
retake_false_embed = self.pitch_retake_embed(
|
242 |
+
torch.zeros(1, 1, dtype=torch.long, device=encoder_out.device)
|
243 |
+
) # [B=1, T=1] => [B=1, T=1, H]
|
244 |
+
expr = (expr * retake)[:, :, None] # [B, T, 1]
|
245 |
+
retake_embed = expr * retake_true_embed + (1. - expr) * retake_false_embed
|
246 |
+
pitch_cond = condition + retake_embed
|
247 |
+
frame_midi_pitch = self.forward_mel2x_gather(note_midi, note_dur, x_dim=None)
|
248 |
+
base_pitch = self.smooth(frame_midi_pitch)
|
249 |
+
if self.use_melody_encoder:
|
250 |
+
delta_pitch = (pitch - base_pitch) * ~retake
|
251 |
+
pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None])
|
252 |
+
else:
|
253 |
+
base_pitch = base_pitch * retake + pitch * ~retake
|
254 |
+
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
|
255 |
+
if hparams['use_spk_id'] and spk_embed is not None:
|
256 |
+
pitch_cond += spk_embed
|
257 |
+
return pitch_cond, base_pitch
|
258 |
+
|
259 |
+
def forward_pitch_reflow(
|
260 |
+
self, pitch_cond, steps: int = 10
|
261 |
+
):
|
262 |
+
x_pred = self.pitch_predictor(pitch_cond, steps=steps)
|
263 |
+
return x_pred
|
264 |
+
|
265 |
+
def forward_pitch_postprocess(self, x_pred, base_pitch):
|
266 |
+
pitch_pred = self.pitch_predictor.clamp_spec(x_pred) + base_pitch
|
267 |
+
return pitch_pred
|
268 |
+
|
269 |
+
def forward_variance_preprocess(
|
270 |
+
self, encoder_out, ph_dur, pitch,
|
271 |
+
variances: dict = None, retake=None, spk_embed=None
|
272 |
+
):
|
273 |
+
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size)
|
274 |
+
variance_cond = condition + self.pitch_embed(pitch[:, :, None])
|
275 |
+
non_retake_masks = [
|
276 |
+
v_retake.float() # [B, T, 1]
|
277 |
+
for v_retake in (~retake).split(1, dim=2)
|
278 |
+
]
|
279 |
+
variance_embeds = [
|
280 |
+
self.variance_embeds[v_name](variances[v_name][:, :, None]) * v_masks
|
281 |
+
for v_name, v_masks in zip(self.variance_prediction_list, non_retake_masks)
|
282 |
+
]
|
283 |
+
variance_cond += torch.stack(variance_embeds, dim=-1).sum(-1)
|
284 |
+
if hparams['use_spk_id'] and spk_embed is not None:
|
285 |
+
variance_cond += spk_embed
|
286 |
+
return variance_cond
|
287 |
+
|
288 |
+
def forward_variance_reflow(self, variance_cond, steps: int = 10):
|
289 |
+
xs_pred = self.variance_predictor(variance_cond, steps=steps)
|
290 |
+
return xs_pred
|
291 |
+
|
292 |
+
def forward_variance_postprocess(self, xs_pred):
|
293 |
+
if self.variance_predictor.num_feats == 1:
|
294 |
+
xs_pred = [xs_pred]
|
295 |
+
else:
|
296 |
+
xs_pred = xs_pred.unbind(dim=1)
|
297 |
+
variance_pred = self.variance_predictor.clamp_spec(xs_pred)
|
298 |
+
return tuple(variance_pred)
|
299 |
+
|
300 |
+
def view_as_linguistic_encoder(self):
|
301 |
+
model = copy.deepcopy(self)
|
302 |
+
if self.predict_pitch:
|
303 |
+
del model.pitch_predictor
|
304 |
+
if self.use_melody_encoder:
|
305 |
+
del model.melody_encoder
|
306 |
+
if self.predict_variances:
|
307 |
+
del model.variance_predictor
|
308 |
+
model.fs2 = model.fs2.view_as_encoder()
|
309 |
+
if self.predict_dur:
|
310 |
+
model.forward = model.forward_linguistic_encoder_word
|
311 |
+
else:
|
312 |
+
model.forward = model.forward_linguistic_encoder_phoneme
|
313 |
+
return model
|
314 |
+
|
315 |
+
def view_as_dur_predictor(self):
|
316 |
+
assert self.predict_dur
|
317 |
+
model = copy.deepcopy(self)
|
318 |
+
if self.predict_pitch:
|
319 |
+
del model.pitch_predictor
|
320 |
+
if self.use_melody_encoder:
|
321 |
+
del model.melody_encoder
|
322 |
+
if self.predict_variances:
|
323 |
+
del model.variance_predictor
|
324 |
+
model.fs2 = model.fs2.view_as_dur_predictor()
|
325 |
+
model.forward = model.forward_dur_predictor
|
326 |
+
return model
|
327 |
+
|
328 |
+
def view_as_pitch_preprocess(self):
|
329 |
+
model = copy.deepcopy(self)
|
330 |
+
del model.fs2
|
331 |
+
if self.predict_pitch:
|
332 |
+
del model.pitch_predictor
|
333 |
+
if self.predict_variances:
|
334 |
+
del model.variance_predictor
|
335 |
+
model.forward = model.forward_pitch_preprocess
|
336 |
+
return model
|
337 |
+
|
338 |
+
def view_as_pitch_predictor(self):
|
339 |
+
assert self.predict_pitch
|
340 |
+
model = copy.deepcopy(self)
|
341 |
+
del model.fs2
|
342 |
+
del model.lr
|
343 |
+
if self.use_melody_encoder:
|
344 |
+
del model.melody_encoder
|
345 |
+
if self.predict_variances:
|
346 |
+
del model.variance_predictor
|
347 |
+
model.forward = model.forward_pitch_reflow
|
348 |
+
return model
|
349 |
+
|
350 |
+
def view_as_pitch_postprocess(self):
|
351 |
+
model = copy.deepcopy(self)
|
352 |
+
del model.fs2
|
353 |
+
if self.use_melody_encoder:
|
354 |
+
del model.melody_encoder
|
355 |
+
if self.predict_variances:
|
356 |
+
del model.variance_predictor
|
357 |
+
model.forward = model.forward_pitch_postprocess
|
358 |
+
return model
|
359 |
+
|
360 |
+
def view_as_variance_preprocess(self):
|
361 |
+
model = copy.deepcopy(self)
|
362 |
+
del model.fs2
|
363 |
+
if self.predict_pitch:
|
364 |
+
del model.pitch_predictor
|
365 |
+
if self.use_melody_encoder:
|
366 |
+
del model.melody_encoder
|
367 |
+
if self.predict_variances:
|
368 |
+
del model.variance_predictor
|
369 |
+
model.forward = model.forward_variance_preprocess
|
370 |
+
return model
|
371 |
+
|
372 |
+
def view_as_variance_predictor(self):
|
373 |
+
assert self.predict_variances
|
374 |
+
model = copy.deepcopy(self)
|
375 |
+
del model.fs2
|
376 |
+
del model.lr
|
377 |
+
if self.predict_pitch:
|
378 |
+
del model.pitch_predictor
|
379 |
+
if self.use_melody_encoder:
|
380 |
+
del model.melody_encoder
|
381 |
+
model.forward = model.forward_variance_reflow
|
382 |
+
return model
|
383 |
+
|
384 |
+
def view_as_variance_postprocess(self):
|
385 |
+
model = copy.deepcopy(self)
|
386 |
+
del model.fs2
|
387 |
+
if self.predict_pitch:
|
388 |
+
del model.pitch_predictor
|
389 |
+
if self.use_melody_encoder:
|
390 |
+
del model.melody_encoder
|
391 |
+
model.forward = model.forward_variance_postprocess
|
392 |
+
return model
|
dictionaries/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
*.py
|
2 |
+
*.txt
|
3 |
+
!opencpop*
|
docs/BestPractices.md
ADDED
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Best Practices
|
2 |
+
|
3 |
+
## Materials for training and using models
|
4 |
+
|
5 |
+
### Datasets
|
6 |
+
|
7 |
+
A dataset mainly includes recordings and transcriptions, which is called a _raw dataset_. Raw datasets should be organized as the following folder structure:
|
8 |
+
|
9 |
+
- my_raw_data/
|
10 |
+
- wavs/
|
11 |
+
- 001.wav
|
12 |
+
- 002.wav
|
13 |
+
- ... (more recording files)
|
14 |
+
- transcriptions.csv
|
15 |
+
|
16 |
+
In the example above, the _my_raw_data_ folder is the root directory of a raw dataset.
|
17 |
+
|
18 |
+
The _transcriptions.csv_ file contains all labels of the recordings. The common column of the CSV file is `name`, which represents all recording items by their filenames **without extension**. Elements of sequence attributes should be split by `space`. Other required columns may vary according to the category of the model you are training, and will be introduced in the following sections.
|
19 |
+
|
20 |
+
### Dictionaries
|
21 |
+
|
22 |
+
A dictionary is a .txt file, in which each line represents a mapping rule from one syllable to its phoneme sequence. The syllable and the phonemes are split by `tab`, and the phonemes are split by `space`:
|
23 |
+
|
24 |
+
```
|
25 |
+
<syllable> <phoneme1> <phoneme2> ...
|
26 |
+
```
|
27 |
+
|
28 |
+
Syllable names and phoneme names can be customized, but with the following limitations/suggestions:
|
29 |
+
|
30 |
+
- `SP` (rest), `AP` (breath) and `<PAD>` (padding) cannot be used because they are reserved.
|
31 |
+
- `-` and `+` cannot be used because they are defined as slur tags in most singing voice synthesis editors.
|
32 |
+
- Special characters including but not limited to `@`, `#`, `&`, `|`, `/`, `<`, `>`, etc. should be avoided because they may be used as special tags in the future format changes. Using them now is okay, and all modifications will be notified in advance.
|
33 |
+
- ASCII characters are preferred for the best encoding compatibility, but all UTF-8 characters are acceptable.
|
34 |
+
|
35 |
+
There are some preset dictionaries in the [dictionaries/](../dictionaries) folder. For the guidance of using a custom dictionary, see [Using custom dictionaries](#using-custom-dictionaries).
|
36 |
+
|
37 |
+
### Configuration files
|
38 |
+
|
39 |
+
A configuration file is a YAML file that defines enabled features, model hyperparameters and controls the behavior of the binarizer, trainer and inference. For more information of the configuration system and configurable attributes, see [Configuration Schemas](ConfigurationSchemas.md).
|
40 |
+
|
41 |
+
### DS files
|
42 |
+
|
43 |
+
DS files are JSON files with _.ds_ suffix that contains phoneme sequence, phoneme durations, music scores or curve parameters. They are mainly used to run inference on models for test and evaluation purposes, and they can be used as training data in some cases. There are some example DS files in the [samples/](../samples) folder.
|
44 |
+
|
45 |
+
The current recommended way of using a model for production purposes is to use [OpenUTAU for DiffSinger](https://github.com/xunmengshe/OpenUtau). It can export DS files as well.
|
46 |
+
|
47 |
+
### Other fundamental assets
|
48 |
+
|
49 |
+
#### Vocoders
|
50 |
+
|
51 |
+
A vocoder is a model that can reconstruct the audio waveform given the low-dimensional mel-spectrogram. The vocoder is the essential dependency if you want to train an acoustic model and hear the voice on the TensorBoard.
|
52 |
+
|
53 |
+
The [DiffSinger Community Vocoders Project](https://openvpi.github.io/vocoders) provides a universal pre-trained NSF-HiFiGAN vocoder that can be used for starters of this repository. To use it, download the model (~50 MB size) from its releases and unzip it into the `checkpoints/` folder.
|
54 |
+
|
55 |
+
The pre-trained vocoder can be fine-tuned on your target dataset. It is highly recommended to do so because fine-tuned vocoder can generate much better results on specific (seen) datasets while does not need much computing resources. See the [vocoder training and fine-tuning repository](https://github.com/openvpi/SingingVocoders) for detailed instructions. After you get the fine-tuned vocoder checkpoint, you can configure it by `vocoder_ckpt` key in your configuration file. The fine-tuned NSF-HiFiGAN vocoder checkpoints can be exported to ONNX format like other DiffSinger user models for further production purposes.
|
56 |
+
|
57 |
+
Another unrecommended option: train a ultra-lightweight [DDSP vocoder](https://github.com/yxlllc/pc-ddsp) first by yourself, then configure it according to the relevant [instructions](https://github.com/yxlllc/pc-ddsp/blob/master/DiffSinger.md).
|
58 |
+
|
59 |
+
#### Feature extractors or auxiliary models
|
60 |
+
|
61 |
+
RMVPE is the recommended pitch extractor of this repository, which is an NN-based algorithm and requires a pre-trained model. For more information about pitch extractors and how to configure them, see [feature extraction](#pitch-extraction).
|
62 |
+
|
63 |
+
Vocal Remover (VR) is the recommended harmonic-noise separator of this repository, which is an NN-based algorithm and requires a pre-trained model. For more information about harmonic-noise separators and how to configure them, see [feature extraction](#harmonic-noise-separation).
|
64 |
+
|
65 |
+
## Overview: training acoustic models
|
66 |
+
|
67 |
+
An acoustic model takes low-level singing information as input, including (but not limited to) phoneme sequence, phoneme durations and F0 sequence. The only output of an acoustic model is the mel-spectrogram, which can be converted to waveform (the final audio) through the vocoder. Briefly speaking, an acoustic model takes in all features that are explicitly given, and produces the singing voice.
|
68 |
+
|
69 |
+
### Datasets
|
70 |
+
|
71 |
+
To train an acoustic model, you must have three columns in your transcriptions.csv: `name`, `ph_seq` and `ph_dur`, where `ph_seq` is the phoneme sequence and `ph_dur` is the phoneme duration sequence in seconds. You must have all corresponding recordings declared by the `name` column in mono, WAV format.
|
72 |
+
|
73 |
+
Training from multiple datasets in one model (so that the model is a multi-speaker model) is supported. See `speakers`, `spk_ids` and `use_spk_id` in the configuration schemas.
|
74 |
+
|
75 |
+
### Functionalities
|
76 |
+
|
77 |
+
Functionalities of acoustic models are defined by their inputs. Acoustic models have three basic and fixed inputs: phoneme sequence, phoneme duration sequence and F0 (pitch) sequence. There are three categories of additional inputs (control parameters):
|
78 |
+
|
79 |
+
- speaker IDs: if your acoustic model is a multi-speaker model, you can use different speaker in the same model, or mix their timbre and style.
|
80 |
+
- variance parameters: these curve parameters are features extracted from the recordings, and can control the timbre and style of the singing voice. See `use_energy_embed` and `use_breathiness_embed` in the configuration schemas. Please note that variance parameters **do not have default values**, so they are usually obtained from the variance model at inference time.
|
81 |
+
- transition parameters: these values represent the transition of the mel-spectrogram, and are obtained by enabling data augmentation. They are scalars at training time and sequences at inference time. See `augmentation_args`, `use_key_shift_embed` and `use_speed_embed` in the configuration schemas.
|
82 |
+
|
83 |
+
## Overview: training variance models
|
84 |
+
|
85 |
+
A variance model takes high-level music information as input, including phoneme sequence, word division, word durations and music scores. The outputs of a variance model may include phoneme durations, pitch curve and other control parameters that will be consumed by acoustic models. Briefly speaking, a variance model works as an auxiliary tool (so-called _automatic parameter generator_) for the acoustic models.
|
86 |
+
|
87 |
+
### Datasets
|
88 |
+
|
89 |
+
To train a variance model, you must have all the required attributes listed in the following table in your transcriptions.csv according to the functionalities enabled.
|
90 |
+
|
91 |
+
| | name | ph_seq | ph_dur | ph_num | note_seq | note_dur |
|
92 |
+
|:------------------------------:|:----:|:------:|:------:|:------:|:--------:|:--------:|
|
93 |
+
| phoneme duration prediction | ✓ | ✓ | ✓ | ✓ | | |
|
94 |
+
| pitch prediction | ✓ | ✓ | ✓ | | ✓ | ✓ |
|
95 |
+
| variance parameters prediction | ✓ | ✓ | ✓ | | | |
|
96 |
+
|
97 |
+
The recommended way of building a variance dataset is to extend an acoustic dataset. You may have all the recordings prepared like the acoustic dataset as well, or [use DS files in your variance datasets](#build-variance-datasets-with-ds-files).
|
98 |
+
|
99 |
+
Variance models support multi-speaker settings like acoustic models do.
|
100 |
+
|
101 |
+
### Functionalities
|
102 |
+
|
103 |
+
Functionalities of variance models are defined by their outputs. There are three main prediction modules that can be enabled/disable independently:
|
104 |
+
|
105 |
+
- Duration Predictor: predicts the phoneme durations. See `predict_dur` in the configuration schemas.
|
106 |
+
- Pitch Predictor: predicts the pitch curve. See `predict_pitch` in the configuration schemas.
|
107 |
+
- Multi-Variance Predictor: jointly predicts other variance parameters. See `predict_energy` and `predict_breathiness` in the configuration schemas.
|
108 |
+
|
109 |
+
There may be some mutual influence between the modules above when they are enabled together. See [mutual influence between variance modules](#mutual-influence-between-variance-modules) for more details.
|
110 |
+
|
111 |
+
## Using custom dictionaries
|
112 |
+
|
113 |
+
This section is about using a custom grapheme-to-phoneme dictionary for any language(s).
|
114 |
+
|
115 |
+
### Add a dictionary
|
116 |
+
|
117 |
+
Assume that you have made a dictionary file named `my_dict.txt`. Edit your configuration file:
|
118 |
+
|
119 |
+
```yaml
|
120 |
+
dictionary: my_dict.txt
|
121 |
+
```
|
122 |
+
|
123 |
+
Then you can binarize your data as normal. The phonemes in your dataset must cover, and must only cover the phonemes appeared in your dictionary. Otherwise, the binarizer will raise an error:
|
124 |
+
|
125 |
+
```
|
126 |
+
AssertionError: transcriptions and dictionary mismatch.
|
127 |
+
(+) ['E', 'En', 'i0', 'ir']
|
128 |
+
(-) ['AP', 'SP']
|
129 |
+
```
|
130 |
+
|
131 |
+
This means there are 4 unexpected symbols in the data labels (`ir`, `i0`, `E`, `En`) and 2 missing phonemes that are not covered by the data labels (`AP`, `SP`).
|
132 |
+
|
133 |
+
Once the coverage checks passed, a phoneme distribution summary will be saved into your binary data directory. Below is an example.
|
134 |
+
|
135 |
+

|
136 |
+
|
137 |
+
During the binarization process, each phoneme will be assigned with a unique phoneme ID according the order of their names. There are one padding index (marked as `<PAD`) before all real phonemes IDs.
|
138 |
+
|
139 |
+
The dictionary used to binarize the dataset will be copied to the binary data directory by the binarizer, and will be copied again to the experiment directory by the trainer. When exported to ONNX, the dictionary and the phoneme sequence ordered by IDs will be saved to the artifact directory. You do not need to carry the original dictionary file for training and inference.
|
140 |
+
|
141 |
+
### Preset dictionaries
|
142 |
+
|
143 |
+
There are currently some preset dictionaries for you to use directly:
|
144 |
+
|
145 |
+
| dictionary | filename | description |
|
146 |
+
|:------------------:|:----------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
147 |
+
| Opencpop | opencpop.txt | The original dictionary used by the Opencpop mandarin singing dataset that is fully aligned with the pinyin writing system. We copied the dictionary from [here](http://wenet.org.cn/opencpop/resources/annotationformat/), removed 5 syllables that has no occurrence in the data labels (`hm`, `hng`, `m`, `n` and `ng`) and added some aliases for some syllables (e.g. `jv` for `ju`). Due to pronunciation issues, this dictionary is deprecated and remained only for backward compatibility. |
|
148 |
+
| Opencpop extension | opencpop-extension.txt | The modified version of the opencpop dictionary, with stricter phoneme division rules for some pinyin syllables. For example, `ci` is mapped to `c i0` and `chi` is mapped to `ch ir` to distinguish with `bi` (`b i`). This dictionary is now used as the default dictionary for mandarin Chinese. There are also many new syllables for more phoneme combinations. |
|
149 |
+
|
150 |
+
### Submit or propose a new dictionary
|
151 |
+
|
152 |
+
You can submit or propose a new dictionary by raising a topic in [Discussions](https://github.com/openvpi/DiffSinger/discussions). Any dictionary to be formally supported in the main branch must match the following principles:
|
153 |
+
|
154 |
+
- Only monolingual dictionaries are accepted for now. Support for multilingual dictionaries will be designed in the future.
|
155 |
+
- All syllables and phonemes in the dictionary should have linguistic meanings. Style tags (vocal fry, falsetto, etc.) should not appear in the dictionary.
|
156 |
+
- Its syllables should be standard spelling or phonetic transcriptions (like pinyin in mandarin Chinese and romaji in Japanese) for easy integration with G2P modules.
|
157 |
+
- Its phonemes should cover all (or almost all) possible pronunciations in that language.
|
158 |
+
- Every syllable and every phoneme should have one, and only one certain pronunciation, in all or almost all situations in that language. Some slight context-based pronunciation differences are allowed as the networks can learn.
|
159 |
+
- Most native speakers/singers of that language should be able to easily cover all phonemes in the dictionary. This means the dictionary should not contain extremely rare or highly customized phonemes of some dialects or accents.
|
160 |
+
- It should not bring too much difficulty and complexity to the data labeling workflow, and it should be easy to use for end users of voicebanks.
|
161 |
+
|
162 |
+
## Build variance datasets with DS files
|
163 |
+
|
164 |
+
By default, the variance binarizer loads attributes from transcriptions.csv and searches for recording files (*.wav) to extract features and parameters. These attributes and parameters also exist in DS files, which are normally used for inference. This section introduces the required settings and important notes to build a variance dataset from DS files.
|
165 |
+
|
166 |
+
First of all, you should edit your configuration file to enable loading from DS files:
|
167 |
+
|
168 |
+
```yaml
|
169 |
+
binarization_args:
|
170 |
+
prefer_ds: true # prefer loading from DS files
|
171 |
+
```
|
172 |
+
|
173 |
+
Then you should prepare some DS files which are properly segmented. If you export DS files with OpenUTAU for DiffSinger, the DS files are already segmented according to the spaces between notes. You should put these DS files in a folder named `ds` in your raw dataset directory (besides the `wavs` folder).
|
174 |
+
|
175 |
+
The DS files should also use the same dictionary as that of your target model. The attributes required vary from your target functionalities, as listed below:
|
176 |
+
|
177 |
+
| attribute name | required by duration prediction | required by pitch prediction | required by variance parameters prediction | previous source | current source |
|
178 |
+
|:----------------------------:|:-------------------------------:|:----------------------------:|:------------------------------------------:|:---------------:|:--------------:|
|
179 |
+
| `name` | ✓ | ✓ | ✓ | CSV | CSV |
|
180 |
+
| `ph_seq` | ✓ | ✓ | ✓ | CSV | DS/CSV |
|
181 |
+
| `ph_dur` | ✓ | ✓ | ✓ | CSV | DS/CSV |
|
182 |
+
| `ph_num` | ✓ | | | CSV | DS/CSV |
|
183 |
+
| `note_seq` | | ✓ | | CSV | DS/CSV |
|
184 |
+
| `note_dur` | | ✓ | | CSV | DS/CSV |
|
185 |
+
| `f0_seq` | ✓ | ✓ | ✓ | WAV | DS/WAV |
|
186 |
+
| `energy`, `breathiness`, ... | | | ✓ | WAV | DS/WAV |
|
187 |
+
|
188 |
+
This means you only need one column in transcriptions.csv, the `name` column, to declare all DS files included in the dataset. The name pattern can be:
|
189 |
+
|
190 |
+
- Full name: `some-name` will firstly match the first segment in `some-name.ds`.
|
191 |
+
- Name with index: `some-name#0` and `some-name#1` will match segment 0 and segment 1 in `some-name.ds` if there are no match with full name.
|
192 |
+
|
193 |
+
Though not recommended, the binarizer will still try to load attributes from transcriptions.csv or extract parameters from recordings if there are no matching DS files. In this case the full name matching logic is applied (the same as the normal binarization process).
|
194 |
+
|
195 |
+
## Choosing variance parameters
|
196 |
+
|
197 |
+
Variance parameters are a type of parameters that are significantly related to singing styles and emotions, have no default values and need to be predicted by the variance models. Choosing the proper variance parameters can obtain more controllability and expressiveness for your singing models. In this section, we are only talking about **narrowly defined variance parameters**, which are variance parameters except the pitch.
|
198 |
+
|
199 |
+
### Supported variance parameters
|
200 |
+
|
201 |
+
#### Energy
|
202 |
+
|
203 |
+
> WARNING
|
204 |
+
>
|
205 |
+
> This parameter is no longer recommended in favor of the new voicing parameter. The latter are less coupled with breathiness than energy.
|
206 |
+
|
207 |
+
Energy is defined as the RMS curve of the singing, in dB, which can control the strength of voice to a certain extent.
|
208 |
+
|
209 |
+
#### Breathiness
|
210 |
+
|
211 |
+
Breathiness is defined as the RMS curve of the aperiodic part of the singing, in dB, which can control the power of the air and unvoiced consonants in the voice.
|
212 |
+
|
213 |
+
#### Voicing
|
214 |
+
|
215 |
+
Voicing is defined as the RMS curve of the harmonic part of the singing, in dB, which can control the power of the harmonics in vowels and voiced consonants in the voice.
|
216 |
+
|
217 |
+
#### Tension
|
218 |
+
|
219 |
+
Tension is mostly related to the ratio of the base harmonic to the full harmonics, which can be used to control the strength and timbre of the voice. The ratio is calculated as
|
220 |
+
$$
|
221 |
+
r = \frac{\text{RMS}(H_{full}-H_{base})}{\text{RMS}(H_{full})}
|
222 |
+
$$
|
223 |
+
where $H_{full}$ is the full harmonics and $H_{base}$ is the base harmonic. The ratio is then mapped to the final domain via the inverse function of Sigmoid, that
|
224 |
+
$$
|
225 |
+
T = \log{\frac{r}{1-r}}
|
226 |
+
$$
|
227 |
+
where $T$ is the tension value.
|
228 |
+
|
229 |
+
### Principles of choosing multiple parameters
|
230 |
+
|
231 |
+
#### Energy, breathiness and voicing
|
232 |
+
|
233 |
+
These three parameters should **NOT** be enabled together. Energy is the RMS of the full waveform, which is the composition of the harmonic part and the aperiodic part. Therefore, these three parameters are coupled with each other.
|
234 |
+
|
235 |
+
#### Energy, voicing and tension
|
236 |
+
|
237 |
+
When voicing (or energy) is enabled, it almost fixes the loudness. However, tension sometimes rely on the implicitly predicted loudness for more expressiveness, because when a person sings with higher tension, he/she always produces louder voice. For this reason, some people may find their models or datasets _less natural_ with tension control. To be specific, changing tension will change the timbre but keep the loudness, and changing voicing (or energy) will change the loudness but keep the timbre. This behavior can be suitable for some, but not all datasets and users. Therefore, it is highly recommended for everyone to conduct some experiments on the actual datasets used to train the model.
|
238 |
+
|
239 |
+
## Mutual influence between variance modules
|
240 |
+
|
241 |
+
In some recent experiments and researches, some mutual influence between the modules of variance models has been found. In practice, being aware of the influence and making use of it can improve accuracy and avoid instability of the model.
|
242 |
+
|
243 |
+
### Influence on the duration predictor
|
244 |
+
|
245 |
+
The duration predictor benefits from its downstream modules, like the pitch predictor and the variance predictor.
|
246 |
+
|
247 |
+
The experiments were conducted on both manually refined datasets and automatically labeled datasets, and with pitch predictors driven by both base pitch and melody encoder. All the results have shown that when either of the pitch predictor and the variance predictor is enabled together with the duration predictor, its rhythm correctness and duration accuracy significantly outperforms those of a solely trained duration predictor.
|
248 |
+
|
249 |
+
Possible reason for this difference can be the lack of information carried by pure phoneme duration sequences, which may not fully represent the phoneme features in the real world. With the help of frame-level feature predictors, the encoder learns more knowledge about the voice features related to the phoneme types and durations, thus making the duration predictor produce better results.
|
250 |
+
|
251 |
+
### Influence on frame-level feature predictors
|
252 |
+
|
253 |
+
Frame-level feature predictors, including the pitch predictor and the variance predictor, have better performance when trained without enabling the duration predictor.
|
254 |
+
|
255 |
+
The experiments found that when the duration predictor is enabled, the pitch accuracy drops and the dynamics of variance parameters sometimes become unstable. And it has nothing to do with the gradients from the duration predictor, because applying a scale factor on the gradients does not make any difference even if the gradients are completely cut off.
|
256 |
+
|
257 |
+
Possible reason for this phenomenon can be the lack of direct phoneme duration input. When the duration predictor is enabled, the model takes in word durations instead of phoneme durations; when there is no duration predictor together, the phoneme duration sequence is directly taken in and passed through the attention-based linguistic encoder. With direct modeling on the phoneme duration, the frame-level predictors can have a better understanding of the context, thus producing better results.
|
258 |
+
|
259 |
+
Another set of experiments showed that there is no significant influence between the pitch predictor and the variance predictor. When they are enabled together without the duration predictor, both can converge well and produce satisfactory results. No conclusion can be drawn on this issue, and it can depend on the dataset.
|
260 |
+
|
261 |
+
### Suggested procedures of training variance models
|
262 |
+
|
263 |
+
According to the experiment results and the analysis above, the suggested procedures of training a set of variance models are listed below:
|
264 |
+
|
265 |
+
1. Train the duration predictor together with the variance predictor, and discard the variance predictor part.
|
266 |
+
2. Train the pitch predictor and the variance predictor separately or together.
|
267 |
+
3. If interested, compare across different combinations in step 2 and choose the best.
|
268 |
+
|
269 |
+
## Feature extraction
|
270 |
+
|
271 |
+
Feature extraction is the process of extracting low-level features from the recordings, which are needed as inputs for the acoustic models, or as outputs for the variance models.
|
272 |
+
|
273 |
+
### Pitch extraction
|
274 |
+
|
275 |
+
A pitch extractor estimates pitch (F0 sequence) from given recordings. F0 (fundamental frequency) is one of the most important components of singing voice that is needed by both acoustic models and variance models.
|
276 |
+
|
277 |
+
```yaml
|
278 |
+
pe: parselmouth # pitch extractor type
|
279 |
+
pe_ckpt: checkpoints/xxx/model.pt # pitch extractor model path (if it requires any)
|
280 |
+
```
|
281 |
+
|
282 |
+
#### Parselmouth
|
283 |
+
|
284 |
+
[Parselmouth](https://github.com/YannickJadoul/Parselmouth) is the default pitch extractor in this repository. It is based on DSP algorithms, runs fast on CPU and can get accurate F0 on clean and normal recordings.
|
285 |
+
|
286 |
+
To use parselmouth, simply include the following line in your configuration file:
|
287 |
+
|
288 |
+
```yaml
|
289 |
+
pe: parselmouth
|
290 |
+
```
|
291 |
+
|
292 |
+
#### RMVPE (recommended)
|
293 |
+
|
294 |
+
[RMVPE](https://github.com/Dream-High/RMVPE) (Robust Model for Vocal Pitch Estimation) is the state-of-the-art NN-based pitch estimation model for singing voice. It runs slower than parselmouth, consumes more memory, however uses CUDA to accelerate computation (if available) and produce better results on noisy recordings and edge cases.
|
295 |
+
|
296 |
+
To enable RMVPE, download its pre-trained checkpoint from [here](https://github.com/yxlllc/RMVPE/releases), extract it into the `checkpoints/` folder and edit the configuration file:
|
297 |
+
|
298 |
+
```yaml
|
299 |
+
pe: rmvpe
|
300 |
+
pe_ckpt: checkpoints/rmvpe/model.pt
|
301 |
+
```
|
302 |
+
|
303 |
+
#### Harvest
|
304 |
+
|
305 |
+
Harvest (Harvest: A high-performance fundamental frequency estimator from speech signals) is the recommended pitch extractor from Masanori Morise's [WORLD](https://github.com/mmorise/World), a free software for high-quality speech analysis, manipulation and synthesis. It is a state-of-the-art algorithmic pitch estimator designed for speech, but has seen use in singing voice synthesis. It runs the slowest compared to the others, but provides very accurate F0 on clean and normal recordings compared to parselmouth.
|
306 |
+
|
307 |
+
To use Harvest, simply include the following line in your configuration file:
|
308 |
+
|
309 |
+
```yaml
|
310 |
+
pe: harvest
|
311 |
+
```
|
312 |
+
|
313 |
+
**Note:** It is also recommended to change the F0 detection range for Harvest with accordance to your dataset, as they are hard boundaries for this algorithm and the defaults might not suffice for most use cases. To change the F0 detection range, you may include or edit this part in the configuration file:
|
314 |
+
|
315 |
+
```yaml
|
316 |
+
f0_min: 65 # Minimum F0 to detect
|
317 |
+
f0_max: 800 # Maximum F0 to detect
|
318 |
+
```
|
319 |
+
|
320 |
+
### Harmonic-noise separation
|
321 |
+
|
322 |
+
Harmonic-noise separation is the process of separating the harmonic part and the aperiodic part of the singing voice. These parts are the fundamental components for variance parameters including breathiness, voicing and tension to be calculated from.
|
323 |
+
|
324 |
+
#### WORLD
|
325 |
+
|
326 |
+
This algorithm uses Masanori Morise's [WORLD](https://github.com/mmorise/World), a free software for high-quality speech analysis, manipulation and synthesis. It uses CPU (no CUDA required) but runs relatively slow.
|
327 |
+
|
328 |
+
To use WORLD, simply include the following line in your configuration file:
|
329 |
+
|
330 |
+
```yaml
|
331 |
+
hnsep: world
|
332 |
+
```
|
333 |
+
|
334 |
+
#### Vocal Remover (recommended)
|
335 |
+
|
336 |
+
Vocal Remover (VR) is originally a popular NN-based algorithm for music source separation that removes the vocal part from the music. This repository uses a specially trained model for harmonic-noise separation. VR extracts much cleaner harmonic parts, utilizes CUDA to accelerate computation (if available) and runs much faster than WORLD. However, it consumes more memory and should not be used with too many parallel workers.
|
337 |
+
|
338 |
+
To enable VR, download its pre-trained checkpoint from [here](https://github.com/yxlllc/vocal-remover/releases), extract it into the `checkpoints/` folder and edit the configuration file:
|
339 |
+
|
340 |
+
```yaml
|
341 |
+
hnsep: vr
|
342 |
+
hnsep_ckpt: checkpoints/vr/model.pt
|
343 |
+
```
|
344 |
+
|
345 |
+
## Shallow diffusion
|
346 |
+
|
347 |
+
Shallow diffusion is a mechanism that can improve quality and save inference time for diffusion models that was first introduced in the original DiffSinger [paper](https://arxiv.org/abs/2105.02446). Instead of starting the diffusion process from purely gaussian noise as classic diffusion does, shallow diffusion adds a shallow gaussian noise on a low-quality results generated by a simple network (which is called the auxiliary decoder) to skip many unnecessary steps from the beginning. With the combination of shallow diffusion and sampling acceleration algorithms, we can get better results under the same inference speed as before, or achieve higher inference speed without quality deterioration.
|
348 |
+
|
349 |
+
Currently, acoustic models in this repository support shallow diffusion. The main switch of shallow diffusion is `use_shallow_diffusion` in the configuration file, and most arguments of shallow diffusion can be adjusted under `shallow_diffusion_args`. See [Configuration Schemas](ConfigurationSchemas.md) for more details.
|
350 |
+
|
351 |
+
### Train full shallow diffusion models from scratch
|
352 |
+
|
353 |
+
To train a full shallow diffusion model from scratch, simply introduce the following settings in your configuration file:
|
354 |
+
|
355 |
+
```yaml
|
356 |
+
use_shallow_diffusion: true
|
357 |
+
K_step: 400 # adjust according to your needs
|
358 |
+
K_step_infer: 400 # should be <= K_step
|
359 |
+
```
|
360 |
+
|
361 |
+
Please note that when shallow diffusion is enabled, only the last $K$ diffusion steps will be trained. Unlike classic diffusion models which are trained on full steps, the limit of `K_step` can make the training more efficient. However, `K_step` should not be set too small because without enough diffusion depth (steps), the low-quality auxiliary decoder results cannot be well refined. 200 ~ 400 should be the proper range of `K_step`.
|
362 |
+
|
363 |
+
The auxiliary decoder and the diffusion decoder shares the same linguistic encoder, which receives gradients from both the decoders. In some experiments, it was found that gradients from the auxiliary decoder will cause mismatching between the encoder and the diffusion decoder, resulting in the latter being unable to produce reasonable results. To prevent this case, a configuration item called `aux_decoder_grad` is introduced to apply a scale factor on the gradients from the auxiliary decoder during training. To adjust this factor, introduce the following in the configuration file:
|
364 |
+
|
365 |
+
```yaml
|
366 |
+
shallow_diffusion_args:
|
367 |
+
aux_decoder_grad: 0.1 # should not be too high
|
368 |
+
```
|
369 |
+
|
370 |
+
### Train auxiliary decoder and diffusion decoder separately
|
371 |
+
|
372 |
+
Training a full shallow diffusion model can consume more memory because the auxiliary decoder is also in the training graph. In limited situations, the two decoders can be trained separately, i.e. train one decoder after another.
|
373 |
+
|
374 |
+
**STEP 1: train the diffusion decoder**
|
375 |
+
|
376 |
+
In the first stage, the linguistic encoder and the diffusion decoder is trained together, while the auxiliary decoder is left unchanged. Edit your configuration file like this:
|
377 |
+
|
378 |
+
```yaml
|
379 |
+
use_shallow_diffusion: true # make sure the main option is turned on
|
380 |
+
shallow_diffusion_args:
|
381 |
+
train_aux_decoder: false # exclude the auxiliary decoder from the training graph
|
382 |
+
train_diffusion: true # train diffusion decoder as normal
|
383 |
+
val_gt_start: true # should be true because the auxiliary decoder is not trained yet
|
384 |
+
```
|
385 |
+
|
386 |
+
Start training until `max_updates` is reached, or until you get satisfactory results on the TensorBoard.
|
387 |
+
|
388 |
+
**STEP 2: train the auxiliary decoder**
|
389 |
+
|
390 |
+
In the second stage, the auxiliary decoder is trained besides the linguistic encoder and the diffusion decoder. Edit your configuration file like this:
|
391 |
+
|
392 |
+
```yaml
|
393 |
+
shallow_diffusion_args:
|
394 |
+
train_aux_decoder: true
|
395 |
+
train_diffusion: false # exclude the diffusion decoder from the training graph
|
396 |
+
lambda_aux_mel_loss: 1.0 # no more need to limit the auxiliary loss
|
397 |
+
```
|
398 |
+
|
399 |
+
Then you should freeze the encoder to prevent it from getting updates. This is because if the encoder changes, it no longer matches with the diffusion decoder, thus making the latter unable to produce correct results again. Edit your configuration file:
|
400 |
+
|
401 |
+
```yaml
|
402 |
+
freezing_enabled: true
|
403 |
+
frozen_params:
|
404 |
+
- model.fs2 # the linguistic encoder
|
405 |
+
```
|
406 |
+
|
407 |
+
You should also manually reset your learning rate scheduler because this is a new training process for the auxiliary decoder. Possible ways are:
|
408 |
+
|
409 |
+
1. Rename the latest checkpoint to `model_ckpt_steps_0.ckpt` and remove the other checkpoints from the directory.
|
410 |
+
2. Increase the initial learning rate (if you use a scheduler that decreases the LR over training steps) so that the auxiliary decoder gets proper learning rate.
|
411 |
+
|
412 |
+
Additionally, `max_updates` should be adjusted to ensure enough training steps for the auxiliary decoder.
|
413 |
+
|
414 |
+
Once you finished the configurations above, you can resume the training. The auxiliary decoder normally does not need many steps to train, and you can stop training when you get stable results on the TensorBoard. Because this step is much more complicated than the previous step, it is recommended to run some inference to verify if the model is trained properly after everything is finished.
|
415 |
+
|
416 |
+
### Add shallow diffusion to classic diffusion models
|
417 |
+
|
418 |
+
Actually, all classic DDPMs have the ability to be "shallow". If you want to add shallow diffusion functionality to a former classic diffusion model, the only thing you need to do is to train an auxiliary decoder for it.
|
419 |
+
|
420 |
+
Before you start, you should edit the configuration file to ensure that you use the same datasets, and that you do not remove or add any of the functionalities of the old model. Then you can configure the old checkpoint in your configuration file:
|
421 |
+
|
422 |
+
```yaml
|
423 |
+
finetune_enabled: true
|
424 |
+
finetune_ckpt_path: xxx.ckpt # path to your old checkpoint
|
425 |
+
finetune_ignored_params: [] # do not ignore any parameters
|
426 |
+
```
|
427 |
+
|
428 |
+
Then you can follow the instructions in STEP 2 of the [previous section](#add-shallow-diffusion-to-classic-diffusion-models) to finish your training.
|
429 |
+
|
430 |
+
## Performance tuning
|
431 |
+
|
432 |
+
This section is about accelerating training and utilizing hardware.
|
433 |
+
|
434 |
+
### Data loader and batch sampler
|
435 |
+
|
436 |
+
The data loader loads data pieces from the binary dataset, and the batch sampler forms batches according to data lengths.
|
437 |
+
|
438 |
+
To configure the data loader, edit your configuration file:
|
439 |
+
|
440 |
+
```yaml
|
441 |
+
ds_workers: 4 # number of DataLoader workers
|
442 |
+
dataloader_prefetch_factor: 2 # load data in advance
|
443 |
+
```
|
444 |
+
|
445 |
+
To configure the batch sampler, edit your configuration file:
|
446 |
+
|
447 |
+
```yaml
|
448 |
+
sampler_frame_count_grid: 6 # lower value means higher speed but less randomness
|
449 |
+
```
|
450 |
+
|
451 |
+
For more details of the batch sampler algorithm and this configuration key, see [sampler_frame_count_grid](ConfigurationSchemas.md#sampler_frame_count_grid).
|
452 |
+
|
453 |
+
### Automatic mixed precision
|
454 |
+
|
455 |
+
Enabling automatic mixed precision (AMP) can accelerate training and save GPU memory. DiffSinger have adapted the latest version of PyTorch Lightning for AMP functionalities.
|
456 |
+
|
457 |
+
By default, the training runs in FP32 precision. To enable AMP, edit your configuration file:
|
458 |
+
|
459 |
+
```yaml
|
460 |
+
pl_trainer_precision: 16-mixed # FP16 precision
|
461 |
+
```
|
462 |
+
|
463 |
+
or
|
464 |
+
|
465 |
+
```yaml
|
466 |
+
pl_trainer_precision: bf16-mixed # BF16 precision
|
467 |
+
```
|
468 |
+
|
469 |
+
For more precision options, please check out the [official documentation](https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision).
|
470 |
+
|
471 |
+
### Training on multiple GPUs
|
472 |
+
|
473 |
+
Using distributed data parallel (DDP) can divide training tasks to multiple GPUs and synchronize gradients and weights between them. DiffSinger have adapted the latest version of PyTorch Lightning for DDP functionalities.
|
474 |
+
|
475 |
+
By default, the trainer will utilize all CUDA devices defined in the `CUDA_VISIBLE_DEVICES` environment variable (empty means using all available devices). If you want to specify which GPUs to use, edit your configuration file:
|
476 |
+
|
477 |
+
```yaml
|
478 |
+
pl_trainer_devices: [0, 1, 2, 3] # use the first 4 GPUs defined in CUDA_VISIBLE_DEVICES
|
479 |
+
```
|
480 |
+
|
481 |
+
Please note that `max_batch_size` and `max_batch_frames` are values for **each** GPU.
|
482 |
+
|
483 |
+
By default, the trainer uses NCCL as the DDP backend. If this gets stuck on your machine, try disabling P2P first via
|
484 |
+
|
485 |
+
```yaml
|
486 |
+
nccl_p2p: false # disable P2P in NCCL
|
487 |
+
```
|
488 |
+
|
489 |
+
Or if your machine does not support NCCL, you can switch to Gloo instead:
|
490 |
+
|
491 |
+
```yaml
|
492 |
+
pl_trainer_strategy:
|
493 |
+
name: ddp # must manually choose a strategy instead of 'auto'
|
494 |
+
process_group_backend: gloo # however, it has a lower performance than NCCL
|
495 |
+
```
|
496 |
+
|
497 |
+
### Gradient accumulation
|
498 |
+
|
499 |
+
Gradient accumulation means accumulating losses for several batches before each time the weights are updated. This can simulate a larger batch size with a lower GPU memory cost.
|
500 |
+
|
501 |
+
By default, the trainer calls `backward()` each time the losses are calculated through one batch of data. To enable gradient accumulation, edit your configuration file:
|
502 |
+
|
503 |
+
```yaml
|
504 |
+
accumulate_grad_batches: 4 # the actual batch size will be 4x.
|
505 |
+
```
|
506 |
+
|
507 |
+
Please note that enabling gradient accumulation will slow down training because the losses must be calculated for several times before the weights are updated (1 update to the weights = 1 actual training step).
|
508 |
+
|
509 |
+
## Optimizers and learning rate schedulers
|
510 |
+
|
511 |
+
The optimizer and the learning rate scheduler can take an important role in the training process. DiffSinger uses a flexible configuration logic for these two modules.
|
512 |
+
|
513 |
+
### Basic configurations
|
514 |
+
|
515 |
+
The optimizer and learning rate scheduler used during training can be configured by their full class name and keyword arguments in the configuration file. Take the following as an example for the optimizer:
|
516 |
+
|
517 |
+
```yaml
|
518 |
+
optimizer_args:
|
519 |
+
optimizer_cls: torch.optim.AdamW # class name of optimizer
|
520 |
+
lr: 0.0004
|
521 |
+
beta1: 0.9
|
522 |
+
beta2: 0.98
|
523 |
+
weight_decay: 0
|
524 |
+
```
|
525 |
+
|
526 |
+
and for the learning rate scheduler:
|
527 |
+
|
528 |
+
```yaml
|
529 |
+
lr_scheduler_args:
|
530 |
+
scheduler_cls: torch.optim.lr_scheduler.StepLR # class name of learning rate schedule
|
531 |
+
warmup_steps: 2000
|
532 |
+
step_size: 50000
|
533 |
+
gamma: 0.5
|
534 |
+
```
|
535 |
+
|
536 |
+
Note that `optimizer_args` and `lr_scheduler_args` will be filtered by needed parameters and passed to `__init__` as keyword arguments (`kwargs`) when constructing the optimizer and scheduler. Therefore, you could specify all arguments according to your need in the configuration file to directly control the behavior of optimization and LR scheduling. It will also tolerate parameters existing in the configuration but not needed in `__init__`.
|
537 |
+
|
538 |
+
Also, note that the LR scheduler performs scheduling on the granularity of steps, not epochs.
|
539 |
+
|
540 |
+
The special case applies when a tuple is needed in `__init__`: `beta1` and `beta2` are treated separately and form a tuple in the code. You could try to pass in an array instead. (And as an experiment, AdamW does accept `[beta1, beta2]`). If there is another special treatment required, please submit an issue.
|
541 |
+
|
542 |
+
For PyTorch built-in optimizers and LR schedulers, see official [documentation](https://pytorch.org/docs/stable/optim.html) of the `torch.optim` package. If you found other optimizer and learning rate scheduler useful, you can raise a topic in [Discussions](https://github.com/openvpi/DiffSinger/discussions), raise [Issues](https://github.com/openvpi/DiffSinger/issues) or submit [PRs](https://github.com/openvpi/DiffSinger/pulls) if it introduces new codes or dependencies.
|
543 |
+
|
544 |
+
### Composite LR schedulers
|
545 |
+
|
546 |
+
Some LR schedulers like `SequentialLR` and `ChainedScheduler` may use other schedulers as arguments. Besides built-in types, there is a special design to configure these scheduler objects. See the following example.
|
547 |
+
|
548 |
+
```yaml
|
549 |
+
lr_scheduler_args:
|
550 |
+
scheduler_cls: torch.optim.lr_scheduler.SequentialLR
|
551 |
+
schedulers:
|
552 |
+
- cls: torch.optim.lr_scheduler.ExponentialLR
|
553 |
+
gamma: 0.5
|
554 |
+
- cls: torch.optim.lr_scheduler.LinearLR
|
555 |
+
- cls: torch.optim.lr_scheduler.MultiStepLR
|
556 |
+
milestones:
|
557 |
+
- 10
|
558 |
+
- 20
|
559 |
+
milestones:
|
560 |
+
- 10
|
561 |
+
- 20
|
562 |
+
```
|
563 |
+
|
564 |
+
The LR scheduler objects will be recursively construct objects if `cls` is present in sub-arguments. Please note that `cls` must be a scheduler class because this is a special design.
|
565 |
+
|
566 |
+
**WARNING:** Nested `SequentialLR` and `ChainedScheduler` have unexpected behavior. **DO NOT** nest them. Also, make sure the scheduler is _chainable_ before using it in `ChainedScheduler`.
|
567 |
+
|
568 |
+
## Fine-tuning and parameter freezing
|
569 |
+
|
570 |
+
### Fine-tuning from existing checkpoints
|
571 |
+
|
572 |
+
By default, the training starts from a model from scratch with randomly initialized parameters. However, if you already have some pre-trained checkpoints, and you need to adapt them to other datasets with their functionalities unchanged, fine-tuning may save training steps and time. In general, you need to add the following structure into the configuration file:
|
573 |
+
|
574 |
+
```yaml
|
575 |
+
# take acoustic models as an example
|
576 |
+
finetune_enabled: true # the main switch to enable fine-tuning
|
577 |
+
finetune_ckpt_path: checkpoints/pretrained/model_ckpt_steps_320000.ckpt # path to your pre-trained checkpoint
|
578 |
+
finetune_ignored_params: # prefix rules to exclude specific parameters when loading the checkpoints
|
579 |
+
- model.fs2.encoder.embed_tokens # in case when the phoneme set is changed
|
580 |
+
- model.fs2.txt_embed # same as above
|
581 |
+
- model.fs2.spk_embed # in case when the speaker set is changed
|
582 |
+
finetune_strict_shapes: true # whether to raise an error when parameter shapes mismatch
|
583 |
+
```
|
584 |
+
|
585 |
+
For the pre-trained checkpoint, it must be a file saved with `torch.save`, containing a `dict` object and a `state_dict` key, like the following example:
|
586 |
+
|
587 |
+
```json5
|
588 |
+
{
|
589 |
+
"state_dict": {
|
590 |
+
"model.fs2.txt_embed": null, // torch.Tensor
|
591 |
+
"model.fs2.pitch_embed.weight": null, // torch.Tensor
|
592 |
+
"model.fs2.pitch_embed.bias": null, // torch.Tensor
|
593 |
+
// ... (other parameters)
|
594 |
+
}
|
595 |
+
// ... (other possible keys
|
596 |
+
}
|
597 |
+
```
|
598 |
+
|
599 |
+
**IMPORTANT NOTES**:
|
600 |
+
|
601 |
+
- The pre-trained checkpoint is **loaded only once** at the beginning of the training experiment. You may interrupt the training at any time, but after this new experiment has saved its own checkpoint, the pre-trained checkpoint will not be loaded again when the training is resumed.
|
602 |
+
- Only the state dict of the checkpoint will be loaded. The optimizer state in the pre-trained checkpoint will be ignored.
|
603 |
+
- The parameter name matching is **not strict** when loading the pre-trained checkpoint. This means that missing parameters in the state dict will still be left as randomly initialized, and redundant parameters will be ignored without any warnings and errors. There are cases where the tensor shapes mismatch between the pre-trained state dict and the model - edit `finetune_strict_shapes` to change the behavior when dealing with this.
|
604 |
+
- Be careful if you want to change the functionalities when fine-tuning. Starting from a checkpoint trained under different functionalities may be even slower than training from scratch.
|
605 |
+
|
606 |
+
### Freezing model parameters
|
607 |
+
|
608 |
+
Sometimes you want to freeze part of the model during training or fine-tuning to save GPU memory, accelerate the training process or avoid catastrophic forgetting. Parameter freezing may also be useful if you want to add/remove functionalities from pre-trained checkpoints. In general, you need to add the following structure into the configuration file:
|
609 |
+
|
610 |
+
```yaml
|
611 |
+
# take acoustic models as an example
|
612 |
+
freezing_enabled: true # main switch to enable parameter freezing
|
613 |
+
frozen_params: # prefix rules to freeze specific parameters during training
|
614 |
+
- model.fs2.encoder
|
615 |
+
- model.fs2.pitch_embed
|
616 |
+
```
|
617 |
+
|
618 |
+
You may interrupt the training and change the settings above at any time. Sometimes this will cause mismatching optimizer state - and it will be discarded silently.
|
docs/ConfigurationSchemas.md
ADDED
@@ -0,0 +1,2109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration Schemas
|
2 |
+
|
3 |
+
## The configuration system
|
4 |
+
|
5 |
+
DiffSinger uses a cascading configuration system based on YAML files. All configuration files originally inherit and override [configs/base.yaml](../configs/base.yaml), and each file directly override another file by setting the `base_config` attribute. The overriding rules are:
|
6 |
+
|
7 |
+
- Configuration keys with the same path and the same name will be replaced. Other paths and names will be merged.
|
8 |
+
- All configurations in the inheritance chain will be squashed (via the rule above) as the final configuration.
|
9 |
+
- The trainer will save the final configuration in the experiment directory, which is detached from the chain and made independent from other configuration files.
|
10 |
+
|
11 |
+
## Configurable parameters
|
12 |
+
|
13 |
+
This following are the meaning and usages of all editable keys in a configuration file.
|
14 |
+
|
15 |
+
Each configuration key (including nested keys) are described with a brief explanation and several attributes listed as follows:
|
16 |
+
|
17 |
+
| Attribute | Explanation |
|
18 |
+
|:---------------:||
|
19 |
+
| visibility | Represents what kind(s) of models and tasks this configuration belongs to. |
|
20 |
+
| scope | The scope of effects of the configuration, indicating what it can influence within the whole pipeline. Possible values are:<br>**nn** - This configuration is related to how the neural networks are formed and initialized. Modifying it will result in failure when loading or resuming from checkpoints.<br>**preprocessing** - This configuration controls how raw data pieces or inference inputs are converted to inputs of neural networks. Binarizers should be re-run if this configuration is modified.<br>**training** - This configuration describes the training procedures. Most training configurations can affect training performance, memory consumption, device utilization and loss calculation. Modifying training-only configurations will not cause severe inconsistency or errors in most situations.<br>**inference** - This configuration describes the calculation logic through the model graph. Changing it can lead to inconsistent or wrong outputs of inference or validation.<br>**others** - Other configurations not discussed above. Will have different effects according to the descriptions. |
|
21 |
+
| customizability | The level of customizability of the configuration. Possible values are:<br>**required** - This configuration **must** be set or modified according to the actual situation or condition, otherwise errors can be raised.<br>**recommended** - It is recommended to adjust this configuration according to the dataset, requirements, environment and hardware. Most functionality-related and feature-related configurations are at this level, and all configurations in this level are widely tested with different values. However, leaving it unchanged will not cause problems.<br>**normal** - There is no need to modify it as the default value is carefully tuned and widely validated. However, one can still use another value if there are some special requirements or situations.<br>**not recommended** - No other values except the default one of this configuration are tested. Modifying it will not cause errors, but may cause unpredictable or significant impacts to the pipelines.<br>**reserved** - This configuration **must not** be modified. It appears in the configuration file only for future scalability, and currently changing it will result in errors. |
|
22 |
+
| type | Value type of the configuration. Follows the syntax of Python type hints. |
|
23 |
+
| constraints | Value constraints of the configuration. |
|
24 |
+
| default | Default value of the configuration. Uses YAML value syntax. |
|
25 |
+
|
26 |
+
### accumulate_grad_batches
|
27 |
+
|
28 |
+
Indicates that gradients of how many training steps are accumulated before each `optimizer.step()` call. 1 means no gradient accumulation.
|
29 |
+
|
30 |
+
<table><tbody>
|
31 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
32 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
33 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
34 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
35 |
+
<tr><td align="center"><b>default</b></td><td>1</td>
|
36 |
+
</tbody></table>
|
37 |
+
|
38 |
+
### audio_num_mel_bins
|
39 |
+
|
40 |
+
Number of mel channels for the mel-spectrogram.
|
41 |
+
|
42 |
+
<table><tbody>
|
43 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
44 |
+
<tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
|
45 |
+
<tr><td align="center"><b>customizability</b></td><td>reserved</td>
|
46 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
47 |
+
<tr><td align="center"><b>default</b></td><td>128</td>
|
48 |
+
</tbody></table>
|
49 |
+
|
50 |
+
### audio_sample_rate
|
51 |
+
|
52 |
+
Sampling rate of waveforms.
|
53 |
+
|
54 |
+
<table><tbody>
|
55 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
56 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
57 |
+
<tr><td align="center"><b>customizability</b></td><td>reserved</td>
|
58 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
59 |
+
<tr><td align="center"><b>default</b></td><td>44100</td>
|
60 |
+
</tbody></table>
|
61 |
+
|
62 |
+
### augmentation_args
|
63 |
+
|
64 |
+
Arguments for data augmentation.
|
65 |
+
|
66 |
+
<table><tbody>
|
67 |
+
<tr><td align="center"><b>type</b></td><td>dict</td>
|
68 |
+
</tbody></table>
|
69 |
+
|
70 |
+
### augmentation_args.fixed_pitch_shifting
|
71 |
+
|
72 |
+
Arguments for fixed pitch shifting augmentation.
|
73 |
+
|
74 |
+
<table><tbody>
|
75 |
+
<tr><td align="center"><b>type</b></td><td>dict</td>
|
76 |
+
</tbody></table>
|
77 |
+
|
78 |
+
### augmentation_args.fixed_pitch_shifting.enabled
|
79 |
+
|
80 |
+
Whether to apply fixed pitch shifting augmentation.
|
81 |
+
|
82 |
+
<table><tbody>
|
83 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
84 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
85 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
86 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
87 |
+
<tr><td align="center"><b>default</b></td><td>false</td>
|
88 |
+
<tr><td align="center"><b>constraints</b></td><td>Must be false if <a href="#augmentation_argsrandom_pitch_shiftingenabled">augmentation_args.random_pitch_shifting.enabled</a> is set to true.</td>
|
89 |
+
</tbody></table>
|
90 |
+
|
91 |
+
### augmentation_args.fixed_pitch_shifting.scale
|
92 |
+
|
93 |
+
Scale ratio of each target in fixed pitch shifting augmentation.
|
94 |
+
|
95 |
+
<table><tbody>
|
96 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
97 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
98 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
99 |
+
<tr><td align="center"><b>type</b></td><td>tuple</td>
|
100 |
+
<tr><td align="center"><b>default</b></td><td>0.5</td>
|
101 |
+
</tbody></table>
|
102 |
+
|
103 |
+
### augmentation_args.fixed_pitch_shifting.targets
|
104 |
+
|
105 |
+
Targets (in semitones) of fixed pitch shifting augmentation.
|
106 |
+
|
107 |
+
<table><tbody>
|
108 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
109 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
110 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
111 |
+
<tr><td align="center"><b>type</b></td><td>tuple</td>
|
112 |
+
<tr><td align="center"><b>default</b></td><td>[-5.0, 5.0]</td>
|
113 |
+
</tbody></table>
|
114 |
+
|
115 |
+
### augmentation_args.random_pitch_shifting
|
116 |
+
|
117 |
+
Arguments for random pitch shifting augmentation.
|
118 |
+
|
119 |
+
<table><tbody>
|
120 |
+
<tr><td align="center"><b>type</b></td><td>dict</td>
|
121 |
+
</tbody></table>
|
122 |
+
|
123 |
+
### augmentation_args.random_pitch_shifting.enabled
|
124 |
+
|
125 |
+
Whether to apply random pitch shifting augmentation.
|
126 |
+
|
127 |
+
<table><tbody>
|
128 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
129 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
130 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
131 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
132 |
+
<tr><td align="center"><b>default</b></td><td>true</td>
|
133 |
+
<tr><td align="center"><b>constraints</b></td><td>Must be false if <a href="#augmentation_argsfixed_pitch_shiftingenabled">augmentation_args.fixed_pitch_shifting.enabled</a> is set to true.</td>
|
134 |
+
</tbody></table>
|
135 |
+
|
136 |
+
### augmentation_args.random_pitch_shifting.range
|
137 |
+
|
138 |
+
Range of the random pitch shifting ( in semitones).
|
139 |
+
|
140 |
+
<table><tbody>
|
141 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
142 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
143 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
144 |
+
<tr><td align="center"><b>type</b></td><td>tuple</td>
|
145 |
+
<tr><td align="center"><b>default</b></td><td>[-5.0, 5.0]</td>
|
146 |
+
</tbody></table>
|
147 |
+
|
148 |
+
### augmentation_args.random_pitch_shifting.scale
|
149 |
+
|
150 |
+
Scale ratio of the random pitch shifting augmentation.
|
151 |
+
|
152 |
+
<table><tbody>
|
153 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
154 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
155 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
156 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
157 |
+
<tr><td align="center"><b>default</b></td><td>0.75</td>
|
158 |
+
</tbody></table>
|
159 |
+
|
160 |
+
### augmentation_args.random_time_stretching
|
161 |
+
|
162 |
+
Arguments for random time stretching augmentation.
|
163 |
+
|
164 |
+
<table><tbody>
|
165 |
+
<tr><td align="center"><b>type</b></td><td>dict</td>
|
166 |
+
</tbody></table>
|
167 |
+
|
168 |
+
### augmentation_args.random_time_stretching.enabled
|
169 |
+
|
170 |
+
Whether to apply random time stretching augmentation.
|
171 |
+
|
172 |
+
<table><tbody>
|
173 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
174 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
175 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
176 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
177 |
+
<tr><td align="center"><b>default</b></td><td>true</td>
|
178 |
+
</tbody></table>
|
179 |
+
|
180 |
+
### augmentation_args.random_time_stretching.range
|
181 |
+
|
182 |
+
Range of random time stretching factors.
|
183 |
+
|
184 |
+
<table><tbody>
|
185 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
186 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
187 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
188 |
+
<tr><td align="center"><b>type</b></td><td>tuple</td>
|
189 |
+
<tr><td align="center"><b>default</b></td><td>[0.5, 2]</td>
|
190 |
+
</tbody></table>
|
191 |
+
|
192 |
+
### augmentation_args.random_time_stretching.scale
|
193 |
+
|
194 |
+
Scale ratio of random time stretching augmentation.
|
195 |
+
|
196 |
+
<table><tbody>
|
197 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
198 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
199 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
200 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
201 |
+
<tr><td align="center"><b>default</b></td><td>0.75</td>
|
202 |
+
</tbody></table>
|
203 |
+
|
204 |
+
### backbone_args
|
205 |
+
|
206 |
+
Keyword arguments for the backbone of main decoder module.
|
207 |
+
|
208 |
+
<table><tbody>
|
209 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
210 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
211 |
+
<tr><td align="center"><b>type</b></td><td>dict</td>
|
212 |
+
</tbody></table>
|
213 |
+
|
214 |
+
Some available arguments are listed below.
|
215 |
+
|
216 |
+
| argument name | for backbone type | description |
|
217 |
+
|:---------------------:|:-----------------:|:-----------------------------------------------------------------------------------------------------------:|
|
218 |
+
| num_layers | wavenet/lynxnet | Number of layer blocks, or depth of the network |
|
219 |
+
| num_channels | wavenet/lynxnet | Number of channels, or width of the network |
|
220 |
+
| dilation_cycle_length | wavenet | Length k of the cycle $2^0, 2^1 ...., 2^k$ of convolution dilation factors through WaveNet residual blocks. |
|
221 |
+
|
222 |
+
### backbone_type
|
223 |
+
|
224 |
+
Backbone type of the main decoder/predictor module.
|
225 |
+
|
226 |
+
<table><tbody>
|
227 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
228 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
229 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
230 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
231 |
+
<tr><td align="center"><b>default</b></td><td>lynxnet</td>
|
232 |
+
<tr><td align="center"><b>constraints</b></td><td>Choose from 'wavenet', 'lynxnet'.</td>
|
233 |
+
</tbody></table>
|
234 |
+
|
235 |
+
### base_config
|
236 |
+
|
237 |
+
Path(s) of other config files that the current config is based on and will override.
|
238 |
+
|
239 |
+
<table><tbody>
|
240 |
+
<tr><td align="center"><b>scope</b></td><td>others</td>
|
241 |
+
<tr><td align="center"><b>type</b></td><td>Union[str, list]</td>
|
242 |
+
</tbody></table>
|
243 |
+
|
244 |
+
### binarization_args
|
245 |
+
|
246 |
+
Arguments for binarizers.
|
247 |
+
|
248 |
+
<table><tbody>
|
249 |
+
<tr><td align="center"><b>type</b></td><td>dict</td>
|
250 |
+
</tbody></table>
|
251 |
+
|
252 |
+
### binarization_args.num_workers
|
253 |
+
|
254 |
+
Number of worker subprocesses when running binarizers. More workers can speed up the preprocessing but will consume more memory. 0 means the main processing doing everything.
|
255 |
+
|
256 |
+
<table><tbody>
|
257 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
258 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
259 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
260 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
261 |
+
<tr><td align="center"><b>default</b></td><td>1</td>
|
262 |
+
</tbody></table>
|
263 |
+
|
264 |
+
### binarization_args.prefer_ds
|
265 |
+
|
266 |
+
Whether to prefer loading attributes and parameters from DS files.
|
267 |
+
|
268 |
+
<table><tbody>
|
269 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
270 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
271 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
272 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
273 |
+
<tr><td align="center"><b>default</b></td><td>False</td>
|
274 |
+
</tbody></table>
|
275 |
+
|
276 |
+
### binarization_args.shuffle
|
277 |
+
|
278 |
+
Whether binarized dataset will be shuffled or not.
|
279 |
+
|
280 |
+
<table><tbody>
|
281 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
282 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
283 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
284 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
285 |
+
<tr><td align="center"><b>default</b></td><td>true</td>
|
286 |
+
</tbody></table>
|
287 |
+
|
288 |
+
### binarizer_cls
|
289 |
+
|
290 |
+
Binarizer class name.
|
291 |
+
|
292 |
+
<table><tbody>
|
293 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
294 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
295 |
+
<tr><td align="center"><b>customizability</b></td><td>reserved</td>
|
296 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
297 |
+
</tbody></table>
|
298 |
+
|
299 |
+
### binary_data_dir
|
300 |
+
|
301 |
+
Path to the binarized dataset.
|
302 |
+
|
303 |
+
<table><tbody>
|
304 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
305 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing, training</td>
|
306 |
+
<tr><td align="center"><b>customizability</b></td><td>required</td>
|
307 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
308 |
+
</tbody></table>
|
309 |
+
|
310 |
+
### breathiness_db_max
|
311 |
+
|
312 |
+
Maximum breathiness value in dB used for normalization to [-1, 1].
|
313 |
+
|
314 |
+
<table><tbody>
|
315 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
316 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
317 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
318 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
319 |
+
<tr><td align="center"><b>default</b></td><td>-20.0</td>
|
320 |
+
</tbody></table>
|
321 |
+
|
322 |
+
### breathiness_db_min
|
323 |
+
|
324 |
+
Minimum breathiness value in dB used for normalization to [-1, 1].
|
325 |
+
|
326 |
+
<table><tbody>
|
327 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
328 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
329 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
330 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
331 |
+
<tr><td align="center"><b>default</b></td><td>-96.0</td>
|
332 |
+
</tbody></table>
|
333 |
+
|
334 |
+
### breathiness_smooth_width
|
335 |
+
|
336 |
+
Length of sinusoidal smoothing convolution kernel (in seconds) on extracted breathiness curve.
|
337 |
+
|
338 |
+
<table><tbody>
|
339 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
340 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
341 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
342 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
343 |
+
<tr><td align="center"><b>default</b></td><td>0.12</td>
|
344 |
+
</tbody></table>
|
345 |
+
|
346 |
+
### clip_grad_norm
|
347 |
+
|
348 |
+
The value at which to clip gradients. Equivalent to `gradient_clip_val` in `lightning.pytorch.Trainer`.
|
349 |
+
|
350 |
+
<table><tbody>
|
351 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
352 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
353 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
354 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
355 |
+
<tr><td align="center"><b>default</b></td><td>1</td>
|
356 |
+
</tbody></table>
|
357 |
+
|
358 |
+
### dataloader_prefetch_factor
|
359 |
+
|
360 |
+
Number of batches loaded in advance by each `torch.utils.data.DataLoader` worker.
|
361 |
+
|
362 |
+
<table><tbody>
|
363 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
364 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
365 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
366 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
367 |
+
<tr><td align="center"><b>default</b></td><td>true</td>
|
368 |
+
</tbody></table>
|
369 |
+
|
370 |
+
### dataset_size_key
|
371 |
+
|
372 |
+
The key that indexes the binarized metadata to be used as the `sizes` when batching by size
|
373 |
+
|
374 |
+
<table><tbody>
|
375 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
376 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
377 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
378 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
379 |
+
<tr><td align="center"><b>default</b></td><td>lengths</td>
|
380 |
+
</tbody></table>
|
381 |
+
|
382 |
+
### dictionary
|
383 |
+
|
384 |
+
Path to the word-phoneme mapping dictionary file. Training data must fully cover phonemes in the dictionary.
|
385 |
+
|
386 |
+
<table><tbody>
|
387 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
388 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
389 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
390 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
391 |
+
</tbody></table>
|
392 |
+
|
393 |
+
### diff_accelerator
|
394 |
+
|
395 |
+
DDPM sampling acceleration method. The following methods are currently available:
|
396 |
+
|
397 |
+
- DDIM: the DDIM method from [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)
|
398 |
+
- PNDM: the PLMS method from [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778)
|
399 |
+
- DPM-Solver++ adapted from [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://github.com/LuChengTHU/dpm-solver)
|
400 |
+
- UniPC adapted from [UniPC: A Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models](https://github.com/wl-zhao/UniPC)
|
401 |
+
|
402 |
+
<table><tbody>
|
403 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
404 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
405 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
406 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
407 |
+
<tr><td align="center"><b>default</b></td><td>dpm-solver</td>
|
408 |
+
<tr><td align="center"><b>constraints</b></td><td>Choose from 'ddim', 'pndm', 'dpm-solver', 'unipc'.</td>
|
409 |
+
</tbody></table>
|
410 |
+
|
411 |
+
### diff_speedup
|
412 |
+
|
413 |
+
DDPM sampling speed-up ratio. 1 means no speeding up.
|
414 |
+
|
415 |
+
<table><tbody>
|
416 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
417 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
418 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
419 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
420 |
+
<tr><td align="center"><b>default</b></td><td>10</td>
|
421 |
+
<tr><td align="center"><b>constraints</b></td><td>Must be a factor of <a href="#K_step">K_step</a>.</td>
|
422 |
+
</tbody></table>
|
423 |
+
|
424 |
+
### diffusion_type
|
425 |
+
|
426 |
+
The type of ODE-based generative model algorithm. The following models are currently available:
|
427 |
+
|
428 |
+
- Denoising Diffusion Probabilistic Models (DDPM) from [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
|
429 |
+
- Rectified Flow from [Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow](https://arxiv.org/abs/2209.03003)
|
430 |
+
|
431 |
+
<table><tbody>
|
432 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
433 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
434 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
435 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
436 |
+
<tr><td align="center"><b>default</b></td><td>reflow</td>
|
437 |
+
<tr><td align="center"><b>constraints</b></td><td>Choose from 'ddpm', 'reflow'.</td>
|
438 |
+
</tbody></table>
|
439 |
+
|
440 |
+
### dropout
|
441 |
+
|
442 |
+
Dropout rate in some FastSpeech2 modules.
|
443 |
+
|
444 |
+
<table><tbody>
|
445 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
446 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
447 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
448 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
449 |
+
<tr><td align="center"><b>default</b></td><td>0.1</td>
|
450 |
+
</tbody></table>
|
451 |
+
|
452 |
+
### ds_workers
|
453 |
+
|
454 |
+
Number of workers of `torch.utils.data.DataLoader`.
|
455 |
+
|
456 |
+
<table><tbody>
|
457 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
458 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
459 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
460 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
461 |
+
<tr><td align="center"><b>default</b></td><td>4</td>
|
462 |
+
</tbody></table>
|
463 |
+
|
464 |
+
### dur_prediction_args
|
465 |
+
|
466 |
+
Arguments for phoneme duration prediction.
|
467 |
+
|
468 |
+
<table><tbody>
|
469 |
+
<tr><td align="center"><b>type</b></td><td>dict</td>
|
470 |
+
</tbody></table>
|
471 |
+
|
472 |
+
### dur_prediction_args.arch
|
473 |
+
|
474 |
+
Architecture of duration predictor.
|
475 |
+
|
476 |
+
<table><tbody>
|
477 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
478 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
479 |
+
<tr><td align="center"><b>customizability</b></td><td>reserved</td>
|
480 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
481 |
+
<tr><td align="center"><b>default</b></td><td>fs2</td>
|
482 |
+
<tr><td align="center"><b>constraints</b></td><td>Choose from 'fs2'.</td>
|
483 |
+
</tbody></table>
|
484 |
+
|
485 |
+
### dur_prediction_args.dropout
|
486 |
+
|
487 |
+
Dropout rate in duration predictor of FastSpeech2.
|
488 |
+
|
489 |
+
<table><tbody>
|
490 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
491 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
492 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
493 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
494 |
+
<tr><td align="center"><b>default</b></td><td>0.1</td>
|
495 |
+
</tbody></table>
|
496 |
+
|
497 |
+
### dur_prediction_args.hidden_size
|
498 |
+
|
499 |
+
Dimensions of hidden layers in duration predictor of FastSpeech2.
|
500 |
+
|
501 |
+
<table><tbody>
|
502 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
503 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
504 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
505 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
506 |
+
<tr><td align="center"><b>default</b></td><td>512</td>
|
507 |
+
</tbody></table>
|
508 |
+
|
509 |
+
### dur_prediction_args.kernel_size
|
510 |
+
|
511 |
+
Kernel size of convolution layers of duration predictor of FastSpeech2.
|
512 |
+
|
513 |
+
<table><tbody>
|
514 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
515 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
516 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
517 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
518 |
+
<tr><td align="center"><b>default</b></td><td>3</td>
|
519 |
+
</tbody></table>
|
520 |
+
|
521 |
+
### dur_prediction_args.lambda_pdur_loss
|
522 |
+
|
523 |
+
Coefficient of single phone duration loss when calculating joint duration loss.
|
524 |
+
|
525 |
+
<table><tbody>
|
526 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
527 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
528 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
529 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
530 |
+
<tr><td align="center"><b>default</b></td><td>0.3</td>
|
531 |
+
</tbody></table>
|
532 |
+
|
533 |
+
### dur_prediction_args.lambda_sdur_loss
|
534 |
+
|
535 |
+
Coefficient of sentence duration loss when calculating joint duration loss.
|
536 |
+
|
537 |
+
<table><tbody>
|
538 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
539 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
540 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
541 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
542 |
+
<tr><td align="center"><b>default</b></td><td>3.0</td>
|
543 |
+
</tbody></table>
|
544 |
+
|
545 |
+
### dur_prediction_args.lambda_wdur_loss
|
546 |
+
|
547 |
+
Coefficient of word duration loss when calculating joint duration loss.
|
548 |
+
|
549 |
+
<table><tbody>
|
550 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
551 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
552 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
553 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
554 |
+
<tr><td align="center"><b>default</b></td><td>1.0</td>
|
555 |
+
</tbody></table>
|
556 |
+
|
557 |
+
### dur_prediction_args.log_offset
|
558 |
+
|
559 |
+
Offset for log domain duration loss calculation, where the following transformation is applied:
|
560 |
+
$$
|
561 |
+
D' = \ln{(D+d)}
|
562 |
+
$$
|
563 |
+
with the offset value $d$.
|
564 |
+
|
565 |
+
<table><tbody>
|
566 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
567 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
568 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
569 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
570 |
+
<tr><td align="center"><b>default</b></td><td>1.0</td>
|
571 |
+
</tbody></table>
|
572 |
+
|
573 |
+
### dur_prediction_args.loss_type
|
574 |
+
|
575 |
+
Underlying loss type of duration loss.
|
576 |
+
|
577 |
+
<table><tbody>
|
578 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
579 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
580 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
581 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
582 |
+
<tr><td align="center"><b>default</b></td><td>mse</td>
|
583 |
+
<tr><td align="center"><b>constraints</b></td><td>Choose from 'mse', 'huber'.</td>
|
584 |
+
</tbody></table>
|
585 |
+
|
586 |
+
### dur_prediction_args.num_layers
|
587 |
+
|
588 |
+
Number of duration predictor layers.
|
589 |
+
|
590 |
+
<table><tbody>
|
591 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
592 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
593 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
594 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
595 |
+
<tr><td align="center"><b>default</b></td><td>5</td>
|
596 |
+
</tbody></table>
|
597 |
+
|
598 |
+
### enc_ffn_kernel_size
|
599 |
+
|
600 |
+
Size of TransformerFFNLayer convolution kernel size in FastSpeech2 encoder.
|
601 |
+
|
602 |
+
<table><tbody>
|
603 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
604 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
605 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
606 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
607 |
+
<tr><td align="center"><b>default</b></td><td>9</td>
|
608 |
+
</tbody></table>
|
609 |
+
|
610 |
+
### enc_layers
|
611 |
+
|
612 |
+
Number of FastSpeech2 encoder layers.
|
613 |
+
|
614 |
+
<table><tbody>
|
615 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
616 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
617 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
618 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
619 |
+
<tr><td align="center"><b>default</b></td><td>4</td>
|
620 |
+
</tbody></table>
|
621 |
+
|
622 |
+
### energy_db_max
|
623 |
+
|
624 |
+
Maximum energy value in dB used for normalization to [-1, 1].
|
625 |
+
|
626 |
+
<table><tbody>
|
627 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
628 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
629 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
630 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
631 |
+
<tr><td align="center"><b>default</b></td><td>-12.0</td>
|
632 |
+
</tbody></table>
|
633 |
+
|
634 |
+
### energy_db_min
|
635 |
+
|
636 |
+
Minimum energy value in dB used for normalization to [-1, 1].
|
637 |
+
|
638 |
+
<table><tbody>
|
639 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
640 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
641 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
642 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
643 |
+
<tr><td align="center"><b>default</b></td><td>-96.0</td>
|
644 |
+
</tbody></table>
|
645 |
+
|
646 |
+
### energy_smooth_width
|
647 |
+
|
648 |
+
Length of sinusoidal smoothing convolution kernel (in seconds) on extracted energy curve.
|
649 |
+
|
650 |
+
<table><tbody>
|
651 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
652 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
653 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
654 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
655 |
+
<tr><td align="center"><b>default</b></td><td>0.12</td>
|
656 |
+
</tbody></table>
|
657 |
+
|
658 |
+
### f0_max
|
659 |
+
|
660 |
+
Maximum base frequency (F0) in Hz for pitch extraction.
|
661 |
+
|
662 |
+
<table><tbody>
|
663 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
664 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
665 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
666 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
667 |
+
<tr><td align="center"><b>default</b></td><td>1100</td>
|
668 |
+
</tbody></table>
|
669 |
+
|
670 |
+
### f0_min
|
671 |
+
|
672 |
+
Minimum base frequency (F0) in Hz for pitch extraction.
|
673 |
+
|
674 |
+
<table><tbody>
|
675 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
676 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
677 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
678 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
679 |
+
<tr><td align="center"><b>default</b></td><td>65</td>
|
680 |
+
</tbody></table>
|
681 |
+
|
682 |
+
### ffn_act
|
683 |
+
|
684 |
+
Activation function of TransformerFFNLayer in FastSpeech2 encoder:
|
685 |
+
|
686 |
+
- `torch.nn.ReLU` if 'relu'
|
687 |
+
- `torch.nn.GELU` if 'gelu'
|
688 |
+
- `torch.nn.SiLU` if 'swish'
|
689 |
+
|
690 |
+
<table><tbody>
|
691 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
692 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
693 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
694 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
695 |
+
<tr><td align="center"><b>default</b></td><td>gelu</td>
|
696 |
+
<tr><td align="center"><b>constraints</b></td><td>Choose from 'relu', 'gelu', 'swish'.</td>
|
697 |
+
</tbody></table>
|
698 |
+
|
699 |
+
### fft_size
|
700 |
+
|
701 |
+
Fast Fourier Transforms parameter for mel extraction.
|
702 |
+
|
703 |
+
<table><tbody>
|
704 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
705 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
706 |
+
<tr><td align="center"><b>customizability</b></td><td>reserved</td>
|
707 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
708 |
+
<tr><td align="center"><b>default</b></td><td>2048</td>
|
709 |
+
</tbody></table>
|
710 |
+
|
711 |
+
### finetune_enabled
|
712 |
+
|
713 |
+
Whether to finetune from a pretrained model.
|
714 |
+
|
715 |
+
<table><tbody>
|
716 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
717 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
718 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
719 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
720 |
+
<tr><td align="center"><b>default</b></td><td>False</td>
|
721 |
+
</tbody></table>
|
722 |
+
|
723 |
+
### finetune_ckpt_path
|
724 |
+
|
725 |
+
Path to the pretrained model for finetuning.
|
726 |
+
|
727 |
+
<table><tbody>
|
728 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
729 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
730 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
731 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
732 |
+
<tr><td align="center"><b>default</b></td><td>null</td>
|
733 |
+
</tbody></table>
|
734 |
+
|
735 |
+
### finetune_ignored_params
|
736 |
+
|
737 |
+
Prefixes of parameter key names in the state dict of the pretrained model that need to be dropped before finetuning.
|
738 |
+
|
739 |
+
<table><tbody>
|
740 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
741 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
742 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
743 |
+
<tr><td align="center"><b>type</b></td><td>list</td>
|
744 |
+
</tbody></table>
|
745 |
+
|
746 |
+
### finetune_strict_shapes
|
747 |
+
|
748 |
+
Whether to raise error if the tensor shapes of any parameter of the pretrained model and the target model mismatch. If set to `False`, parameters with mismatching shapes will be skipped.
|
749 |
+
|
750 |
+
<table><tbody>
|
751 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
752 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
753 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
754 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
755 |
+
<tr><td align="center"><b>default</b></td><td>True</td>
|
756 |
+
</tbody></table>
|
757 |
+
|
758 |
+
### fmax
|
759 |
+
|
760 |
+
Maximum frequency of mel extraction.
|
761 |
+
|
762 |
+
<table><tbody>
|
763 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
764 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
765 |
+
<tr><td align="center"><b>customizability</b></td><td>reserved</td>
|
766 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
767 |
+
<tr><td align="center"><b>default</b></td><td>16000</td>
|
768 |
+
</tbody></table>
|
769 |
+
|
770 |
+
### fmin
|
771 |
+
|
772 |
+
Minimum frequency of mel extraction.
|
773 |
+
|
774 |
+
<table><tbody>
|
775 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
776 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
777 |
+
<tr><td align="center"><b>customizability</b></td><td>reserved</td>
|
778 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
779 |
+
<tr><td align="center"><b>default</b></td><td>40</td>
|
780 |
+
</tbody></table>
|
781 |
+
|
782 |
+
### freezing_enabled
|
783 |
+
|
784 |
+
Whether enabling parameter freezing during training.
|
785 |
+
|
786 |
+
<table><tbody>
|
787 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
788 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
789 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
790 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
791 |
+
<tr><td align="center"><b>default</b></td><td>False</td>
|
792 |
+
</tbody></table>
|
793 |
+
|
794 |
+
### frozen_params
|
795 |
+
|
796 |
+
Parameter name prefixes to freeze during training.
|
797 |
+
|
798 |
+
<table><tbody>
|
799 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
800 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
801 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
802 |
+
<tr><td align="center"><b>type</b></td><td>list</td>
|
803 |
+
<tr><td align="center"><b>default</b></td><td>[]</td>
|
804 |
+
</tbody></table>
|
805 |
+
|
806 |
+
### glide_embed_scale
|
807 |
+
|
808 |
+
The scale factor to be multiplied on the glide embedding values for melody encoder.
|
809 |
+
|
810 |
+
<table><tbody>
|
811 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
812 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
813 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
814 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
815 |
+
<tr><td align="center"><b>default</b></td><td>11.313708498984760</td>
|
816 |
+
</tbody></table>
|
817 |
+
|
818 |
+
### glide_types
|
819 |
+
|
820 |
+
Type names of glide notes.
|
821 |
+
|
822 |
+
<table><tbody>
|
823 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
824 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
825 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
826 |
+
<tr><td align="center"><b>type</b></td><td>list</td>
|
827 |
+
<tr><td align="center"><b>default</b></td><td>[up, down]</td>
|
828 |
+
</tbody></table>
|
829 |
+
|
830 |
+
### hidden_size
|
831 |
+
|
832 |
+
Dimension of hidden layers of FastSpeech2, token and parameter embeddings, and diffusion condition.
|
833 |
+
|
834 |
+
<table><tbody>
|
835 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
836 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
837 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
838 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
839 |
+
<tr><td align="center"><b>default</b></td><td>256</td>
|
840 |
+
</tbody></table>
|
841 |
+
|
842 |
+
### hnsep
|
843 |
+
|
844 |
+
Harmonic-noise separation algorithm type.
|
845 |
+
|
846 |
+
<table><tbody>
|
847 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
848 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
849 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
850 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
851 |
+
<tr><td align="center"><b>default</b></td><td>world</td>
|
852 |
+
<tr><td align="center"><b>constraints</b></td><td>Choose from 'world', 'vr'.</td>
|
853 |
+
</tbody></table>
|
854 |
+
|
855 |
+
### hnsep_ckpt
|
856 |
+
|
857 |
+
Checkpoint or model path of NN-based harmonic-noise separator.
|
858 |
+
|
859 |
+
<table><tbody>
|
860 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
861 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
862 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
863 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
864 |
+
</tbody></table>
|
865 |
+
|
866 |
+
### hop_size
|
867 |
+
|
868 |
+
Hop size or step length (in number of waveform samples) of mel and feature extraction.
|
869 |
+
|
870 |
+
<table><tbody>
|
871 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
872 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
873 |
+
<tr><td align="center"><b>customizability</b></td><td>reserved</td>
|
874 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
875 |
+
<tr><td align="center"><b>default</b></td><td>512</td>
|
876 |
+
</tbody></table>
|
877 |
+
|
878 |
+
### lambda_aux_mel_loss
|
879 |
+
|
880 |
+
Coefficient of aux mel loss when calculating total loss of acoustic model with shallow diffusion.
|
881 |
+
|
882 |
+
<table><tbody>
|
883 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
884 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
885 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
886 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
887 |
+
<tr><td align="center"><b>default</b></td><td>0.2</td>
|
888 |
+
</tbody></table>
|
889 |
+
|
890 |
+
### lambda_dur_loss
|
891 |
+
|
892 |
+
Coefficient of duration loss when calculating total loss of variance model.
|
893 |
+
|
894 |
+
<table><tbody>
|
895 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
896 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
897 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
898 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
899 |
+
<tr><td align="center"><b>default</b></td><td>1.0</td>
|
900 |
+
</tbody></table>
|
901 |
+
|
902 |
+
### lambda_pitch_loss
|
903 |
+
|
904 |
+
Coefficient of pitch loss when calculating total loss of variance model.
|
905 |
+
|
906 |
+
<table><tbody>
|
907 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
908 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
909 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
910 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
911 |
+
<tr><td align="center"><b>default</b></td><td>1.0</td>
|
912 |
+
</tbody></table>
|
913 |
+
|
914 |
+
### lambda_var_loss
|
915 |
+
|
916 |
+
Coefficient of variance loss (all variance parameters other than pitch, like energy, breathiness, etc.) when calculating total loss of variance model.
|
917 |
+
|
918 |
+
<table><tbody>
|
919 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
920 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
921 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
922 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
923 |
+
<tr><td align="center"><b>default</b></td><td>1.0</td>
|
924 |
+
</tbody></table>
|
925 |
+
|
926 |
+
### K_step
|
927 |
+
|
928 |
+
Maximum number of DDPM steps used by shallow diffusion.
|
929 |
+
|
930 |
+
<table><tbody>
|
931 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
932 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
933 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
934 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
935 |
+
<tr><td align="center"><b>default</b></td><td>400</td>
|
936 |
+
</tbody></table>
|
937 |
+
|
938 |
+
### K_step_infer
|
939 |
+
|
940 |
+
Number of DDPM steps used during shallow diffusion inference. Normally set as same as [K_step](#K_step).
|
941 |
+
|
942 |
+
<table><tbody>
|
943 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
944 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
945 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
946 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
947 |
+
<tr><td align="center"><b>default</b></td><td>400</td>
|
948 |
+
<tr><td align="center"><b>constraints</b></td><td>Should be no larger than K_step.</td>
|
949 |
+
</tbody></table>
|
950 |
+
|
951 |
+
### log_interval
|
952 |
+
|
953 |
+
Controls how often to log within training steps. Equivalent to `log_every_n_steps` in `lightning.pytorch.Trainer`.
|
954 |
+
|
955 |
+
<table><tbody>
|
956 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
957 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
958 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
959 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
960 |
+
<tr><td align="center"><b>default</b></td><td>100</td>
|
961 |
+
</tbody></table>
|
962 |
+
|
963 |
+
### lr_scheduler_args
|
964 |
+
|
965 |
+
Arguments of learning rate scheduler. Keys will be used as keyword arguments of the `__init__()` method of [lr_scheduler_args.scheduler_cls](#lr_scheduler_argsscheduler_cls).
|
966 |
+
|
967 |
+
<table><tbody>
|
968 |
+
<tr><td align="center"><b>type</b></td><td>dict</td>
|
969 |
+
</tbody></table>
|
970 |
+
|
971 |
+
### lr_scheduler_args.scheduler_cls
|
972 |
+
|
973 |
+
Learning rate scheduler class name.
|
974 |
+
|
975 |
+
<table><tbody>
|
976 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
977 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
978 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
979 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
980 |
+
<tr><td align="center"><b>default</b></td><td>torch.optim.lr_scheduler.StepLR</td>
|
981 |
+
</tbody></table>
|
982 |
+
|
983 |
+
### main_loss_log_norm
|
984 |
+
|
985 |
+
Whether to use log-normalized weight for the main loss. This is similar to the method in the Stable Diffusion 3 paper [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206).
|
986 |
+
|
987 |
+
<table><tbody>
|
988 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
989 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
990 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
991 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
992 |
+
</tbody></table>
|
993 |
+
|
994 |
+
### main_loss_type
|
995 |
+
|
996 |
+
Loss type of the main decoder/predictor.
|
997 |
+
|
998 |
+
<table><tbody>
|
999 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1000 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1001 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
1002 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
1003 |
+
<tr><td align="center"><b>default</b></td><td>l2</td>
|
1004 |
+
<tr><td align="center"><b>constraints</b></td><td>Choose from 'l1', 'l2'.</td>
|
1005 |
+
</tbody></table>
|
1006 |
+
|
1007 |
+
### max_batch_frames
|
1008 |
+
|
1009 |
+
Maximum number of data frames in each training batch. Used to dynamically control the batch size.
|
1010 |
+
|
1011 |
+
<table><tbody>
|
1012 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1013 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1014 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1015 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1016 |
+
<tr><td align="center"><b>default</b></td><td>80000</td>
|
1017 |
+
</tbody></table>
|
1018 |
+
|
1019 |
+
### max_batch_size
|
1020 |
+
|
1021 |
+
The maximum training batch size.
|
1022 |
+
|
1023 |
+
<table><tbody>
|
1024 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1025 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1026 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1027 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1028 |
+
<tr><td align="center"><b>default</b></td><td>48</td>
|
1029 |
+
</tbody></table>
|
1030 |
+
|
1031 |
+
### max_beta
|
1032 |
+
|
1033 |
+
Max beta of the DDPM noise schedule.
|
1034 |
+
|
1035 |
+
<table><tbody>
|
1036 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1037 |
+
<tr><td align="center"><b>scope</b></td><td>nn, inference</td>
|
1038 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1039 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
1040 |
+
<tr><td align="center"><b>default</b></td><td>0.02</td>
|
1041 |
+
</tbody></table>
|
1042 |
+
|
1043 |
+
### max_updates
|
1044 |
+
|
1045 |
+
Stop training after this number of steps. Equivalent to `max_steps` in `lightning.pytorch.Trainer`.
|
1046 |
+
|
1047 |
+
<table><tbody>
|
1048 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1049 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1050 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1051 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1052 |
+
<tr><td align="center"><b>default</b></td><td>320000</td>
|
1053 |
+
</tbody></table>
|
1054 |
+
|
1055 |
+
### max_val_batch_frames
|
1056 |
+
|
1057 |
+
Maximum number of data frames in each validation batch.
|
1058 |
+
|
1059 |
+
<table><tbody>
|
1060 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1061 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1062 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1063 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1064 |
+
<tr><td align="center"><b>default</b></td><td>60000</td>
|
1065 |
+
</tbody></table>
|
1066 |
+
|
1067 |
+
### max_val_batch_size
|
1068 |
+
|
1069 |
+
The maximum validation batch size.
|
1070 |
+
|
1071 |
+
<table><tbody>
|
1072 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1073 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1074 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1075 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1076 |
+
<tr><td align="center"><b>default</b></td><td>1</td>
|
1077 |
+
</tbody></table>
|
1078 |
+
|
1079 |
+
### mel_base
|
1080 |
+
|
1081 |
+
The logarithmic base of mel spectrogram calculation.
|
1082 |
+
|
1083 |
+
**WARNING: Since v2.4.0 release, this value is no longer configurable for preprocessing new datasets.**
|
1084 |
+
|
1085 |
+
<table><tbody>
|
1086 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1087 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
1088 |
+
<tr><td align="center"><b>customizability</b></td><td>reserved</td>
|
1089 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
1090 |
+
<tr><td align="center"><b>default</b></td><td>e</td>
|
1091 |
+
</tbody></table>
|
1092 |
+
|
1093 |
+
### mel_vmax
|
1094 |
+
|
1095 |
+
Maximum mel spectrogram heatmap value for TensorBoard plotting.
|
1096 |
+
|
1097 |
+
<table><tbody>
|
1098 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1099 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1100 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
1101 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
1102 |
+
<tr><td align="center"><b>default</b></td><td>1.5</td>
|
1103 |
+
</tbody></table>
|
1104 |
+
|
1105 |
+
### mel_vmin
|
1106 |
+
|
1107 |
+
Minimum mel spectrogram heatmap value for TensorBoard plotting.
|
1108 |
+
|
1109 |
+
<table><tbody>
|
1110 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1111 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1112 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
1113 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
1114 |
+
<tr><td align="center"><b>default</b></td><td>-6.0</td>
|
1115 |
+
</tbody></table>
|
1116 |
+
|
1117 |
+
### melody_encoder_args
|
1118 |
+
|
1119 |
+
Arguments for melody encoder. Available sub-keys: `hidden_size`, `enc_layers`, `enc_ffn_kernel_size`, `ffn_act`, `dropout`, `num_heads`, `use_pos_embed`, `rel_pos`. If either of the parameter does not exist in this configuration key, it inherits from the linguistic encoder.
|
1120 |
+
|
1121 |
+
<table><tbody>
|
1122 |
+
<tr><td align="center"><b>type</b></td><td>dict</td>
|
1123 |
+
</tbody></table>
|
1124 |
+
|
1125 |
+
### midi_smooth_width
|
1126 |
+
|
1127 |
+
Length of sinusoidal smoothing convolution kernel (in seconds) on the step function representing MIDI sequence for base pitch calculation.
|
1128 |
+
|
1129 |
+
<table><tbody>
|
1130 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1131 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
1132 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1133 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
1134 |
+
<tr><td align="center"><b>default</b></td><td>0.06</td>
|
1135 |
+
</tbody></table>
|
1136 |
+
|
1137 |
+
### nccl_p2p
|
1138 |
+
|
1139 |
+
Whether to enable P2P when using NCCL as the backend. Turn it to `false` if the training process is stuck upon beginning.
|
1140 |
+
|
1141 |
+
<table><tbody>
|
1142 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1143 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1144 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1145 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
1146 |
+
<tr><td align="center"><b>default</b></td><td>true</td>
|
1147 |
+
</tbody></table>
|
1148 |
+
|
1149 |
+
### num_ckpt_keep
|
1150 |
+
|
1151 |
+
Number of newest checkpoints kept during training.
|
1152 |
+
|
1153 |
+
<table><tbody>
|
1154 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1155 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1156 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1157 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1158 |
+
<tr><td align="center"><b>default</b></td><td>5</td>
|
1159 |
+
</tbody></table>
|
1160 |
+
|
1161 |
+
### num_heads
|
1162 |
+
|
1163 |
+
The number of attention heads of `torch.nn.MultiheadAttention` in FastSpeech2 encoder.
|
1164 |
+
|
1165 |
+
<table><tbody>
|
1166 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1167 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
1168 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
1169 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1170 |
+
<tr><td align="center"><b>default</b></td><td>2</td>
|
1171 |
+
</tbody></table>
|
1172 |
+
|
1173 |
+
### num_sanity_val_steps
|
1174 |
+
|
1175 |
+
Number of sanity validation steps at the beginning.
|
1176 |
+
|
1177 |
+
<table><tbody>
|
1178 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1179 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1180 |
+
<tr><td align="center"><b>customizability</b></td><td>reserved</td>
|
1181 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1182 |
+
<tr><td align="center"><b>default</b></td><td>1</td>
|
1183 |
+
</tbody></table>
|
1184 |
+
|
1185 |
+
### num_spk
|
1186 |
+
|
1187 |
+
Maximum number of speakers in multi-speaker models.
|
1188 |
+
|
1189 |
+
<table><tbody>
|
1190 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1191 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
1192 |
+
<tr><td align="center"><b>customizability</b></td><td>required</td>
|
1193 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1194 |
+
<tr><td align="center"><b>default</b></td><td>1</td>
|
1195 |
+
</tbody></table>
|
1196 |
+
|
1197 |
+
### num_valid_plots
|
1198 |
+
|
1199 |
+
Number of validation plots in each validation. Plots will be chosen from the start of the validation set.
|
1200 |
+
|
1201 |
+
<table><tbody>
|
1202 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1203 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1204 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1205 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1206 |
+
<tr><td align="center"><b>default</b></td><td>10</td>
|
1207 |
+
</tbody></table>
|
1208 |
+
|
1209 |
+
### optimizer_args
|
1210 |
+
|
1211 |
+
Arguments of optimizer. Keys will be used as keyword arguments of the `__init__()` method of [optimizer_args.optimizer_cls](#optimizer_argsoptimizer_cls).
|
1212 |
+
|
1213 |
+
<table><tbody>
|
1214 |
+
<tr><td align="center"><b>type</b></td><td>dict</td>
|
1215 |
+
</tbody></table>
|
1216 |
+
|
1217 |
+
### optimizer_args.optimizer_cls
|
1218 |
+
|
1219 |
+
Optimizer class name
|
1220 |
+
|
1221 |
+
<table><tbody>
|
1222 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1223 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1224 |
+
<tr><td align="center"><b>customizability</b></td><td>reserved</td>
|
1225 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
1226 |
+
<tr><td align="center"><b>default</b></td><td>torch.optim.AdamW</td>
|
1227 |
+
</tbody></table>
|
1228 |
+
|
1229 |
+
### pe
|
1230 |
+
|
1231 |
+
Pitch extraction algorithm type.
|
1232 |
+
|
1233 |
+
<table><tbody>
|
1234 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1235 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
1236 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1237 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
1238 |
+
<tr><td align="center"><b>default</b></td><td>parselmouth</td>
|
1239 |
+
<tr><td align="center"><b>constraints</b></td><td>Choose from 'parselmouth', 'rmvpe', 'harvest'.</td>
|
1240 |
+
</tbody></table>
|
1241 |
+
|
1242 |
+
### pe_ckpt
|
1243 |
+
|
1244 |
+
Checkpoint or model path of NN-based pitch extractor.
|
1245 |
+
|
1246 |
+
<table><tbody>
|
1247 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1248 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
1249 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1250 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
1251 |
+
</tbody></table>
|
1252 |
+
|
1253 |
+
### permanent_ckpt_interval
|
1254 |
+
|
1255 |
+
The interval (in number of training steps) of permanent checkpoints. Permanent checkpoints will not be removed even if they are not the newest ones.
|
1256 |
+
|
1257 |
+
<table><tbody>
|
1258 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1259 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1260 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1261 |
+
<tr><td align="center"><b>default</b></td><td>40000</td>
|
1262 |
+
</tbody></table>
|
1263 |
+
|
1264 |
+
### permanent_ckpt_start
|
1265 |
+
|
1266 |
+
Checkpoints will be marked as permanent every [permanent_ckpt_interval](#permanent_ckpt_interval) training steps after this number of training steps.
|
1267 |
+
|
1268 |
+
<table><tbody>
|
1269 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1270 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1271 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1272 |
+
<tr><td align="center"><b>default</b></td><td>120000</td>
|
1273 |
+
</tbody></table>
|
1274 |
+
|
1275 |
+
### pitch_prediction_args
|
1276 |
+
|
1277 |
+
Arguments for pitch prediction.
|
1278 |
+
|
1279 |
+
<table><tbody>
|
1280 |
+
<tr><td align="center"><b>type</b></td><td>dict</td>
|
1281 |
+
</tbody></table>
|
1282 |
+
|
1283 |
+
### pitch_prediction_args.backbone_args
|
1284 |
+
|
1285 |
+
Equivalent to [backbone_args](#backbone_args) but only for the pitch predictor model. If not set, use the root backbone type.
|
1286 |
+
|
1287 |
+
<table><tbody>
|
1288 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1289 |
+
</tbody></table>
|
1290 |
+
|
1291 |
+
### pitch_prediction_args.backbone_type
|
1292 |
+
|
1293 |
+
Equivalent to [backbone_type](#backbone_type) but only for the pitch predictor model.
|
1294 |
+
|
1295 |
+
<table><tbody>
|
1296 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1297 |
+
<tr><td align="center"><b>default</b></td><td>wavenet</td>
|
1298 |
+
</tbody></table>
|
1299 |
+
|
1300 |
+
### pitch_prediction_args.pitd_clip_max
|
1301 |
+
|
1302 |
+
Maximum clipping value (in semitones) of pitch delta between actual pitch and base pitch.
|
1303 |
+
|
1304 |
+
<table><tbody>
|
1305 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1306 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
1307 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
1308 |
+
<tr><td align="center"><b>default</b></td><td>12.0</td>
|
1309 |
+
</tbody></table>
|
1310 |
+
|
1311 |
+
### pitch_prediction_args.pitd_clip_min
|
1312 |
+
|
1313 |
+
Minimum clipping value (in semitones) of pitch delta between actual pitch and base pitch.
|
1314 |
+
|
1315 |
+
<table><tbody>
|
1316 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1317 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
1318 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
1319 |
+
<tr><td align="center"><b>default</b></td><td>-12.0</td>
|
1320 |
+
</tbody></table>
|
1321 |
+
|
1322 |
+
### pitch_prediction_args.pitd_norm_max
|
1323 |
+
|
1324 |
+
Maximum pitch delta value in semitones used for normalization to [-1, 1].
|
1325 |
+
|
1326 |
+
<table><tbody>
|
1327 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1328 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
1329 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1330 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
1331 |
+
<tr><td align="center"><b>default</b></td><td>8.0</td>
|
1332 |
+
</tbody></table>
|
1333 |
+
|
1334 |
+
### pitch_prediction_args.pitd_norm_min
|
1335 |
+
|
1336 |
+
Minimum pitch delta value in semitones used for normalization to [-1, 1].
|
1337 |
+
|
1338 |
+
<table><tbody>
|
1339 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1340 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
1341 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1342 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
1343 |
+
<tr><td align="center"><b>default</b></td><td>-8.0</td>
|
1344 |
+
</tbody></table>
|
1345 |
+
|
1346 |
+
### pitch_prediction_args.repeat_bins
|
1347 |
+
|
1348 |
+
Number of repeating bins in the pitch predictor.
|
1349 |
+
|
1350 |
+
<table><tbody>
|
1351 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1352 |
+
<tr><td align="center"><b>scope</b></td><td>nn, inference</td>
|
1353 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1354 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1355 |
+
<tr><td align="center"><b>default</b></td><td>64</td>
|
1356 |
+
</tbody></table>
|
1357 |
+
|
1358 |
+
### pl_trainer_accelerator
|
1359 |
+
|
1360 |
+
Type of Lightning trainer hardware accelerator.
|
1361 |
+
|
1362 |
+
<table><tbody>
|
1363 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1364 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1365 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
1366 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
1367 |
+
<tr><td align="center"><b>default</b></td><td>auto</td>
|
1368 |
+
<tr><td align="center"><b>constraints</b></td><td>See <a href="https://lightning.ai/docs/pytorch/stable/extensions/accelerator.html?highlight=accelerator">Accelerator — PyTorch Lightning 2.X.X documentation</a> for available values.</td>
|
1369 |
+
</tbody></table>
|
1370 |
+
|
1371 |
+
### pl_trainer_devices
|
1372 |
+
|
1373 |
+
To determine on which device(s) model should be trained.
|
1374 |
+
|
1375 |
+
'auto' will utilize all visible devices defined with the `CUDA_VISIBLE_DEVICES` environment variable, or utilize all available devices if that variable is not set. Otherwise, it behaves like `CUDA_VISIBLE_DEVICES` which can filter out visible devices.
|
1376 |
+
|
1377 |
+
<table><tbody>
|
1378 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1379 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1380 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
1381 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
1382 |
+
<tr><td align="center"><b>default</b></td><td>auto</td>
|
1383 |
+
</tbody></table>
|
1384 |
+
|
1385 |
+
### pl_trainer_precision
|
1386 |
+
|
1387 |
+
The computation precision of training.
|
1388 |
+
|
1389 |
+
<table><tbody>
|
1390 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1391 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1392 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1393 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
1394 |
+
<tr><td align="center"><b>default</b></td><td>16-mixed</td>
|
1395 |
+
<tr><td align="center"><b>constraints</b></td><td>Choose from '32-true', 'bf16-mixed', '16-mixed'. See more possible values at <a href="https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api">Trainer — PyTorch Lightning 2.X.X documentation</a>.</td>
|
1396 |
+
</tbody></table>
|
1397 |
+
|
1398 |
+
### pl_trainer_num_nodes
|
1399 |
+
|
1400 |
+
Number of nodes in the training cluster of Lightning trainer.
|
1401 |
+
|
1402 |
+
<table><tbody>
|
1403 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1404 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1405 |
+
<tr><td align="center"><b>customizability</b></td><td>reserved</td>
|
1406 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1407 |
+
<tr><td align="center"><b>default</b></td><td>1</td>
|
1408 |
+
</tbody></table>
|
1409 |
+
|
1410 |
+
### pl_trainer_strategy
|
1411 |
+
|
1412 |
+
Arguments of Lightning Strategy. Values will be used as keyword arguments when constructing the Strategy object.
|
1413 |
+
|
1414 |
+
<table><tbody>
|
1415 |
+
<tr><td align="center"><b>type</b></td><td>dict</td>
|
1416 |
+
</tbody></table>
|
1417 |
+
|
1418 |
+
### pl_trainer_strategy.name
|
1419 |
+
|
1420 |
+
Strategy name for the Lightning trainer.
|
1421 |
+
|
1422 |
+
<table><tbody>
|
1423 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1424 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1425 |
+
<tr><td align="center"><b>customizability</b></td><td>reserved</td>
|
1426 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
1427 |
+
<tr><td align="center"><b>default</b></td><td>auto</td>
|
1428 |
+
</tbody></table>
|
1429 |
+
|
1430 |
+
### predict_breathiness
|
1431 |
+
|
1432 |
+
Whether to enable breathiness prediction.
|
1433 |
+
|
1434 |
+
<table><tbody>
|
1435 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1436 |
+
<tr><td align="center"><b>scope</b></td><td>nn, preprocessing, training, inference</td>
|
1437 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1438 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
1439 |
+
<tr><td align="center"><b>default</b></td><td>false</td>
|
1440 |
+
</tbody></table>
|
1441 |
+
|
1442 |
+
### predict_dur
|
1443 |
+
|
1444 |
+
Whether to enable phoneme duration prediction.
|
1445 |
+
|
1446 |
+
<table><tbody>
|
1447 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1448 |
+
<tr><td align="center"><b>scope</b></td><td>nn, preprocessing, training, inference</td>
|
1449 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1450 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
1451 |
+
<tr><td align="center"><b>default</b></td><td>true</td>
|
1452 |
+
</tbody></table>
|
1453 |
+
|
1454 |
+
### predict_energy
|
1455 |
+
|
1456 |
+
Whether to enable energy prediction.
|
1457 |
+
|
1458 |
+
<table><tbody>
|
1459 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1460 |
+
<tr><td align="center"><b>scope</b></td><td>nn, preprocessing, training, inference</td>
|
1461 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1462 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
1463 |
+
<tr><td align="center"><b>default</b></td><td>false</td>
|
1464 |
+
</tbody></table>
|
1465 |
+
|
1466 |
+
### predict_pitch
|
1467 |
+
|
1468 |
+
Whether to enable pitch prediction.
|
1469 |
+
|
1470 |
+
<table><tbody>
|
1471 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1472 |
+
<tr><td align="center"><b>scope</b></td><td>nn, preprocessing, training, inference</td>
|
1473 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1474 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
1475 |
+
<tr><td align="center"><b>default</b></td><td>true</td>
|
1476 |
+
</tbody></table>
|
1477 |
+
|
1478 |
+
### predict_tension
|
1479 |
+
|
1480 |
+
Whether to enable tension prediction.
|
1481 |
+
|
1482 |
+
<table><tbody>
|
1483 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1484 |
+
<tr><td align="center"><b>scope</b></td><td>nn, preprocessing, training, inference</td>
|
1485 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1486 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
1487 |
+
<tr><td align="center"><b>default</b></td><td>true</td>
|
1488 |
+
</tbody></table>
|
1489 |
+
|
1490 |
+
### predict_voicing
|
1491 |
+
|
1492 |
+
Whether to enable voicing prediction.
|
1493 |
+
|
1494 |
+
<table><tbody>
|
1495 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1496 |
+
<tr><td align="center"><b>scope</b></td><td>nn, preprocessing, training, inference</td>
|
1497 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1498 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
1499 |
+
<tr><td align="center"><b>default</b></td><td>true</td>
|
1500 |
+
</tbody></table>
|
1501 |
+
|
1502 |
+
### raw_data_dir
|
1503 |
+
|
1504 |
+
Path(s) to the raw dataset including wave files, transcriptions, etc.
|
1505 |
+
|
1506 |
+
<table><tbody>
|
1507 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1508 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
1509 |
+
<tr><td align="center"><b>customizability</b></td><td>required</td>
|
1510 |
+
<tr><td align="center"><b>type</b></td><td>str, List[str]</td>
|
1511 |
+
</tbody></table>
|
1512 |
+
|
1513 |
+
### rel_pos
|
1514 |
+
|
1515 |
+
Whether to use relative positional encoding in FastSpeech2 module.
|
1516 |
+
|
1517 |
+
<table><tbody>
|
1518 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1519 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
1520 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
1521 |
+
<tr><td align="center"><b>type</b></td><td>boolean</td>
|
1522 |
+
<tr><td align="center"><b>default</b></td><td>true</td>
|
1523 |
+
</tbody></table>
|
1524 |
+
|
1525 |
+
### sampler_frame_count_grid
|
1526 |
+
|
1527 |
+
The batch sampler applies an algorithm called _sorting by similar length_ when collecting batches. Data samples are first grouped by their approximate lengths before they get shuffled within each group. Assume this value is set to $L_{grid}$, the approximate length of a data sample with length $L_{real}$ can be calculated through the following expression:
|
1528 |
+
|
1529 |
+
$$
|
1530 |
+
L_{approx} = \lfloor\frac{L_{real}}{L_{grid}}\rfloor\cdot L_{grid}
|
1531 |
+
$$
|
1532 |
+
|
1533 |
+
Training performance on some datasets may be very sensitive to this value. Change it to 1 (completely sorted by length without shuffling) to get the best performance in theory.
|
1534 |
+
|
1535 |
+
<table><tbody>
|
1536 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1537 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1538 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1539 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1540 |
+
<tr><td align="center"><b>default</b></td><td>6</td>
|
1541 |
+
</tbody></table>
|
1542 |
+
|
1543 |
+
### sampling_algorithm
|
1544 |
+
|
1545 |
+
The algorithm to solve the ODE of Rectified Flow. The following methods are currently available:
|
1546 |
+
|
1547 |
+
- Euler: The Euler method.
|
1548 |
+
- Runge-Kutta (order 2): The 2nd-order Runge-Kutta method.
|
1549 |
+
- Runge-Kutta (order 4): The 4th-order Runge-Kutta method.
|
1550 |
+
- Runge-Kutta (order 5): The 5th-order Runge-Kutta method.
|
1551 |
+
|
1552 |
+
<table><tbody>
|
1553 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1554 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
1555 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1556 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
1557 |
+
<tr><td align="center"><b>default</b></td><td>euler</td>
|
1558 |
+
<tr><td align="center"><b>constraints</b></td><td>Choose from 'euler', 'rk2', 'rk4', 'rk5'.</td>
|
1559 |
+
</tbody></table>
|
1560 |
+
|
1561 |
+
### sampling_steps
|
1562 |
+
|
1563 |
+
The total sampling steps to solve the ODE of Rectified Flow. Note that this value may not equal to NFE (Number of Function Evaluations) because some methods may require more than one function evaluation per step.
|
1564 |
+
|
1565 |
+
<table><tbody>
|
1566 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1567 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
1568 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1569 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1570 |
+
<tr><td align="center"><b>default</b></td><td>20</td>
|
1571 |
+
</tbody></table>
|
1572 |
+
|
1573 |
+
### schedule_type
|
1574 |
+
|
1575 |
+
The DDPM schedule type.
|
1576 |
+
|
1577 |
+
<table><tbody>
|
1578 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1579 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
1580 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
1581 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
1582 |
+
<tr><td align="center"><b>default</b></td><td>linear</td>
|
1583 |
+
<tr><td align="center"><b>constraints</b></td><td>Choose from 'linear', 'cosine'.</td>
|
1584 |
+
</tbody></table>
|
1585 |
+
|
1586 |
+
### shallow_diffusion_args
|
1587 |
+
|
1588 |
+
Arguments for shallow diffusion.
|
1589 |
+
|
1590 |
+
<table><tbody>
|
1591 |
+
<tr><td align="center"><b>type</b></td><td>dict</td>
|
1592 |
+
</tbody></table>
|
1593 |
+
|
1594 |
+
### shallow_diffusion_args.aux_decoder_arch
|
1595 |
+
|
1596 |
+
Architecture type of the auxiliary decoder.
|
1597 |
+
|
1598 |
+
<table><tbody>
|
1599 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1600 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
1601 |
+
<tr><td align="center"><b>customizability</b></td><td>reserved</td>
|
1602 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
1603 |
+
<tr><td align="center"><b>default</b></td><td>convnext</td>
|
1604 |
+
<tr><td align="center"><b>constraints</b></td><td>Choose from 'convnext'.</td>
|
1605 |
+
</tbody></table>
|
1606 |
+
|
1607 |
+
### shallow_diffusion_args.aux_decoder_args
|
1608 |
+
|
1609 |
+
Keyword arguments for dynamically constructing the auxiliary decoder.
|
1610 |
+
|
1611 |
+
<table><tbody>
|
1612 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1613 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
1614 |
+
<tr><td align="center"><b>type</b></td><td>dict</td>
|
1615 |
+
</tbody></table>
|
1616 |
+
|
1617 |
+
### shallow_diffusion_args.aux_decoder_grad
|
1618 |
+
|
1619 |
+
Scale factor of the gradients from the auxiliary decoder to the encoder.
|
1620 |
+
|
1621 |
+
<table><tbody>
|
1622 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1623 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1624 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1625 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
1626 |
+
<tr><td align="center"><b>default</b></td><td>0.1</td>
|
1627 |
+
</tbody></table>
|
1628 |
+
|
1629 |
+
### shallow_diffusion_args.train_aux_decoder
|
1630 |
+
|
1631 |
+
Whether to forward and backward the auxiliary decoder during training. If set to `false`, the auxiliary decoder hangs in the memory and does not get any updates.
|
1632 |
+
|
1633 |
+
<table><tbody>
|
1634 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1635 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1636 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1637 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
1638 |
+
<tr><td align="center"><b>default</b></td><td>true</td>
|
1639 |
+
</tbody></table>
|
1640 |
+
|
1641 |
+
### shallow_diffusion_args.train_diffusion
|
1642 |
+
|
1643 |
+
Whether to forward and backward the diffusion (main) decoder during training. If set to `false`, the diffusion decoder hangs in the memory and does not get any updates.
|
1644 |
+
|
1645 |
+
<table><tbody>
|
1646 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1647 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1648 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1649 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
1650 |
+
<tr><td align="center"><b>default</b></td><td>true</td>
|
1651 |
+
</tbody></table>
|
1652 |
+
|
1653 |
+
### shallow_diffusion_args.val_gt_start
|
1654 |
+
|
1655 |
+
Whether to use the ground truth as `x_start` in the shallow diffusion validation process. If set to `true`, gaussian noise is added to the ground truth before shallow diffusion is performed; otherwise the noise is added to the output of the auxiliary decoder. This option is useful when the auxiliary decoder has not been trained yet.
|
1656 |
+
|
1657 |
+
<table><tbody>
|
1658 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1659 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1660 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1661 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
1662 |
+
<tr><td align="center"><b>default</b></td><td>false</td>
|
1663 |
+
</tbody></table>
|
1664 |
+
|
1665 |
+
### sort_by_len
|
1666 |
+
|
1667 |
+
Whether to apply the _sorting by similar length_ algorithm described in [sampler_frame_count_grid](#sampler_frame_count_grid). Turning off this option may slow down training because sorting by length can better utilize the computing resources.
|
1668 |
+
|
1669 |
+
<table><tbody>
|
1670 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1671 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1672 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
1673 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
1674 |
+
<tr><td align="center"><b>default</b></td><td>true</td>
|
1675 |
+
</tbody></table>
|
1676 |
+
|
1677 |
+
### speakers
|
1678 |
+
|
1679 |
+
The names of speakers in a multi-speaker model. Speaker names are mapped to speaker indexes and stored into spk_map.json when preprocessing.
|
1680 |
+
|
1681 |
+
<table><tbody>
|
1682 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1683 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
1684 |
+
<tr><td align="center"><b>customizability</b></td><td>required</td>
|
1685 |
+
<tr><td align="center"><b>type</b></td><td>list</td>
|
1686 |
+
</tbody></table>
|
1687 |
+
|
1688 |
+
### spk_ids
|
1689 |
+
|
1690 |
+
The IDs of speakers in a multi-speaker model. If an empty list is given, speaker IDs will be automatically generated as $0,1,2,...,N_{spk}-1$. IDs can be duplicate or discontinuous.
|
1691 |
+
|
1692 |
+
<table><tbody>
|
1693 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1694 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
1695 |
+
<tr><td align="center"><b>customizability</b></td><td>required</td>
|
1696 |
+
<tr><td align="center"><b>type</b></td><td>List[int]</td>
|
1697 |
+
<tr><td align="center"><b>default</b></td><td>[]</td>
|
1698 |
+
</tbody></table>
|
1699 |
+
|
1700 |
+
### spec_min
|
1701 |
+
|
1702 |
+
Minimum mel spectrogram value used for normalization to [-1, 1]. Different mel bins can have different minimum values.
|
1703 |
+
|
1704 |
+
<table><tbody>
|
1705 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1706 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
1707 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
1708 |
+
<tr><td align="center"><b>type</b></td><td>List[float]</td>
|
1709 |
+
<tr><td align="center"><b>default</b></td><td>[-5.0]</td>
|
1710 |
+
</tbody></table>
|
1711 |
+
|
1712 |
+
### spec_max
|
1713 |
+
|
1714 |
+
Maximum mel spectrogram value used for normalization to [-1, 1]. Different mel bins can have different maximum values.
|
1715 |
+
|
1716 |
+
<table><tbody>
|
1717 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1718 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
1719 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
1720 |
+
<tr><td align="center"><b>type</b></td><td>List[float]</td>
|
1721 |
+
<tr><td align="center"><b>default</b></td><td>[0.0]</td>
|
1722 |
+
</tbody></table>
|
1723 |
+
|
1724 |
+
### T_start
|
1725 |
+
|
1726 |
+
The starting value of time $t$ in the Rectified Flow ODE which applies on $t \in (T_{start}, 1)$.
|
1727 |
+
|
1728 |
+
<table><tbody>
|
1729 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1730 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1731 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1732 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
1733 |
+
<tr><td align="center"><b>default</b></td><td>0.4</td>
|
1734 |
+
</tbody></table>
|
1735 |
+
|
1736 |
+
### T_start_infer
|
1737 |
+
|
1738 |
+
The starting value of time $t$ in the ODE during shallow Rectified Flow inference. Normally set as same as [T_start](#T_start).
|
1739 |
+
|
1740 |
+
<table><tbody>
|
1741 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1742 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
1743 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1744 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
1745 |
+
<tr><td align="center"><b>default</b></td><td>0.4</td>
|
1746 |
+
<tr><td align="center"><b>constraints</b></td><td>Should be no less than T_start.</td>
|
1747 |
+
</tbody></table>
|
1748 |
+
|
1749 |
+
### task_cls
|
1750 |
+
|
1751 |
+
Task trainer class name.
|
1752 |
+
|
1753 |
+
<table><tbody>
|
1754 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1755 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1756 |
+
<tr><td align="center"><b>customizability</b></td><td>reserved</td>
|
1757 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
1758 |
+
</tbody></table>
|
1759 |
+
|
1760 |
+
### tension_logit_max
|
1761 |
+
|
1762 |
+
Maximum tension logit value used for normalization to [-1, 1]. Logit is the reverse function of Sigmoid:
|
1763 |
+
|
1764 |
+
$$
|
1765 |
+
f(x) = \ln\frac{x}{1-x}
|
1766 |
+
$$
|
1767 |
+
|
1768 |
+
<table><tbody>
|
1769 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1770 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
1771 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1772 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
1773 |
+
<tr><td align="center"><b>default</b></td><td>10.0</td>
|
1774 |
+
</tbody></table>
|
1775 |
+
|
1776 |
+
### tension_logit_min
|
1777 |
+
|
1778 |
+
Minimum tension logit value used for normalization to [-1, 1]. Logit is the reverse function of Sigmoid:
|
1779 |
+
|
1780 |
+
$$
|
1781 |
+
f(x) = \ln\frac{x}{1-x}
|
1782 |
+
$$
|
1783 |
+
|
1784 |
+
<table><tbody>
|
1785 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1786 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
1787 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1788 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
1789 |
+
<tr><td align="center"><b>default</b></td><td>-10.0</td>
|
1790 |
+
</tbody></table>
|
1791 |
+
|
1792 |
+
### tension_smooth_width
|
1793 |
+
|
1794 |
+
Length of sinusoidal smoothing convolution kernel (in seconds) on extracted tension curve.
|
1795 |
+
|
1796 |
+
<table><tbody>
|
1797 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1798 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
1799 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1800 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
1801 |
+
<tr><td align="center"><b>default</b></td><td>0.12</td>
|
1802 |
+
</tbody></table>
|
1803 |
+
|
1804 |
+
### test_prefixes
|
1805 |
+
|
1806 |
+
List of data item names or name prefixes for the validation set. For each string `s` in the list:
|
1807 |
+
|
1808 |
+
- If `s` equals to an actual item name, add that item to validation set.
|
1809 |
+
- If `s` does not equal to any item names, add all items whose names start with `s` to validation set.
|
1810 |
+
|
1811 |
+
For multi-speaker combined datasets, "ds_id:name_prefix" can be used to apply the rules above within data from a specific sub-dataset, where ds_id represents the dataset index.
|
1812 |
+
|
1813 |
+
<table><tbody>
|
1814 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1815 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
1816 |
+
<tr><td align="center"><b>customizability</b></td><td>required</td>
|
1817 |
+
<tr><td align="center"><b>type</b></td><td>list</td>
|
1818 |
+
</tbody></table>
|
1819 |
+
|
1820 |
+
### time_scale_factor
|
1821 |
+
|
1822 |
+
The scale factor that will be multiplied on the time $t$ of Rectified Flow before embedding into the model.
|
1823 |
+
|
1824 |
+
<table><tbody>
|
1825 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1826 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
1827 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
1828 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
1829 |
+
<tr><td align="center"><b>default</b></td><td>1000</td>
|
1830 |
+
</tbody></table>
|
1831 |
+
|
1832 |
+
### timesteps
|
1833 |
+
|
1834 |
+
Total number of DDPM steps.
|
1835 |
+
|
1836 |
+
<table><tbody>
|
1837 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1838 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
1839 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
1840 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1841 |
+
<tr><td align="center"><b>default</b></td><td>1000</td>
|
1842 |
+
</tbody></table>
|
1843 |
+
|
1844 |
+
### use_breathiness_embed
|
1845 |
+
|
1846 |
+
Whether to accept and embed breathiness values into the model.
|
1847 |
+
|
1848 |
+
<table><tbody>
|
1849 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1850 |
+
<tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
|
1851 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1852 |
+
<tr><td align="center"><b>type</b></td><td>boolean</td>
|
1853 |
+
<tr><td align="center"><b>default</b></td><td>false</td>
|
1854 |
+
</tbody></table>
|
1855 |
+
|
1856 |
+
### use_energy_embed
|
1857 |
+
|
1858 |
+
Whether to accept and embed energy values into the model.
|
1859 |
+
|
1860 |
+
<table><tbody>
|
1861 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1862 |
+
<tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
|
1863 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1864 |
+
<tr><td align="center"><b>type</b></td><td>boolean</td>
|
1865 |
+
<tr><td align="center"><b>default</b></td><td>false</td>
|
1866 |
+
</tbody></table>
|
1867 |
+
|
1868 |
+
### use_glide_embed
|
1869 |
+
|
1870 |
+
Whether to accept and embed glide types in melody encoder.
|
1871 |
+
|
1872 |
+
<table><tbody>
|
1873 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1874 |
+
<tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
|
1875 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1876 |
+
<tr><td align="center"><b>type</b></td><td>boolean</td>
|
1877 |
+
<tr><td align="center"><b>default</b></td><td>false</td>
|
1878 |
+
<tr><td align="center"><b>constraints</b></td><td>Only take affects when melody encoder is enabled.</td>
|
1879 |
+
</tbody></table>
|
1880 |
+
|
1881 |
+
### use_key_shift_embed
|
1882 |
+
|
1883 |
+
Whether to embed key shifting values introduced by random pitch shifting augmentation.
|
1884 |
+
|
1885 |
+
<table><tbody>
|
1886 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1887 |
+
<tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
|
1888 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1889 |
+
<tr><td align="center"><b>type</b></td><td>boolean</td>
|
1890 |
+
<tr><td align="center"><b>default</b></td><td>false</td>
|
1891 |
+
<tr><td align="center"><b>constraints</b></td><td>Must be true if random pitch shifting is enabled.</td>
|
1892 |
+
</tbody></table>
|
1893 |
+
|
1894 |
+
### use_melody_encoder
|
1895 |
+
|
1896 |
+
Whether to enable melody encoder for the pitch predictor.
|
1897 |
+
|
1898 |
+
<table><tbody>
|
1899 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
1900 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
1901 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1902 |
+
<tr><td align="center"><b>type</b></td><td>boolean</td>
|
1903 |
+
<tr><td align="center"><b>default</b></td><td>false</td>
|
1904 |
+
</tbody></table>
|
1905 |
+
|
1906 |
+
### use_pos_embed
|
1907 |
+
|
1908 |
+
Whether to use SinusoidalPositionalEmbedding in FastSpeech2 encoder.
|
1909 |
+
|
1910 |
+
<table><tbody>
|
1911 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1912 |
+
<tr><td align="center"><b>scope</b></td><td>nn</td>
|
1913 |
+
<tr><td align="center"><b>customizability</b></td><td>not recommended</td>
|
1914 |
+
<tr><td align="center"><b>type</b></td><td>boolean</td>
|
1915 |
+
<tr><td align="center"><b>default</b></td><td>true</td>
|
1916 |
+
</tbody></table>
|
1917 |
+
|
1918 |
+
### use_shallow_diffusion
|
1919 |
+
|
1920 |
+
Whether to use shallow diffusion.
|
1921 |
+
|
1922 |
+
<table><tbody>
|
1923 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1924 |
+
<tr><td align="center"><b>scope</b></td><td>nn, inference</td>
|
1925 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1926 |
+
<tr><td align="center"><b>type</b></td><td>boolean</td>
|
1927 |
+
<tr><td align="center"><b>default</b></td><td>false</td>
|
1928 |
+
</tbody></table>
|
1929 |
+
|
1930 |
+
### use_speed_embed
|
1931 |
+
|
1932 |
+
Whether to embed speed values introduced by random time stretching augmentation.
|
1933 |
+
|
1934 |
+
<table><tbody>
|
1935 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1936 |
+
<tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
|
1937 |
+
<tr><td align="center"><b>type</b></td><td>boolean</td>
|
1938 |
+
<tr><td align="center"><b>default</b></td><td>false</td>
|
1939 |
+
<tr><td align="center"><b>constraints</b></td><td>Must be true if random time stretching is enabled.</td>
|
1940 |
+
</tbody></table>
|
1941 |
+
|
1942 |
+
### use_spk_id
|
1943 |
+
|
1944 |
+
Whether embed the speaker id from a multi-speaker dataset.
|
1945 |
+
|
1946 |
+
<table><tbody>
|
1947 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
1948 |
+
<tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
|
1949 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1950 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
1951 |
+
<tr><td align="center"><b>default</b></td><td>false</td>
|
1952 |
+
</tbody></table>
|
1953 |
+
|
1954 |
+
### use_tension_embed
|
1955 |
+
|
1956 |
+
Whether to accept and embed tension values into the model.
|
1957 |
+
|
1958 |
+
<table><tbody>
|
1959 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1960 |
+
<tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
|
1961 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1962 |
+
<tr><td align="center"><b>type</b></td><td>boolean</td>
|
1963 |
+
<tr><td align="center"><b>default</b></td><td>false</td>
|
1964 |
+
</tbody></table>
|
1965 |
+
|
1966 |
+
### use_voicing_embed
|
1967 |
+
|
1968 |
+
Whether to accept and embed voicing values into the model.
|
1969 |
+
|
1970 |
+
<table><tbody>
|
1971 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1972 |
+
<tr><td align="center"><b>scope</b></td><td>nn, preprocessing, inference</td>
|
1973 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1974 |
+
<tr><td align="center"><b>type</b></td><td>boolean</td>
|
1975 |
+
<tr><td align="center"><b>default</b></td><td>false</td>
|
1976 |
+
</tbody></table>
|
1977 |
+
|
1978 |
+
### val_check_interval
|
1979 |
+
|
1980 |
+
Interval (in number of training steps) between validation checks.
|
1981 |
+
|
1982 |
+
<table><tbody>
|
1983 |
+
<tr><td align="center"><b>visibility</b></td><td>all</td>
|
1984 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1985 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
1986 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
1987 |
+
<tr><td align="center"><b>default</b></td><td>2000</td>
|
1988 |
+
</tbody></table>
|
1989 |
+
|
1990 |
+
### val_with_vocoder
|
1991 |
+
|
1992 |
+
Whether to load and use the vocoder to generate audio during validation. Validation audio will not be available if this option is disabled.
|
1993 |
+
|
1994 |
+
<table><tbody>
|
1995 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
1996 |
+
<tr><td align="center"><b>scope</b></td><td>training</td>
|
1997 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
1998 |
+
<tr><td align="center"><b>type</b></td><td>bool</td>
|
1999 |
+
<tr><td align="center"><b>default</b></td><td>true</td>
|
2000 |
+
</tbody></table>
|
2001 |
+
|
2002 |
+
### variances_prediction_args
|
2003 |
+
|
2004 |
+
Arguments for prediction of variance parameters other than pitch, like energy, breathiness, etc.
|
2005 |
+
|
2006 |
+
<table><tbody>
|
2007 |
+
<tr><td align="center"><b>type</b></td><td>dict</td>
|
2008 |
+
</tbody></table>
|
2009 |
+
|
2010 |
+
### variances_prediction_args.backbone_args
|
2011 |
+
|
2012 |
+
Equivalent to [backbone_args](#backbone_args) but only for the multi-variance predictor.
|
2013 |
+
|
2014 |
+
<table><tbody>
|
2015 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
2016 |
+
</tbody></table>
|
2017 |
+
|
2018 |
+
### variances_prediction_args.backbone_type
|
2019 |
+
|
2020 |
+
Equivalent to [backbone_type](#backbone_type) but only for the multi-variance predictor model. If not set, use the root backbone type.
|
2021 |
+
|
2022 |
+
<table><tbody>
|
2023 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
2024 |
+
<tr><td align="center"><b>default</b></td><td>wavenet</td>
|
2025 |
+
</tbody></table>
|
2026 |
+
|
2027 |
+
### variances_prediction_args.total_repeat_bins
|
2028 |
+
|
2029 |
+
Total number of repeating bins in the multi-variance predictor. Repeating bins are distributed evenly to each variance parameter.
|
2030 |
+
|
2031 |
+
<table><tbody>
|
2032 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
2033 |
+
<tr><td align="center"><b>scope</b></td><td>nn, inference</td>
|
2034 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
2035 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
2036 |
+
<tr><td align="center"><b>default</b></td><td>48</td>
|
2037 |
+
</tbody></table>
|
2038 |
+
|
2039 |
+
### vocoder
|
2040 |
+
|
2041 |
+
The vocoder class name.
|
2042 |
+
|
2043 |
+
<table><tbody>
|
2044 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
2045 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing, training, inference</td>
|
2046 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
2047 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
2048 |
+
<tr><td align="center"><b>default</b></td><td>NsfHifiGAN</td>
|
2049 |
+
</tbody></table>
|
2050 |
+
|
2051 |
+
### vocoder_ckpt
|
2052 |
+
|
2053 |
+
Path of the vocoder model.
|
2054 |
+
|
2055 |
+
<table><tbody>
|
2056 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic</td>
|
2057 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing, training, inference</td>
|
2058 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
2059 |
+
<tr><td align="center"><b>type</b></td><td>str</td>
|
2060 |
+
<tr><td align="center"><b>default</b></td><td>checkpoints/nsf_hifigan/model</td>
|
2061 |
+
</tbody></table>
|
2062 |
+
|
2063 |
+
### voicing_db_max
|
2064 |
+
|
2065 |
+
Maximum voicing value in dB used for normalization to [-1, 1].
|
2066 |
+
|
2067 |
+
<table><tbody>
|
2068 |
+
<tr><td align="center"><b>visibility</b></td><td>variance</td>
|
2069 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
2070 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
2071 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
2072 |
+
<tr><td align="center"><b>default</b></td><td>-20.0</td>
|
2073 |
+
</tbody></table>
|
2074 |
+
|
2075 |
+
### voicing_db_min
|
2076 |
+
|
2077 |
+
Minimum voicing value in dB used for normalization to [-1, 1].
|
2078 |
+
|
2079 |
+
<table><tbody>
|
2080 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
2081 |
+
<tr><td align="center"><b>scope</b></td><td>inference</td>
|
2082 |
+
<tr><td align="center"><b>customizability</b></td><td>recommended</td>
|
2083 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
2084 |
+
<tr><td align="center"><b>default</b></td><td>-96.0</td>
|
2085 |
+
</tbody></table>
|
2086 |
+
|
2087 |
+
### voicing_smooth_width
|
2088 |
+
|
2089 |
+
Length of sinusoidal smoothing convolution kernel (in seconds) on extracted voicing curve.
|
2090 |
+
|
2091 |
+
<table><tbody>
|
2092 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
2093 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
2094 |
+
<tr><td align="center"><b>customizability</b></td><td>normal</td>
|
2095 |
+
<tr><td align="center"><b>type</b></td><td>float</td>
|
2096 |
+
<tr><td align="center"><b>default</b></td><td>0.12</td>
|
2097 |
+
</tbody></table>
|
2098 |
+
|
2099 |
+
### win_size
|
2100 |
+
|
2101 |
+
Window size for mel or feature extraction.
|
2102 |
+
|
2103 |
+
<table><tbody>
|
2104 |
+
<tr><td align="center"><b>visibility</b></td><td>acoustic, variance</td>
|
2105 |
+
<tr><td align="center"><b>scope</b></td><td>preprocessing</td>
|
2106 |
+
<tr><td align="center"><b>customizability</b></td><td>reserved</td>
|
2107 |
+
<tr><td align="center"><b>type</b></td><td>int</td>
|
2108 |
+
<tr><td align="center"><b>default</b></td><td>2048</td>
|
2109 |
+
</tbody></table>
|
docs/GettingStarted.md
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Getting Started
|
2 |
+
|
3 |
+
## Installation
|
4 |
+
|
5 |
+
### Environments and dependencies
|
6 |
+
|
7 |
+
DiffSinger requires Python 3.8 or later. We strongly recommend you create a virtual environment via Conda or venv before installing dependencies.
|
8 |
+
|
9 |
+
1. Install The latest PyTorch following the [official instructions](https://pytorch.org/get-started/locally/) according to your OS and hardware.
|
10 |
+
|
11 |
+
2. Install other dependencies via the following command:
|
12 |
+
|
13 |
+
```bash
|
14 |
+
pip install -r requirements.txt
|
15 |
+
```
|
16 |
+
|
17 |
+
### Materials and assets
|
18 |
+
|
19 |
+
Some essential materials and assets are needed before continuing with this repository. See [materials for training and using models](BestPractices.md#materials-for-training-and-using-models) for detailed instructions.
|
20 |
+
|
21 |
+
## Configuration
|
22 |
+
|
23 |
+
Every model needs a configuration file to run preprocessing, training, inference and deployment. Templates of configurations files are in [configs/templates](../configs/templates). Please **copy** the templates to your own data directory before you edit them.
|
24 |
+
|
25 |
+
Before you continue, it is highly recommended to read through [Best Practices](BestPractices.md), which is a more detailed tutorial on how to configure your experiments.
|
26 |
+
|
27 |
+
For more details about configurable parameters, see [Configuration Schemas](ConfigurationSchemas.md).
|
28 |
+
|
29 |
+
> Tips: to see which parameters are required or recommended to be edited, you can search by _customizability_ in the configuration schemas.
|
30 |
+
|
31 |
+
## Preprocessing
|
32 |
+
|
33 |
+
Raw data pieces and transcriptions should be binarized into dataset files before training. Before doing this step, please ensure all required configurations like `raw_data_dir` and `binary_data_dir` are set properly, and all your desired functionalities and features are enabled and configured.
|
34 |
+
|
35 |
+
Assume that you have a configuration file called `my_config.yaml`. Run:
|
36 |
+
|
37 |
+
```bash
|
38 |
+
python scripts/binarize.py --config my_config.yaml
|
39 |
+
```
|
40 |
+
|
41 |
+
Preprocessing can be accelerated through multiprocessing. See [binarization_args.num_workers](ConfigurationSchemas.md#binarization_args.num_workers) for more explanations.
|
42 |
+
|
43 |
+
## Training
|
44 |
+
|
45 |
+
Assume that you have a configuration file called `my_config.yaml` and the name of your model is `my_experiment`. Run:
|
46 |
+
|
47 |
+
```bash
|
48 |
+
python scripts/train.py --config my_config.yaml --exp_name my_experiment --reset
|
49 |
+
```
|
50 |
+
|
51 |
+
Checkpoints will be saved at the `checkpoints/my_experiment/` directory. When interrupting the program and running the above command again, the training resumes automatically from the latest checkpoint.
|
52 |
+
|
53 |
+
For more suggestions related to training performance, see [performance tuning](BestPractices.md#performance-tuning).
|
54 |
+
|
55 |
+
### TensorBoard
|
56 |
+
|
57 |
+
Run the following command to start the TensorBoard:
|
58 |
+
|
59 |
+
```bash
|
60 |
+
tensorboard --logdir checkpoints/
|
61 |
+
```
|
62 |
+
|
63 |
+
> NOTICE
|
64 |
+
>
|
65 |
+
> If you are training a model with multiple GPUs (DDP), please add `--reload_multifile=true` option when launching TensorBoard, otherwise it may not update properly.
|
66 |
+
|
67 |
+
## Inference
|
68 |
+
|
69 |
+
Inference of DiffSinger is based on DS files. Assume that you have a DS file named `my_song.ds` and your model is named `my_experiment`.
|
70 |
+
|
71 |
+
If your model is a variance model, run:
|
72 |
+
|
73 |
+
```bash
|
74 |
+
python scripts/infer.py variance my_song.ds --exp my_experiment
|
75 |
+
```
|
76 |
+
|
77 |
+
or run
|
78 |
+
|
79 |
+
```bash
|
80 |
+
python scripts/infer.py variance --help
|
81 |
+
```
|
82 |
+
|
83 |
+
for more configurable options.
|
84 |
+
|
85 |
+
If your model is an acoustic model, run:
|
86 |
+
|
87 |
+
```bash
|
88 |
+
python scripts/infer.py acoustic my_song.ds --exp my_experiment
|
89 |
+
```
|
90 |
+
|
91 |
+
or run
|
92 |
+
|
93 |
+
```bash
|
94 |
+
python scripts/infer.py acoustic --help
|
95 |
+
```
|
96 |
+
|
97 |
+
for more configurable options.
|
98 |
+
|
99 |
+
## Deployment
|
100 |
+
|
101 |
+
DiffSinger uses [ONNX](https://onnx.ai/) as the deployment format.
|
102 |
+
|
103 |
+
Due to TorchScript issues, exporting to ONNX now requires PyTorch **1.13**. Please ensure the correct dependencies through following steps:
|
104 |
+
|
105 |
+
1. Create a new separate environment for exporting ONNX.
|
106 |
+
|
107 |
+
2. Install PyTorch 1.13 following the [official instructions](https://pytorch.org/get-started/previous-versions/). A CPU-only version is enough.
|
108 |
+
|
109 |
+
3. Install other dependencies via the following command:
|
110 |
+
|
111 |
+
```bash
|
112 |
+
pip install -r requirements-onnx.txt
|
113 |
+
```
|
114 |
+
|
115 |
+
Assume that you have a model named `my_experiment`.
|
116 |
+
|
117 |
+
If your model is a variance model, run:
|
118 |
+
|
119 |
+
```bash
|
120 |
+
python scripts/export.py variance --exp my_experiment
|
121 |
+
```
|
122 |
+
|
123 |
+
or run
|
124 |
+
|
125 |
+
```bash
|
126 |
+
python scripts/export.py variance --help
|
127 |
+
```
|
128 |
+
|
129 |
+
for more configurable options.
|
130 |
+
|
131 |
+
If your model is an acoustic model, run:
|
132 |
+
|
133 |
+
```bash
|
134 |
+
python scripts/export.py acoustic --exp my_experiment
|
135 |
+
```
|
136 |
+
|
137 |
+
or run
|
138 |
+
|
139 |
+
```bash
|
140 |
+
python scripts/export.py acoustic --help
|
141 |
+
```
|
142 |
+
|
143 |
+
for more configurable options.
|
144 |
+
|
145 |
+
To export an NSF-HiFiGAN vocoder checkpoint, run:
|
146 |
+
|
147 |
+
```bash
|
148 |
+
python scripts/export.py nsf-hifigan --config CONFIG --ckpt CKPT
|
149 |
+
```
|
150 |
+
|
151 |
+
where `CONFIG` is a configuration file that has configured the same mel parameters as the vocoder (can be configs/acoustic.yaml for most cases) and `CKPT` is the path of the checkpoint to be exported.
|
152 |
+
|
153 |
+
For more configurable options, run
|
154 |
+
|
155 |
+
```bash
|
156 |
+
python scripts/export.py nsf-hifigan --help
|
157 |
+
```
|
158 |
+
|
159 |
+
## Other utilities
|
160 |
+
|
161 |
+
There are other useful CLI tools in the [scripts/](../scripts) directory not mentioned above:
|
162 |
+
|
163 |
+
- drop_spk.py - delete speaker embeddings from checkpoints (for data security reasons when distributing models)
|
164 |
+
- vocoder.py - bypass the acoustic model and only run the vocoder on given mel-spectrograms
|
docs/resources/arch-acoustic.drawio
ADDED
The diff for this file is too large to render.
See raw diff
|
|
docs/resources/arch-overview.drawio
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<mxfile host="Electron" modified="2023-07-06T16:23:00.268Z" agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/21.5.1 Chrome/112.0.5615.204 Electron/24.6.0 Safari/537.36" etag="Jj5RvJnN6RNPHxq1C0K-" version="21.5.1" type="device">
|
2 |
+
<diagram name="第 1 页" id="YZcYkKBV4jQxwUdL2UfG">
|
3 |
+
<mxGraphModel dx="1270" dy="914" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
|
4 |
+
<root>
|
5 |
+
<mxCell id="0" />
|
6 |
+
<mxCell id="1" parent="0" />
|
7 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-8" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.25;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" parent="1" source="gMzeuS7BrVqeg4yIgOCW-1" target="FjPlvFc2R7vJgaK0KZA3-1" edge="1">
|
8 |
+
<mxGeometry relative="1" as="geometry" />
|
9 |
+
</mxCell>
|
10 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-9" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;" parent="1" source="gMzeuS7BrVqeg4yIgOCW-1" target="FjPlvFc2R7vJgaK0KZA3-3" edge="1">
|
11 |
+
<mxGeometry relative="1" as="geometry" />
|
12 |
+
</mxCell>
|
13 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-10" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.75;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" parent="1" source="gMzeuS7BrVqeg4yIgOCW-1" target="FjPlvFc2R7vJgaK0KZA3-6" edge="1">
|
14 |
+
<mxGeometry relative="1" as="geometry" />
|
15 |
+
</mxCell>
|
16 |
+
<mxCell id="gMzeuS7BrVqeg4yIgOCW-1" value="<font face="Times New Roman" style="font-size: 20px;">Variance Model</font>" style="rounded=1;whiteSpace=wrap;html=1;strokeWidth=1.5;fillColor=#dae8fc;strokeColor=#6c8ebf;" parent="1" vertex="1">
|
17 |
+
<mxGeometry x="232.97" y="720" width="300" height="50" as="geometry" />
|
18 |
+
</mxCell>
|
19 |
+
<mxCell id="gMzeuS7BrVqeg4yIgOCW-3" value="<font face="Times New Roman" style="font-size: 16px;">Phoneme</font>" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
|
20 |
+
<mxGeometry x="267.97" y="800" width="80" height="30" as="geometry" />
|
21 |
+
</mxCell>
|
22 |
+
<mxCell id="gMzeuS7BrVqeg4yIgOCW-4" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.25;entryY=1;entryDx=0;entryDy=0;strokeWidth=1;" parent="1" source="gMzeuS7BrVqeg4yIgOCW-3" target="gMzeuS7BrVqeg4yIgOCW-1" edge="1">
|
23 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
24 |
+
<mxPoint x="562.97" y="620" as="sourcePoint" />
|
25 |
+
<mxPoint x="612.97" y="570" as="targetPoint" />
|
26 |
+
</mxGeometry>
|
27 |
+
</mxCell>
|
28 |
+
<mxCell id="gMzeuS7BrVqeg4yIgOCW-6" value="<font face="Times New Roman" style="font-size: 16px;">Word</font>" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
|
29 |
+
<mxGeometry x="342.97" y="800" width="80" height="30" as="geometry" />
|
30 |
+
</mxCell>
|
31 |
+
<mxCell id="gMzeuS7BrVqeg4yIgOCW-7" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeWidth=1;" parent="1" source="gMzeuS7BrVqeg4yIgOCW-6" target="gMzeuS7BrVqeg4yIgOCW-1" edge="1">
|
32 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
33 |
+
<mxPoint x="657.97" y="635" as="sourcePoint" />
|
34 |
+
<mxPoint x="402.97" y="785" as="targetPoint" />
|
35 |
+
</mxGeometry>
|
36 |
+
</mxCell>
|
37 |
+
<mxCell id="gMzeuS7BrVqeg4yIgOCW-8" value="<font face="Times New Roman" style="font-size: 16px;">MIDI</font>" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
|
38 |
+
<mxGeometry x="417.97" y="800" width="80" height="30" as="geometry" />
|
39 |
+
</mxCell>
|
40 |
+
<mxCell id="gMzeuS7BrVqeg4yIgOCW-9" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.75;entryY=1;entryDx=0;entryDy=0;strokeWidth=1;" parent="1" source="gMzeuS7BrVqeg4yIgOCW-8" target="gMzeuS7BrVqeg4yIgOCW-1" edge="1">
|
41 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
42 |
+
<mxPoint x="717.97" y="635" as="sourcePoint" />
|
43 |
+
<mxPoint x="462.97" y="785" as="targetPoint" />
|
44 |
+
</mxGeometry>
|
45 |
+
</mxCell>
|
46 |
+
<mxCell id="gMzeuS7BrVqeg4yIgOCW-10" value="<font face="Times New Roman" style="font-size: 20px;">Acoustic Model</font>" style="rounded=1;whiteSpace=wrap;html=1;strokeWidth=1.5;fillColor=#dae8fc;strokeColor=#6c8ebf;" parent="1" vertex="1">
|
47 |
+
<mxGeometry x="232.97" y="580" width="300" height="50" as="geometry" />
|
48 |
+
</mxCell>
|
49 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-1" value="<font face="Times New Roman" style="font-size: 16px;">Duration</font>" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
|
50 |
+
<mxGeometry x="268.97" y="660" width="78" height="30" as="geometry" />
|
51 |
+
</mxCell>
|
52 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-2" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;strokeWidth=1;entryX=0.25;entryY=1;entryDx=0;entryDy=0;" parent="1" source="FjPlvFc2R7vJgaK0KZA3-1" target="gMzeuS7BrVqeg4yIgOCW-10" edge="1">
|
53 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
54 |
+
<mxPoint x="442.97" y="670" as="sourcePoint" />
|
55 |
+
<mxPoint x="312.97" y="630" as="targetPoint" />
|
56 |
+
</mxGeometry>
|
57 |
+
</mxCell>
|
58 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-3" value="<font face="Times New Roman" style="font-size: 16px;">Pitch</font>" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
|
59 |
+
<mxGeometry x="343.97" y="660" width="78" height="30" as="geometry" />
|
60 |
+
</mxCell>
|
61 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-4" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;strokeWidth=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" parent="1" source="FjPlvFc2R7vJgaK0KZA3-3" target="gMzeuS7BrVqeg4yIgOCW-10" edge="1">
|
62 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
63 |
+
<mxPoint x="526.97" y="680" as="sourcePoint" />
|
64 |
+
<mxPoint x="391.97" y="630" as="targetPoint" />
|
65 |
+
</mxGeometry>
|
66 |
+
</mxCell>
|
67 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-5" value="" style="endArrow=classic;html=1;rounded=0;exitX=0;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.75;entryDx=0;entryDy=0;edgeStyle=orthogonalEdgeStyle;" parent="1" source="gMzeuS7BrVqeg4yIgOCW-3" target="gMzeuS7BrVqeg4yIgOCW-10" edge="1">
|
68 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
69 |
+
<mxPoint x="432.97" y="710" as="sourcePoint" />
|
70 |
+
<mxPoint x="482.97" y="660" as="targetPoint" />
|
71 |
+
<Array as="points">
|
72 |
+
<mxPoint x="213" y="815" />
|
73 |
+
<mxPoint x="213" y="618" />
|
74 |
+
<mxPoint x="233" y="618" />
|
75 |
+
</Array>
|
76 |
+
</mxGeometry>
|
77 |
+
</mxCell>
|
78 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-6" value="<font face="Times New Roman" style="">Variance Parameters<br><font style="font-size: 10px;">(energy, breathiness, etc.)</font><br></font>" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
|
79 |
+
<mxGeometry x="392.97" y="660" width="130" height="30" as="geometry" />
|
80 |
+
</mxCell>
|
81 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-7" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;strokeWidth=1;entryX=0.75;entryY=1;entryDx=0;entryDy=0;" parent="1" source="FjPlvFc2R7vJgaK0KZA3-6" target="gMzeuS7BrVqeg4yIgOCW-10" edge="1">
|
82 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
83 |
+
<mxPoint x="625.97" y="700" as="sourcePoint" />
|
84 |
+
<mxPoint x="481.97" y="640" as="targetPoint" />
|
85 |
+
</mxGeometry>
|
86 |
+
</mxCell>
|
87 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-11" value="<font face="Times New Roman" style="">Transformation Parameters<br><font style="font-size: 10px;">(gender &amp; velocity)</font><br></font>" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
|
88 |
+
<mxGeometry x="521.94" y="650" width="92.06" height="50" as="geometry" />
|
89 |
+
</mxCell>
|
90 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-13" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=1;entryY=0.75;entryDx=0;entryDy=0;edgeStyle=orthogonalEdgeStyle;" parent="1" source="FjPlvFc2R7vJgaK0KZA3-11" target="gMzeuS7BrVqeg4yIgOCW-10" edge="1">
|
91 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
92 |
+
<mxPoint x="412.97" y="690" as="sourcePoint" />
|
93 |
+
<mxPoint x="462.97" y="640" as="targetPoint" />
|
94 |
+
</mxGeometry>
|
95 |
+
</mxCell>
|
96 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-22" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" parent="1" source="FjPlvFc2R7vJgaK0KZA3-18" target="FjPlvFc2R7vJgaK0KZA3-21" edge="1">
|
97 |
+
<mxGeometry relative="1" as="geometry" />
|
98 |
+
</mxCell>
|
99 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-18" value="<font face="Times New Roman" style="font-size: 16px;">Mel-spectrogram<br></font>" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
|
100 |
+
<mxGeometry x="315.97" y="520" width="134" height="30" as="geometry" />
|
101 |
+
</mxCell>
|
102 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-19" value="" style="endArrow=classic;html=1;rounded=0;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" parent="1" source="gMzeuS7BrVqeg4yIgOCW-10" target="FjPlvFc2R7vJgaK0KZA3-18" edge="1">
|
103 |
+
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
104 |
+
<mxPoint x="532.97" y="570" as="sourcePoint" />
|
105 |
+
<mxPoint x="582.97" y="520" as="targetPoint" />
|
106 |
+
</mxGeometry>
|
107 |
+
</mxCell>
|
108 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-33" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" parent="1" source="FjPlvFc2R7vJgaK0KZA3-21" target="FjPlvFc2R7vJgaK0KZA3-32" edge="1">
|
109 |
+
<mxGeometry relative="1" as="geometry" />
|
110 |
+
</mxCell>
|
111 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-21" value="<font face="Times New Roman" style="font-size: 20px;">Vocoder</font>" style="rounded=1;whiteSpace=wrap;html=1;strokeWidth=1.5;fillColor=#e1d5e7;strokeColor=#9673a6;" parent="1" vertex="1">
|
112 |
+
<mxGeometry x="292.94" y="440.13" width="180.06" height="50" as="geometry" />
|
113 |
+
</mxCell>
|
114 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-32" value="" style="shape=image;verticalLabelPosition=bottom;labelBackgroundColor=default;verticalAlign=top;aspect=fixed;imageAspect=0;image=data:image/png,;" parent="1" vertex="1">
|
115 |
+
<mxGeometry x="236.79" y="370" width="292.36" height="42.48" as="geometry" />
|
116 |
+
</mxCell>
|
117 |
+
<mxCell id="FjPlvFc2R7vJgaK0KZA3-34" value="<font face="Times New Roman" style="font-size: 16px;">Waveform<br></font>" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
|
118 |
+
<mxGeometry x="315.97" y="340" width="134" height="30" as="geometry" />
|
119 |
+
</mxCell>
|
120 |
+
</root>
|
121 |
+
</mxGraphModel>
|
122 |
+
</diagram>
|
123 |
+
</mxfile>
|
docs/resources/arch-overview.jpg
ADDED
![]() |
docs/resources/arch-variance.drawio
ADDED
The diff for this file is too large to render.
See raw diff
|
|
inference/dpm_solver_pytorch.py
ADDED
@@ -0,0 +1,1305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
class NoiseScheduleVP:
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
schedule='discrete',
|
10 |
+
betas=None,
|
11 |
+
alphas_cumprod=None,
|
12 |
+
continuous_beta_0=0.1,
|
13 |
+
continuous_beta_1=20.,
|
14 |
+
dtype=torch.float32,
|
15 |
+
):
|
16 |
+
"""Create a wrapper class for the forward SDE (VP type).
|
17 |
+
|
18 |
+
***
|
19 |
+
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
20 |
+
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
|
21 |
+
***
|
22 |
+
|
23 |
+
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
24 |
+
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
25 |
+
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
26 |
+
|
27 |
+
log_alpha_t = self.marginal_log_mean_coeff(t)
|
28 |
+
sigma_t = self.marginal_std(t)
|
29 |
+
lambda_t = self.marginal_lambda(t)
|
30 |
+
|
31 |
+
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
32 |
+
|
33 |
+
t = self.inverse_lambda(lambda_t)
|
34 |
+
|
35 |
+
===============================================================
|
36 |
+
|
37 |
+
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
38 |
+
|
39 |
+
1. For discrete-time DPMs:
|
40 |
+
|
41 |
+
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
42 |
+
t_i = (i + 1) / N
|
43 |
+
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
44 |
+
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
48 |
+
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
49 |
+
|
50 |
+
Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
51 |
+
|
52 |
+
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
53 |
+
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
54 |
+
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
55 |
+
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
56 |
+
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
57 |
+
and
|
58 |
+
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
59 |
+
|
60 |
+
|
61 |
+
2. For continuous-time DPMs:
|
62 |
+
|
63 |
+
We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise
|
64 |
+
schedule are the default settings in Yang Song's ScoreSDE:
|
65 |
+
|
66 |
+
Args:
|
67 |
+
beta_min: A `float` number. The smallest beta for the linear schedule.
|
68 |
+
beta_max: A `float` number. The largest beta for the linear schedule.
|
69 |
+
T: A `float` number. The ending time of the forward process.
|
70 |
+
|
71 |
+
===============================================================
|
72 |
+
|
73 |
+
Args:
|
74 |
+
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
75 |
+
'linear' for continuous-time DPMs.
|
76 |
+
Returns:
|
77 |
+
A wrapper object of the forward SDE (VP type).
|
78 |
+
|
79 |
+
===============================================================
|
80 |
+
|
81 |
+
Example:
|
82 |
+
|
83 |
+
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
84 |
+
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
85 |
+
|
86 |
+
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
87 |
+
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
88 |
+
|
89 |
+
# For continuous-time DPMs (VPSDE), linear schedule:
|
90 |
+
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
91 |
+
|
92 |
+
"""
|
93 |
+
|
94 |
+
if schedule not in ['discrete', 'linear']:
|
95 |
+
raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear'".format(schedule))
|
96 |
+
|
97 |
+
self.schedule = schedule
|
98 |
+
if schedule == 'discrete':
|
99 |
+
if betas is not None:
|
100 |
+
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
101 |
+
else:
|
102 |
+
assert alphas_cumprod is not None
|
103 |
+
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
104 |
+
self.T = 1.
|
105 |
+
self.log_alpha_array = self.numerical_clip_alpha(log_alphas).reshape((1, -1,)).to(dtype=dtype)
|
106 |
+
self.total_N = self.log_alpha_array.shape[1]
|
107 |
+
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
|
108 |
+
else:
|
109 |
+
self.T = 1.
|
110 |
+
self.total_N = 1000
|
111 |
+
self.beta_0 = continuous_beta_0
|
112 |
+
self.beta_1 = continuous_beta_1
|
113 |
+
|
114 |
+
def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1):
|
115 |
+
"""
|
116 |
+
For some beta schedules such as cosine schedule, the log-SNR has numerical isssues.
|
117 |
+
We clip the log-SNR near t=T within -5.1 to ensure the stability.
|
118 |
+
Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE.
|
119 |
+
"""
|
120 |
+
log_sigmas = 0.5 * torch.log(1. - torch.exp(2. * log_alphas))
|
121 |
+
lambs = log_alphas - log_sigmas
|
122 |
+
idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda)
|
123 |
+
if idx > 0:
|
124 |
+
log_alphas = log_alphas[:-idx]
|
125 |
+
return log_alphas
|
126 |
+
|
127 |
+
def marginal_log_mean_coeff(self, t):
|
128 |
+
"""
|
129 |
+
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
130 |
+
"""
|
131 |
+
if self.schedule == 'discrete':
|
132 |
+
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
|
133 |
+
elif self.schedule == 'linear':
|
134 |
+
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
135 |
+
|
136 |
+
def marginal_alpha(self, t):
|
137 |
+
"""
|
138 |
+
Compute alpha_t of a given continuous-time label t in [0, T].
|
139 |
+
"""
|
140 |
+
return torch.exp(self.marginal_log_mean_coeff(t))
|
141 |
+
|
142 |
+
def marginal_std(self, t):
|
143 |
+
"""
|
144 |
+
Compute sigma_t of a given continuous-time label t in [0, T].
|
145 |
+
"""
|
146 |
+
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
147 |
+
|
148 |
+
def marginal_lambda(self, t):
|
149 |
+
"""
|
150 |
+
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
151 |
+
"""
|
152 |
+
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
153 |
+
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
154 |
+
return log_mean_coeff - log_std
|
155 |
+
|
156 |
+
def inverse_lambda(self, lamb):
|
157 |
+
"""
|
158 |
+
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
159 |
+
"""
|
160 |
+
if self.schedule == 'linear':
|
161 |
+
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
162 |
+
Delta = self.beta_0**2 + tmp
|
163 |
+
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
164 |
+
elif self.schedule == 'discrete':
|
165 |
+
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
166 |
+
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
|
167 |
+
return t.reshape((-1,))
|
168 |
+
|
169 |
+
|
170 |
+
def model_wrapper(
|
171 |
+
model,
|
172 |
+
noise_schedule,
|
173 |
+
model_type="noise",
|
174 |
+
model_kwargs={},
|
175 |
+
guidance_type="uncond",
|
176 |
+
condition=None,
|
177 |
+
unconditional_condition=None,
|
178 |
+
guidance_scale=1.,
|
179 |
+
classifier_fn=None,
|
180 |
+
classifier_kwargs={},
|
181 |
+
):
|
182 |
+
"""Create a wrapper function for the noise prediction model.
|
183 |
+
|
184 |
+
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
185 |
+
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
186 |
+
|
187 |
+
We support four types of the diffusion model by setting `model_type`:
|
188 |
+
|
189 |
+
1. "noise": noise prediction model. (Trained by predicting noise).
|
190 |
+
|
191 |
+
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
192 |
+
|
193 |
+
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
194 |
+
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
195 |
+
|
196 |
+
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
|
197 |
+
arXiv preprint arXiv:2202.00512 (2022).
|
198 |
+
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
199 |
+
arXiv preprint arXiv:2210.02303 (2022).
|
200 |
+
|
201 |
+
4. "score": marginal score function. (Trained by denoising score matching).
|
202 |
+
Note that the score function and the noise prediction model follows a simple relationship:
|
203 |
+
```
|
204 |
+
noise(x_t, t) = -sigma_t * score(x_t, t)
|
205 |
+
```
|
206 |
+
|
207 |
+
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
208 |
+
1. "uncond": unconditional sampling by DPMs.
|
209 |
+
The input `model` has the following format:
|
210 |
+
``
|
211 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
212 |
+
``
|
213 |
+
|
214 |
+
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
215 |
+
The input `model` has the following format:
|
216 |
+
``
|
217 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
218 |
+
``
|
219 |
+
|
220 |
+
The input `classifier_fn` has the following format:
|
221 |
+
``
|
222 |
+
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
223 |
+
``
|
224 |
+
|
225 |
+
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
|
226 |
+
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
227 |
+
|
228 |
+
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
229 |
+
The input `model` has the following format:
|
230 |
+
``
|
231 |
+
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
232 |
+
``
|
233 |
+
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
234 |
+
|
235 |
+
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
236 |
+
arXiv preprint arXiv:2207.12598 (2022).
|
237 |
+
|
238 |
+
|
239 |
+
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
240 |
+
or continuous-time labels (i.e. epsilon to T).
|
241 |
+
|
242 |
+
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
243 |
+
``
|
244 |
+
def model_fn(x, t_continuous) -> noise:
|
245 |
+
t_input = get_model_input_time(t_continuous)
|
246 |
+
return noise_pred(model, x, t_input, **model_kwargs)
|
247 |
+
``
|
248 |
+
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
249 |
+
|
250 |
+
===============================================================
|
251 |
+
|
252 |
+
Args:
|
253 |
+
model: A diffusion model with the corresponding format described above.
|
254 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
255 |
+
model_type: A `str`. The parameterization type of the diffusion model.
|
256 |
+
"noise" or "x_start" or "v" or "score".
|
257 |
+
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
258 |
+
guidance_type: A `str`. The type of the guidance for sampling.
|
259 |
+
"uncond" or "classifier" or "classifier-free".
|
260 |
+
condition: A pytorch tensor. The condition for the guided sampling.
|
261 |
+
Only used for "classifier" or "classifier-free" guidance type.
|
262 |
+
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
263 |
+
Only used for "classifier-free" guidance type.
|
264 |
+
guidance_scale: A `float`. The scale for the guided sampling.
|
265 |
+
classifier_fn: A classifier function. Only used for the classifier guidance.
|
266 |
+
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
267 |
+
Returns:
|
268 |
+
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
269 |
+
"""
|
270 |
+
|
271 |
+
def get_model_input_time(t_continuous):
|
272 |
+
"""
|
273 |
+
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
274 |
+
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
275 |
+
For continuous-time DPMs, we just use `t_continuous`.
|
276 |
+
"""
|
277 |
+
if noise_schedule.schedule == 'discrete':
|
278 |
+
return (t_continuous - 1. / noise_schedule.total_N) * noise_schedule.total_N
|
279 |
+
else:
|
280 |
+
return t_continuous
|
281 |
+
|
282 |
+
def noise_pred_fn(x, t_continuous, cond=None):
|
283 |
+
t_input = get_model_input_time(t_continuous)
|
284 |
+
if cond is None:
|
285 |
+
output = model(x, t_input, **model_kwargs)
|
286 |
+
else:
|
287 |
+
output = model(x, t_input, cond, **model_kwargs)
|
288 |
+
if model_type == "noise":
|
289 |
+
return output
|
290 |
+
elif model_type == "x_start":
|
291 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
292 |
+
return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
|
293 |
+
elif model_type == "v":
|
294 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
295 |
+
return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
|
296 |
+
elif model_type == "score":
|
297 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
298 |
+
return -expand_dims(sigma_t, x.dim()) * output
|
299 |
+
|
300 |
+
def cond_grad_fn(x, t_input):
|
301 |
+
"""
|
302 |
+
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
303 |
+
"""
|
304 |
+
with torch.enable_grad():
|
305 |
+
x_in = x.detach().requires_grad_(True)
|
306 |
+
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
307 |
+
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
308 |
+
|
309 |
+
def model_fn(x, t_continuous):
|
310 |
+
"""
|
311 |
+
The noise predicition model function that is used for DPM-Solver.
|
312 |
+
"""
|
313 |
+
if guidance_type == "uncond":
|
314 |
+
return noise_pred_fn(x, t_continuous)
|
315 |
+
elif guidance_type == "classifier":
|
316 |
+
assert classifier_fn is not None
|
317 |
+
t_input = get_model_input_time(t_continuous)
|
318 |
+
cond_grad = cond_grad_fn(x, t_input)
|
319 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
320 |
+
noise = noise_pred_fn(x, t_continuous)
|
321 |
+
return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
|
322 |
+
elif guidance_type == "classifier-free":
|
323 |
+
if guidance_scale == 1. or unconditional_condition is None:
|
324 |
+
return noise_pred_fn(x, t_continuous, cond=condition)
|
325 |
+
else:
|
326 |
+
x_in = torch.cat([x] * 2)
|
327 |
+
t_in = torch.cat([t_continuous] * 2)
|
328 |
+
c_in = torch.cat([unconditional_condition, condition])
|
329 |
+
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
330 |
+
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
331 |
+
|
332 |
+
assert model_type in ["noise", "x_start", "v", "score"]
|
333 |
+
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
334 |
+
return model_fn
|
335 |
+
|
336 |
+
|
337 |
+
class DPM_Solver:
|
338 |
+
def __init__(
|
339 |
+
self,
|
340 |
+
model_fn,
|
341 |
+
noise_schedule,
|
342 |
+
algorithm_type="dpmsolver++",
|
343 |
+
correcting_x0_fn=None,
|
344 |
+
correcting_xt_fn=None,
|
345 |
+
thresholding_max_val=1.,
|
346 |
+
dynamic_thresholding_ratio=0.995,
|
347 |
+
):
|
348 |
+
"""Construct a DPM-Solver.
|
349 |
+
|
350 |
+
We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
|
351 |
+
|
352 |
+
We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
|
353 |
+
can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
|
354 |
+
dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
|
355 |
+
DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
|
356 |
+
DPMs (such as stable-diffusion).
|
357 |
+
|
358 |
+
To support advanced algorithms in image-to-image applications, we also support corrector functions for
|
359 |
+
both x0 and xt.
|
360 |
+
|
361 |
+
Args:
|
362 |
+
model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
|
363 |
+
``
|
364 |
+
def model_fn(x, t_continuous):
|
365 |
+
return noise
|
366 |
+
``
|
367 |
+
The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
|
368 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
369 |
+
algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
|
370 |
+
correcting_x0_fn: A `str` or a function with the following format:
|
371 |
+
```
|
372 |
+
def correcting_x0_fn(x0, t):
|
373 |
+
x0_new = ...
|
374 |
+
return x0_new
|
375 |
+
```
|
376 |
+
This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
|
377 |
+
```
|
378 |
+
x0_pred = data_pred_model(xt, t)
|
379 |
+
if correcting_x0_fn is not None:
|
380 |
+
x0_pred = correcting_x0_fn(x0_pred, t)
|
381 |
+
xt_1 = update(x0_pred, xt, t)
|
382 |
+
```
|
383 |
+
If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
|
384 |
+
correcting_xt_fn: A function with the following format:
|
385 |
+
```
|
386 |
+
def correcting_xt_fn(xt, t, step):
|
387 |
+
x_new = ...
|
388 |
+
return x_new
|
389 |
+
```
|
390 |
+
This function is to correct the intermediate samples xt at each sampling step. e.g.,
|
391 |
+
```
|
392 |
+
xt = ...
|
393 |
+
xt = correcting_xt_fn(xt, t, step)
|
394 |
+
```
|
395 |
+
thresholding_max_val: A `float`. The max value for thresholding.
|
396 |
+
Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
|
397 |
+
dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
|
398 |
+
Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
|
399 |
+
|
400 |
+
[1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
|
401 |
+
Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
|
402 |
+
with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
|
403 |
+
"""
|
404 |
+
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
|
405 |
+
self.noise_schedule = noise_schedule
|
406 |
+
assert algorithm_type in ["dpmsolver", "dpmsolver++"]
|
407 |
+
self.algorithm_type = algorithm_type
|
408 |
+
if correcting_x0_fn == "dynamic_thresholding":
|
409 |
+
self.correcting_x0_fn = self.dynamic_thresholding_fn
|
410 |
+
else:
|
411 |
+
self.correcting_x0_fn = correcting_x0_fn
|
412 |
+
self.correcting_xt_fn = correcting_xt_fn
|
413 |
+
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
|
414 |
+
self.thresholding_max_val = thresholding_max_val
|
415 |
+
|
416 |
+
def dynamic_thresholding_fn(self, x0, t):
|
417 |
+
"""
|
418 |
+
The dynamic thresholding method.
|
419 |
+
"""
|
420 |
+
dims = x0.dim()
|
421 |
+
p = self.dynamic_thresholding_ratio
|
422 |
+
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
423 |
+
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
424 |
+
x0 = torch.clamp(x0, -s, s) / s
|
425 |
+
return x0
|
426 |
+
|
427 |
+
def noise_prediction_fn(self, x, t):
|
428 |
+
"""
|
429 |
+
Return the noise prediction model.
|
430 |
+
"""
|
431 |
+
return self.model(x, t)
|
432 |
+
|
433 |
+
def data_prediction_fn(self, x, t):
|
434 |
+
"""
|
435 |
+
Return the data prediction model (with corrector).
|
436 |
+
"""
|
437 |
+
noise = self.noise_prediction_fn(x, t)
|
438 |
+
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
439 |
+
x0 = (x - sigma_t * noise) / alpha_t
|
440 |
+
if self.correcting_x0_fn is not None:
|
441 |
+
x0 = self.correcting_x0_fn(x0, t)
|
442 |
+
return x0
|
443 |
+
|
444 |
+
def model_fn(self, x, t):
|
445 |
+
"""
|
446 |
+
Convert the model to the noise prediction model or the data prediction model.
|
447 |
+
"""
|
448 |
+
if self.algorithm_type == "dpmsolver++":
|
449 |
+
return self.data_prediction_fn(x, t)
|
450 |
+
else:
|
451 |
+
return self.noise_prediction_fn(x, t)
|
452 |
+
|
453 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
454 |
+
"""Compute the intermediate time steps for sampling.
|
455 |
+
|
456 |
+
Args:
|
457 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
458 |
+
- 'logSNR': uniform logSNR for the time steps.
|
459 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
460 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
461 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
462 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
463 |
+
N: A `int`. The total number of the spacing of the time steps.
|
464 |
+
device: A torch device.
|
465 |
+
Returns:
|
466 |
+
A pytorch tensor of the time steps, with the shape (N + 1,).
|
467 |
+
"""
|
468 |
+
if skip_type == 'logSNR':
|
469 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
470 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
471 |
+
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
472 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
473 |
+
elif skip_type == 'time_uniform':
|
474 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
475 |
+
elif skip_type == 'time_quadratic':
|
476 |
+
t_order = 2
|
477 |
+
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
478 |
+
return t
|
479 |
+
else:
|
480 |
+
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
481 |
+
|
482 |
+
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
483 |
+
"""
|
484 |
+
Get the order of each step for sampling by the singlestep DPM-Solver.
|
485 |
+
|
486 |
+
We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
|
487 |
+
Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
|
488 |
+
- If order == 1:
|
489 |
+
We take `steps` of DPM-Solver-1 (i.e. DDIM).
|
490 |
+
- If order == 2:
|
491 |
+
- Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
|
492 |
+
- If steps % 2 == 0, we use K steps of DPM-Solver-2.
|
493 |
+
- If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
494 |
+
- If order == 3:
|
495 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
496 |
+
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
497 |
+
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
|
498 |
+
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
|
499 |
+
|
500 |
+
============================================
|
501 |
+
Args:
|
502 |
+
order: A `int`. The max order for the solver (2 or 3).
|
503 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
504 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
505 |
+
- 'logSNR': uniform logSNR for the time steps.
|
506 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
507 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
508 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
509 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
510 |
+
device: A torch device.
|
511 |
+
Returns:
|
512 |
+
orders: A list of the solver order of each step.
|
513 |
+
"""
|
514 |
+
if order == 3:
|
515 |
+
K = steps // 3 + 1
|
516 |
+
if steps % 3 == 0:
|
517 |
+
orders = [3,] * (K - 2) + [2, 1]
|
518 |
+
elif steps % 3 == 1:
|
519 |
+
orders = [3,] * (K - 1) + [1]
|
520 |
+
else:
|
521 |
+
orders = [3,] * (K - 1) + [2]
|
522 |
+
elif order == 2:
|
523 |
+
if steps % 2 == 0:
|
524 |
+
K = steps // 2
|
525 |
+
orders = [2,] * K
|
526 |
+
else:
|
527 |
+
K = steps // 2 + 1
|
528 |
+
orders = [2,] * (K - 1) + [1]
|
529 |
+
elif order == 1:
|
530 |
+
K = 1
|
531 |
+
orders = [1,] * steps
|
532 |
+
else:
|
533 |
+
raise ValueError("'order' must be '1' or '2' or '3'.")
|
534 |
+
if skip_type == 'logSNR':
|
535 |
+
# To reproduce the results in DPM-Solver paper
|
536 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
537 |
+
else:
|
538 |
+
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
|
539 |
+
return timesteps_outer, orders
|
540 |
+
|
541 |
+
def denoise_to_zero_fn(self, x, s):
|
542 |
+
"""
|
543 |
+
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
544 |
+
"""
|
545 |
+
return self.data_prediction_fn(x, s)
|
546 |
+
|
547 |
+
def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
|
548 |
+
"""
|
549 |
+
DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
|
550 |
+
|
551 |
+
Args:
|
552 |
+
x: A pytorch tensor. The initial value at time `s`.
|
553 |
+
s: A pytorch tensor. The starting time, with the shape (1,).
|
554 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
555 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
556 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
557 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`.
|
558 |
+
Returns:
|
559 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
560 |
+
"""
|
561 |
+
ns = self.noise_schedule
|
562 |
+
dims = x.dim()
|
563 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
564 |
+
h = lambda_t - lambda_s
|
565 |
+
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
|
566 |
+
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
|
567 |
+
alpha_t = torch.exp(log_alpha_t)
|
568 |
+
|
569 |
+
if self.algorithm_type == "dpmsolver++":
|
570 |
+
phi_1 = torch.expm1(-h)
|
571 |
+
if model_s is None:
|
572 |
+
model_s = self.model_fn(x, s)
|
573 |
+
x_t = (
|
574 |
+
sigma_t / sigma_s * x
|
575 |
+
- alpha_t * phi_1 * model_s
|
576 |
+
)
|
577 |
+
if return_intermediate:
|
578 |
+
return x_t, {'model_s': model_s}
|
579 |
+
else:
|
580 |
+
return x_t
|
581 |
+
else:
|
582 |
+
phi_1 = torch.expm1(h)
|
583 |
+
if model_s is None:
|
584 |
+
model_s = self.model_fn(x, s)
|
585 |
+
x_t = (
|
586 |
+
torch.exp(log_alpha_t - log_alpha_s) * x
|
587 |
+
- (sigma_t * phi_1) * model_s
|
588 |
+
)
|
589 |
+
if return_intermediate:
|
590 |
+
return x_t, {'model_s': model_s}
|
591 |
+
else:
|
592 |
+
return x_t
|
593 |
+
|
594 |
+
def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpmsolver'):
|
595 |
+
"""
|
596 |
+
Singlestep solver DPM-Solver-2 from time `s` to time `t`.
|
597 |
+
|
598 |
+
Args:
|
599 |
+
x: A pytorch tensor. The initial value at time `s`.
|
600 |
+
s: A pytorch tensor. The starting time, with the shape (1,).
|
601 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
602 |
+
r1: A `float`. The hyperparameter of the second-order solver.
|
603 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
604 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
605 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
|
606 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
607 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
608 |
+
Returns:
|
609 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
610 |
+
"""
|
611 |
+
if solver_type not in ['dpmsolver', 'taylor']:
|
612 |
+
raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
|
613 |
+
if r1 is None:
|
614 |
+
r1 = 0.5
|
615 |
+
ns = self.noise_schedule
|
616 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
617 |
+
h = lambda_t - lambda_s
|
618 |
+
lambda_s1 = lambda_s + r1 * h
|
619 |
+
s1 = ns.inverse_lambda(lambda_s1)
|
620 |
+
log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t)
|
621 |
+
sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
|
622 |
+
alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
|
623 |
+
|
624 |
+
if self.algorithm_type == "dpmsolver++":
|
625 |
+
phi_11 = torch.expm1(-r1 * h)
|
626 |
+
phi_1 = torch.expm1(-h)
|
627 |
+
|
628 |
+
if model_s is None:
|
629 |
+
model_s = self.model_fn(x, s)
|
630 |
+
x_s1 = (
|
631 |
+
(sigma_s1 / sigma_s) * x
|
632 |
+
- (alpha_s1 * phi_11) * model_s
|
633 |
+
)
|
634 |
+
model_s1 = self.model_fn(x_s1, s1)
|
635 |
+
if solver_type == 'dpmsolver':
|
636 |
+
x_t = (
|
637 |
+
(sigma_t / sigma_s) * x
|
638 |
+
- (alpha_t * phi_1) * model_s
|
639 |
+
- (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
|
640 |
+
)
|
641 |
+
elif solver_type == 'taylor':
|
642 |
+
x_t = (
|
643 |
+
(sigma_t / sigma_s) * x
|
644 |
+
- (alpha_t * phi_1) * model_s
|
645 |
+
+ (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s)
|
646 |
+
)
|
647 |
+
else:
|
648 |
+
phi_11 = torch.expm1(r1 * h)
|
649 |
+
phi_1 = torch.expm1(h)
|
650 |
+
|
651 |
+
if model_s is None:
|
652 |
+
model_s = self.model_fn(x, s)
|
653 |
+
x_s1 = (
|
654 |
+
torch.exp(log_alpha_s1 - log_alpha_s) * x
|
655 |
+
- (sigma_s1 * phi_11) * model_s
|
656 |
+
)
|
657 |
+
model_s1 = self.model_fn(x_s1, s1)
|
658 |
+
if solver_type == 'dpmsolver':
|
659 |
+
x_t = (
|
660 |
+
torch.exp(log_alpha_t - log_alpha_s) * x
|
661 |
+
- (sigma_t * phi_1) * model_s
|
662 |
+
- (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
|
663 |
+
)
|
664 |
+
elif solver_type == 'taylor':
|
665 |
+
x_t = (
|
666 |
+
torch.exp(log_alpha_t - log_alpha_s) * x
|
667 |
+
- (sigma_t * phi_1) * model_s
|
668 |
+
- (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s)
|
669 |
+
)
|
670 |
+
if return_intermediate:
|
671 |
+
return x_t, {'model_s': model_s, 'model_s1': model_s1}
|
672 |
+
else:
|
673 |
+
return x_t
|
674 |
+
|
675 |
+
def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpmsolver'):
|
676 |
+
"""
|
677 |
+
Singlestep solver DPM-Solver-3 from time `s` to time `t`.
|
678 |
+
|
679 |
+
Args:
|
680 |
+
x: A pytorch tensor. The initial value at time `s`.
|
681 |
+
s: A pytorch tensor. The starting time, with the shape (1,).
|
682 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
683 |
+
r1: A `float`. The hyperparameter of the third-order solver.
|
684 |
+
r2: A `float`. The hyperparameter of the third-order solver.
|
685 |
+
model_s: A pytorch tensor. The model function evaluated at time `s`.
|
686 |
+
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
|
687 |
+
model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
|
688 |
+
If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
|
689 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
|
690 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
691 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
692 |
+
Returns:
|
693 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
694 |
+
"""
|
695 |
+
if solver_type not in ['dpmsolver', 'taylor']:
|
696 |
+
raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
|
697 |
+
if r1 is None:
|
698 |
+
r1 = 1. / 3.
|
699 |
+
if r2 is None:
|
700 |
+
r2 = 2. / 3.
|
701 |
+
ns = self.noise_schedule
|
702 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
703 |
+
h = lambda_t - lambda_s
|
704 |
+
lambda_s1 = lambda_s + r1 * h
|
705 |
+
lambda_s2 = lambda_s + r2 * h
|
706 |
+
s1 = ns.inverse_lambda(lambda_s1)
|
707 |
+
s2 = ns.inverse_lambda(lambda_s2)
|
708 |
+
log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
|
709 |
+
sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t)
|
710 |
+
alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
|
711 |
+
|
712 |
+
if self.algorithm_type == "dpmsolver++":
|
713 |
+
phi_11 = torch.expm1(-r1 * h)
|
714 |
+
phi_12 = torch.expm1(-r2 * h)
|
715 |
+
phi_1 = torch.expm1(-h)
|
716 |
+
phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
|
717 |
+
phi_2 = phi_1 / h + 1.
|
718 |
+
phi_3 = phi_2 / h - 0.5
|
719 |
+
|
720 |
+
if model_s is None:
|
721 |
+
model_s = self.model_fn(x, s)
|
722 |
+
if model_s1 is None:
|
723 |
+
x_s1 = (
|
724 |
+
(sigma_s1 / sigma_s) * x
|
725 |
+
- (alpha_s1 * phi_11) * model_s
|
726 |
+
)
|
727 |
+
model_s1 = self.model_fn(x_s1, s1)
|
728 |
+
x_s2 = (
|
729 |
+
(sigma_s2 / sigma_s) * x
|
730 |
+
- (alpha_s2 * phi_12) * model_s
|
731 |
+
+ r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
|
732 |
+
)
|
733 |
+
model_s2 = self.model_fn(x_s2, s2)
|
734 |
+
if solver_type == 'dpmsolver':
|
735 |
+
x_t = (
|
736 |
+
(sigma_t / sigma_s) * x
|
737 |
+
- (alpha_t * phi_1) * model_s
|
738 |
+
+ (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
|
739 |
+
)
|
740 |
+
elif solver_type == 'taylor':
|
741 |
+
D1_0 = (1. / r1) * (model_s1 - model_s)
|
742 |
+
D1_1 = (1. / r2) * (model_s2 - model_s)
|
743 |
+
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
744 |
+
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
|
745 |
+
x_t = (
|
746 |
+
(sigma_t / sigma_s) * x
|
747 |
+
- (alpha_t * phi_1) * model_s
|
748 |
+
+ (alpha_t * phi_2) * D1
|
749 |
+
- (alpha_t * phi_3) * D2
|
750 |
+
)
|
751 |
+
else:
|
752 |
+
phi_11 = torch.expm1(r1 * h)
|
753 |
+
phi_12 = torch.expm1(r2 * h)
|
754 |
+
phi_1 = torch.expm1(h)
|
755 |
+
phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
|
756 |
+
phi_2 = phi_1 / h - 1.
|
757 |
+
phi_3 = phi_2 / h - 0.5
|
758 |
+
|
759 |
+
if model_s is None:
|
760 |
+
model_s = self.model_fn(x, s)
|
761 |
+
if model_s1 is None:
|
762 |
+
x_s1 = (
|
763 |
+
(torch.exp(log_alpha_s1 - log_alpha_s)) * x
|
764 |
+
- (sigma_s1 * phi_11) * model_s
|
765 |
+
)
|
766 |
+
model_s1 = self.model_fn(x_s1, s1)
|
767 |
+
x_s2 = (
|
768 |
+
(torch.exp(log_alpha_s2 - log_alpha_s)) * x
|
769 |
+
- (sigma_s2 * phi_12) * model_s
|
770 |
+
- r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
|
771 |
+
)
|
772 |
+
model_s2 = self.model_fn(x_s2, s2)
|
773 |
+
if solver_type == 'dpmsolver':
|
774 |
+
x_t = (
|
775 |
+
(torch.exp(log_alpha_t - log_alpha_s)) * x
|
776 |
+
- (sigma_t * phi_1) * model_s
|
777 |
+
- (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
|
778 |
+
)
|
779 |
+
elif solver_type == 'taylor':
|
780 |
+
D1_0 = (1. / r1) * (model_s1 - model_s)
|
781 |
+
D1_1 = (1. / r2) * (model_s2 - model_s)
|
782 |
+
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
783 |
+
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
|
784 |
+
x_t = (
|
785 |
+
(torch.exp(log_alpha_t - log_alpha_s)) * x
|
786 |
+
- (sigma_t * phi_1) * model_s
|
787 |
+
- (sigma_t * phi_2) * D1
|
788 |
+
- (sigma_t * phi_3) * D2
|
789 |
+
)
|
790 |
+
|
791 |
+
if return_intermediate:
|
792 |
+
return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
|
793 |
+
else:
|
794 |
+
return x_t
|
795 |
+
|
796 |
+
def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
|
797 |
+
"""
|
798 |
+
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
|
799 |
+
|
800 |
+
Args:
|
801 |
+
x: A pytorch tensor. The initial value at time `s`.
|
802 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
803 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
|
804 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
805 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
806 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
807 |
+
Returns:
|
808 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
809 |
+
"""
|
810 |
+
if solver_type not in ['dpmsolver', 'taylor']:
|
811 |
+
raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
|
812 |
+
ns = self.noise_schedule
|
813 |
+
model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
|
814 |
+
t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
|
815 |
+
lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
|
816 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
817 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
818 |
+
alpha_t = torch.exp(log_alpha_t)
|
819 |
+
|
820 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
821 |
+
h = lambda_t - lambda_prev_0
|
822 |
+
r0 = h_0 / h
|
823 |
+
D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
|
824 |
+
if self.algorithm_type == "dpmsolver++":
|
825 |
+
phi_1 = torch.expm1(-h)
|
826 |
+
if solver_type == 'dpmsolver':
|
827 |
+
x_t = (
|
828 |
+
(sigma_t / sigma_prev_0) * x
|
829 |
+
- (alpha_t * phi_1) * model_prev_0
|
830 |
+
- 0.5 * (alpha_t * phi_1) * D1_0
|
831 |
+
)
|
832 |
+
elif solver_type == 'taylor':
|
833 |
+
x_t = (
|
834 |
+
(sigma_t / sigma_prev_0) * x
|
835 |
+
- (alpha_t * phi_1) * model_prev_0
|
836 |
+
+ (alpha_t * (phi_1 / h + 1.)) * D1_0
|
837 |
+
)
|
838 |
+
else:
|
839 |
+
phi_1 = torch.expm1(h)
|
840 |
+
if solver_type == 'dpmsolver':
|
841 |
+
x_t = (
|
842 |
+
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
843 |
+
- (sigma_t * phi_1) * model_prev_0
|
844 |
+
- 0.5 * (sigma_t * phi_1) * D1_0
|
845 |
+
)
|
846 |
+
elif solver_type == 'taylor':
|
847 |
+
x_t = (
|
848 |
+
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
849 |
+
- (sigma_t * phi_1) * model_prev_0
|
850 |
+
- (sigma_t * (phi_1 / h - 1.)) * D1_0
|
851 |
+
)
|
852 |
+
return x_t
|
853 |
+
|
854 |
+
def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'):
|
855 |
+
"""
|
856 |
+
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
|
857 |
+
|
858 |
+
Args:
|
859 |
+
x: A pytorch tensor. The initial value at time `s`.
|
860 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
861 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
|
862 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
863 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
864 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
865 |
+
Returns:
|
866 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
867 |
+
"""
|
868 |
+
ns = self.noise_schedule
|
869 |
+
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
|
870 |
+
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
|
871 |
+
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
|
872 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
873 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
874 |
+
alpha_t = torch.exp(log_alpha_t)
|
875 |
+
|
876 |
+
h_1 = lambda_prev_1 - lambda_prev_2
|
877 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
878 |
+
h = lambda_t - lambda_prev_0
|
879 |
+
r0, r1 = h_0 / h, h_1 / h
|
880 |
+
D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
|
881 |
+
D1_1 = (1. / r1) * (model_prev_1 - model_prev_2)
|
882 |
+
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
883 |
+
D2 = (1. / (r0 + r1)) * (D1_0 - D1_1)
|
884 |
+
if self.algorithm_type == "dpmsolver++":
|
885 |
+
phi_1 = torch.expm1(-h)
|
886 |
+
phi_2 = phi_1 / h + 1.
|
887 |
+
phi_3 = phi_2 / h - 0.5
|
888 |
+
x_t = (
|
889 |
+
(sigma_t / sigma_prev_0) * x
|
890 |
+
- (alpha_t * phi_1) * model_prev_0
|
891 |
+
+ (alpha_t * phi_2) * D1
|
892 |
+
- (alpha_t * phi_3) * D2
|
893 |
+
)
|
894 |
+
else:
|
895 |
+
phi_1 = torch.expm1(h)
|
896 |
+
phi_2 = phi_1 / h - 1.
|
897 |
+
phi_3 = phi_2 / h - 0.5
|
898 |
+
x_t = (
|
899 |
+
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
900 |
+
- (sigma_t * phi_1) * model_prev_0
|
901 |
+
- (sigma_t * phi_2) * D1
|
902 |
+
- (sigma_t * phi_3) * D2
|
903 |
+
)
|
904 |
+
return x_t
|
905 |
+
|
906 |
+
def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None, r2=None):
|
907 |
+
"""
|
908 |
+
Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
|
909 |
+
|
910 |
+
Args:
|
911 |
+
x: A pytorch tensor. The initial value at time `s`.
|
912 |
+
s: A pytorch tensor. The starting time, with the shape (1,).
|
913 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
914 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
915 |
+
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
|
916 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
917 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
918 |
+
r1: A `float`. The hyperparameter of the second-order or third-order solver.
|
919 |
+
r2: A `float`. The hyperparameter of the third-order solver.
|
920 |
+
Returns:
|
921 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
922 |
+
"""
|
923 |
+
if order == 1:
|
924 |
+
return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
|
925 |
+
elif order == 2:
|
926 |
+
return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1)
|
927 |
+
elif order == 3:
|
928 |
+
return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2)
|
929 |
+
else:
|
930 |
+
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
931 |
+
|
932 |
+
def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'):
|
933 |
+
"""
|
934 |
+
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
|
935 |
+
|
936 |
+
Args:
|
937 |
+
x: A pytorch tensor. The initial value at time `s`.
|
938 |
+
model_prev_list: A list of pytorch tensor. The previous computed model values.
|
939 |
+
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
|
940 |
+
t: A pytorch tensor. The ending time, with the shape (1,).
|
941 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
942 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
943 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
944 |
+
Returns:
|
945 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
946 |
+
"""
|
947 |
+
if order == 1:
|
948 |
+
return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
|
949 |
+
elif order == 2:
|
950 |
+
return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
|
951 |
+
elif order == 3:
|
952 |
+
return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
|
953 |
+
else:
|
954 |
+
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
955 |
+
|
956 |
+
def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpmsolver'):
|
957 |
+
"""
|
958 |
+
The adaptive step size solver based on singlestep DPM-Solver.
|
959 |
+
|
960 |
+
Args:
|
961 |
+
x: A pytorch tensor. The initial value at time `t_T`.
|
962 |
+
order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
|
963 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
964 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
965 |
+
h_init: A `float`. The initial step size (for logSNR).
|
966 |
+
atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
|
967 |
+
rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
|
968 |
+
theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
|
969 |
+
t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
|
970 |
+
current time and `t_0` is less than `t_err`. The default setting is 1e-5.
|
971 |
+
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
|
972 |
+
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
|
973 |
+
Returns:
|
974 |
+
x_0: A pytorch tensor. The approximated solution at time `t_0`.
|
975 |
+
|
976 |
+
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
|
977 |
+
"""
|
978 |
+
ns = self.noise_schedule
|
979 |
+
s = t_T * torch.ones((1,)).to(x)
|
980 |
+
lambda_s = ns.marginal_lambda(s)
|
981 |
+
lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
|
982 |
+
h = h_init * torch.ones_like(s).to(x)
|
983 |
+
x_prev = x
|
984 |
+
nfe = 0
|
985 |
+
if order == 2:
|
986 |
+
r1 = 0.5
|
987 |
+
lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
|
988 |
+
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
|
989 |
+
elif order == 3:
|
990 |
+
r1, r2 = 1. / 3., 2. / 3.
|
991 |
+
lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type)
|
992 |
+
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
|
993 |
+
else:
|
994 |
+
raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
|
995 |
+
while torch.abs((s - t_0)).mean() > t_err:
|
996 |
+
t = ns.inverse_lambda(lambda_s + h)
|
997 |
+
x_lower, lower_noise_kwargs = lower_update(x, s, t)
|
998 |
+
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
|
999 |
+
delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
|
1000 |
+
norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
|
1001 |
+
E = norm_fn((x_higher - x_lower) / delta).max()
|
1002 |
+
if torch.all(E <= 1.):
|
1003 |
+
x = x_higher
|
1004 |
+
s = t
|
1005 |
+
x_prev = x_lower
|
1006 |
+
lambda_s = ns.marginal_lambda(s)
|
1007 |
+
h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
|
1008 |
+
nfe += order
|
1009 |
+
print('adaptive solver nfe', nfe)
|
1010 |
+
return x
|
1011 |
+
|
1012 |
+
def add_noise(self, x, t, noise=None):
|
1013 |
+
"""
|
1014 |
+
Compute the noised input xt = alpha_t * x + sigma_t * noise.
|
1015 |
+
|
1016 |
+
Args:
|
1017 |
+
x: A `torch.Tensor` with shape `(batch_size, *shape)`.
|
1018 |
+
t: A `torch.Tensor` with shape `(t_size,)`.
|
1019 |
+
Returns:
|
1020 |
+
xt with shape `(t_size, batch_size, *shape)`.
|
1021 |
+
"""
|
1022 |
+
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
1023 |
+
if noise is None:
|
1024 |
+
noise = torch.randn((t.shape[0], *x.shape), device=x.device)
|
1025 |
+
x = x.reshape((-1, *x.shape))
|
1026 |
+
xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
|
1027 |
+
if t.shape[0] == 1:
|
1028 |
+
return xt.squeeze(0)
|
1029 |
+
else:
|
1030 |
+
return xt
|
1031 |
+
|
1032 |
+
def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
|
1033 |
+
method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
|
1034 |
+
atol=0.0078, rtol=0.05, return_intermediate=False,
|
1035 |
+
):
|
1036 |
+
"""
|
1037 |
+
Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
|
1038 |
+
For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
|
1039 |
+
"""
|
1040 |
+
t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start
|
1041 |
+
t_T = self.noise_schedule.T if t_end is None else t_end
|
1042 |
+
assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
1043 |
+
return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type,
|
1044 |
+
method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero, solver_type=solver_type,
|
1045 |
+
atol=atol, rtol=rtol, return_intermediate=return_intermediate)
|
1046 |
+
|
1047 |
+
def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
|
1048 |
+
method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
|
1049 |
+
atol=0.0078, rtol=0.05, return_intermediate=False,
|
1050 |
+
):
|
1051 |
+
"""
|
1052 |
+
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
|
1053 |
+
|
1054 |
+
=====================================================
|
1055 |
+
|
1056 |
+
We support the following algorithms for both noise prediction model and data prediction model:
|
1057 |
+
- 'singlestep':
|
1058 |
+
Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
|
1059 |
+
We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
|
1060 |
+
The total number of function evaluations (NFE) == `steps`.
|
1061 |
+
Given a fixed NFE == `steps`, the sampling procedure is:
|
1062 |
+
- If `order` == 1:
|
1063 |
+
- Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
|
1064 |
+
- If `order` == 2:
|
1065 |
+
- Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
|
1066 |
+
- If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
|
1067 |
+
- If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
|
1068 |
+
- If `order` == 3:
|
1069 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
1070 |
+
- If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
|
1071 |
+
- If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
|
1072 |
+
- If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
|
1073 |
+
- 'multistep':
|
1074 |
+
Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
|
1075 |
+
We initialize the first `order` values by lower order multistep solvers.
|
1076 |
+
Given a fixed NFE == `steps`, the sampling procedure is:
|
1077 |
+
Denote K = steps.
|
1078 |
+
- If `order` == 1:
|
1079 |
+
- We use K steps of DPM-Solver-1 (i.e. DDIM).
|
1080 |
+
- If `order` == 2:
|
1081 |
+
- We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
|
1082 |
+
- If `order` == 3:
|
1083 |
+
- We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
|
1084 |
+
- 'singlestep_fixed':
|
1085 |
+
Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
|
1086 |
+
We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
|
1087 |
+
- 'adaptive':
|
1088 |
+
Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
|
1089 |
+
We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
|
1090 |
+
You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
|
1091 |
+
(NFE) and the sample quality.
|
1092 |
+
- If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
|
1093 |
+
- If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
|
1094 |
+
|
1095 |
+
=====================================================
|
1096 |
+
|
1097 |
+
Some advices for choosing the algorithm:
|
1098 |
+
- For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
|
1099 |
+
Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
|
1100 |
+
e.g., DPM-Solver:
|
1101 |
+
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
|
1102 |
+
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
|
1103 |
+
skip_type='time_uniform', method='singlestep')
|
1104 |
+
e.g., DPM-Solver++:
|
1105 |
+
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
|
1106 |
+
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
|
1107 |
+
skip_type='time_uniform', method='singlestep')
|
1108 |
+
- For **guided sampling with large guidance scale** by DPMs:
|
1109 |
+
Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
|
1110 |
+
e.g.
|
1111 |
+
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
|
1112 |
+
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
|
1113 |
+
skip_type='time_uniform', method='multistep')
|
1114 |
+
|
1115 |
+
We support three types of `skip_type`:
|
1116 |
+
- 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
|
1117 |
+
- 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
|
1118 |
+
- 'time_quadratic': quadratic time for the time steps.
|
1119 |
+
|
1120 |
+
=====================================================
|
1121 |
+
Args:
|
1122 |
+
x: A pytorch tensor. The initial value at time `t_start`
|
1123 |
+
e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
|
1124 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
1125 |
+
t_start: A `float`. The starting time of the sampling.
|
1126 |
+
If `T` is None, we use self.noise_schedule.T (default is 1.0).
|
1127 |
+
t_end: A `float`. The ending time of the sampling.
|
1128 |
+
If `t_end` is None, we use 1. / self.noise_schedule.total_N.
|
1129 |
+
e.g. if total_N == 1000, we have `t_end` == 1e-3.
|
1130 |
+
For discrete-time DPMs:
|
1131 |
+
- We recommend `t_end` == 1. / self.noise_schedule.total_N.
|
1132 |
+
For continuous-time DPMs:
|
1133 |
+
- We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
|
1134 |
+
order: A `int`. The order of DPM-Solver.
|
1135 |
+
skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
|
1136 |
+
method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
|
1137 |
+
denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
|
1138 |
+
Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
|
1139 |
+
|
1140 |
+
This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
|
1141 |
+
score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
|
1142 |
+
for diffusion models sampling by diffusion SDEs for low-resolutional images
|
1143 |
+
(such as CIFAR-10). However, we observed that such trick does not matter for
|
1144 |
+
high-resolutional images. As it needs an additional NFE, we do not recommend
|
1145 |
+
it for high-resolutional images.
|
1146 |
+
lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
|
1147 |
+
Only valid for `method=multistep` and `steps < 15`. We empirically find that
|
1148 |
+
this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
|
1149 |
+
(especially for steps <= 10). So we recommend to set it to be `True`.
|
1150 |
+
solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
|
1151 |
+
atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
|
1152 |
+
rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
|
1153 |
+
return_intermediate: A `bool`. Whether to save the xt at each step.
|
1154 |
+
When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
|
1155 |
+
Returns:
|
1156 |
+
x_end: A pytorch tensor. The approximated solution at time `t_end`.
|
1157 |
+
|
1158 |
+
"""
|
1159 |
+
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
1160 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
1161 |
+
assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
1162 |
+
if return_intermediate:
|
1163 |
+
assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values"
|
1164 |
+
if self.correcting_xt_fn is not None:
|
1165 |
+
assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None"
|
1166 |
+
device = x.device
|
1167 |
+
intermediates = []
|
1168 |
+
with torch.no_grad():
|
1169 |
+
if method == 'adaptive':
|
1170 |
+
x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type)
|
1171 |
+
elif method == 'multistep':
|
1172 |
+
assert steps >= order
|
1173 |
+
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
1174 |
+
assert timesteps.shape[0] - 1 == steps
|
1175 |
+
# Init the initial values.
|
1176 |
+
step = 0
|
1177 |
+
t = timesteps[step]
|
1178 |
+
t_prev_list = [t]
|
1179 |
+
model_prev_list = [self.model_fn(x, t)]
|
1180 |
+
if self.correcting_xt_fn is not None:
|
1181 |
+
x = self.correcting_xt_fn(x, t, step)
|
1182 |
+
if return_intermediate:
|
1183 |
+
intermediates.append(x)
|
1184 |
+
# Init the first `order` values by lower order multistep DPM-Solver.
|
1185 |
+
for step in range(1, order):
|
1186 |
+
t = timesteps[step]
|
1187 |
+
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step, solver_type=solver_type)
|
1188 |
+
if self.correcting_xt_fn is not None:
|
1189 |
+
x = self.correcting_xt_fn(x, t, step)
|
1190 |
+
if return_intermediate:
|
1191 |
+
intermediates.append(x)
|
1192 |
+
t_prev_list.append(t)
|
1193 |
+
model_prev_list.append(self.model_fn(x, t))
|
1194 |
+
# Compute the remaining values by `order`-th order multistep DPM-Solver.
|
1195 |
+
for step in range(order, steps + 1):
|
1196 |
+
t = timesteps[step]
|
1197 |
+
# We only use lower order for steps < 10
|
1198 |
+
if lower_order_final and steps < 10:
|
1199 |
+
step_order = min(order, steps + 1 - step)
|
1200 |
+
else:
|
1201 |
+
step_order = order
|
1202 |
+
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type)
|
1203 |
+
if self.correcting_xt_fn is not None:
|
1204 |
+
x = self.correcting_xt_fn(x, t, step)
|
1205 |
+
if return_intermediate:
|
1206 |
+
intermediates.append(x)
|
1207 |
+
for i in range(order - 1):
|
1208 |
+
t_prev_list[i] = t_prev_list[i + 1]
|
1209 |
+
model_prev_list[i] = model_prev_list[i + 1]
|
1210 |
+
t_prev_list[-1] = t
|
1211 |
+
# We do not need to evaluate the final model value.
|
1212 |
+
if step < steps:
|
1213 |
+
model_prev_list[-1] = self.model_fn(x, t)
|
1214 |
+
elif method in ['singlestep', 'singlestep_fixed']:
|
1215 |
+
if method == 'singlestep':
|
1216 |
+
timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device)
|
1217 |
+
elif method == 'singlestep_fixed':
|
1218 |
+
K = steps // order
|
1219 |
+
orders = [order,] * K
|
1220 |
+
timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
|
1221 |
+
for step, order in enumerate(orders):
|
1222 |
+
s, t = timesteps_outer[step], timesteps_outer[step + 1]
|
1223 |
+
timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device)
|
1224 |
+
lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
|
1225 |
+
h = lambda_inner[-1] - lambda_inner[0]
|
1226 |
+
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
|
1227 |
+
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
|
1228 |
+
x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
|
1229 |
+
if self.correcting_xt_fn is not None:
|
1230 |
+
x = self.correcting_xt_fn(x, t, step)
|
1231 |
+
if return_intermediate:
|
1232 |
+
intermediates.append(x)
|
1233 |
+
else:
|
1234 |
+
raise ValueError("Got wrong method {}".format(method))
|
1235 |
+
if denoise_to_zero:
|
1236 |
+
t = torch.ones((1,)).to(device) * t_0
|
1237 |
+
x = self.denoise_to_zero_fn(x, t)
|
1238 |
+
if self.correcting_xt_fn is not None:
|
1239 |
+
x = self.correcting_xt_fn(x, t, step + 1)
|
1240 |
+
if return_intermediate:
|
1241 |
+
intermediates.append(x)
|
1242 |
+
if return_intermediate:
|
1243 |
+
return x, intermediates
|
1244 |
+
else:
|
1245 |
+
return x
|
1246 |
+
|
1247 |
+
|
1248 |
+
|
1249 |
+
#############################################################
|
1250 |
+
# other utility functions
|
1251 |
+
#############################################################
|
1252 |
+
|
1253 |
+
def interpolate_fn(x, xp, yp):
|
1254 |
+
"""
|
1255 |
+
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
1256 |
+
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
1257 |
+
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
1258 |
+
|
1259 |
+
Args:
|
1260 |
+
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
1261 |
+
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
1262 |
+
yp: PyTorch tensor with shape [C, K].
|
1263 |
+
Returns:
|
1264 |
+
The function values f(x), with shape [N, C].
|
1265 |
+
"""
|
1266 |
+
N, K = x.shape[0], xp.shape[1]
|
1267 |
+
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
1268 |
+
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
1269 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
1270 |
+
cand_start_idx = x_idx - 1
|
1271 |
+
start_idx = torch.where(
|
1272 |
+
torch.eq(x_idx, 0),
|
1273 |
+
torch.tensor(1, device=x.device),
|
1274 |
+
torch.where(
|
1275 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
1276 |
+
),
|
1277 |
+
)
|
1278 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
1279 |
+
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
1280 |
+
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
1281 |
+
start_idx2 = torch.where(
|
1282 |
+
torch.eq(x_idx, 0),
|
1283 |
+
torch.tensor(0, device=x.device),
|
1284 |
+
torch.where(
|
1285 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
1286 |
+
),
|
1287 |
+
)
|
1288 |
+
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
1289 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
1290 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
1291 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
1292 |
+
return cand
|
1293 |
+
|
1294 |
+
|
1295 |
+
def expand_dims(v, dims):
|
1296 |
+
"""
|
1297 |
+
Expand the tensor `v` to the dim `dims`.
|
1298 |
+
|
1299 |
+
Args:
|
1300 |
+
`v`: a PyTorch tensor with shape [N].
|
1301 |
+
`dim`: a `int`.
|
1302 |
+
Returns:
|
1303 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
1304 |
+
"""
|
1305 |
+
return v[(...,) + (None,)*(dims - 1)]
|