diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..2bf35c43ed664551d57926e08095cedb08191769
Binary files /dev/null and b/.DS_Store differ
diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..0da7c02e5dda467db394050b263aff53bdc12ee4 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+test.wav filter=lfs diff=lfs merge=lfs -text
+vad/assets/silero_vad.jit filter=lfs diff=lfs merge=lfs -text
diff --git a/.python-version b/.python-version
new file mode 100644
index 0000000000000000000000000000000000000000..54c5196a2b9da061074225f39dc40aed04fec0b9
--- /dev/null
+++ b/.python-version
@@ -0,0 +1 @@
+3.10.9
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..6ebe0ff106911262124772a8f199fe9c9e68e585
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023 PlayVoice
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 68b3932323aa60b5afa11d18fc39150f78d0e676..056a88d5d2640c34cafdabc6a8bb1f41148e973b 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,500 @@
---
-title: Sovits Test
-emoji: 📚
+title: Whisper Vits SVC
+emoji: 🎵
+python_version: 3.10.12
colorFrom: blue
colorTo: purple
sdk: gradio
-sdk_version: 5.8.0
-app_file: app.py
+sdk_version: 5.7.1
+app_file: main.py
pinned: false
+license: mit
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
Variational Inference with adversarial learning for end-to-end Singing Voice Conversion based on VITS
+
+[](https://huggingface.co/spaces/maxmax20160403/sovits5.0)
+

+

+

+

+
+[中文文档](./README_ZH.md)
+
+The tree [bigvgan-mix-v2](https://github.com/PlayVoice/whisper-vits-svc/tree/bigvgan-mix-v2) has good audio quality
+
+The tree [RoFormer-HiFTNet](https://github.com/PlayVoice/whisper-vits-svc/tree/RoFormer-HiFTNet) has fast infer speed
+
+No More Upgrade
+
+
+
+- This project targets deep learning beginners, basic knowledge of Python and PyTorch are the prerequisites for this project;
+- This project aims to help deep learning beginners get rid of boring pure theoretical learning, and master the basic knowledge of deep learning by combining it with practices;
+- This project does not support real-time voice converting; (need to replace whisper if real-time voice converting is what you are looking for)
+- This project will not develop one-click packages for other purposes;
+
+
+
+- A minimum VRAM requirement of 6GB for training
+
+- Support for multiple speakers
+
+- Create unique speakers through speaker mixing
+
+- It can even convert voices with light accompaniment
+
+- You can edit F0 using Excel
+
+https://github.com/PlayVoice/so-vits-svc-5.0/assets/16432329/6a09805e-ab93-47fe-9a14-9cbc1e0e7c3a
+
+Powered by [@ShadowVap](https://space.bilibili.com/491283091)
+
+## Model properties
+
+| Feature | From | Status | Function |
+| :--- | :--- | :--- | :--- |
+| whisper | OpenAI | ✅ | strong noise immunity |
+| bigvgan | NVIDA | ✅ | alias and snake | The formant is clearer and the sound quality is obviously improved |
+| natural speech | Microsoft | ✅ | reduce mispronunciation |
+| neural source-filter | Xin Wang | ✅ | solve the problem of audio F0 discontinuity |
+| pitch quantization | Xin Wang | ✅ | quantize the F0 for embedding |
+| speaker encoder | Google | ✅ | Timbre Encoding and Clustering |
+| GRL for speaker | Ubisoft |✅ | Preventing Encoder Leakage Timbre |
+| SNAC | Samsung | ✅ | One Shot Clone of VITS |
+| SCLN | Microsoft | ✅ | Improve Clone |
+| Diffusion | HuaWei | ✅ | Improve sound quality |
+| PPG perturbation | this project | ✅ | Improved noise immunity and de-timbre |
+| HuBERT perturbation | this project | ✅ | Improved noise immunity and de-timbre |
+| VAE perturbation | this project | ✅ | Improve sound quality |
+| MIX encoder | this project | ✅ | Improve conversion stability |
+| USP infer | this project | ✅ | Improve conversion stability |
+| HiFTNet | Columbia University | ✅ | NSF-iSTFTNet for speed up |
+| RoFormer | Zhuiyi Technology | ✅ | Rotary Positional Embeddings |
+
+due to the use of data perturbation, it takes longer to train than other projects.
+
+**USP : Unvoice and Silence with Pitch when infer**
+
+
+## Why mix
+
+
+
+## Plug-In-Diffusion
+
+
+
+## Setup Environment
+
+1. Install [PyTorch](https://pytorch.org/get-started/locally/).
+
+2. Install project dependencies
+ ```shell
+ pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
+ ```
+ **Note: whisper is already built-in, do not install it again otherwise it will cuase conflict and error**
+3. Download the Timbre Encoder: [Speaker-Encoder by @mueller91](https://drive.google.com/drive/folders/15oeBYf6Qn1edONkVLXe82MzdIi3O_9m3), put `best_model.pth.tar` into `speaker_pretrain/`.
+
+4. Download whisper model [whisper-large-v2](https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt). Make sure to download `large-v2.pt`,put it into `whisper_pretrain/`.
+
+5. Download [hubert_soft model](https://github.com/bshall/hubert/releases/tag/v0.1),put `hubert-soft-0d54a1f4.pt` into `hubert_pretrain/`.
+
+6. Download pitch extractor [crepe full](https://github.com/maxrmorrison/torchcrepe/tree/master/torchcrepe/assets),put `full.pth` into `crepe/assets`.
+
+ **Note: crepe full.pth is 84.9 MB, not 6kb**
+
+7. Download pretrain model [sovits5.0.pretrain.pth](https://github.com/PlayVoice/so-vits-svc-5.0/releases/tag/5.0/), and put it into `vits_pretrain/`.
+ ```shell
+ python svc_inference.py --config configs/base.yaml --model ./vits_pretrain/sovits5.0.pretrain.pth --spk ./configs/singers/singer0001.npy --wave test.wav
+ ```
+
+## Dataset preparation
+
+Necessary pre-processing:
+1. Separate voice and accompaniment with [UVR](https://github.com/Anjok07/ultimatevocalremovergui) (skip if no accompaniment)
+2. Cut audio input to shorter length with [slicer](https://github.com/flutydeer/audio-slicer), whisper takes input less than 30 seconds.
+3. Manually check generated audio input, remove inputs shorter than 2 seconds or with obivous noise.
+4. Adjust loudness if necessary, recommend Adobe Audiiton.
+5. Put the dataset into the `dataset_raw` directory following the structure below.
+```
+dataset_raw
+├───speaker0
+│ ├───000001.wav
+│ ├───...
+│ └───000xxx.wav
+└───speaker1
+ ├───000001.wav
+ ├───...
+ └───000xxx.wav
+```
+
+## Data preprocessing
+```shell
+python svc_preprocessing.py -t 2
+```
+`-t`: threading, max number should not exceed CPU core count, usually 2 is enough.
+After preprocessing you will get an output with following structure.
+```
+data_svc/
+└── waves-16k
+│ └── speaker0
+│ │ ├── 000001.wav
+│ │ └── 000xxx.wav
+│ └── speaker1
+│ ├── 000001.wav
+│ └── 000xxx.wav
+└── waves-32k
+│ └── speaker0
+│ │ ├── 000001.wav
+│ │ └── 000xxx.wav
+│ └── speaker1
+│ ├── 000001.wav
+│ └── 000xxx.wav
+└── pitch
+│ └── speaker0
+│ │ ├── 000001.pit.npy
+│ │ └── 000xxx.pit.npy
+│ └── speaker1
+│ ├── 000001.pit.npy
+│ └── 000xxx.pit.npy
+└── hubert
+│ └── speaker0
+│ │ ├── 000001.vec.npy
+│ │ └── 000xxx.vec.npy
+│ └── speaker1
+│ ├── 000001.vec.npy
+│ └── 000xxx.vec.npy
+└── whisper
+│ └── speaker0
+│ │ ├── 000001.ppg.npy
+│ │ └── 000xxx.ppg.npy
+│ └── speaker1
+│ ├── 000001.ppg.npy
+│ └── 000xxx.ppg.npy
+└── speaker
+│ └── speaker0
+│ │ ├── 000001.spk.npy
+│ │ └── 000xxx.spk.npy
+│ └── speaker1
+│ ├── 000001.spk.npy
+│ └── 000xxx.spk.npy
+└── singer
+│ ├── speaker0.spk.npy
+│ └── speaker1.spk.npy
+|
+└── indexes
+ ├── speaker0
+ │ ├── some_prefix_hubert.index
+ │ └── some_prefix_whisper.index
+ └── speaker1
+ ├── hubert.index
+ └── whisper.index
+```
+
+1. Re-sampling
+ - Generate audio with a sampling rate of 16000Hz in `./data_svc/waves-16k`
+ ```
+ python prepare/preprocess_a.py -w ./dataset_raw -o ./data_svc/waves-16k -s 16000
+ ```
+
+ - Generate audio with a sampling rate of 32000Hz in `./data_svc/waves-32k`
+ ```
+ python prepare/preprocess_a.py -w ./dataset_raw -o ./data_svc/waves-32k -s 32000
+ ```
+2. Use 16K audio to extract pitch
+ ```
+ python prepare/preprocess_crepe.py -w data_svc/waves-16k/ -p data_svc/pitch
+ ```
+3. Use 16K audio to extract ppg
+ ```
+ python prepare/preprocess_ppg.py -w data_svc/waves-16k/ -p data_svc/whisper
+ ```
+4. Use 16K audio to extract hubert
+ ```
+ python prepare/preprocess_hubert.py -w data_svc/waves-16k/ -v data_svc/hubert
+ ```
+5. Use 16k audio to extract timbre code
+ ```
+ python prepare/preprocess_speaker.py data_svc/waves-16k/ data_svc/speaker
+ ```
+6. Extract the average value of the timbre code for inference; it can also replace a single audio timbre in generating the training index, and use it as the unified timbre of the speaker for training
+ ```
+ python prepare/preprocess_speaker_ave.py data_svc/speaker/ data_svc/singer
+ ```
+7. Use 32k audio to extract the linear spectrum
+ ```
+ python prepare/preprocess_spec.py -w data_svc/waves-32k/ -s data_svc/specs
+ ```
+8. Use 32k audio to generate training index
+ ```
+ python prepare/preprocess_train.py
+ ```
+11. Training file debugging
+ ```
+ python prepare/preprocess_zzz.py
+ ```
+
+## Train
+1. If fine-tuning is based on the pre-trained model, you need to download the pre-trained model: [sovits5.0.pretrain.pth](https://github.com/PlayVoice/so-vits-svc-5.0/releases/tag/5.0). Put pretrained model under project root, change this line
+ ```
+ pretrain: "./vits_pretrain/sovits5.0.pretrain.pth"
+ ```
+ in `configs/base.yaml`,and adjust the learning rate appropriately, eg 5e-5.
+
+ `batch_size`: for GPU with 6G VRAM, 6 is the recommended value, 8 will work but step speed will be much slower.
+2. Start training
+ ```
+ python svc_trainer.py -c configs/base.yaml -n sovits5.0
+ ```
+3. Resume training
+ ```
+ python svc_trainer.py -c configs/base.yaml -n sovits5.0 -p chkpt/sovits5.0/sovits5.0_***.pt
+ ```
+4. Log visualization
+ ```
+ tensorboard --logdir logs/
+ ```
+
+
+
+
+
+## Inference
+
+1. Export inference model: text encoder, Flow network, Decoder network
+ ```
+ python svc_export.py --config configs/base.yaml --checkpoint_path chkpt/sovits5.0/***.pt
+ ```
+2. Inference
+ - if there is no need to adjust `f0`, just run the following command.
+ ```
+ python svc_inference.py --config configs/base.yaml --model sovits5.0.pth --spk ./data_svc/singer/your_singer.spk.npy --wave test.wav --shift 0
+ ```
+ - if `f0` will be adjusted manually, follow the steps:
+ 1. use whisper to extract content encoding, generate `test.vec.npy`.
+ ```
+ python whisper/inference.py -w test.wav -p test.ppg.npy
+ ```
+ 2. use hubert to extract content vector, without using one-click reasoning, in order to reduce GPU memory usage
+ ```
+ python hubert/inference.py -w test.wav -v test.vec.npy
+ ```
+ 3. extract the F0 parameter to the csv text format, open the csv file in Excel, and manually modify the wrong F0 according to Audition or SonicVisualiser
+ ```
+ python pitch/inference.py -w test.wav -p test.csv
+ ```
+ 4. final inference
+ ```
+ python svc_inference.py --config configs/base.yaml --model sovits5.0.pth --spk ./data_svc/singer/your_singer.spk.npy --wave test.wav --ppg test.ppg.npy --vec test.vec.npy --pit test.csv --shift 0
+ ```
+3. Notes
+
+ - when `--ppg` is specified, when the same audio is reasoned multiple times, it can avoid repeated extraction of audio content codes; if it is not specified, it will be automatically extracted;
+
+ - when `--vec` is specified, when the same audio is reasoned multiple times, it can avoid repeated extraction of audio content codes; if it is not specified, it will be automatically extracted;
+
+ - when `--pit` is specified, the manually tuned F0 parameter can be loaded; if not specified, it will be automatically extracted;
+
+ - generate files in the current directory:svc_out.wav
+
+4. Arguments ref
+
+ | args |--config | --model | --spk | --wave | --ppg | --vec | --pit | --shift |
+ | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
+ | name | config path | model path | speaker | wave input | wave ppg | wave hubert | wave pitch | pitch shift |
+
+5. post by vad
+```
+python svc_inference_post.py --ref test.wav --svc svc_out.wav --out svc_out_post.wav
+```
+
+## Train Feature Retrieval Index (Optional)
+
+To increase the stability of the generated timbre, you can use the method described in the
+[Retrieval-based-Voice-Conversion](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/main/docs/en/README.en.md)
+repository. This method consists of 2 steps:
+
+1. Training the retrieval index on hubert and whisper features
+ Run training with default settings:
+ ```
+ python svc_train_retrieval.py
+ ```
+
+ If the number of vectors is more than 200_000 they will be compressed to 10_000 using the MiniBatchKMeans algorithm.
+ You can change these settings using command line options:
+ ```
+ usage: crate faiss indexes for feature retrieval [-h] [--debug] [--prefix PREFIX] [--speakers SPEAKERS [SPEAKERS ...]] [--compress-features-after COMPRESS_FEATURES_AFTER]
+ [--n-clusters N_CLUSTERS] [--n-parallel N_PARALLEL]
+
+ options:
+ -h, --help show this help message and exit
+ --debug
+ --prefix PREFIX add prefix to index filename
+ --speakers SPEAKERS [SPEAKERS ...]
+ speaker names to create an index. By default all speakers are from data_svc
+ --compress-features-after COMPRESS_FEATURES_AFTER
+ If the number of features is greater than the value compress feature vectors using MiniBatchKMeans.
+ --n-clusters N_CLUSTERS
+ Number of centroids to which features will be compressed
+ --n-parallel N_PARALLEL
+ Nuber of parallel job of MinibatchKmeans. Default is cpus-1
+ ```
+ Compression of training vectors can speed up index inference, but reduces the quality of the retrieve.
+ Use vector count compression if you really have a lot of them.
+
+ The resulting indexes will be stored in the "indexes" folder as:
+ ```
+ data_svc
+ ...
+ └── indexes
+ ├── speaker0
+ │ ├── some_prefix_hubert.index
+ │ └── some_prefix_whisper.index
+ └── speaker1
+ ├── hubert.index
+ └── whisper.index
+ ```
+2. At the inference stage adding the n closest features in a certain proportion of the vits model
+ Enable Feature Retrieval with settings:
+ ```
+ python svc_inference.py --config configs/base.yaml --model sovits5.0.pth --spk ./data_svc/singer/your_singer.spk.npy --wave test.wav --shift 0 \
+ --enable-retrieval \
+ --retrieval-ratio 0.5 \
+ --n-retrieval-vectors 3
+ ```
+ For a better retrieval effect, you can try to cycle through different parameters: `--retrieval-ratio` and `--n-retrieval-vectors`
+
+ If you have multiple sets of indexes, you can specify a specific set via the parameter: `--retrieval-index-prefix`
+
+ You can explicitly specify the paths to the hubert and whisper indexes using the parameters: `--hubert-index-path` and `--whisper-index-path`
+
+
+## Create singer
+named by pure coincidence:average -> ave -> eva,eve(eva) represents conception and reproduction
+
+```
+python svc_eva.py
+```
+
+```python
+eva_conf = {
+ './configs/singers/singer0022.npy': 0,
+ './configs/singers/singer0030.npy': 0,
+ './configs/singers/singer0047.npy': 0.5,
+ './configs/singers/singer0051.npy': 0.5,
+}
+```
+
+the generated singer file will be `eva.spk.npy`.
+
+## Data set
+
+| Name | URL |
+| :--- | :--- |
+|KiSing |http://shijt.site/index.php/2021/05/16/kising-the-first-open-source-mandarin-singing-voice-synthesis-corpus/|
+|PopCS |https://github.com/MoonInTheRiver/DiffSinger/blob/master/resources/apply_form.md|
+|opencpop |https://wenet.org.cn/opencpop/download/|
+|Multi-Singer |https://github.com/Multi-Singer/Multi-Singer.github.io|
+|M4Singer |https://github.com/M4Singer/M4Singer/blob/master/apply_form.md|
+|CSD |https://zenodo.org/record/4785016#.YxqrTbaOMU4|
+|KSS |https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset|
+|JVS MuSic |https://sites.google.com/site/shinnosuketakamichi/research-topics/jvs_music|
+|PJS |https://sites.google.com/site/shinnosuketakamichi/research-topics/pjs_corpus|
+|JUST Song |https://sites.google.com/site/shinnosuketakamichi/publication/jsut-song|
+|MUSDB18 |https://sigsep.github.io/datasets/musdb.html#musdb18-compressed-stems|
+|DSD100 |https://sigsep.github.io/datasets/dsd100.html|
+|Aishell-3 |http://www.aishelltech.com/aishell_3|
+|VCTK |https://datashare.ed.ac.uk/handle/10283/2651|
+|Korean Songs |http://urisori.co.kr/urisori-en/doku.php/|
+
+## Code sources and references
+
+https://github.com/facebookresearch/speech-resynthesis [paper](https://arxiv.org/abs/2104.00355)
+
+https://github.com/jaywalnut310/vits [paper](https://arxiv.org/abs/2106.06103)
+
+https://github.com/openai/whisper/ [paper](https://arxiv.org/abs/2212.04356)
+
+https://github.com/NVIDIA/BigVGAN [paper](https://arxiv.org/abs/2206.04658)
+
+https://github.com/mindslab-ai/univnet [paper](https://arxiv.org/abs/2106.07889)
+
+https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts/tree/master/project/01-nsf
+
+https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS
+
+https://github.com/brentspell/hifi-gan-bwe
+
+https://github.com/mozilla/TTS
+
+https://github.com/bshall/soft-vc
+
+https://github.com/maxrmorrison/torchcrepe
+
+https://github.com/MoonInTheRiver/DiffSinger
+
+https://github.com/OlaWod/FreeVC [paper](https://arxiv.org/abs/2210.15418)
+
+https://github.com/yl4579/HiFTNet [paper](https://arxiv.org/abs/2309.09493)
+
+[Autoregressive neural f0 model for statistical parametric speech synthesis](https://web.archive.org/web/20210718024752id_/https://ieeexplore.ieee.org/ielx7/6570655/8356719/08341752.pdf)
+
+[One-shot Voice Conversion by Separating Speaker and Content Representations with Instance Normalization](https://arxiv.org/abs/1904.05742)
+
+[SNAC : Speaker-normalized Affine Coupling Layer in Flow-based Architecture for Zero-Shot Multi-Speaker Text-to-Speech](https://github.com/hcy71o/SNAC)
+
+[Adapter-Based Extension of Multi-Speaker Text-to-Speech Model for New Speakers](https://arxiv.org/abs/2211.00585)
+
+[AdaSpeech: Adaptive Text to Speech for Custom Voice](https://arxiv.org/pdf/2103.00993.pdf)
+
+[AdaVITS: Tiny VITS for Low Computing Resource Speaker Adaptation](https://arxiv.org/pdf/2206.00208.pdf)
+
+[Cross-Speaker Prosody Transfer on Any Text for Expressive Speech Synthesis](https://github.com/ubisoft/ubisoft-laforge-daft-exprt)
+
+[Learn to Sing by Listening: Building Controllable Virtual Singer by Unsupervised Learning from Voice Recordings](https://arxiv.org/abs/2305.05401)
+
+[Adversarial Speaker Disentanglement Using Unannotated External Data for Self-supervised Representation Based Voice Conversion](https://arxiv.org/pdf/2305.09167.pdf)
+
+[Multilingual Speech Synthesis and Cross-Language Voice Cloning: GRL](https://arxiv.org/abs/1907.04448)
+
+[RoFormer: Enhanced Transformer with rotary position embedding](https://arxiv.org/abs/2104.09864)
+
+## Method of Preventing Timbre Leakage Based on Data Perturbation
+
+https://github.com/auspicious3000/contentvec/blob/main/contentvec/data/audio/audio_utils_1.py
+
+https://github.com/revsic/torch-nansy/blob/main/utils/augment/praat.py
+
+https://github.com/revsic/torch-nansy/blob/main/utils/augment/peq.py
+
+https://github.com/biggytruck/SpeechSplit2/blob/main/utils.py
+
+https://github.com/OlaWod/FreeVC/blob/main/preprocess_sr.py
+
+## Contributors
+
+
+
+
+
+## Thanks to
+
+https://github.com/Francis-Komizu/Sovits
+
+## Relevant Projects
+- [LoRA-SVC](https://github.com/PlayVoice/lora-svc): decoder only svc
+- [Grad-SVC](https://github.com/PlayVoice/Grad-SVC): diffusion based svc
+
+## Original evidence
+2022.04.12 https://mp.weixin.qq.com/s/autNBYCsG4_SvWt2-Ll_zA
+
+2022.04.22 https://github.com/PlayVoice/VI-SVS
+
+2022.07.26 https://mp.weixin.qq.com/s/qC4TJy-4EVdbpvK2cQb1TA
+
+2022.09.08 https://github.com/PlayVoice/VI-SVC
+
+## Be copied by svc-develop-team/so-vits-svc
+
diff --git a/README_ZH.md b/README_ZH.md
new file mode 100644
index 0000000000000000000000000000000000000000..cf9571210c7a3d319a65dae3c4a449f6473a3bd9
--- /dev/null
+++ b/README_ZH.md
@@ -0,0 +1,418 @@
+
+
Variational Inference with adversarial learning for end-to-end Singing Voice Conversion based on VITS
+
+[](https://huggingface.co/spaces/maxmax20160403/sovits5.0)
+

+

+

+

+
+
+
+### 本项目使用简洁明了的代码结构,用于深度学习技术的研究
+### 基于学习的目的,本项目并不追求效果极限、而更多的为学生笔记本考虑,采用了低配置参数、最终预训练模型为202M(包括生成器和判别器,且为float32模型),远远小于同类项目模型大小
+### 如果你寻找的是直接可用的项目,本项目并不适合你
+
+- 本项目的目标群体是:深度学习初学者,具备Python和PyTorch的基本操作是使用本项目的前置条件;
+- 本项目旨在帮助深度学习初学者,摆脱枯燥的纯理论学习,通过与实践结合,熟练掌握深度学习基本知识;
+- 本项目不支持实时变声;(支持需要换掉whisper)
+- 本项目不会开发用于其他用途的一键包
+### 代码详解课程
+- 1-整体框架 https://www.bilibili.com/video/BV1Tj411e7pQ
+- 2-数据准备和预处理 https://www.bilibili.com/video/BV1uj411v7zW
+- 3-先验后验编码器 https://www.bilibili.com/video/BV1Be411Q7r5
+- 4-decoder部分 https://www.bilibili.com/video/BV19u4y1b73U
+- 5-蛇形激活函数 https://www.bilibili.com/video/BV1HN4y1D7AR
+- 6-Flow部分 https://www.bilibili.com/video/BV1ju411F7Fs
+- 7-训练及损失函数部分 https://www.bilibili.com/video/BV1qw411W73B
+- 8-训练推理以及基频矫正 https://www.bilibili.com/video/BV1eb4y1u7ER
+
+
+
+- 【无 泄漏】支持多发音人
+
+- 【捏 音色】创造独有发音人
+
+- 【带 伴奏】也能进行转换,轻度伴奏
+
+- 【用 Excel】进行原始调教,纯手工
+
+https://github.com/PlayVoice/so-vits-svc-5.0/assets/16432329/63858332-cc0d-40e1-a216-6fe8bf638f7c
+
+Powered by [@ShadowVap](https://space.bilibili.com/491283091)
+
+## 模型特点:
+
+| Feature | From | Status | Function |
+| :--- | :--- | :--- | :--- |
+| whisper | OpenAI | ✅ | 强大的抗噪能力 |
+| bigvgan | NVIDA | ✅ | 抗锯齿与蛇形激活,共振峰更清晰,提升音质明显 |
+| natural speech | Microsoft | ✅ | 减少发音错误 |
+| neural source-filter | NII | ✅ | 解决断音问题 |
+| speaker encoder | Google | ✅ | 音色编码与聚类 |
+| GRL for speaker | Ubisoft |✅ | 对抗去音色 |
+| SNAC | Samsung | ✅ | VITS 一句话克隆 |
+| SCLN | Microsoft | ✅ | 改善克隆 |
+| PPG perturbation | 本项目 | ✅ | 提升抗噪性和去音色 |
+| HuBERT perturbation | 本项目 | ✅ | 提升抗噪性和去音色 |
+| VAE perturbation | 本项目 | ✅ | 提升音质 |
+| Mix encoder | 本项目 | ✅ | 提升转换稳定性 |
+| USP 推理 | 本项目 | ✅ | 提升转换稳定性 |
+
+**USP : 即使unvoice和silence在推理的时候,也有Pitch,这个Pitch平滑链接voice段**
+
+
+## 为什么要mix
+
+
+
+## 安装环境
+
+1. 安装[PyTorch](https://pytorch.org/get-started/locally/)
+
+2. 安装项目依赖
+ ```
+ pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
+ ```
+ **注意:不能额外安装whisper,否则会和代码内置whisper冲突**
+
+3. 下载[音色编码器](https://drive.google.com/drive/folders/15oeBYf6Qn1edONkVLXe82MzdIi3O_9m3), 把`best_model.pth.tar`放到`speaker_pretrain/`里面 (**不要解压**)
+
+4. 下载[whisper-large-v2模型](https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt),把`large-v2.pt`放到`whisper_pretrain/`里面
+
+5. 下载[hubert_soft模型](https://github.com/bshall/hubert/releases/tag/v0.1),把`hubert-soft-0d54a1f4.pt`放到`hubert_pretrain/`里面
+
+6. 下载音高提取模型[crepe full](https://github.com/maxrmorrison/torchcrepe/tree/master/torchcrepe/assets),把`full.pth`放到`crepe/assets`里面
+
+ **注意:full.pth为84.9M,请确认文件大小无误**
+
+7. 下载[sovits5.0.pretrain.pth](https://github.com/PlayVoice/so-vits-svc-5.0/releases/tag/5.0/), 把它放到`vits_pretrain/`里面,推理测试
+
+ > python svc_inference.py --config configs/base.yaml --model ./vits_pretrain/sovits5.0.pretrain.pth --spk ./configs/singers/singer0001.npy --wave test.wav
+
+## 数据集准备
+1. 人声分离,如果数据集没有BGM直接跳过此步骤(推荐使用[UVR](https://github.com/Anjok07/ultimatevocalremovergui)中的3_HP-Vocal-UVR模型或者htdemucs_ft模型抠出数据集中的人声)
+2. 用[slicer](https://github.com/flutydeer/audio-slicer)剪切音频,whisper要求为小于30秒(建议丢弃不足2秒的音频,短音频大多没有音素,有可能会影响训练效果)
+3. 手动筛选经过第1步和第2步处理过的音频,裁剪或者丢弃杂音明显的音频,如果数据集没有BGM直接跳过此步骤
+4. 用Adobe Audition进行响度平衡处理
+5. 按下面文件结构,将数据集放入dataset_raw目录
+```shell
+dataset_raw
+├───speaker0
+│ ├───000001.wav
+│ ├───...
+│ └───000xxx.wav
+└───speaker1
+ ├───000001.wav
+ ├───...
+ └───000xxx.wav
+```
+
+## 数据预处理
+
+```shell
+python svc_preprocessing.py -t 2
+```
+-t:指定线程数,必须是正整数且不得超过CPU总核心数,一般写2就可以了
+
+预处理完成后文件夹结构如下面所示
+```shell
+data_svc/
+└── waves-16k
+│ └── speaker0
+│ │ ├── 000001.wav
+│ │ └── 000xxx.wav
+│ └── speaker1
+│ ├── 000001.wav
+│ └── 000xxx.wav
+└── waves-32k
+│ └── speaker0
+│ │ ├── 000001.wav
+│ │ └── 000xxx.wav
+│ └── speaker1
+│ ├── 000001.wav
+│ └── 000xxx.wav
+└── pitch
+│ └── speaker0
+│ │ ├── 000001.pit.npy
+│ │ └── 000xxx.pit.npy
+│ └── speaker1
+│ ├── 000001.pit.npy
+│ └── 000xxx.pit.npy
+└── hubert
+│ └── speaker0
+│ │ ├── 000001.vec.npy
+│ │ └── 000xxx.vec.npy
+│ └── speaker1
+│ ├── 000001.vec.npy
+│ └── 000xxx.vec.npy
+└── whisper
+│ └── speaker0
+│ │ ├── 000001.ppg.npy
+│ │ └── 000xxx.ppg.npy
+│ └── speaker1
+│ ├── 000001.ppg.npy
+│ └── 000xxx.ppg.npy
+└── speaker
+│ └── speaker0
+│ │ ├── 000001.spk.npy
+│ │ └── 000xxx.spk.npy
+│ └── speaker1
+│ ├── 000001.spk.npy
+│ └── 000xxx.spk.npy
+└── singer
+ ├── speaker0.spk.npy
+ └── speaker1.spk.npy
+```
+
+如果您有编程基础,推荐,逐步完成数据处理,也利于学习内部工作原理
+
+- 1, 重采样
+
+ 生成采样率16000Hz音频, 存储路径为:./data_svc/waves-16k
+
+ > python prepare/preprocess_a.py -w ./dataset_raw -o ./data_svc/waves-16k -s 16000
+
+ 生成采样率32000Hz音频, 存储路径为:./data_svc/waves-32k
+
+ > python prepare/preprocess_a.py -w ./dataset_raw -o ./data_svc/waves-32k -s 32000
+
+- 2, 使用16K音频,提取音高
+
+ > python prepare/preprocess_crepe.py -w data_svc/waves-16k/ -p data_svc/pitch
+
+- 3, 使用16k音频,提取内容编码
+ > python prepare/preprocess_ppg.py -w data_svc/waves-16k/ -p data_svc/whisper
+
+- 4, 使用16k音频,提取内容编码
+ > python prepare/preprocess_hubert.py -w data_svc/waves-16k/ -v data_svc/hubert
+
+- 5, 使用16k音频,提取音色编码
+ > python prepare/preprocess_speaker.py data_svc/waves-16k/ data_svc/speaker
+
+- 6, 提取音色编码均值;用于推理,也可作为发音人统一音色用于生成训练索引(数据音色变化不大的情况下)
+ > python prepare/preprocess_speaker_ave.py data_svc/speaker/ data_svc/singer
+
+- 7, 使用32k音频,提取线性谱
+ > python prepare/preprocess_spec.py -w data_svc/waves-32k/ -s data_svc/specs
+
+- 8, 使用32k音频,生成训练索引
+ > python prepare/preprocess_train.py
+
+- 9, 训练文件调试
+ > python prepare/preprocess_zzz.py
+
+## 训练
+0. 参数调整
+ 如果基于预训练模型微调,需要下载预训练模型[sovits5.0.pretrain.pth](https://github.com/PlayVoice/so-vits-svc-5.0/releases/tag/5.0)并且放在项目根目录下面
+ 并且修改`configs/base.yaml`的参数`pretrain: "./vits_pretrain/sovits5.0.pretrain.pth"`,并适当调小学习率(建议从5e-5开始尝试)
+ **learning_rate & batch_size & accum_step 为三个紧密相关的参数,需要仔细调节**
+ **batch_size 乘以 accum_step 通常等于 16 或 32,对于低显存GPU,可以尝试 batch_size = 4,accum_step = 4**
+
+1. 开始训练
+ ```
+ python svc_trainer.py -c configs/base.yaml -n sovits5.0
+ ```
+2. 恢复训练
+ ```
+ python svc_trainer.py -c configs/base.yaml -n sovits5.0 -p chkpt/sovits5.0/sovits5.0_***.pt
+ ```
+3. 训练日志可视化
+ ```
+ tensorboard --logdir logs/
+ ```
+
+
+
+
+
+## 推理
+1. 导出推理模型:文本编码器,Flow网络,Decoder网络;判别器和后验编码器等只在训练中使用
+ ```
+ python svc_export.py --config configs/base.yaml --checkpoint_path chkpt/sovits5.0/***.pt
+ ```
+2. 推理
+- 如果不想手动调整f0,只需要最终的推理结果,运行下面的命令即可
+ ```
+ python svc_inference.py --config configs/base.yaml --model sovits5.0.pth --spk ./data_svc/singer/修改成对应的名称.npy --wave test.wav --shift 0
+ ```
+- 如果需要手动调整f0,依据下面的流程操作
+
+ - 使用whisper提取内容编码,生成test.ppg.npy
+ ```
+ python whisper/inference.py -w test.wav -p test.ppg.npy
+ ```
+
+ - 使用hubert提取内容编码,生成test.vec.npy
+ ```
+ python hubert/inference.py -w test.wav -v test.vec.npy
+ ```
+
+ - 提取csv文本格式F0参数,用Excel打开csv文件,对照Audition或者SonicVisualiser手动修改错误的F0
+ ```
+ python pitch/inference.py -w test.wav -p test.csv
+ ```
+ - 最终推理
+ ```
+ python svc_inference.py --config configs/base.yaml --model sovits5.0.pth --spk ./data_svc/singer/修改成对应的名称.npy --wave test.wav --ppg test.ppg.npy --vec test.vec.npy --pit test.csv --shift 0
+ ```
+
+3. 一些注意点
+ 当指定--ppg后,多次推理同一个音频时,可以避免重复提取音频内容编码;没有指定,也会自动提取
+
+ 当指定--vec后,多次推理同一个音频时,可以避免重复提取音频内容编码;没有指定,也会自动提取
+
+ 当指定--pit后,可以加载手工调教的F0参数;没有指定,也会自动提取
+
+ 生成文件在当前目录svc_out.wav
+
+ | args | --config | --model | --spk | --wave | --ppg | --vec | --pit | --shift |
+ | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
+ | name | 配置文件 | 模型文件 | 音色文件 | 音频文件 | ppg内容 | hubert内容 | 音高内容 | 升降调 |
+
+4. 去噪后处理
+```
+python svc_inference_post.py --ref test.wav --svc svc_out.wav --out svc_out_post.wav
+```
+
+## 两种训练模式
+- 分散模式:训练索引中,音色文件使用音频音色
+- 统一模式:训练索引中,音色文件使用发音人音色
+
+**问题:哪种情况下,哪个模式更好**
+
+## 模型融合
+```
+python svc_merge.py --model1 模型1.pt --model1 模型2.pt --rate 模型1占比(0~1)
+```
+对不同epoch的模型进行融合,可以获得比较平均的性能、削弱过拟合
+
+例如:python svc_merge.py --model1 chkpt\sovits5.0\sovits5.0_1045.pt --model2 chkpt\sovits5.0\sovits5.0_1050.pt --rate 0.4
+
+## 捏音色
+纯属巧合的取名:average -> ave -> eva,夏娃代表者孕育和繁衍
+```
+python svc_eva.py
+```
+```python
+eva_conf = {
+ './configs/singers/singer0022.npy': 0,
+ './configs/singers/singer0030.npy': 0,
+ './configs/singers/singer0047.npy': 0.5,
+ './configs/singers/singer0051.npy': 0.5,
+}
+```
+
+生成的音色文件为:eva.spk.npy
+
+## 数据集
+
+| Name | URL |
+| :--- | :--- |
+|KiSing |http://shijt.site/index.php/2021/05/16/kising-the-first-open-source-mandarin-singing-voice-synthesis-corpus/|
+|PopCS |https://github.com/MoonInTheRiver/DiffSinger/blob/master/resources/apply_form.md|
+|opencpop |https://wenet.org.cn/opencpop/download/|
+|Multi-Singer |https://github.com/Multi-Singer/Multi-Singer.github.io|
+|M4Singer |https://github.com/M4Singer/M4Singer/blob/master/apply_form.md|
+|CSD |https://zenodo.org/record/4785016#.YxqrTbaOMU4|
+|KSS |https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset|
+|JVS MuSic |https://sites.google.com/site/shinnosuketakamichi/research-topics/jvs_music|
+|PJS |https://sites.google.com/site/shinnosuketakamichi/research-topics/pjs_corpus|
+|JUST Song |https://sites.google.com/site/shinnosuketakamichi/publication/jsut-song|
+|MUSDB18 |https://sigsep.github.io/datasets/musdb.html#musdb18-compressed-stems|
+|DSD100 |https://sigsep.github.io/datasets/dsd100.html|
+|Aishell-3 |http://www.aishelltech.com/aishell_3|
+|VCTK |https://datashare.ed.ac.uk/handle/10283/2651|
+|Korean Songs |http://urisori.co.kr/urisori-en/doku.php/|
+
+## 代码来源和参考文献
+
+https://github.com/facebookresearch/speech-resynthesis [paper](https://arxiv.org/abs/2104.00355)
+
+https://github.com/jaywalnut310/vits [paper](https://arxiv.org/abs/2106.06103)
+
+https://github.com/openai/whisper/ [paper](https://arxiv.org/abs/2212.04356)
+
+https://github.com/NVIDIA/BigVGAN [paper](https://arxiv.org/abs/2206.04658)
+
+https://github.com/mindslab-ai/univnet [paper](https://arxiv.org/abs/2106.07889)
+
+https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts/tree/master/project/01-nsf
+
+https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS
+
+https://github.com/brentspell/hifi-gan-bwe
+
+https://github.com/mozilla/TTS
+
+https://github.com/bshall/soft-vc
+
+https://github.com/maxrmorrison/torchcrepe
+
+https://github.com/MoonInTheRiver/DiffSinger
+
+https://github.com/OlaWod/FreeVC [paper](https://arxiv.org/abs/2210.15418)
+
+https://github.com/yl4579/HiFTNet [paper](https://arxiv.org/abs/2309.09493)
+
+[One-shot Voice Conversion by Separating Speaker and Content Representations with Instance Normalization](https://arxiv.org/abs/1904.05742)
+
+[SNAC : Speaker-normalized Affine Coupling Layer in Flow-based Architecture for Zero-Shot Multi-Speaker Text-to-Speech](https://github.com/hcy71o/SNAC)
+
+[Adapter-Based Extension of Multi-Speaker Text-to-Speech Model for New Speakers](https://arxiv.org/abs/2211.00585)
+
+[AdaSpeech: Adaptive Text to Speech for Custom Voice](https://arxiv.org/pdf/2103.00993.pdf)
+
+[AdaVITS: Tiny VITS for Low Computing Resource Speaker Adaptation](https://arxiv.org/pdf/2206.00208.pdf)
+
+[Cross-Speaker Prosody Transfer on Any Text for Expressive Speech Synthesis](https://github.com/ubisoft/ubisoft-laforge-daft-exprt)
+
+[Learn to Sing by Listening: Building Controllable Virtual Singer by Unsupervised Learning from Voice Recordings](https://arxiv.org/abs/2305.05401)
+
+[Adversarial Speaker Disentanglement Using Unannotated External Data for Self-supervised Representation Based Voice Conversion](https://arxiv.org/pdf/2305.09167.pdf)
+
+[Multilingual Speech Synthesis and Cross-Language Voice Cloning: GRL](https://arxiv.org/abs/1907.04448)
+
+[RoFormer: Enhanced Transformer with rotary position embedding](https://arxiv.org/abs/2104.09864))https://github.com/facebookresearch/speech-resynthesis [paper](https://arxiv.org/abs/2104.00355)
+
+## 基于数据扰动防止音色泄露的方法
+
+https://github.com/auspicious3000/contentvec/blob/main/contentvec/data/audio/audio_utils_1.py
+
+https://github.com/revsic/torch-nansy/blob/main/utils/augment/praat.py
+
+https://github.com/revsic/torch-nansy/blob/main/utils/augment/peq.py
+
+https://github.com/biggytruck/SpeechSplit2/blob/main/utils.py
+
+https://github.com/OlaWod/FreeVC/blob/main/preprocess_sr.py
+
+## 贡献者
+
+
+
+
+
+## 特别感谢
+
+https://github.com/Francis-Komizu/Sovits
+
+## 原创过程
+2022.04.12 https://mp.weixin.qq.com/s/autNBYCsG4_SvWt2-Ll_zA
+
+2022.04.22 https://github.com/PlayVoice/VI-SVS
+
+2022.07.26 https://mp.weixin.qq.com/s/qC4TJy-4EVdbpvK2cQb1TA
+
+2022.09.08 https://github.com/PlayVoice/VI-SVC
+
+## 被这个项目拷贝:svc-develop-team/so-vits-svc
+
+
+
+
+
+
+## Rcell对拷贝的真实回应
+
+
diff --git a/__pycache__/svc_inference.cpython-310.pyc b/__pycache__/svc_inference.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a79eaa7c24684227ac10306cb42310722ef8672a
Binary files /dev/null and b/__pycache__/svc_inference.cpython-310.pyc differ
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..33e8eb3051352a1e2f93eaebcdc8681b5e6799fb
--- /dev/null
+++ b/app.py
@@ -0,0 +1,444 @@
+import os
+import subprocess
+import yaml
+import sys
+import webbrowser
+import gradio as gr
+from ruamel.yaml import YAML
+import shutil
+import soundfile
+import shlex
+import locale
+
+class WebUI:
+ def __init__(self):
+ self.train_config_path = 'configs/train.yaml'
+ self.info = Info()
+ self.names = []
+ self.names2 = []
+ self.voice_names = []
+ self.base_config_path = 'configs/base.yaml'
+ if not os.path.exists(self.train_config_path):
+ shutil.copyfile(self.base_config_path, self.train_config_path)
+ print(i18n("初始化成功"))
+ else:
+ print(i18n("就绪"))
+ self.main_ui()
+
+ def main_ui(self):
+ with gr.Blocks(theme=gr.themes.Base(primary_hue=gr.themes.colors.green)) as ui:
+
+ gr.Markdown('# so-vits-svc5.0 WebUI')
+
+ with gr.Tab(i18n("预处理-训练")):
+
+ with gr.Accordion(i18n('训练说明'), open=False):
+
+ gr.Markdown(self.info.train)
+
+ gr.Markdown(i18n('### 预处理参数设置'))
+
+ with gr.Row():
+
+ self.model_name = gr.Textbox(value='sovits5.0', label='model', info=i18n('模型名称'), interactive=True) #建议设置为不可修改
+
+ self.f0_extractor = gr.Textbox(value='crepe', label='f0_extractor', info=i18n('f0提取器'), interactive=False)
+
+ self.thread_count = gr.Slider(minimum=1, maximum=os.cpu_count(), step=1, value=2, label='thread_count', info=i18n('预处理线程数'), interactive=True)
+
+ gr.Markdown(i18n('### 训练参数设置'))
+
+ with gr.Row():
+
+ self.learning_rate = gr.Number(value=5e-5, label='learning_rate', info=i18n('学习率'), interactive=True)
+
+ self.batch_size = gr.Slider(minimum=1, maximum=50, step=1, value=6, label='batch_size', info=i18n('批大小'), interactive=True)
+
+ with gr.Row():
+
+ self.info_interval = gr.Number(value=50, label='info_interval', info=i18n('训练日志记录间隔(step)'), interactive=True)
+
+ self.eval_interval = gr.Number(value=1, label='eval_interval', info=i18n('验证集验证间隔(epoch)'), interactive=True)
+
+ self.save_interval = gr.Number(value=5, label='save_interval', info=i18n('检查点保存间隔(epoch)'), interactive=True)
+
+ self.keep_ckpts = gr.Number(value=0, label='keep_ckpts', info=i18n('保留最新的检查点文件(0保存全部)'),interactive=True)
+
+ with gr.Row():
+
+ self.slow_model = gr.Checkbox(label=i18n("是否添加底模"), value=True, interactive=True)
+
+ gr.Markdown(i18n('### 开始训练'))
+
+ with gr.Row():
+
+ self.bt_open_dataset_folder = gr.Button(value=i18n('打开数据集文件夹'))
+
+ self.bt_onekey_train = gr.Button(i18n('一键训练'), variant="primary")
+
+ self.bt_tb = gr.Button(i18n('启动Tensorboard'), variant="primary")
+
+ gr.Markdown(i18n('### 恢复训练'))
+
+ with gr.Row():
+
+ self.resume_model = gr.Dropdown(choices=sorted(self.names), label='Resume training progress from checkpoints', info=i18n('从检查点恢复训练进度'), interactive=True)
+
+ with gr.Column():
+
+ self.bt_refersh = gr.Button(i18n('刷新'))
+
+ self.bt_resume_train = gr.Button(i18n('恢复训练'), variant="primary")
+
+ with gr.Tab(i18n("推理")):
+
+ with gr.Accordion(i18n('推理说明'), open=False):
+
+ gr.Markdown(self.info.inference)
+
+ gr.Markdown(i18n('### 推理参数设置'))
+
+ with gr.Row():
+
+ with gr.Column():
+
+ self.keychange = gr.Slider(-24, 24, value=0, step=1, label=i18n('变调'))
+
+ self.file_list = gr.Markdown(value="", label=i18n("文件列表"))
+
+ with gr.Row():
+
+ self.resume_model2 = gr.Dropdown(choices=sorted(self.names2), label='Select the model you want to export',
+ info=i18n('选择要导出的模型'), interactive=True)
+ with gr.Column():
+
+ self.bt_refersh2 = gr.Button(value=i18n('刷新模型和音色'))
+
+
+ self.bt_out_model = gr.Button(value=i18n('导出模型'), variant="primary")
+
+ with gr.Row():
+
+ self.resume_voice = gr.Dropdown(choices=sorted(self.voice_names), label='Select the sound file',
+ info=i18n('选择音色文件'), interactive=True)
+
+ with gr.Row():
+
+ self.input_wav = gr.Audio(type='filepath', label=i18n('选择待转换音频'), source='upload')
+
+ with gr.Row():
+
+ self.bt_infer = gr.Button(value=i18n('开始转换'), variant="primary")
+
+ with gr.Row():
+
+ self.output_wav = gr.Audio(label=i18n('输出音频'), interactive=False)
+
+ self.bt_open_dataset_folder.click(fn=self.openfolder)
+ self.bt_onekey_train.click(fn=self.onekey_training,inputs=[self.model_name, self.thread_count,self.learning_rate,self.batch_size, self.info_interval, self.eval_interval,self.save_interval, self.keep_ckpts, self.slow_model])
+ self.bt_out_model.click(fn=self.out_model, inputs=[self.model_name, self.resume_model2])
+ self.bt_tb.click(fn=self.tensorboard)
+ self.bt_refersh.click(fn=self.refresh_model, inputs=[self.model_name], outputs=[self.resume_model])
+ self.bt_resume_train.click(fn=self.resume_train, inputs=[self.model_name, self.resume_model, self.learning_rate,self.batch_size, self.info_interval, self.eval_interval,self.save_interval, self.keep_ckpts, self.slow_model])
+ self.bt_infer.click(fn=self.inference, inputs=[self.input_wav, self.resume_voice, self.keychange], outputs=[self.output_wav])
+ self.bt_refersh2.click(fn=self.refresh_model_and_voice, inputs=[self.model_name],outputs=[self.resume_model2, self.resume_voice])
+
+ ui.launch(inbrowser=True, server_port=2333, share=True)
+
+ def openfolder(self):
+
+ try:
+ if sys.platform.startswith('win'):
+ os.startfile('dataset_raw')
+ elif sys.platform.startswith('linux'):
+ subprocess.call(['xdg-open', 'dataset_raw'])
+ elif sys.platform.startswith('darwin'):
+ subprocess.call(['open', 'dataset_raw'])
+ else:
+ print(i18n('打开文件夹失败!'))
+ except BaseException:
+ print(i18n('打开文件夹失败!'))
+
+ def preprocessing(self, thread_count):
+ print(i18n('开始预处理'))
+ train_process = subprocess.Popen('python -u svc_preprocessing.py -t ' + str(thread_count), stdout=subprocess.PIPE)
+ while train_process.poll() is None:
+ output = train_process.stdout.readline().decode('utf-8')
+ print(output, end='')
+
+ def create_config(self, model_name, learning_rate, batch_size, info_interval, eval_interval, save_interval,
+ keep_ckpts, slow_model):
+ yaml = YAML()
+ yaml.preserve_quotes = True
+ yaml.width = 1024
+ with open("configs/train.yaml", "r") as f:
+ config = yaml.load(f)
+ config['train']['model'] = model_name
+ config['train']['learning_rate'] = learning_rate
+ config['train']['batch_size'] = batch_size
+ config["log"]["info_interval"] = int(info_interval)
+ config["log"]["eval_interval"] = int(eval_interval)
+ config["log"]["save_interval"] = int(save_interval)
+ config["log"]["keep_ckpts"] = int(keep_ckpts)
+ if slow_model:
+ config["train"]["pretrain"] = "vits_pretrain\sovits5.0.pretrain.pth"
+ else:
+ config["train"]["pretrain"] = ""
+ with open("configs/train.yaml", "w") as f:
+ yaml.dump(config, f)
+ return f"{config['log']}"
+
+ def training(self, model_name):
+ print(i18n('开始训练'))
+ train_process = subprocess.Popen('python -u svc_trainer.py -c ' + self.train_config_path + ' -n ' + str(model_name), stdout=subprocess.PIPE, creationflags=subprocess.CREATE_NEW_CONSOLE)
+ while train_process.poll() is None:
+ output = train_process.stdout.readline().decode('utf-8')
+ print(output, end='')
+
+ def onekey_training(self, model_name, thread_count, learning_rate, batch_size, info_interval, eval_interval, save_interval, keep_ckpts, slow_model):
+ print(self, model_name, thread_count, learning_rate, batch_size, info_interval, eval_interval,
+ save_interval, keep_ckpts)
+ self.create_config(model_name, learning_rate, batch_size, info_interval, eval_interval, save_interval, keep_ckpts, slow_model)
+ self.preprocessing(thread_count)
+ self.training(model_name)
+
+ def out_model(self, model_name, resume_model2):
+ print(i18n('开始导出模型'))
+ try:
+ subprocess.Popen('python -u svc_export.py -c {} -p "chkpt/{}/{}"'.format(self.train_config_path, model_name, resume_model2),stdout=subprocess.PIPE)
+ print(i18n('导出模型成功'))
+ except Exception as e:
+ print(i18n("出现错误:"), e)
+
+
+ def tensorboard(self):
+ if sys.platform.startswith('win'):
+ tb_process = subprocess.Popen('tensorboard --logdir=logs --port=6006', stdout=subprocess.PIPE)
+ webbrowser.open("http://localhost:6006")
+ else:
+ p1 = subprocess.Popen(["ps", "-ef"], stdout=subprocess.PIPE) #ps -ef | grep tensorboard | awk '{print $2}' | xargs kill -9
+ p2 = subprocess.Popen(["grep", "tensorboard"], stdin=p1.stdout, stdout=subprocess.PIPE)
+ p3 = subprocess.Popen(["awk", "{print $2}"], stdin=p2.stdout, stdout=subprocess.PIPE)
+ p4 = subprocess.Popen(["xargs", "kill", "-9"], stdin=p3.stdout)
+ p1.stdout.close()
+ p2.stdout.close()
+ p3.stdout.close()
+ p4.communicate()
+ tb_process = subprocess.Popen('tensorboard --logdir=logs --port=6007', stdout=subprocess.PIPE) # AutoDL端口设置为6007
+ while tb_process.poll() is None:
+ output = tb_process.stdout.readline().decode('utf-8')
+ print(output)
+
+ def refresh_model(self, model_name):
+ self.script_dir = os.path.dirname(os.path.abspath(__file__))
+ self.model_root = os.path.join(self.script_dir, f"chkpt/{model_name}")
+ self.names = []
+ try:
+ for self.name in os.listdir(self.model_root):
+ if self.name.endswith(".pt"):
+ self.names.append(self.name)
+ return {"choices": sorted(self.names), "__type__": "update"}
+ except FileNotFoundError:
+ return {"label": i18n("缺少模型文件"), "__type__": "update"}
+
+ def refresh_model2(self, model_name):
+ self.script_dir = os.path.dirname(os.path.abspath(__file__))
+ self.model_root = os.path.join(self.script_dir, f"chkpt/{model_name}")
+ self.names2 = []
+ try:
+ for self.name in os.listdir(self.model_root):
+ if self.name.endswith(".pt"):
+ self.names2.append(self.name)
+ return {"choices": sorted(self.names2), "__type__": "update"}
+ except FileNotFoundError:
+ return {"label": i18n("缺少模型文件"), "__type__": "update"}
+
+ def refresh_voice(self):
+ self.script_dir = os.path.dirname(os.path.abspath(__file__))
+ self.model_root = os.path.join(self.script_dir, "data_svc/singer")
+ self.voice_names = []
+ try:
+ for self.name in os.listdir(self.model_root):
+ if self.name.endswith(".npy"):
+ self.voice_names.append(self.name)
+ return {"choices": sorted(self.voice_names), "__type__": "update"}
+ except FileNotFoundError:
+ return {"label": i18n("缺少文件"), "__type__": "update"}
+
+ def refresh_model_and_voice(self, model_name):
+ model_update = self.refresh_model2(model_name)
+ voice_update = self.refresh_voice()
+ return model_update, voice_update
+
+ def resume_train(self, model_name, resume_model ,learning_rate, batch_size, info_interval, eval_interval, save_interval, keep_ckpts, slow_model):
+ print(i18n('开始恢复训练'))
+ self.create_config(model_name, learning_rate, batch_size, info_interval, eval_interval, save_interval,keep_ckpts, slow_model)
+ train_process = subprocess.Popen('python -u svc_trainer.py -c {} -n {} -p "chkpt/{}/{}"'.format(self.train_config_path, model_name, model_name, resume_model), stdout=subprocess.PIPE, creationflags=subprocess.CREATE_NEW_CONSOLE)
+ while train_process.poll() is None:
+ output = train_process.stdout.readline().decode('utf-8')
+ print(output, end='')
+
+ def inference(self, input, resume_voice, keychange):
+ if os.path.exists("test.wav"):
+ os.remove("test.wav")
+ print(i18n("已清理残留文件"))
+ else:
+ print(i18n("无需清理残留文件"))
+ self.train_config_path = 'configs/train.yaml'
+ print(i18n('开始推理'))
+ shutil.copy(input, ".")
+ input_name = os.path.basename(input)
+ os.rename(input_name, "test.wav")
+ input_name = "test.wav"
+ if not input_name.endswith(".wav"):
+ data, samplerate = soundfile.read(input_name)
+ input_name = input_name.rsplit(".", 1)[0] + ".wav"
+ soundfile.write(input_name, data, samplerate)
+ train_config_path = shlex.quote(self.train_config_path)
+ keychange = shlex.quote(str(keychange))
+ cmd = ["python", "-u", "svc_inference.py", "--config", train_config_path, "--model", "sovits5.0.pth", "--spk",
+ f"data_svc/singer/{resume_voice}", "--wave", "test.wav", "--shift", keychange]
+ train_process = subprocess.run(cmd, shell=False, capture_output=True, text=True)
+ print(train_process.stdout)
+ print(train_process.stderr)
+ print(i18n("推理成功"))
+ return "svc_out.wav"
+
+class Info:
+ def __init__(self) -> None:
+ self.train = i18n('### 2023.7.11|[@OOPPEENN](https://github.com/OOPPEENN)第一次编写|[@thestmitsuk](https://github.com/thestmitsuki)二次补完')
+
+ self.inference = i18n('### 2023.7.11|[@OOPPEENN](https://github.com/OOPPEENN)第一次编写|[@thestmitsuk](https://github.com/thestmitsuki)二次补完')
+
+
+LANGUAGE_LIST = ['zh_CN', 'en_US']
+LANGUAGE_ALL = {
+ 'zh_CN': {
+ 'SUPER': 'END',
+ 'LANGUAGE': 'zh_CN',
+ '初始化成功': '初始化成功',
+ '就绪': '就绪',
+ '预处理-训练': '预处理-训练',
+ '训练说明': '训练说明',
+ '### 预处理参数设置': '### 预处理参数设置',
+ '模型名称': '模型名称',
+ 'f0提取器': 'f0提取器',
+ '预处理线程数': '预处理线程数',
+ '### 训练参数设置': '### 训练参数设置',
+ '学习率': '学习率',
+ '批大小': '批大小',
+ '训练日志记录间隔(step)': '训练日志记录间隔(step)',
+ '验证集验证间隔(epoch)': '验证集验证间隔(epoch)',
+ '检查点保存间隔(epoch)': '检查点保存间隔(epoch)',
+ '保留最新的检查点文件(0保存全部)': '保留最新的检查点文件(0保存全部)',
+ '是否添加底模': '是否添加底模',
+ '### 开始训练': '### 开始训练',
+ '打开数据集文件夹': '打开数据集文件夹',
+ '一键训练': '一键训练',
+ '启动Tensorboard': '启动Tensorboard',
+ '### 恢复训练': '### 恢复训练',
+ '从检查点恢复训练进度': '从检查点恢复训练进度',
+ '刷新': '刷新',
+ '恢复训练': '恢复训练',
+ '推理': '推理',
+ '推理说明': '推理说明',
+ '### 推理参数设置': '### 推理参数设置',
+ '变调': '变调',
+ '文件列表': '文件列表',
+ '选择要导出的模型': '选择要导出的模型',
+ '刷新模型和音色': '刷新模型和音色',
+ '导出模型': '导出模型',
+ '选择音色文件': '选择音色文件',
+ '选择待转换音频': '选择待转换音频',
+ '开始转换': '开始转换',
+ '输出音频': '输出音频',
+ '打开文件夹失败!': '打开文件夹失败!',
+ '开始预处理': '开始预处理',
+ '开始训练': '开始训练',
+ '开始导出模型': '开始导出模型',
+ '导出模型成功': '导出模型成功',
+ '出现错误:': '出现错误:',
+ '缺少模型文件': '缺少模型文件',
+ '缺少文件': '缺少文件',
+ '已清理残留文件': '已清理残留文件',
+ '无需清理残留文件': '无需清理残留文件',
+ '开始推理': '开始推理',
+ '推理成功': '推理成功',
+ '### 2023.7.11|[@OOPPEENN](https://github.com/OOPPEENN)第一次编写|[@thestmitsuk](https://github.com/thestmitsuki)二次补完': '### 2023.7.11|[@OOPPEENN](https://github.com/OOPPEENN)第一次编写|[@thestmitsuk](https://github.com/thestmitsuki)二次补完'
+ },
+ 'en_US': {
+ 'SUPER': 'zh_CN',
+ 'LANGUAGE': 'en_US',
+ '初始化成功': 'Initialization successful',
+ '就绪': 'Ready',
+ '预处理-训练': 'Preprocessing-Training',
+ '训练说明': 'Training instructions',
+ '### 预处理参数设置': '### Preprocessing parameter settings',
+ '模型名称': 'Model name',
+ 'f0提取器': 'f0 extractor',
+ '预处理线程数': 'Preprocessing thread number',
+ '### 训练参数设置': '### Training parameter settings',
+ '学习率': 'Learning rate',
+ '批大小': 'Batch size',
+ '训练日志记录间隔(step)': 'Training log recording interval (step)',
+ '验证集验证间隔(epoch)': 'Validation set validation interval (epoch)',
+ '检查点保存间隔(epoch)': 'Checkpoint save interval (epoch)',
+ '保留最新的检查点文件(0保存全部)': 'Keep the latest checkpoint file (0 save all)',
+ '是否添加底模': 'Whether to add the base model',
+ '### 开始训练': '### Start training',
+ '打开数据集文件夹': 'Open the dataset folder',
+ '一键训练': 'One-click training',
+ '启动Tensorboard': 'Start Tensorboard',
+ '### 恢复训练': '### Resume training',
+ '从检查点恢复训练进度': 'Restore training progress from checkpoint',
+ '刷新': 'Refresh',
+ '恢复训练': 'Resume training',
+ "推理": "Inference",
+ "推理说明": "Inference instructions",
+ "### 推理参数设置": "### Inference parameter settings",
+ "变调": "Pitch shift",
+ "文件列表": "File list",
+ "选择要导出的模型": "Select the model to export",
+ "刷新模型和音色": "Refresh model and timbre",
+ "导出模型": "Export model",
+ "选择音色文件": "Select timbre file",
+ "选择待转换音频": "Select audio to be converted",
+ "开始转换": "Start conversion",
+ "输出音频": "Output audio",
+ "打开文件夹失败!": "Failed to open folder!",
+ "开始预处理": "Start preprocessing",
+ "开始训练": "Start training",
+ "开始导出模型": "Start exporting model",
+ "导出模型成功": "Model exported successfully",
+ "出现错误:": "An error occurred:",
+ "缺少模型文件": "Missing model file",
+ '缺少文件': 'Missing file',
+ "已清理残留文件": "Residual files cleaned up",
+ "无需清理残留文件": "No need to clean up residual files",
+ "开始推理": "Start inference",
+ '### 2023.7.11|[@OOPPEENN](https://github.com/OOPPEENN)第一次编写|[@thestmitsuk](https://github.com/thestmitsuki)二次补完': '### 2023.7.11|[@OOPPEENN](https://github.com/OOPPEENN)first writing|[@thestmitsuk](https://github.com/thestmitsuki)second completion'
+ }
+}
+
+class I18nAuto:
+ def __init__(self, language=None):
+ self.language_list = LANGUAGE_LIST
+ self.language_all = LANGUAGE_ALL
+ self.language_map = {}
+ self.language = language or locale.getdefaultlocale()[0]
+ if self.language not in self.language_list:
+ self.language = 'zh_CN'
+ self.read_language(self.language_all['zh_CN'])
+ while self.language_all[self.language]['SUPER'] != 'END':
+ self.read_language(self.language_all[self.language])
+ self.language = self.language_all[self.language]['SUPER']
+
+ def read_language(self, lang_dict: dict):
+ self.language_map.update(lang_dict)
+
+ def __call__(self, key):
+ return self.language_map[key]
+
+if __name__ == "__main__":
+ i18n = I18nAuto()
+ webui = WebUI()
diff --git a/colab.ipynb b/colab.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..19943fa73562ea14b8775c9f56a92f0be663ab10
--- /dev/null
+++ b/colab.ipynb
@@ -0,0 +1,374 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SggegFslkbbK"
+ },
+ "source": [
+ "https://github.com/PlayVoice/so-vits-svc-5.0/\n",
+ "\n",
+ "↑原仓库\n",
+ "\n",
+ "*《colab保持连接的方法》*https://zhuanlan.zhihu.com/p/144629818\n",
+ "\n",
+ "预览版本,可使用预设模型进行推理"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "M1MdDryJP73G"
+ },
+ "source": [
+ "# **环境配置&必要文件下载**\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "xfJWCr_EkO2i"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 看看抽了个啥卡~~基本都是T4~~\n",
+ "!nvidia-smi"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "nMspj8t3knR6"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 克隆github仓库\n",
+ "!git clone https://github.com/PlayVoice/so-vits-svc-5.0/ -b bigvgan-mix-v2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Kj2j81K6kubj"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 安装依赖&下载必要文件\n",
+ "%cd /content/so-vits-svc-5.0\n",
+ "\n",
+ "!pip install -r requirements.txt\n",
+ "!pip install --upgrade pip setuptools numpy numba\n",
+ "\n",
+ "!wget -P hubert_pretrain/ https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt\n",
+ "!wget -P whisper_pretrain/ https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt\n",
+ "!wget -P speaker_pretrain/ https://github.com/PlayVoice/so-vits-svc-5.0/releases/download/dependency/best_model.pth.tar\n",
+ "!wget -P crepe/assets https://github.com/PlayVoice/so-vits-svc-5.0/releases/download/dependency/full.pth\n",
+ "!wget -P vits_pretrain https://github.com/PlayVoice/so-vits-svc-5.0/releases/download/5.0/sovits5.0.pretrain.pth"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "v9zHS9VXly9b"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 加载Google云端硬盘\n",
+ "from google.colab import drive\n",
+ "drive.mount('/content/drive')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hZ5KH8NgQ7os"
+ },
+ "source": [
+ "# 包含多说话人的推理预览"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "2o6m3D0IsphU"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 提取内容编码\n",
+ "\n",
+ "#@markdown **将处理好的\" .wav \"输入源文件上传到云盘根目录,并修改以下选项**\n",
+ "\n",
+ "#@markdown **\" .wav \"文件【文件名】**\n",
+ "input = \"\\u30AE\\u30BF\\u30FC\\u3068\\u5B64\\u72EC\\u3068\\u84BC\\u3044\\u60D1\\u661F\" #@param {type:\"string\"}\n",
+ "input_path = \"/content/drive/MyDrive/\"\n",
+ "input_name = input_path + input\n",
+ "!PYTHONPATH=. python whisper/inference.py -w {input_name}.wav -p test.ppg.npy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "A7nvX5mRlwJ7"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 推理\n",
+ "\n",
+ "#@markdown **将处理好的\" .wav \"输入源文件上传到云盘根目录,并修改以下选项**\n",
+ "\n",
+ "#@markdown **\" .wav \"文件【文件名】**\n",
+ "input = \"\\u30AE\\u30BF\\u30FC\\u3068\\u5B64\\u72EC\\u3068\\u84BC\\u3044\\u60D1\\u661F\" #@param {type:\"string\"}\n",
+ "input_path = \"/content/drive/MyDrive/\"\n",
+ "input_name = input_path + input\n",
+ "#@markdown **指定说话人(0001~0056)(推荐0022、0030、0047、0051)**\n",
+ "speaker = \"0002\" #@param {type:\"string\"}\n",
+ "!PYTHONPATH=. python svc_inference.py --config configs/base.yaml --model vits_pretrain/sovits5.0.pretrain.pth --spk ./configs/singers/singer{speaker}.npy --wave {input_name}.wav --ppg test.ppg.npy"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "F8oerogXyd3u"
+ },
+ "source": [
+ "推理结果保存在根目录,文件名为svc_out.wav"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qKX17GElPuso"
+ },
+ "source": [
+ "# 训练"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "sVe0lEGWQBLU"
+ },
+ "source": [
+ "将音频剪裁为小于30秒的音频段,响度匹配并修改为单声道,预处理时会进行重采样所以对采样率无要求。(但是降低采样率的操作会降低你的数据质量)\n",
+ "\n",
+ "**使用Adobe Audition™的响度匹配功能可以一次性完成重采样修改声道和响度匹配。**\n",
+ "\n",
+ "之后将音频文件保存为以下文件结构:\n",
+ "```\n",
+ "dataset_raw\n",
+ "├───speaker0\n",
+ "│ ├───xxx1-xxx1.wav\n",
+ "│ ├───...\n",
+ "│ └───Lxx-0xx8.wav\n",
+ "└───speaker1\n",
+ " ├───xx2-0xxx2.wav\n",
+ " ├───...\n",
+ " └───xxx7-xxx007.wav\n",
+ "```\n",
+ "\n",
+ "打包为zip格式,命名为data.zip,上传到网盘根目录。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "vC8IthV8VYgy"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 从云盘获取数据集\n",
+ "!unzip -d /content/so-vits-svc-5.0/ /content/drive/MyDrive/data.zip #自行修改路径与文件名"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "J101PiFUSL1N"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 重采样\n",
+ "# 生成采样率16000Hz音频, 存储路径为:./data_svc/waves-16k\n",
+ "!python prepare/preprocess_a.py -w ./dataset_raw -o ./data_svc/waves-16k -s 16000\n",
+ "# 生成采样率32000Hz音频, 存储路径为:./data_svc/waves-32k\n",
+ "!python prepare/preprocess_a.py -w ./dataset_raw -o ./data_svc/waves-32k -s 32000"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ZpxeYJCBSbgf"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 提取f0\n",
+ "!python prepare/preprocess_f0.py -w data_svc/waves-16k/ -p data_svc/pitch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "7VasDGhDSlP5"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 使用16k音频,提取内容编码\n",
+ "!PYTHONPATH=. python prepare/preprocess_ppg.py -w data_svc/waves-16k/ -p data_svc/whisper"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#@title 使用16k音频,提取内容编码\n",
+ "!PYTHONPATH=. python prepare/preprocess_hubert.py -w data_svc/waves-16k/ -v data_svc/hubert"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ovRqQUINSoII"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 提取音色特征\n",
+ "!PYTHONPATH=. python prepare/preprocess_speaker.py data_svc/waves-16k/ data_svc/speaker"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "s8Ba8Fd1bzzX"
+ },
+ "outputs": [],
+ "source": [
+ "#(解决“.ipynb_checkpoints”相关的错)\n",
+ "!rm -rf \"find -type d -name .ipynb_checkpoints\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ic9q599_b0Ae"
+ },
+ "outputs": [],
+ "source": [
+ "#(解决“.ipynb_checkpoints”相关的错)\n",
+ "!rm -rf .ipynb_checkpoints\n",
+ "!find . -name \".ipynb_checkpoints\" -exec rm -rf {} \\;"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "QamG3_B6o3vF"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 提取平均音色\n",
+ "!PYTHONPATH=. python prepare/preprocess_speaker_ave.py data_svc/speaker/ data_svc/singer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "3wBmyQHvSs6K"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 提取spec\n",
+ "!PYTHONPATH=. python prepare/preprocess_spec.py -w data_svc/waves-32k/ -s data_svc/specs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "tUcljCLbS5O3"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 生成索引\n",
+ "!python prepare/preprocess_train.py"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "30fXnscFS7Wo"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 训练文件调试\n",
+ "!PYTHONPATH=. python prepare/preprocess_zzz.py"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "hacR8qDFVOWo"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 设定模型备份\n",
+ "#@markdown **是否备份模型到云盘,colab随时爆炸建议备份,默认保存到云盘根目录Sovits5.0文件夹**\n",
+ "Save_to_drive = True #@param {type:\"boolean\"}\n",
+ "if Save_to_drive:\n",
+ " !mkdir -p /content/so-vits-svc-5.0/chkpt/\n",
+ " !rm -rf /content/so-vits-svc-5.0/chkpt/\n",
+ " !mkdir -p /content/drive/MyDrive/Sovits5.0\n",
+ " !ln -s /content/drive/MyDrive/Sovits5.0 /content/so-vits-svc-5.0/chkpt/"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "5BIiKIAoU3Kd"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 开始训练\n",
+ "%load_ext tensorboard\n",
+ "%tensorboard --logdir /content/so-vits-svc-5.0/logs/\n",
+ "\n",
+ "!PYTHONPATH=. python svc_trainer.py -c configs/base.yaml -n sovits5.0"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": []
+ },
+ "gpuClass": "standard",
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/configs/base.yaml b/configs/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dbf59f59eb764196566ed2479c8472ce3a1cdeb1
--- /dev/null
+++ b/configs/base.yaml
@@ -0,0 +1,72 @@
+train:
+ model: "sovits"
+ seed: 1234
+ epochs: 10000
+ learning_rate: 5e-5
+ betas: [0.8, 0.99]
+ lr_decay: 0.999875
+ eps: 1e-9
+ batch_size: 8
+ accum_step: 2
+ c_stft: 9
+ c_mel: 1.
+ c_kl: 0.2
+ port: 8001
+ pretrain: "./vits_pretrain/sovits5.0.pretrain.pth"
+#############################
+data:
+ training_files: "files/train.txt"
+ validation_files: "files/valid.txt"
+ segment_size: 8000 # WARNING: base on hop_length
+ max_wav_value: 32768.0
+ sampling_rate: 32000
+ filter_length: 1024
+ hop_length: 320
+ win_length: 1024
+ mel_channels: 100
+ mel_fmin: 50.0
+ mel_fmax: 16000.0
+#############################
+vits:
+ ppg_dim: 1280
+ vec_dim: 256
+ spk_dim: 256
+ gin_channels: 256
+ inter_channels: 192
+ hidden_channels: 192
+ filter_channels: 640
+#############################
+gen:
+ upsample_input: 192
+ upsample_rates: [5,4,4,2,2]
+ upsample_kernel_sizes: [15,8,8,4,4]
+ upsample_initial_channel: 320
+ resblock_kernel_sizes: [3,7,11]
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
+#############################
+mpd:
+ periods: [2,3,5,7,11]
+ kernel_size: 5
+ stride: 3
+ use_spectral_norm: False
+ lReLU_slope: 0.2
+#############################
+mrd:
+ resolutions: "[(1024, 120, 600), (2048, 240, 1200), (4096, 480, 2400), (512, 50, 240)]" # (filter_length, hop_length, win_length)
+ use_spectral_norm: False
+ lReLU_slope: 0.2
+#############################
+log:
+ info_interval: 100
+ eval_interval: 1
+ save_interval: 5
+ num_audio: 6
+ pth_dir: 'chkpt'
+ log_dir: 'logs'
+ keep_ckpts: 0
+#############################
+dist_config:
+ dist_backend: "nccl"
+ dist_url: "tcp://localhost:54321"
+ world_size: 1
+
diff --git a/configs/singers/singer0001.npy b/configs/singers/singer0001.npy
new file mode 100644
index 0000000000000000000000000000000000000000..9352a330a6ac78e14129c5062d2235c05b15668c
--- /dev/null
+++ b/configs/singers/singer0001.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e2879921d43bdbf11fc5d6ac91f434f905a2c5e59d75368bfbf3c6bdbddcb3cf
+size 1152
diff --git a/configs/singers/singer0002.npy b/configs/singers/singer0002.npy
new file mode 100644
index 0000000000000000000000000000000000000000..b8ccb3f218758254f2971a3dbeaa5340e7377c7f
--- /dev/null
+++ b/configs/singers/singer0002.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fbe5c7925c2fdb514e2c5b450de1d2737ec7f86f1c65eeb488c1888c0b9a7069
+size 1152
diff --git a/configs/singers/singer0003.npy b/configs/singers/singer0003.npy
new file mode 100644
index 0000000000000000000000000000000000000000..3a92f50cbd7336703910831f03ac7d1cb029a90e
--- /dev/null
+++ b/configs/singers/singer0003.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5665126aeb6c6fab89c79b90debf2ce2e64b321076dcb414089eff8848eac8b4
+size 1152
diff --git a/configs/singers/singer0004.npy b/configs/singers/singer0004.npy
new file mode 100644
index 0000000000000000000000000000000000000000..6ef48a0a0cb042ff8c419f747261c579aa66520e
--- /dev/null
+++ b/configs/singers/singer0004.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:79f0fe5993e9adcaeae25b0fa68265d40c9c1b5539ca12d6e438477de2177819
+size 1152
diff --git a/configs/singers/singer0005.npy b/configs/singers/singer0005.npy
new file mode 100644
index 0000000000000000000000000000000000000000..ebe4251d2aef83c2c9db470bec9cdff8cf97e769
--- /dev/null
+++ b/configs/singers/singer0005.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1158fb447929cf9400a31675cf9992fd3ed7558e061562189d9e6bf56d83fb2a
+size 1152
diff --git a/configs/singers/singer0006.npy b/configs/singers/singer0006.npy
new file mode 100644
index 0000000000000000000000000000000000000000..6336c044c3b98765c5bc1f9121ef36465bcaf79e
--- /dev/null
+++ b/configs/singers/singer0006.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:06c1fd3a9afaa7944e4b81b7ca787e667b0dae8c7e90c6d24177245449f4e940
+size 1152
diff --git a/configs/singers/singer0007.npy b/configs/singers/singer0007.npy
new file mode 100644
index 0000000000000000000000000000000000000000..b401dcf6e4f774798010ccbf27cd8622943e7174
--- /dev/null
+++ b/configs/singers/singer0007.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:36611b9e57545332b9fb97fd35a356fbe8d60258f2f5e2232168481bb6dfab5b
+size 1152
diff --git a/configs/singers/singer0008.npy b/configs/singers/singer0008.npy
new file mode 100644
index 0000000000000000000000000000000000000000..f28df113963e42afba4d76af39ec150a66de406c
--- /dev/null
+++ b/configs/singers/singer0008.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8584ad6f3569a1307082cd410085d9a562807e962274b89b72487c7bc79124d4
+size 1152
diff --git a/configs/singers/singer0009.npy b/configs/singers/singer0009.npy
new file mode 100644
index 0000000000000000000000000000000000000000..15d125808f4475d4b9e9a2714839b61f36c2ae1c
--- /dev/null
+++ b/configs/singers/singer0009.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b069db4e3e5ca389ffba974c74eab46caf4c60545773e5f7e5e253310619073e
+size 1152
diff --git a/configs/singers/singer0010.npy b/configs/singers/singer0010.npy
new file mode 100644
index 0000000000000000000000000000000000000000..bda76fe1913f4310bb2dab2a3eb411454aac8d12
--- /dev/null
+++ b/configs/singers/singer0010.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7d4d92735e4bac1618e89198d113013db09061b6c1f74ba0c500b70b097cd407
+size 1152
diff --git a/configs/singers/singer0011.npy b/configs/singers/singer0011.npy
new file mode 100644
index 0000000000000000000000000000000000000000..0fd56c80b7fc42933a3c1b2a0c37e2e7291e44f8
--- /dev/null
+++ b/configs/singers/singer0011.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:942388b4276dc06ee365f59c324ce1642e4bf810dcc99992739787e3b9ad135d
+size 1152
diff --git a/configs/singers/singer0012.npy b/configs/singers/singer0012.npy
new file mode 100644
index 0000000000000000000000000000000000000000..54261d088430996e5a0abf019c47632031bb8886
--- /dev/null
+++ b/configs/singers/singer0012.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3411efcf4ee4f534cea2b742c2eca166ae971efbceab21fb41b77b8923a1ba3a
+size 1152
diff --git a/configs/singers/singer0013.npy b/configs/singers/singer0013.npy
new file mode 100644
index 0000000000000000000000000000000000000000..3eedaf46d7a1fe7865d7c2783375e0f4a010f154
--- /dev/null
+++ b/configs/singers/singer0013.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6e8e30cd1bce61405db194278dd7bf207d16abf656dd22f9a20f29e3657674f3
+size 1152
diff --git a/configs/singers/singer0014.npy b/configs/singers/singer0014.npy
new file mode 100644
index 0000000000000000000000000000000000000000..602e8f6203eb05962b3102f319e0ec99db9c097b
--- /dev/null
+++ b/configs/singers/singer0014.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f9cc8200753b4ba7605c9a13bf454b100025965135c5d816f7440ec53a2e6dd4
+size 1152
diff --git a/configs/singers/singer0015.npy b/configs/singers/singer0015.npy
new file mode 100644
index 0000000000000000000000000000000000000000..dfe824316e9fa153390c87f5d1fd1d4b34caab3b
--- /dev/null
+++ b/configs/singers/singer0015.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dcb58688e51dbdeb22e5dd85d27ff3904c4594c78420b8e9c9ab481adbecc5fe
+size 1152
diff --git a/configs/singers/singer0016.npy b/configs/singers/singer0016.npy
new file mode 100644
index 0000000000000000000000000000000000000000..5ce37e18e3b70a9c149e29743f227b5fd4cfdffd
--- /dev/null
+++ b/configs/singers/singer0016.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:66a3c6162b8c937e9e8bbdc806b873866afce4b110664831642f7b41922bbf39
+size 1152
diff --git a/configs/singers/singer0017.npy b/configs/singers/singer0017.npy
new file mode 100644
index 0000000000000000000000000000000000000000..4104cb371c4e57ca071ec299e39e7320cc3e4569
--- /dev/null
+++ b/configs/singers/singer0017.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:84782c98c930bd980f350837f4b3e8e193c49ef46aef9f92471c6136659975a9
+size 1152
diff --git a/configs/singers/singer0018.npy b/configs/singers/singer0018.npy
new file mode 100644
index 0000000000000000000000000000000000000000..fc43cc1750632008cc8057ce24b39e759bb3a047
--- /dev/null
+++ b/configs/singers/singer0018.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:731ebafda06aecedfd79941978149a0f87595f04e24eab7ed5300defe9070fc0
+size 1152
diff --git a/configs/singers/singer0019.npy b/configs/singers/singer0019.npy
new file mode 100644
index 0000000000000000000000000000000000000000..5e32ca3dca1e975ba5bae37b730fb3a764fc8595
--- /dev/null
+++ b/configs/singers/singer0019.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3d88e620994e4413c4c58ffb9239ef46ded60ff3eab0715c7af96cbe4092198f
+size 1152
diff --git a/configs/singers/singer0020.npy b/configs/singers/singer0020.npy
new file mode 100644
index 0000000000000000000000000000000000000000..88a0e64f47ac03688db8b45329f4de9554f86835
--- /dev/null
+++ b/configs/singers/singer0020.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e5abaabe5457a20161351dcf5f8737d63a2a92fb1de1842ea9e92e47b9ca6fe
+size 1152
diff --git a/configs/singers/singer0021.npy b/configs/singers/singer0021.npy
new file mode 100644
index 0000000000000000000000000000000000000000..d80f97eac1be1779490a37a503aaac6ac1f5d130
--- /dev/null
+++ b/configs/singers/singer0021.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1d7f99c92c89a44c1f2dd0688f033f0593c8c88b0537b092928bfbaa63a8d3e9
+size 1152
diff --git a/configs/singers/singer0022.npy b/configs/singers/singer0022.npy
new file mode 100644
index 0000000000000000000000000000000000000000..64dbba2610e35274f13da90702f608f283ddb0f4
--- /dev/null
+++ b/configs/singers/singer0022.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:33becb1da48b12ba4957a0ef0b25bbd51e100d5762ebc4c7d381f6b957e682a2
+size 1152
diff --git a/configs/singers/singer0023.npy b/configs/singers/singer0023.npy
new file mode 100644
index 0000000000000000000000000000000000000000..6cb1c218971618be8de47f9c8daec658d4578531
--- /dev/null
+++ b/configs/singers/singer0023.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7f49cbaf3f7653f48f80854a513a334f31dca719a09cca66e257995ce4a741a9
+size 1152
diff --git a/configs/singers/singer0024.npy b/configs/singers/singer0024.npy
new file mode 100644
index 0000000000000000000000000000000000000000..0ca92912b8959fc8743b9628f08c7595b6eb94f9
--- /dev/null
+++ b/configs/singers/singer0024.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92ed584994d56473c8bab0799d213e927c5a2928facef2b93a2f95f764d868b4
+size 1152
diff --git a/configs/singers/singer0025.npy b/configs/singers/singer0025.npy
new file mode 100644
index 0000000000000000000000000000000000000000..05bd93acfad034e7e83e7965c4ce4557a4b7a277
--- /dev/null
+++ b/configs/singers/singer0025.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:14b7e1f55393d5beaa2f3bbd0ef7f2be7e108993c680acb265ff24df19f7062b
+size 1152
diff --git a/configs/singers/singer0026.npy b/configs/singers/singer0026.npy
new file mode 100644
index 0000000000000000000000000000000000000000..cddbd2fddc6a51caf1a2f111f3bd6b1d3dbe2c18
--- /dev/null
+++ b/configs/singers/singer0026.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92ecc9aa68f136960c00e98aaca16e92c38960bc7eb9687aee90190972974726
+size 1152
diff --git a/configs/singers/singer0027.npy b/configs/singers/singer0027.npy
new file mode 100644
index 0000000000000000000000000000000000000000..aedcbf0c48465cb3273ba46d180711901ca911ea
--- /dev/null
+++ b/configs/singers/singer0027.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f5a8a1c2a445179d38664fb55c84ee9a36350beee50efa9f850d29b394447bfa
+size 1152
diff --git a/configs/singers/singer0028.npy b/configs/singers/singer0028.npy
new file mode 100644
index 0000000000000000000000000000000000000000..788e6fdeb897960c0b1e203f9750cd2ae3969975
--- /dev/null
+++ b/configs/singers/singer0028.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b79b8266c8d368dc99f49a347b2631e1e5cfb44056b5a9ab4470b42f9851ee35
+size 1152
diff --git a/configs/singers/singer0029.npy b/configs/singers/singer0029.npy
new file mode 100644
index 0000000000000000000000000000000000000000..0340a7adc7bde1ada08e8e1962ce94c6f72a9d2f
--- /dev/null
+++ b/configs/singers/singer0029.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:60fa5fd9e8ba14d7f6d67304842f16382f7d2e739969bde9551222ff8c282775
+size 1152
diff --git a/configs/singers/singer0030.npy b/configs/singers/singer0030.npy
new file mode 100644
index 0000000000000000000000000000000000000000..3597a4cf4e540f6aaae2f5f0ecd6f21d52a15658
--- /dev/null
+++ b/configs/singers/singer0030.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2f5070e4196c91fa713aed20aedb2a570a7b2ad8301ee61f59821dafaea3c6a7
+size 1152
diff --git a/configs/singers/singer0031.npy b/configs/singers/singer0031.npy
new file mode 100644
index 0000000000000000000000000000000000000000..73be80545df0ca0a007ce914569e088082db62d8
--- /dev/null
+++ b/configs/singers/singer0031.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:47f4f8c065be1c5448c1b80e5c99087e7357cf1f8a8a55f2d844ccf1ca4931e6
+size 1152
diff --git a/configs/singers/singer0032.npy b/configs/singers/singer0032.npy
new file mode 100644
index 0000000000000000000000000000000000000000..09d9d322a3232c0ab60cd05ea0257534b05e3c35
--- /dev/null
+++ b/configs/singers/singer0032.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:019f40cf49cb7ccb44fb9c6a9f6345e84f837185a1642623144b4e2969c8738b
+size 1152
diff --git a/configs/singers/singer0033.npy b/configs/singers/singer0033.npy
new file mode 100644
index 0000000000000000000000000000000000000000..a6efd8966cb9e1f2ed87cbc103cdb70ceb279213
--- /dev/null
+++ b/configs/singers/singer0033.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2e05e212c93fc9e7b13174dd76721ee891bb4ea8bb1638a4c43523ed65d30f67
+size 1152
diff --git a/configs/singers/singer0034.npy b/configs/singers/singer0034.npy
new file mode 100644
index 0000000000000000000000000000000000000000..1c23504832579cc0c9c3a7b505a7d6b8cc1efd81
--- /dev/null
+++ b/configs/singers/singer0034.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:715a089dd9b3e5cbf021b0f41055f59208911e49cccf375ecf8b82544f325c3d
+size 1152
diff --git a/configs/singers/singer0035.npy b/configs/singers/singer0035.npy
new file mode 100644
index 0000000000000000000000000000000000000000..894595cd6678a1befdfb843a02ee09ad7badc03c
--- /dev/null
+++ b/configs/singers/singer0035.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9af8cd05182ec53ff573bce53dad049759bea1de5656915f414910eaf47f61ed
+size 1152
diff --git a/configs/singers/singer0036.npy b/configs/singers/singer0036.npy
new file mode 100644
index 0000000000000000000000000000000000000000..de86320c15d7e9e3162edc9eadb783fb306b79c3
--- /dev/null
+++ b/configs/singers/singer0036.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3cec474244d86acfd24d6abf7e033b24b40b838cba2fcd3b4d0e5611313d67ef
+size 1152
diff --git a/configs/singers/singer0037.npy b/configs/singers/singer0037.npy
new file mode 100644
index 0000000000000000000000000000000000000000..36488b5b3c83fca4e3617aaefe5138002c150cd9
--- /dev/null
+++ b/configs/singers/singer0037.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:316e3435d373e352fe95fcb2ec0ab1c8afdeb270ce9f13c940ba91187eecdcf3
+size 1152
diff --git a/configs/singers/singer0038.npy b/configs/singers/singer0038.npy
new file mode 100644
index 0000000000000000000000000000000000000000..9c234763efb8227d82a4770dfc0b5b885d124f13
--- /dev/null
+++ b/configs/singers/singer0038.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0e6458e251512dab86abce504490de6762f9c2de66ddbc853c24c3d05eb39c96
+size 1152
diff --git a/configs/singers/singer0039.npy b/configs/singers/singer0039.npy
new file mode 100644
index 0000000000000000000000000000000000000000..64b2bc8072f901df785c32d47b5926e98179ee50
--- /dev/null
+++ b/configs/singers/singer0039.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c2e484ae33eef7ac92dd784e9e3b9bca7e6c0838d50b43c674da47620f281f20
+size 1152
diff --git a/configs/singers/singer0040.npy b/configs/singers/singer0040.npy
new file mode 100644
index 0000000000000000000000000000000000000000..96dd086113c61fac151de03b22a12daf768d7f41
--- /dev/null
+++ b/configs/singers/singer0040.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7b3a104163ad4cf87caff70b845b2c3e70190ce430a8f21247d350ef102071dc
+size 1152
diff --git a/configs/singers/singer0041.npy b/configs/singers/singer0041.npy
new file mode 100644
index 0000000000000000000000000000000000000000..265f3dc086d595a0675e38fd35f5c433cc3dbcef
--- /dev/null
+++ b/configs/singers/singer0041.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:962ba35045f952562bbf68239c8abfda4e1888118fae7ef19814282abee2d28e
+size 1152
diff --git a/configs/singers/singer0042.npy b/configs/singers/singer0042.npy
new file mode 100644
index 0000000000000000000000000000000000000000..7b13c99f9ef87b7e64bcdac1216f772239946b4f
--- /dev/null
+++ b/configs/singers/singer0042.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0cf0871ba7e939c90f2f027862f80e11151d8b1a21b6624ee05f184d024b35a3
+size 1152
diff --git a/configs/singers/singer0043.npy b/configs/singers/singer0043.npy
new file mode 100644
index 0000000000000000000000000000000000000000..11b3e4a998bb8eb219aa7b2bbfbd36234293e6e4
--- /dev/null
+++ b/configs/singers/singer0043.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9494ed20a9b095d19cce17619a6372ba3371f980c643078ffda8649a30ac2f8b
+size 1152
diff --git a/configs/singers/singer0044.npy b/configs/singers/singer0044.npy
new file mode 100644
index 0000000000000000000000000000000000000000..a12211417bf2d237b0c164c2d275ed95bd3ff175
--- /dev/null
+++ b/configs/singers/singer0044.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c12949caa6176fbe5f323cf643d29eef14af9a3ee03be27c938d8bb6fc2922f1
+size 1152
diff --git a/configs/singers/singer0045.npy b/configs/singers/singer0045.npy
new file mode 100644
index 0000000000000000000000000000000000000000..04962d6066594759e2fd3ec9d69687f6179d8e74
--- /dev/null
+++ b/configs/singers/singer0045.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:222adf210792d1b2745ef98b717f57e0d309d8176e9b59ff56063c1e2001728d
+size 1152
diff --git a/configs/singers/singer0046.npy b/configs/singers/singer0046.npy
new file mode 100644
index 0000000000000000000000000000000000000000..74976cf00a14a550f4189100409ba940e5f81c7c
--- /dev/null
+++ b/configs/singers/singer0046.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6070f66f028928114493a363e4117636afefb1a094c54ffc01f89ef261ad1882
+size 1152
diff --git a/configs/singers/singer0047.npy b/configs/singers/singer0047.npy
new file mode 100644
index 0000000000000000000000000000000000000000..50304b9bb81a71beb480c2ce786a2c8ff0aa8db5
--- /dev/null
+++ b/configs/singers/singer0047.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7c8bb2fb993a55d13996df74463408c8e7e8d5e24b391887813e2ac1c204c9c4
+size 1152
diff --git a/configs/singers/singer0048.npy b/configs/singers/singer0048.npy
new file mode 100644
index 0000000000000000000000000000000000000000..71f0fbde409976dbc079d27b957d269d3dd59129
--- /dev/null
+++ b/configs/singers/singer0048.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5b33ee26125ae840494dc2cb3839f7a2f6b48571c15ebc7f0aa9f2b0fef5022e
+size 1152
diff --git a/configs/singers/singer0049.npy b/configs/singers/singer0049.npy
new file mode 100644
index 0000000000000000000000000000000000000000..00eb5f5b705f1236aebc44ccc40d87fe071e12b0
--- /dev/null
+++ b/configs/singers/singer0049.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e9a8d97dd4d320e4049c39e112587416d06aa70ec52c05417519bac70fe76556
+size 1152
diff --git a/configs/singers/singer0050.npy b/configs/singers/singer0050.npy
new file mode 100644
index 0000000000000000000000000000000000000000..23f9815b6ed783ba964a8bed34e1a7353fd3873c
--- /dev/null
+++ b/configs/singers/singer0050.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8fc8cf73923c6a567a134bffa037b3c9d1dcfde75d5a976238df222d91517d9f
+size 1152
diff --git a/configs/singers/singer0051.npy b/configs/singers/singer0051.npy
new file mode 100644
index 0000000000000000000000000000000000000000..1ffac1a5057ba9bd5b8dc77ab4d8ddb1f02c8333
--- /dev/null
+++ b/configs/singers/singer0051.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a142fdcdb2fa1e69d09df6de3f96bf863038da4e27f51320adb5483cb4f5d306
+size 1152
diff --git a/configs/singers/singer0052.npy b/configs/singers/singer0052.npy
new file mode 100644
index 0000000000000000000000000000000000000000..ce1c360dd3c195b30157c33bb49138d55b72c586
--- /dev/null
+++ b/configs/singers/singer0052.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2fb5dcbe59a636f84c9522a3278b601358584fdc340dff8db06eaa544fddd4b
+size 1152
diff --git a/configs/singers/singer0053.npy b/configs/singers/singer0053.npy
new file mode 100644
index 0000000000000000000000000000000000000000..e2328ccf8b490bf37b622bf663292cd0c32f92b8
--- /dev/null
+++ b/configs/singers/singer0053.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:893f50a74fdf3e7f27debc52fd739cd96a147ae2dcbeb34e2fb8fd328fa698a5
+size 1152
diff --git a/configs/singers/singer0054.npy b/configs/singers/singer0054.npy
new file mode 100644
index 0000000000000000000000000000000000000000..aa47e0a5d03bc259e93ce2fe0cb8cef9ddc01a2b
--- /dev/null
+++ b/configs/singers/singer0054.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:151b0206dd960b1304f8d752e0502e9c7b0260326f7b932b278773aa0c5bb3ef
+size 1152
diff --git a/configs/singers/singer0055.npy b/configs/singers/singer0055.npy
new file mode 100644
index 0000000000000000000000000000000000000000..944ff6e36130ab3aaa2aea7f89edc4a50fff960d
--- /dev/null
+++ b/configs/singers/singer0055.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aa92ab16df7f82f5bde66501629310a5a250ff61d5f467a7ff2b0c97fe6d8066
+size 1152
diff --git a/configs/singers/singer0056.npy b/configs/singers/singer0056.npy
new file mode 100644
index 0000000000000000000000000000000000000000..79f339eb5ef8319b5410749e19a5c647d7c2e81f
--- /dev/null
+++ b/configs/singers/singer0056.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:07022464dcb678dddd3957ea431e53b4c79fd1904927d2622faef5c521da1b5e
+size 1152
diff --git a/configs/singers_sample/22-wave-girl/031.wav b/configs/singers_sample/22-wave-girl/031.wav
new file mode 100644
index 0000000000000000000000000000000000000000..6dadf271930958b6a55bb1bb15cab164e381c0ab
Binary files /dev/null and b/configs/singers_sample/22-wave-girl/031.wav differ
diff --git a/configs/singers_sample/22-wave-girl/032.wav b/configs/singers_sample/22-wave-girl/032.wav
new file mode 100644
index 0000000000000000000000000000000000000000..ee782d84355f46d3fb685a3d8342315887d7dd08
Binary files /dev/null and b/configs/singers_sample/22-wave-girl/032.wav differ
diff --git a/configs/singers_sample/22-wave-girl/033.wav b/configs/singers_sample/22-wave-girl/033.wav
new file mode 100644
index 0000000000000000000000000000000000000000..2525fa03d17d5281591e0fd4e0f35e0f8832efb8
Binary files /dev/null and b/configs/singers_sample/22-wave-girl/033.wav differ
diff --git a/configs/singers_sample/22-wave-girl/034.wav b/configs/singers_sample/22-wave-girl/034.wav
new file mode 100644
index 0000000000000000000000000000000000000000..12f8abacf3f69b33c4771bda9d4bd36d72c14a1d
Binary files /dev/null and b/configs/singers_sample/22-wave-girl/034.wav differ
diff --git a/configs/singers_sample/22-wave-girl/035.wav b/configs/singers_sample/22-wave-girl/035.wav
new file mode 100644
index 0000000000000000000000000000000000000000..67061ed6d9f01a77be336295427152235727f2f3
Binary files /dev/null and b/configs/singers_sample/22-wave-girl/035.wav differ
diff --git a/configs/singers_sample/30-wave-boy/010.wav b/configs/singers_sample/30-wave-boy/010.wav
new file mode 100644
index 0000000000000000000000000000000000000000..f155cd942d81afb15b154726dafabc8d52c8e039
Binary files /dev/null and b/configs/singers_sample/30-wave-boy/010.wav differ
diff --git a/configs/singers_sample/30-wave-boy/011.wav b/configs/singers_sample/30-wave-boy/011.wav
new file mode 100644
index 0000000000000000000000000000000000000000..4c97631687ff52e88559a46da777230332b761d3
Binary files /dev/null and b/configs/singers_sample/30-wave-boy/011.wav differ
diff --git a/configs/singers_sample/30-wave-boy/012.wav b/configs/singers_sample/30-wave-boy/012.wav
new file mode 100644
index 0000000000000000000000000000000000000000..ef425b68fda074ab2a269f2c52defe166169b7d8
Binary files /dev/null and b/configs/singers_sample/30-wave-boy/012.wav differ
diff --git a/configs/singers_sample/30-wave-boy/013.wav b/configs/singers_sample/30-wave-boy/013.wav
new file mode 100644
index 0000000000000000000000000000000000000000..a88cf259b8be0cba850407986e32542256a2a9ee
Binary files /dev/null and b/configs/singers_sample/30-wave-boy/013.wav differ
diff --git a/configs/singers_sample/30-wave-boy/014.wav b/configs/singers_sample/30-wave-boy/014.wav
new file mode 100644
index 0000000000000000000000000000000000000000..7cf8a332d01eec2c56cd34366a428bee36dc004a
Binary files /dev/null and b/configs/singers_sample/30-wave-boy/014.wav differ
diff --git a/configs/singers_sample/30-wave-boy/015.wav b/configs/singers_sample/30-wave-boy/015.wav
new file mode 100644
index 0000000000000000000000000000000000000000..d1bc022dca24ed3834b766a5f8307bd4739253a9
Binary files /dev/null and b/configs/singers_sample/30-wave-boy/015.wav differ
diff --git a/configs/singers_sample/47-wave-girl/020.wav b/configs/singers_sample/47-wave-girl/020.wav
new file mode 100644
index 0000000000000000000000000000000000000000..4176e18fb71d802f054765a58c1ca9a2935c9e22
Binary files /dev/null and b/configs/singers_sample/47-wave-girl/020.wav differ
diff --git a/configs/singers_sample/47-wave-girl/021.wav b/configs/singers_sample/47-wave-girl/021.wav
new file mode 100644
index 0000000000000000000000000000000000000000..9891ceecbf5563590ab88fec7c0bed19808b6b1a
Binary files /dev/null and b/configs/singers_sample/47-wave-girl/021.wav differ
diff --git a/configs/singers_sample/47-wave-girl/022.wav b/configs/singers_sample/47-wave-girl/022.wav
new file mode 100644
index 0000000000000000000000000000000000000000..a98fc13df8bdde03a1c11df7918c2011b4b48342
Binary files /dev/null and b/configs/singers_sample/47-wave-girl/022.wav differ
diff --git a/configs/singers_sample/47-wave-girl/023.wav b/configs/singers_sample/47-wave-girl/023.wav
new file mode 100644
index 0000000000000000000000000000000000000000..614edb86cec4a7e61fed0e3e5d590b4ce375e914
Binary files /dev/null and b/configs/singers_sample/47-wave-girl/023.wav differ
diff --git a/configs/singers_sample/47-wave-girl/024.wav b/configs/singers_sample/47-wave-girl/024.wav
new file mode 100644
index 0000000000000000000000000000000000000000..83c3fce417a04c3928bae789f760ddf1c233cd86
Binary files /dev/null and b/configs/singers_sample/47-wave-girl/024.wav differ
diff --git a/configs/singers_sample/47-wave-girl/025.wav b/configs/singers_sample/47-wave-girl/025.wav
new file mode 100644
index 0000000000000000000000000000000000000000..1e6d11027c8a2d120d0d457a0501c7a6b6902600
Binary files /dev/null and b/configs/singers_sample/47-wave-girl/025.wav differ
diff --git a/configs/singers_sample/51-wave-boy/006.wav b/configs/singers_sample/51-wave-boy/006.wav
new file mode 100644
index 0000000000000000000000000000000000000000..f2b140dbd1207bbc81d5cb207f052faa59c573a5
Binary files /dev/null and b/configs/singers_sample/51-wave-boy/006.wav differ
diff --git a/configs/singers_sample/51-wave-boy/007.wav b/configs/singers_sample/51-wave-boy/007.wav
new file mode 100644
index 0000000000000000000000000000000000000000..76115acb255395428ecc063fb63616137c4f58b0
Binary files /dev/null and b/configs/singers_sample/51-wave-boy/007.wav differ
diff --git a/configs/singers_sample/51-wave-boy/008.wav b/configs/singers_sample/51-wave-boy/008.wav
new file mode 100644
index 0000000000000000000000000000000000000000..117f33965675781598fe5200a1d73e64db797936
Binary files /dev/null and b/configs/singers_sample/51-wave-boy/008.wav differ
diff --git a/configs/singers_sample/51-wave-boy/009.wav b/configs/singers_sample/51-wave-boy/009.wav
new file mode 100644
index 0000000000000000000000000000000000000000..b38ba32ccb0b95ab53c4a55477d70c408970b563
Binary files /dev/null and b/configs/singers_sample/51-wave-boy/009.wav differ
diff --git a/configs/singers_sample/51-wave-boy/010.wav b/configs/singers_sample/51-wave-boy/010.wav
new file mode 100644
index 0000000000000000000000000000000000000000..99a43ba2013a126103bd167689b06f819fef8573
Binary files /dev/null and b/configs/singers_sample/51-wave-boy/010.wav differ
diff --git a/crepe/LICENSE.txt b/crepe/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..efc01ae87f6cc931d539ee9672a4e00aa583814c
--- /dev/null
+++ b/crepe/LICENSE.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 Max Morrison
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/crepe/README.md b/crepe/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..296537c8aee47545f5600a6e7d84731d535e84d8
--- /dev/null
+++ b/crepe/README.md
@@ -0,0 +1,223 @@
+torchcrepe
+
+
+[](https://pypi.python.org/pypi/torchcrepe)
+[](https://opensource.org/licenses/MIT)
+[](https://pepy.tech/project/torchcrepe)
+
+
+
+Pytorch implementation of the CREPE [1] pitch tracker. The original Tensorflow
+implementation can be found [here](https://github.com/marl/crepe/). The
+provided model weights were obtained by converting the "tiny" and "full" models
+using [MMdnn](https://github.com/microsoft/MMdnn), an open-source model
+management framework.
+
+
+## Installation
+Perform the system-dependent PyTorch install using the instructions found
+[here](https://pytorch.org/).
+
+`pip install torchcrepe`
+
+
+## Usage
+
+### Computing pitch and periodicity from audio
+
+
+```python
+import torchcrepe
+
+
+# Load audio
+audio, sr = torchcrepe.load.audio( ... )
+
+# Here we'll use a 5 millisecond hop length
+hop_length = int(sr / 200.)
+
+# Provide a sensible frequency range for your domain (upper limit is 2006 Hz)
+# This would be a reasonable range for speech
+fmin = 50
+fmax = 550
+
+# Select a model capacity--one of "tiny" or "full"
+model = 'tiny'
+
+# Choose a device to use for inference
+device = 'cuda:0'
+
+# Pick a batch size that doesn't cause memory errors on your gpu
+batch_size = 2048
+
+# Compute pitch using first gpu
+pitch = torchcrepe.predict(audio,
+ sr,
+ hop_length,
+ fmin,
+ fmax,
+ model,
+ batch_size=batch_size,
+ device=device)
+```
+
+A periodicity metric similar to the Crepe confidence score can also be
+extracted by passing `return_periodicity=True` to `torchcrepe.predict`.
+
+
+### Decoding
+
+By default, `torchcrepe` uses Viterbi decoding on the softmax of the network
+output. This is different than the original implementation, which uses a
+weighted average near the argmax of binary cross-entropy probabilities.
+The argmax operation can cause double/half frequency errors. These can be
+removed by penalizing large pitch jumps via Viterbi decoding. The `decode`
+submodule provides some options for decoding.
+
+```python
+# Decode using viterbi decoding (default)
+torchcrepe.predict(..., decoder=torchcrepe.decode.viterbi)
+
+# Decode using weighted argmax (as in the original implementation)
+torchcrepe.predict(..., decoder=torchcrepe.decode.weighted_argmax)
+
+# Decode using argmax
+torchcrepe.predict(..., decoder=torchcrepe.decode.argmax)
+```
+
+
+### Filtering and thresholding
+
+When periodicity is low, the pitch is less reliable. For some problems, it
+makes sense to mask these less reliable pitch values. However, the periodicity
+can be noisy and the pitch has quantization artifacts. `torchcrepe` provides
+submodules `filter` and `threshold` for this purpose. The filter and threshold
+parameters should be tuned to your data. For clean speech, a 10-20 millisecond
+window with a threshold of 0.21 has worked.
+
+```python
+# We'll use a 15 millisecond window assuming a hop length of 5 milliseconds
+win_length = 3
+
+# Median filter noisy confidence value
+periodicity = torchcrepe.filter.median(periodicity, win_length)
+
+# Remove inharmonic regions
+pitch = torchcrepe.threshold.At(.21)(pitch, periodicity)
+
+# Optionally smooth pitch to remove quantization artifacts
+pitch = torchcrepe.filter.mean(pitch, win_length)
+```
+
+For more fine-grained control over pitch thresholding, see
+`torchcrepe.threshold.Hysteresis`. This is especially useful for removing
+spurious voiced regions caused by noise in the periodicity values, but
+has more parameters and may require more manual tuning to your data.
+
+CREPE was not trained on silent audio. Therefore, it sometimes assigns high
+confidence to pitch bins in silent regions. You can use
+`torchcrepe.threshold.Silence` to manually set the periodicity in silent
+regions to zero.
+
+```python
+periodicity = torchcrepe.threshold.Silence(-60.)(periodicity,
+ audio,
+ sr,
+ hop_length)
+```
+
+
+### Computing the CREPE model output activations
+
+```python
+batch = next(torchcrepe.preprocess(audio, sr, hop_length))
+probabilities = torchcrepe.infer(batch)
+```
+
+
+### Computing the CREPE embedding space
+
+As in Differentiable Digital Signal Processing [2], this uses the output of the
+fifth max-pooling layer as a pretrained pitch embedding
+
+```python
+embeddings = torchcrepe.embed(audio, sr, hop_length)
+```
+
+### Computing from files
+
+`torchcrepe` defines the following functions convenient for predicting
+directly from audio files on disk. Each of these functions also takes
+a `device` argument that can be used for device placement (e.g.,
+`device='cuda:0'`).
+
+```python
+torchcrepe.predict_from_file(audio_file, ...)
+torchcrepe.predict_from_file_to_file(
+ audio_file, output_pitch_file, output_periodicity_file, ...)
+torchcrepe.predict_from_files_to_files(
+ audio_files, output_pitch_files, output_periodicity_files, ...)
+
+torchcrepe.embed_from_file(audio_file, ...)
+torchcrepe.embed_from_file_to_file(audio_file, output_file, ...)
+torchcrepe.embed_from_files_to_files(audio_files, output_files, ...)
+```
+
+### Command-line interface
+
+```bash
+usage: python -m torchcrepe
+ [-h]
+ --audio_files AUDIO_FILES [AUDIO_FILES ...]
+ --output_files OUTPUT_FILES [OUTPUT_FILES ...]
+ [--hop_length HOP_LENGTH]
+ [--output_periodicity_files OUTPUT_PERIODICITY_FILES [OUTPUT_PERIODICITY_FILES ...]]
+ [--embed]
+ [--fmin FMIN]
+ [--fmax FMAX]
+ [--model MODEL]
+ [--decoder DECODER]
+ [--gpu GPU]
+ [--no_pad]
+
+optional arguments:
+ -h, --help show this help message and exit
+ --audio_files AUDIO_FILES [AUDIO_FILES ...]
+ The audio file to process
+ --output_files OUTPUT_FILES [OUTPUT_FILES ...]
+ The file to save pitch or embedding
+ --hop_length HOP_LENGTH
+ The hop length of the analysis window
+ --output_periodicity_files OUTPUT_PERIODICITY_FILES [OUTPUT_PERIODICITY_FILES ...]
+ The file to save periodicity
+ --embed Performs embedding instead of pitch prediction
+ --fmin FMIN The minimum frequency allowed
+ --fmax FMAX The maximum frequency allowed
+ --model MODEL The model capacity. One of "tiny" or "full"
+ --decoder DECODER The decoder to use. One of "argmax", "viterbi", or
+ "weighted_argmax"
+ --gpu GPU The gpu to perform inference on
+ --no_pad Whether to pad the audio
+```
+
+
+## Tests
+
+The module tests can be run as follows.
+
+```bash
+pip install pytest
+pytest
+```
+
+
+## References
+[1] J. W. Kim, J. Salamon, P. Li, and J. P. Bello, “Crepe: A
+Convolutional Representation for Pitch Estimation,” in 2018 IEEE
+International Conference on Acoustics, Speech and Signal
+Processing (ICASSP).
+
+[2] J. H. Engel, L. Hantrakul, C. Gu, and A. Roberts,
+“DDSP: Differentiable Digital Signal Processing,” in
+2020 International Conference on Learning
+Representations (ICLR).
diff --git a/crepe/__init__.py b/crepe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f78e20d4a4a07cb7dfc37df643d96a34a4486ccd
--- /dev/null
+++ b/crepe/__init__.py
@@ -0,0 +1,8 @@
+from . import decode
+from .core import *
+from .model import Crepe
+from . import convert
+from . import filter
+from . import load
+from . import loudness
+from . import threshold
diff --git a/crepe/__main__.py b/crepe/__main__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d1a3120adea147778bc5829ae9b8037bed8efd0
--- /dev/null
+++ b/crepe/__main__.py
@@ -0,0 +1,148 @@
+import argparse
+import os
+import warnings
+
+import crepe
+
+
+###############################################################################
+# Entry point
+###############################################################################
+
+
+def parse_args():
+ """Parse command-line arguments"""
+ parser = argparse.ArgumentParser()
+
+ # Required arguments
+ parser.add_argument(
+ '--audio_files',
+ nargs='+',
+ required=True,
+ help='The audio file to process')
+ parser.add_argument(
+ '--output_files',
+ nargs='+',
+ required=True,
+ help='The file to save pitch or embedding')
+ parser.add_argument(
+ '--hop_length',
+ type=int,
+ help='The hop length of the analysis window')
+
+ # Optionally save harmonicity [DEPRECATED]
+ parser.add_argument(
+ '--output_harmonicity_files',
+ nargs='+',
+ help='The file to save harmonicity')
+ # Optionally save periodicity
+ parser.add_argument(
+ '--output_periodicity_files',
+ nargs='+',
+ help='The files to save periodicity')
+
+ # Optionally create embedding instead of pitch contour
+ parser.add_argument(
+ '--embed',
+ action='store_true',
+ help='Performs embedding instead of pitch prediction')
+
+ # Optional arguments
+ parser.add_argument(
+ '--fmin',
+ default=50.,
+ type=float,
+ help='The minimum frequency allowed')
+ parser.add_argument(
+ '--fmax',
+ default=crepe.MAX_FMAX,
+ type=float,
+ help='The maximum frequency allowed')
+ parser.add_argument(
+ '--model',
+ default='full',
+ help='The model capacity. One of "tiny" or "full"')
+ parser.add_argument(
+ '--decoder',
+ default='viterbi',
+ help='The decoder to use. One of "argmax", "viterbi", or ' +
+ '"weighted_argmax"')
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ help='The number of frames per batch')
+ parser.add_argument(
+ '--gpu',
+ type=int,
+ help='The gpu to perform inference on')
+ parser.add_argument(
+ '--no_pad',
+ action='store_true',
+ help='Whether to pad the audio')
+
+ return parser.parse_args()
+
+
+def make_parent_directory(file):
+ """Create parent directory for file if it does not already exist"""
+ parent = os.path.dirname(os.path.abspath(file))
+ os.makedirs(parent, exist_ok=True)
+
+
+def main():
+ # Parse command-line arguments
+ args = parse_args()
+
+ # Deprecate output_harmonicity_files
+ if args.output_harmonicity_files is not None:
+ message = (
+ 'The crepe output_harmonicity_files argument is deprecated and '
+ 'will be removed in a future release. Please use '
+ 'output_periodicity_files. Rationale: if network confidence measured '
+ 'harmonic content, the value would be low for non-harmonic, periodic '
+ 'sounds (e.g., sine waves). But this is not observed.')
+ warnings.warn(message, DeprecationWarning)
+ args.output_periodicity_files = args.output_harmonicity_files
+
+ # Ensure output directory exist
+ [make_parent_directory(file) for file in args.output_files]
+ if args.output_periodicity_files is not None:
+ [make_parent_directory(file) for file in args.output_periodicity_files]
+
+ # Get inference device
+ device = 'cpu' if args.gpu is None else f'cuda:{args.gpu}'
+
+ # Get decoder
+ if args.decoder == 'argmax':
+ decoder = crepe.decode.argmax
+ elif args.decoder == 'weighted_argmax':
+ decoder = crepe.decode.weighted_argmax
+ elif args.decoder == 'viterbi':
+ decoder = crepe.decode.viterbi
+
+ # Infer pitch or embedding and save to disk
+ if args.embed:
+ crepe.embed_from_files_to_files(args.audio_files,
+ args.output_files,
+ args.hop_length,
+ args.model,
+ args.batch_size,
+ device,
+ not args.no_pad)
+ else:
+ crepe.predict_from_files_to_files(args.audio_files,
+ args.output_files,
+ None,
+ args.output_periodicity_files,
+ args.hop_length,
+ args.fmin,
+ args.fmax,
+ args.model,
+ decoder,
+ args.batch_size,
+ device,
+ not args.no_pad)
+
+
+# Run module entry point
+main()
diff --git a/crepe/assets/tiny.pth b/crepe/assets/tiny.pth
new file mode 100644
index 0000000000000000000000000000000000000000..79d10d896a956c54dee45257cfe6bf87425bbdf5
--- /dev/null
+++ b/crepe/assets/tiny.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d4993eea36ed1a0ad9ac549c740dae5265b049ce72004f00c2f59e01c0be8432
+size 1962363
diff --git a/crepe/convert.py b/crepe/convert.py
new file mode 100644
index 0000000000000000000000000000000000000000..27ace1e111bb1c824894af50125c60a73af9bc20
--- /dev/null
+++ b/crepe/convert.py
@@ -0,0 +1,57 @@
+import scipy
+import torch
+
+import crepe
+
+
+###############################################################################
+# Pitch unit conversions
+###############################################################################
+
+
+def bins_to_cents(bins):
+ """Converts pitch bins to cents"""
+ cents = crepe.CENTS_PER_BIN * bins + 1997.3794084376191
+
+ # Trade quantization error for noise
+ return dither(cents)
+
+
+def bins_to_frequency(bins):
+ """Converts pitch bins to frequency in Hz"""
+ return cents_to_frequency(bins_to_cents(bins))
+
+
+def cents_to_bins(cents, quantize_fn=torch.floor):
+ """Converts cents to pitch bins"""
+ bins = (cents - 1997.3794084376191) / crepe.CENTS_PER_BIN
+ return quantize_fn(bins).int()
+
+
+def cents_to_frequency(cents):
+ """Converts cents to frequency in Hz"""
+ return 10 * 2 ** (cents / 1200)
+
+
+def frequency_to_bins(frequency, quantize_fn=torch.floor):
+ """Convert frequency in Hz to pitch bins"""
+ return cents_to_bins(frequency_to_cents(frequency), quantize_fn)
+
+
+def frequency_to_cents(frequency):
+ """Convert frequency in Hz to cents"""
+ return 1200 * torch.log2(frequency / 10.)
+
+
+###############################################################################
+# Utilities
+###############################################################################
+
+
+def dither(cents):
+ """Dither the predicted pitch in cents to remove quantization error"""
+ noise = scipy.stats.triang.rvs(c=0.5,
+ loc=-crepe.CENTS_PER_BIN,
+ scale=2 * crepe.CENTS_PER_BIN,
+ size=cents.size())
+ return cents + cents.new_tensor(noise)
diff --git a/crepe/core.py b/crepe/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa7f0dd8e794ac3475a69cf7c80dc880a5f1598d
--- /dev/null
+++ b/crepe/core.py
@@ -0,0 +1,738 @@
+import warnings
+
+import numpy as np
+import resampy
+import torch
+import tqdm
+
+import crepe
+
+
+__all__ = ['CENTS_PER_BIN',
+ 'MAX_FMAX',
+ 'PITCH_BINS',
+ 'SAMPLE_RATE',
+ 'WINDOW_SIZE',
+ 'UNVOICED',
+ 'embed',
+ 'embed_from_file',
+ 'embed_from_file_to_file',
+ 'embed_from_files_to_files',
+ 'infer',
+ 'predict',
+ 'predict_from_file',
+ 'predict_from_file_to_file',
+ 'predict_from_files_to_files',
+ 'preprocess',
+ 'postprocess',
+ 'resample']
+
+
+###############################################################################
+# Constants
+###############################################################################
+
+
+CENTS_PER_BIN = 20 # cents
+MAX_FMAX = 2006. # hz
+PITCH_BINS = 360
+SAMPLE_RATE = 16000 # hz
+WINDOW_SIZE = 1024 # samples
+UNVOICED = np.nan
+
+
+###############################################################################
+# Crepe pitch prediction
+###############################################################################
+
+
+def predict(audio,
+ sample_rate,
+ hop_length=None,
+ fmin=50.,
+ fmax=MAX_FMAX,
+ model='full',
+ decoder=crepe.decode.viterbi,
+ return_harmonicity=False,
+ return_periodicity=False,
+ batch_size=None,
+ device='cpu',
+ pad=True):
+ """Performs pitch estimation
+
+ Arguments
+ audio (torch.tensor [shape=(1, time)])
+ The audio signal
+ sample_rate (int)
+ The sampling rate in Hz
+ hop_length (int)
+ The hop_length in samples
+ fmin (float)
+ The minimum allowable frequency in Hz
+ fmax (float)
+ The maximum allowable frequency in Hz
+ model (string)
+ The model capacity. One of 'full' or 'tiny'.
+ decoder (function)
+ The decoder to use. See decode.py for decoders.
+ return_harmonicity (bool) [DEPRECATED]
+ Whether to also return the network confidence
+ return_periodicity (bool)
+ Whether to also return the network confidence
+ batch_size (int)
+ The number of frames per batch
+ device (string)
+ The device used to run inference
+ pad (bool)
+ Whether to zero-pad the audio
+
+ Returns
+ pitch (torch.tensor [shape=(1, 1 + int(time // hop_length))])
+ (Optional) periodicity (torch.tensor
+ [shape=(1, 1 + int(time // hop_length))])
+ """
+ # Deprecate return_harmonicity
+ if return_harmonicity:
+ message = (
+ 'The crepe return_harmonicity argument is deprecated and '
+ 'will be removed in a future release. Please use '
+ 'return_periodicity. Rationale: if network confidence measured '
+ 'harmonics, the value would be low for non-harmonic, periodic '
+ 'sounds (e.g., sine waves). But this is not observed.')
+ warnings.warn(message, DeprecationWarning)
+ return_periodicity = return_harmonicity
+
+ results = []
+
+ # Postprocessing breaks gradients, so just don't compute them
+ with torch.no_grad():
+
+ # Preprocess audio
+ generator = preprocess(audio,
+ sample_rate,
+ hop_length,
+ batch_size,
+ device,
+ pad)
+ for frames in generator:
+
+ # Infer independent probabilities for each pitch bin
+ probabilities = infer(frames, model)
+
+ # shape=(batch, 360, time / hop_length)
+ probabilities = probabilities.reshape(
+ audio.size(0), -1, PITCH_BINS).transpose(1, 2)
+
+ # Convert probabilities to F0 and periodicity
+ result = postprocess(probabilities,
+ fmin,
+ fmax,
+ decoder,
+ return_harmonicity,
+ return_periodicity)
+
+ # Place on same device as audio to allow very long inputs
+ if isinstance(result, tuple):
+ result = (result[0].to(audio.device),
+ result[1].to(audio.device))
+ else:
+ result = result.to(audio.device)
+
+ results.append(result)
+
+ # Split pitch and periodicity
+ if return_periodicity:
+ pitch, periodicity = zip(*results)
+ return torch.cat(pitch, 1), torch.cat(periodicity, 1)
+
+ # Concatenate
+ return torch.cat(results, 1)
+
+
+def predict_from_file(audio_file,
+ hop_length=None,
+ fmin=50.,
+ fmax=MAX_FMAX,
+ model='full',
+ decoder=crepe.decode.viterbi,
+ return_harmonicity=False,
+ return_periodicity=False,
+ batch_size=None,
+ device='cpu',
+ pad=True):
+ """Performs pitch estimation from file on disk
+
+ Arguments
+ audio_file (string)
+ The file to perform pitch tracking on
+ hop_length (int)
+ The hop_length in samples
+ fmin (float)
+ The minimum allowable frequency in Hz
+ fmax (float)
+ The maximum allowable frequency in Hz
+ model (string)
+ The model capacity. One of 'full' or 'tiny'.
+ decoder (function)
+ The decoder to use. See decode.py for decoders.
+ return_harmonicity (bool) [DEPRECATED]
+ Whether to also return the network confidence
+ return_periodicity (bool)
+ Whether to also return the network confidence
+ batch_size (int)
+ The number of frames per batch
+ device (string)
+ The device used to run inference
+ pad (bool)
+ Whether to zero-pad the audio
+
+ Returns
+ pitch (torch.tensor [shape=(1, 1 + int(time // hop_length))])
+ (Optional) periodicity (torch.tensor
+ [shape=(1, 1 + int(time // hop_length))])
+ """
+ # Load audio
+ audio, sample_rate = crepe.load.audio(audio_file)
+
+ # Predict
+ return predict(audio,
+ sample_rate,
+ hop_length,
+ fmin,
+ fmax,
+ model,
+ decoder,
+ return_harmonicity,
+ return_periodicity,
+ batch_size,
+ device,
+ pad)
+
+
+def predict_from_file_to_file(audio_file,
+ output_pitch_file,
+ output_harmonicity_file=None,
+ output_periodicity_file=None,
+ hop_length=None,
+ fmin=50.,
+ fmax=MAX_FMAX,
+ model='full',
+ decoder=crepe.decode.viterbi,
+ batch_size=None,
+ device='cpu',
+ pad=True):
+ """Performs pitch estimation from file on disk
+
+ Arguments
+ audio_file (string)
+ The file to perform pitch tracking on
+ output_pitch_file (string)
+ The file to save predicted pitch
+ output_harmonicity_file (string or None) [DEPRECATED]
+ The file to save predicted harmonicity
+ output_periodicity_file (string or None)
+ The file to save predicted periodicity
+ hop_length (int)
+ The hop_length in samples
+ fmin (float)
+ The minimum allowable frequency in Hz
+ fmax (float)
+ The maximum allowable frequency in Hz
+ model (string)
+ The model capacity. One of 'full' or 'tiny'.
+ decoder (function)
+ The decoder to use. See decode.py for decoders.
+ batch_size (int)
+ The number of frames per batch
+ device (string)
+ The device used to run inference
+ pad (bool)
+ Whether to zero-pad the audio
+ """
+ # Deprecate output_harmonicity_file
+ if output_harmonicity_file is not None:
+ message = (
+ 'The crepe output_harmonicity_file argument is deprecated and '
+ 'will be removed in a future release. Please use '
+ 'output_periodicity_file. Rationale: if network confidence measured '
+ 'harmonic content, the value would be low for non-harmonic, periodic '
+ 'sounds (e.g., sine waves). But this is not observed.')
+ warnings.warn(message, DeprecationWarning)
+ output_periodicity_file = output_harmonicity_file
+
+ # Predict from file
+ prediction = predict_from_file(audio_file,
+ hop_length,
+ fmin,
+ fmax,
+ model,
+ decoder,
+ False,
+ output_periodicity_file is not None,
+ batch_size,
+ device,
+ pad)
+
+ # Save to disk
+ if output_periodicity_file is not None:
+ torch.save(prediction[0].detach(), output_pitch_file)
+ torch.save(prediction[1].detach(), output_periodicity_file)
+ else:
+ torch.save(prediction.detach(), output_pitch_file)
+
+
+def predict_from_files_to_files(audio_files,
+ output_pitch_files,
+ output_harmonicity_files=None,
+ output_periodicity_files=None,
+ hop_length=None,
+ fmin=50.,
+ fmax=MAX_FMAX,
+ model='full',
+ decoder=crepe.decode.viterbi,
+ batch_size=None,
+ device='cpu',
+ pad=True):
+ """Performs pitch estimation from files on disk without reloading model
+
+ Arguments
+ audio_files (list[string])
+ The files to perform pitch tracking on
+ output_pitch_files (list[string])
+ The files to save predicted pitch
+ output_harmonicity_files (list[string] or None) [DEPRECATED]
+ The files to save predicted harmonicity
+ output_periodicity_files (list[string] or None)
+ The files to save predicted periodicity
+ hop_length (int)
+ The hop_length in samples
+ fmin (float)
+ The minimum allowable frequency in Hz
+ fmax (float)
+ The maximum allowable frequency in Hz
+ model (string)
+ The model capacity. One of 'full' or 'tiny'.
+ decoder (function)
+ The decoder to use. See decode.py for decoders.
+ batch_size (int)
+ The number of frames per batch
+ device (string)
+ The device used to run inference
+ pad (bool)
+ Whether to zero-pad the audio
+ """
+ # Deprecate output_harmonicity_files
+ if output_harmonicity_files is not None:
+ message = (
+ 'The crepe output_harmonicity_files argument is deprecated and '
+ 'will be removed in a future release. Please use '
+ 'output_periodicity_files. Rationale: if network confidence measured '
+ 'harmonic content, the value would be low for non-harmonic, periodic '
+ 'sounds (e.g., sine waves). But this is not observed.')
+ warnings.warn(message, DeprecationWarning)
+ output_periodicity_files = output_harmonicity_files
+
+ if output_periodicity_files is None:
+ output_periodicity_files = len(audio_files) * [None]
+
+ # Setup iterator
+ iterator = zip(audio_files, output_pitch_files, output_periodicity_files)
+ iterator = tqdm.tqdm(iterator, desc='crepe', dynamic_ncols=True)
+ for audio_file, output_pitch_file, output_periodicity_file in iterator:
+
+ # Predict a file
+ predict_from_file_to_file(audio_file,
+ output_pitch_file,
+ None,
+ output_periodicity_file,
+ hop_length,
+ fmin,
+ fmax,
+ model,
+ decoder,
+ batch_size,
+ device,
+ pad)
+
+###############################################################################
+# Crepe pitch embedding
+###############################################################################
+
+
+def embed(audio,
+ sample_rate,
+ hop_length=None,
+ model='full',
+ batch_size=None,
+ device='cpu',
+ pad=True):
+ """Embeds audio to the output of CREPE's fifth maxpool layer
+
+ Arguments
+ audio (torch.tensor [shape=(1, time)])
+ The audio signals
+ sample_rate (int)
+ The sampling rate in Hz
+ hop_length (int)
+ The hop_length in samples
+ model (string)
+ The model capacity. One of 'full' or 'tiny'.
+ batch_size (int)
+ The number of frames per batch
+ device (string)
+ The device to run inference on
+ pad (bool)
+ Whether to zero-pad the audio
+
+ Returns
+ embedding (torch.tensor [shape=(1,
+ 1 + int(time // hop_length), 32, -1)])
+ """
+ results = []
+
+ # Preprocess audio
+ generator = preprocess(audio,
+ sample_rate,
+ hop_length,
+ batch_size,
+ device,
+ pad)
+ for frames in generator:
+
+ # Infer pitch embeddings
+ embedding = infer(frames, model, embed=True)
+
+ # shape=(batch, time / hop_length, 32, embedding_size)
+ result = embedding.reshape(audio.size(0), frames.size(0), 32, -1)
+
+ # Place on same device as audio. This allows for large inputs.
+ results.append(result.to(audio.device))
+
+ # Concatenate
+ return torch.cat(results, 1)
+
+
+def embed_from_file(audio_file,
+ hop_length=None,
+ model='full',
+ batch_size=None,
+ device='cpu',
+ pad=True):
+ """Embeds audio from disk to the output of CREPE's fifth maxpool layer
+
+ Arguments
+ audio_file (string)
+ The wav file containing the audio to embed
+ hop_length (int)
+ The hop_length in samples
+ model (string)
+ The model capacity. One of 'full' or 'tiny'.
+ batch_size (int)
+ The number of frames per batch
+ device (string)
+ The device to run inference on
+ pad (bool)
+ Whether to zero-pad the audio
+
+ Returns
+ embedding (torch.tensor [shape=(1,
+ 1 + int(time // hop_length), 32, -1)])
+ """
+ # Load audio
+ audio, sample_rate = crepe.load.audio(audio_file)
+
+ # Embed
+ return embed(audio,
+ sample_rate,
+ hop_length,
+ model,
+ batch_size,
+ device,
+ pad)
+
+
+def embed_from_file_to_file(audio_file,
+ output_file,
+ hop_length=None,
+ model='full',
+ batch_size=None,
+ device='cpu',
+ pad=True):
+ """Embeds audio from disk and saves to disk
+
+ Arguments
+ audio_file (string)
+ The wav file containing the audio to embed
+ hop_length (int)
+ The hop_length in samples
+ output_file (string)
+ The file to save the embedding
+ model (string)
+ The model capacity. One of 'full' or 'tiny'.
+ batch_size (int)
+ The number of frames per batch
+ device (string)
+ The device to run inference on
+ pad (bool)
+ Whether to zero-pad the audio
+ """
+ # No use computing gradients if we're just saving to file
+ with torch.no_grad():
+
+ # Embed
+ embedding = embed_from_file(audio_file,
+ hop_length,
+ model,
+ batch_size,
+ device,
+ pad)
+
+ # Save to disk
+ torch.save(embedding.detach(), output_file)
+
+
+def embed_from_files_to_files(audio_files,
+ output_files,
+ hop_length=None,
+ model='full',
+ batch_size=None,
+ device='cpu',
+ pad=True):
+ """Embeds audio from disk and saves to disk without reloading model
+
+ Arguments
+ audio_files (list[string])
+ The wav files containing the audio to embed
+ output_files (list[string])
+ The files to save the embeddings
+ hop_length (int)
+ The hop_length in samples
+ model (string)
+ The model capacity. One of 'full' or 'tiny'.
+ batch_size (int)
+ The number of frames per batch
+ device (string)
+ The device to run inference on
+ pad (bool)
+ Whether to zero-pad the audio
+ """
+ # Setup iterator
+ iterator = zip(audio_files, output_files)
+ iterator = tqdm.tqdm(iterator, desc='crepe', dynamic_ncols=True)
+ for audio_file, output_file in iterator:
+
+ # Embed a file
+ embed_from_file_to_file(audio_file,
+ output_file,
+ hop_length,
+ model,
+ batch_size,
+ device,
+ pad)
+
+
+###############################################################################
+# Components for step-by-step prediction
+###############################################################################
+
+
+def infer(frames, model='full', embed=False):
+ """Forward pass through the model
+
+ Arguments
+ frames (torch.tensor [shape=(time / hop_length, 1024)])
+ The network input
+ model (string)
+ The model capacity. One of 'full' or 'tiny'.
+ embed (bool)
+ Whether to stop inference at the intermediate embedding layer
+
+ Returns
+ logits (torch.tensor [shape=(1 + int(time // hop_length), 360)]) OR
+ embedding (torch.tensor [shape=(1 + int(time // hop_length),
+ embedding_size)])
+ """
+ # Load the model if necessary
+ if not hasattr(infer, 'model') or not hasattr(infer, 'capacity') or \
+ (hasattr(infer, 'capacity') and infer.capacity != model):
+ crepe.load.model(frames.device, model)
+
+ # Move model to correct device (no-op if devices are the same)
+ infer.model = infer.model.to(frames.device)
+
+ # Apply model
+ return infer.model(frames, embed=embed)
+
+
+def postprocess(probabilities,
+ fmin=0.,
+ fmax=MAX_FMAX,
+ decoder=crepe.decode.viterbi,
+ return_harmonicity=False,
+ return_periodicity=False):
+ """Convert model output to F0 and periodicity
+
+ Arguments
+ probabilities (torch.tensor [shape=(1, 360, time / hop_length)])
+ The probabilities for each pitch bin inferred by the network
+ fmin (float)
+ The minimum allowable frequency in Hz
+ fmax (float)
+ The maximum allowable frequency in Hz
+ viterbi (bool)
+ Whether to use viterbi decoding
+ return_harmonicity (bool) [DEPRECATED]
+ Whether to also return the network confidence
+ return_periodicity (bool)
+ Whether to also return the network confidence
+
+ Returns
+ pitch (torch.tensor [shape=(1, 1 + int(time // hop_length))])
+ periodicity (torch.tensor [shape=(1, 1 + int(time // hop_length))])
+ """
+ # Sampling is non-differentiable, so remove from graph
+ probabilities = probabilities.detach()
+
+ # Convert frequency range to pitch bin range
+ minidx = crepe.convert.frequency_to_bins(torch.tensor(fmin))
+ maxidx = crepe.convert.frequency_to_bins(torch.tensor(fmax),
+ torch.ceil)
+
+ # Remove frequencies outside of allowable range
+ probabilities[:, :minidx] = -float('inf')
+ probabilities[:, maxidx:] = -float('inf')
+
+ # Perform argmax or viterbi sampling
+ bins, pitch = decoder(probabilities)
+
+ # Deprecate return_harmonicity
+ if return_harmonicity:
+ message = (
+ 'The crepe return_harmonicity argument is deprecated and '
+ 'will be removed in a future release. Please use '
+ 'return_periodicity. Rationale: if network confidence measured '
+ 'harmonics, the value would be low for non-harmonic, periodic '
+ 'sounds (e.g., sine waves). But this is not observed.')
+ warnings.warn(message, DeprecationWarning)
+ return_periodicity = return_harmonicity
+
+ if not return_periodicity:
+ return pitch
+
+ # Compute periodicity from probabilities and decoded pitch bins
+ return pitch, periodicity(probabilities, bins)
+
+
+def preprocess(audio,
+ sample_rate,
+ hop_length=None,
+ batch_size=None,
+ device='cpu',
+ pad=True):
+ """Convert audio to model input
+
+ Arguments
+ audio (torch.tensor [shape=(1, time)])
+ The audio signals
+ sample_rate (int)
+ The sampling rate in Hz
+ hop_length (int)
+ The hop_length in samples
+ batch_size (int)
+ The number of frames per batch
+ device (string)
+ The device to run inference on
+ pad (bool)
+ Whether to zero-pad the audio
+
+ Returns
+ frames (torch.tensor [shape=(1 + int(time // hop_length), 1024)])
+ """
+ # Default hop length of 10 ms
+ hop_length = sample_rate // 100 if hop_length is None else hop_length
+
+ # Resample
+ if sample_rate != SAMPLE_RATE:
+ audio = resample(audio, sample_rate)
+ hop_length = int(hop_length * SAMPLE_RATE / sample_rate)
+
+ # Get total number of frames
+
+ # Maybe pad
+ if pad:
+ total_frames = 1 + int(audio.size(1) // hop_length)
+ audio = torch.nn.functional.pad(
+ audio,
+ (WINDOW_SIZE // 2, WINDOW_SIZE // 2))
+ else:
+ total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length)
+
+ # Default to running all frames in a single batch
+ batch_size = total_frames if batch_size is None else batch_size
+
+ # Generate batches
+ for i in range(0, total_frames, batch_size):
+
+ # Batch indices
+ start = max(0, i * hop_length)
+ end = min(audio.size(1),
+ (i + batch_size - 1) * hop_length + WINDOW_SIZE)
+
+ # Chunk
+ frames = torch.nn.functional.unfold(
+ audio[:, None, None, start:end],
+ kernel_size=(1, WINDOW_SIZE),
+ stride=(1, hop_length))
+
+ # shape=(1 + int(time / hop_length, 1024)
+ frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE)
+
+ # Place on device
+ frames = frames.to(device)
+
+ # Mean-center
+ frames -= frames.mean(dim=1, keepdim=True)
+
+ # Scale
+ # Note: during silent frames, this produces very large values. But
+ # this seems to be what the network expects.
+ frames /= torch.max(torch.tensor(1e-10, device=frames.device),
+ frames.std(dim=1, keepdim=True))
+
+ yield frames
+
+
+###############################################################################
+# Utilities
+###############################################################################
+
+
+def periodicity(probabilities, bins):
+ """Computes the periodicity from the network output and pitch bins"""
+ # shape=(batch * time / hop_length, 360)
+ probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS)
+
+ # shape=(batch * time / hop_length, 1)
+ bins_stacked = bins.reshape(-1, 1).to(torch.int64)
+
+ # Use maximum logit over pitch bins as periodicity
+ periodicity = probs_stacked.gather(1, bins_stacked)
+
+ # shape=(batch, time / hop_length)
+ return periodicity.reshape(probabilities.size(0), probabilities.size(2))
+
+
+def resample(audio, sample_rate):
+ """Resample audio"""
+ # Store device for later placement
+ device = audio.device
+
+ # Convert to numpy
+ audio = audio.detach().cpu().numpy().squeeze(0)
+
+ # Resample
+ # We have to use resampy if we want numbers to match Crepe
+ audio = resampy.resample(audio, sample_rate, SAMPLE_RATE)
+
+ # Convert to pytorch
+ return torch.tensor(audio, device=device).unsqueeze(0)
diff --git a/crepe/decode.py b/crepe/decode.py
new file mode 100644
index 0000000000000000000000000000000000000000..559e566b8e2c09fb7634c6ac9ce867731295901b
--- /dev/null
+++ b/crepe/decode.py
@@ -0,0 +1,80 @@
+import librosa
+import numpy as np
+import torch
+
+import crepe
+
+
+###############################################################################
+# Probability sequence decoding methods
+###############################################################################
+
+
+def argmax(logits):
+ """Sample observations by taking the argmax"""
+ bins = logits.argmax(dim=1)
+
+ # Convert to frequency in Hz
+ return bins, crepe.convert.bins_to_frequency(bins)
+
+
+def weighted_argmax(logits):
+ """Sample observations using weighted sum near the argmax"""
+ # Find center of analysis window
+ bins = logits.argmax(dim=1)
+
+ # Find bounds of analysis window
+ start = torch.max(torch.tensor(0, device=logits.device), bins - 4)
+ end = torch.min(torch.tensor(logits.size(1), device=logits.device), bins + 5)
+
+ # Mask out everything outside of window
+ for batch in range(logits.size(0)):
+ for time in range(logits.size(2)):
+ logits[batch, :start[batch, time], time] = -float('inf')
+ logits[batch, end[batch, time]:, time] = -float('inf')
+
+ # Construct weights
+ if not hasattr(weighted_argmax, 'weights'):
+ weights = crepe.convert.bins_to_cents(torch.arange(360))
+ weighted_argmax.weights = weights[None, :, None]
+
+ # Ensure devices are the same (no-op if they are)
+ weighted_argmax.weights = weighted_argmax.weights.to(logits.device)
+
+ # Convert to probabilities
+ with torch.no_grad():
+ probs = torch.sigmoid(logits)
+
+ # Apply weights
+ cents = (weighted_argmax.weights * probs).sum(dim=1) / probs.sum(dim=1)
+
+ # Convert to frequency in Hz
+ return bins, crepe.convert.cents_to_frequency(cents)
+
+
+def viterbi(logits):
+ """Sample observations using viterbi decoding"""
+ # Create viterbi transition matrix
+ if not hasattr(viterbi, 'transition'):
+ xx, yy = np.meshgrid(range(360), range(360))
+ transition = np.maximum(12 - abs(xx - yy), 0)
+ transition = transition / transition.sum(axis=1, keepdims=True)
+ viterbi.transition = transition
+
+ # Normalize logits
+ with torch.no_grad():
+ probs = torch.nn.functional.softmax(logits, dim=1)
+
+ # Convert to numpy
+ sequences = probs.cpu().numpy()
+
+ # Perform viterbi decoding
+ bins = np.array([
+ librosa.sequence.viterbi(sequence, viterbi.transition).astype(np.int64)
+ for sequence in sequences])
+
+ # Convert to pytorch
+ bins = torch.tensor(bins, device=probs.device)
+
+ # Convert to frequency in Hz
+ return bins, crepe.convert.bins_to_frequency(bins)
diff --git a/crepe/filter.py b/crepe/filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd62ef59c7e2c7dd0c2544ae17b5ef60d0b642f6
--- /dev/null
+++ b/crepe/filter.py
@@ -0,0 +1,195 @@
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+###############################################################################
+# Sequence filters
+###############################################################################
+
+
+def mean(signals, win_length=9):
+ """Averave filtering for signals containing nan values
+
+ Arguments
+ signals (torch.tensor (shape=(batch, time)))
+ The signals to filter
+ win_length
+ The size of the analysis window
+
+ Returns
+ filtered (torch.tensor (shape=(batch, time)))
+ """
+
+ assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)"
+ signals = signals.unsqueeze(1)
+
+ # Apply the mask by setting masked elements to zero, or make NaNs zero
+ mask = ~torch.isnan(signals)
+ masked_x = torch.where(mask, signals, torch.zeros_like(signals))
+
+ # Create a ones kernel with the same number of channels as the input tensor
+ ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)
+
+ # Perform sum pooling
+ sum_pooled = F.conv1d(
+ masked_x,
+ ones_kernel,
+ stride=1,
+ padding=win_length // 2,
+ )
+
+ # Count the non-masked (valid) elements in each pooling window
+ valid_count = F.conv1d(
+ mask.float(),
+ ones_kernel,
+ stride=1,
+ padding=win_length // 2,
+ )
+ valid_count = valid_count.clamp(min=1) # Avoid division by zero
+
+ # Perform masked average pooling
+ avg_pooled = sum_pooled / valid_count
+
+ # Fill zero values with NaNs
+ avg_pooled[avg_pooled == 0] = float("nan")
+
+ return avg_pooled.squeeze(1)
+
+
+def median(signals, win_length):
+ """Median filtering for signals containing nan values
+
+ Arguments
+ signals (torch.tensor (shape=(batch, time)))
+ The signals to filter
+ win_length
+ The size of the analysis window
+
+ Returns
+ filtered (torch.tensor (shape=(batch, time)))
+ """
+
+ assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)"
+ signals = signals.unsqueeze(1)
+
+ mask = ~torch.isnan(signals)
+ masked_x = torch.where(mask, signals, torch.zeros_like(signals))
+ padding = win_length // 2
+
+ x = F.pad(masked_x, (padding, padding), mode="reflect")
+ mask = F.pad(mask.float(), (padding, padding), mode="constant", value=0)
+
+ x = x.unfold(2, win_length, 1)
+ mask = mask.unfold(2, win_length, 1)
+
+ x = x.contiguous().view(x.size()[:3] + (-1,))
+ mask = mask.contiguous().view(mask.size()[:3] + (-1,))
+
+ # Combine the mask with the input tensor
+ x_masked = torch.where(mask.bool(), x.double(), float("inf")).to(x)
+
+ # Sort the masked tensor along the last dimension
+ x_sorted, _ = torch.sort(x_masked, dim=-1)
+
+ # Compute the count of non-masked (valid) values
+ valid_count = mask.sum(dim=-1)
+
+ # Calculate the index of the median value for each pooling window
+ median_idx = ((valid_count - 1) // 2).clamp(min=0)
+
+ # Gather the median values using the calculated indices
+ median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1)
+
+ # Fill infinite values with NaNs
+ median_pooled[torch.isinf(median_pooled)] = float("nan")
+
+ return median_pooled.squeeze(1)
+
+
+###############################################################################
+# Utilities
+###############################################################################
+
+
+def nanfilter(signals, win_length, filter_fn):
+ """Filters a sequence, ignoring nan values
+
+ Arguments
+ signals (torch.tensor (shape=(batch, time)))
+ The signals to filter
+ win_length
+ The size of the analysis window
+ filter_fn (function)
+ The function to use for filtering
+
+ Returns
+ filtered (torch.tensor (shape=(batch, time)))
+ """
+ # Output buffer
+ filtered = torch.empty_like(signals)
+
+ # Loop over frames
+ for i in range(signals.size(1)):
+
+ # Get analysis window bounds
+ start = max(0, i - win_length // 2)
+ end = min(signals.size(1), i + win_length // 2 + 1)
+
+ # Apply filter to window
+ filtered[:, i] = filter_fn(signals[:, start:end])
+
+ return filtered
+
+
+def nanmean(signals):
+ """Computes the mean, ignoring nans
+
+ Arguments
+ signals (torch.tensor [shape=(batch, time)])
+ The signals to filter
+
+ Returns
+ filtered (torch.tensor [shape=(batch, time)])
+ """
+ signals = signals.clone()
+
+ # Find nans
+ nans = torch.isnan(signals)
+
+ # Set nans to 0.
+ signals[nans] = 0.
+
+ # Compute average
+ return signals.sum(dim=1) / (~nans).float().sum(dim=1)
+
+
+def nanmedian(signals):
+ """Computes the median, ignoring nans
+
+ Arguments
+ signals (torch.tensor [shape=(batch, time)])
+ The signals to filter
+
+ Returns
+ filtered (torch.tensor [shape=(batch, time)])
+ """
+ # Find nans
+ nans = torch.isnan(signals)
+
+ # Compute median for each slice
+ medians = [nanmedian1d(signal[~nan]) for signal, nan in zip(signals, nans)]
+
+ # Stack results
+ return torch.tensor(medians, dtype=signals.dtype, device=signals.device)
+
+
+def nanmedian1d(signal):
+ """Computes the median. If signal is empty, returns torch.nan
+
+ Arguments
+ signal (torch.tensor [shape=(time,)])
+
+ Returns
+ median (torch.tensor [shape=(1,)])
+ """
+ return torch.median(signal) if signal.numel() else np.nan
diff --git a/crepe/load.py b/crepe/load.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb5a3c355b31f0495721d6dcfc4fbc57927c4f91
--- /dev/null
+++ b/crepe/load.py
@@ -0,0 +1,36 @@
+import os
+
+import numpy as np
+import torch
+import crepe
+from scipy.io import wavfile
+
+
+def audio(filename):
+ """Load audio from disk"""
+ sample_rate, audio = wavfile.read(filename)
+
+ # Convert to float32
+ if audio.dtype == np.int16:
+ audio = audio.astype(np.float32) / np.iinfo(np.int16).max
+
+ # PyTorch is not compatible with non-writeable arrays, so we make a copy
+ return torch.tensor(np.copy(audio))[None], sample_rate
+
+
+def model(device, capacity='full'):
+ """Preloads model from disk"""
+ # Bind model and capacity
+ crepe.infer.capacity = capacity
+ crepe.infer.model = crepe.Crepe(capacity)
+
+ # Load weights
+ file = os.path.join(os.path.dirname(__file__), 'assets', f'{capacity}.pth')
+ crepe.infer.model.load_state_dict(
+ torch.load(file, map_location=device))
+
+ # Place on device
+ crepe.infer.model = crepe.infer.model.to(torch.device(device))
+
+ # Eval mode
+ crepe.infer.model.eval()
diff --git a/crepe/loudness.py b/crepe/loudness.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6f5c4a648b6adfa7c0a0c8988f4ae0bfd7b051d
--- /dev/null
+++ b/crepe/loudness.py
@@ -0,0 +1,78 @@
+import warnings
+
+import librosa
+import numpy as np
+import resampy
+import torch
+
+import crepe
+
+
+###############################################################################
+# Constants
+###############################################################################
+
+
+# Minimum decibel level
+MIN_DB = -100.
+
+# Reference decibel level
+REF_DB = 20.
+
+
+###############################################################################
+# A-weighted loudness
+###############################################################################
+
+
+def a_weighted(audio, sample_rate, hop_length=None, pad=True):
+ """Retrieve the per-frame loudness"""
+ # Save device
+ device = audio.device
+
+ # Default hop length of 10 ms
+ hop_length = sample_rate // 100 if hop_length is None else hop_length
+
+ # Convert to numpy
+ audio = audio.detach().cpu().numpy().squeeze(0)
+
+ # Resample
+ if sample_rate != crepe.SAMPLE_RATE:
+ audio = resampy.resample(audio, sample_rate, crepe.SAMPLE_RATE)
+ hop_length = int(hop_length * crepe.SAMPLE_RATE / sample_rate)
+
+ # Cache weights
+ if not hasattr(a_weighted, 'weights'):
+ a_weighted.weights = perceptual_weights()
+
+ # Take stft
+ stft = librosa.stft(audio,
+ n_fft=crepe.WINDOW_SIZE,
+ hop_length=hop_length,
+ win_length=crepe.WINDOW_SIZE,
+ center=pad,
+ pad_mode='constant')
+
+ # Compute magnitude on db scale
+ db = librosa.amplitude_to_db(np.abs(stft))
+
+ # Apply A-weighting
+ weighted = db + a_weighted.weights
+
+ # Threshold
+ weighted[weighted < MIN_DB] = MIN_DB
+
+ # Average over weighted frequencies
+ return torch.from_numpy(weighted.mean(axis=0)).float().to(device)[None]
+
+
+def perceptual_weights():
+ """A-weighted frequency-dependent perceptual loudness weights"""
+ frequencies = librosa.fft_frequencies(sr=crepe.SAMPLE_RATE,
+ n_fft=crepe.WINDOW_SIZE)
+
+ # A warning is raised for nearly inaudible frequencies, but it ends up
+ # defaulting to -100 db. That default is fine for our purposes.
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', RuntimeWarning)
+ return librosa.A_weighting(frequencies)[:, None] - REF_DB
diff --git a/crepe/model.py b/crepe/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1c1a5b687773211d77e89d096e0e0189014ac54
--- /dev/null
+++ b/crepe/model.py
@@ -0,0 +1,134 @@
+import functools
+
+import torch
+import torch.nn.functional as F
+
+import crepe
+
+
+###########################################################################
+# Model definition
+###########################################################################
+
+
+class Crepe(torch.nn.Module):
+ """Crepe model definition"""
+
+ def __init__(self, model='full'):
+ super().__init__()
+
+ # Model-specific layer parameters
+ if model == 'full':
+ in_channels = [1, 1024, 128, 128, 128, 256]
+ out_channels = [1024, 128, 128, 128, 256, 512]
+ self.in_features = 2048
+ elif model == 'tiny':
+ in_channels = [1, 128, 16, 16, 16, 32]
+ out_channels = [128, 16, 16, 16, 32, 64]
+ self.in_features = 256
+ else:
+ raise ValueError(f'Model {model} is not supported')
+
+ # Shared layer parameters
+ kernel_sizes = [(512, 1)] + 5 * [(64, 1)]
+ strides = [(4, 1)] + 5 * [(1, 1)]
+
+ # Overload with eps and momentum conversion given by MMdnn
+ batch_norm_fn = functools.partial(torch.nn.BatchNorm2d,
+ eps=0.0010000000474974513,
+ momentum=0.0)
+
+ # Layer definitions
+ self.conv1 = torch.nn.Conv2d(
+ in_channels=in_channels[0],
+ out_channels=out_channels[0],
+ kernel_size=kernel_sizes[0],
+ stride=strides[0])
+ self.conv1_BN = batch_norm_fn(
+ num_features=out_channels[0])
+
+ self.conv2 = torch.nn.Conv2d(
+ in_channels=in_channels[1],
+ out_channels=out_channels[1],
+ kernel_size=kernel_sizes[1],
+ stride=strides[1])
+ self.conv2_BN = batch_norm_fn(
+ num_features=out_channels[1])
+
+ self.conv3 = torch.nn.Conv2d(
+ in_channels=in_channels[2],
+ out_channels=out_channels[2],
+ kernel_size=kernel_sizes[2],
+ stride=strides[2])
+ self.conv3_BN = batch_norm_fn(
+ num_features=out_channels[2])
+
+ self.conv4 = torch.nn.Conv2d(
+ in_channels=in_channels[3],
+ out_channels=out_channels[3],
+ kernel_size=kernel_sizes[3],
+ stride=strides[3])
+ self.conv4_BN = batch_norm_fn(
+ num_features=out_channels[3])
+
+ self.conv5 = torch.nn.Conv2d(
+ in_channels=in_channels[4],
+ out_channels=out_channels[4],
+ kernel_size=kernel_sizes[4],
+ stride=strides[4])
+ self.conv5_BN = batch_norm_fn(
+ num_features=out_channels[4])
+
+ self.conv6 = torch.nn.Conv2d(
+ in_channels=in_channels[5],
+ out_channels=out_channels[5],
+ kernel_size=kernel_sizes[5],
+ stride=strides[5])
+ self.conv6_BN = batch_norm_fn(
+ num_features=out_channels[5])
+
+ self.classifier = torch.nn.Linear(
+ in_features=self.in_features,
+ out_features=crepe.PITCH_BINS)
+
+ def forward(self, x, embed=False):
+ # Forward pass through first five layers
+ x = self.embed(x)
+
+ if embed:
+ return x
+
+ # Forward pass through layer six
+ x = self.layer(x, self.conv6, self.conv6_BN)
+
+ # shape=(batch, self.in_features)
+ x = x.permute(0, 2, 1, 3).reshape(-1, self.in_features)
+
+ # Compute logits
+ return torch.sigmoid(self.classifier(x))
+
+ ###########################################################################
+ # Forward pass utilities
+ ###########################################################################
+
+ def embed(self, x):
+ """Map input audio to pitch embedding"""
+ # shape=(batch, 1, 1024, 1)
+ x = x[:, None, :, None]
+
+ # Forward pass through first five layers
+ x = self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254))
+ x = self.layer(x, self.conv2, self.conv2_BN)
+ x = self.layer(x, self.conv3, self.conv3_BN)
+ x = self.layer(x, self.conv4, self.conv4_BN)
+ x = self.layer(x, self.conv5, self.conv5_BN)
+
+ return x
+
+ def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)):
+ """Forward pass through one layer"""
+ x = F.pad(x, padding)
+ x = conv(x)
+ x = F.relu(x)
+ x = batch_norm(x)
+ return F.max_pool2d(x, (2, 1), (2, 1))
diff --git a/crepe/threshold.py b/crepe/threshold.py
new file mode 100644
index 0000000000000000000000000000000000000000..85d6ec9bef2d03b0eb101c6b7fa4f3464cdf1554
--- /dev/null
+++ b/crepe/threshold.py
@@ -0,0 +1,134 @@
+import numpy as np
+import torch
+
+import crepe
+
+
+###############################################################################
+# Pitch thresholding methods
+###############################################################################
+
+
+class At:
+ """Simple thresholding at a specified probability value"""
+
+ def __init__(self, value):
+ self.value = value
+
+ def __call__(self, pitch, periodicity):
+ # Make a copy to prevent in-place modification
+ pitch = torch.clone(pitch)
+
+ # Threshold
+ pitch[periodicity < self.value] = crepe.UNVOICED
+ return pitch
+
+
+class Hysteresis:
+ """Hysteresis thresholding"""
+
+ def __init__(self,
+ lower_bound=.19,
+ upper_bound=.31,
+ width=.2,
+ stds=1.7,
+ return_threshold=False):
+ self.lower_bound = lower_bound
+ self.upper_bound = upper_bound
+ self.width = width
+ self.stds = stds
+ self.return_threshold = return_threshold
+
+ def __call__(self, pitch, periodicity):
+ # Save output device
+ device = pitch.device
+
+ # Perform hysteresis in log-2 space
+ pitch = torch.log2(pitch).detach().flatten().cpu().numpy()
+
+ # Flatten periodicity
+ periodicity = periodicity.flatten().cpu().numpy()
+
+ # Ignore confidently unvoiced pitch
+ pitch[periodicity < self.lower_bound] = crepe.UNVOICED
+
+ # Whiten pitch
+ mean, std = np.nanmean(pitch), np.nanstd(pitch)
+ pitch = (pitch - mean) / std
+
+ # Require high confidence to make predictions far from the mean
+ parabola = self.width * pitch ** 2 - self.width * self.stds ** 2
+ threshold = \
+ self.lower_bound + np.clip(parabola, 0, 1 - self.lower_bound)
+ threshold[np.isnan(threshold)] = self.lower_bound
+
+ # Apply hysteresis to prevent short, unconfident voiced regions
+ i = 0
+ while i < len(periodicity) - 1:
+
+ # Detect unvoiced to voiced transition
+ if periodicity[i] < threshold[i] and \
+ periodicity[i + 1] > threshold[i + 1]:
+
+ # Grow region until next unvoiced or end of array
+ start, end, keep = i + 1, i + 1, False
+ while end < len(periodicity) and \
+ periodicity[end] > threshold[end]:
+ if periodicity[end] > self.upper_bound:
+ keep = True
+ end += 1
+
+ # Force unvoiced if we didn't pass the confidence required by
+ # the hysteresis
+ if not keep:
+ threshold[start:end] = 1
+
+ i = end
+
+ else:
+ i += 1
+
+ # Remove pitch with low periodicity
+ pitch[periodicity < threshold] = crepe.UNVOICED
+
+ # Unwhiten
+ pitch = pitch * std + mean
+
+ # Convert to Hz
+ pitch = torch.tensor(2 ** pitch, device=device)[None, :]
+
+ # Optionally return threshold
+ if self.return_threshold:
+ return pitch, torch.tensor(threshold, device=device)
+
+ return pitch
+
+
+###############################################################################
+# Periodicity thresholding methods
+###############################################################################
+
+
+class Silence:
+ """Set periodicity to zero in silent regions"""
+
+ def __init__(self, value=-60):
+ self.value = value
+
+ def __call__(self,
+ periodicity,
+ audio,
+ sample_rate=crepe.SAMPLE_RATE,
+ hop_length=None,
+ pad=True):
+ # Don't modify in-place
+ periodicity = torch.clone(periodicity)
+
+ # Compute loudness
+ loudness = crepe.loudness.a_weighted(
+ audio, sample_rate, hop_length, pad)
+
+ # Threshold silence
+ periodicity[loudness < self.value] = 0.
+
+ return periodicity
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..6986a80645a34b56943e444e730047d256317745
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,201 @@
+name: so-vits-svc-5.0
+channels:
+ - pytorch
+ - anaconda
+ - nvidia
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _openmp_mutex=5.1=1_gnu
+ - blas=1.0=mkl
+ - brotlipy=0.7.0=py311h5eee18b_1002
+ - bzip2=1.0.8=h7b6447c_0
+ - ca-certificates=2023.08.22=h06a4308_0
+ - certifi=2023.7.22=py311h06a4308_0
+ - cffi=1.15.1=py311h5eee18b_3
+ - cryptography=41.0.3=py311hdda0065_0
+ - cuda-cudart=11.7.99=0
+ - cuda-cupti=11.7.101=0
+ - cuda-libraries=11.7.1=0
+ - cuda-nvrtc=11.7.99=0
+ - cuda-nvtx=11.7.91=0
+ - cuda-runtime=11.7.1=0
+ - cudatoolkit=11.3.1=h2bc3f7f_2
+ - ffmpeg=4.3=hf484d3e_0
+ - filelock=3.9.0=py311h06a4308_0
+ - freetype=2.12.1=h4a9f257_0
+ - giflib=5.2.1=h5eee18b_3
+ - gmp=6.2.1=h295c915_3
+ - gmpy2=2.1.2=py311hc9b5ff0_0
+ - gnutls=3.6.15=he1e5248_0
+ - idna=3.4=py311h06a4308_0
+ - intel-openmp=2023.1.0=hdb19cb5_46305
+ - jinja2=3.1.2=py311h06a4308_0
+ - jpeg=9e=h5eee18b_1
+ - lame=3.100=h7b6447c_0
+ - lcms2=2.12=h3be6417_0
+ - ld_impl_linux-64=2.38=h1181459_1
+ - lerc=3.0=h295c915_0
+ - libcublas=11.10.3.66=0
+ - libcufft=10.7.2.124=h4fbf590_0
+ - libcufile=1.7.2.10=0
+ - libcurand=10.3.3.141=0
+ - libcusolver=11.4.0.1=0
+ - libcusparse=11.7.4.91=0
+ - libdeflate=1.17=h5eee18b_0
+ - libffi=3.4.4=h6a678d5_0
+ - libgcc-ng=11.2.0=h1234567_1
+ - libgomp=11.2.0=h1234567_1
+ - libiconv=1.16=h7f8727e_2
+ - libidn2=2.3.4=h5eee18b_0
+ - libnpp=11.7.4.75=0
+ - libnvjpeg=11.8.0.2=0
+ - libpng=1.6.39=h5eee18b_0
+ - libstdcxx-ng=11.2.0=h1234567_1
+ - libtasn1=4.19.0=h5eee18b_0
+ - libtiff=4.5.1=h6a678d5_0
+ - libunistring=0.9.10=h27cfd23_0
+ - libuuid=1.41.5=h5eee18b_0
+ - libwebp=1.2.4=h11a3e52_1
+ - libwebp-base=1.2.4=h5eee18b_1
+ - lz4-c=1.9.4=h6a678d5_0
+ - markupsafe=2.1.1=py311h5eee18b_0
+ - mkl=2023.1.0=h213fc3f_46343
+ - mkl-service=2.4.0=py311h5eee18b_1
+ - mkl_fft=1.3.6=py311ha02d727_1
+ - mkl_random=1.2.2=py311ha02d727_1
+ - mpc=1.1.0=h10f8cd9_1
+ - mpfr=4.0.2=hb69a4c5_1
+ - mpmath=1.3.0=py311h06a4308_0
+ - ncurses=6.4=h6a678d5_0
+ - nettle=3.7.3=hbbd107a_1
+ - networkx=3.1=py311h06a4308_0
+ - numpy-base=1.25.2=py311hf175353_0
+ - openh264=2.1.1=h4ff587b_0
+ - openssl=3.0.10=h7f8727e_2
+ - pip=23.2.1=py311h06a4308_0
+ - pycparser=2.21=pyhd3eb1b0_0
+ - pyopenssl=23.2.0=py311h06a4308_0
+ - pysocks=1.7.1=py311h06a4308_0
+ - python=3.11.5=h955ad1f_0
+ - pytorch=2.0.1=py3.11_cuda11.7_cudnn8.5.0_0
+ - pytorch-cuda=11.7=h778d358_5
+ - pytorch-mutex=1.0=cuda
+ - readline=8.2=h5eee18b_0
+ - requests=2.31.0=py311h06a4308_0
+ - setuptools=68.0.0=py311h06a4308_0
+ - sqlite=3.41.2=h5eee18b_0
+ - sympy=1.11.1=py311h06a4308_0
+ - tbb=2021.8.0=hdb19cb5_0
+ - tk=8.6.12=h1ccaba5_0
+ - torchaudio=2.0.2=py311_cu117
+ - torchtriton=2.0.0=py311
+ - torchvision=0.15.2=py311_cu117
+ - typing_extensions=4.7.1=py311h06a4308_0
+ - urllib3=1.26.16=py311h06a4308_0
+ - wheel=0.38.4=py311h06a4308_0
+ - xz=5.4.2=h5eee18b_0
+ - zlib=1.2.13=h5eee18b_0
+ - zstd=1.5.5=hc292b87_0
+ - pip:
+ - absl-py==1.4.0
+ - aiofiles==23.2.1
+ - aiohttp==3.8.5
+ - aiosignal==1.3.1
+ - altair==5.1.1
+ - annotated-types==0.5.0
+ - antlr4-python3-runtime==4.9.3
+ - anyio==3.7.1
+ - async-timeout==4.0.3
+ - attrs==23.1.0
+ - audioread==3.0.0
+ - cachetools==5.3.1
+ - chardet==5.2.0
+ - charset-normalizer==3.2.0
+ - click==8.1.7
+ - contourpy==1.1.0
+ - cycler==0.11.0
+ - cython==3.0.2
+ - decorator==5.1.1
+ - fastapi==0.103.1
+ - ffmpy==0.3.1
+ - fonttools==4.42.1
+ - frozenlist==1.4.0
+ - fsspec==2023.9.0
+ - google-auth==2.23.0
+ - google-auth-oauthlib==1.0.0
+ - gradio==3.36.1
+ - gradio-client==0.5.0
+ - grpcio==1.58.0
+ - h11==0.14.0
+ - httpcore==0.18.0
+ - httpx==0.25.0
+ - huggingface-hub==0.17.1
+ - joblib==1.3.2
+ - jsonschema==4.19.0
+ - jsonschema-specifications==2023.7.1
+ - kiwisolver==1.4.5
+ - lazy-loader==0.3
+ - librosa==0.10.1
+ - linkify-it-py==2.0.2
+ - llvmlite==0.40.1
+ - markdown==3.4.4
+ - markdown-it-py==2.2.0
+ - matplotlib==3.7.3
+ - mdit-py-plugins==0.3.3
+ - mdurl==0.1.2
+ - msgpack==1.0.5
+ - multidict==6.0.4
+ - numba==0.57.1
+ - numpy==1.24.0
+ - oauthlib==3.2.2
+ - omegaconf==2.3.0
+ - orjson==3.9.7
+ - packaging==23.1
+ - pandas==2.1.0
+ - pillow==10.0.0
+ - platformdirs==3.10.0
+ - pooch==1.7.0
+ - protobuf==4.24.3
+ - pyasn1==0.5.0
+ - pyasn1-modules==0.3.0
+ - pydantic==2.3.0
+ - pydantic-core==2.6.3
+ - pydub==0.25.1
+ - pygments==2.16.1
+ - pyparsing==3.1.1
+ - python-dateutil==2.8.2
+ - python-multipart==0.0.6
+ - pytz==2023.3.post1
+ - pyworld==0.3.4
+ - pyyaml==6.0.1
+ - referencing==0.30.2
+ - regex==2023.8.8
+ - requests-oauthlib==1.3.1
+ - resampy==0.4.2
+ - rpds-py==0.10.2
+ - rsa==4.9
+ - ruamel-yaml==0.17.32
+ - ruamel-yaml-clib==0.2.7
+ - safetensors==0.3.3
+ - scikit-learn==1.3.0
+ - scipy==1.11.2
+ - semantic-version==2.10.0
+ - six==1.16.0
+ - sniffio==1.3.0
+ - soundfile==0.12.1
+ - soxr==0.3.6
+ - starlette==0.27.0
+ - tensorboard==2.14.0
+ - tensorboard-data-server==0.7.1
+ - threadpoolctl==3.2.0
+ - tokenizers==0.13.3
+ - toolz==0.12.0
+ - tqdm==4.66.1
+ - transformers==4.33.1
+ - tzdata==2023.3
+ - uc-micro-py==1.0.2
+ - uvicorn==0.23.2
+ - websockets==11.0.3
+ - werkzeug==2.3.7
+ - yarl==1.9.2
diff --git a/feature_retrieval/__init__.py b/feature_retrieval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..80bd7dbf8d83e7c6b7659ec3bb5b896e9e251888
--- /dev/null
+++ b/feature_retrieval/__init__.py
@@ -0,0 +1,4 @@
+from .index import *
+from .train import *
+from .transform import *
+from .retrieval import *
diff --git a/feature_retrieval/index.py b/feature_retrieval/index.py
new file mode 100644
index 0000000000000000000000000000000000000000..96c770196c6c4cb2437dd53ec0f4194c60f8f112
--- /dev/null
+++ b/feature_retrieval/index.py
@@ -0,0 +1,166 @@
+import abc
+import logging
+import math
+import time
+from pathlib import Path
+from typing import TypeVar, Generic, cast, Any
+
+import numpy as np
+import numpy.typing as npt
+
+from tqdm import tqdm
+
+import faiss
+from faiss import IndexIVF, Index
+
+logger = logging.getLogger(__name__)
+
+T = TypeVar("T", bound=Index)
+NumpyArray = npt.NDArray[np.float32]
+
+
+class FaissFeatureIndex(Generic[T], abc.ABC):
+ def __init__(self, index: T) -> None:
+ self._index = index
+
+ def save(self, filepath: Path, rewrite: bool = False) -> None:
+ if filepath.exists() and not rewrite:
+ raise FileExistsError(f"index already exists by path {filepath}")
+ faiss.write_index(self._index, str(filepath))
+
+
+class FaissRetrievableFeatureIndex(FaissFeatureIndex[Index], abc.ABC):
+ """retrieve voice feature vectors by faiss index"""
+
+ def __init__(self, index: T, ratio: float, n_nearest_vectors: int) -> None:
+ super().__init__(index=index)
+ if index.metric_type != self.supported_distance:
+ raise ValueError(f"index metric type {index.metric_type=} is unsupported {self.supported_distance=}")
+
+ if 1 > n_nearest_vectors:
+ raise ValueError("n-retrieval-vectors must be gte 1")
+ self._n_nearest = n_nearest_vectors
+
+ if 0 > ratio > 1:
+ raise ValueError(f"{ratio=} must be in rage (0, 1)")
+ self._ratio = ratio
+
+ @property
+ @abc.abstractmethod
+ def supported_distance(self) -> Any:
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def _weight_nearest_vectors(self, nearest_vectors: NumpyArray, scores: NumpyArray) -> NumpyArray:
+ raise NotImplementedError
+
+ def retriv(self, features: NumpyArray) -> NumpyArray:
+ # use method search_and_reconstruct instead of recreating the whole matrix
+ scores, _, nearest_vectors = self._index.search_and_reconstruct(features, k=self._n_nearest)
+ weighted_nearest_vectors = self._weight_nearest_vectors(nearest_vectors, scores)
+ retriv_vector = (1 - self._ratio) * features + self._ratio * weighted_nearest_vectors
+ return retriv_vector
+
+
+class FaissRVCRetrievableFeatureIndex(FaissRetrievableFeatureIndex):
+ """
+ retrieve voice encoded features with algorith from RVC repository
+ https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI
+ """
+
+ @property
+ def supported_distance(self) -> Any:
+ return faiss.METRIC_L2
+
+ def _weight_nearest_vectors(self, nearest_vectors: NumpyArray, scores: NumpyArray) -> NumpyArray:
+ """
+ magic code from original RVC
+ https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/86ed98aacaa8b2037aad795abd11cdca122cf39f/vc_infer_pipeline.py#L213C18-L213C19
+
+ nearest_vectors dim (n_nearest, vector_dim)
+ scores dim (num_vectors, n_nearest)
+ """
+ logger.debug("shape: nv=%s sc=%s", nearest_vectors.shape, scores.shape)
+ weight = np.square(1 / scores)
+ weight /= weight.sum(axis=1, keepdims=True)
+ weight = np.expand_dims(weight, axis=2)
+ weighted_nearest_vectors = np.sum(nearest_vectors * weight, axis=1)
+ logger.debug(
+ "shape: nv=%s weight=%s weight_nearest=%s",
+ nearest_vectors.shape,
+ weight.shape,
+ weighted_nearest_vectors.shape,
+ )
+ return cast(NumpyArray, weighted_nearest_vectors)
+
+
+class FaissIVFTrainableFeatureIndex(FaissFeatureIndex[IndexIVF]):
+ """IVF faiss index that can train and add feature vectors"""
+
+ def __init__(self, index: IndexIVF, batch_size: int) -> None:
+ super().__init__(index=index)
+ self._batch_size = batch_size
+
+ @property
+ def _trained_index(self) -> IndexIVF:
+ if not self._index.is_trained:
+ raise RuntimeError("index needs to be trained first")
+ return self._index
+
+ @property
+ def _not_trained_index(self) -> IndexIVF:
+ if self._index.is_trained:
+ raise RuntimeError("index is already trained")
+ return self._index
+
+ def _batch_count(self, feature_matrix: NumpyArray) -> int:
+ return math.ceil(feature_matrix.shape[0] / self._batch_size)
+
+ def _split_matrix_by_batch(self, feature_matrix: NumpyArray) -> list[NumpyArray]:
+ return np.array_split(feature_matrix, indices_or_sections=self._batch_count(feature_matrix), axis=0)
+
+ def _train_index(self, train_feature_matrix: NumpyArray) -> None:
+ start = time.monotonic()
+ self._not_trained_index.train(train_feature_matrix)
+ took = time.monotonic() - start
+ logger.info("index is trained. Took %.2f seconds", took)
+
+ def add_to_index(self, feature_matrix: NumpyArray) -> None:
+ n_batches = self._batch_count(feature_matrix)
+ logger.info("adding %s batches to index", n_batches)
+ start = time.monotonic()
+ for batch in tqdm(self._split_matrix_by_batch(feature_matrix), total=n_batches):
+ self._trained_index.add(batch)
+ took = time.monotonic() - start
+ logger.info("all batches added. Took %.2f seconds", took)
+
+ def add_with_train(self, feature_matrix: NumpyArray) -> None:
+ self._train_index(feature_matrix)
+ self.add_to_index(feature_matrix)
+
+
+class FaissIVFFlatTrainableFeatureIndexBuilder:
+ def __init__(self, batch_size: int, distance: int) -> None:
+ self._batch_size = batch_size
+ self._distance = distance
+
+ def _build_index(self, num_vectors: int, vector_dim: int) -> IndexIVF:
+ n_ivf = min(int(16 * np.sqrt(num_vectors)), num_vectors // 39)
+ factory_string = f"IVF{n_ivf},Flat"
+ index = faiss.index_factory(vector_dim, factory_string, self._distance)
+ logger.debug('faiss index built by string "%s" and dimension %s', factory_string, vector_dim)
+ index_ivf = faiss.extract_index_ivf(index)
+ index_ivf.nprobe = 1
+ return index
+
+ def build(self, num_vectors: int, vector_dim: int) -> FaissIVFTrainableFeatureIndex:
+ return FaissIVFTrainableFeatureIndex(
+ index=self._build_index(num_vectors, vector_dim),
+ batch_size=self._batch_size,
+ )
+
+
+def load_retrieve_index(filepath: Path, ratio: float, n_nearest_vectors: int) -> FaissRetrievableFeatureIndex:
+ return FaissRVCRetrievableFeatureIndex(
+ index=faiss.read_index(str(filepath)), ratio=ratio, n_nearest_vectors=n_nearest_vectors
+ )
diff --git a/feature_retrieval/retrieval.py b/feature_retrieval/retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2145ab20f1b07c3f0cde8c3f591b8571253c9b6
--- /dev/null
+++ b/feature_retrieval/retrieval.py
@@ -0,0 +1,44 @@
+import abc
+import logging
+
+import torch
+
+from feature_retrieval import FaissRetrievableFeatureIndex
+
+logger = logging.getLogger(__name__)
+
+
+class IRetrieval(abc.ABC):
+ @abc.abstractmethod
+ def retriv_whisper(self, vec: torch.Tensor) -> torch.Tensor:
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def retriv_hubert(self, vec: torch.Tensor) -> torch.Tensor:
+ raise NotImplementedError
+
+
+class DummyRetrieval(IRetrieval):
+ def retriv_whisper(self, vec: torch.FloatTensor) -> torch.FloatTensor:
+ logger.debug("start dummy retriv whisper")
+ return vec.clone().to(torch.device("cpu"))
+
+ def retriv_hubert(self, vec: torch.FloatTensor) -> torch.FloatTensor:
+ logger.debug("start dummy retriv hubert")
+ return vec.clone().to(torch.device("cpu"))
+
+
+class FaissIndexRetrieval(IRetrieval):
+ def __init__(self, hubert_index: FaissRetrievableFeatureIndex, whisper_index: FaissRetrievableFeatureIndex) -> None:
+ self._hubert_index = hubert_index
+ self._whisper_index = whisper_index
+
+ def retriv_whisper(self, vec: torch.Tensor) -> torch.Tensor:
+ logger.debug("start retriv whisper")
+ np_vec = self._whisper_index.retriv(vec.numpy())
+ return torch.from_numpy(np_vec)
+
+ def retriv_hubert(self, vec: torch.Tensor) -> torch.Tensor:
+ logger.debug("start retriv hubert")
+ np_vec = self._hubert_index.retriv(vec.numpy())
+ return torch.from_numpy(np_vec)
diff --git a/feature_retrieval/train.py b/feature_retrieval/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a565804afe32175b2e5a83f747447773f5f56b4
--- /dev/null
+++ b/feature_retrieval/train.py
@@ -0,0 +1,37 @@
+from pathlib import Path
+from typing import cast
+
+import numpy as np
+
+from feature_retrieval import NumpyArray
+from feature_retrieval.index import FaissIVFFlatTrainableFeatureIndexBuilder, logger
+from feature_retrieval.transform import IFeatureMatrixTransform
+
+
+def train_index(
+ features_path: Path,
+ index_save_filepath: Path,
+ index_builder: FaissIVFFlatTrainableFeatureIndexBuilder,
+ feature_transform: IFeatureMatrixTransform,
+) -> None:
+ logger.info("start getting feature vectors from %s", features_path.absolute())
+ feature_matrix = get_feature_matrix(features_path)
+ logger.debug("fetched %s features", feature_matrix.shape[0])
+
+ logger.info("apply transform to feature matrix")
+ feature_matrix = feature_transform.transform(feature_matrix)
+ num_vectors, vector_dim = feature_matrix.shape
+ logger.debug("features transformed. Current features %s", num_vectors)
+
+ feature_index = index_builder.build(num_vectors=num_vectors, vector_dim=vector_dim)
+ logger.info("adding features to index with training")
+
+ feature_index.add_with_train(feature_matrix)
+ feature_index.save(index_save_filepath)
+ logger.info("index saved to %s", index_save_filepath.absolute())
+
+
+def get_feature_matrix(features_dir_path: Path) -> NumpyArray:
+ matrices = [np.load(str(features_path)) for features_path in features_dir_path.rglob("*.npy")]
+ feature_matrix = np.concatenate(matrices, axis=0)
+ return cast(NumpyArray, feature_matrix)
diff --git a/feature_retrieval/transform.py b/feature_retrieval/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c4ecf4e817ef17631fddb065e3fd0742aad4b44
--- /dev/null
+++ b/feature_retrieval/transform.py
@@ -0,0 +1,72 @@
+import abc
+import logging
+from typing import cast, Callable
+
+from sklearn.cluster import MiniBatchKMeans
+
+from feature_retrieval.index import NumpyArray
+
+
+logger = logging.getLogger(__name__)
+
+
+class IFeatureMatrixTransform:
+ """Interface for transform encoded voice feature from (n_features,vector_dim) to (m_features,vector_dim)"""
+
+ @abc.abstractmethod
+ def transform(self, matrix: NumpyArray) -> NumpyArray:
+ """transform given feature matrix from (n_features,vector_dim) to (m_features,vector_dim)"""
+ raise NotImplementedError
+
+
+class DummyFeatureTransform(IFeatureMatrixTransform):
+ """do nothing"""
+
+ def transform(self, matrix: NumpyArray) -> NumpyArray:
+ return matrix
+
+
+class MinibatchKmeansFeatureTransform(IFeatureMatrixTransform):
+ """replaces number of examples with k-means centroids using minibatch algorythm"""
+
+ def __init__(self, n_clusters: int, n_parallel: int) -> None:
+ self._n_clusters = n_clusters
+ self._n_parallel = n_parallel
+
+ @property
+ def _batch_size(self) -> int:
+ return self._n_parallel * 256
+
+ def transform(self, matrix: NumpyArray) -> NumpyArray:
+ """transform given feature matrix from (n_features,vector_dim) to (n_clusters,vector_dim)"""
+ cluster = MiniBatchKMeans(
+ n_clusters=self._n_clusters,
+ verbose=True,
+ batch_size=self._batch_size,
+ compute_labels=False,
+ init="k-means++",
+ )
+ return cast(NumpyArray, cluster.fit(matrix).cluster_centers_)
+
+
+class OnConditionFeatureTransform(IFeatureMatrixTransform):
+ """call given transform if condition is True else call otherwise transform"""
+
+ def __init__(
+ self,
+ condition: Callable[[NumpyArray], bool],
+ on_condition: IFeatureMatrixTransform,
+ otherwise: IFeatureMatrixTransform,
+ ) -> None:
+ self._condition = condition
+ self._on_condition = on_condition
+ self._otherwise = otherwise
+
+ def transform(self, matrix: NumpyArray) -> NumpyArray:
+ if self._condition(matrix):
+ transform_name = self._on_condition.__class__.__name__
+ logger.info(f"pass condition. Transform by rule {transform_name}")
+ return self._on_condition.transform(matrix)
+ transform_name = self._otherwise.__class__.__name__
+ logger.info(f"condition is not passed. Transform by rule {transform_name}")
+ return self._otherwise.transform(matrix)
diff --git a/hubert/LICENSE.txt b/hubert/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6eb2af050447968cc32481fcfe67b5a4c6cdc69e
--- /dev/null
+++ b/hubert/LICENSE.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 Benjamin van Niekerk
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/hubert/__init__.py b/hubert/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hubert/hubert_model.py b/hubert/hubert_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fb642d89b07ca60792debab18e3454f52d8f357
--- /dev/null
+++ b/hubert/hubert_model.py
@@ -0,0 +1,222 @@
+import copy
+import random
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as t_func
+from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
+
+
+class Hubert(nn.Module):
+ def __init__(self, num_label_embeddings: int = 100, mask: bool = True):
+ super().__init__()
+ self._mask = mask
+ self.feature_extractor = FeatureExtractor()
+ self.feature_projection = FeatureProjection()
+ self.positional_embedding = PositionalConvEmbedding()
+ self.norm = nn.LayerNorm(768)
+ self.dropout = nn.Dropout(0.1)
+ self.encoder = TransformerEncoder(
+ nn.TransformerEncoderLayer(
+ 768, 12, 3072, activation="gelu", batch_first=True
+ ),
+ 12,
+ )
+ self.proj = nn.Linear(768, 256)
+
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_())
+ self.label_embedding = nn.Embedding(num_label_embeddings, 256)
+
+ def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ mask = None
+ if self.training and self._mask:
+ mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2)
+ x[mask] = self.masked_spec_embed.to(x.dtype)
+ return x, mask
+
+ def encode(
+ self, x: torch.Tensor, layer: Optional[int] = None
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ x = self.feature_extractor(x)
+ x = self.feature_projection(x.transpose(1, 2))
+ x, mask = self.mask(x)
+ x = x + self.positional_embedding(x)
+ x = self.dropout(self.norm(x))
+ x = self.encoder(x, output_layer=layer)
+ return x, mask
+
+ def logits(self, x: torch.Tensor) -> torch.Tensor:
+ logits = torch.cosine_similarity(
+ x.unsqueeze(2),
+ self.label_embedding.weight.unsqueeze(0).unsqueeze(0),
+ dim=-1,
+ )
+ return logits / 0.1
+
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ x, mask = self.encode(x)
+ x = self.proj(x)
+ logits = self.logits(x)
+ return logits, mask
+
+
+class HubertSoft(Hubert):
+ def __init__(self):
+ super().__init__()
+
+ @torch.inference_mode()
+ def units(self, wav: torch.Tensor) -> torch.Tensor:
+ wav = t_func.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
+ x, _ = self.encode(wav)
+ return self.proj(x)
+
+
+class FeatureExtractor(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False)
+ self.norm0 = nn.GroupNorm(512, 512)
+ self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False)
+ self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False)
+ self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False)
+ self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False)
+ self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False)
+ self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = t_func.gelu(self.norm0(self.conv0(x)))
+ x = t_func.gelu(self.conv1(x))
+ x = t_func.gelu(self.conv2(x))
+ x = t_func.gelu(self.conv3(x))
+ x = t_func.gelu(self.conv4(x))
+ x = t_func.gelu(self.conv5(x))
+ x = t_func.gelu(self.conv6(x))
+ return x
+
+
+class FeatureProjection(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.norm = nn.LayerNorm(512)
+ self.projection = nn.Linear(512, 768)
+ self.dropout = nn.Dropout(0.1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.norm(x)
+ x = self.projection(x)
+ x = self.dropout(x)
+ return x
+
+
+class PositionalConvEmbedding(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ 768,
+ 768,
+ kernel_size=128,
+ padding=128 // 2,
+ groups=16,
+ )
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.conv(x.transpose(1, 2))
+ x = t_func.gelu(x[:, :, :-1])
+ return x.transpose(1, 2)
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(
+ self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int
+ ) -> None:
+ super(TransformerEncoder, self).__init__()
+ self.layers = nn.ModuleList(
+ [copy.deepcopy(encoder_layer) for _ in range(num_layers)]
+ )
+ self.num_layers = num_layers
+
+ def forward(
+ self,
+ src: torch.Tensor,
+ mask: torch.Tensor = None,
+ src_key_padding_mask: torch.Tensor = None,
+ output_layer: Optional[int] = None,
+ ) -> torch.Tensor:
+ output = src
+ for layer in self.layers[:output_layer]:
+ output = layer(
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
+ )
+ return output
+
+
+def _compute_mask(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ device: torch.device,
+ min_masks: int = 0,
+) -> torch.Tensor:
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
+ )
+
+ # compute number of masked spans in batch
+ num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random())
+ num_masked_spans = max(num_masked_spans, min_masks)
+
+ # make sure num masked indices <= sequence_length
+ if num_masked_spans * mask_length > sequence_length:
+ num_masked_spans = sequence_length // mask_length
+
+ # SpecAugment mask to fill
+ mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
+
+ # uniform distribution to sample from, make sure that offset samples are < sequence_length
+ uniform_dist = torch.ones(
+ (batch_size, sequence_length - (mask_length - 1)), device=device
+ )
+
+ # get random indices to mask
+ mask_indices = torch.multinomial(uniform_dist, num_masked_spans)
+
+ # expand masked indices to masked spans
+ mask_indices = (
+ mask_indices.unsqueeze(dim=-1)
+ .expand((batch_size, num_masked_spans, mask_length))
+ .reshape(batch_size, num_masked_spans * mask_length)
+ )
+ offsets = (
+ torch.arange(mask_length, device=device)[None, None, :]
+ .expand((batch_size, num_masked_spans, mask_length))
+ .reshape(batch_size, num_masked_spans * mask_length)
+ )
+ mask_idxs = mask_indices + offsets
+
+ # scatter indices to mask
+ mask = mask.scatter(1, mask_idxs, True)
+
+ return mask
+
+
+def hubert_soft(
+ path: str,
+) -> HubertSoft:
+ r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
+ Args:
+ path (str): path of a pretrained model
+ """
+ hubert = HubertSoft()
+ checkpoint = torch.load(path)
+ consume_prefix_in_state_dict_if_present(checkpoint, "module.")
+ hubert.load_state_dict(checkpoint)
+ hubert.eval()
+ return hubert
diff --git a/hubert/inference.py b/hubert/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac4fbeec8cb636569aef7aeb938ef709b115ab4a
--- /dev/null
+++ b/hubert/inference.py
@@ -0,0 +1,67 @@
+import sys,os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import numpy as np
+import argparse
+import torch
+import librosa
+
+from hubert import hubert_model
+
+
+def load_audio(file: str, sr: int = 16000):
+ x, sr = librosa.load(file, sr=sr)
+ return x
+
+
+def load_model(path, device):
+ model = hubert_model.hubert_soft(path)
+ model.eval()
+ if not (device == "cpu"):
+ model.half()
+ model.to(device)
+ return model
+
+
+def pred_vec(model, wavPath, vecPath, device):
+ audio = load_audio(wavPath)
+ audln = audio.shape[0]
+ vec_a = []
+ idx_s = 0
+ while (idx_s + 20 * 16000 < audln):
+ feats = audio[idx_s:idx_s + 20 * 16000]
+ feats = torch.from_numpy(feats).to(device)
+ feats = feats[None, None, :]
+ if not (device == "cpu"):
+ feats = feats.half()
+ with torch.no_grad():
+ vec = model.units(feats).squeeze().data.cpu().float().numpy()
+ vec_a.extend(vec)
+ idx_s = idx_s + 20 * 16000
+ if (idx_s < audln):
+ feats = audio[idx_s:audln]
+ feats = torch.from_numpy(feats).to(device)
+ feats = feats[None, None, :]
+ if not (device == "cpu"):
+ feats = feats.half()
+ with torch.no_grad():
+ vec = model.units(feats).squeeze().data.cpu().float().numpy()
+ # print(vec.shape) # [length, dim=256] hop=320
+ vec_a.extend(vec)
+ np.save(vecPath, vec_a, allow_pickle=False)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True)
+ parser.add_argument("-v", "--vec", help="vec", dest="vec", required=True)
+ args = parser.parse_args()
+ print(args.wav)
+ print(args.vec)
+
+ wavPath = args.wav
+ vecPath = args.vec
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ hubert = load_model(os.path.join(
+ "hubert_pretrain", "hubert-soft-0d54a1f4.pt"), device)
+ pred_vec(hubert, wavPath, vecPath, device)
diff --git a/hubert_pretrain/.DS_Store b/hubert_pretrain/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6
Binary files /dev/null and b/hubert_pretrain/.DS_Store differ
diff --git a/hubert_pretrain/README.md b/hubert_pretrain/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..dbecfeb8bdf2b4120f3331252dd123ffba46e30c
--- /dev/null
+++ b/hubert_pretrain/README.md
@@ -0,0 +1,3 @@
+Path for:
+
+ hubert-soft-0d54a1f4.pt
\ No newline at end of file
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..276ba5e2edadb1eb5bffc5f4209d26136fe2b63d
--- /dev/null
+++ b/main.py
@@ -0,0 +1,32 @@
+import gradio as gr
+
+def click_test():
+ """1から10までのランダムな数値を生成する関数"""
+ import random
+ number = random.randint(1, 10)
+ return f"生成された数値: {number}"
+
+# Gradio インターフェースの作成
+with gr.Blocks() as demo:
+ gr.Markdown("# ランダム数値ジェネレーター")
+ gr.Markdown("下のボタンをクリックすると1から10までのランダムな数値が生成されます。")
+
+ # 結果表示用のテキストボックス
+ output_text = gr.Text(label="結果")
+
+ # カスタムボタンの追加
+ generate_btn = gr.Button(
+ value="数値を生成する", # ボタンのテキスト
+ variant="primary", # ボタンのスタイル
+ size="lg" # ボタンのサイズ
+ )
+
+ # ボタンクリック時のイベント設定
+ generate_btn.click(
+ fn=click_test,
+ outputs=output_text
+ )
+
+# アプリケーションの起動
+if __name__ == "__main__":
+ demo.launch()
\ No newline at end of file
diff --git a/pitch/__init__.py b/pitch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc418142ecbe692527dfe1f1768205a50726dd2f
--- /dev/null
+++ b/pitch/__init__.py
@@ -0,0 +1 @@
+from .inference import load_csv_pitch
\ No newline at end of file
diff --git a/pitch/core/LICENCE b/pitch/core/LICENCE
new file mode 100644
index 0000000000000000000000000000000000000000..7e7c9386da890afff77de85d059a03ea6139865c
--- /dev/null
+++ b/pitch/core/LICENCE
@@ -0,0 +1,25 @@
+MIT License
+
+Copyright (c) 2022 Sebastian Rosenzweig, Simon Schwär, Meinard Müller, International Audio Laboratories Erlangen, Germany.
+We thank the German Research Foundation (DFG) for various research grants that
+allow us for conducting fundamental research in music processing.
+The International Audio Laboratories Erlangen are a joint institution of the
+Friedrich-Alexander-Universität Erlangen-Nürnberg (FAU) and Fraunhofer
+Institute for Integrated Circuits IIS.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+the Software, and to permit persons to whom the Software is furnished to do so,
+subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
diff --git a/pitch/core/README.md b/pitch/core/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f75171a8bf86370f212f0a6ae508c34a7ca9c421
--- /dev/null
+++ b/pitch/core/README.md
@@ -0,0 +1,41 @@
+This repository contains a Python package called libf0 which provides open-source implementations for four popular model-based F0-estimation approaches, YIN (Cheveigné & Kawahara, 2002), pYIN (Mauch & Dixon, 2014), an approach inspired by Melodia (Salamon & Gómez, 2012), and SWIPE (Camacho & Harris, 2008).
+
+If you use the libf0 in your research, please consider the following references.
+
+## References
+
+Sebastian Rosenzweig, Simon Schwär, and Meinard Müller.
+[A Python Library for Fundamental Frequency Estimation.](https://archives.ismir.net/ismir2022/latebreaking/000003.pdf)
+In Late Breaking Demos of the International Society for Music Information Retrieval Conference (ISMIR), Bengaluru, India, 2022.
+
+Alain de Cheveigné and Hideki Kawahara.
+YIN, a fundamental frequency estimator for speech and music. Journal of the Acoustical Society of America (JASA), 111(4):1917–1930, 2002.
+
+Matthias Mauch and Simon Dixon.
+pYIN: A fundamental frequency estimator using probabilistic threshold distributions. In IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pages 659–663, Florence, Italy, 2014.
+
+Justin Salamon and Emilia Gómez.
+Melody extraction from polyphonic music signals using pitch contour characteristics. IEEE Transactions on Audio, Speech, and Language Processing, 20(6):
+1759–1770, 2012.
+
+Arturo Camacho and John G. Harris.
+A sawtooth waveform inspired pitch estimator for speech and music. The Journal of the Acoustical Society of America, 124(3):1638–1652, 2008.
+
+Meinard Müller. Fundamentals of Music Processing – Using Python and Jupyter Notebooks. Springer Verlag, 2nd edition, 2021. ISBN 978-3-030-69807-2. doi: 10.1007/978-3-030-69808-9.
+
+## Documentation
+There is also an API documentation for libf0:
+
+https://groupmm.github.io/libf0
+
+## Contributing
+
+We are happy for suggestions and contributions. We would be grateful for either directly contacting us via email (meinard.mueller@audiolabs-erlangen.de) or for creating an issue in our Github repository. Please do not submit a pull request without prior consultation with us.
+
+## Licence
+
+The code for this toolbox is published under an MIT licence.
+
+## Acknowledgements
+
+This work was supported by the German Research Foundation (MU 2686/13-1, SCHE 280/20-1). We thank Edgar Suárez and Vojtěch Pešek for helping with the implementations. Furthermore, we thank Fatemeh Eftekhar and Maryam Pirmoradi for testing the toolbox. The International Audio Laboratories Erlangen are a joint institution of the Friedrich-Alexander-Universität Erlangen-Nürnberg (FAU) and Fraunhofer Institute for Integrated Circuits IIS.
diff --git a/pitch/core/__init__.py b/pitch/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/pitch/core/pyin.py b/pitch/core/pyin.py
new file mode 100644
index 0000000000000000000000000000000000000000..8045599eaf292b00c67d1783e97e7facacb6492e
--- /dev/null
+++ b/pitch/core/pyin.py
@@ -0,0 +1,481 @@
+"""
+| Description: libf0 yin implementation
+| Contributors: Sebastian Rosenzweig, Simon Schwär, Edgar Suárez, Meinard Müller
+| License: The MIT license, https://opensource.org/licenses/MIT
+| This file is part of libf0.
+"""
+import numpy as np
+from scipy.special import beta, comb # Scipy library for binomial beta distribution
+from scipy.stats import triang # Scipy library for triangular distribution
+from .yin import cumulative_mean_normalized_difference_function, parabolic_interpolation
+from numba import njit
+
+
+# pYIN estimate computation
+def pyin(x, Fs=22050, N=2048, H=256, F_min=55.0, F_max=1760.0, R=10, thresholds=np.arange(0.01, 1, 0.01),
+ beta_params=[1, 18], absolute_min_prob=0.01, voicing_prob=0.5):
+ """
+ Implementation of the pYIN F0-estimation algorithm.
+
+ .. [#] Matthias Mauch and Simon Dixon.
+ "PYIN: A fundamental frequency estimator using probabilistic threshold distributions".
+ IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) (2014): 659-663.
+
+ Parameters
+ ----------
+ x : ndarray
+ Audio signal
+ Fs : int
+ Sampling rate
+ N : int
+ Window size
+ H : int
+ Hop size
+ F_min : float or int
+ Minimal frequency
+ F_max : float or int
+ Maximal frequency
+ R : int
+ Frequency resolution given in cents
+ thresholds : ndarray
+ Range of thresholds
+ beta_params : tuple or list
+ Parameters of beta-distribution in the form [alpha, beta]
+ absolute_min_prob : float
+ Prior for voice activity
+ voicing_prob: float
+ Prior for transition probability?
+ Returns
+ -------
+ f0 : ndarray
+ Estimated F0-trajectory
+ t : ndarray
+ Time axis
+ conf : ndarray
+ Confidence
+ """
+
+ if F_min > F_max:
+ raise Exception("F_min must be smaller than F_max!")
+
+ if F_min < Fs/N:
+ raise Exception(f"The condition (F_min >= Fs/N) was not met. With Fs = {Fs}, N = {N} and F_min = {F_min} you have the following options: \n1) Set F_min >= {np.ceil(Fs/N)} Hz. \n2) Set N >= {np.ceil(Fs/F_min).astype(int)}. \n3) Set Fs <= {np.floor(F_min * N)} Hz.")
+
+ x_pad = np.concatenate((np.zeros(N // 2), x, np.zeros(N // 2))) # Add zeros for centered estimates
+
+ # Compute Beta distribution
+ thr_idxs = np.arange(len(thresholds))
+ beta_distr = comb(len(thresholds), thr_idxs) * beta(thr_idxs+beta_params[0],
+ len(thresholds)-thr_idxs+beta_params[1]) / beta(beta_params[0],
+ beta_params[1])
+
+ # YIN with multiple thresholds, yielding observation matrix
+ B = int(np.log2(F_max / F_min) * (1200 / R))
+ F_axis = F_min * np.power(2, np.arange(B) * R / 1200) # for quantizing the estimated F0s
+ O, rms, p_orig, val_orig = yin_multi_thr(x_pad, Fs=Fs, N=N, H=H, F_min=F_min, F_max=F_max, thresholds=thresholds,
+ beta_distr=beta_distr, absolute_min_prob=absolute_min_prob, F_axis=F_axis,
+ voicing_prob=voicing_prob)
+
+ # Transition matrix, using triangular distribution used for pitch transition probabilities
+ max_step_cents = 50 # Pitch jump can be at most 50 cents from frame to frame
+ max_step = int(max_step_cents / R)
+ triang_distr = triang.pdf(np.arange(-max_step, max_step+1), 0.5, scale=2*max_step, loc=-max_step)
+ A = compute_transition_matrix(B, triang_distr)
+
+ # HMM smoothing
+ C = np.ones((2*B, 1)) / (2*B) # uniform initialization
+ f0_idxs = viterbi_log_likelihood(A, C.flatten(), O) # libfmp Viterbi implementation
+
+ # Obtain F0-trajectory
+ F_axis_extended = np.concatenate((F_axis, np.zeros(len(F_axis))))
+ f0 = F_axis_extended[f0_idxs]
+
+ # Suppress low power estimates
+ f0[0] = 0 # due to algorithmic reasons, we set the first value unvoiced
+ f0[rms < 0.01] = 0
+
+ # confidence
+ O_norm = O[:, np.arange(O.shape[1])]/np.max(O, axis=0)
+ conf = O_norm[f0_idxs, np.arange(O.shape[1])]
+
+ # Refine estimates by choosing the closest original YIN estimate
+ refine_estimates = True
+ if refine_estimates:
+ f0 = refine_estimates_yin(f0, p_orig, val_orig, Fs, R)
+
+ t = np.arange(O.shape[1]) * H / Fs # Time axis
+
+ return f0, t, conf
+
+
+@njit
+def refine_estimates_yin(f0, p_orig, val_orig, Fs, tol):
+ """
+ Refine estimates using YIN CMNDF information.
+
+ Parameters
+ ----------
+ f0 : ndarray
+ F0 in Hz
+ p_orig : ndarray
+ Original lag as computed by YIN
+ val_orig : ndarray
+ Original CMNDF values as computed by YIN
+ Fs : float
+ Sampling frequency
+ tol : float
+ Tolerance for refinements in cents
+
+ Returns
+ -------
+ f0_refined : ndarray
+ Refined F0-trajectory
+ """
+ f0_refined = np.zeros_like(f0)
+ voiced_idxs = np.where(f0 > 0)[0]
+
+ f_orig = Fs / p_orig
+
+ # find closest original YIN estimate, maximally allowed absolute deviation: R (quantization error)
+ for m in voiced_idxs:
+ diff_cents = np.abs(1200 * np.log2(f_orig[:, m] / f0[m]))
+ candidate_idxs = np.where(diff_cents < tol)[0]
+
+ if not candidate_idxs.size:
+ f0_refined[m] = f0[m]
+ else:
+ f0_refined[m] = f_orig[candidate_idxs[np.argmin(val_orig[candidate_idxs, m])], m]
+
+ return f0_refined
+
+
+@njit
+def probabilistic_thresholding(cmndf, thresholds, p_min, p_max, absolute_min_prob, F_axis, Fs, beta_distr,
+ parabolic_interp=True):
+ """
+ Probabilistic thresholding of the YIN CMNDF.
+
+ Parameters
+ ----------
+ cmndf : ndarray
+ Cumulative Mean Normalized Difference Function
+ thresholds : ndarray
+ Array of thresholds for CMNDF
+ p_min : float
+ Period corresponding to the lower frequency bound
+ p_max : float
+ Period corresponding to the upper frequency bound
+ absolute_min_prob : float
+ Probability to chose absolute minimum
+ F_axis : ndarray
+ Frequency axis
+ Fs : float
+ Sampling rate
+ beta_distr : ndarray
+ Beta distribution that defines mapping between thresholds and probabilities
+ parabolic_interp : bool
+ Switch to activate/deactivate parabolic interpolation
+
+ Returns
+ -------
+ O_m : ndarray
+ Observations for given frame
+ lag_thr : ndarray
+ Computed lags for every threshold
+ val_thr : ndarray
+ CMNDF values for computed lag
+ """
+ # restrict search range to interval [p_min:p_max]
+ cmndf[:p_min] = np.inf
+ cmndf[p_max:] = np.inf
+
+ # find local minima (assuming that cmndf is real in [p_min:p_max], you will always find a minimum,
+ # at least p_min or p_max)
+ min_idxs = (np.argwhere((cmndf[1:-1] < cmndf[0:-2]) & (cmndf[1:-1] < cmndf[2:]))).flatten().astype(np.int64) + 1
+
+ O_m = np.zeros(2 * len(F_axis))
+
+ # return if no minima are found, e.g., when frame is silence
+ if min_idxs.size == 0:
+ return O_m, np.ones_like(thresholds)*p_min, np.ones_like(thresholds)
+
+ # Optional: Parabolic Interpolation of local minima
+ if parabolic_interp:
+ # do not interpolate at the boarders, Numba compatible workaround for np.delete()
+ min_idxs_interp = delete_numba(min_idxs, np.argwhere(min_idxs == p_min))
+ min_idxs_interp = delete_numba(min_idxs_interp, np.argwhere(min_idxs_interp == p_max - 1))
+ p_corr, cmndf[min_idxs_interp] = parabolic_interpolation(cmndf[min_idxs_interp - 1],
+ cmndf[min_idxs_interp],
+ cmndf[min_idxs_interp + 1])
+ else:
+ p_corr = np.zeros_like(min_idxs).astype(np.float64)
+
+ # set p_corr=0 at the boarders (no correction done later)
+ if min_idxs[0] == p_min:
+ p_corr = np.concatenate((np.array([0.0]), p_corr))
+
+ if min_idxs[-1] == p_max - 1:
+ p_corr = np.concatenate((p_corr, np.array([0.0])))
+
+ lag_thr = np.zeros_like(thresholds)
+ val_thr = np.zeros_like(thresholds)
+
+ # loop over all thresholds
+ for i, threshold in enumerate(thresholds):
+ # minima below absolute threshold
+ min_idxs_thr = min_idxs[cmndf[min_idxs] < threshold]
+
+ # find first local minimum
+ if not min_idxs_thr.size:
+ lag = np.argmin(cmndf) # choose absolute minimum when no local minimum is found
+ am_prob = absolute_min_prob
+ val = np.min(cmndf)
+ else:
+ am_prob = 1
+ lag = np.min(min_idxs_thr) # choose first local minimum
+ val = cmndf[lag]
+
+ # correct lag
+ if parabolic_interp:
+ lag += p_corr[np.argmin(min_idxs_thr)]
+
+ # ensure that lag is in [p_min:p_max]
+ if lag < p_min:
+ lag = p_min
+ elif lag >= p_max:
+ lag = p_max - 1
+
+ lag_thr[i] = lag
+ val_thr[i] = val
+
+ idx = np.argmin(np.abs(1200 * np.log2(F_axis / (Fs / lag)))) # quantize estimated period
+ O_m[idx] += am_prob * beta_distr[i] # pYIN-Paper, Formula 4/5
+
+ return O_m, lag_thr, val_thr
+
+
+@njit
+def yin_multi_thr(x, Fs, N, H, F_min, F_max, thresholds, beta_distr, absolute_min_prob, F_axis, voicing_prob,
+ parabolic_interp=True):
+ """
+ Applies YIN multiple times on input audio signals using different thresholds for CMNDF.
+
+ Parameters
+ ----------
+ x : ndarray
+ Input audio signal
+ Fs : int
+ Sampling rate
+ N : int
+ Window size
+ H : int
+ Hop size
+ F_min : float
+ Lower frequency bound
+ F_max : float
+ Upper frequency bound
+ thresholds : ndarray
+ Array of thresholds
+ beta_distr : ndarray
+ Beta distribution that defines mapping between thresholds and probabilities
+ absolute_min_prob :float
+ Probability to chose absolute minimum
+ F_axis : ndarray
+ Frequency axis
+ voicing_prob : float
+ Probability of a frame being voiced
+ parabolic_interp : bool
+ Switch to activate/deactivate parabolic interpolation
+
+ Returns
+ -------
+ O : ndarray
+ Observations based on YIN output
+ rms : ndarray
+ Root mean square power
+ p_orig : ndarray
+ Original YIN period estimates
+ val_orig : ndarray
+ CMNDF values corresponding to original YIN period estimates
+ """
+
+ M = int(np.floor((len(x) - N) / H)) + 1 # Compute number of estimates that will be generated
+ B = len(F_axis)
+
+ p_min = max(int(np.ceil(Fs / F_max)), 1) # period of maximal frequency in frames
+ p_max = int(np.ceil(Fs / F_min)) # period of minimal frequency in frames
+
+ if p_max > N:
+ raise Exception("The condition (Fmin >= Fs/N) was not met.")
+
+ rms = np.zeros(M) # RMS Power
+ O = np.zeros((2 * B, M)) # every voiced state has an unvoiced state (important for later HMM modeling)
+ p_orig = np.zeros((len(thresholds), M))
+ val_orig = np.zeros((len(thresholds), M))
+
+ for m in range(M):
+ # Take a frame from input signal
+ frame = x[m * H:m * H + N]
+
+ # Cumulative Mean Normalized Difference Function
+ cmndf = cumulative_mean_normalized_difference_function(frame, p_max)
+
+ # compute RMS power
+ rms[m] = np.sqrt(np.mean(frame ** 2))
+
+ # Probabilistic Thresholding with different thresholds
+ O_m, p_est_thr, val_thr = probabilistic_thresholding(cmndf, thresholds, p_min, p_max, absolute_min_prob, F_axis,
+ Fs, beta_distr, parabolic_interp=parabolic_interp)
+
+ O[:, m] = O_m
+ p_orig[:, m] = p_est_thr # store original YIN estimates for later refinement
+ val_orig[:, m] = val_thr # store original cmndf value of minimum corresponding to p_est
+
+ # normalization (pYIN-Paper, Formula 6)
+ O[0:B, :] *= voicing_prob
+ O[B:2 * B, :] = (1 - voicing_prob) * (1 - np.sum(O[0:B, :], axis=0)) / B
+
+ return O, rms, p_orig, val_orig
+
+
+@njit
+def compute_transition_matrix(M, triang_distr):
+ """
+ Compute a transition matrix for PYIN Viterbi.
+
+ Parameters
+ ----------
+ M : int
+ Matrix dimension
+ triang_distr : ndarray
+ (Triangular) distribution, defining tolerance for jumps deviating from the main diagonal
+
+ Returns
+ -------
+ A : ndarray
+ Transition matrix
+ """
+ prob_self = 0.99
+
+ A = np.zeros((2*M, 2*M))
+ max_step = len(triang_distr) // 2
+
+ for i in range(M):
+ if i < max_step:
+ A[i, 0:i+max_step] = prob_self * triang_distr[max_step - i:-1] / np.sum(triang_distr[max_step - i:-1])
+ A[i+M, M:i+M+max_step] = prob_self * triang_distr[max_step - i:-1] / np.sum(triang_distr[max_step - i:-1])
+
+ if i >= max_step and i < M-max_step:
+ A[i, i-max_step:i+max_step+1] = prob_self * triang_distr
+ A[i+M, (i+M)-max_step:(i+M)+max_step+1] = prob_self * triang_distr
+
+ if i >= M-max_step:
+ A[i, i-max_step:M] = prob_self * triang_distr[0:max_step - (i-M)] / np.sum(triang_distr[0:max_step - (i-M)])
+ A[i+M, i+M-max_step:2*M] = prob_self * triang_distr[0:max_step - (i - M)] / \
+ np.sum(triang_distr[0:max_step - (i - M)])
+
+ A[i, i+M] = 1 - prob_self
+ A[i+M, i] = 1 - prob_self
+
+ return A
+
+
+@njit
+def viterbi_pyin(A, C, O):
+ """Viterbi algorithm (pYIN variant)
+
+ Args:
+ A : ndarray
+ State transition probability matrix of dimension I x I
+ C : ndarray
+ Initial state distribution of dimension I X 1
+ O : ndarray
+ Likelihood matrix of dimension I x N
+
+ Returns:
+ idxs : ndarray
+ Optimal state sequence of length N
+ """
+ B = O.shape[0] // 2
+ M = O.shape[1]
+ D = np.zeros((B * 2, M))
+ E = np.zeros((B * 2, M - 1))
+
+ idxs = np.zeros(M)
+
+ for i in range(B * 2):
+ D[i, 0] = C[i, 0] * O[i, 0] # D matrix Intial state setting
+
+ D[:, 0] = D[:, 0] / np.sum(D[:, 0]) # Normalization (using pYIN source code as a basis)
+
+ for n in range(1, M):
+ for i in range(B * 2):
+ abyd = np.multiply(A[:, i], D[:, n-1])
+ D[i, n] = np.max(abyd) * O[i, n]
+ E[i, n-1] = np.argmax(abyd)
+
+ D[:, n] = D[:, n] / np.sum(D[:, n]) # Row normalization to avoid underflow (pYIN source code sparseHMM)
+
+ idxs[M - 1] = np.argmax(D[:, M - 1])
+
+ for n in range(M - 2, 0, -1):
+ bkd = int(idxs[n+1]) # Intermediate variable to be compatible with Numba
+ idxs[n] = E[bkd, n]
+
+ return idxs.astype(np.int32)
+
+
+@njit
+def viterbi_log_likelihood(A, C, B_O):
+ """Viterbi algorithm (log variant) for solving the uncovering problem
+
+ Notebook: C5/C5S3_Viterbi.ipynb
+
+ Args:
+ A : ndarray
+ State transition probability matrix of dimension I x I
+ C : ndarray
+ Initial state distribution of dimension I
+ B_O : ndarray
+ Likelihood matrix of dimension I x N
+
+ Returns:
+ S_opt : ndarray
+ Optimal state sequence of length N
+ """
+ I = A.shape[0] # Number of states
+ N = B_O.shape[1] # Length of observation sequence
+ tiny = np.finfo(0.).tiny
+ A_log = np.log(A + tiny)
+ C_log = np.log(C + tiny)
+ B_O_log = np.log(B_O + tiny)
+
+ # Initialize D and E matrices
+ D_log = np.zeros((I, N))
+ E = np.zeros((I, N-1)).astype(np.int32)
+ D_log[:, 0] = C_log + B_O_log[:, 0]
+
+ # Compute D and E in a nested loop
+ for n in range(1, N):
+ for i in range(I):
+ temp_sum = A_log[:, i] + D_log[:, n-1]
+ D_log[i, n] = np.max(temp_sum) + B_O_log[i, n]
+ E[i, n-1] = np.argmax(temp_sum)
+
+ # Backtracking
+ S_opt = np.zeros(N).astype(np.int32)
+ S_opt[-1] = np.argmax(D_log[:, -1])
+ for n in range(N-2, -1, -1):
+ S_opt[n] = E[int(S_opt[n+1]), n]
+
+ return S_opt
+
+
+@njit
+def delete_numba(arr, num):
+ """Delete number from array, Numba compatible. Inspired by:
+ https://stackoverflow.com/questions/53602663/delete-a-row-in-numpy-array-in-numba
+ """
+ mask = np.zeros(len(arr), dtype=np.int64) == 0
+ mask[np.where(arr == num)[0]] = False
+ return arr[mask]
diff --git a/pitch/core/salience.py b/pitch/core/salience.py
new file mode 100644
index 0000000000000000000000000000000000000000..54b33ab0e4caf9bdb700b01d68ee6145fd48423d
--- /dev/null
+++ b/pitch/core/salience.py
@@ -0,0 +1,441 @@
+"""
+| Description: libf0 salience-based F0 estimation implementation
+| Author: Sebastian Rosenzweig, Simon Schwär, Meinard Müller
+| License: The MIT license, https://opensource.org/licenses/MIT
+| This file is part of libf0.
+"""
+import numpy as np
+from librosa import stft
+from scipy import ndimage, linalg
+from numba import njit
+
+
+def salience(x, Fs=22050, N=2048, H=256, F_min=55.0, F_max=1760.0, R=10.0, num_harm=10, freq_smooth_len=11,
+ alpha=0.9, gamma=0.0, constraint_region=None, tol=5, score_low=0.01, score_high=1.0):
+ """
+ Implementation of a salience-based F0-estimation algorithm using pitch contours, inspired by Melodia.
+
+ .. [#] Justin Salamon and Emilia Gómez,
+ "Melody Extraction From Polyphonic Music Signals Using Pitch Contour Characteristics."
+ IEEE Transactions on Audio, Speech, and Language Processing, vol. 20, no. 6, pp. 1759–1770, Aug. 2012.
+
+ Parameters
+ ----------
+ x : ndarray
+ Audio signal
+ Fs : int
+ Sampling rate
+ N : int
+ Window size
+ H : int
+ Hop size
+ F_min : float or int
+ Minimal frequency
+ F_max : float or int
+ Maximal frequency
+ R : int
+ Frequency resolution given in cents
+ num_harm : int
+ Number of harmonics (Default value = 10)
+ freq_smooth_len : int
+ Filter length for vertical smoothing (Default value = 11)
+ alpha : float
+ Weighting parameter for harmonics (Default value = 0.9)
+ gamma : float
+ Logarithmic compression factor (Default value = 0.0)
+ constraint_region : None or ndarray
+ Constraint regions, row-format: (t_start_sec, t_end_sec, f_start_hz, f_end,hz)
+ (Default value = None)
+ tol : int
+ Tolerance parameter for transition matrix (Default value = 5)
+ score_low : float
+ Score (low) for transition matrix (Default value = 0.01)
+ score_high : float
+ Score (high) for transition matrix (Default value = 1.0)
+
+ Returns
+ -------
+ f0 : ndarray
+ Estimated F0-trajectory
+ T_coef: ndarray
+ Time axis
+ sal: ndarray
+ Salience value of estimated F0
+
+ See also
+ --------
+ [FMP] Notebook: C8/C8S2_SalienceRepresentation.ipynb
+ """
+
+ # compute salience representation via instantaneous frequency and harmonic summation
+ Z, F_coef_hertz = compute_salience_rep(x, Fs, N=N, H=H, F_min=F_min, F_max=F_max, R=R,
+ num_harm=num_harm, freq_smooth_len=freq_smooth_len,
+ alpha=alpha, gamma=gamma)
+
+ # compute trajectory via dynamic programming
+ T_coef = (np.arange(Z.shape[1]) * H) / Fs
+ index_CR = compute_trajectory_cr(Z, T_coef, F_coef_hertz, constraint_region,
+ tol=tol, score_low=score_low, score_high=score_high)
+
+ traj = F_coef_hertz[index_CR]
+ traj[index_CR == -1] = 0
+
+ # compute salience value
+ Z_max = np.max(Z, axis=0)
+ Z_norm = np.divide(Z, np.ones((Z.shape[0], 1)) * Z_max)
+ sal = Z_norm[index_CR, np.arange(Z.shape[1])]
+ sal[traj == 0] = 0
+
+ return traj, T_coef, sal
+
+
+def compute_salience_rep(x, Fs, N, H, F_min, F_max, R, num_harm, freq_smooth_len, alpha, gamma):
+ """
+ Compute salience representation [FMP, Eq. (8.56)]
+
+ Parameters
+ ----------
+ x : ndarray
+ Audio signal
+ Fs : int
+ Sampling rate
+ N : int
+ Window size
+ H : int
+ Hop size
+ F_min : float or int
+ Minimal frequency
+ F_max : float or int
+ Maximal frequency
+ R : int
+ Frequency resolution given in cents
+ num_harm : int
+ Number of harmonics
+ freq_smooth_len : int
+ Filter length for vertical smoothing
+ alpha : float
+ Weighting parameter for harmonics
+ gamma : float
+ Logarithmic compression factor
+
+ Returns
+ -------
+ Z : ndarray
+ Salience representation
+ F_coef_hertz : ndarray
+ Frequency axis in Hz
+
+ See also
+ --------
+ [FMP] Notebook: C8/C8S2_SalienceRepresentation.ipynb
+ """
+
+ X = stft(x, n_fft=N, hop_length=H, win_length=N, pad_mode='constant')
+ Y_LF_IF_bin, F_coef_hertz = compute_y_lf_if_bin_eff(X, Fs, N, H, F_min, F_max, R)
+
+ # smoothing
+ Y_LF_IF_bin = ndimage.convolve1d(Y_LF_IF_bin, np.hanning(freq_smooth_len), axis=0, mode='constant')
+
+ Z = compute_salience_from_logfreq_spec(Y_LF_IF_bin, R, n_harmonics=num_harm, alpha=alpha, beta=1, gamma=gamma)
+ return Z, F_coef_hertz
+
+
+def compute_y_lf_if_bin_eff(X, Fs, N, H, F_min, F_max, R):
+ """
+ Binned Log-frequency Spectrogram with variable frequency resolution based on instantaneous frequency,
+ more efficient implementation than FMP
+
+ Parameters
+ ----------
+ X : ndarray
+ Complex spectrogram
+ Fs : int
+ Sampling rate in Hz
+ N : int
+ Window size
+ H : int
+ Hop size
+ F_min : float or int
+ Minimal frequency
+ F_max : float or int
+ Maximal frequency
+ R : int
+ Frequency resolution given in cents
+
+ Returns
+ -------
+ Y_LF_IF_bin : ndarray
+ Binned log-frequency spectrogram using instantaneous frequency (shape: [freq, time])
+ F_coef_hertz : ndarray
+ Frequency axis in Hz
+ """
+
+ # calculate number of bins on log frequency axis
+ B = frequency_to_bin_index(F_max, R, F_min) + 1
+
+ # center frequencies of the final bins
+ F_coef_hertz = F_min * np.power(2, (np.arange(0, B) * R / 1200))
+
+ # calculate heterodyned phase increment (hpi)
+ k = np.arange(X.shape[0]).reshape(-1, 1)
+ omega = 2 * np.pi * k / N # center frequency for each bin in rad
+ hpi = (np.angle(X[:, 1:]) - np.angle(X[:, 0:-1])) - omega * H
+
+ # reduce hpi to -pi:pi range
+ # this is much faster than using the modulo function below, but gives the same result
+ # hpi = np.mod(hpi + np.pi, 2 * np.pi) - np.pi
+ hpi = hpi - 2 * np.pi * (np.around((hpi / (2 * np.pi)) + 1) - 1)
+
+ # calculate instantaneous frequencies in Hz
+ inst_f = (omega + hpi / H) * Fs / (2 * np.pi)
+ # repeat the first time frame to match dimensions of X
+ inst_f = np.hstack((np.copy(inst_f[:, 0]).reshape(-1, 1), inst_f))
+
+ # mask frequencies that are not relevant
+ mask = np.logical_and(inst_f >= F_min, inst_f < F_max)
+ inst_f *= mask
+ # set 0 to nan, so it does stay at nan in the bin assignment calculation
+ inst_f[np.where(inst_f == 0)] = np.nan
+
+ # find which inst_f values belong to which bin
+ bin_assignment = frequency_to_bin_index(inst_f, R, F_min)
+ # we map the discarded values to an extra bin that we remove before returning the binned spectrogram
+ bin_assignment[np.where(np.isnan(inst_f))] = B
+
+ # perform binning on power spectrogram for each time frame separately
+ Y = np.abs(X) ** 2
+ Y_LF_IF_bin = np.zeros((B+1, Y.shape[1]))
+ for t in range(Y.shape[1]):
+ np.add.at(Y_LF_IF_bin[:, t], bin_assignment[:, t], Y[:, t])
+
+ return Y_LF_IF_bin[:B, :], F_coef_hertz
+
+
+def compute_salience_from_logfreq_spec(lf_spec, R, n_harmonics, alpha, beta, gamma, harmonic_win_len=11):
+ """
+ Compute salience representation using harmonic summation following [1]
+
+ [1] J. Salamon and E. Gomez,
+ "Melody Extraction From Polyphonic Music Signals Using Pitch Contour Characteristics."
+ IEEE Transactions on Audio, Speech, and Language Processing, vol. 20, no. 6, pp. 1759–1770, Aug. 2012.
+
+ Parameters
+ ----------
+ lf_spec : ndarray
+ (F, T) log-spectrogram
+ R : int
+ Frequency resolution given in cents
+ n_harmonics : int
+ Number of harmonics
+ alpha : float
+ Weighting parameter for harmonics
+ beta : float
+ Compression parameter for spectrogram magnitudes
+ gamma : float
+ Magnitude threshold
+ harmonic_win_len : int
+ Length of a frequency weighting window in bins
+
+ Returns
+ -------
+ Z : ndarray
+ (F, T) salience representation of the input spectrogram
+ """
+
+ # magnitude thresholding and compression
+ eps = np.finfo(np.float32).eps
+ threshold_mask = (20 * np.log10(lf_spec/np.max(lf_spec) + eps)) < gamma
+ lf_spec = lf_spec**beta * threshold_mask
+
+ # compute window
+ max_diff_bins = harmonic_win_len // 2
+ window = np.cos(np.linspace(-1, 1, 2*max_diff_bins+1)*np.pi/2)**2 # cosine^2 window
+
+ # compute indices of harmonics
+ harmonics = np.round(np.log2(np.arange(1, n_harmonics + 1)) * 1200 / R).astype(int)
+ weighting_vec = np.zeros((lf_spec.shape[0] + max_diff_bins))
+
+ # compute weights
+ for idx, h in enumerate(harmonics):
+ if h+harmonic_win_len > len(weighting_vec):
+ break # we reached the maximum length available
+ weighting_vec[h:h+harmonic_win_len] += window * alpha**idx
+
+ # correlate lf_spec with the weighting vector on the frequency axis
+ Z = ndimage.correlate1d(lf_spec, weighting_vec[:],
+ axis=0, mode='constant', cval=0, origin=-len(weighting_vec)//2 + max_diff_bins)
+
+ # magnitude thresholding and compression
+ threshold_mask = (20 * np.log10(Z / np.max(Z) + eps)) < gamma
+ Z = Z ** beta * threshold_mask
+
+ return Z
+
+
+def define_transition_matrix(B, tol=0, score_low=0.01, score_high=1.0):
+ """
+ Generate transition matrix for dynamic programming
+
+ Parameters
+ ----------
+ B : int
+ Number of bins
+ tol : int
+ Tolerance parameter for transition matrix (Default value = 0)
+ score_low : float
+ Score (low) for transition matrix (Default value = 0.01)
+ score_high : float
+ Score (high) for transition matrix (Default value = 1.0)
+
+ Returns
+ -------
+ T : ndarray
+ (B, B) Transition matrix
+
+ See also
+ --------
+ [FMP] Notebook: C8/C8S2_FundFreqTracking.ipynb
+ """
+
+ col = np.ones((B,)) * score_low
+ col[0:tol+1] = np.ones((tol+1, )) * score_high
+ T = linalg.toeplitz(col)
+ return T
+
+
+@njit
+def compute_trajectory_dp(Z, T):
+ """
+ Trajectory tracking using dynamic programming
+
+ Parameters
+ ----------
+ Z : ndarray
+ Salience representation
+ T : ndarray
+ Transisition matrix
+
+ Returns
+ -------
+ eta_DP : ndarray
+ Trajectory indices
+
+ See also
+ --------
+ [FMP] Notebook: C8/C8S2_FundFreqTracking.ipynb
+ """
+
+ B, N = Z.shape
+ eps_machine = np.finfo(np.float32).eps
+ Z_log = np.log(Z + eps_machine)
+ T_log = np.log(T + eps_machine)
+
+ E = np.zeros((B, N))
+ D = np.zeros((B, N))
+ D[:, 0] = Z_log[:, 0]
+
+ for n in np.arange(1, N):
+ for b in np.arange(0, B):
+ D[b, n] = np.max(T_log[b, :] + D[:, n-1]) + Z_log[b, n]
+ E[b, n-1] = np.argmax(T_log[b, :] + D[:, n-1])
+
+ # backtracking
+ eta_DP = np.zeros(N)
+ eta_DP[N-1] = int(np.argmax(D[:, N-1]))
+
+ for n in np.arange(N-2, -1, -1):
+ eta_DP[n] = E[int(eta_DP[n+1]), n]
+
+ return eta_DP.astype(np.int64)
+
+
+def compute_trajectory_cr(Z, T_coef, F_coef_hertz, constraint_region=None,
+ tol=5, score_low=0.01, score_high=1.0):
+ """
+ Trajectory tracking with constraint regions
+ Notebook: C8/C8S2_FundFreqTracking.ipynb
+
+ Parameters
+ ----------
+ Z : ndarray
+ Salience representation
+ T_coef : ndarray
+ Time axis
+ F_coef_hertz : ndarray
+ Frequency axis in Hz
+ constraint_region : ndarray or None
+ Constraint regions, row-format: (t_start_sec, t_end_sec, f_start_hz, f_end_hz)
+ (Default value = None)
+ tol : int
+ Tolerance parameter for transition matrix (Default value = 5)
+ score_low : float
+ Score (low) for transition matrix (Default value = 0.01)
+ score_high : float
+ Score (high) for transition matrix (Default value = 1.0)
+
+ Returns
+ -------
+ eta : ndarray
+ Trajectory indices, unvoiced frames are indicated with -1
+
+ See also
+ --------
+ [FMP] Notebook: C8/C8S2_FundFreqTracking.ipynb
+ """
+
+ # do tracking within every constraint region
+ if constraint_region is not None:
+ # initialize contour, unvoiced frames are indicated with -1
+ eta = np.full(len(T_coef), -1)
+
+ for row_idx in range(constraint_region.shape[0]):
+ t_start = constraint_region[row_idx, 0] # sec
+ t_end = constraint_region[row_idx, 1] # sec
+ f_start = constraint_region[row_idx, 2] # Hz
+ f_end = constraint_region[row_idx, 3] # Hz
+
+ # convert start/end values to indices
+ t_start_idx = np.argmin(np.abs(T_coef - t_start))
+ t_end_idx = np.argmin(np.abs(T_coef - t_end))
+ f_start_idx = np.argmin(np.abs(F_coef_hertz - f_start))
+ f_end_idx = np.argmin(np.abs(F_coef_hertz - f_end))
+
+ # track in salience part
+ cur_Z = Z[f_start_idx:f_end_idx+1, t_start_idx:t_end_idx+1]
+ T = define_transition_matrix(cur_Z.shape[0], tol=tol,
+ score_low=score_low, score_high=score_high)
+ cur_eta = compute_trajectory_dp(cur_Z, T)
+
+ # fill contour
+ eta[t_start_idx:t_end_idx+1] = f_start_idx + cur_eta
+ else:
+ T = define_transition_matrix(Z.shape[0], tol=tol, score_low=score_low, score_high=score_high)
+ eta = compute_trajectory_dp(Z, T)
+
+ return eta
+
+
+def frequency_to_bin_index(F, R, F_ref):
+ """
+ Binning function with variable frequency resolution
+ Note: Indexing starts with 0 (opposed to [FMP, Eq. (8.49)])
+
+ Parameters
+ ----------
+ F : float or ndarray
+ Frequency in Hz
+ R : float
+ Frequency resolution in cents (Default value = 10.0)
+ F_ref : float
+ Reference frequency in Hz (Default value = 55.0)
+
+ Returns
+ -------
+ bin_index (int): Index for bin (starting with index 0)
+
+ See also
+ --------
+ [FMP] Notebook: C8/C8S2_SalienceRepresentation.ipynb
+ """
+ bin_index = np.floor((1200 / R) * np.log2(F / F_ref) + 0.5).astype(np.int64)
+ return bin_index
diff --git a/pitch/core/swipe.py b/pitch/core/swipe.py
new file mode 100644
index 0000000000000000000000000000000000000000..99960caf303cc2403437bacb9bd50494dcaf1670
--- /dev/null
+++ b/pitch/core/swipe.py
@@ -0,0 +1,282 @@
+"""
+| Description: libf0 SWIPE implementation
+| Contributors: Sebastian Rosenzweig, Vojtěch Pešek, Simon Schwär, Meinard Müller
+| License: The MIT license, https://opensource.org/licenses/MIT
+| This file is part of libf0.
+"""
+from scipy import interpolate
+import numpy as np
+import librosa
+
+
+def swipe(x, Fs=22050, H=256, F_min=55.0, F_max=1760.0, dlog2p=1 / 96, derbs=0.1, strength_threshold=0):
+ """
+ Implementation of a sawtooth waveform inspired pitch estimator (SWIPE).
+ This version of the algorithm follows the original implementation, see `swipe_slim` for a more efficient
+ alternative.
+
+ .. [#] Arturo Camacho and John G. Harris,
+ "A sawtooth waveform inspired pitch estimator for speech and music."
+ The Journal of the Acoustical Society of America, vol. 124, no. 3, pp. 1638–1652, Sep. 2008
+
+ Parameters
+ ----------
+ x : ndarray
+ Audio signal
+ Fs : int
+ Sampling rate
+ H : int
+ Hop size
+ F_min : float or int
+ Minimal frequency
+ F_max : float or int
+ Maximal frequency
+ dlog2p : float
+ resolution of the pitch candidate bins in octaves (default value = 1/96 -> 96 bins per octave)
+ derbs : float
+ resolution of the ERB bands (default value = 0.1)
+ strength_threshold : float
+ confidence threshold [0, 1] for the pitch detection (default value = 0)
+
+ Returns
+ -------
+ f0 : ndarray
+ Estimated F0-trajectory
+ t : ndarray
+ Time axis
+ strength : ndarray
+ Confidence/Pitch Strength
+ """
+
+ t = np.arange(0, len(x), H) / Fs # Times
+
+ # Compute pitch candidates
+ pc = 2 ** np.arange(np.log2(F_min), np.log2(F_max), dlog2p)
+
+ # Pitch strength matrix
+ S = np.zeros((len(pc), len(t)))
+
+ # Determine P2-WSs [max, min]
+ log_ws_max = np.ceil(np.log2((8 / F_min) * Fs))
+ log_ws_min = np.floor(np.log2((8 / F_max) * Fs))
+
+ # P2-WSs - window sizes in samples
+ ws = 2 ** np.arange(log_ws_max, log_ws_min - 1, -1, dtype=np.int32)
+ # print(f'window sizes in samples: {ws}')
+
+ # Determine window sizes used by each pitch candidate
+ log2pc = np.arange(np.log2(F_min), np.log2(F_max), dlog2p)
+ d = log2pc - np.log2(np.divide(8 * Fs, ws[0]))
+
+ # Create ERBs spaced frequencies (in Hertz)
+ fERBs = erbs2hz(np.arange(hz2erbs(pc[0] / 4), hz2erbs(Fs / 2), derbs))
+
+ for i in range(0, len(ws)):
+ N = ws[i]
+ H = int(N / 2)
+
+ x_zero_padded = np.concatenate([x, np.zeros(N)])
+
+ X = librosa.stft(x_zero_padded, n_fft=N, hop_length=H, pad_mode='constant', center=True)
+ ti = librosa.frames_to_time(np.arange(0, X.shape[1]), sr=Fs, hop_length=H, n_fft=N)
+ f = librosa.fft_frequencies(sr=Fs, n_fft=N)
+
+ ti = np.insert(ti, 0, 0)
+ ti = np.delete(ti, -1)
+
+ spectrum = np.abs(X)
+ magnitude = resample_ferbs(spectrum, f, fERBs)
+ loudness = np.sqrt(magnitude)
+
+ # Select candidates that use this window size
+ # First window
+ if i == 0:
+ j = np.argwhere(d < 1).flatten()
+ k = np.argwhere(d[j] > 0).flatten()
+ # Last Window
+ elif i == len(ws) - 1:
+ j = np.argwhere(d - i > -1).flatten()
+ k = np.argwhere(d[j] - i < 0).flatten()
+ else:
+ j = np.argwhere(np.abs(d - i) < 1).flatten()
+ k = np.arange(0, len(j))
+
+ pc_to_compute = pc[j]
+
+ pitch_strength = pitch_strength_all_candidates(fERBs, loudness, pc_to_compute)
+
+ resampled_pitch_strength = resample_time(pitch_strength, t, ti)
+
+ lambda_ = d[j[k]] - i
+ mu = np.ones(len(j))
+ mu[k] = 1 - np.abs(lambda_)
+
+ S[j, :] = S[j, :] + np.multiply(
+ np.ones(resampled_pitch_strength.shape) * mu.reshape((mu.shape[0], 1)),
+ resampled_pitch_strength
+ )
+
+ # Fine-tune the pitch using parabolic interpolation
+ pitches, strength = parabolic_int(S, strength_threshold, pc)
+
+ pitches[np.where(np.isnan(pitches))] = 0 # avoid NaN output
+
+ return pitches, t, strength
+
+
+def nyquist(Fs):
+ """Nyquist Frequency"""
+ return Fs / 2
+
+
+def F_coef(k, N, Fs):
+ """Physical frequency of STFT coefficients"""
+ return (k * Fs) / N
+
+
+def T_coef(m, H, Fs):
+ """Physical time of STFT coefficients"""
+ return m * H / Fs
+
+
+def stft_with_f_t(y, N, H, Fs):
+ """STFT wrapper"""
+ x = librosa.stft(y, int(N), int(H), pad_mode='constant', center=True)
+ f = F_coef(np.arange(0, x.shape[0]), N, Fs)
+ t = T_coef(np.arange(0, x.shape[1]), H, Fs)
+
+ return x, f, t
+
+
+def hz2erbs(hz):
+ """Convert Hz to ERB scale"""
+ return 21.4 * np.log10(1 + hz / 229)
+
+
+def erbs2hz(erbs):
+ """Convert ERB to Hz"""
+ return (10 ** np.divide(erbs, 21.4) - 1) * 229
+
+
+def pitch_strength_all_candidates(ferbs, loudness, pitch_candidates):
+ """Compute pitch strength for all pitch candidates"""
+ # Normalize loudness
+ normalization_loudness = np.full_like(loudness, np.sqrt(np.sum(loudness * loudness, axis=0)))
+ with np.errstate(divide='ignore', invalid='ignore'):
+ loudness = loudness / normalization_loudness
+
+ # Create pitch salience matrix
+ S = np.zeros((len(pitch_candidates), loudness.shape[1]))
+
+ for j in range(0, len(pitch_candidates)):
+ S[j, :] = pitch_strength_one(ferbs, loudness, pitch_candidates[j])
+ return S
+
+
+def pitch_strength_one(erbs_frequencies, normalized_loudness, pitch_candidate):
+ """Compute pitch strength for one pitch candidate"""
+ number_of_harmonics = np.floor(erbs_frequencies[-1] / pitch_candidate - 0.75).astype(np.int32)
+ k = np.zeros(erbs_frequencies.shape)
+
+ # f_prime / f
+ q = erbs_frequencies / pitch_candidate
+
+ for i in np.concatenate(([1], primes(number_of_harmonics))):
+ a = np.abs(q - i)
+ p = a < 0.25
+ k[p] = np.cos(np.dot(2 * np.pi, q[p]))
+ v = np.logical_and(0.25 < a, a < 0.75)
+ k[v] = k[v] + np.cos(np.dot(2 * np.pi, q[v])) / 2
+
+ # Apply envelope
+ k = np.multiply(k, np.sqrt(1.0 / erbs_frequencies))
+
+ # K+-normalize kernel
+ k = k / np.linalg.norm(k[k > 0])
+
+ # Compute pitch strength
+ S = np.dot(k, normalized_loudness)
+ return S
+
+
+def resample_ferbs(spectrum, f, ferbs):
+ """Resample to ERB scale"""
+ magnitude = np.zeros((len(ferbs), spectrum.shape[1]))
+
+ for t in range(spectrum.shape[1]):
+ spl = interpolate.splrep(f, spectrum[:, t])
+ interpolate.splev(ferbs, spl)
+
+ magnitude[:, t] = interpolate.splev(ferbs, spl)
+
+ return np.maximum(magnitude, 0)
+
+
+def resample_time(pitch_strength, resampled_time, ti):
+ """Resample time axis"""
+ if pitch_strength.shape[1] > 0:
+ pitch_strength = interpolate_one_candidate(pitch_strength, ti, resampled_time)
+ else:
+ pitch_strength = np.kron(np.ones((len(pitch_strength), len(resampled_time))), np.NaN)
+ return pitch_strength
+
+
+def interpolate_one_candidate(pitch_strength, ti, resampled_time):
+ """Interpolate time axis"""
+ pitch_strength_interpolated = np.zeros((pitch_strength.shape[0], len(resampled_time)))
+
+ for s in range(pitch_strength.shape[0]):
+ t_i = interpolate.interp1d(ti, pitch_strength[s, :], 'linear', bounds_error=True)
+ pitch_strength_interpolated[s, :] = t_i(resampled_time)
+
+ return pitch_strength_interpolated
+
+
+def parabolic_int(pitch_strength, strength_threshold, pc):
+ """Parabolic interpolation between pitch candidates using pitch strength"""
+ p = np.full((pitch_strength.shape[1],), np.NaN)
+ s = np.full((pitch_strength.shape[1],), np.NaN)
+
+ for j in range(pitch_strength.shape[1]):
+ i = np.argmax(pitch_strength[:, j])
+ s[j] = pitch_strength[i, j]
+
+ if s[j] < strength_threshold:
+ continue
+
+ if i == 0:
+ p[j] = pc[0]
+ elif i == len(pc) - 1:
+ p[j] = pc[0]
+ else:
+ I = np.arange(i - 1, i + 2)
+ tc = 1 / pc[I]
+ ntc = np.dot((tc / tc[1] - 1), 2 * np.pi)
+ if np.any(np.isnan(pitch_strength[I, j])):
+ s[j] = np.nan
+ p[j] = np.nan
+ else:
+ c = np.polyfit(ntc, pitch_strength[I, j], 2)
+ ftc = 1 / 2 ** np.arange(np.log2(pc[I[0]]), np.log2(pc[I[2]]), 1 / 12 / 64)
+ nftc = np.dot((ftc / tc[1] - 1), 2 * np.pi)
+ poly = np.polyval(c, nftc)
+ k = np.argmax(poly)
+ s[j] = poly[k]
+ p[j] = 2 ** (np.log2(pc[I[0]]) + k / 12 / 64)
+ return p, s
+
+
+def primes(n):
+ """Returns a set of n prime numbers"""
+ small_primes = np.array([2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89,
+ 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181,
+ 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281,
+ 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397,
+ 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503,
+ 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619,
+ 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743,
+ 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863,
+ 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997])
+
+ b = small_primes <= n
+ return small_primes[b]
diff --git a/pitch/core/swipe_slim.py b/pitch/core/swipe_slim.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cf23f2b561b47d654c2a5b98ba392bdf6c05f9a
--- /dev/null
+++ b/pitch/core/swipe_slim.py
@@ -0,0 +1,180 @@
+"""
+| Description: libf0 SWIPE slim implementation
+| Contributors: Sebastian Rosenzweig, Simon Schwär, Meinard Müller
+| License: The MIT license, https://opensource.org/licenses/MIT
+| This file is part of libf0.
+"""
+import numpy as np
+import librosa
+from .yin import parabolic_interpolation
+from scipy.interpolate import interp1d
+
+
+def swipe_slim(x, Fs=22050, H=256, F_min=55.0, F_max=1760.0, R=10, strength_threshold=0):
+ """
+ Slim and didactical implementation of a sawtooth waveform inspired pitch estimator (SWIPE).
+ This version uses a log-frequency spectrogram instead of ERB filters. Furthermore, it is implemented more
+ efficiently. See `swipe()` for the original implementation.
+
+ .. [#] A. Camacho and J. G. Harris,
+ "A sawtooth waveform inspired pitch estimator for speech and music."
+ The Journal of the Acoustical Society of America, vol. 124, no. 3, pp. 1638–1652, Sep. 2008
+
+ Parameters
+ ----------
+ x : ndarray
+ Audio signal
+ Fs : int
+ Sampling rate
+ H : int
+ Hop size
+ F_min : float or int
+ Minimal frequency
+ F_max : float or int
+ Maximal frequency
+ R : float
+ resolution of the pitch candidate bins in cents (default = 10)
+ strength_threshold : float
+ confidence threshold [0, 1] for the pitch detection (default value = 0)
+
+ Returns
+ -------
+ f0 : ndarray
+ Estimated F0-trajectory
+ t : ndarray
+ Time axis
+ conf : ndarray
+ Confidence / Pitch Strength
+ """
+
+ # compute time and frequency axis
+ t = np.arange(0, len(x), H) / Fs # time axis
+ F_coef_log = np.arange(0, np.log2(Fs/2/F_min), R/1200)
+ F_coef_log_hz = F_min * 2 ** F_coef_log # pitch candidates
+
+ # pre-compute kernels, one kernel for each pitch candidate in range [F_min : F_max]
+ F_min_idx = np.argmin(np.abs(F_coef_log_hz - F_min))
+ F_max_idx = np.argmin(np.abs(F_coef_log_hz - F_max))
+ B = F_max_idx - F_min_idx # Number of pitch candidates
+ kernels = np.zeros((B, len(F_coef_log_hz)))
+ for i, f in enumerate(F_coef_log_hz[F_min_idx:F_max_idx]):
+ kernels[i, :] = compute_kernel(f, F_coef_log_hz)
+
+ # determine optimal window length for each candidate
+ L_opt = np.log2(Fs * 8 / np.array([F_min, F_max])) # exponents for optimal window sizes 2^L, see paper Section II.G
+ L_rnd = np.arange(np.round(L_opt[1]), np.round(L_opt[0])+1).astype(np.int32) # range of rounded exponents
+ N_pow2 = 2 ** L_rnd # Compute rounded power-2 windows sizes
+ # Quantization error between optimal window size (see paper Section II.G) and rounded power-2 windows size
+ # Using only the largest N here, since errors for other N can be derived from err by subtracting exponent (cyclic)
+ err = np.abs(np.log2(8 * Fs / F_coef_log_hz[F_min_idx:F_max_idx]) - np.log2(np.max(N_pow2)))
+
+ S = np.zeros((B, len(t))) # "pitch-strength" matrix
+
+ # loop through all window sizes
+ for octave, N in enumerate(N_pow2):
+ # Compute STFT
+ x_pad = np.pad(x, (0, N)) # to avoid problems during time axis interpolation
+ H = N // 2
+ X = librosa.stft(x_pad, n_fft=N, hop_length=H, win_length=N, window='hann', pad_mode='constant', center=True)
+ Y = np.abs(X)
+ T_coef_lin_s = np.arange(0, X.shape[1]) * H / Fs
+ F_coef_lin_hz = np.arange(N // 2 + 1) * Fs / N
+
+ # Resample to log-frequency axis
+ compute_Y_log = interp1d(F_coef_lin_hz, Y, kind='cubic', axis=0)
+ Y_log = compute_Y_log(F_coef_log_hz)
+
+ # Normalize magnitudes
+ Y_log /= np.sqrt(np.sum(Y_log ** 2, axis=0)) + np.finfo(float).eps
+
+ # Correlate kernels with log-spectrum for pitch candidates where N is optimal
+ S_N = np.matmul(kernels, Y_log)
+
+ # Resample time axis
+ compute_S_N_res = interp1d(T_coef_lin_s, S_N, kind='linear', axis=1)
+ S_N_res = compute_S_N_res(t)
+
+ # Weight pitch strength according to quantization error
+ candidates = (err > octave - 1) & (err < octave + 1) # consider pitches +/- 1 octave from current window
+ mu = 1 - np.abs(err[candidates] - octave)
+
+ S[candidates, :] += np.multiply(mu.reshape(-1, 1), S_N_res[candidates, :])
+
+ # Obtain pitch estimates and corresponding confidence
+ max_indices = np.argmax(S, axis=0)
+ conf = np.max(S, axis=0)
+
+ # Parabolic Interpolation of pitch estimates for refinement
+ time_idx = np.arange(S.shape[1])
+ indeces_shift, _ = parabolic_interpolation(S[max_indices-1, time_idx],
+ S[max_indices, time_idx],
+ S[max_indices+1, time_idx])
+ compute_f0_log = interp1d(np.arange(len(F_coef_log)), F_coef_log, kind='linear')
+ f0_hz = F_min * 2 ** compute_f0_log(max_indices+indeces_shift)
+
+ # Thresholding
+ f0_hz[conf < strength_threshold] = 0 # discard estimates where confidence is low
+
+ return f0_hz, t, conf
+
+
+def compute_kernel(f, F_coef_log_hz):
+ """
+ Compute a SWIPE' kernel.
+
+ Parameters
+ ----------
+ f : float
+ Frequency in Hz
+ F_coef_log_hz :
+ Logarithmic frequency axis in Hz
+
+ Returns
+ -------
+ k : ndarray
+ Kernel
+ """
+ k = np.zeros(len(F_coef_log_hz))
+ n_harmonics = np.floor(F_coef_log_hz[-1] / f).astype(np.int32)
+ prime_numbers = prime_and_one(100)[:n_harmonics] # only consider prime harmonics for kernel peaks
+
+ ratio = F_coef_log_hz / f
+
+ # loop through all prime harmonics
+ for p in prime_numbers:
+ a = np.abs(ratio - p) # normalized distance between harmonic and current pitch candidate
+ main_peak_bins = a < 0.25
+ k[main_peak_bins] = np.cos(np.dot(np.array(2 * np.pi).reshape(-1, 1),
+ ratio[main_peak_bins].reshape(1, -1))).flatten()
+ valley_bins = np.logical_and(0.25 < a, a < 0.75)
+ k[valley_bins] += np.cos(np.dot(np.array(2 * np.pi).reshape(-1, 1),
+ ratio[valley_bins].reshape(1, -1))).flatten() / 2
+
+ # Apply decay
+ k = np.multiply(k, np.sqrt(1.0 / F_coef_log_hz))
+
+ # K+-normalize kernel
+ k = k / np.linalg.norm(k[k > 0])
+
+ return k
+
+
+def prime_and_one(upto=1000000):
+ """
+ Returns a set of prime numbers, adapted from http://rebrained.com/?p=458
+
+ Parameters
+ ----------
+ upto : int
+ Find prime numbers up to this number
+
+ Returns
+ -------
+ A set of prime numbers including 1 & 2
+ """
+ primes = np.arange(3, upto+1, 2)
+ isprime = np.ones((upto-1)//2, dtype=np.bool8)
+ for factor in primes[:int(np.sqrt(upto))//2]:
+ if isprime[(factor-2)//2]:
+ isprime[(factor*3-2)//2::factor] = 0
+ return np.concatenate((np.array([1, 2]), primes[isprime]))
diff --git a/pitch/core/utils.py b/pitch/core/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc488d104b0e9fa4ed3ab7835293aab0ae3757f6
--- /dev/null
+++ b/pitch/core/utils.py
@@ -0,0 +1,119 @@
+"""
+| Description: libf0 utility functions
+| Contributors: Sebastian Rosenzweig, Simon Schwär, Meinard Müller
+| License: The MIT license, https://opensource.org/licenses/MIT
+| This file is part of libf0.
+"""
+import numpy as np
+
+
+def sonify_trajectory_with_sinusoid(f0, t, audio_len, confidence=None, Fs=22050, smooth_len=11):
+ """
+ Sonification of trajectory with sinusoidal. Adapted from FMP notebook: C8/C8S2_FundFreqTracking.ipynb
+
+ Parameters
+ ----------
+ f0 : ndarray
+ F0-trajectory
+ t : ndarray
+ Time axis
+ audio_len : int
+ Desired audio length in samples
+ confidence : None or ndarray
+ Confidence values for amplitude control
+ Fs : int
+ Sampling rate
+ smooth_len : int
+ Smoothing filter length to avoid clicks in the sonification
+
+ Returns
+ -------
+ x_soni : ndarray
+ Sonified F0-trajectory
+ """
+ if confidence is None:
+ confidence = np.ones_like(f0)
+
+ # initialize
+ x_soni = np.zeros(audio_len)
+ amplitude_mod = np.zeros(audio_len)
+
+ # Computation of hop size
+ sine_len = int(t[1] * Fs)
+
+ t = np.arange(0, sine_len) / Fs
+ phase = 0
+
+ # loop over all F0 values, ensure continuous phase
+ for idx in np.arange(0, len(f0)):
+ cur_f = f0[idx]
+ cur_amp = confidence[idx]
+
+ if cur_f == 0:
+ phase = 0
+ continue
+
+ cur_soni = np.sin(2*np.pi*(cur_f*t+phase))
+ diff = np.maximum(0, (idx+1)*sine_len - len(x_soni))
+ if diff > 0:
+ x_soni[idx * sine_len:(idx + 1) * sine_len - diff] = cur_soni[:-diff]
+ amplitude_mod[idx * sine_len:(idx + 1) * sine_len - diff] = cur_amp
+ else:
+ x_soni[idx*sine_len:(idx+1)*sine_len-diff] = cur_soni
+ amplitude_mod[idx*sine_len:(idx+1)*sine_len-diff] = cur_amp
+
+ phase += cur_f * sine_len / Fs
+ phase -= 2 * np.round(phase/2)
+
+ # filter amplitudes to avoid transients
+ amplitude_mod = np.convolve(amplitude_mod, np.hanning(smooth_len)/np.sum(np.hanning(smooth_len)), 'same')
+ x_soni = x_soni * amplitude_mod
+ return x_soni
+
+
+def hz_to_cents(F, F_ref=55.0):
+ """
+ Converts frequency in Hz to cents.
+
+ Parameters
+ ----------
+ F : float or ndarray
+ Frequency value in Hz
+ F_ref : float
+ Reference frequency in Hz (Default value = 55.0)
+ Returns
+ -------
+ F_cents : float or ndarray
+ Frequency in cents
+ """
+
+ # Avoid division by 0
+ F_temp = np.array(F).astype(float)
+ F_temp[F_temp == 0] = np.nan
+
+ F_cents = 1200 * np.log2(F_temp / F_ref)
+
+ return F_cents
+
+
+def cents_to_hz(F_cents, F_ref=55.0):
+ """
+ Converts frequency in cents to Hz.
+
+ Parameters
+ ----------
+ F_cents : float or ndarray
+ Frequency in cents
+ F_ref : float
+ Reference frequency in Hz (Default value = 55.0)
+ Returns
+ -------
+ F : float or ndarray
+ Frequency in Hz
+ """
+ F = F_ref * 2 ** (F_cents / 1200)
+
+ # Avoid NaN output
+ F = np.nan_to_num(F, copy=False, nan=0)
+
+ return F
diff --git a/pitch/core/yin.py b/pitch/core/yin.py
new file mode 100644
index 0000000000000000000000000000000000000000..7408443a59dd6c7be820aac85445fbe4fec177fc
--- /dev/null
+++ b/pitch/core/yin.py
@@ -0,0 +1,238 @@
+"""
+| Description: libf0 YIN implementation
+| Contributors: Sebastian Rosenzweig, Simon Schwär, Edgar Suárez, Meinard Müller
+| License: The MIT license, https://opensource.org/licenses/MIT
+| This file is part of libf0.
+"""
+import numpy as np
+from numba import njit
+
+
+def yin(x, Fs=22050, N=2048, H=256, F_min=55.0, F_max=1760.0, threshold=0.15, verbose=False):
+ """
+ Implementation of the YIN algorithm.
+
+ .. [#] Alain De Cheveigné and Hideki Kawahara.
+ "YIN, a fundamental frequency estimator for speech and music."
+ The Journal of the Acoustical Society of America 111.4 (2002): 1917-1930.
+
+ Parameters
+ ----------
+ x : ndarray [shape=(L, )], real - valued
+ Audio signal
+ Fs : int
+ Sampling frequency
+ N : int
+ Window size
+ H : int
+ Hop size
+ F_min : float
+ Minimal frequency
+ F_max : float
+ Maximal frequency
+ threshold : float
+ Threshold for cumulative mean normalized difference function
+ verbose : bool
+ Switch to activate/deactivate status bar
+
+ Returns
+ -------
+ f0 : ndarray
+ Estimated F0-trajectory
+ t : ndarray
+ Time axis
+ ap: ndarray
+ Aperiodicity (indicator for voicing: the lower, the more reliable the estimate)
+ """
+
+ if F_min > F_max:
+ raise Exception("F_min must be smaller than F_max!")
+
+ if F_min < Fs/N:
+ raise Exception(f"The condition (F_min >= Fs/N) was not met. With Fs = {Fs}, N = {N} and F_min = {F_min} you have the following options: \n1) Set F_min >= {np.ceil(Fs/N)} Hz. \n2) Set N >= {np.ceil(Fs/F_min).astype(int)}. \n3) Set Fs <= {np.floor(F_min * N)} Hz.")
+
+ x_pad = np.concatenate((np.zeros(N//2), x, np.zeros(N//2))) # Add zeros for centered estimates
+ M = int(np.floor((len(x_pad) - N) / H)) + 1 # Compute number of estimates that will be generated
+ f0 = np.zeros(M) # Estimated fundamental frequencies (0 for unspecified frames)
+ t = np.arange(M)*H/Fs # Time axis
+ ap = np.zeros(M) # Aperiodicity
+
+ lag_min = max(int(np.ceil(Fs / F_max)), 1) # lag of maximal frequency in samples
+ lag_max = int(np.ceil(Fs / F_min)) # lag of minimal frequency in samples
+
+ for m in range(M):
+ if verbose:
+ print(f"YIN Progress: {np.ceil(100*m/M).astype(int)}%", end='\r')
+ # Take a frame from input signal
+ frame = x_pad[m*H:m*H + N]
+
+ # Cumulative Mean Normalized Difference Function
+ cmndf = cumulative_mean_normalized_difference_function(frame, lag_max)
+
+ # Absolute Thresholding
+ lag_est = absolute_thresholding(cmndf, threshold, lag_min, lag_max, parabolic_interp=True)
+
+ # Refine estimate by constraining search to vicinity of best local estimate (default: +/- 25 cents)
+ tol_cents = 25
+ lag_min_local = int(np.round(Fs / ((Fs / lag_est) * 2 ** (tol_cents/1200))))
+ if lag_min_local < lag_min:
+ lag_min_local = lag_min
+ lag_max_local = int(np.round(Fs / ((Fs / lag_est) * 2 ** (-tol_cents/1200))))
+ if lag_max_local > lag_max:
+ lag_max_local = lag_max
+ lag_new = absolute_thresholding(cmndf, threshold=np.inf, lag_min=lag_min_local, lag_max=lag_max_local,
+ parabolic_interp=True)
+
+ # Compute Fundamental Frequency Estimate
+ f0[m] = Fs / lag_new
+
+ # Compute Aperiodicity
+ ap[m] = aperiodicity(frame, lag_new)
+
+ return f0, t, ap
+
+
+@njit
+def cumulative_mean_normalized_difference_function(frame, lag_max):
+ """
+ Computes Cumulative Mean Normalized Difference Function (CMNDF).
+
+ Parameters
+ ----------
+ frame : ndarray
+ Audio frame
+ lag_max : int
+ Maximum expected lag in the CMNDF
+
+ Returns
+ -------
+ cmndf : ndarray
+ Cumulative Mean Normalized Difference Function
+ """
+
+ cmndf = np.zeros(lag_max+1) # Initialize CMNDF
+ cmndf[0] = 1
+ diff_mean = 0
+
+ for tau in range(1, lag_max+1):
+ # Difference function
+ diff = np.sum((frame[0:-tau] - frame[0 + tau:]) ** 2)
+ # Iterative mean of the difference function
+ diff_mean = diff_mean*(tau-1)/tau + diff/tau
+
+ cmndf[tau] = diff / (diff_mean + np.finfo(np.float64).eps)
+
+ return cmndf
+
+
+def absolute_thresholding(cmndf, threshold, lag_min, lag_max, parabolic_interp=True):
+ """
+ Absolute thresholding:
+ Set an absolute threshold and choose the smallest value of tau that gives a minimum of d' deeper than that
+ threshold. If none is found, the global minimum is chosen instead.
+
+ Parameters
+ ----------
+ cmndf : ndarray
+ Cumulative Mean Normalized Difference Function
+ threshold : float
+ Threshold
+ lag_min : float
+ Minimal lag
+ lag_max : float
+ Maximal lag
+ parabolic_interp : bool
+ Switch to activate/deactivate parabolic interpolation
+
+ Returns
+ -------
+
+ """
+
+ # take shortcut if search range only allows for one possible lag
+ if lag_min == lag_max:
+ return lag_min
+
+ # find local minima below absolute threshold in interval [lag_min:lag_max]
+ local_min_idxs = (np.argwhere((cmndf[1:-1] < cmndf[0:-2]) & (cmndf[1:-1] < cmndf[2:]))).flatten() + 1
+ below_thr_idxs = np.argwhere(cmndf[lag_min:lag_max] < threshold).flatten() + lag_min
+ # numba compatible intersection of indices sets
+ min_idxs = np.unique(np.array([i for i in local_min_idxs for j in below_thr_idxs if i == j]))
+
+ # if no local minima below threshold are found, return global minimum
+ if not min_idxs.size:
+ return np.argmin(cmndf[lag_min:lag_max]) + lag_min
+
+ # find first local minimum
+ lag = np.min(min_idxs) # choose first local minimum
+
+ # Optional: Parabolic Interpolation of local minima
+ if parabolic_interp:
+ lag_corr, cmndf[lag] = parabolic_interpolation(cmndf[lag-1], cmndf[lag], cmndf[lag+1])
+ lag += lag_corr
+
+ return lag
+
+
+@njit
+def parabolic_interpolation(y1, y2, y3):
+ """
+ Parabolic interpolation of an extremal value given three samples with equal spacing on the x-axis.
+ The middle value y2 is assumed to be the extremal sample of the three.
+
+ Parameters
+ ----------
+ y1: f(x1)
+ y2: f(x2)
+ y3: f(x3)
+
+ Returns
+ -------
+ x_interp: Interpolated x-value (relative to x3-x2)
+ y_interp: Interpolated y-value, f(x_interp)
+ """
+
+ a = np.finfo(np.float64).eps + (y1 + y3 - 2 * y2) / 2
+ b = (y3 - y1) / 2
+ x_interp = -b / (2 * a)
+ y_interp = y2 - (b ** 2) / (4 * a)
+
+ return x_interp, y_interp
+
+
+def aperiodicity(frame, lag_est):
+ """
+ Compute aperiodicity of given frame (serves as indicator for reliability or voicing detection).
+
+ Parameters
+ ----------
+ frame : ndarray
+ Frame
+ lag_est : float
+ Estimated lag
+
+ Returns
+ -------
+ ap: float
+ Aperiodicity (the lower, the more reliable the estimate)
+ """
+
+ lag_int = int(np.floor(lag_est)) # uncorrected period estimate
+ frac = lag_est - lag_int # residual
+
+ # Pad frame to insure constant size
+ frame_pad = np.concatenate((frame, np.flip(frame))) # mirror padding
+
+ # Shift frame by estimated period
+ if frac == 0:
+ frame_shift = frame_pad[lag_int:lag_int+len(frame)]
+ else:
+ # linear interpolation between adjacent shifts
+ frame_shift = (1 - frac) * frame_pad[lag_int:lag_int+len(frame)] + \
+ frac * frame_pad[lag_int+1:lag_int+1+len(frame)]
+
+ pwr = (np.mean(frame ** 2) + np.mean(frame_shift ** 2)) / 2 # average power over fixed and shifted frame
+ res = np.mean((frame - frame_shift) ** 2) / 2 # residual power
+ ap = res / (pwr + np.finfo(np.float64).eps)
+
+ return ap
diff --git a/pitch/debug.py b/pitch/debug.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa21c7c88889c0b401de6c6693e4b68c4a8ffc4d
--- /dev/null
+++ b/pitch/debug.py
@@ -0,0 +1,23 @@
+import argparse
+import numpy as np
+
+
+def save_csv_pitch(pitch, path):
+ with open(path, "w", encoding='utf-8') as pitch_file:
+ for i in range(len(pitch)):
+ t = i * 10
+ minute = t // 60000
+ seconds = (t - minute * 60000) // 1000
+ millisecond = t % 1000
+ print(
+ f"{minute}m {seconds}s {millisecond:3d},{int(pitch[i])}", file=pitch_file)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-p", "--pit", help="pit", dest="pit", required=True) # pit for train
+ args = parser.parse_args()
+ print(args.pit)
+
+ pitch = np.load(args.pit)
+ save_csv_pitch(pitch, 'pitch_debug.csv')
diff --git a/pitch/inference.py b/pitch/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b204f09b4617fa9bade30d70ed9b7bf5c846b32
--- /dev/null
+++ b/pitch/inference.py
@@ -0,0 +1,134 @@
+import sys,os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import torch
+import librosa
+import argparse
+import numpy as np
+import crepe
+
+
+def move_average(a, n, mode="same"):
+ return (np.convolve(a, np.ones((n,))/n, mode=mode))
+
+
+def compute_f0_mouth(path, device):
+ # pip install praat-parselmouth
+ import parselmouth
+
+ x, sr = librosa.load(path, sr=16000)
+ assert sr == 16000
+ lpad = 1024 // 160
+ rpad = lpad
+ f0 = parselmouth.Sound(x, sr).to_pitch_ac(
+ time_step=160 / sr,
+ voicing_threshold=0.5,
+ pitch_floor=30,
+ pitch_ceiling=1000).selected_array['frequency']
+ f0 = np.pad(f0, [[lpad, rpad]], mode='constant')
+ return f0
+
+
+def compute_f0_salience(filename, device):
+ from pitch.core.salience import salience
+ audio, sr = librosa.load(filename, sr=16000)
+ assert sr == 16000
+ f0, t, s = salience(
+ audio,
+ Fs=sr,
+ H=320,
+ N=2048,
+ F_min=45.0,
+ F_max=1760.0)
+ f0 = np.repeat(f0, 2, -1) # 320 -> 160 * 2
+ f0 = move_average(f0, 3)
+ return f0
+
+
+def compute_f0_voice(filename, device):
+ audio, sr = librosa.load(filename, sr=16000)
+ assert sr == 16000
+ audio = torch.tensor(np.copy(audio))[None]
+ audio = audio + torch.randn_like(audio) * 0.001
+ # Here we'll use a 10 millisecond hop length
+ hop_length = 160
+ fmin = 50
+ fmax = 1000
+ model = "full"
+ batch_size = 512
+ pitch = crepe.predict(
+ audio,
+ sr,
+ hop_length,
+ fmin,
+ fmax,
+ model,
+ batch_size=batch_size,
+ device=device,
+ return_periodicity=False,
+ )
+ pitch = crepe.filter.mean(pitch, 3)
+ pitch = pitch.squeeze(0)
+ return pitch
+
+
+def compute_f0_sing(filename, device):
+ audio, sr = librosa.load(filename, sr=16000)
+ assert sr == 16000
+ audio = torch.tensor(np.copy(audio))[None]
+ audio = audio + torch.randn_like(audio) * 0.001
+ # Here we'll use a 20 millisecond hop length
+ hop_length = 320
+ fmin = 50
+ fmax = 1000
+ model = "full"
+ batch_size = 512
+ pitch = crepe.predict(
+ audio,
+ sr,
+ hop_length,
+ fmin,
+ fmax,
+ model,
+ batch_size=batch_size,
+ device=device,
+ return_periodicity=False,
+ )
+ pitch = np.repeat(pitch, 2, -1) # 320 -> 160 * 2
+ pitch = crepe.filter.mean(pitch, 5)
+ pitch = pitch.squeeze(0)
+ return pitch
+
+
+def save_csv_pitch(pitch, path):
+ with open(path, "w", encoding='utf-8') as pitch_file:
+ for i in range(len(pitch)):
+ t = i * 10
+ minute = t // 60000
+ seconds = (t - minute * 60000) // 1000
+ millisecond = t % 1000
+ print(
+ f"{minute}m {seconds}s {millisecond:3d},{int(pitch[i])}", file=pitch_file)
+
+
+def load_csv_pitch(path):
+ pitch = []
+ with open(path, "r", encoding='utf-8') as pitch_file:
+ for line in pitch_file.readlines():
+ pit = line.strip().split(",")[-1]
+ pitch.append(int(pit))
+ return pitch
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True)
+ parser.add_argument("-p", "--pit", help="pit", dest="pit", required=True) # csv for excel
+ args = parser.parse_args()
+ print(args.wav)
+ print(args.pit)
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ pitch = compute_f0_sing(args.wav, device)
+ save_csv_pitch(pitch, args.pit)
+ # tmp = load_csv_pitch(args.pit)
+ # save_csv_pitch(tmp, "tmp.csv")
diff --git a/prepare/preprocess_a.py b/prepare/preprocess_a.py
new file mode 100644
index 0000000000000000000000000000000000000000..87d03b5baffc1c6f355bb59dc94e299ac37b2427
--- /dev/null
+++ b/prepare/preprocess_a.py
@@ -0,0 +1,58 @@
+import os
+import librosa
+import argparse
+import numpy as np
+from tqdm import tqdm
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from scipy.io import wavfile
+
+
+def resample_wave(wav_in, wav_out, sample_rate):
+ wav, _ = librosa.load(wav_in, sr=sample_rate)
+ wav = wav / np.abs(wav).max() * 0.6
+ wav = wav / max(0.01, np.max(np.abs(wav))) * 32767 * 0.6
+ wavfile.write(wav_out, sample_rate, wav.astype(np.int16))
+
+
+def process_file(file, wavPath, spks, outPath, sr):
+ if file.endswith(".wav"):
+ file = file[:-4]
+ resample_wave(f"{wavPath}/{spks}/{file}.wav", f"{outPath}/{spks}/{file}.wav", sr)
+
+
+def process_files_with_thread_pool(wavPath, spks, outPath, sr, thread_num=None):
+ files = [f for f in os.listdir(f"./{wavPath}/{spks}") if f.endswith(".wav")]
+
+ with ThreadPoolExecutor(max_workers=thread_num) as executor:
+ futures = {executor.submit(process_file, file, wavPath, spks, outPath, sr): file for file in files}
+
+ for future in tqdm(as_completed(futures), total=len(futures), desc=f'Processing {sr} {spks}'):
+ future.result()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True)
+ parser.add_argument("-o", "--out", help="out", dest="out", required=True)
+ parser.add_argument("-s", "--sr", help="sample rate", dest="sr", type=int, required=True)
+ parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1)
+
+ args = parser.parse_args()
+ print(args.wav)
+ print(args.out)
+ print(args.sr)
+
+ os.makedirs(args.out, exist_ok=True)
+ wavPath = args.wav
+ outPath = args.out
+
+ assert args.sr == 16000 or args.sr == 32000
+
+ for spks in os.listdir(wavPath):
+ if os.path.isdir(f"./{wavPath}/{spks}"):
+ os.makedirs(f"./{outPath}/{spks}", exist_ok=True)
+ if args.thread_count == 0:
+ process_num = os.cpu_count() // 2 + 1
+ else:
+ process_num = args.thread_count
+ process_files_with_thread_pool(wavPath, spks, outPath, args.sr, process_num)
diff --git a/prepare/preprocess_cdc.py b/prepare/preprocess_cdc.py
new file mode 100644
index 0000000000000000000000000000000000000000..730feaacd24620136e8fb60de7989cffc7f56043
--- /dev/null
+++ b/prepare/preprocess_cdc.py
@@ -0,0 +1,51 @@
+import os
+import argparse
+import torch
+import torchaudio
+
+from tqdm import tqdm
+from scipy.io.wavfile import read
+from scipy.io.wavfile import write
+# torch=1.9.0 -> pip install torchaudio==0.9.0 -i https://mirrors.aliyun.com/pypi/simple/
+# this file is for VCTK
+
+
+MAX_WAV_VALUE = 32768.0
+
+
+def cut_direct_content(iWave, oWave):
+ source, sr = torchaudio.load(iWave)
+ stft = torch.stft(source, 1024, 256, 1024, torch.hann_window(1024), return_complex=True)
+ stft[:, 0, :] = 0
+ stft[:, 1, :] = 0
+ istft = torch.istft(stft, 1024, 256, 1024, torch.hann_window(1024))
+ audio = istft.squeeze()
+ audio = MAX_WAV_VALUE * audio
+ audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
+ audio = audio.short()
+ audio = audio.data.cpu().detach().numpy()
+ write(oWave, sr, audio)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-i", help="input path", dest="inPath", required=True)
+ parser.add_argument("-o", help="output path", dest="outPath", required=True)
+
+ args = parser.parse_args()
+ print(args.inPath)
+ print(args.outPath)
+
+ os.makedirs(args.outPath, exist_ok=True)
+ rootPath = args.inPath
+ outPath = args.outPath
+
+ for spks in os.listdir(rootPath):
+ if (os.path.isdir(f"./{rootPath}/{spks}")):
+ os.makedirs(f"./{outPath}/{spks}", exist_ok=True)
+
+ files = [f for f in os.listdir(f"./{rootPath}/{spks}") if f.endswith(".wav")]
+ for file in tqdm(files, desc=f'Processing cdc {spks}'):
+ iWave = f"./{rootPath}/{spks}/{file}"
+ oWave = f"./{outPath}/{spks}/{file}"
+ cut_direct_content(iWave, oWave)
diff --git a/prepare/preprocess_crepe.py b/prepare/preprocess_crepe.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f9fda489d8fce1c8ee5d4f44aea7165a0b0534d
--- /dev/null
+++ b/prepare/preprocess_crepe.py
@@ -0,0 +1,69 @@
+import sys,os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import numpy as np
+import librosa
+import torch
+import crepe
+import argparse
+from tqdm import tqdm
+
+
+def compute_f0(filename, save, device):
+ audio, sr = librosa.load(filename, sr=16000)
+ assert sr == 16000
+ # Load audio
+ audio = torch.tensor(np.copy(audio))[None]
+ audio = audio + torch.randn_like(audio) * 0.001
+ # Here we'll use a 10 millisecond hop length
+ hop_length = 160
+ # Provide a sensible frequency range for your domain (upper limit is 2006 Hz)
+ # This would be a reasonable range for speech
+ fmin = 50
+ fmax = 1000
+ # Select a model capacity--one of "tiny" or "full"
+ model = "full"
+ # Pick a batch size that doesn't cause memory errors on your gpu
+ batch_size = 512
+ # Compute pitch using first gpu
+ pitch, periodicity = crepe.predict(
+ audio,
+ sr,
+ hop_length,
+ fmin,
+ fmax,
+ model,
+ batch_size=batch_size,
+ device=device,
+ return_periodicity=True,
+ )
+ # CREPE was not trained on silent audio. some error on silent need filter.pitPath
+ periodicity = crepe.filter.median(periodicity, 7)
+ pitch = crepe.filter.mean(pitch, 5)
+ pitch[periodicity < 0.5] = 0
+ pitch = pitch.squeeze(0)
+ np.save(save, pitch, allow_pickle=False)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True)
+ parser.add_argument("-p", "--pit", help="pit", dest="pit", required=True)
+
+ args = parser.parse_args()
+ print(args.wav)
+ print(args.pit)
+
+ os.makedirs(args.pit, exist_ok=True)
+ wavPath = args.wav
+ pitPath = args.pit
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ for spks in os.listdir(wavPath):
+ if os.path.isdir(f"./{wavPath}/{spks}"):
+ os.makedirs(f"./{pitPath}/{spks}", exist_ok=True)
+
+ files = [f for f in os.listdir(f"./{wavPath}/{spks}") if f.endswith(".wav")]
+ for file in tqdm(files, desc=f'Processing crepe {spks}'):
+ file = file[:-4]
+ compute_f0(f"{wavPath}/{spks}/{file}.wav", f"{pitPath}/{spks}/{file}.pit", device)
diff --git a/prepare/preprocess_f0.py b/prepare/preprocess_f0.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b6ae384f8511455c660caf9974815f8d781bc8c
--- /dev/null
+++ b/prepare/preprocess_f0.py
@@ -0,0 +1,62 @@
+import os
+import numpy as np
+import librosa
+import pyworld
+import argparse
+from tqdm import tqdm
+from concurrent.futures import ProcessPoolExecutor, as_completed
+
+
+def compute_f0(path, save):
+ x, sr = librosa.load(path, sr=16000)
+ assert sr == 16000
+ f0, t = pyworld.dio(
+ x.astype(np.double),
+ fs=sr,
+ f0_ceil=900,
+ frame_period=1000 * 160 / sr,
+ )
+ f0 = pyworld.stonemask(x.astype(np.double), f0, t, fs=16000)
+ for index, pitch in enumerate(f0):
+ f0[index] = round(pitch, 1)
+ np.save(save, f0, allow_pickle=False)
+
+
+def process_file(file, wavPath, spks, pitPath):
+ if file.endswith(".wav"):
+ file = file[:-4]
+ compute_f0(f"{wavPath}/{spks}/{file}.wav", f"{pitPath}/{spks}/{file}.pit")
+
+
+def process_files_with_process_pool(wavPath, spks, pitPath, process_num=None):
+ files = [f for f in os.listdir(f"./{wavPath}/{spks}") if f.endswith(".wav")]
+
+ with ProcessPoolExecutor(max_workers=process_num) as executor:
+ futures = {executor.submit(process_file, file, wavPath, spks, pitPath): file for file in files}
+
+ for future in tqdm(as_completed(futures), total=len(futures), desc=f'Processing f0 {spks}'):
+ future.result()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True)
+ parser.add_argument("-p", "--pit", help="pit", dest="pit", required=True)
+ parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1)
+
+ args = parser.parse_args()
+ print(args.wav)
+ print(args.pit)
+
+ os.makedirs(args.pit, exist_ok=True)
+ wavPath = args.wav
+ pitPath = args.pit
+
+ for spks in os.listdir(wavPath):
+ if os.path.isdir(f"./{wavPath}/{spks}"):
+ os.makedirs(f"./{pitPath}/{spks}", exist_ok=True)
+ if args.thread_count == 0:
+ process_num = os.cpu_count() // 2 + 1
+ else:
+ process_num = args.thread_count
+ process_files_with_process_pool(wavPath, spks, pitPath, process_num)
diff --git a/prepare/preprocess_f0_mouth.py b/prepare/preprocess_f0_mouth.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a03ff6e2403dc736beb40829a1da8c416353f00
--- /dev/null
+++ b/prepare/preprocess_f0_mouth.py
@@ -0,0 +1,62 @@
+import os
+import numpy as np
+import librosa
+import argparse
+import parselmouth
+# pip install praat-parselmouth
+from tqdm import tqdm
+from concurrent.futures import ProcessPoolExecutor, as_completed
+
+
+def compute_f0(path, save):
+ x, sr = librosa.load(path, sr=16000)
+ assert sr == 16000
+ lpad = 1024 // 160
+ rpad = lpad
+ f0 = parselmouth.Sound(x, sr).to_pitch_ac(
+ time_step=160 / sr,
+ voicing_threshold=0.5,
+ pitch_floor=30,
+ pitch_ceiling=1000).selected_array['frequency']
+ f0 = np.pad(f0, [[lpad, rpad]], mode='constant')
+ np.save(save, f0, allow_pickle=False)
+
+
+def process_file(file, wavPath, spks, pitPath):
+ if file.endswith(".wav"):
+ file = file[:-4]
+ compute_f0(f"{wavPath}/{spks}/{file}.wav", f"{pitPath}/{spks}/{file}.pit")
+
+
+def process_files_with_process_pool(wavPath, spks, pitPath, process_num=None):
+ files = [f for f in os.listdir(f"./{wavPath}/{spks}") if f.endswith(".wav")]
+
+ with ProcessPoolExecutor(max_workers=process_num) as executor:
+ futures = {executor.submit(process_file, file, wavPath, spks, pitPath): file for file in files}
+
+ for future in tqdm(as_completed(futures), total=len(futures), desc=f'Processing f0 {spks}'):
+ future.result()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True)
+ parser.add_argument("-p", "--pit", help="pit", dest="pit", required=True)
+ parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1)
+
+ args = parser.parse_args()
+ print(args.wav)
+ print(args.pit)
+
+ os.makedirs(args.pit, exist_ok=True)
+ wavPath = args.wav
+ pitPath = args.pit
+
+ for spks in os.listdir(wavPath):
+ if os.path.isdir(f"./{wavPath}/{spks}"):
+ os.makedirs(f"./{pitPath}/{spks}", exist_ok=True)
+ if args.thread_count == 0:
+ process_num = os.cpu_count() // 2 + 1
+ else:
+ process_num = args.thread_count
+ process_files_with_process_pool(wavPath, spks, pitPath, process_num)
diff --git a/prepare/preprocess_hubert.py b/prepare/preprocess_hubert.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd4265b715877a9b97cb8192a3c0d9c450cb2fbb
--- /dev/null
+++ b/prepare/preprocess_hubert.py
@@ -0,0 +1,58 @@
+import sys,os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import numpy as np
+import argparse
+import torch
+import librosa
+
+from tqdm import tqdm
+from hubert import hubert_model
+
+
+def load_audio(file: str, sr: int = 16000):
+ x, sr = librosa.load(file, sr=sr)
+ return x
+
+
+def load_model(path, device):
+ model = hubert_model.hubert_soft(path)
+ model.eval()
+ model.half()
+ model.to(device)
+ return model
+
+
+def pred_vec(model, wavPath, vecPath, device):
+ feats = load_audio(wavPath)
+ feats = torch.from_numpy(feats).to(device)
+ feats = feats[None, None, :].half()
+ with torch.no_grad():
+ vec = model.units(feats).squeeze().data.cpu().float().numpy()
+ # print(vec.shape) # [length, dim=256] hop=320
+ np.save(vecPath, vec, allow_pickle=False)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True)
+ parser.add_argument("-v", "--vec", help="vec", dest="vec", required=True)
+
+ args = parser.parse_args()
+ print(args.wav)
+ print(args.vec)
+ os.makedirs(args.vec, exist_ok=True)
+
+ wavPath = args.wav
+ vecPath = args.vec
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ hubert = load_model(os.path.join("hubert_pretrain", "hubert-soft-0d54a1f4.pt"), device)
+
+ for spks in os.listdir(wavPath):
+ if os.path.isdir(f"./{wavPath}/{spks}"):
+ os.makedirs(f"./{vecPath}/{spks}", exist_ok=True)
+
+ files = [f for f in os.listdir(f"./{wavPath}/{spks}") if f.endswith(".wav")]
+ for file in tqdm(files, desc=f'Processing vec {spks}'):
+ file = file[:-4]
+ pred_vec(hubert, f"{wavPath}/{spks}/{file}.wav", f"{vecPath}/{spks}/{file}.vec", device)
diff --git a/prepare/preprocess_ppg.py b/prepare/preprocess_ppg.py
new file mode 100644
index 0000000000000000000000000000000000000000..999bec671c8b044c5d73c53614268ce36530b279
--- /dev/null
+++ b/prepare/preprocess_ppg.py
@@ -0,0 +1,71 @@
+import sys,os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import numpy as np
+import argparse
+import torch
+import random
+from tqdm import tqdm
+from whisper.model import Whisper, ModelDimensions
+from whisper.audio import load_audio, pad_or_trim, log_mel_spectrogram
+
+
+def load_model(path) -> Whisper:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ checkpoint = torch.load(path, map_location="cpu")
+ dims = ModelDimensions(**checkpoint["dims"])
+ print(dims)
+ model = Whisper(dims)
+ del model.decoder
+ cut = len(model.encoder.blocks) // 4
+ cut = -1 * cut
+ del model.encoder.blocks[cut:]
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
+ model.eval()
+ model.half()
+ model.to(device)
+ return model
+
+
+def pred_ppg(whisper: Whisper, wavPath, ppgPath):
+ audio = load_audio(wavPath)
+ audln = audio.shape[0]
+ ppgln = audln // 320
+ audio = pad_or_trim(audio)
+ mel = log_mel_spectrogram(audio).half().to(whisper.device)
+ with torch.no_grad():
+ ppg = whisper.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy()
+ ppg = ppg[:ppgln,] # [length, dim=1280]
+ # print(ppg.shape)
+ np.save(ppgPath, ppg, allow_pickle=False)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True)
+ parser.add_argument("-p", "--ppg", help="ppg", dest="ppg", required=True)
+ args = parser.parse_args()
+ print(args.wav)
+ print(args.ppg)
+
+ os.makedirs(args.ppg, exist_ok=True)
+ wavPath = args.wav
+ ppgPath = args.ppg
+
+ whisper = load_model(os.path.join("whisper_pretrain", "large-v2.pt"))
+ spkPaths = os.listdir(wavPath)
+ random.shuffle(spkPaths)
+
+ for spks in spkPaths:
+ if os.path.isdir(f"./{wavPath}/{spks}"):
+ os.makedirs(f"./{ppgPath}/{spks}", exist_ok=True)
+
+ files = [f for f in os.listdir(f"./{wavPath}/{spks}") if f.endswith(".wav")]
+ for file in tqdm(files, desc=f'Processing ppg {spks}'):
+ if file.endswith(".wav"):
+ # print(file)
+ file = file[:-4]
+ path_wav = f"{wavPath}/{spks}/{file}.wav"
+ path_ppg = f"{ppgPath}/{spks}/{file}.ppg"
+ if os.path.isfile(f"{path_ppg}.npy"):
+ continue
+ pred_ppg(whisper, path_wav, path_ppg)
diff --git a/prepare/preprocess_random.py b/prepare/preprocess_random.py
new file mode 100644
index 0000000000000000000000000000000000000000..f84977bb49d090b333a382772374830d5d1318c6
--- /dev/null
+++ b/prepare/preprocess_random.py
@@ -0,0 +1,23 @@
+import random
+
+
+if __name__ == "__main__":
+ all_items = []
+ fo = open("./files/train_all.txt", "r+", encoding='utf-8')
+ while (True):
+ try:
+ item = fo.readline().strip()
+ except Exception as e:
+ print('nothing of except:', e)
+ break
+ if (item == None or item == ""):
+ break
+ all_items.append(item)
+ fo.close()
+
+ random.shuffle(all_items)
+
+ fw = open("./files/train_all.txt", "w", encoding="utf-8")
+ for strs in all_items:
+ print(strs, file=fw)
+ fw.close()
diff --git a/prepare/preprocess_speaker.py b/prepare/preprocess_speaker.py
new file mode 100644
index 0000000000000000000000000000000000000000..797b60edbeb16f8a50a1be8bd2095f206bd8875e
--- /dev/null
+++ b/prepare/preprocess_speaker.py
@@ -0,0 +1,103 @@
+import sys,os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import torch
+import numpy as np
+import argparse
+
+from tqdm import tqdm
+from functools import partial
+from argparse import RawTextHelpFormatter
+from multiprocessing.pool import ThreadPool
+
+from speaker.models.lstm import LSTMSpeakerEncoder
+from speaker.config import SpeakerEncoderConfig
+from speaker.utils.audio import AudioProcessor
+from speaker.infer import read_json
+
+
+def get_spk_wavs(dataset_path, output_path):
+ wav_files = []
+ os.makedirs(f"./{output_path}", exist_ok=True)
+ for spks in os.listdir(dataset_path):
+ if os.path.isdir(f"./{dataset_path}/{spks}"):
+ os.makedirs(f"./{output_path}/{spks}", exist_ok=True)
+ for file in os.listdir(f"./{dataset_path}/{spks}"):
+ if file.endswith(".wav"):
+ wav_files.append(f"./{dataset_path}/{spks}/{file}")
+ elif spks.endswith(".wav"):
+ wav_files.append(f"./{dataset_path}/{spks}")
+ return wav_files
+
+
+def process_wav(wav_file, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder):
+ waveform = speaker_encoder_ap.load_wav(
+ wav_file, sr=speaker_encoder_ap.sample_rate
+ )
+ spec = speaker_encoder_ap.melspectrogram(waveform)
+ spec = torch.from_numpy(spec.T)
+ if args.use_cuda:
+ spec = spec.cuda()
+ spec = spec.unsqueeze(0)
+ embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy()
+ embed = embed.squeeze()
+ embed_path = wav_file.replace(dataset_path, output_path)
+ embed_path = embed_path.replace(".wav", ".spk")
+ np.save(embed_path, embed, allow_pickle=False)
+
+
+def extract_speaker_embeddings(wav_files, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder, concurrency):
+ bound_process_wav = partial(process_wav, dataset_path=dataset_path, output_path=output_path, args=args, speaker_encoder_ap=speaker_encoder_ap, speaker_encoder=speaker_encoder)
+
+ with ThreadPool(concurrency) as pool:
+ list(tqdm(pool.imap(bound_process_wav, wav_files), total=len(wav_files)))
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(
+ description="""Compute embedding vectors for each wav file in a dataset.""",
+ formatter_class=RawTextHelpFormatter,
+ )
+ parser.add_argument("dataset_path", type=str, help="Path to dataset waves.")
+ parser.add_argument(
+ "output_path", type=str, help="path for output speaker/speaker_wavs.npy."
+ )
+ parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
+ parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1)
+ args = parser.parse_args()
+ dataset_path = args.dataset_path
+ output_path = args.output_path
+ thread_count = args.thread_count
+ # model
+ args.model_path = os.path.join("speaker_pretrain", "best_model.pth.tar")
+ args.config_path = os.path.join("speaker_pretrain", "config.json")
+ # config
+ config_dict = read_json(args.config_path)
+
+ # model
+ config = SpeakerEncoderConfig(config_dict)
+ config.from_dict(config_dict)
+
+ speaker_encoder = LSTMSpeakerEncoder(
+ config.model_params["input_dim"],
+ config.model_params["proj_dim"],
+ config.model_params["lstm_dim"],
+ config.model_params["num_lstm_layers"],
+ )
+
+ speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda)
+
+ # preprocess
+ speaker_encoder_ap = AudioProcessor(**config.audio)
+ # normalize the input audio level and trim silences
+ speaker_encoder_ap.do_sound_norm = True
+ speaker_encoder_ap.do_trim_silence = True
+
+ wav_files = get_spk_wavs(dataset_path, output_path)
+
+ if thread_count == 0:
+ process_num = os.cpu_count()
+ else:
+ process_num = thread_count
+
+ extract_speaker_embeddings(wav_files, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder, process_num)
\ No newline at end of file
diff --git a/prepare/preprocess_speaker_ave.py b/prepare/preprocess_speaker_ave.py
new file mode 100644
index 0000000000000000000000000000000000000000..9423f61693a91fef7ff9836f89f157856767d924
--- /dev/null
+++ b/prepare/preprocess_speaker_ave.py
@@ -0,0 +1,54 @@
+import os
+import torch
+import argparse
+import numpy as np
+from tqdm import tqdm
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("dataset_speaker", type=str)
+ parser.add_argument("dataset_singer", type=str)
+
+ data_speaker = parser.parse_args().dataset_speaker
+ data_singer = parser.parse_args().dataset_singer
+
+ os.makedirs(data_singer, exist_ok=True)
+
+ for speaker in os.listdir(data_speaker):
+ subfile_num = 0
+ speaker_ave = 0
+
+ for file in tqdm(os.listdir(os.path.join(data_speaker, speaker)), desc=f"average {speaker}"):
+ if not file.endswith(".npy"):
+ continue
+ source_embed = np.load(os.path.join(data_speaker, speaker, file))
+ source_embed = source_embed.astype(np.float32)
+ speaker_ave = speaker_ave + source_embed
+ subfile_num = subfile_num + 1
+ if subfile_num == 0:
+ continue
+ speaker_ave = speaker_ave / subfile_num
+
+ np.save(os.path.join(data_singer, f"{speaker}.spk.npy"),
+ speaker_ave, allow_pickle=False)
+
+ # rewrite timbre code by average, if similarity is larger than cmp_val
+ rewrite_timbre_code = False
+ if not rewrite_timbre_code:
+ continue
+ cmp_src = torch.FloatTensor(speaker_ave)
+ cmp_num = 0
+ cmp_val = 0.85
+ for file in tqdm(os.listdir(os.path.join(data_speaker, speaker)), desc=f"rewrite {speaker}"):
+ if not file.endswith(".npy"):
+ continue
+ cmp_tmp = np.load(os.path.join(data_speaker, speaker, file))
+ cmp_tmp = cmp_tmp.astype(np.float32)
+ cmp_tmp = torch.FloatTensor(cmp_tmp)
+ cmp_cos = torch.cosine_similarity(cmp_src, cmp_tmp, dim=0)
+ if (cmp_cos > cmp_val):
+ cmp_num += 1
+ np.save(os.path.join(data_speaker, speaker, file),
+ speaker_ave, allow_pickle=False)
+ print(f"rewrite timbre for {speaker} with :", cmp_num)
diff --git a/prepare/preprocess_spec.py b/prepare/preprocess_spec.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2eef6e3cde6072342dffe8e39cec82d30d0f80e
--- /dev/null
+++ b/prepare/preprocess_spec.py
@@ -0,0 +1,62 @@
+import sys,os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import torch
+import argparse
+import multiprocessing
+from concurrent.futures import ThreadPoolExecutor
+from tqdm import tqdm
+from vits import spectrogram
+from vits import utils
+from omegaconf import OmegaConf
+
+
+def compute_spec(hps, filename, specname):
+ audio, sampling_rate = utils.load_wav_to_torch(filename)
+ assert sampling_rate == hps.sampling_rate, f"{sampling_rate} is not {hps.sampling_rate}"
+ audio_norm = audio / hps.max_wav_value
+ audio_norm = audio_norm.unsqueeze(0)
+ n_fft = hps.filter_length
+ sampling_rate = hps.sampling_rate
+ hop_size = hps.hop_length
+ win_size = hps.win_length
+ spec = spectrogram.spectrogram_torch(
+ audio_norm, n_fft, sampling_rate, hop_size, win_size, center=False)
+ spec = torch.squeeze(spec, 0)
+ torch.save(spec, specname)
+
+
+def process_file(file):
+ if file.endswith(".wav"):
+ file = file[:-4]
+ compute_spec(hps.data, f"{wavPath}/{spks}/{file}.wav", f"{spePath}/{spks}/{file}.pt")
+
+
+def process_files_with_thread_pool(wavPath, spks, thread_num):
+ files = os.listdir(f"./{wavPath}/{spks}")
+ with ThreadPoolExecutor(max_workers=thread_num) as executor:
+ list(tqdm(executor.map(process_file, files), total=len(files), desc=f'Processing spec {spks}'))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True)
+ parser.add_argument("-s", "--spe", help="spe", dest="spe", required=True)
+ parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1)
+
+ args = parser.parse_args()
+ print(args.wav)
+ print(args.spe)
+
+ os.makedirs(args.spe, exist_ok=True)
+ wavPath = args.wav
+ spePath = args.spe
+ hps = OmegaConf.load("./configs/base.yaml")
+
+ for spks in os.listdir(wavPath):
+ if os.path.isdir(f"./{wavPath}/{spks}"):
+ os.makedirs(f"./{spePath}/{spks}", exist_ok=True)
+ if args.thread_count == 0:
+ process_num = os.cpu_count() // 2 + 1
+ else:
+ process_num = args.thread_count
+ process_files_with_thread_pool(wavPath, spks, process_num)
diff --git a/prepare/preprocess_train.py b/prepare/preprocess_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..985738ec3ef4e2c5d123558ae5a9e400b1cbba85
--- /dev/null
+++ b/prepare/preprocess_train.py
@@ -0,0 +1,68 @@
+import os
+import random
+
+
+def print_error(info):
+ print(f"\033[31m File isn't existed: {info}\033[0m")
+
+
+IndexBySinger = False
+if __name__ == "__main__":
+ os.makedirs("./files/", exist_ok=True)
+
+ rootPath = "./data_svc/waves-32k/"
+ all_items = []
+ for spks in os.listdir(f"./{rootPath}"):
+ if not os.path.isdir(f"./{rootPath}/{spks}"):
+ continue
+ print(f"./{rootPath}/{spks}")
+ for file in os.listdir(f"./{rootPath}/{spks}"):
+ if file.endswith(".wav"):
+ file = file[:-4]
+
+ if (IndexBySinger == False):
+ path_spk = f"./data_svc/speaker/{spks}/{file}.spk.npy"
+ else:
+ path_spk = f"./data_svc/singer/{spks}.spk.npy"
+
+ path_wave = f"./data_svc/waves-32k/{spks}/{file}.wav"
+ path_spec = f"./data_svc/specs/{spks}/{file}.pt"
+ path_pitch = f"./data_svc/pitch/{spks}/{file}.pit.npy"
+ path_hubert = f"./data_svc/hubert/{spks}/{file}.vec.npy"
+ path_whisper = f"./data_svc/whisper/{spks}/{file}.ppg.npy"
+ has_error = 0
+ if not os.path.isfile(path_spk):
+ print_error(path_spk)
+ has_error = 1
+ if not os.path.isfile(path_wave):
+ print_error(path_wave)
+ has_error = 1
+ if not os.path.isfile(path_spec):
+ print_error(path_spec)
+ has_error = 1
+ if not os.path.isfile(path_pitch):
+ print_error(path_pitch)
+ has_error = 1
+ if not os.path.isfile(path_hubert):
+ print_error(path_hubert)
+ has_error = 1
+ if not os.path.isfile(path_whisper):
+ print_error(path_whisper)
+ has_error = 1
+ if has_error == 0:
+ all_items.append(
+ f"{path_wave}|{path_spec}|{path_pitch}|{path_hubert}|{path_whisper}|{path_spk}")
+
+ random.shuffle(all_items)
+ valids = all_items[:10]
+ valids.sort()
+ trains = all_items[10:]
+ # trains.sort()
+ fw = open("./files/valid.txt", "w", encoding="utf-8")
+ for strs in valids:
+ print(strs, file=fw)
+ fw.close()
+ fw = open("./files/train.txt", "w", encoding="utf-8")
+ for strs in trains:
+ print(strs, file=fw)
+ fw.close()
diff --git a/prepare/preprocess_trim.py b/prepare/preprocess_trim.py
new file mode 100644
index 0000000000000000000000000000000000000000..3856fd413b5a5e781fd342f6cb2c01fd7e48f300
--- /dev/null
+++ b/prepare/preprocess_trim.py
@@ -0,0 +1,50 @@
+import os
+import argparse
+
+from tqdm import tqdm
+from pydub import AudioSegment
+from pydub.silence import split_on_silence
+from pydub import effects
+# this file is for VCTK, use after CDC
+
+
+def trim_silence(iWave, oWave):
+ try:
+ audio = AudioSegment.from_wav(iWave)
+ # audio = effects.normalize(audio, 6)# max - 6dB
+ audio_chunks = split_on_silence(
+ audio,
+ min_silence_len=200,
+ silence_thresh=-45,
+ keep_silence=200,
+ )
+ for chunk in audio_chunks[1:]:
+ audio_chunks[0] += chunk
+ audio_chunks[0].export(oWave, format="wav")
+ except Exception as e:
+ print(str(e))
+ print(iWave)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-i", help="input path", dest="inPath", required=True)
+ parser.add_argument("-o", help="output path", dest="outPath", required=True)
+
+ args = parser.parse_args()
+ print(args.inPath)
+ print(args.outPath)
+
+ os.makedirs(args.outPath, exist_ok=True)
+ rootPath = args.inPath
+ outPath = args.outPath
+
+ for spks in os.listdir(rootPath):
+ if (os.path.isdir(f"./{rootPath}/{spks}")):
+ os.makedirs(f"./{outPath}/{spks}", exist_ok=True)
+
+ files = [f for f in os.listdir(f"./{rootPath}/{spks}") if f.endswith(".wav")]
+ for file in tqdm(files, desc=f'Processing sil {spks}'):
+ iWave = f"./{rootPath}/{spks}/{file}"
+ oWave = f"./{outPath}/{spks}/{file}"
+ trim_silence(iWave, oWave)
diff --git a/prepare/preprocess_zzz.py b/prepare/preprocess_zzz.py
new file mode 100644
index 0000000000000000000000000000000000000000..79e62a97271a9c5f14e220063900d48f09207c61
--- /dev/null
+++ b/prepare/preprocess_zzz.py
@@ -0,0 +1,31 @@
+import sys,os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from tqdm import tqdm
+from torch.utils.data import DataLoader
+from omegaconf import OmegaConf
+from vits.data_utils import TextAudioSpeakerSet
+from vits.data_utils import TextAudioSpeakerCollate
+from vits.data_utils import DistributedBucketSampler
+
+
+hps = OmegaConf.load("./configs/base.yaml")
+dataset = TextAudioSpeakerSet("files/valid.txt", hps.data)
+
+for _ in tqdm(dataset):
+ pass
+
+
+sampler = DistributedBucketSampler(
+ dataset,
+ 4,
+ [150, 300, 450],
+ num_replicas=1,
+ rank=0,
+ shuffle=True)
+collate_fn = TextAudioSpeakerCollate()
+loader = DataLoader(dataset, num_workers=0, shuffle=False, pin_memory=True,
+ collate_fn=collate_fn, batch_sampler=sampler)
+
+
+for _ in tqdm(loader):
+ pass
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2b6efa293e00ed0fe56a5f32717fb80b2d669e43
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,19 @@
+torch==2.2.2
+torchvision==0.17.2
+torchaudio==2.2.2
+fsspec
+pyworld
+matplotlib
+soundfile
+scikit-learn
+scipy
+tensorboard
+transformers
+tqdm
+librosa
+omegaconf
+ruamel.yaml
+resampy
+numpy==1.24
+chardet
+faiss-cpu==1.7.4
diff --git a/speaker/README.md b/speaker/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b6f541f884f6165a37540cc7fae4df7bf2fa2ac7
--- /dev/null
+++ b/speaker/README.md
@@ -0,0 +1,18 @@
+### Speaker Encoder
+
+This is an implementation of https://arxiv.org/abs/1710.10467. This model can be used for voice and speaker embedding.
+
+With the code here you can generate d-vectors for both multi-speaker and single-speaker TTS datasets, then visualise and explore them along with the associated audio files in an interactive chart.
+
+Below is an example showing embedding results of various speakers. You can generate the same plot with the provided notebook as demonstrated in [this video](https://youtu.be/KW3oO7JVa7Q).
+
+
+
+Download a pretrained model from [Released Models](https://github.com/mozilla/TTS/wiki/Released-Models) page.
+
+To run the code, you need to follow the same flow as in TTS.
+
+- Define 'config.json' for your needs. Note that, audio parameters should match your TTS model.
+- Example training call ```python speaker_encoder/train.py --config_path speaker_encoder/config.json --data_path ~/Data/Libri-TTS/train-clean-360```
+- Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda true /model/path/best_model.pth.tar model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files.
+- Watch training on Tensorboard as in TTS
diff --git a/speaker/__init__.py b/speaker/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/speaker/config.py b/speaker/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7172ee231f9aaf3d9aa21e7244a1e6b48ebaad39
--- /dev/null
+++ b/speaker/config.py
@@ -0,0 +1,64 @@
+from dataclasses import asdict, dataclass, field
+from typing import Dict, List
+
+from .utils.coqpit import MISSING
+from .utils.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig
+
+
+@dataclass
+class SpeakerEncoderConfig(BaseTrainingConfig):
+ """Defines parameters for Speaker Encoder model."""
+
+ model: str = "speaker_encoder"
+ audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
+ datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
+ # model params
+ model_params: Dict = field(
+ default_factory=lambda: {
+ "model_name": "lstm",
+ "input_dim": 80,
+ "proj_dim": 256,
+ "lstm_dim": 768,
+ "num_lstm_layers": 3,
+ "use_lstm_with_projection": True,
+ }
+ )
+
+ audio_augmentation: Dict = field(default_factory=lambda: {})
+
+ storage: Dict = field(
+ default_factory=lambda: {
+ "sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage
+ "storage_size": 15, # the size of the in-memory storage with respect to a single batch
+ }
+ )
+
+ # training params
+ max_train_step: int = 1000000 # end training when number of training steps reaches this value.
+ loss: str = "angleproto"
+ grad_clip: float = 3.0
+ lr: float = 0.0001
+ lr_decay: bool = False
+ warmup_steps: int = 4000
+ wd: float = 1e-6
+
+ # logging params
+ tb_model_param_stats: bool = False
+ steps_plot_stats: int = 10
+ checkpoint: bool = True
+ save_step: int = 1000
+ print_step: int = 20
+
+ # data loader
+ num_speakers_in_batch: int = MISSING
+ num_utters_per_speaker: int = MISSING
+ num_loader_workers: int = MISSING
+ skip_speakers: bool = False
+ voice_len: float = 1.6
+
+ def check_values(self):
+ super().check_values()
+ c = asdict(self)
+ assert (
+ c["model_params"]["input_dim"] == self.audio.num_mels
+ ), " [!] model input dimendion must be equal to melspectrogram dimension."
diff --git a/speaker/infer.py b/speaker/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b69b2ee6d0c1f00492e50fc11411cf6e245a18e8
--- /dev/null
+++ b/speaker/infer.py
@@ -0,0 +1,108 @@
+import re
+import json
+import fsspec
+import torch
+import numpy as np
+import argparse
+
+from argparse import RawTextHelpFormatter
+from .models.lstm import LSTMSpeakerEncoder
+from .config import SpeakerEncoderConfig
+from .utils.audio import AudioProcessor
+
+
+def read_json(json_path):
+ config_dict = {}
+ try:
+ with fsspec.open(json_path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+ except json.decoder.JSONDecodeError:
+ # backwards compat.
+ data = read_json_with_comments(json_path)
+ config_dict.update(data)
+ return config_dict
+
+
+def read_json_with_comments(json_path):
+ """for backward compat."""
+ # fallback to json
+ with fsspec.open(json_path, "r", encoding="utf-8") as f:
+ input_str = f.read()
+ # handle comments
+ input_str = re.sub(r"\\\n", "", input_str)
+ input_str = re.sub(r"//.*\n", "\n", input_str)
+ data = json.loads(input_str)
+ return data
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(
+ description="""Compute embedding vectors for each wav file in a dataset.""",
+ formatter_class=RawTextHelpFormatter,
+ )
+ parser.add_argument("model_path", type=str, help="Path to model checkpoint file.")
+ parser.add_argument(
+ "config_path",
+ type=str,
+ help="Path to model config file.",
+ )
+
+ parser.add_argument("-s", "--source", help="input wave", dest="source")
+ parser.add_argument(
+ "-t", "--target", help="output 256d speaker embeddimg", dest="target"
+ )
+
+ parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
+ parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
+
+ args = parser.parse_args()
+ source_file = args.source
+ target_file = args.target
+
+ # config
+ config_dict = read_json(args.config_path)
+ # print(config_dict)
+
+ # model
+ config = SpeakerEncoderConfig(config_dict)
+ config.from_dict(config_dict)
+
+ speaker_encoder = LSTMSpeakerEncoder(
+ config.model_params["input_dim"],
+ config.model_params["proj_dim"],
+ config.model_params["lstm_dim"],
+ config.model_params["num_lstm_layers"],
+ )
+
+ speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda)
+
+ # preprocess
+ speaker_encoder_ap = AudioProcessor(**config.audio)
+ # normalize the input audio level and trim silences
+ speaker_encoder_ap.do_sound_norm = True
+ speaker_encoder_ap.do_trim_silence = True
+
+ # compute speaker embeddings
+
+ # extract the embedding
+ waveform = speaker_encoder_ap.load_wav(
+ source_file, sr=speaker_encoder_ap.sample_rate
+ )
+ spec = speaker_encoder_ap.melspectrogram(waveform)
+ spec = torch.from_numpy(spec.T)
+ if args.use_cuda:
+ spec = spec.cuda()
+ spec = spec.unsqueeze(0)
+ embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy()
+ embed = embed.squeeze()
+ # print(embed)
+ # print(embed.size)
+ np.save(target_file, embed, allow_pickle=False)
+
+
+ if hasattr(speaker_encoder, 'module'):
+ state_dict = speaker_encoder.module.state_dict()
+ else:
+ state_dict = speaker_encoder.state_dict()
+ torch.save({'model': state_dict}, "model_small.pth")
diff --git a/speaker/models/__init__.py b/speaker/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/speaker/models/lstm.py b/speaker/models/lstm.py
new file mode 100644
index 0000000000000000000000000000000000000000..45e8ccefb76f5b0200f7f2d8392c87624abb4965
--- /dev/null
+++ b/speaker/models/lstm.py
@@ -0,0 +1,131 @@
+import numpy as np
+import torch
+from torch import nn
+
+from ..utils.io import load_fsspec
+
+
+class LSTMWithProjection(nn.Module):
+ def __init__(self, input_size, hidden_size, proj_size):
+ super().__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.proj_size = proj_size
+ self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
+ self.linear = nn.Linear(hidden_size, proj_size, bias=False)
+
+ def forward(self, x):
+ self.lstm.flatten_parameters()
+ o, (_, _) = self.lstm(x)
+ return self.linear(o)
+
+
+class LSTMWithoutProjection(nn.Module):
+ def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers):
+ super().__init__()
+ self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True)
+ self.linear = nn.Linear(lstm_dim, proj_dim, bias=True)
+ self.relu = nn.ReLU()
+
+ def forward(self, x):
+ _, (hidden, _) = self.lstm(x)
+ return self.relu(self.linear(hidden[-1]))
+
+
+class LSTMSpeakerEncoder(nn.Module):
+ def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True):
+ super().__init__()
+ self.use_lstm_with_projection = use_lstm_with_projection
+ layers = []
+ # choise LSTM layer
+ if use_lstm_with_projection:
+ layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim))
+ for _ in range(num_lstm_layers - 1):
+ layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim))
+ self.layers = nn.Sequential(*layers)
+ else:
+ self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
+
+ self._init_layers()
+
+ def _init_layers(self):
+ for name, param in self.layers.named_parameters():
+ if "bias" in name:
+ nn.init.constant_(param, 0.0)
+ elif "weight" in name:
+ nn.init.xavier_normal_(param)
+
+ def forward(self, x):
+ # TODO: implement state passing for lstms
+ d = self.layers(x)
+ if self.use_lstm_with_projection:
+ d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
+ else:
+ d = torch.nn.functional.normalize(d, p=2, dim=1)
+ return d
+
+ @torch.no_grad()
+ def inference(self, x):
+ d = self.layers.forward(x)
+ if self.use_lstm_with_projection:
+ d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
+ else:
+ d = torch.nn.functional.normalize(d, p=2, dim=1)
+ return d
+
+ def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
+ """
+ Generate embeddings for a batch of utterances
+ x: 1xTxD
+ """
+ max_len = x.shape[1]
+
+ if max_len < num_frames:
+ num_frames = max_len
+
+ offsets = np.linspace(0, max_len - num_frames, num=num_eval)
+
+ frames_batch = []
+ for offset in offsets:
+ offset = int(offset)
+ end_offset = int(offset + num_frames)
+ frames = x[:, offset:end_offset]
+ frames_batch.append(frames)
+
+ frames_batch = torch.cat(frames_batch, dim=0)
+ embeddings = self.inference(frames_batch)
+
+ if return_mean:
+ embeddings = torch.mean(embeddings, dim=0, keepdim=True)
+
+ return embeddings
+
+ def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5):
+ """
+ Generate embeddings for a batch of utterances
+ x: BxTxD
+ """
+ num_overlap = num_frames * overlap
+ max_len = x.shape[1]
+ embed = None
+ num_iters = seq_lens / (num_frames - num_overlap)
+ cur_iter = 0
+ for offset in range(0, max_len, num_frames - num_overlap):
+ cur_iter += 1
+ end_offset = min(x.shape[1], offset + num_frames)
+ frames = x[:, offset:end_offset]
+ if embed is None:
+ embed = self.inference(frames)
+ else:
+ embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :])
+ return embed / num_iters
+
+ # pylint: disable=unused-argument, redefined-builtin
+ def load_checkpoint(self, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
+ self.load_state_dict(state["model"])
+ if use_cuda:
+ self.cuda()
+ if eval:
+ self.eval()
+ assert not self.training
diff --git a/speaker/models/resnet.py b/speaker/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcc850d7b87e03c2490e6b88232d9a0f668586ad
--- /dev/null
+++ b/speaker/models/resnet.py
@@ -0,0 +1,212 @@
+import numpy as np
+import torch
+from torch import nn
+
+from TTS.utils.io import load_fsspec
+
+
+class SELayer(nn.Module):
+ def __init__(self, channel, reduction=8):
+ super(SELayer, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction),
+ nn.ReLU(inplace=True),
+ nn.Linear(channel // reduction, channel),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y
+
+
+class SEBasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
+ super(SEBasicBlock, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.se = SELayer(planes, reduction)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.relu(out)
+ out = self.bn1(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.se(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+ return out
+
+
+class ResNetSpeakerEncoder(nn.Module):
+ """Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153
+ Adapted from: https://github.com/clovaai/voxceleb_trainer
+ """
+
+ # pylint: disable=W0102
+ def __init__(
+ self,
+ input_dim=64,
+ proj_dim=512,
+ layers=[3, 4, 6, 3],
+ num_filters=[32, 64, 128, 256],
+ encoder_type="ASP",
+ log_input=False,
+ ):
+ super(ResNetSpeakerEncoder, self).__init__()
+
+ self.encoder_type = encoder_type
+ self.input_dim = input_dim
+ self.log_input = log_input
+ self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
+ self.relu = nn.ReLU(inplace=True)
+ self.bn1 = nn.BatchNorm2d(num_filters[0])
+
+ self.inplanes = num_filters[0]
+ self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0])
+ self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2))
+ self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2))
+ self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2))
+
+ self.instancenorm = nn.InstanceNorm1d(input_dim)
+
+ outmap_size = int(self.input_dim / 8)
+
+ self.attention = nn.Sequential(
+ nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
+ nn.ReLU(),
+ nn.BatchNorm1d(128),
+ nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
+ nn.Softmax(dim=2),
+ )
+
+ if self.encoder_type == "SAP":
+ out_dim = num_filters[3] * outmap_size
+ elif self.encoder_type == "ASP":
+ out_dim = num_filters[3] * outmap_size * 2
+ else:
+ raise ValueError("Undefined encoder")
+
+ self.fc = nn.Linear(out_dim, proj_dim)
+
+ self._init_layers()
+
+ def _init_layers(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def create_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ # pylint: disable=R0201
+ def new_parameter(self, *size):
+ out = nn.Parameter(torch.FloatTensor(*size))
+ nn.init.xavier_normal_(out)
+ return out
+
+ def forward(self, x, l2_norm=False):
+ x = x.transpose(1, 2)
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(enabled=False):
+ if self.log_input:
+ x = (x + 1e-6).log()
+ x = self.instancenorm(x).unsqueeze(1)
+
+ x = self.conv1(x)
+ x = self.relu(x)
+ x = self.bn1(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = x.reshape(x.size()[0], -1, x.size()[-1])
+
+ w = self.attention(x)
+
+ if self.encoder_type == "SAP":
+ x = torch.sum(x * w, dim=2)
+ elif self.encoder_type == "ASP":
+ mu = torch.sum(x * w, dim=2)
+ sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5))
+ x = torch.cat((mu, sg), 1)
+
+ x = x.view(x.size()[0], -1)
+ x = self.fc(x)
+
+ if l2_norm:
+ x = torch.nn.functional.normalize(x, p=2, dim=1)
+ return x
+
+ @torch.no_grad()
+ def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
+ """
+ Generate embeddings for a batch of utterances
+ x: 1xTxD
+ """
+ max_len = x.shape[1]
+
+ if max_len < num_frames:
+ num_frames = max_len
+
+ offsets = np.linspace(0, max_len - num_frames, num=num_eval)
+
+ frames_batch = []
+ for offset in offsets:
+ offset = int(offset)
+ end_offset = int(offset + num_frames)
+ frames = x[:, offset:end_offset]
+ frames_batch.append(frames)
+
+ frames_batch = torch.cat(frames_batch, dim=0)
+ embeddings = self.forward(frames_batch, l2_norm=True)
+
+ if return_mean:
+ embeddings = torch.mean(embeddings, dim=0, keepdim=True)
+
+ return embeddings
+
+ def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
+ self.load_state_dict(state["model"])
+ if use_cuda:
+ self.cuda()
+ if eval:
+ self.eval()
+ assert not self.training
diff --git a/speaker/umap.png b/speaker/umap.png
new file mode 100644
index 0000000000000000000000000000000000000000..ca8aefeac8cbe616983b35e968c9c9133eb41ede
Binary files /dev/null and b/speaker/umap.png differ
diff --git a/speaker/utils/__init__.py b/speaker/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/speaker/utils/audio.py b/speaker/utils/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2c9627e93cf7ba864532144bb4522cf575d3d2b
--- /dev/null
+++ b/speaker/utils/audio.py
@@ -0,0 +1,822 @@
+from typing import Dict, Tuple
+
+import librosa
+import numpy as np
+import pyworld as pw
+import scipy.io.wavfile
+import scipy.signal
+import soundfile as sf
+import torch
+from torch import nn
+
+class StandardScaler:
+ """StandardScaler for mean-scale normalization with the given mean and scale values."""
+
+ def __init__(self, mean: np.ndarray = None, scale: np.ndarray = None) -> None:
+ self.mean_ = mean
+ self.scale_ = scale
+
+ def set_stats(self, mean, scale):
+ self.mean_ = mean
+ self.scale_ = scale
+
+ def reset_stats(self):
+ delattr(self, "mean_")
+ delattr(self, "scale_")
+
+ def transform(self, X):
+ X = np.asarray(X)
+ X -= self.mean_
+ X /= self.scale_
+ return X
+
+ def inverse_transform(self, X):
+ X = np.asarray(X)
+ X *= self.scale_
+ X += self.mean_
+ return X
+
+class TorchSTFT(nn.Module): # pylint: disable=abstract-method
+ """Some of the audio processing funtions using Torch for faster batch processing.
+
+ TODO: Merge this with audio.py
+ """
+
+ def __init__(
+ self,
+ n_fft,
+ hop_length,
+ win_length,
+ pad_wav=False,
+ window="hann_window",
+ sample_rate=None,
+ mel_fmin=0,
+ mel_fmax=None,
+ n_mels=80,
+ use_mel=False,
+ do_amp_to_db=False,
+ spec_gain=1.0,
+ ):
+ super().__init__()
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.pad_wav = pad_wav
+ self.sample_rate = sample_rate
+ self.mel_fmin = mel_fmin
+ self.mel_fmax = mel_fmax
+ self.n_mels = n_mels
+ self.use_mel = use_mel
+ self.do_amp_to_db = do_amp_to_db
+ self.spec_gain = spec_gain
+ self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
+ self.mel_basis = None
+ if use_mel:
+ self._build_mel_basis()
+
+ def __call__(self, x):
+ """Compute spectrogram frames by torch based stft.
+
+ Args:
+ x (Tensor): input waveform
+
+ Returns:
+ Tensor: spectrogram frames.
+
+ Shapes:
+ x: [B x T] or [:math:`[B, 1, T]`]
+ """
+ if x.ndim == 2:
+ x = x.unsqueeze(1)
+ if self.pad_wav:
+ padding = int((self.n_fft - self.hop_length) / 2)
+ x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
+ # B x D x T x 2
+ o = torch.stft(
+ x.squeeze(1),
+ self.n_fft,
+ self.hop_length,
+ self.win_length,
+ self.window,
+ center=True,
+ pad_mode="reflect", # compatible with audio.py
+ normalized=False,
+ onesided=True,
+ return_complex=False,
+ )
+ M = o[:, :, :, 0]
+ P = o[:, :, :, 1]
+ S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
+ if self.use_mel:
+ S = torch.matmul(self.mel_basis.to(x), S)
+ if self.do_amp_to_db:
+ S = self._amp_to_db(S, spec_gain=self.spec_gain)
+ return S
+
+ def _build_mel_basis(self):
+ mel_basis = librosa.filters.mel(
+ sr=self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
+ )
+ self.mel_basis = torch.from_numpy(mel_basis).float()
+
+ @staticmethod
+ def _amp_to_db(x, spec_gain=1.0):
+ return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
+
+ @staticmethod
+ def _db_to_amp(x, spec_gain=1.0):
+ return torch.exp(x) / spec_gain
+
+
+# pylint: disable=too-many-public-methods
+class AudioProcessor(object):
+ """Audio Processor for TTS used by all the data pipelines.
+
+ Note:
+ All the class arguments are set to default values to enable a flexible initialization
+ of the class with the model config. They are not meaningful for all the arguments.
+
+ Args:
+ sample_rate (int, optional):
+ target audio sampling rate. Defaults to None.
+
+ resample (bool, optional):
+ enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False.
+
+ num_mels (int, optional):
+ number of melspectrogram dimensions. Defaults to None.
+
+ log_func (int, optional):
+ log exponent used for converting spectrogram aplitude to DB.
+
+ min_level_db (int, optional):
+ minimum db threshold for the computed melspectrograms. Defaults to None.
+
+ frame_shift_ms (int, optional):
+ milliseconds of frames between STFT columns. Defaults to None.
+
+ frame_length_ms (int, optional):
+ milliseconds of STFT window length. Defaults to None.
+
+ hop_length (int, optional):
+ number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None.
+
+ win_length (int, optional):
+ STFT window length. Used if ```frame_length_ms``` is None. Defaults to None.
+
+ ref_level_db (int, optional):
+ reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None.
+
+ fft_size (int, optional):
+ FFT window size for STFT. Defaults to 1024.
+
+ power (int, optional):
+ Exponent value applied to the spectrogram before GriffinLim. Defaults to None.
+
+ preemphasis (float, optional):
+ Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0.
+
+ signal_norm (bool, optional):
+ enable/disable signal normalization. Defaults to None.
+
+ symmetric_norm (bool, optional):
+ enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None.
+
+ max_norm (float, optional):
+ ```k``` defining the normalization range. Defaults to None.
+
+ mel_fmin (int, optional):
+ minimum filter frequency for computing melspectrograms. Defaults to None.
+
+ mel_fmax (int, optional):
+ maximum filter frequency for computing melspectrograms.. Defaults to None.
+
+ spec_gain (int, optional):
+ gain applied when converting amplitude to DB. Defaults to 20.
+
+ stft_pad_mode (str, optional):
+ Padding mode for STFT. Defaults to 'reflect'.
+
+ clip_norm (bool, optional):
+ enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
+
+ griffin_lim_iters (int, optional):
+ Number of GriffinLim iterations. Defaults to None.
+
+ do_trim_silence (bool, optional):
+ enable/disable silence trimming when loading the audio signal. Defaults to False.
+
+ trim_db (int, optional):
+ DB threshold used for silence trimming. Defaults to 60.
+
+ do_sound_norm (bool, optional):
+ enable/disable signal normalization. Defaults to False.
+
+ do_amp_to_db_linear (bool, optional):
+ enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
+
+ do_amp_to_db_mel (bool, optional):
+ enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
+
+ stats_path (str, optional):
+ Path to the computed stats file. Defaults to None.
+
+ verbose (bool, optional):
+ enable/disable logging. Defaults to True.
+
+ """
+
+ def __init__(
+ self,
+ sample_rate=None,
+ resample=False,
+ num_mels=None,
+ log_func="np.log10",
+ min_level_db=None,
+ frame_shift_ms=None,
+ frame_length_ms=None,
+ hop_length=None,
+ win_length=None,
+ ref_level_db=None,
+ fft_size=1024,
+ power=None,
+ preemphasis=0.0,
+ signal_norm=None,
+ symmetric_norm=None,
+ max_norm=None,
+ mel_fmin=None,
+ mel_fmax=None,
+ spec_gain=20,
+ stft_pad_mode="reflect",
+ clip_norm=True,
+ griffin_lim_iters=None,
+ do_trim_silence=False,
+ trim_db=60,
+ do_sound_norm=False,
+ do_amp_to_db_linear=True,
+ do_amp_to_db_mel=True,
+ stats_path=None,
+ verbose=True,
+ **_,
+ ):
+
+ # setup class attributed
+ self.sample_rate = sample_rate
+ self.resample = resample
+ self.num_mels = num_mels
+ self.log_func = log_func
+ self.min_level_db = min_level_db or 0
+ self.frame_shift_ms = frame_shift_ms
+ self.frame_length_ms = frame_length_ms
+ self.ref_level_db = ref_level_db
+ self.fft_size = fft_size
+ self.power = power
+ self.preemphasis = preemphasis
+ self.griffin_lim_iters = griffin_lim_iters
+ self.signal_norm = signal_norm
+ self.symmetric_norm = symmetric_norm
+ self.mel_fmin = mel_fmin or 0
+ self.mel_fmax = mel_fmax
+ self.spec_gain = float(spec_gain)
+ self.stft_pad_mode = stft_pad_mode
+ self.max_norm = 1.0 if max_norm is None else float(max_norm)
+ self.clip_norm = clip_norm
+ self.do_trim_silence = do_trim_silence
+ self.trim_db = trim_db
+ self.do_sound_norm = do_sound_norm
+ self.do_amp_to_db_linear = do_amp_to_db_linear
+ self.do_amp_to_db_mel = do_amp_to_db_mel
+ self.stats_path = stats_path
+ # setup exp_func for db to amp conversion
+ if log_func == "np.log":
+ self.base = np.e
+ elif log_func == "np.log10":
+ self.base = 10
+ else:
+ raise ValueError(" [!] unknown `log_func` value.")
+ # setup stft parameters
+ if hop_length is None:
+ # compute stft parameters from given time values
+ self.hop_length, self.win_length = self._stft_parameters()
+ else:
+ # use stft parameters from config file
+ self.hop_length = hop_length
+ self.win_length = win_length
+ assert min_level_db != 0.0, " [!] min_level_db is 0"
+ assert self.win_length <= self.fft_size, " [!] win_length cannot be larger than fft_size"
+ members = vars(self)
+ if verbose:
+ print(" > Setting up Audio Processor...")
+ for key, value in members.items():
+ print(" | > {}:{}".format(key, value))
+ # create spectrogram utils
+ self.mel_basis = self._build_mel_basis()
+ self.inv_mel_basis = np.linalg.pinv(self._build_mel_basis())
+ # setup scaler
+ if stats_path and signal_norm:
+ mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
+ self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std)
+ self.signal_norm = True
+ self.max_norm = None
+ self.clip_norm = None
+ self.symmetric_norm = None
+
+ ### setting up the parameters ###
+ def _build_mel_basis(
+ self,
+ ) -> np.ndarray:
+ """Build melspectrogram basis.
+
+ Returns:
+ np.ndarray: melspectrogram basis.
+ """
+ if self.mel_fmax is not None:
+ assert self.mel_fmax <= self.sample_rate // 2
+ return librosa.filters.mel(
+ sr=self.sample_rate, n_fft=self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
+ )
+
+ def _stft_parameters(
+ self,
+ ) -> Tuple[int, int]:
+ """Compute the real STFT parameters from the time values.
+
+ Returns:
+ Tuple[int, int]: hop length and window length for STFT.
+ """
+ factor = self.frame_length_ms / self.frame_shift_ms
+ assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
+ hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
+ win_length = int(hop_length * factor)
+ return hop_length, win_length
+
+ ### normalization ###
+ def normalize(self, S: np.ndarray) -> np.ndarray:
+ """Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]`
+
+ Args:
+ S (np.ndarray): Spectrogram to normalize.
+
+ Raises:
+ RuntimeError: Mean and variance is computed from incompatible parameters.
+
+ Returns:
+ np.ndarray: Normalized spectrogram.
+ """
+ # pylint: disable=no-else-return
+ S = S.copy()
+ if self.signal_norm:
+ # mean-var scaling
+ if hasattr(self, "mel_scaler"):
+ if S.shape[0] == self.num_mels:
+ return self.mel_scaler.transform(S.T).T
+ elif S.shape[0] == self.fft_size / 2:
+ return self.linear_scaler.transform(S.T).T
+ else:
+ raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.")
+ # range normalization
+ S -= self.ref_level_db # discard certain range of DB assuming it is air noise
+ S_norm = (S - self.min_level_db) / (-self.min_level_db)
+ if self.symmetric_norm:
+ S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm
+ if self.clip_norm:
+ S_norm = np.clip(
+ S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type
+ )
+ return S_norm
+ else:
+ S_norm = self.max_norm * S_norm
+ if self.clip_norm:
+ S_norm = np.clip(S_norm, 0, self.max_norm)
+ return S_norm
+ else:
+ return S
+
+ def denormalize(self, S: np.ndarray) -> np.ndarray:
+ """Denormalize spectrogram values.
+
+ Args:
+ S (np.ndarray): Spectrogram to denormalize.
+
+ Raises:
+ RuntimeError: Mean and variance are incompatible.
+
+ Returns:
+ np.ndarray: Denormalized spectrogram.
+ """
+ # pylint: disable=no-else-return
+ S_denorm = S.copy()
+ if self.signal_norm:
+ # mean-var scaling
+ if hasattr(self, "mel_scaler"):
+ if S_denorm.shape[0] == self.num_mels:
+ return self.mel_scaler.inverse_transform(S_denorm.T).T
+ elif S_denorm.shape[0] == self.fft_size / 2:
+ return self.linear_scaler.inverse_transform(S_denorm.T).T
+ else:
+ raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.")
+ if self.symmetric_norm:
+ if self.clip_norm:
+ S_denorm = np.clip(
+ S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type
+ )
+ S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db
+ return S_denorm + self.ref_level_db
+ else:
+ if self.clip_norm:
+ S_denorm = np.clip(S_denorm, 0, self.max_norm)
+ S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db
+ return S_denorm + self.ref_level_db
+ else:
+ return S_denorm
+
+ ### Mean-STD scaling ###
+ def load_stats(self, stats_path: str) -> Tuple[np.array, np.array, np.array, np.array, Dict]:
+ """Loading mean and variance statistics from a `npy` file.
+
+ Args:
+ stats_path (str): Path to the `npy` file containing
+
+ Returns:
+ Tuple[np.array, np.array, np.array, np.array, Dict]: loaded statistics and the config used to
+ compute them.
+ """
+ stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg
+ mel_mean = stats["mel_mean"]
+ mel_std = stats["mel_std"]
+ linear_mean = stats["linear_mean"]
+ linear_std = stats["linear_std"]
+ stats_config = stats["audio_config"]
+ # check all audio parameters used for computing stats
+ skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"]
+ for key in stats_config.keys():
+ if key in skip_parameters:
+ continue
+ if key not in ["sample_rate", "trim_db"]:
+ assert (
+ stats_config[key] == self.__dict__[key]
+ ), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
+ return mel_mean, mel_std, linear_mean, linear_std, stats_config
+
+ # pylint: disable=attribute-defined-outside-init
+ def setup_scaler(
+ self, mel_mean: np.ndarray, mel_std: np.ndarray, linear_mean: np.ndarray, linear_std: np.ndarray
+ ) -> None:
+ """Initialize scaler objects used in mean-std normalization.
+
+ Args:
+ mel_mean (np.ndarray): Mean for melspectrograms.
+ mel_std (np.ndarray): STD for melspectrograms.
+ linear_mean (np.ndarray): Mean for full scale spectrograms.
+ linear_std (np.ndarray): STD for full scale spectrograms.
+ """
+ self.mel_scaler = StandardScaler()
+ self.mel_scaler.set_stats(mel_mean, mel_std)
+ self.linear_scaler = StandardScaler()
+ self.linear_scaler.set_stats(linear_mean, linear_std)
+
+ ### DB and AMP conversion ###
+ # pylint: disable=no-self-use
+ def _amp_to_db(self, x: np.ndarray) -> np.ndarray:
+ """Convert amplitude values to decibels.
+
+ Args:
+ x (np.ndarray): Amplitude spectrogram.
+
+ Returns:
+ np.ndarray: Decibels spectrogram.
+ """
+ return self.spec_gain * _log(np.maximum(1e-5, x), self.base)
+
+ # pylint: disable=no-self-use
+ def _db_to_amp(self, x: np.ndarray) -> np.ndarray:
+ """Convert decibels spectrogram to amplitude spectrogram.
+
+ Args:
+ x (np.ndarray): Decibels spectrogram.
+
+ Returns:
+ np.ndarray: Amplitude spectrogram.
+ """
+ return _exp(x / self.spec_gain, self.base)
+
+ ### Preemphasis ###
+ def apply_preemphasis(self, x: np.ndarray) -> np.ndarray:
+ """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
+
+ Args:
+ x (np.ndarray): Audio signal.
+
+ Raises:
+ RuntimeError: Preemphasis coeff is set to 0.
+
+ Returns:
+ np.ndarray: Decorrelated audio signal.
+ """
+ if self.preemphasis == 0:
+ raise RuntimeError(" [!] Preemphasis is set 0.0.")
+ return scipy.signal.lfilter([1, -self.preemphasis], [1], x)
+
+ def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray:
+ """Reverse pre-emphasis."""
+ if self.preemphasis == 0:
+ raise RuntimeError(" [!] Preemphasis is set 0.0.")
+ return scipy.signal.lfilter([1], [1, -self.preemphasis], x)
+
+ ### SPECTROGRAMs ###
+ def _linear_to_mel(self, spectrogram: np.ndarray) -> np.ndarray:
+ """Project a full scale spectrogram to a melspectrogram.
+
+ Args:
+ spectrogram (np.ndarray): Full scale spectrogram.
+
+ Returns:
+ np.ndarray: Melspectrogram
+ """
+ return np.dot(self.mel_basis, spectrogram)
+
+ def _mel_to_linear(self, mel_spec: np.ndarray) -> np.ndarray:
+ """Convert a melspectrogram to full scale spectrogram."""
+ return np.maximum(1e-10, np.dot(self.inv_mel_basis, mel_spec))
+
+ def spectrogram(self, y: np.ndarray) -> np.ndarray:
+ """Compute a spectrogram from a waveform.
+
+ Args:
+ y (np.ndarray): Waveform.
+
+ Returns:
+ np.ndarray: Spectrogram.
+ """
+ if self.preemphasis != 0:
+ D = self._stft(self.apply_preemphasis(y))
+ else:
+ D = self._stft(y)
+ if self.do_amp_to_db_linear:
+ S = self._amp_to_db(np.abs(D))
+ else:
+ S = np.abs(D)
+ return self.normalize(S).astype(np.float32)
+
+ def melspectrogram(self, y: np.ndarray) -> np.ndarray:
+ """Compute a melspectrogram from a waveform."""
+ if self.preemphasis != 0:
+ D = self._stft(self.apply_preemphasis(y))
+ else:
+ D = self._stft(y)
+ if self.do_amp_to_db_mel:
+ S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
+ else:
+ S = self._linear_to_mel(np.abs(D))
+ return self.normalize(S).astype(np.float32)
+
+ def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
+ """Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
+ S = self.denormalize(spectrogram)
+ S = self._db_to_amp(S)
+ # Reconstruct phase
+ if self.preemphasis != 0:
+ return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
+ return self._griffin_lim(S ** self.power)
+
+ def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
+ """Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
+ D = self.denormalize(mel_spectrogram)
+ S = self._db_to_amp(D)
+ S = self._mel_to_linear(S) # Convert back to linear
+ if self.preemphasis != 0:
+ return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
+ return self._griffin_lim(S ** self.power)
+
+ def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
+ """Convert a full scale linear spectrogram output of a network to a melspectrogram.
+
+ Args:
+ linear_spec (np.ndarray): Normalized full scale linear spectrogram.
+
+ Returns:
+ np.ndarray: Normalized melspectrogram.
+ """
+ S = self.denormalize(linear_spec)
+ S = self._db_to_amp(S)
+ S = self._linear_to_mel(np.abs(S))
+ S = self._amp_to_db(S)
+ mel = self.normalize(S)
+ return mel
+
+ ### STFT and ISTFT ###
+ def _stft(self, y: np.ndarray) -> np.ndarray:
+ """Librosa STFT wrapper.
+
+ Args:
+ y (np.ndarray): Audio signal.
+
+ Returns:
+ np.ndarray: Complex number array.
+ """
+ return librosa.stft(
+ y=y,
+ n_fft=self.fft_size,
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ pad_mode=self.stft_pad_mode,
+ window="hann",
+ center=True,
+ )
+
+ def _istft(self, y: np.ndarray) -> np.ndarray:
+ """Librosa iSTFT wrapper."""
+ return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
+
+ def _griffin_lim(self, S):
+ angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
+ S_complex = np.abs(S).astype(np.complex)
+ y = self._istft(S_complex * angles)
+ if not np.isfinite(y).all():
+ print(" [!] Waveform is not finite everywhere. Skipping the GL.")
+ return np.array([0.0])
+ for _ in range(self.griffin_lim_iters):
+ angles = np.exp(1j * np.angle(self._stft(y)))
+ y = self._istft(S_complex * angles)
+ return y
+
+ def compute_stft_paddings(self, x, pad_sides=1):
+ """Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
+ (first and final frames)"""
+ assert pad_sides in (1, 2)
+ pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0]
+ if pad_sides == 1:
+ return 0, pad
+ return pad // 2, pad // 2 + pad % 2
+
+ def compute_f0(self, x: np.ndarray) -> np.ndarray:
+ """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram.
+
+ Args:
+ x (np.ndarray): Waveform.
+
+ Returns:
+ np.ndarray: Pitch.
+
+ Examples:
+ >>> WAV_FILE = filename = librosa.util.example_audio_file()
+ >>> from TTS.config import BaseAudioConfig
+ >>> from TTS.utils.audio import AudioProcessor
+ >>> conf = BaseAudioConfig(mel_fmax=8000)
+ >>> ap = AudioProcessor(**conf)
+ >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050]
+ >>> pitch = ap.compute_f0(wav)
+ """
+ f0, t = pw.dio(
+ x.astype(np.double),
+ fs=self.sample_rate,
+ f0_ceil=self.mel_fmax,
+ frame_period=1000 * self.hop_length / self.sample_rate,
+ )
+ f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
+ # pad = int((self.win_length / self.hop_length) / 2)
+ # f0 = [0.0] * pad + f0 + [0.0] * pad
+ # f0 = np.pad(f0, (pad, pad), mode="constant", constant_values=0)
+ # f0 = np.array(f0, dtype=np.float32)
+
+ # f01, _, _ = librosa.pyin(
+ # x,
+ # fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
+ # fmax=self.mel_fmax,
+ # frame_length=self.win_length,
+ # sr=self.sample_rate,
+ # fill_na=0.0,
+ # )
+
+ # spec = self.melspectrogram(x)
+ return f0
+
+ ### Audio Processing ###
+ def find_endpoint(self, wav: np.ndarray, threshold_db=-40, min_silence_sec=0.8) -> int:
+ """Find the last point without silence at the end of a audio signal.
+
+ Args:
+ wav (np.ndarray): Audio signal.
+ threshold_db (int, optional): Silence threshold in decibels. Defaults to -40.
+ min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8.
+
+ Returns:
+ int: Last point without silence.
+ """
+ window_length = int(self.sample_rate * min_silence_sec)
+ hop_length = int(window_length / 4)
+ threshold = self._db_to_amp(threshold_db)
+ for x in range(hop_length, len(wav) - window_length, hop_length):
+ if np.max(wav[x : x + window_length]) < threshold:
+ return x + hop_length
+ return len(wav)
+
+ def trim_silence(self, wav):
+ """Trim silent parts with a threshold and 0.01 sec margin"""
+ margin = int(self.sample_rate * 0.01)
+ wav = wav[margin:-margin]
+ return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[
+ 0
+ ]
+
+ @staticmethod
+ def sound_norm(x: np.ndarray) -> np.ndarray:
+ """Normalize the volume of an audio signal.
+
+ Args:
+ x (np.ndarray): Raw waveform.
+
+ Returns:
+ np.ndarray: Volume normalized waveform.
+ """
+ return x / abs(x).max() * 0.95
+
+ ### save and load ###
+ def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
+ """Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
+
+ Args:
+ filename (str): Path to the wav file.
+ sr (int, optional): Sampling rate for resampling. Defaults to None.
+
+ Returns:
+ np.ndarray: Loaded waveform.
+ """
+ if self.resample:
+ x, sr = librosa.load(filename, sr=self.sample_rate)
+ elif sr is None:
+ x, sr = sf.read(filename)
+ assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr)
+ else:
+ x, sr = librosa.load(filename, sr=sr)
+ if self.do_trim_silence:
+ try:
+ x = self.trim_silence(x)
+ except ValueError:
+ print(f" [!] File cannot be trimmed for silence - {filename}")
+ if self.do_sound_norm:
+ x = self.sound_norm(x)
+ return x
+
+ def save_wav(self, wav: np.ndarray, path: str, sr: int = None) -> None:
+ """Save a waveform to a file using Scipy.
+
+ Args:
+ wav (np.ndarray): Waveform to save.
+ path (str): Path to a output file.
+ sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
+ """
+ wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
+ scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm.astype(np.int16))
+
+ @staticmethod
+ def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
+ mu = 2 ** qc - 1
+ # wav_abs = np.minimum(np.abs(wav), 1.0)
+ signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
+ # Quantize signal to the specified number of levels.
+ signal = (signal + 1) / 2 * mu + 0.5
+ return np.floor(
+ signal,
+ )
+
+ @staticmethod
+ def mulaw_decode(wav, qc):
+ """Recovers waveform from quantized values."""
+ mu = 2 ** qc - 1
+ x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
+ return x
+
+ @staticmethod
+ def encode_16bits(x):
+ return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16)
+
+ @staticmethod
+ def quantize(x: np.ndarray, bits: int) -> np.ndarray:
+ """Quantize a waveform to a given number of bits.
+
+ Args:
+ x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
+ bits (int): Number of quantization bits.
+
+ Returns:
+ np.ndarray: Quantized waveform.
+ """
+ return (x + 1.0) * (2 ** bits - 1) / 2
+
+ @staticmethod
+ def dequantize(x, bits):
+ """Dequantize a waveform from the given number of bits."""
+ return 2 * x / (2 ** bits - 1) - 1
+
+
+def _log(x, base):
+ if base == 10:
+ return np.log10(x)
+ return np.log(x)
+
+
+def _exp(x, base):
+ if base == 10:
+ return np.power(10, x)
+ return np.exp(x)
diff --git a/speaker/utils/coqpit.py b/speaker/utils/coqpit.py
new file mode 100644
index 0000000000000000000000000000000000000000..e214c8b8a2045701b15e77c5d66012a64f135429
--- /dev/null
+++ b/speaker/utils/coqpit.py
@@ -0,0 +1,954 @@
+import argparse
+import functools
+import json
+import operator
+import os
+from collections.abc import MutableMapping
+from dataclasses import MISSING as _MISSING
+from dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace
+from pathlib import Path
+from pprint import pprint
+from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, get_type_hints
+
+T = TypeVar("T")
+MISSING: Any = "???"
+
+
+class _NoDefault(Generic[T]):
+ pass
+
+
+NoDefaultVar = Union[_NoDefault[T], T]
+no_default: NoDefaultVar = _NoDefault()
+
+
+def is_primitive_type(arg_type: Any) -> bool:
+ """Check if the input type is one of `int, float, str, bool`.
+
+ Args:
+ arg_type (typing.Any): input type to check.
+
+ Returns:
+ bool: True if input type is one of `int, float, str, bool`.
+ """
+ try:
+ return isinstance(arg_type(), (int, float, str, bool))
+ except (AttributeError, TypeError):
+ return False
+
+
+def is_list(arg_type: Any) -> bool:
+ """Check if the input type is `list`
+
+ Args:
+ arg_type (typing.Any): input type.
+
+ Returns:
+ bool: True if input type is `list`
+ """
+ try:
+ return arg_type is list or arg_type is List or arg_type.__origin__ is list or arg_type.__origin__ is List
+ except AttributeError:
+ return False
+
+
+def is_dict(arg_type: Any) -> bool:
+ """Check if the input type is `dict`
+
+ Args:
+ arg_type (typing.Any): input type.
+
+ Returns:
+ bool: True if input type is `dict`
+ """
+ try:
+ return arg_type is dict or arg_type is Dict or arg_type.__origin__ is dict
+ except AttributeError:
+ return False
+
+
+def is_union(arg_type: Any) -> bool:
+ """Check if the input type is `Union`.
+
+ Args:
+ arg_type (typing.Any): input type.
+
+ Returns:
+ bool: True if input type is `Union`
+ """
+ try:
+ return safe_issubclass(arg_type.__origin__, Union)
+ except AttributeError:
+ return False
+
+
+def safe_issubclass(cls, classinfo) -> bool:
+ """Check if the input type is a subclass of the given class.
+
+ Args:
+ cls (type): input type.
+ classinfo (type): parent class.
+
+ Returns:
+ bool: True if the input type is a subclass of the given class
+ """
+ try:
+ r = issubclass(cls, classinfo)
+ except Exception: # pylint: disable=broad-except
+ return cls is classinfo
+ else:
+ return r
+
+
+def _coqpit_json_default(obj: Any) -> Any:
+ if isinstance(obj, Path):
+ return str(obj)
+ raise TypeError(f"Can't encode object of type {type(obj).__name__}")
+
+
+def _default_value(x: Field):
+ """Return the default value of the input Field.
+
+ Args:
+ x (Field): input Field.
+
+ Returns:
+ object: default value of the input Field.
+ """
+ if x.default not in (MISSING, _MISSING):
+ return x.default
+ if x.default_factory not in (MISSING, _MISSING):
+ return x.default_factory()
+ return x.default
+
+
+def _is_optional_field(field) -> bool:
+ """Check if the input field is optional.
+
+ Args:
+ field (Field): input Field to check.
+
+ Returns:
+ bool: True if the input field is optional.
+ """
+ # return isinstance(field.type, _GenericAlias) and type(None) in getattr(field.type, "__args__")
+ return type(None) in getattr(field.type, "__args__")
+
+
+def my_get_type_hints(
+ cls,
+):
+ """Custom `get_type_hints` dealing with https://github.com/python/typing/issues/737
+
+ Returns:
+ [dataclass]: dataclass to get the type hints of its fields.
+ """
+ r_dict = {}
+ for base in cls.__class__.__bases__:
+ if base == object:
+ break
+ r_dict.update(my_get_type_hints(base))
+ r_dict.update(get_type_hints(cls))
+ return r_dict
+
+
+def _serialize(x):
+ """Pick the right serialization for the datatype of the given input.
+
+ Args:
+ x (object): input object.
+
+ Returns:
+ object: serialized object.
+ """
+ if isinstance(x, Path):
+ return str(x)
+ if isinstance(x, dict):
+ return {k: _serialize(v) for k, v in x.items()}
+ if isinstance(x, list):
+ return [_serialize(xi) for xi in x]
+ if isinstance(x, Serializable) or issubclass(type(x), Serializable):
+ return x.serialize()
+ if isinstance(x, type) and issubclass(x, Serializable):
+ return x.serialize(x)
+ return x
+
+
+def _deserialize_dict(x: Dict) -> Dict:
+ """Deserialize dict.
+
+ Args:
+ x (Dict): value to deserialized.
+
+ Returns:
+ Dict: deserialized dictionary.
+ """
+ out_dict = {}
+ for k, v in x.items():
+ if v is None: # if {'key':None}
+ out_dict[k] = None
+ else:
+ out_dict[k] = _deserialize(v, type(v))
+ return out_dict
+
+
+def _deserialize_list(x: List, field_type: Type) -> List:
+ """Deserialize values for List typed fields.
+
+ Args:
+ x (List): value to be deserialized
+ field_type (Type): field type.
+
+ Raises:
+ ValueError: Coqpit does not support multi type-hinted lists.
+
+ Returns:
+ [List]: deserialized list.
+ """
+ field_args = None
+ if hasattr(field_type, "__args__") and field_type.__args__:
+ field_args = field_type.__args__
+ elif hasattr(field_type, "__parameters__") and field_type.__parameters__:
+ # bandaid for python 3.6
+ field_args = field_type.__parameters__
+ if field_args:
+ if len(field_args) > 1:
+ raise ValueError(" [!] Coqpit does not support multi-type hinted 'List'")
+ field_arg = field_args[0]
+ # if field type is TypeVar set the current type by the value's type.
+ if isinstance(field_arg, TypeVar):
+ field_arg = type(x)
+ return [_deserialize(xi, field_arg) for xi in x]
+ return x
+
+
+def _deserialize_union(x: Any, field_type: Type) -> Any:
+ """Deserialize values for Union typed fields
+
+ Args:
+ x (Any): value to be deserialized.
+ field_type (Type): field type.
+
+ Returns:
+ [Any]: desrialized value.
+ """
+ for arg in field_type.__args__:
+ # stop after first matching type in Union
+ try:
+ x = _deserialize(x, arg)
+ break
+ except ValueError:
+ pass
+ return x
+
+
+def _deserialize_primitive_types(x: Union[int, float, str, bool], field_type: Type) -> Union[int, float, str, bool]:
+ """Deserialize python primitive types (float, int, str, bool).
+ It handles `inf` values exclusively and keeps them float against int fields since int does not support inf values.
+
+ Args:
+ x (Union[int, float, str, bool]): value to be deserialized.
+ field_type (Type): field type.
+
+ Returns:
+ Union[int, float, str, bool]: deserialized value.
+ """
+
+ if isinstance(x, (str, bool)):
+ return x
+ if isinstance(x, (int, float)):
+ if x == float("inf") or x == float("-inf"):
+ # if value type is inf return regardless.
+ return x
+ x = field_type(x)
+ return x
+ # TODO: Raise an error when x does not match the types.
+ return None
+
+
+def _deserialize(x: Any, field_type: Any) -> Any:
+ """Pick the right desrialization for the given object and the corresponding field type.
+
+ Args:
+ x (object): object to be deserialized.
+ field_type (type): expected type after deserialization.
+
+ Returns:
+ object: deserialized object
+
+ """
+ # pylint: disable=too-many-return-statements
+ if is_dict(field_type):
+ return _deserialize_dict(x)
+ if is_list(field_type):
+ return _deserialize_list(x, field_type)
+ if is_union(field_type):
+ return _deserialize_union(x, field_type)
+ if issubclass(field_type, Serializable):
+ return field_type.deserialize_immutable(x)
+ if is_primitive_type(field_type):
+ return _deserialize_primitive_types(x, field_type)
+ raise ValueError(f" [!] '{type(x)}' value type of '{x}' does not match '{field_type}' field type.")
+
+
+# Recursive setattr (supports dotted attr names)
+def rsetattr(obj, attr, val):
+ def _setitem(obj, attr, val):
+ return operator.setitem(obj, int(attr), val)
+
+ pre, _, post = attr.rpartition(".")
+ setfunc = _setitem if post.isnumeric() else setattr
+
+ return setfunc(rgetattr(obj, pre) if pre else obj, post, val)
+
+
+# Recursive getattr (supports dotted attr names)
+def rgetattr(obj, attr, *args):
+ def _getitem(obj, attr):
+ return operator.getitem(obj, int(attr), *args)
+
+ def _getattr(obj, attr):
+ getfunc = _getitem if attr.isnumeric() else getattr
+ return getfunc(obj, attr, *args)
+
+ return functools.reduce(_getattr, [obj] + attr.split("."))
+
+
+# Recursive setitem (supports dotted attr names)
+def rsetitem(obj, attr, val):
+ pre, _, post = attr.rpartition(".")
+ return operator.setitem(rgetitem(obj, pre) if pre else obj, post, val)
+
+
+# Recursive getitem (supports dotted attr names)
+def rgetitem(obj, attr, *args):
+ def _getitem(obj, attr):
+ return operator.getitem(obj, int(attr) if attr.isnumeric() else attr, *args)
+
+ return functools.reduce(_getitem, [obj] + attr.split("."))
+
+
+@dataclass
+class Serializable:
+ """Gives serialization ability to any inheriting dataclass."""
+
+ def __post_init__(self):
+ self._validate_contracts()
+ for key, value in self.__dict__.items():
+ if value is no_default:
+ raise TypeError(f"__init__ missing 1 required argument: '{key}'")
+
+ def _validate_contracts(self):
+ dataclass_fields = fields(self)
+
+ for field in dataclass_fields:
+
+ value = getattr(self, field.name)
+
+ if value is None:
+ if not _is_optional_field(field):
+ raise TypeError(f"{field.name} is not optional")
+
+ contract = field.metadata.get("contract", None)
+
+ if contract is not None:
+ if value is not None and not contract(value):
+ raise ValueError(f"break the contract for {field.name}, {self.__class__.__name__}")
+
+ def validate(self):
+ """validate if object can serialize / deserialize correctly."""
+ self._validate_contracts()
+ if self != self.__class__.deserialize( # pylint: disable=no-value-for-parameter
+ json.loads(json.dumps(self.serialize()))
+ ):
+ raise ValueError("could not be deserialized with same value")
+
+ def to_dict(self) -> dict:
+ """Transform serializable object to dict."""
+ cls_fields = fields(self)
+ o = {}
+ for cls_field in cls_fields:
+ o[cls_field.name] = getattr(self, cls_field.name)
+ return o
+
+ def serialize(self) -> dict:
+ """Serialize object to be json serializable representation."""
+ if not is_dataclass(self):
+ raise TypeError("need to be decorated as dataclass")
+
+ dataclass_fields = fields(self)
+
+ o = {}
+
+ for field in dataclass_fields:
+ value = getattr(self, field.name)
+ value = _serialize(value)
+ o[field.name] = value
+ return o
+
+ def deserialize(self, data: dict) -> "Serializable":
+ """Parse input dictionary and desrialize its fields to a dataclass.
+
+ Returns:
+ self: deserialized `self`.
+ """
+ if not isinstance(data, dict):
+ raise ValueError()
+ data = data.copy()
+ init_kwargs = {}
+ for field in fields(self):
+ # if field.name == 'dataset_config':
+ if field.name not in data:
+ if field.name in vars(self):
+ init_kwargs[field.name] = vars(self)[field.name]
+ continue
+ raise ValueError(f' [!] Missing required field "{field.name}"')
+ value = data.get(field.name, _default_value(field))
+ if value is None:
+ init_kwargs[field.name] = value
+ continue
+ if value == MISSING:
+ raise ValueError(f"deserialized with unknown value for {field.name} in {self.__name__}")
+ value = _deserialize(value, field.type)
+ init_kwargs[field.name] = value
+ for k, v in init_kwargs.items():
+ setattr(self, k, v)
+ return self
+
+ @classmethod
+ def deserialize_immutable(cls, data: dict) -> "Serializable":
+ """Parse input dictionary and desrialize its fields to a dataclass.
+
+ Returns:
+ Newly created deserialized object.
+ """
+ if not isinstance(data, dict):
+ raise ValueError()
+ data = data.copy()
+ init_kwargs = {}
+ for field in fields(cls):
+ # if field.name == 'dataset_config':
+ if field.name not in data:
+ if field.name in vars(cls):
+ init_kwargs[field.name] = vars(cls)[field.name]
+ continue
+ # if not in cls and the default value is not Missing use it
+ default_value = _default_value(field)
+ if default_value not in (MISSING, _MISSING):
+ init_kwargs[field.name] = default_value
+ continue
+ raise ValueError(f' [!] Missing required field "{field.name}"')
+ value = data.get(field.name, _default_value(field))
+ if value is None:
+ init_kwargs[field.name] = value
+ continue
+ if value == MISSING:
+ raise ValueError(f"Deserialized with unknown value for {field.name} in {cls.__name__}")
+ value = _deserialize(value, field.type)
+ init_kwargs[field.name] = value
+ return cls(**init_kwargs)
+
+
+# ---------------------------------------------------------------------------- #
+# Argument Parsing from `argparse` #
+# ---------------------------------------------------------------------------- #
+
+
+def _get_help(field):
+ try:
+ field_help = field.metadata["help"]
+ except KeyError:
+ field_help = ""
+ return field_help
+
+
+def _init_argparse(
+ parser,
+ field_name,
+ field_type,
+ field_default,
+ field_default_factory,
+ field_help,
+ arg_prefix="",
+ help_prefix="",
+ relaxed_parser=False,
+):
+ has_default = False
+ default = None
+ if field_default:
+ has_default = True
+ default = field_default
+ elif field_default_factory not in (None, _MISSING):
+ has_default = True
+ default = field_default_factory()
+
+ if not has_default and not is_primitive_type(field_type) and not is_list(field_type):
+ # aggregate types (fields with a Coqpit subclass as type) are not supported without None
+ return parser
+ arg_prefix = field_name if arg_prefix == "" else f"{arg_prefix}.{field_name}"
+ help_prefix = field_help if help_prefix == "" else f"{help_prefix} - {field_help}"
+ if is_dict(field_type): # pylint: disable=no-else-raise
+ # NOTE: accept any string in json format as input to dict field.
+ parser.add_argument(
+ f"--{arg_prefix}",
+ dest=arg_prefix,
+ default=json.dumps(field_default) if field_default else None,
+ type=json.loads,
+ )
+ elif is_list(field_type):
+ # TODO: We need a more clear help msg for lists.
+ if hasattr(field_type, "__args__"): # if the list is hinted
+ if len(field_type.__args__) > 1 and not relaxed_parser:
+ raise ValueError(" [!] Coqpit does not support multi-type hinted 'List'")
+ list_field_type = field_type.__args__[0]
+ else:
+ raise ValueError(" [!] Coqpit does not support un-hinted 'List'")
+
+ # TODO: handle list of lists
+ if is_list(list_field_type) and relaxed_parser:
+ return parser
+
+ if not has_default or field_default_factory is list:
+ if not is_primitive_type(list_field_type) and not relaxed_parser:
+ raise NotImplementedError(" [!] Empty list with non primitive inner type is currently not supported.")
+
+ # If the list's default value is None, the user can specify the entire list by passing multiple parameters
+ parser.add_argument(
+ f"--{arg_prefix}",
+ nargs="*",
+ type=list_field_type,
+ help=f"Coqpit Field: {help_prefix}",
+ )
+ else:
+ # If a default value is defined, just enable editing the values from argparse
+ # TODO: allow inserting a new value/obj to the end of the list.
+ for idx, fv in enumerate(default):
+ parser = _init_argparse(
+ parser,
+ str(idx),
+ list_field_type,
+ fv,
+ field_default_factory,
+ field_help="",
+ help_prefix=f"{help_prefix} - ",
+ arg_prefix=f"{arg_prefix}",
+ relaxed_parser=relaxed_parser,
+ )
+ elif is_union(field_type):
+ # TODO: currently I don't know how to handle Union type on argparse
+ if not relaxed_parser:
+ raise NotImplementedError(
+ " [!] Parsing `Union` field from argparse is not yet implemented. Please create an issue."
+ )
+ elif issubclass(field_type, Serializable):
+ return default.init_argparse(
+ parser, arg_prefix=arg_prefix, help_prefix=help_prefix, relaxed_parser=relaxed_parser
+ )
+ elif isinstance(field_type(), bool):
+
+ def parse_bool(x):
+ if x not in ("true", "false"):
+ raise ValueError(f' [!] Value for boolean field must be either "true" or "false". Got "{x}".')
+ return x == "true"
+
+ parser.add_argument(
+ f"--{arg_prefix}",
+ type=parse_bool,
+ default=field_default,
+ help=f"Coqpit Field: {help_prefix}",
+ metavar="true/false",
+ )
+ elif is_primitive_type(field_type):
+ parser.add_argument(
+ f"--{arg_prefix}",
+ default=field_default,
+ type=field_type,
+ help=f"Coqpit Field: {help_prefix}",
+ )
+ else:
+ if not relaxed_parser:
+ raise NotImplementedError(f" [!] '{field_type}' is not supported by arg_parser. Please file a bug report.")
+ return parser
+
+
+# ---------------------------------------------------------------------------- #
+# Main Coqpit Class #
+# ---------------------------------------------------------------------------- #
+
+
+@dataclass
+class Coqpit(Serializable, MutableMapping):
+ """Coqpit base class to be inherited by any Coqpit dataclasses.
+ It overrides Python `dict` interface and provides `dict` compatible API.
+ It also enables serializing/deserializing a dataclass to/from a json file, plus some semi-dynamic type and value check.
+ Note that it does not support all datatypes and likely to fail in some cases.
+ """
+
+ _initialized = False
+
+ def _is_initialized(self):
+ """Check if Coqpit is initialized. Useful to prevent running some aux functions
+ at the initialization when no attribute has been defined."""
+ return "_initialized" in vars(self) and self._initialized
+
+ def __post_init__(self):
+ self._initialized = True
+ try:
+ self.check_values()
+ except AttributeError:
+ pass
+
+ ## `dict` API functions
+
+ def __iter__(self):
+ return iter(asdict(self))
+
+ def __len__(self):
+ return len(fields(self))
+
+ def __setitem__(self, arg: str, value: Any):
+ setattr(self, arg, value)
+
+ def __getitem__(self, arg: str):
+ """Access class attributes with ``[arg]``."""
+ return self.__dict__[arg]
+
+ def __delitem__(self, arg: str):
+ delattr(self, arg)
+
+ def _keytransform(self, key): # pylint: disable=no-self-use
+ return key
+
+ ## end `dict` API functions
+
+ def __getattribute__(self, arg: str): # pylint: disable=no-self-use
+ """Check if the mandatory field is defined when accessing it."""
+ value = super().__getattribute__(arg)
+ if isinstance(value, str) and value == "???":
+ raise AttributeError(f" [!] MISSING field {arg} must be defined.")
+ return value
+
+ def __contains__(self, arg: str):
+ return arg in self.to_dict()
+
+ def get(self, key: str, default: Any = None):
+ if self.has(key):
+ return asdict(self)[key]
+ return default
+
+ def items(self):
+ return asdict(self).items()
+
+ def merge(self, coqpits: Union["Coqpit", List["Coqpit"]]):
+ """Merge a coqpit instance or a list of coqpit instances to self.
+ Note that it does not pass the fields and overrides attributes with
+ the last Coqpit instance in the given List.
+ TODO: find a way to merge instances with all the class internals.
+
+ Args:
+ coqpits (Union[Coqpit, List[Coqpit]]): coqpit instance or list of instances to be merged.
+ """
+
+ def _merge(coqpit):
+ self.__dict__.update(coqpit.__dict__)
+ self.__annotations__.update(coqpit.__annotations__)
+ self.__dataclass_fields__.update(coqpit.__dataclass_fields__)
+
+ if isinstance(coqpits, list):
+ for coqpit in coqpits:
+ _merge(coqpit)
+ else:
+ _merge(coqpits)
+
+ def check_values(self):
+ pass
+
+ def has(self, arg: str) -> bool:
+ return arg in vars(self)
+
+ def copy(self):
+ return replace(self)
+
+ def update(self, new: dict, allow_new=False) -> None:
+ """Update Coqpit fields by the input ```dict```.
+
+ Args:
+ new (dict): dictionary with new values.
+ allow_new (bool, optional): allow new fields to add. Defaults to False.
+ """
+ for key, value in new.items():
+ if allow_new:
+ setattr(self, key, value)
+ else:
+ if hasattr(self, key):
+ setattr(self, key, value)
+ else:
+ raise KeyError(f" [!] No key - {key}")
+
+ def pprint(self) -> None:
+ """Print Coqpit fields in a format."""
+ pprint(asdict(self))
+
+ def to_dict(self) -> dict:
+ # return asdict(self)
+ return self.serialize()
+
+ def from_dict(self, data: dict) -> None:
+ self = self.deserialize(data) # pylint: disable=self-cls-assignment
+
+ @classmethod
+ def new_from_dict(cls: Serializable, data: dict) -> "Coqpit":
+ return cls.deserialize_immutable(data)
+
+ def to_json(self) -> str:
+ """Returns a JSON string representation."""
+ return json.dumps(asdict(self), indent=4, default=_coqpit_json_default)
+
+ def save_json(self, file_name: str) -> None:
+ """Save Coqpit to a json file.
+
+ Args:
+ file_name (str): path to the output json file.
+ """
+ with open(file_name, "w", encoding="utf8") as f:
+ json.dump(asdict(self), f, indent=4)
+
+ def load_json(self, file_name: str) -> None:
+ """Load a json file and update matching config fields with type checking.
+ Non-matching parameters in the json file are ignored.
+
+ Args:
+ file_name (str): path to the json file.
+
+ Returns:
+ Coqpit: new Coqpit with updated config fields.
+ """
+ with open(file_name, "r", encoding="utf8") as f:
+ input_str = f.read()
+ dump_dict = json.loads(input_str)
+ # TODO: this looks stupid 💆
+ self = self.deserialize(dump_dict) # pylint: disable=self-cls-assignment
+ self.check_values()
+
+ @classmethod
+ def init_from_argparse(
+ cls, args: Optional[Union[argparse.Namespace, List[str]]] = None, arg_prefix: str = "coqpit"
+ ) -> "Coqpit":
+ """Create a new Coqpit instance from argparse input.
+
+ Args:
+ args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```.
+ arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed.
+ """
+ if not args:
+ # If args was not specified, parse from sys.argv
+ parser = cls.init_argparse(cls, arg_prefix=arg_prefix)
+ args = parser.parse_args() # pylint: disable=E1120, E1111
+ if isinstance(args, list):
+ # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace
+ parser = cls.init_argparse(cls, arg_prefix=arg_prefix)
+ args = parser.parse_args(args) # pylint: disable=E1120, E1111
+
+ # Handle list and object attributes with defaults, which can be modified
+ # directly (eg. --coqpit.list.0.val_a 1), by constructing real objects
+ # from defaults and passing those to `cls.__init__`
+ args_with_lists_processed = {}
+ class_fields = fields(cls)
+ for field in class_fields:
+ has_default = False
+ default = None
+ field_default = field.default if field.default is not _MISSING else None
+ field_default_factory = field.default_factory if field.default_factory is not _MISSING else None
+ if field_default:
+ has_default = True
+ default = field_default
+ elif field_default_factory:
+ has_default = True
+ default = field_default_factory()
+
+ if has_default and (not is_primitive_type(field.type) or is_list(field.type)):
+ args_with_lists_processed[field.name] = default
+
+ args_dict = vars(args)
+ for k, v in args_dict.items():
+ # Remove argparse prefix (eg. "--coqpit." if present)
+ if k.startswith(f"{arg_prefix}."):
+ k = k[len(f"{arg_prefix}.") :]
+
+ rsetitem(args_with_lists_processed, k, v)
+
+ return cls(**args_with_lists_processed)
+
+ def parse_args(
+ self, args: Optional[Union[argparse.Namespace, List[str]]] = None, arg_prefix: str = "coqpit"
+ ) -> None:
+ """Update config values from argparse arguments with some meta-programming ✨.
+
+ Args:
+ args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```.
+ arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed.
+ """
+ if not args:
+ # If args was not specified, parse from sys.argv
+ parser = self.init_argparse(arg_prefix=arg_prefix)
+ args = parser.parse_args()
+ if isinstance(args, list):
+ # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace
+ parser = self.init_argparse(arg_prefix=arg_prefix)
+ args = parser.parse_args(args)
+
+ args_dict = vars(args)
+
+ for k, v in args_dict.items():
+ if k.startswith(f"{arg_prefix}."):
+ k = k[len(f"{arg_prefix}.") :]
+ try:
+ rgetattr(self, k)
+ except (TypeError, AttributeError) as e:
+ raise Exception(f" [!] '{k}' not exist to override from argparse.") from e
+
+ rsetattr(self, k, v)
+
+ self.check_values()
+
+ def parse_known_args(
+ self,
+ args: Optional[Union[argparse.Namespace, List[str]]] = None,
+ arg_prefix: str = "coqpit",
+ relaxed_parser=False,
+ ) -> List[str]:
+ """Update config values from argparse arguments. Ignore unknown arguments.
+ This is analog to argparse.ArgumentParser.parse_known_args (vs parse_args).
+
+ Args:
+ args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```.
+ arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed.
+ relaxed_parser (bool, optional): If True, do not force all the fields to have compatible types with the argparser. Defaults to False.
+
+ Returns:
+ List of unknown parameters.
+ """
+ if not args:
+ # If args was not specified, parse from sys.argv
+ parser = self.init_argparse(arg_prefix=arg_prefix, relaxed_parser=relaxed_parser)
+ args, unknown = parser.parse_known_args()
+ if isinstance(args, list):
+ # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace
+ parser = self.init_argparse(arg_prefix=arg_prefix, relaxed_parser=relaxed_parser)
+ args, unknown = parser.parse_known_args(args)
+
+ self.parse_args(args)
+ return unknown
+
+ def init_argparse(
+ self,
+ parser: Optional[argparse.ArgumentParser] = None,
+ arg_prefix="coqpit",
+ help_prefix="",
+ relaxed_parser=False,
+ ) -> argparse.ArgumentParser:
+ """Pass Coqpit fields as argparse arguments. This allows to edit values through command-line.
+
+ Args:
+ parser (argparse.ArgumentParser, optional): argparse.ArgumentParser instance. If unspecified a new one will be created.
+ arg_prefix (str, optional): Prefix to be used for the argument name. Defaults to 'coqpit'.
+ help_prefix (str, optional): Prefix to be used for the argument description. Defaults to ''.
+ relaxed_parser (bool, optional): If True, do not force all the fields to have compatible types with the argparser. Defaults to False.
+
+ Returns:
+ argparse.ArgumentParser: parser instance with the new arguments.
+ """
+ if not parser:
+ parser = argparse.ArgumentParser()
+ class_fields = fields(self)
+ for field in class_fields:
+ if field.name in vars(self):
+ # use the current value of the field
+ # prevent dropping the current value
+ field_default = vars(self)[field.name]
+ else:
+ # use the default value of the field
+ field_default = field.default if field.default is not _MISSING else None
+ field_type = field.type
+ field_default_factory = field.default_factory
+ field_help = _get_help(field)
+ _init_argparse(
+ parser,
+ field.name,
+ field_type,
+ field_default,
+ field_default_factory,
+ field_help,
+ arg_prefix,
+ help_prefix,
+ relaxed_parser,
+ )
+ return parser
+
+
+def check_argument(
+ name,
+ c,
+ is_path: bool = False,
+ prerequest: str = None,
+ enum_list: list = None,
+ max_val: float = None,
+ min_val: float = None,
+ restricted: bool = False,
+ alternative: str = None,
+ allow_none: bool = True,
+) -> None:
+ """Simple type and value checking for Coqpit.
+ It is intended to be used under ```__post_init__()``` of config dataclasses.
+
+ Args:
+ name (str): name of the field to be checked.
+ c (dict): config dictionary.
+ is_path (bool, optional): if ```True``` check if the path is exist. Defaults to False.
+ prerequest (list or str, optional): a list of field name that are prerequestedby the target field name.
+ Defaults to ```[]```.
+ enum_list (list, optional): list of possible values for the target field. Defaults to None.
+ max_val (float, optional): maximum possible value for the target field. Defaults to None.
+ min_val (float, optional): minimum possible value for the target field. Defaults to None.
+ restricted (bool, optional): if ```True``` the target field has to be defined. Defaults to False.
+ alternative (str, optional): a field name superceding the target field. Defaults to None.
+ allow_none (bool, optional): if ```True``` allow the target field to be ```None```. Defaults to False.
+
+
+ Example:
+ >>> num_mels = 5
+ >>> check_argument('num_mels', c, restricted=True, min_val=10, max_val=2056)
+ >>> fft_size = 128
+ >>> check_argument('fft_size', c, restricted=True, min_val=128, max_val=4058)
+ """
+ # check if None allowed
+ if allow_none and c[name] is None:
+ return
+ if not allow_none:
+ assert c[name] is not None, f" [!] None value is not allowed for {name}."
+ # check if restricted and it it is check if it exists
+ if isinstance(restricted, bool) and restricted:
+ assert name in c.keys(), f" [!] {name} not defined in config.json"
+ # check prerequest fields are defined
+ if isinstance(prerequest, list):
+ assert any(
+ f not in c.keys() for f in prerequest
+ ), f" [!] prequested fields {prerequest} for {name} are not defined."
+ else:
+ assert (
+ prerequest is None or prerequest in c.keys()
+ ), f" [!] prequested fields {prerequest} for {name} are not defined."
+ # check if the path exists
+ if is_path:
+ assert os.path.exists(c[name]), f' [!] path for {name} ("{c[name]}") does not exist.'
+ # skip the rest if the alternative field is defined.
+ if alternative in c.keys() and c[alternative] is not None:
+ return
+ # check value constraints
+ if name in c.keys():
+ if max_val is not None:
+ assert c[name] <= max_val, f" [!] {name} is larger than max value {max_val}"
+ if min_val is not None:
+ assert c[name] >= min_val, f" [!] {name} is smaller than min value {min_val}"
+ if enum_list is not None:
+ assert c[name].lower() in enum_list, f" [!] {name} is not a valid value"
diff --git a/speaker/utils/io.py b/speaker/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d4c07940d872cb6773d388029595aecf67e4408
--- /dev/null
+++ b/speaker/utils/io.py
@@ -0,0 +1,198 @@
+import datetime
+import json
+import os
+import pickle as pickle_tts
+import shutil
+from typing import Any, Callable, Dict, Union
+
+import fsspec
+import torch
+from .coqpit import Coqpit
+
+
+class RenamingUnpickler(pickle_tts.Unpickler):
+ """Overload default pickler to solve module renaming problem"""
+
+ def find_class(self, module, name):
+ return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name)
+
+
+class AttrDict(dict):
+ """A custom dict which converts dict keys
+ to class attributes"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+def copy_model_files(config: Coqpit, out_path, new_fields):
+ """Copy config.json and other model files to training folder and add
+ new fields.
+
+ Args:
+ config (Coqpit): Coqpit config defining the training run.
+ out_path (str): output path to copy the file.
+ new_fields (dict): new fileds to be added or edited
+ in the config file.
+ """
+ copy_config_path = os.path.join(out_path, "config.json")
+ # add extra information fields
+ config.update(new_fields, allow_new=True)
+ # TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
+ with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
+ json.dump(config.to_dict(), f, indent=4)
+
+ # copy model stats file if available
+ if config.audio.stats_path is not None:
+ copy_stats_path = os.path.join(out_path, "scale_stats.npy")
+ filesystem = fsspec.get_mapper(copy_stats_path).fs
+ if not filesystem.exists(copy_stats_path):
+ with fsspec.open(config.audio.stats_path, "rb") as source_file:
+ with fsspec.open(copy_stats_path, "wb") as target_file:
+ shutil.copyfileobj(source_file, target_file)
+
+
+def load_fsspec(
+ path: str,
+ map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
+ **kwargs,
+) -> Any:
+ """Like torch.load but can load from other locations (e.g. s3:// , gs://).
+
+ Args:
+ path: Any path or url supported by fsspec.
+ map_location: torch.device or str.
+ **kwargs: Keyword arguments forwarded to torch.load.
+
+ Returns:
+ Object stored in path.
+ """
+ with fsspec.open(path, "rb") as f:
+ return torch.load(f, map_location=map_location, **kwargs)
+
+
+def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
+ try:
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
+ except ModuleNotFoundError:
+ pickle_tts.Unpickler = RenamingUnpickler
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
+ model.load_state_dict(state["model"])
+ if use_cuda:
+ model.cuda()
+ if eval:
+ model.eval()
+ return model, state
+
+
+def save_fsspec(state: Any, path: str, **kwargs):
+ """Like torch.save but can save to other locations (e.g. s3:// , gs://).
+
+ Args:
+ state: State object to save
+ path: Any path or url supported by fsspec.
+ **kwargs: Keyword arguments forwarded to torch.save.
+ """
+ with fsspec.open(path, "wb") as f:
+ torch.save(state, f, **kwargs)
+
+
+def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
+ if hasattr(model, "module"):
+ model_state = model.module.state_dict()
+ else:
+ model_state = model.state_dict()
+ if isinstance(optimizer, list):
+ optimizer_state = [optim.state_dict() for optim in optimizer]
+ else:
+ optimizer_state = optimizer.state_dict() if optimizer is not None else None
+
+ if isinstance(scaler, list):
+ scaler_state = [s.state_dict() for s in scaler]
+ else:
+ scaler_state = scaler.state_dict() if scaler is not None else None
+
+ if isinstance(config, Coqpit):
+ config = config.to_dict()
+
+ state = {
+ "config": config,
+ "model": model_state,
+ "optimizer": optimizer_state,
+ "scaler": scaler_state,
+ "step": current_step,
+ "epoch": epoch,
+ "date": datetime.date.today().strftime("%B %d, %Y"),
+ }
+ state.update(kwargs)
+ save_fsspec(state, output_path)
+
+
+def save_checkpoint(
+ config,
+ model,
+ optimizer,
+ scaler,
+ current_step,
+ epoch,
+ output_folder,
+ **kwargs,
+):
+ file_name = "checkpoint_{}.pth.tar".format(current_step)
+ checkpoint_path = os.path.join(output_folder, file_name)
+ print("\n > CHECKPOINT : {}".format(checkpoint_path))
+ save_model(
+ config,
+ model,
+ optimizer,
+ scaler,
+ current_step,
+ epoch,
+ checkpoint_path,
+ **kwargs,
+ )
+
+
+def save_best_model(
+ current_loss,
+ best_loss,
+ config,
+ model,
+ optimizer,
+ scaler,
+ current_step,
+ epoch,
+ out_path,
+ keep_all_best=False,
+ keep_after=10000,
+ **kwargs,
+):
+ if current_loss < best_loss:
+ best_model_name = f"best_model_{current_step}.pth.tar"
+ checkpoint_path = os.path.join(out_path, best_model_name)
+ print(" > BEST MODEL : {}".format(checkpoint_path))
+ save_model(
+ config,
+ model,
+ optimizer,
+ scaler,
+ current_step,
+ epoch,
+ checkpoint_path,
+ model_loss=current_loss,
+ **kwargs,
+ )
+ fs = fsspec.get_mapper(out_path).fs
+ # only delete previous if current is saved successfully
+ if not keep_all_best or (current_step < keep_after):
+ model_names = fs.glob(os.path.join(out_path, "best_model*.pth.tar"))
+ for model_name in model_names:
+ if os.path.basename(model_name) != best_model_name:
+ fs.rm(model_name)
+ # create a shortcut which always points to the currently best model
+ shortcut_name = "best_model.pth.tar"
+ shortcut_path = os.path.join(out_path, shortcut_name)
+ fs.copy(checkpoint_path, shortcut_path)
+ best_loss = current_loss
+ return best_loss
diff --git a/speaker/utils/shared_configs.py b/speaker/utils/shared_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..a89d3a91c31679989b60b657fe8ef6ace5f02552
--- /dev/null
+++ b/speaker/utils/shared_configs.py
@@ -0,0 +1,342 @@
+from dataclasses import asdict, dataclass
+from typing import List
+
+from .coqpit import Coqpit, check_argument
+
+
+@dataclass
+class BaseAudioConfig(Coqpit):
+ """Base config to definge audio processing parameters. It is used to initialize
+ ```TTS.utils.audio.AudioProcessor.```
+
+ Args:
+ fft_size (int):
+ Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024.
+
+ win_length (int):
+ Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match
+ ```fft_size```. Defaults to 1024.
+
+ hop_length (int):
+ Number of audio samples between adjacent STFT columns. Defaults to 1024.
+
+ frame_shift_ms (int):
+ Set ```hop_length``` based on milliseconds and sampling rate.
+
+ frame_length_ms (int):
+ Set ```win_length``` based on milliseconds and sampling rate.
+
+ stft_pad_mode (str):
+ Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'.
+
+ sample_rate (int):
+ Audio sampling rate. Defaults to 22050.
+
+ resample (bool):
+ Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```.
+
+ preemphasis (float):
+ Preemphasis coefficient. Defaults to 0.0.
+
+ ref_level_db (int): 20
+ Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air.
+ Defaults to 20.
+
+ do_sound_norm (bool):
+ Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.
+
+ log_func (str):
+ Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'.
+
+ do_trim_silence (bool):
+ Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.
+
+ do_amp_to_db_linear (bool, optional):
+ enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
+
+ do_amp_to_db_mel (bool, optional):
+ enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
+
+ trim_db (int):
+ Silence threshold used for silence trimming. Defaults to 45.
+
+ power (float):
+ Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the
+ artifacts in the synthesized voice. Defaults to 1.5.
+
+ griffin_lim_iters (int):
+ Number of Griffing Lim iterations. Defaults to 60.
+
+ num_mels (int):
+ Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80.
+
+ mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices.
+ It needs to be adjusted for a dataset. Defaults to 0.
+
+ mel_fmax (float):
+ Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset.
+
+ spec_gain (int):
+ Gain applied when converting amplitude to DB. Defaults to 20.
+
+ signal_norm (bool):
+ enable/disable signal normalization. Defaults to True.
+
+ min_level_db (int):
+ minimum db threshold for the computed melspectrograms. Defaults to -100.
+
+ symmetric_norm (bool):
+ enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else
+ [0, k], Defaults to True.
+
+ max_norm (float):
+ ```k``` defining the normalization range. Defaults to 4.0.
+
+ clip_norm (bool):
+ enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
+
+ stats_path (str):
+ Path to the computed stats file. Defaults to None.
+ """
+
+ # stft parameters
+ fft_size: int = 1024
+ win_length: int = 1024
+ hop_length: int = 256
+ frame_shift_ms: int = None
+ frame_length_ms: int = None
+ stft_pad_mode: str = "reflect"
+ # audio processing parameters
+ sample_rate: int = 22050
+ resample: bool = False
+ preemphasis: float = 0.0
+ ref_level_db: int = 20
+ do_sound_norm: bool = False
+ log_func: str = "np.log10"
+ # silence trimming
+ do_trim_silence: bool = True
+ trim_db: int = 45
+ # griffin-lim params
+ power: float = 1.5
+ griffin_lim_iters: int = 60
+ # mel-spec params
+ num_mels: int = 80
+ mel_fmin: float = 0.0
+ mel_fmax: float = None
+ spec_gain: int = 20
+ do_amp_to_db_linear: bool = True
+ do_amp_to_db_mel: bool = True
+ # normalization params
+ signal_norm: bool = True
+ min_level_db: int = -100
+ symmetric_norm: bool = True
+ max_norm: float = 4.0
+ clip_norm: bool = True
+ stats_path: str = None
+
+ def check_values(
+ self,
+ ):
+ """Check config fields"""
+ c = asdict(self)
+ check_argument("num_mels", c, restricted=True, min_val=10, max_val=2056)
+ check_argument("fft_size", c, restricted=True, min_val=128, max_val=4058)
+ check_argument("sample_rate", c, restricted=True, min_val=512, max_val=100000)
+ check_argument(
+ "frame_length_ms",
+ c,
+ restricted=True,
+ min_val=10,
+ max_val=1000,
+ alternative="win_length",
+ )
+ check_argument("frame_shift_ms", c, restricted=True, min_val=1, max_val=1000, alternative="hop_length")
+ check_argument("preemphasis", c, restricted=True, min_val=0, max_val=1)
+ check_argument("min_level_db", c, restricted=True, min_val=-1000, max_val=10)
+ check_argument("ref_level_db", c, restricted=True, min_val=0, max_val=1000)
+ check_argument("power", c, restricted=True, min_val=1, max_val=5)
+ check_argument("griffin_lim_iters", c, restricted=True, min_val=10, max_val=1000)
+
+ # normalization parameters
+ check_argument("signal_norm", c, restricted=True)
+ check_argument("symmetric_norm", c, restricted=True)
+ check_argument("max_norm", c, restricted=True, min_val=0.1, max_val=1000)
+ check_argument("clip_norm", c, restricted=True)
+ check_argument("mel_fmin", c, restricted=True, min_val=0.0, max_val=1000)
+ check_argument("mel_fmax", c, restricted=True, min_val=500.0, allow_none=True)
+ check_argument("spec_gain", c, restricted=True, min_val=1, max_val=100)
+ check_argument("do_trim_silence", c, restricted=True)
+ check_argument("trim_db", c, restricted=True)
+
+
+@dataclass
+class BaseDatasetConfig(Coqpit):
+ """Base config for TTS datasets.
+
+ Args:
+ name (str):
+ Dataset name that defines the preprocessor in use. Defaults to None.
+
+ path (str):
+ Root path to the dataset files. Defaults to None.
+
+ meta_file_train (str):
+ Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets.
+ Defaults to None.
+
+ unused_speakers (List):
+ List of speakers IDs that are not used at the training. Default None.
+
+ meta_file_val (str):
+ Name of the dataset meta file that defines the instances used at validation.
+
+ meta_file_attn_mask (str):
+ Path to the file that lists the attention mask files used with models that require attention masks to
+ train the duration predictor.
+ """
+
+ name: str = ""
+ path: str = ""
+ meta_file_train: str = ""
+ ununsed_speakers: List[str] = None
+ meta_file_val: str = ""
+ meta_file_attn_mask: str = ""
+
+ def check_values(
+ self,
+ ):
+ """Check config fields"""
+ c = asdict(self)
+ check_argument("name", c, restricted=True)
+ check_argument("path", c, restricted=True)
+ check_argument("meta_file_train", c, restricted=True)
+ check_argument("meta_file_val", c, restricted=False)
+ check_argument("meta_file_attn_mask", c, restricted=False)
+
+
+@dataclass
+class BaseTrainingConfig(Coqpit):
+ """Base config to define the basic training parameters that are shared
+ among all the models.
+
+ Args:
+ model (str):
+ Name of the model that is used in the training.
+
+ run_name (str):
+ Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`.
+
+ run_description (str):
+ Short description of the experiment.
+
+ epochs (int):
+ Number training epochs. Defaults to 10000.
+
+ batch_size (int):
+ Training batch size.
+
+ eval_batch_size (int):
+ Validation batch size.
+
+ mixed_precision (bool):
+ Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however
+ it may also cause numerical unstability in some cases.
+
+ scheduler_after_epoch (bool):
+ If true, run the scheduler step after each epoch else run it after each model step.
+
+ run_eval (bool):
+ Enable / Disable evaluation (validation) run. Defaults to True.
+
+ test_delay_epochs (int):
+ Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful
+ results, hence waiting for a couple of epochs might save some time.
+
+ print_eval (bool):
+ Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at
+ the end of the evaluation. Default to ```False```.
+
+ print_step (int):
+ Number of steps required to print the next training log.
+
+ log_dashboard (str): "tensorboard" or "wandb"
+ Set the experiment tracking tool
+
+ plot_step (int):
+ Number of steps required to log training on Tensorboard.
+
+ model_param_stats (bool):
+ Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging.
+ Defaults to ```False```.
+
+ project_name (str):
+ Name of the project. Defaults to config.model
+
+ wandb_entity (str):
+ Name of W&B entity/team. Enables collaboration across a team or org.
+
+ log_model_step (int):
+ Number of steps required to log a checkpoint as W&B artifact
+
+ save_step (int):ipt
+ Number of steps required to save the next checkpoint.
+
+ checkpoint (bool):
+ Enable / Disable checkpointing.
+
+ keep_all_best (bool):
+ Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults
+ to ```False```.
+
+ keep_after (int):
+ Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults
+ to 10000.
+
+ num_loader_workers (int):
+ Number of workers for training time dataloader.
+
+ num_eval_loader_workers (int):
+ Number of workers for evaluation time dataloader.
+
+ output_path (str):
+ Path for training output folder, either a local file path or other
+ URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or
+ S3 (s3://) paths. The nonexist part of the given path is created
+ automatically. All training artefacts are saved there.
+ """
+
+ model: str = None
+ run_name: str = "coqui_tts"
+ run_description: str = ""
+ # training params
+ epochs: int = 10000
+ batch_size: int = None
+ eval_batch_size: int = None
+ mixed_precision: bool = False
+ scheduler_after_epoch: bool = False
+ # eval params
+ run_eval: bool = True
+ test_delay_epochs: int = 0
+ print_eval: bool = False
+ # logging
+ dashboard_logger: str = "tensorboard"
+ print_step: int = 25
+ plot_step: int = 100
+ model_param_stats: bool = False
+ project_name: str = None
+ log_model_step: int = None
+ wandb_entity: str = None
+ # checkpointing
+ save_step: int = 10000
+ checkpoint: bool = True
+ keep_all_best: bool = False
+ keep_after: int = 10000
+ # dataloading
+ num_loader_workers: int = 0
+ num_eval_loader_workers: int = 0
+ use_noise_augment: bool = False
+ # paths
+ output_path: str = None
+ # distributed
+ distributed_backend: str = "nccl"
+ distributed_url: str = "tcp://localhost:54321"
diff --git a/speaker_pretrain/.DS_Store b/speaker_pretrain/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6
Binary files /dev/null and b/speaker_pretrain/.DS_Store differ
diff --git a/speaker_pretrain/README.md b/speaker_pretrain/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b585941e1639b2dee0ef76773c27a1dad4407e83
--- /dev/null
+++ b/speaker_pretrain/README.md
@@ -0,0 +1,5 @@
+Path for:
+
+ best_model.pth.tar
+
+ config.json
diff --git a/speaker_pretrain/config.json b/speaker_pretrain/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..e330aabe8aba41a76af1250f90bb35cbe15d3cdc
--- /dev/null
+++ b/speaker_pretrain/config.json
@@ -0,0 +1,104 @@
+{
+ "model_name": "lstm",
+ "run_name": "mueller91",
+ "run_description": "train speaker encoder with voxceleb1, voxceleb2 and libriSpeech ",
+ "audio":{
+ // Audio processing parameters
+ "num_mels": 80, // size of the mel spec frame.
+ "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
+ "sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
+ "win_length": 1024, // stft window length in ms.
+ "hop_length": 256, // stft window hop-lengh in ms.
+ "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
+ "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
+ "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
+ "min_level_db": -100, // normalization range
+ "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
+ "power": 1.5, // value to sharpen wav signals after GL algorithm.
+ "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
+ // Normalization parameters
+ "signal_norm": true, // normalize the spec values in range [0, 1]
+ "symmetric_norm": true, // move normalization to range [-1, 1]
+ "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
+ "clip_norm": true, // clip normalized values into the range.
+ "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
+ "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
+ "do_trim_silence": true, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
+ "trim_db": 60 // threshold for timming silence. Set this according to your dataset.
+ },
+ "reinit_layers": [],
+ "loss": "angleproto", // "ge2e" to use Generalized End-to-End loss and "angleproto" to use Angular Prototypical loss (new SOTA)
+ "grad_clip": 3.0, // upper limit for gradients for clipping.
+ "epochs": 1000, // total number of epochs to train.
+ "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
+ "lr_decay": false, // if true, Noam learning rate decaying is applied through training.
+ "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
+ "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
+ "steps_plot_stats": 10, // number of steps to plot embeddings.
+ "num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
+ "voice_len": 2.0, // size of the voice
+ "num_utters_per_speaker": 10, //
+ "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
+ "wd": 0.000001, // Weight decay weight.
+ "checkpoint": true, // If true, it saves checkpoints per "save_step"
+ "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
+ "print_step": 20, // Number of steps to log traning on console.
+ "output_path": "../../OutputsMozilla/checkpoints/speaker_encoder/", // DATASET-RELATED: output path for all training outputs.
+ "model": {
+ "input_dim": 80,
+ "proj_dim": 256,
+ "lstm_dim": 768,
+ "num_lstm_layers": 3,
+ "use_lstm_with_projection": true
+ },
+ "storage": {
+ "sample_from_storage_p": 0.9, // the probability with which we'll sample from the DataSet in-memory storage
+ "storage_size": 25, // the size of the in-memory storage with respect to a single batch
+ "additive_noise": 1e-5 // add very small gaussian noise to the data in order to increase robustness
+ },
+ "datasets":
+ [
+ {
+ "name": "vctk_slim",
+ "path": "../../../audio-datasets/en/VCTK-Corpus/",
+ "meta_file_train": null,
+ "meta_file_val": null
+ },
+ {
+ "name": "libri_tts",
+ "path": "../../../audio-datasets/en/LibriTTS/train-clean-100",
+ "meta_file_train": null,
+ "meta_file_val": null
+ },
+ {
+ "name": "libri_tts",
+ "path": "../../../audio-datasets/en/LibriTTS/train-clean-360",
+ "meta_file_train": null,
+ "meta_file_val": null
+ },
+ {
+ "name": "libri_tts",
+ "path": "../../../audio-datasets/en/LibriTTS/train-other-500",
+ "meta_file_train": null,
+ "meta_file_val": null
+ },
+ {
+ "name": "voxceleb1",
+ "path": "../../../audio-datasets/en/voxceleb1/",
+ "meta_file_train": null,
+ "meta_file_val": null
+ },
+ {
+ "name": "voxceleb2",
+ "path": "../../../audio-datasets/en/voxceleb2/",
+ "meta_file_train": null,
+ "meta_file_val": null
+ },
+ {
+ "name": "common_voice",
+ "path": "../../../audio-datasets/en/MozillaCommonVoice",
+ "meta_file_train": "train.tsv",
+ "meta_file_val": "test.tsv"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/svc_eva.py b/svc_eva.py
new file mode 100644
index 0000000000000000000000000000000000000000..905d34e7e432299d2aa3bf9a500b178569bbd96f
--- /dev/null
+++ b/svc_eva.py
@@ -0,0 +1,20 @@
+import os
+import numpy as np
+
+# average -> ave -> eva :haha
+
+eva_conf = {
+ './configs/singers/singer0022.npy': 0,
+ './configs/singers/singer0030.npy': 0,
+ './configs/singers/singer0047.npy': 0.5,
+ './configs/singers/singer0051.npy': 0.5,
+}
+
+if __name__ == "__main__":
+
+ eva = np.zeros(256)
+ for k, v in eva_conf.items():
+ assert os.path.isfile(k), k
+ spk = np.load(k)
+ eva = eva + spk * v
+ np.save("eva.spk.npy", eva, allow_pickle=False)
diff --git a/svc_export.py b/svc_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..13dea0c9a8f9aedfe9cfb77d1d1b81fcb5b922bb
--- /dev/null
+++ b/svc_export.py
@@ -0,0 +1,68 @@
+import sys,os
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+import torch
+import argparse
+from omegaconf import OmegaConf
+
+from vits.models import SynthesizerInfer
+
+
+def load_model(checkpoint_path, model):
+ assert os.path.isfile(checkpoint_path)
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
+ saved_state_dict = checkpoint_dict["model_g"]
+ if hasattr(model, "module"):
+ state_dict = model.module.state_dict()
+ else:
+ state_dict = model.state_dict()
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ try:
+ new_state_dict[k] = saved_state_dict[k]
+ except:
+ new_state_dict[k] = v
+ if hasattr(model, "module"):
+ model.module.load_state_dict(new_state_dict)
+ else:
+ model.load_state_dict(new_state_dict)
+ return model
+
+
+def save_pretrain(checkpoint_path, save_path):
+ assert os.path.isfile(checkpoint_path)
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
+ torch.save({
+ 'model_g': checkpoint_dict['model_g'],
+ 'model_d': checkpoint_dict['model_d'],
+ }, save_path)
+
+
+def save_model(model, checkpoint_path):
+ if hasattr(model, 'module'):
+ state_dict = model.module.state_dict()
+ else:
+ state_dict = model.state_dict()
+ torch.save({'model_g': state_dict}, checkpoint_path)
+
+
+def main(args):
+ hp = OmegaConf.load(args.config)
+ model = SynthesizerInfer(
+ hp.data.filter_length // 2 + 1,
+ hp.data.segment_size // hp.data.hop_length,
+ hp)
+
+ # save_pretrain(args.checkpoint_path, "sovits5.0.pretrain.pth")
+ load_model(args.checkpoint_path, model)
+ save_model(model, "sovits5.0.pth")
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-c', '--config', type=str, required=True,
+ help="yaml file for config. will use hp_str from checkpoint if not given.")
+ parser.add_argument('-p', '--checkpoint_path', type=str, required=True,
+ help="path of checkpoint pt file for evaluation")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/svc_inference.py b/svc_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e8560eab01bcccbf9118780225a520f9107001f
--- /dev/null
+++ b/svc_inference.py
@@ -0,0 +1,241 @@
+import logging
+import sys,os
+from pathlib import Path
+
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+import torch
+import argparse
+import numpy as np
+
+from omegaconf import OmegaConf
+from scipy.io.wavfile import write
+from vits.models import SynthesizerInfer
+from pitch import load_csv_pitch
+from feature_retrieval import IRetrieval, DummyRetrieval, FaissIndexRetrieval, load_retrieve_index
+
+logger = logging.getLogger(__name__)
+
+
+def get_speaker_name_from_path(speaker_path: Path) -> str:
+ suffixes = "".join(speaker_path.suffixes)
+ filename = speaker_path.name
+ return filename.rstrip(suffixes)
+
+
+def create_retrival(cli_args) -> IRetrieval:
+ if not cli_args.enable_retrieval:
+ logger.info("infer without retrival")
+ return DummyRetrieval()
+ else:
+ logger.info("load index retrival model")
+
+ speaker_name = get_speaker_name_from_path(Path(args.spk))
+ base_path = Path(".").absolute() / "data_svc" / "indexes" / speaker_name
+
+ if cli_args.hubert_index_path:
+ hubert_index_filepath = cli_args.hubert_index_path
+ else:
+ index_name = f"{cli_args.retrieval_index_prefix}hubert.index"
+ hubert_index_filepath = base_path / index_name
+
+ if cli_args.whisper_index_path:
+ whisper_index_filepath = cli_args.whisper_index_path
+ else:
+ index_name = f"{cli_args.retrieval_index_prefix}whisper.index"
+ whisper_index_filepath = base_path / index_name
+
+ return FaissIndexRetrieval(
+ hubert_index=load_retrieve_index(
+ filepath=hubert_index_filepath,
+ ratio=cli_args.retrieval_ratio,
+ n_nearest_vectors=cli_args.n_retrieval_vectors
+ ),
+ whisper_index=load_retrieve_index(
+ filepath=whisper_index_filepath,
+ ratio=cli_args.retrieval_ratio,
+ n_nearest_vectors=cli_args.n_retrieval_vectors
+ ),
+ )
+
+
+def load_svc_model(checkpoint_path, model):
+ assert os.path.isfile(checkpoint_path)
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
+ saved_state_dict = checkpoint_dict["model_g"]
+ state_dict = model.state_dict()
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ try:
+ new_state_dict[k] = saved_state_dict[k]
+ except:
+ print("%s is not in the checkpoint" % k)
+ new_state_dict[k] = v
+ model.load_state_dict(new_state_dict)
+ return model
+
+
+def svc_infer(model, retrieval: IRetrieval, spk, pit, ppg, vec, hp, device):
+ len_pit = pit.size()[0]
+ len_vec = vec.size()[0]
+ len_ppg = ppg.size()[0]
+ len_min = min(len_pit, len_vec)
+ len_min = min(len_min, len_ppg)
+ pit = pit[:len_min]
+ vec = vec[:len_min, :]
+ ppg = ppg[:len_min, :]
+
+ with torch.no_grad():
+ spk = spk.unsqueeze(0).to(device)
+ source = pit.unsqueeze(0).to(device)
+ source = model.pitch2source(source)
+ pitwav = model.source2wav(source)
+ write("svc_out_pit.wav", hp.data.sampling_rate, pitwav)
+
+ hop_size = hp.data.hop_length
+ all_frame = len_min
+ hop_frame = 10
+ out_chunk = 2500 # 25 S
+ out_index = 0
+ out_audio = []
+
+ while (out_index < all_frame):
+
+ if (out_index == 0): # start frame
+ cut_s = 0
+ cut_s_out = 0
+ else:
+ cut_s = out_index - hop_frame
+ cut_s_out = hop_frame * hop_size
+
+ if (out_index + out_chunk + hop_frame > all_frame): # end frame
+ cut_e = all_frame
+ cut_e_out = -1
+ else:
+ cut_e = out_index + out_chunk + hop_frame
+ cut_e_out = -1 * hop_frame * hop_size
+
+ sub_ppg = retrieval.retriv_whisper(ppg[cut_s:cut_e, :])
+ sub_vec = retrieval.retriv_hubert(vec[cut_s:cut_e, :])
+ sub_ppg = sub_ppg.unsqueeze(0).to(device)
+ sub_vec = sub_vec.unsqueeze(0).to(device)
+ sub_pit = pit[cut_s:cut_e].unsqueeze(0).to(device)
+ sub_len = torch.LongTensor([cut_e - cut_s]).to(device)
+ sub_har = source[:, :, cut_s *
+ hop_size:cut_e * hop_size].to(device)
+ sub_out = model.inference(
+ sub_ppg, sub_vec, sub_pit, spk, sub_len, sub_har)
+ sub_out = sub_out[0, 0].data.cpu().detach().numpy()
+
+ sub_out = sub_out[cut_s_out:cut_e_out]
+ out_audio.extend(sub_out)
+ out_index = out_index + out_chunk
+
+ out_audio = np.asarray(out_audio)
+ return out_audio
+
+
+def main(args):
+ if (args.ppg == None):
+ args.ppg = "svc_tmp.ppg.npy"
+ print(
+ f"Auto run : python whisper/inference.py -w {args.wave} -p {args.ppg}")
+ os.system(f"python whisper/inference.py -w {args.wave} -p {args.ppg}")
+
+ if (args.vec == None):
+ args.vec = "svc_tmp.vec.npy"
+ print(
+ f"Auto run : python hubert/inference.py -w {args.wave} -v {args.vec}")
+ os.system(f"python hubert/inference.py -w {args.wave} -v {args.vec}")
+
+ if (args.pit == None):
+ args.pit = "svc_tmp.pit.csv"
+ print(
+ f"Auto run : python pitch/inference.py -w {args.wave} -p {args.pit}")
+ os.system(f"python pitch/inference.py -w {args.wave} -p {args.pit}")
+
+ if args.debug:
+ logging.basicConfig(level=logging.DEBUG)
+ else:
+ logging.basicConfig(level=logging.INFO)
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ hp = OmegaConf.load(args.config)
+ model = SynthesizerInfer(
+ hp.data.filter_length // 2 + 1,
+ hp.data.segment_size // hp.data.hop_length,
+ hp)
+ load_svc_model(args.model, model)
+ retrieval = create_retrival(args)
+ model.eval()
+ model.to(device)
+
+ spk = np.load(args.spk)
+ spk = torch.FloatTensor(spk)
+
+ ppg = np.load(args.ppg)
+ ppg = np.repeat(ppg, 2, 0) # 320 PPG -> 160 * 2
+ ppg = torch.FloatTensor(ppg)
+ # ppg = torch.zeros_like(ppg)
+
+ vec = np.load(args.vec)
+ vec = np.repeat(vec, 2, 0) # 320 PPG -> 160 * 2
+ vec = torch.FloatTensor(vec)
+ # vec = torch.zeros_like(vec)
+
+ pit = load_csv_pitch(args.pit)
+ print("pitch shift: ", args.shift)
+ if (args.shift == 0):
+ pass
+ else:
+ pit = np.array(pit)
+ source = pit[pit > 0]
+ source_ave = source.mean()
+ source_min = source.min()
+ source_max = source.max()
+ print(f"source pitch statics: mean={source_ave:0.1f}, \
+ min={source_min:0.1f}, max={source_max:0.1f}")
+ shift = args.shift
+ shift = 2 ** (shift / 12)
+ pit = pit * shift
+ pit = torch.FloatTensor(pit)
+
+ out_audio = svc_infer(model, retrieval, spk, pit, ppg, vec, hp, device)
+ write("svc_out.wav", hp.data.sampling_rate, out_audio)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, required=True,
+ help="yaml file for config.")
+ parser.add_argument('--model', type=str, required=True,
+ help="path of model for evaluation")
+ parser.add_argument('--wave', type=str, required=True,
+ help="Path of raw audio.")
+ parser.add_argument('--spk', type=str, required=True,
+ help="Path of speaker.")
+ parser.add_argument('--ppg', type=str,
+ help="Path of content vector.")
+ parser.add_argument('--vec', type=str,
+ help="Path of hubert vector.")
+ parser.add_argument('--pit', type=str,
+ help="Path of pitch csv file.")
+ parser.add_argument('--shift', type=int, default=0,
+ help="Pitch shift key.")
+
+ parser.add_argument('--enable-retrieval', action="store_true",
+ help="Enable index feature retrieval")
+ parser.add_argument('--retrieval-index-prefix', default='',
+ help='retrieval index file prefix. Will load file %prefix%hubert.index/%prefix%whisper.index')
+ parser.add_argument('--retrieval-ratio', type=float, default=.5,
+ help="ratio of feature retrieval effect. Must be in range 0..1")
+ parser.add_argument('--n-retrieval-vectors', type=int, default=3,
+ help="get n nearest vectors from retrieval index. Works stably in range 1..3")
+ parser.add_argument('--hubert-index-path', required=False,
+ help='path to hubert index file. Default data_svc/indexes/speaker.../%prefix%hubert.index')
+ parser.add_argument('--whisper-index-path', required=False,
+ help='path to whisper index file. Default data_svc/indexes/speaker.../%prefix%whisper.index')
+
+ parser.add_argument('--debug', action="store_true")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/svc_inference_batch.py b/svc_inference_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..f86c19ad55216f8f882f274bff61c76005edf1ad
--- /dev/null
+++ b/svc_inference_batch.py
@@ -0,0 +1,43 @@
+import sys,os
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+import tqdm
+import torch
+import argparse
+
+from whisper.inference import load_model, pred_ppg
+
+# How to use
+# python svc_inference_batch.py --config configs/base.yaml --model vits_pretrain/sovits5.0.pth --wave test_waves/ --spk configs/singers/singer0047.npy
+
+out_path = "./_svc_out"
+os.makedirs(out_path, exist_ok=True)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, required=True,
+ help="yaml file for config.")
+ parser.add_argument('--model', type=str, required=True,
+ help="path of model for evaluation")
+ parser.add_argument('--wave', type=str, required=True,
+ help="Path of raw audio.")
+ parser.add_argument('--spk', type=str, required=True,
+ help="Path of speaker.")
+ parser.add_argument('--shift', type=int, default=0,
+ help="Pitch shift key.")
+ args = parser.parse_args()
+ wave_path = args.wave
+ assert os.path.isdir(wave_path), f"{wave_path} is not folder"
+ waves = [file for file in os.listdir(wave_path) if file.endswith(".wav")]
+ for file in waves:
+ print(file)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ whisper = load_model(os.path.join("whisper_pretrain", "large-v2.pt"), device=device)
+ for file in tqdm.tqdm(waves, desc="whisper"):
+ pred_ppg(whisper, f"{wave_path}/{file}", f"{out_path}/{file}.ppg.npy", device=device)
+ del whisper
+
+ for file in tqdm.tqdm(waves, desc="svc"):
+ os.system(
+ f"python svc_inference.py --config {args.config} --model {args.model} --wave {wave_path}/{file} --ppg {out_path}/{file}.ppg.npy --spk {args.spk} --shift {args.shift}")
+ os.system(f"mv svc_out.wav {out_path}/{file}")
+ os.system(f"rm {out_path}/{file}.ppg.npy")
diff --git a/svc_inference_post.py b/svc_inference_post.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d0bc24ffde5100dbe60889b844e1223d020f7b7
--- /dev/null
+++ b/svc_inference_post.py
@@ -0,0 +1,51 @@
+import sys, os
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+import torch
+import librosa
+import argparse
+import numpy as np
+from scipy.io.wavfile import write
+from vad.utils import init_jit_model, get_speech_timestamps
+
+
+def load_audio(file: str, sr: int = 16000):
+ x, sr = librosa.load(file, sr=sr)
+ return x
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('--ref', type=str, required=True,
+ help="Path of ref audio.")
+ parser.add_argument('--svc', type=str, required=True,
+ help="Path of svc audio.")
+ parser.add_argument('--out', type=str, required=True,
+ help="Path of out audio.")
+
+ args = parser.parse_args()
+ print("svc in wave :", args.ref)
+ print("svc out wave :", args.svc)
+ print("svc post wave :", args.out)
+
+ model = init_jit_model(os.path.join('vad/assets', 'silero_vad.jit'))
+ model.eval()
+
+ ref_wave = load_audio(args.ref, sr=16000)
+ tmp_wave = torch.from_numpy(ref_wave).squeeze(0)
+ tag_wave = get_speech_timestamps(
+ tmp_wave, model, threshold=0.2, sampling_rate=16000)
+
+ ref_wave[:] = 0
+ for tag in tag_wave:
+ ref_wave[tag["start"]:tag["end"]] = 1
+
+ ref_wave = np.repeat(ref_wave, 2, -1)
+ svc_wave = load_audio(args.svc, sr=32000)
+
+ min_len = min(len(ref_wave), len(svc_wave))
+ ref_wave = ref_wave[:min_len]
+ svc_wave = svc_wave[:min_len]
+ svc_wave[ref_wave == 0] = 0
+
+ write(args.out, 32000, svc_wave)
diff --git a/svc_inference_shift.py b/svc_inference_shift.py
new file mode 100644
index 0000000000000000000000000000000000000000..6aa74a90ebdbca5d4ed9b3185aca4b8494a35b1c
--- /dev/null
+++ b/svc_inference_shift.py
@@ -0,0 +1,102 @@
+import sys,os
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+import torch
+import argparse
+import numpy as np
+
+from omegaconf import OmegaConf
+from scipy.io.wavfile import write
+from pitch import load_csv_pitch
+from vits.models import SynthesizerInfer
+from svc_inference import load_svc_model, svc_infer
+
+
+def main(args):
+ if (args.ppg == None):
+ args.ppg = "svc_tmp.ppg.npy"
+ print(
+ f"Auto run : python whisper/inference.py -w {args.wave} -p {args.ppg}")
+ os.system(f"python whisper/inference.py -w {args.wave} -p {args.ppg}")
+
+ if (args.vec == None):
+ args.vec = "svc_tmp.vec.npy"
+ print(
+ f"Auto run : python hubert/inference.py -w {args.wave} -v {args.vec}")
+ os.system(f"python hubert/inference.py -w {args.wave} -v {args.vec}")
+
+ if (args.pit == None):
+ args.pit = "svc_tmp.pit.csv"
+ print(
+ f"Auto run : python pitch/inference.py -w {args.wave} -p {args.pit}")
+ os.system(f"python pitch/inference.py -w {args.wave} -p {args.pit}")
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ hp = OmegaConf.load(args.config)
+ model = SynthesizerInfer(
+ hp.data.filter_length // 2 + 1,
+ hp.data.segment_size // hp.data.hop_length,
+ hp)
+ load_svc_model(args.model, model)
+ model.eval()
+ model.to(device)
+
+ spk = np.load(args.spk)
+ spk = torch.FloatTensor(spk)
+
+ ppg = np.load(args.ppg)
+ ppg = np.repeat(ppg, 2, 0)
+ ppg = torch.FloatTensor(ppg)
+
+ vec = np.load(args.vec)
+ vec = np.repeat(vec, 2, 0)
+ vec = torch.FloatTensor(vec)
+
+ pit = load_csv_pitch(args.pit)
+
+ shift_l = args.shift_l
+ shift_r = args.shift_r
+
+ print(f"pitch shift: [{shift_l}, {shift_r}]")
+
+ for shift in range(shift_l, shift_r + 1):
+ print(shift)
+ tmp = np.array(pit)
+ tmp = tmp * (2 ** (shift / 12))
+ tmp = torch.FloatTensor(tmp)
+
+ out_audio = svc_infer(model, spk, tmp, ppg, vec, hp, device)
+ write(os.path.join("./_svc_out", f"svc_out_{shift}.wav"),
+ hp.data.sampling_rate, out_audio)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, required=True,
+ help="yaml file for config.")
+ parser.add_argument('--model', type=str, required=True,
+ help="path of model for evaluation")
+ parser.add_argument('--wave', type=str, required=True,
+ help="Path of raw audio.")
+ parser.add_argument('--spk', type=str, required=True,
+ help="Path of speaker.")
+ parser.add_argument('--ppg', type=str,
+ help="Path of content vector.")
+ parser.add_argument('--vec', type=str,
+ help="Path of hubert vector.")
+ parser.add_argument('--pit', type=str,
+ help="Path of pitch csv file.")
+ parser.add_argument('--shift_l', type=int, default=0,
+ help="Pitch shift key for [shift_l, shift_r]")
+ parser.add_argument('--shift_r', type=int, default=0,
+ help="Pitch shift key for [shift_l, shift_r]")
+ args = parser.parse_args()
+
+ assert args.shift_l >= -12
+ assert args.shift_r >= -12
+ assert args.shift_l <= 12
+ assert args.shift_r <= 12
+ assert args.shift_l <= args.shift_r
+
+ os.makedirs("./_svc_out", exist_ok=True)
+
+ main(args)
diff --git a/svc_merge.py b/svc_merge.py
new file mode 100644
index 0000000000000000000000000000000000000000..d84f6c1ddc0461530a590e7b815bfa69bf6366e4
--- /dev/null
+++ b/svc_merge.py
@@ -0,0 +1,58 @@
+import os
+import torch
+import argparse
+import collections
+
+
+def load_model(checkpoint_path):
+ assert os.path.isfile(checkpoint_path)
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
+ saved_state_dict = checkpoint_dict["model_g"]
+ return saved_state_dict
+
+
+def save_model(state_dict, checkpoint_path):
+ torch.save({'model_g': state_dict}, checkpoint_path)
+
+
+def average_model(model_list):
+ model_keys = list(model_list[0].keys())
+ model_average = collections.OrderedDict()
+ for key in model_keys:
+ key_sum = 0
+ for i in range(len(model_list)):
+ key_sum = (key_sum + model_list[i][key])
+ model_average[key] = torch.div(key_sum, float(len(model_list)))
+ return model_average
+# ss_list = []
+# ss_list.append(s1)
+# ss_list.append(s2)
+# ss_merge = average_model(ss_list)
+
+
+def merge_model(model1, model2, rate):
+ model_keys = model1.keys()
+ model_merge = collections.OrderedDict()
+ for key in model_keys:
+ key_merge = rate * model1[key] + (1 - rate) * model2[key]
+ model_merge[key] = key_merge
+ return model_merge
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-m1', '--model1', type=str, required=True)
+ parser.add_argument('-m2', '--model2', type=str, required=True)
+ parser.add_argument('-r1', '--rate', type=float, required=True)
+ args = parser.parse_args()
+
+ print(args.model1)
+ print(args.model2)
+ print(args.rate)
+
+ assert args.rate > 0 and args.rate < 1, f"{args.rate} should be in range (0, 1)"
+ s1 = load_model(args.model1)
+ s2 = load_model(args.model2)
+
+ merge = merge_model(s1, s2, args.rate)
+ save_model(merge, "sovits5.0_merge.pth")
diff --git a/svc_preprocessing.py b/svc_preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ae5c2be49f987aba20d6d17cb7c5a8fa22ed878
--- /dev/null
+++ b/svc_preprocessing.py
@@ -0,0 +1,34 @@
+import os
+import torch
+import argparse
+import subprocess
+
+assert torch.cuda.is_available(), "\033[31m You need GPU to Train! \033[0m"
+print("CPU Count is :", os.cpu_count())
+
+parser = argparse.ArgumentParser()
+parser.add_argument("-t", type=int, default=0, help="thread count")
+args = parser.parse_args()
+
+
+commands = [
+ "python prepare/preprocess_a.py -w ./dataset_raw -o ./data_svc/waves-16k -s 16000 -t 0",
+ "python prepare/preprocess_a.py -w ./dataset_raw -o ./data_svc/waves-32k -s 32000 -t 0",
+ "python prepare/preprocess_crepe.py -w data_svc/waves-16k/ -p data_svc/pitch",
+ "python prepare/preprocess_ppg.py -w data_svc/waves-16k/ -p data_svc/whisper",
+ "python prepare/preprocess_hubert.py -w data_svc/waves-16k/ -v data_svc/hubert",
+ "python prepare/preprocess_speaker.py data_svc/waves-16k/ data_svc/speaker -t 0",
+ "python prepare/preprocess_speaker_ave.py data_svc/speaker/ data_svc/singer",
+ "python prepare/preprocess_spec.py -w data_svc/waves-32k/ -s data_svc/specs -t 0",
+ "python prepare/preprocess_train.py",
+ "python prepare/preprocess_zzz.py",
+]
+
+
+for command in commands:
+ print(f"Command: {command}")
+
+ process = subprocess.Popen(command, shell=True)
+ outcode = process.wait()
+ if (outcode):
+ break
diff --git a/svc_train_retrieval.py b/svc_train_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b8e293533ad5f29792e687db6fcc9e7ef9834ff
--- /dev/null
+++ b/svc_train_retrieval.py
@@ -0,0 +1,114 @@
+import argparse
+import logging
+import multiprocessing
+from functools import partial
+from pathlib import Path
+
+import faiss
+
+from feature_retrieval import (
+ train_index,
+ FaissIVFFlatTrainableFeatureIndexBuilder,
+ OnConditionFeatureTransform,
+ MinibatchKmeansFeatureTransform,
+ DummyFeatureTransform,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def get_speaker_list(base_path: Path):
+ speakers_path = base_path / "waves-16k"
+ if not speakers_path.exists():
+ raise FileNotFoundError(f"path {speakers_path} does not exists")
+ return [speaker_dir.name for speaker_dir in speakers_path.iterdir() if speaker_dir.is_dir()]
+
+
+def create_indexes_path(base_path: Path) -> Path:
+ indexes_path = base_path / "indexes"
+ logger.info("create indexes folder %s", indexes_path)
+ indexes_path.mkdir(exist_ok=True)
+ return indexes_path
+
+
+def create_index(
+ feature_name: str,
+ prefix: str,
+ speaker: str,
+ base_path: Path,
+ indexes_path: Path,
+ compress_features_after: int,
+ n_clusters: int,
+ n_parallel: int,
+ train_batch_size: int = 8192,
+) -> None:
+ features_path = base_path / feature_name / speaker
+ if not features_path.exists():
+ raise ValueError(f'features not found by path {features_path}')
+ index_path = indexes_path / speaker
+ index_path.mkdir(exist_ok=True)
+ index_filename = f"{prefix}{feature_name}.index"
+ index_filepath = index_path / index_filename
+ logger.debug('index will be save to %s', index_filepath)
+
+ builder = FaissIVFFlatTrainableFeatureIndexBuilder(train_batch_size, distance=faiss.METRIC_L2)
+ transform = OnConditionFeatureTransform(
+ condition=lambda matrix: matrix.shape[0] > compress_features_after,
+ on_condition=MinibatchKmeansFeatureTransform(n_clusters, n_parallel),
+ otherwise=DummyFeatureTransform()
+ )
+ train_index(features_path, index_filepath, builder, transform)
+
+
+def main() -> None:
+ arg_parser = argparse.ArgumentParser("crate faiss indexes for feature retrieval")
+ arg_parser.add_argument("--debug", action="store_true")
+ arg_parser.add_argument("--prefix", default='', help="add prefix to index filename")
+ arg_parser.add_argument('--speakers', nargs="+",
+ help="speaker names to create an index. By default all speakers are from data_svc")
+ arg_parser.add_argument("--compress-features-after", type=int, default=200_000,
+ help="If the number of features is greater than the value compress "
+ "feature vectors using MiniBatchKMeans.")
+ arg_parser.add_argument("--n-clusters", type=int, default=10_000,
+ help="Number of centroids to which features will be compressed")
+
+ arg_parser.add_argument("--n-parallel", type=int, default=multiprocessing.cpu_count()-1,
+ help="Nuber of parallel job of MinibatchKmeans. Default is cpus-1")
+ args = arg_parser.parse_args()
+
+ if args.debug:
+ logging.basicConfig(level=logging.DEBUG)
+ else:
+ logging.basicConfig(level=logging.INFO)
+
+ base_path = Path(".").absolute() / "data_svc"
+ if args.speakers:
+ speakers = args.speakers
+ else:
+ speakers = get_speaker_list(base_path)
+
+ logger.info("got %s speakers: %s", len(speakers), speakers)
+ indexes_path = create_indexes_path(base_path)
+
+ create_index_func = partial(
+ create_index,
+ prefix=args.prefix,
+ base_path=base_path,
+ indexes_path=indexes_path,
+ compress_features_after=args.compress_features_after,
+ n_clusters=args.n_clusters,
+ n_parallel=args.n_parallel,
+ )
+
+ for speaker in speakers:
+ logger.info("create hubert index for speaker %s", speaker)
+ create_index_func(feature_name="hubert", speaker=speaker)
+
+ logger.info("create whisper index for speaker %s", speaker)
+ create_index_func(feature_name="whisper", speaker=speaker)
+
+ logger.info("done!")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/svc_trainer.py b/svc_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a42186d379936b67d5dcd7c3049cc5185b24e454
--- /dev/null
+++ b/svc_trainer.py
@@ -0,0 +1,43 @@
+import sys,os
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+import argparse
+import torch
+import torch.multiprocessing as mp
+from omegaconf import OmegaConf
+
+from vits_extend.train import train
+
+torch.backends.cudnn.benchmark = True
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-c', '--config', type=str, required=True,
+ help="yaml file for configuration")
+ parser.add_argument('-p', '--checkpoint_path', type=str, default=None,
+ help="path of checkpoint pt file to resume training")
+ parser.add_argument('-n', '--name', type=str, required=True,
+ help="name of the model for logging, saving checkpoint")
+ args = parser.parse_args()
+
+ hp = OmegaConf.load(args.config)
+ with open(args.config, 'r') as f:
+ hp_str = ''.join(f.readlines())
+
+ assert hp.data.hop_length == 320, \
+ 'hp.data.hop_length must be equal to 320, got %d' % hp.data.hop_length
+
+ args.num_gpus = 0
+ torch.manual_seed(hp.train.seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(hp.train.seed)
+ args.num_gpus = torch.cuda.device_count()
+ print('Batch size per GPU :', hp.train.batch_size)
+
+ if args.num_gpus > 1:
+ mp.spawn(train, nprocs=args.num_gpus,
+ args=(args, args.checkpoint_path, hp, hp_str,))
+ else:
+ train(0, args, args.checkpoint_path, hp, hp_str)
+ else:
+ print('No GPU find!')
diff --git a/test.wav b/test.wav
new file mode 100644
index 0000000000000000000000000000000000000000..290a4c3dde752c39de9575dbbd7b194a0c09f274
--- /dev/null
+++ b/test.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9b9d96b70f6ec6a72410bff26b71004010c203c8010fd6fb0c60c4acd53fd2ec
+size 4849732
diff --git a/vad/LICENSE b/vad/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..0bf5e90cac691b999d4a35044f97167d7bbbf0b9
--- /dev/null
+++ b/vad/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020-present Silero Team
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/vad/assets/silero_vad.jit b/vad/assets/silero_vad.jit
new file mode 100644
index 0000000000000000000000000000000000000000..2a0958e90969784c90489cea36ab538a7b44384b
--- /dev/null
+++ b/vad/assets/silero_vad.jit
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:99033608562094bbb44e2363198cd47647a668f846c4c9a9edde68b4800b5fd4
+size 1439299
diff --git a/vad/utils.py b/vad/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..faf56bddcd5012e5a9877a2ba5bf3fd6db2d2df3
--- /dev/null
+++ b/vad/utils.py
@@ -0,0 +1,533 @@
+import torch
+import torchaudio
+from typing import Callable, List
+import torch.nn.functional as F
+import warnings
+
+languages = ['ru', 'en', 'de', 'es']
+
+
+class OnnxWrapper():
+
+ def __init__(self, path, force_onnx_cpu=False):
+ import numpy as np
+ global np
+ import onnxruntime
+
+ opts = onnxruntime.SessionOptions()
+ opts.inter_op_num_threads = 1
+ opts.intra_op_num_threads = 1
+
+ if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
+ self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
+ else:
+ self.session = onnxruntime.InferenceSession(path, sess_options=opts)
+
+ self.reset_states()
+ self.sample_rates = [8000, 16000]
+
+ def _validate_input(self, x, sr: int):
+ if x.dim() == 1:
+ x = x.unsqueeze(0)
+ if x.dim() > 2:
+ raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
+
+ if sr != 16000 and (sr % 16000 == 0):
+ step = sr // 16000
+ x = x[:,::step]
+ sr = 16000
+
+ if sr not in self.sample_rates:
+ raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
+
+ if sr / x.shape[1] > 31.25:
+ raise ValueError("Input audio chunk is too short")
+
+ return x, sr
+
+ def reset_states(self, batch_size=1):
+ self._h = np.zeros((2, batch_size, 64)).astype('float32')
+ self._c = np.zeros((2, batch_size, 64)).astype('float32')
+ self._last_sr = 0
+ self._last_batch_size = 0
+
+ def __call__(self, x, sr: int):
+
+ x, sr = self._validate_input(x, sr)
+ batch_size = x.shape[0]
+
+ if not self._last_batch_size:
+ self.reset_states(batch_size)
+ if (self._last_sr) and (self._last_sr != sr):
+ self.reset_states(batch_size)
+ if (self._last_batch_size) and (self._last_batch_size != batch_size):
+ self.reset_states(batch_size)
+
+ if sr in [8000, 16000]:
+ ort_inputs = {'input': x.numpy(), 'h': self._h, 'c': self._c, 'sr': np.array(sr, dtype='int64')}
+ ort_outs = self.session.run(None, ort_inputs)
+ out, self._h, self._c = ort_outs
+ else:
+ raise ValueError()
+
+ self._last_sr = sr
+ self._last_batch_size = batch_size
+
+ out = torch.tensor(out)
+ return out
+
+ def audio_forward(self, x, sr: int, num_samples: int = 512):
+ outs = []
+ x, sr = self._validate_input(x, sr)
+
+ if x.shape[1] % num_samples:
+ pad_num = num_samples - (x.shape[1] % num_samples)
+ x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
+
+ self.reset_states(x.shape[0])
+ for i in range(0, x.shape[1], num_samples):
+ wavs_batch = x[:, i:i+num_samples]
+ out_chunk = self.__call__(wavs_batch, sr)
+ outs.append(out_chunk)
+
+ stacked = torch.cat(outs, dim=1)
+ return stacked.cpu()
+
+
+class Validator():
+ def __init__(self, url, force_onnx_cpu):
+ self.onnx = True if url.endswith('.onnx') else False
+ torch.hub.download_url_to_file(url, 'inf.model')
+ if self.onnx:
+ import onnxruntime
+ if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
+ self.model = onnxruntime.InferenceSession('inf.model', providers=['CPUExecutionProvider'])
+ else:
+ self.model = onnxruntime.InferenceSession('inf.model')
+ else:
+ self.model = init_jit_model(model_path='inf.model')
+
+ def __call__(self, inputs: torch.Tensor):
+ with torch.no_grad():
+ if self.onnx:
+ ort_inputs = {'input': inputs.cpu().numpy()}
+ outs = self.model.run(None, ort_inputs)
+ outs = [torch.Tensor(x) for x in outs]
+ else:
+ outs = self.model(inputs)
+
+ return outs
+
+
+def read_audio(path: str,
+ sampling_rate: int = 16000):
+
+ wav, sr = torchaudio.load(path)
+
+ if wav.size(0) > 1:
+ wav = wav.mean(dim=0, keepdim=True)
+
+ if sr != sampling_rate:
+ transform = torchaudio.transforms.Resample(orig_freq=sr,
+ new_freq=sampling_rate)
+ wav = transform(wav)
+ sr = sampling_rate
+
+ assert sr == sampling_rate
+ return wav.squeeze(0)
+
+
+def save_audio(path: str,
+ tensor: torch.Tensor,
+ sampling_rate: int = 16000):
+ torchaudio.save(path, tensor.unsqueeze(0), sampling_rate, bits_per_sample=16)
+
+
+def init_jit_model(model_path: str,
+ device=torch.device('cpu')):
+ torch.set_grad_enabled(False)
+ model = torch.jit.load(model_path, map_location=device)
+ model.eval()
+ return model
+
+
+def make_visualization(probs, step):
+ import pandas as pd
+ pd.DataFrame({'probs': probs},
+ index=[x * step for x in range(len(probs))]).plot(figsize=(16, 8),
+ kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step],
+ xlabel='seconds',
+ ylabel='speech probability',
+ colormap='tab20')
+
+
+def get_speech_timestamps(audio: torch.Tensor,
+ model,
+ threshold: float = 0.5,
+ sampling_rate: int = 16000,
+ min_speech_duration_ms: int = 250,
+ max_speech_duration_s: float = float('inf'),
+ min_silence_duration_ms: int = 100,
+ window_size_samples: int = 512,
+ speech_pad_ms: int = 30,
+ return_seconds: bool = False,
+ visualize_probs: bool = False,
+ progress_tracking_callback: Callable[[float], None] = None):
+
+ """
+ This method is used for splitting long audios into speech chunks using silero VAD
+
+ Parameters
+ ----------
+ audio: torch.Tensor, one dimensional
+ One dimensional float torch.Tensor, other types are casted to torch if possible
+
+ model: preloaded .jit silero VAD model
+
+ threshold: float (default - 0.5)
+ Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
+ It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
+
+ sampling_rate: int (default - 16000)
+ Currently silero VAD models support 8000 and 16000 sample rates
+
+ min_speech_duration_ms: int (default - 250 milliseconds)
+ Final speech chunks shorter min_speech_duration_ms are thrown out
+
+ max_speech_duration_s: int (default - inf)
+ Maximum duration of speech chunks in seconds
+ Chunks longer than max_speech_duration_s will be split at the timestamp of the last silence that lasts more than 100ms (if any), to prevent agressive cutting.
+ Otherwise, they will be split aggressively just before max_speech_duration_s.
+
+ min_silence_duration_ms: int (default - 100 milliseconds)
+ In the end of each speech chunk wait for min_silence_duration_ms before separating it
+
+ window_size_samples: int (default - 1536 samples)
+ Audio chunks of window_size_samples size are fed to the silero VAD model.
+ WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples for 8000 sample rate.
+ Values other than these may affect model perfomance!!
+
+ speech_pad_ms: int (default - 30 milliseconds)
+ Final speech chunks are padded by speech_pad_ms each side
+
+ return_seconds: bool (default - False)
+ whether return timestamps in seconds (default - samples)
+
+ visualize_probs: bool (default - False)
+ whether draw prob hist or not
+
+ progress_tracking_callback: Callable[[float], None] (default - None)
+ callback function taking progress in percents as an argument
+
+ Returns
+ ----------
+ speeches: list of dicts
+ list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds)
+ """
+
+ if not torch.is_tensor(audio):
+ try:
+ audio = torch.Tensor(audio)
+ except:
+ raise TypeError("Audio cannot be casted to tensor. Cast it manually")
+
+ if len(audio.shape) > 1:
+ for i in range(len(audio.shape)): # trying to squeeze empty dimensions
+ audio = audio.squeeze(0)
+ if len(audio.shape) > 1:
+ raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?")
+
+ if sampling_rate > 16000 and (sampling_rate % 16000 == 0):
+ step = sampling_rate // 16000
+ sampling_rate = 16000
+ audio = audio[::step]
+ warnings.warn('Sampling rate is a multiply of 16000, casting to 16000 manually!')
+ else:
+ step = 1
+
+ if sampling_rate == 8000 and window_size_samples > 768:
+ warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 768 for 8000 sample rate!')
+ if window_size_samples not in [256, 512, 768, 1024, 1536]:
+ warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate')
+
+ model.reset_states()
+ min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
+ speech_pad_samples = sampling_rate * speech_pad_ms / 1000
+ max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples
+ min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
+ min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
+
+ audio_length_samples = len(audio)
+
+ speech_probs = []
+ for current_start_sample in range(0, audio_length_samples, window_size_samples):
+ chunk = audio[current_start_sample: current_start_sample + window_size_samples]
+ if len(chunk) < window_size_samples:
+ chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
+ speech_prob = model(chunk, sampling_rate).item()
+ speech_probs.append(speech_prob)
+ # caculate progress and seng it to callback function
+ progress = current_start_sample + window_size_samples
+ if progress > audio_length_samples:
+ progress = audio_length_samples
+ progress_percent = (progress / audio_length_samples) * 100
+ if progress_tracking_callback:
+ progress_tracking_callback(progress_percent)
+
+ triggered = False
+ speeches = []
+ current_speech = {}
+ neg_threshold = threshold - 0.15
+ temp_end = 0 # to save potential segment end (and tolerate some silence)
+ prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
+
+ for i, speech_prob in enumerate(speech_probs):
+ if (speech_prob >= threshold) and temp_end:
+ temp_end = 0
+ if next_start < prev_end:
+ next_start = window_size_samples * i
+
+ if (speech_prob >= threshold) and not triggered:
+ triggered = True
+ current_speech['start'] = window_size_samples * i
+ continue
+
+ if triggered and (window_size_samples * i) - current_speech['start'] > max_speech_samples:
+ if prev_end:
+ current_speech['end'] = prev_end
+ speeches.append(current_speech)
+ current_speech = {}
+ if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres)
+ triggered = False
+ else:
+ current_speech['start'] = next_start
+ prev_end = next_start = temp_end = 0
+ else:
+ current_speech['end'] = window_size_samples * i
+ speeches.append(current_speech)
+ current_speech = {}
+ prev_end = next_start = temp_end = 0
+ triggered = False
+ continue
+
+ if (speech_prob < neg_threshold) and triggered:
+ if not temp_end:
+ temp_end = window_size_samples * i
+ if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech : # condition to avoid cutting in very short silence
+ prev_end = temp_end
+ if (window_size_samples * i) - temp_end < min_silence_samples:
+ continue
+ else:
+ current_speech['end'] = temp_end
+ if (current_speech['end'] - current_speech['start']) > min_speech_samples:
+ speeches.append(current_speech)
+ current_speech = {}
+ prev_end = next_start = temp_end = 0
+ triggered = False
+ continue
+
+ if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples:
+ current_speech['end'] = audio_length_samples
+ speeches.append(current_speech)
+
+ for i, speech in enumerate(speeches):
+ if i == 0:
+ speech['start'] = int(max(0, speech['start'] - speech_pad_samples))
+ if i != len(speeches) - 1:
+ silence_duration = speeches[i+1]['start'] - speech['end']
+ if silence_duration < 2 * speech_pad_samples:
+ speech['end'] += int(silence_duration // 2)
+ speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - silence_duration // 2))
+ else:
+ speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
+ speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - speech_pad_samples))
+ else:
+ speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples))
+
+ if return_seconds:
+ for speech_dict in speeches:
+ speech_dict['start'] = round(speech_dict['start'] / sampling_rate, 1)
+ speech_dict['end'] = round(speech_dict['end'] / sampling_rate, 1)
+ elif step > 1:
+ for speech_dict in speeches:
+ speech_dict['start'] *= step
+ speech_dict['end'] *= step
+
+ if visualize_probs:
+ make_visualization(speech_probs, window_size_samples / sampling_rate)
+
+ return speeches
+
+
+def get_number_ts(wav: torch.Tensor,
+ model,
+ model_stride=8,
+ hop_length=160,
+ sample_rate=16000):
+ wav = torch.unsqueeze(wav, dim=0)
+ perframe_logits = model(wav)[0]
+ perframe_preds = torch.argmax(torch.softmax(perframe_logits, dim=1), dim=1).squeeze() # (1, num_frames_strided)
+ extended_preds = []
+ for i in perframe_preds:
+ extended_preds.extend([i.item()] * model_stride)
+ # len(extended_preds) is *num_frames_real*; for each frame of audio we know if it has a number in it.
+ triggered = False
+ timings = []
+ cur_timing = {}
+ for i, pred in enumerate(extended_preds):
+ if pred == 1:
+ if not triggered:
+ cur_timing['start'] = int((i * hop_length) / (sample_rate / 1000))
+ triggered = True
+ elif pred == 0:
+ if triggered:
+ cur_timing['end'] = int((i * hop_length) / (sample_rate / 1000))
+ timings.append(cur_timing)
+ cur_timing = {}
+ triggered = False
+ if cur_timing:
+ cur_timing['end'] = int(len(wav) / (sample_rate / 1000))
+ timings.append(cur_timing)
+ return timings
+
+
+def get_language(wav: torch.Tensor,
+ model):
+ wav = torch.unsqueeze(wav, dim=0)
+ lang_logits = model(wav)[2]
+ lang_pred = torch.argmax(torch.softmax(lang_logits, dim=1), dim=1).item() # from 0 to len(languages) - 1
+ assert lang_pred < len(languages)
+ return languages[lang_pred]
+
+
+def get_language_and_group(wav: torch.Tensor,
+ model,
+ lang_dict: dict,
+ lang_group_dict: dict,
+ top_n=1):
+ wav = torch.unsqueeze(wav, dim=0)
+ lang_logits, lang_group_logits = model(wav)
+
+ softm = torch.softmax(lang_logits, dim=1).squeeze()
+ softm_group = torch.softmax(lang_group_logits, dim=1).squeeze()
+
+ srtd = torch.argsort(softm, descending=True)
+ srtd_group = torch.argsort(softm_group, descending=True)
+
+ outs = []
+ outs_group = []
+ for i in range(top_n):
+ prob = round(softm[srtd[i]].item(), 2)
+ prob_group = round(softm_group[srtd_group[i]].item(), 2)
+ outs.append((lang_dict[str(srtd[i].item())], prob))
+ outs_group.append((lang_group_dict[str(srtd_group[i].item())], prob_group))
+
+ return outs, outs_group
+
+
+class VADIterator:
+ def __init__(self,
+ model,
+ threshold: float = 0.5,
+ sampling_rate: int = 16000,
+ min_silence_duration_ms: int = 100,
+ speech_pad_ms: int = 30
+ ):
+
+ """
+ Class for stream imitation
+
+ Parameters
+ ----------
+ model: preloaded .jit silero VAD model
+
+ threshold: float (default - 0.5)
+ Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
+ It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
+
+ sampling_rate: int (default - 16000)
+ Currently silero VAD models support 8000 and 16000 sample rates
+
+ min_silence_duration_ms: int (default - 100 milliseconds)
+ In the end of each speech chunk wait for min_silence_duration_ms before separating it
+
+ speech_pad_ms: int (default - 30 milliseconds)
+ Final speech chunks are padded by speech_pad_ms each side
+ """
+
+ self.model = model
+ self.threshold = threshold
+ self.sampling_rate = sampling_rate
+
+ if sampling_rate not in [8000, 16000]:
+ raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
+
+ self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
+ self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
+ self.reset_states()
+
+ def reset_states(self):
+
+ self.model.reset_states()
+ self.triggered = False
+ self.temp_end = 0
+ self.current_sample = 0
+
+ def __call__(self, x, return_seconds=False):
+ """
+ x: torch.Tensor
+ audio chunk (see examples in repo)
+
+ return_seconds: bool (default - False)
+ whether return timestamps in seconds (default - samples)
+ """
+
+ if not torch.is_tensor(x):
+ try:
+ x = torch.Tensor(x)
+ except:
+ raise TypeError("Audio cannot be casted to tensor. Cast it manually")
+
+ window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
+ self.current_sample += window_size_samples
+
+ speech_prob = self.model(x, self.sampling_rate).item()
+
+ if (speech_prob >= self.threshold) and self.temp_end:
+ self.temp_end = 0
+
+ if (speech_prob >= self.threshold) and not self.triggered:
+ self.triggered = True
+ speech_start = self.current_sample - self.speech_pad_samples
+ return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
+
+ if (speech_prob < self.threshold - 0.15) and self.triggered:
+ if not self.temp_end:
+ self.temp_end = self.current_sample
+ if self.current_sample - self.temp_end < self.min_silence_samples:
+ return None
+ else:
+ speech_end = self.temp_end + self.speech_pad_samples
+ self.temp_end = 0
+ self.triggered = False
+ return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
+
+ return None
+
+
+def collect_chunks(tss: List[dict],
+ wav: torch.Tensor):
+ chunks = []
+ for i in tss:
+ chunks.append(wav[i['start']: i['end']])
+ return torch.cat(chunks)
+
+
+def drop_chunks(tss: List[dict],
+ wav: torch.Tensor):
+ chunks = []
+ cur_start = 0
+ for i in tss:
+ chunks.append((wav[cur_start: i['start']]))
+ cur_start = i['end']
+ return torch.cat(chunks)
diff --git a/vits/LICENSE b/vits/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..6a6c3181fcdc4e20901a6ecbee5a406b78a5b560
--- /dev/null
+++ b/vits/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 Jaehyeon Kim
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/vits/__init__.py b/vits/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vits/attentions.py b/vits/attentions.py
new file mode 100644
index 0000000000000000000000000000000000000000..26624519b01497cc402dea5f860cf5022d1e7c89
--- /dev/null
+++ b/vits/attentions.py
@@ -0,0 +1,416 @@
+import copy
+import math
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from vits import commons
+from vits.modules import LayerNorm
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size=1,
+ p_dropout=0.0,
+ window_size=4,
+ **kwargs
+ ):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+
+ self.drop = nn.Dropout(p_dropout)
+ self.attn_layers = nn.ModuleList()
+ self.norm_layers_1 = nn.ModuleList()
+ self.ffn_layers = nn.ModuleList()
+ self.norm_layers_2 = nn.ModuleList()
+ for i in range(self.n_layers):
+ self.attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels,
+ hidden_channels,
+ n_heads,
+ p_dropout=p_dropout,
+ window_size=window_size,
+ )
+ )
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(
+ hidden_channels,
+ hidden_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=p_dropout,
+ )
+ )
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+ def forward(self, x, x_mask):
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+ x = x * x_mask
+ for i in range(self.n_layers):
+ y = self.attn_layers[i](x, x, attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_1[i](x + y)
+
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = self.norm_layers_2[i](x + y)
+ x = x * x_mask
+ return x
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size=1,
+ p_dropout=0.0,
+ proximal_bias=False,
+ proximal_init=True,
+ **kwargs
+ ):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.proximal_bias = proximal_bias
+ self.proximal_init = proximal_init
+
+ self.drop = nn.Dropout(p_dropout)
+ self.self_attn_layers = nn.ModuleList()
+ self.norm_layers_0 = nn.ModuleList()
+ self.encdec_attn_layers = nn.ModuleList()
+ self.norm_layers_1 = nn.ModuleList()
+ self.ffn_layers = nn.ModuleList()
+ self.norm_layers_2 = nn.ModuleList()
+ for i in range(self.n_layers):
+ self.self_attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels,
+ hidden_channels,
+ n_heads,
+ p_dropout=p_dropout,
+ proximal_bias=proximal_bias,
+ proximal_init=proximal_init,
+ )
+ )
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
+ self.encdec_attn_layers.append(
+ MultiHeadAttention(
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
+ )
+ )
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(
+ hidden_channels,
+ hidden_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=p_dropout,
+ causal=True,
+ )
+ )
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+ def forward(self, x, x_mask, h, h_mask):
+ """
+ x: decoder input
+ h: encoder output
+ """
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
+ device=x.device, dtype=x.dtype
+ )
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+ x = x * x_mask
+ for i in range(self.n_layers):
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_0[i](x + y)
+
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
+ y = self.drop(y)
+ x = self.norm_layers_1[i](x + y)
+
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = self.norm_layers_2[i](x + y)
+ x = x * x_mask
+ return x
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(
+ self,
+ channels,
+ out_channels,
+ n_heads,
+ p_dropout=0.0,
+ window_size=None,
+ heads_share=True,
+ block_length=None,
+ proximal_bias=False,
+ proximal_init=False,
+ ):
+ super().__init__()
+ assert channels % n_heads == 0
+
+ self.channels = channels
+ self.out_channels = out_channels
+ self.n_heads = n_heads
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+ self.heads_share = heads_share
+ self.block_length = block_length
+ self.proximal_bias = proximal_bias
+ self.proximal_init = proximal_init
+ self.attn = None
+
+ self.k_channels = channels // n_heads
+ self.conv_q = nn.Conv1d(channels, channels, 1)
+ self.conv_k = nn.Conv1d(channels, channels, 1)
+ self.conv_v = nn.Conv1d(channels, channels, 1)
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
+ self.drop = nn.Dropout(p_dropout)
+
+ if window_size is not None:
+ n_heads_rel = 1 if heads_share else n_heads
+ rel_stddev = self.k_channels**-0.5
+ self.emb_rel_k = nn.Parameter(
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+ * rel_stddev
+ )
+ self.emb_rel_v = nn.Parameter(
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+ * rel_stddev
+ )
+
+ nn.init.xavier_uniform_(self.conv_q.weight)
+ nn.init.xavier_uniform_(self.conv_k.weight)
+ nn.init.xavier_uniform_(self.conv_v.weight)
+ if proximal_init:
+ with torch.no_grad():
+ self.conv_k.weight.copy_(self.conv_q.weight)
+ self.conv_k.bias.copy_(self.conv_q.bias)
+
+ def forward(self, x, c, attn_mask=None):
+ q = self.conv_q(x)
+ k = self.conv_k(c)
+ v = self.conv_v(c)
+
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+ x = self.conv_o(x)
+ return x
+
+ def attention(self, query, key, value, mask=None):
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
+ b, d, t_s, t_t = (*key.size(), query.size(2))
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
+ if self.window_size is not None:
+ assert (
+ t_s == t_t
+ ), "Relative attention is only available for self-attention."
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+ rel_logits = self._matmul_with_relative_keys(
+ query / math.sqrt(self.k_channels), key_relative_embeddings
+ )
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
+ scores = scores + scores_local
+ if self.proximal_bias:
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
+ scores = scores + self._attention_bias_proximal(t_s).to(
+ device=scores.device, dtype=scores.dtype
+ )
+ if mask is not None:
+ scores = scores.masked_fill(mask == 0, -1e4)
+ if self.block_length is not None:
+ assert (
+ t_s == t_t
+ ), "Local attention is only available for self-attention."
+ block_mask = (
+ torch.ones_like(scores)
+ .triu(-self.block_length)
+ .tril(self.block_length)
+ )
+ scores = scores.masked_fill(block_mask == 0, -1e4)
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
+ p_attn = self.drop(p_attn)
+ output = torch.matmul(p_attn, value)
+ if self.window_size is not None:
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
+ value_relative_embeddings = self._get_relative_embeddings(
+ self.emb_rel_v, t_s
+ )
+ output = output + self._matmul_with_relative_values(
+ relative_weights, value_relative_embeddings
+ )
+ output = (
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
+ return output, p_attn
+
+ def _matmul_with_relative_values(self, x, y):
+ """
+ x: [b, h, l, m]
+ y: [h or 1, m, d]
+ ret: [b, h, l, d]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0))
+ return ret
+
+ def _matmul_with_relative_keys(self, x, y):
+ """
+ x: [b, h, l, d]
+ y: [h or 1, m, d]
+ ret: [b, h, l, m]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
+ return ret
+
+ def _get_relative_embeddings(self, relative_embeddings, length):
+ max_relative_position = 2 * self.window_size + 1
+ # Pad first before slice to avoid using cond ops.
+ pad_length = max(length - (self.window_size + 1), 0)
+ slice_start_position = max((self.window_size + 1) - length, 0)
+ slice_end_position = slice_start_position + 2 * length - 1
+ if pad_length > 0:
+ padded_relative_embeddings = F.pad(
+ relative_embeddings,
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
+ )
+ else:
+ padded_relative_embeddings = relative_embeddings
+ used_relative_embeddings = padded_relative_embeddings[
+ :, slice_start_position:slice_end_position
+ ]
+ return used_relative_embeddings
+
+ def _relative_position_to_absolute_position(self, x):
+ """
+ x: [b, h, l, 2*l-1]
+ ret: [b, h, l, l]
+ """
+ batch, heads, length, _ = x.size()
+ # Concat columns of pad to shift from relative to absolute indexing.
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
+
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
+ x_flat = x.view([batch, heads, length * 2 * length])
+ x_flat = F.pad(
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
+ )
+
+ # Reshape and slice out the padded elements.
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
+ :, :, :length, length - 1 :
+ ]
+ return x_final
+
+ def _absolute_position_to_relative_position(self, x):
+ """
+ x: [b, h, l, l]
+ ret: [b, h, l, 2*l-1]
+ """
+ batch, heads, length, _ = x.size()
+ # padd along column
+ x = F.pad(
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
+ )
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
+ # add 0's in the beginning that will skew the elements after reshape
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+ return x_final
+
+ def _attention_bias_proximal(self, length):
+ """Bias for self-attention to encourage attention to close positions.
+ Args:
+ length: an integer scalar.
+ Returns:
+ a Tensor with shape [1, 1, length, length]
+ """
+ r = torch.arange(length, dtype=torch.float32)
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ filter_channels,
+ kernel_size,
+ p_dropout=0.0,
+ activation=None,
+ causal=False,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.activation = activation
+ self.causal = causal
+
+ if causal:
+ self.padding = self._causal_padding
+ else:
+ self.padding = self._same_padding
+
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
+ self.drop = nn.Dropout(p_dropout)
+
+ def forward(self, x, x_mask):
+ x = self.conv_1(self.padding(x * x_mask))
+ if self.activation == "gelu":
+ x = x * torch.sigmoid(1.702 * x)
+ else:
+ x = torch.relu(x)
+ x = self.drop(x)
+ x = self.conv_2(self.padding(x * x_mask))
+ return x * x_mask
+
+ def _causal_padding(self, x):
+ if self.kernel_size == 1:
+ return x
+ pad_l = self.kernel_size - 1
+ pad_r = 0
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+ x = F.pad(x, commons.convert_pad_shape(padding))
+ return x
+
+ def _same_padding(self, x):
+ if self.kernel_size == 1:
+ return x
+ pad_l = (self.kernel_size - 1) // 2
+ pad_r = self.kernel_size // 2
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+ x = F.pad(x, commons.convert_pad_shape(padding))
+ return x
diff --git a/vits/commons.py b/vits/commons.py
new file mode 100644
index 0000000000000000000000000000000000000000..045a538d5a3ef8033eca70639a894346b11d5f61
--- /dev/null
+++ b/vits/commons.py
@@ -0,0 +1,187 @@
+import math
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+def slice_pitch_segments(x, ids_str, segment_size=4):
+ ret = torch.zeros_like(x[:, :segment_size])
+ for i in range(x.size(0)):
+ idx_str = ids_str[i]
+ idx_end = idx_str + segment_size
+ ret[i] = x[i, idx_str:idx_end]
+ return ret
+
+
+def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = x_lengths - segment_size + 1
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+ ret = slice_segments(x, ids_str, segment_size)
+ ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size)
+ return ret, ret_pitch, ids_str
+
+
+def rand_spec_segments(x, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = x_lengths - segment_size
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+ ret = slice_segments(x, ids_str, segment_size)
+ return ret, ids_str
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def kl_divergence(m_p, logs_p, m_q, logs_q):
+ """KL(P||Q)"""
+ kl = (logs_q - logs_p) - 0.5
+ kl += (
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
+ )
+ return kl
+
+
+def rand_gumbel(shape):
+ """Sample from the Gumbel distribution, protect from overflows."""
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
+ return -torch.log(-torch.log(uniform_samples))
+
+
+def rand_gumbel_like(x):
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
+ return g
+
+
+def slice_segments(x, ids_str, segment_size=4):
+ ret = torch.zeros_like(x[:, :, :segment_size])
+ for i in range(x.size(0)):
+ idx_str = ids_str[i]
+ idx_end = idx_str + segment_size
+ ret[i] = x[i, :, idx_str:idx_end]
+ return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = x_lengths - segment_size + 1
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+ ret = slice_segments(x, ids_str, segment_size)
+ return ret, ids_str
+
+
+def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
+ position = torch.arange(length, dtype=torch.float)
+ num_timescales = channels // 2
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
+ num_timescales - 1
+ )
+ inv_timescales = min_timescale * torch.exp(
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
+ )
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
+ signal = signal.view(1, channels, length)
+ return signal
+
+
+def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
+ b, channels, length = x.size()
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+ return x + signal.to(dtype=x.dtype, device=x.device)
+
+
+def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
+ b, channels, length = x.size()
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
+
+
+def subsequent_mask(length):
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
+ return mask
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a + input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def shift_1d(x):
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
+ return x
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def generate_path(duration, mask):
+ """
+ duration: [b, 1, t_x]
+ mask: [b, 1, t_y, t_x]
+ """
+ device = duration.device
+
+ b, _, t_y, t_x = mask.shape
+ cum_duration = torch.cumsum(duration, -1)
+
+ cum_duration_flat = cum_duration.view(b * t_x)
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
+ path = path.view(b, t_x, t_y)
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
+ path = path.unsqueeze(1).transpose(2, 3) * mask
+ return path
+
+
+def clip_grad_value_(parameters, clip_value, norm_type=2):
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
+ norm_type = float(norm_type)
+ if clip_value is not None:
+ clip_value = float(clip_value)
+
+ total_norm = 0
+ for p in parameters:
+ param_norm = p.grad.data.norm(norm_type)
+ total_norm += param_norm.item() ** norm_type
+ if clip_value is not None:
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
+ total_norm = total_norm ** (1.0 / norm_type)
+ return total_norm
diff --git a/vits/data_utils.py b/vits/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb9c6635f7287ffa7307893b210680a65754c898
--- /dev/null
+++ b/vits/data_utils.py
@@ -0,0 +1,325 @@
+import os
+import numpy as np
+import random
+import torch
+import torch.utils.data
+
+
+from vits.utils import load_wav_to_torch
+
+
+def load_filepaths(filename, split="|"):
+ with open(filename, encoding='utf-8') as f:
+ filepaths = [line.strip().split(split) for line in f]
+ return filepaths
+
+
+class TextAudioSpeakerSet(torch.utils.data.Dataset):
+ def __init__(self, filename, hparams):
+ self.items = load_filepaths(filename)
+ self.max_wav_value = hparams.max_wav_value
+ self.sampling_rate = hparams.sampling_rate
+ self.segment_size = hparams.segment_size
+ self.hop_length = hparams.hop_length
+ self._filter()
+ print(f'----------{len(self.items)}----------')
+
+ def _filter(self):
+ lengths = []
+ items_new = []
+ items_min = int(self.segment_size / self.hop_length * 4) # 1 S
+ items_max = int(self.segment_size / self.hop_length * 16) # 4 S
+ for wavpath, spec, pitch, vec, ppg, spk in self.items:
+ if not os.path.isfile(wavpath):
+ continue
+ if not os.path.isfile(spec):
+ continue
+ if not os.path.isfile(pitch):
+ continue
+ if not os.path.isfile(vec):
+ continue
+ if not os.path.isfile(ppg):
+ continue
+ if not os.path.isfile(spk):
+ continue
+ temp = np.load(pitch)
+ usel = int(temp.shape[0] - 1) # useful length
+ if (usel < items_min):
+ continue
+ if (usel >= items_max):
+ usel = items_max
+ items_new.append([wavpath, spec, pitch, vec, ppg, spk, usel])
+ lengths.append(usel)
+ self.items = items_new
+ self.lengths = lengths
+
+ def read_wav(self, filename):
+ audio, sampling_rate = load_wav_to_torch(filename)
+ assert sampling_rate == self.sampling_rate, f"error: this sample rate of {filename} is {sampling_rate}"
+ audio_norm = audio / self.max_wav_value
+ audio_norm = audio_norm.unsqueeze(0)
+ return audio_norm
+
+ def __getitem__(self, index):
+ return self.my_getitem(index)
+
+ def __len__(self):
+ return len(self.items)
+
+ def my_getitem(self, idx):
+ item = self.items[idx]
+ # print(item)
+ wav = item[0]
+ spe = item[1]
+ pit = item[2]
+ vec = item[3]
+ ppg = item[4]
+ spk = item[5]
+ use = item[6]
+
+ wav = self.read_wav(wav)
+ spe = torch.load(spe)
+
+ pit = np.load(pit)
+ vec = np.load(vec)
+ vec = np.repeat(vec, 2, 0) # 320 PPG -> 160 * 2
+ ppg = np.load(ppg)
+ ppg = np.repeat(ppg, 2, 0) # 320 PPG -> 160 * 2
+ spk = np.load(spk)
+
+ pit = torch.FloatTensor(pit)
+ vec = torch.FloatTensor(vec)
+ ppg = torch.FloatTensor(ppg)
+ spk = torch.FloatTensor(spk)
+
+ len_pit = pit.size()[0]
+ len_vec = vec.size()[0] - 2 # for safe
+ len_ppg = ppg.size()[0] - 2 # for safe
+ len_min = min(len_pit, len_vec)
+ len_min = min(len_min, len_ppg)
+ len_wav = len_min * self.hop_length
+
+ pit = pit[:len_min]
+ vec = vec[:len_min, :]
+ ppg = ppg[:len_min, :]
+ spe = spe[:, :len_min]
+ wav = wav[:, :len_wav]
+ if len_min > use:
+ max_frame_start = ppg.size(0) - use - 1
+ frame_start = random.randint(0, max_frame_start)
+ frame_end = frame_start + use
+
+ pit = pit[frame_start:frame_end]
+ vec = vec[frame_start:frame_end, :]
+ ppg = ppg[frame_start:frame_end, :]
+ spe = spe[:, frame_start:frame_end]
+
+ wav_start = frame_start * self.hop_length
+ wav_end = frame_end * self.hop_length
+ wav = wav[:, wav_start:wav_end]
+ # print(spe.shape)
+ # print(wav.shape)
+ # print(ppg.shape)
+ # print(pit.shape)
+ # print(spk.shape)
+ return spe, wav, ppg, vec, pit, spk
+
+
+class TextAudioSpeakerCollate:
+ """Zero-pads model inputs and targets"""
+
+ def __call__(self, batch):
+ # Right zero-pad all one-hot text sequences to max input length
+ # mel: [freq, length]
+ # wav: [1, length]
+ # ppg: [len, 1024]
+ # pit: [len]
+ # spk: [256]
+ _, ids_sorted_decreasing = torch.sort(
+ torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True
+ )
+
+ max_spe_len = max([x[0].size(1) for x in batch])
+ max_wav_len = max([x[1].size(1) for x in batch])
+ spe_lengths = torch.LongTensor(len(batch))
+ wav_lengths = torch.LongTensor(len(batch))
+ spe_padded = torch.FloatTensor(
+ len(batch), batch[0][0].size(0), max_spe_len)
+ wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
+ spe_padded.zero_()
+ wav_padded.zero_()
+
+ max_ppg_len = max([x[2].size(0) for x in batch])
+ ppg_lengths = torch.FloatTensor(len(batch))
+ ppg_padded = torch.FloatTensor(
+ len(batch), max_ppg_len, batch[0][2].size(1))
+ vec_padded = torch.FloatTensor(
+ len(batch), max_ppg_len, batch[0][3].size(1))
+ pit_padded = torch.FloatTensor(len(batch), max_ppg_len)
+ ppg_padded.zero_()
+ vec_padded.zero_()
+ pit_padded.zero_()
+ spk = torch.FloatTensor(len(batch), batch[0][5].size(0))
+
+ for i in range(len(ids_sorted_decreasing)):
+ row = batch[ids_sorted_decreasing[i]]
+
+ spe = row[0]
+ spe_padded[i, :, : spe.size(1)] = spe
+ spe_lengths[i] = spe.size(1)
+
+ wav = row[1]
+ wav_padded[i, :, : wav.size(1)] = wav
+ wav_lengths[i] = wav.size(1)
+
+ ppg = row[2]
+ ppg_padded[i, : ppg.size(0), :] = ppg
+ ppg_lengths[i] = ppg.size(0)
+
+ vec = row[3]
+ vec_padded[i, : vec.size(0), :] = vec
+
+ pit = row[4]
+ pit_padded[i, : pit.size(0)] = pit
+
+ spk[i] = row[5]
+ # print(ppg_padded.shape)
+ # print(ppg_lengths.shape)
+ # print(pit_padded.shape)
+ # print(spk.shape)
+ # print(spe_padded.shape)
+ # print(spe_lengths.shape)
+ # print(wav_padded.shape)
+ # print(wav_lengths.shape)
+ return (
+ ppg_padded,
+ ppg_lengths,
+ vec_padded,
+ pit_padded,
+ spk,
+ spe_padded,
+ spe_lengths,
+ wav_padded,
+ wav_lengths,
+ )
+
+
+class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
+ """
+ Maintain similar input lengths in a batch.
+ Length groups are specified by boundaries.
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
+ It removes samples which are not included in the boundaries.
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
+ """
+
+ def __init__(
+ self,
+ dataset,
+ batch_size,
+ boundaries,
+ num_replicas=None,
+ rank=None,
+ shuffle=True,
+ ):
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
+ self.lengths = dataset.lengths
+ self.batch_size = batch_size
+ self.boundaries = boundaries
+
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
+ self.total_size = sum(self.num_samples_per_bucket)
+ self.num_samples = self.total_size // self.num_replicas
+
+ def _create_buckets(self):
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
+ for i in range(len(self.lengths)):
+ length = self.lengths[i]
+ idx_bucket = self._bisect(length)
+ if idx_bucket != -1:
+ buckets[idx_bucket].append(i)
+
+ for i in range(len(buckets) - 1, 0, -1):
+ if len(buckets[i]) == 0:
+ buckets.pop(i)
+ self.boundaries.pop(i + 1)
+
+ num_samples_per_bucket = []
+ for i in range(len(buckets)):
+ len_bucket = len(buckets[i])
+ total_batch_size = self.num_replicas * self.batch_size
+ rem = (
+ total_batch_size - (len_bucket % total_batch_size)
+ ) % total_batch_size
+ num_samples_per_bucket.append(len_bucket + rem)
+ return buckets, num_samples_per_bucket
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+
+ indices = []
+ if self.shuffle:
+ for bucket in self.buckets:
+ indices.append(torch.randperm(
+ len(bucket), generator=g).tolist())
+ else:
+ for bucket in self.buckets:
+ indices.append(list(range(len(bucket))))
+
+ batches = []
+ for i in range(len(self.buckets)):
+ bucket = self.buckets[i]
+ len_bucket = len(bucket)
+ if (len_bucket == 0):
+ continue
+ ids_bucket = indices[i]
+ num_samples_bucket = self.num_samples_per_bucket[i]
+
+ # add extra samples to make it evenly divisible
+ rem = num_samples_bucket - len_bucket
+ ids_bucket = (
+ ids_bucket
+ + ids_bucket * (rem // len_bucket)
+ + ids_bucket[: (rem % len_bucket)]
+ )
+
+ # subsample
+ ids_bucket = ids_bucket[self.rank:: self.num_replicas]
+
+ # batching
+ for j in range(len(ids_bucket) // self.batch_size):
+ batch = [
+ bucket[idx]
+ for idx in ids_bucket[
+ j * self.batch_size: (j + 1) * self.batch_size
+ ]
+ ]
+ batches.append(batch)
+
+ if self.shuffle:
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
+ batches = [batches[i] for i in batch_ids]
+ self.batches = batches
+
+ assert len(self.batches) * self.batch_size == self.num_samples
+ return iter(self.batches)
+
+ def _bisect(self, x, lo=0, hi=None):
+ if hi is None:
+ hi = len(self.boundaries) - 1
+
+ if hi > lo:
+ mid = (hi + lo) // 2
+ if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
+ return mid
+ elif x <= self.boundaries[mid]:
+ return self._bisect(x, lo, mid)
+ else:
+ return self._bisect(x, mid + 1, hi)
+ else:
+ return -1
+
+ def __len__(self):
+ return self.num_samples // self.batch_size
diff --git a/vits/losses.py b/vits/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..9244de65482650fedea4de6e3cc26d4700ee76e9
--- /dev/null
+++ b/vits/losses.py
@@ -0,0 +1,79 @@
+import torch
+
+
+def feature_loss(fmap_r, fmap_g):
+ loss = 0
+ for dr, dg in zip(fmap_r, fmap_g):
+ for rl, gl in zip(dr, dg):
+ rl = rl.float().detach()
+ gl = gl.float()
+ loss += torch.mean(torch.abs(rl - gl))
+
+ return loss * 2
+
+
+def discriminator_loss(disc_real_outputs, disc_generated_outputs):
+ loss = 0
+ r_losses = []
+ g_losses = []
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ dr = dr.float()
+ dg = dg.float()
+ r_loss = torch.mean((1 - dr) ** 2)
+ g_loss = torch.mean(dg**2)
+ loss += r_loss + g_loss
+ r_losses.append(r_loss.item())
+ g_losses.append(g_loss.item())
+
+ return loss, r_losses, g_losses
+
+
+def generator_loss(disc_outputs):
+ loss = 0
+ gen_losses = []
+ for dg in disc_outputs:
+ dg = dg.float()
+ l = torch.mean((1 - dg) ** 2)
+ gen_losses.append(l)
+ loss += l
+
+ return loss, gen_losses
+
+
+def kl_loss(z_p, logs_q, m_p, logs_p, total_logdet, z_mask):
+ """
+ z_p, logs_q: [b, h, t_t]
+ m_p, logs_p: [b, h, t_t]
+ total_logdet: [b] - total_logdet summed over each batch
+ """
+ z_p = z_p.float()
+ logs_q = logs_q.float()
+ m_p = m_p.float()
+ logs_p = logs_p.float()
+ z_mask = z_mask.float()
+
+ kl = logs_p - logs_q - 0.5
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
+ kl = torch.sum(kl * z_mask)
+ # add total_logdet (Negative LL)
+ kl -= torch.sum(total_logdet)
+ l = kl / torch.sum(z_mask)
+ return l
+
+
+def kl_loss_back(z_p, logs_q, m_p, logs_p, z_mask):
+ """
+ z_p, logs_q: [b, h, t_t]
+ m_p, logs_p: [b, h, t_t]
+ """
+ z_p = z_p.float()
+ logs_q = logs_q.float()
+ m_p = m_p.float()
+ logs_p = logs_p.float()
+ z_mask = z_mask.float()
+
+ kl = logs_p - logs_q - 0.5
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
+ kl = torch.sum(kl * z_mask)
+ l = kl / torch.sum(z_mask)
+ return l
diff --git a/vits/models.py b/vits/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..49c74ded38ee5e3731d563b3c2cbdb2bb821a5ac
--- /dev/null
+++ b/vits/models.py
@@ -0,0 +1,256 @@
+
+import torch
+
+from torch import nn
+from torch.nn import functional as F
+from vits import attentions
+from vits import commons
+from vits import modules
+from vits.utils import f0_to_coarse
+from vits_decoder.generator import Generator
+from vits.modules_grl import SpeakerClassifier
+
+
+class TextEncoder(nn.Module):
+ def __init__(self,
+ in_channels,
+ vec_channels,
+ out_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout):
+ super().__init__()
+ self.out_channels = out_channels
+ self.pre = nn.Conv1d(in_channels, hidden_channels, kernel_size=5, padding=2)
+ self.hub = nn.Conv1d(vec_channels, hidden_channels, kernel_size=5, padding=2)
+ self.pit = nn.Embedding(256, hidden_channels)
+ self.enc = attentions.Encoder(
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout)
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+ def forward(self, x, x_lengths, v, f0):
+ x = torch.transpose(x, 1, -1) # [b, h, t]
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
+ x.dtype
+ )
+ x = self.pre(x) * x_mask
+ v = torch.transpose(v, 1, -1) # [b, h, t]
+ v = self.hub(v) * x_mask
+ x = x + v + self.pit(f0).transpose(1, 2)
+ x = self.enc(x * x_mask, x_mask)
+ stats = self.proj(x) * x_mask
+ m, logs = torch.split(stats, self.out_channels, dim=1)
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
+ return z, m, logs, x_mask, x
+
+
+class ResidualCouplingBlock(nn.Module):
+ def __init__(
+ self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ n_flows=4,
+ gin_channels=0,
+ ):
+ super().__init__()
+ self.flows = nn.ModuleList()
+ for i in range(n_flows):
+ self.flows.append(
+ modules.ResidualCouplingLayer(
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=gin_channels,
+ mean_only=True,
+ )
+ )
+ self.flows.append(modules.Flip())
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ if not reverse:
+ total_logdet = 0
+ for flow in self.flows:
+ x, log_det = flow(x, x_mask, g=g, reverse=reverse)
+ total_logdet += log_det
+ return x, total_logdet
+ else:
+ total_logdet = 0
+ for flow in reversed(self.flows):
+ x, log_det = flow(x, x_mask, g=g, reverse=reverse)
+ total_logdet += log_det
+ return x, total_logdet
+
+ def remove_weight_norm(self):
+ for i in range(self.n_flows):
+ self.flows[i * 2].remove_weight_norm()
+
+
+class PosteriorEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=0,
+ ):
+ super().__init__()
+ self.out_channels = out_channels
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
+ self.enc = modules.WN(
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=gin_channels,
+ )
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+ def forward(self, x, x_lengths, g=None):
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
+ x.dtype
+ )
+ x = self.pre(x) * x_mask
+ x = self.enc(x, x_mask, g=g)
+ stats = self.proj(x) * x_mask
+ m, logs = torch.split(stats, self.out_channels, dim=1)
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
+ return z, m, logs, x_mask
+
+ def remove_weight_norm(self):
+ self.enc.remove_weight_norm()
+
+
+class SynthesizerTrn(nn.Module):
+ def __init__(
+ self,
+ spec_channels,
+ segment_size,
+ hp
+ ):
+ super().__init__()
+ self.segment_size = segment_size
+ self.emb_g = nn.Linear(hp.vits.spk_dim, hp.vits.gin_channels)
+ self.enc_p = TextEncoder(
+ hp.vits.ppg_dim,
+ hp.vits.vec_dim,
+ hp.vits.inter_channels,
+ hp.vits.hidden_channels,
+ hp.vits.filter_channels,
+ 2,
+ 6,
+ 3,
+ 0.1,
+ )
+ self.speaker_classifier = SpeakerClassifier(
+ hp.vits.hidden_channels,
+ hp.vits.spk_dim,
+ )
+ self.enc_q = PosteriorEncoder(
+ spec_channels,
+ hp.vits.inter_channels,
+ hp.vits.hidden_channels,
+ 5,
+ 1,
+ 16,
+ gin_channels=hp.vits.gin_channels,
+ )
+ self.flow = ResidualCouplingBlock(
+ hp.vits.inter_channels,
+ hp.vits.hidden_channels,
+ 5,
+ 1,
+ 4,
+ gin_channels=hp.vits.spk_dim
+ )
+ self.dec = Generator(hp=hp)
+
+ def forward(self, ppg, vec, pit, spec, spk, ppg_l, spec_l):
+ ppg = ppg + torch.randn_like(ppg) * 1 # Perturbation
+ vec = vec + torch.randn_like(vec) * 2 # Perturbation
+ g = self.emb_g(F.normalize(spk)).unsqueeze(-1)
+ z_p, m_p, logs_p, ppg_mask, x = self.enc_p(
+ ppg, ppg_l, vec, f0=f0_to_coarse(pit))
+ z_q, m_q, logs_q, spec_mask = self.enc_q(spec, spec_l, g=g)
+
+ z_slice, pit_slice, ids_slice = commons.rand_slice_segments_with_pitch(
+ z_q, pit, spec_l, self.segment_size)
+ audio = self.dec(spk, z_slice, pit_slice)
+
+ # SNAC to flow
+ z_f, logdet_f = self.flow(z_q, spec_mask, g=spk)
+ z_r, logdet_r = self.flow(z_p, spec_mask, g=spk, reverse=True)
+ # speaker
+ spk_preds = self.speaker_classifier(x)
+ return audio, ids_slice, spec_mask, (z_f, z_r, z_p, m_p, logs_p, z_q, m_q, logs_q, logdet_f, logdet_r), spk_preds
+
+ def infer(self, ppg, vec, pit, spk, ppg_l):
+ ppg = ppg + torch.randn_like(ppg) * 0.0001 # Perturbation
+ z_p, m_p, logs_p, ppg_mask, x = self.enc_p(
+ ppg, ppg_l, vec, f0=f0_to_coarse(pit))
+ z, _ = self.flow(z_p, ppg_mask, g=spk, reverse=True)
+ o = self.dec(spk, z * ppg_mask, f0=pit)
+ return o
+
+
+class SynthesizerInfer(nn.Module):
+ def __init__(
+ self,
+ spec_channels,
+ segment_size,
+ hp
+ ):
+ super().__init__()
+ self.segment_size = segment_size
+ self.enc_p = TextEncoder(
+ hp.vits.ppg_dim,
+ hp.vits.vec_dim,
+ hp.vits.inter_channels,
+ hp.vits.hidden_channels,
+ hp.vits.filter_channels,
+ 2,
+ 6,
+ 3,
+ 0.1,
+ )
+ self.flow = ResidualCouplingBlock(
+ hp.vits.inter_channels,
+ hp.vits.hidden_channels,
+ 5,
+ 1,
+ 4,
+ gin_channels=hp.vits.spk_dim
+ )
+ self.dec = Generator(hp=hp)
+
+ def remove_weight_norm(self):
+ self.flow.remove_weight_norm()
+ self.dec.remove_weight_norm()
+
+ def pitch2source(self, f0):
+ return self.dec.pitch2source(f0)
+
+ def source2wav(self, source):
+ return self.dec.source2wav(source)
+
+ def inference(self, ppg, vec, pit, spk, ppg_l, source):
+ z_p, m_p, logs_p, ppg_mask, x = self.enc_p(
+ ppg, ppg_l, vec, f0=f0_to_coarse(pit))
+ z, _ = self.flow(z_p, ppg_mask, g=spk, reverse=True)
+ o = self.dec.inference(spk, z * ppg_mask, source)
+ return o
diff --git a/vits/modules.py b/vits/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a0e545add587fd5f18437ef4756f34ae23b0e08
--- /dev/null
+++ b/vits/modules.py
@@ -0,0 +1,324 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+from vits import commons
+
+
+LRELU_SLOPE = 0.1
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-5):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ x = x.transpose(1, -1)
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+ return x.transpose(1, -1)
+
+
+class ConvReluNorm(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ hidden_channels,
+ out_channels,
+ kernel_size,
+ n_layers,
+ p_dropout,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+ assert n_layers > 1, "Number of layers should be larger than 0."
+
+ self.conv_layers = nn.ModuleList()
+ self.norm_layers = nn.ModuleList()
+ self.conv_layers.append(
+ nn.Conv1d(
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
+ )
+ )
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
+ for _ in range(n_layers - 1):
+ self.conv_layers.append(
+ nn.Conv1d(
+ hidden_channels,
+ hidden_channels,
+ kernel_size,
+ padding=kernel_size // 2,
+ )
+ )
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(self, x, x_mask):
+ x_org = x
+ for i in range(self.n_layers):
+ x = self.conv_layers[i](x * x_mask)
+ x = self.norm_layers[i](x)
+ x = self.relu_drop(x)
+ x = x_org + self.proj(x)
+ return x * x_mask
+
+
+class DDSConv(nn.Module):
+ """
+ Dialted and Depth-Separable Convolution
+ """
+
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
+ super().__init__()
+ self.channels = channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+
+ self.drop = nn.Dropout(p_dropout)
+ self.convs_sep = nn.ModuleList()
+ self.convs_1x1 = nn.ModuleList()
+ self.norms_1 = nn.ModuleList()
+ self.norms_2 = nn.ModuleList()
+ for i in range(n_layers):
+ dilation = kernel_size**i
+ padding = (kernel_size * dilation - dilation) // 2
+ self.convs_sep.append(
+ nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ groups=channels,
+ dilation=dilation,
+ padding=padding,
+ )
+ )
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
+ self.norms_1.append(LayerNorm(channels))
+ self.norms_2.append(LayerNorm(channels))
+
+ def forward(self, x, x_mask, g=None):
+ if g is not None:
+ x = x + g
+ for i in range(self.n_layers):
+ y = self.convs_sep[i](x * x_mask)
+ y = self.norms_1[i](y)
+ y = F.gelu(y)
+ y = self.convs_1x1[i](y)
+ y = self.norms_2[i](y)
+ y = F.gelu(y)
+ y = self.drop(y)
+ x = x + y
+ return x * x_mask
+
+
+class WN(torch.nn.Module):
+ def __init__(
+ self,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ gin_channels=0,
+ p_dropout=0,
+ ):
+ super(WN, self).__init__()
+ assert kernel_size % 2 == 1
+ self.hidden_channels = hidden_channels
+ self.kernel_size = (kernel_size,)
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+
+ self.in_layers = torch.nn.ModuleList()
+ self.res_skip_layers = torch.nn.ModuleList()
+ self.drop = nn.Dropout(p_dropout)
+
+ if gin_channels != 0:
+ cond_layer = torch.nn.Conv1d(
+ gin_channels, 2 * hidden_channels * n_layers, 1
+ )
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
+
+ for i in range(n_layers):
+ dilation = dilation_rate**i
+ padding = int((kernel_size * dilation - dilation) / 2)
+ in_layer = torch.nn.Conv1d(
+ hidden_channels,
+ 2 * hidden_channels,
+ kernel_size,
+ dilation=dilation,
+ padding=padding,
+ )
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
+ self.in_layers.append(in_layer)
+
+ # last one is not necessary
+ if i < n_layers - 1:
+ res_skip_channels = 2 * hidden_channels
+ else:
+ res_skip_channels = hidden_channels
+
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
+ self.res_skip_layers.append(res_skip_layer)
+
+ def forward(self, x, x_mask, g=None, **kwargs):
+ output = torch.zeros_like(x)
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
+
+ if g is not None:
+ g = self.cond_layer(g)
+
+ for i in range(self.n_layers):
+ x_in = self.in_layers[i](x)
+ if g is not None:
+ cond_offset = i * 2 * self.hidden_channels
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
+ else:
+ g_l = torch.zeros_like(x_in)
+
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
+ acts = self.drop(acts)
+
+ res_skip_acts = self.res_skip_layers[i](acts)
+ if i < self.n_layers - 1:
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
+ x = (x + res_acts) * x_mask
+ output = output + res_skip_acts[:, self.hidden_channels:, :]
+ else:
+ output = output + res_skip_acts
+ return output * x_mask
+
+ def remove_weight_norm(self):
+ if self.gin_channels != 0:
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
+ for l in self.in_layers:
+ torch.nn.utils.remove_weight_norm(l)
+ for l in self.res_skip_layers:
+ torch.nn.utils.remove_weight_norm(l)
+
+
+class Log(nn.Module):
+ def forward(self, x, x_mask, reverse=False, **kwargs):
+ if not reverse:
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
+ logdet = torch.sum(-y, [1, 2])
+ return y, logdet
+ else:
+ x = torch.exp(x) * x_mask
+ return x
+
+
+class Flip(nn.Module):
+ def forward(self, x, *args, reverse=False, **kwargs):
+ x = torch.flip(x, [1])
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
+ return x, logdet
+
+
+class ElementwiseAffine(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.channels = channels
+ self.m = nn.Parameter(torch.zeros(channels, 1))
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
+
+ def forward(self, x, x_mask, reverse=False, **kwargs):
+ if not reverse:
+ y = self.m + torch.exp(self.logs) * x
+ y = y * x_mask
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
+ return y, logdet
+ else:
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
+ return x
+
+
+class ResidualCouplingLayer(nn.Module):
+ def __init__(
+ self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ p_dropout=0,
+ gin_channels=0,
+ mean_only=False,
+ ):
+ assert channels % 2 == 0, "channels should be divisible by 2"
+ super().__init__()
+ self.channels = channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.half_channels = channels // 2
+ self.mean_only = mean_only
+
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+ # no use gin_channels
+ self.enc = WN(
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ p_dropout=p_dropout,
+ )
+ self.post = nn.Conv1d(
+ hidden_channels, self.half_channels * (2 - mean_only), 1)
+ self.post.weight.data.zero_()
+ self.post.bias.data.zero_()
+ # SNAC Speaker-normalized Affine Coupling Layer
+ self.snac = nn.Conv1d(gin_channels, 2 * self.half_channels, 1)
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ speaker = self.snac(g.unsqueeze(-1))
+ speaker_m, speaker_v = speaker.chunk(2, dim=1) # (B, half_channels, 1)
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+ # x0 norm
+ x0_norm = (x0 - speaker_m) * torch.exp(-speaker_v) * x_mask
+ h = self.pre(x0_norm) * x_mask
+ # don't use global condition
+ h = self.enc(h, x_mask)
+ stats = self.post(h) * x_mask
+ if not self.mean_only:
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
+ else:
+ m = stats
+ logs = torch.zeros_like(m)
+
+ if not reverse:
+ # x1 norm before affine xform
+ x1_norm = (x1 - speaker_m) * torch.exp(-speaker_v) * x_mask
+ x1 = (m + x1_norm * torch.exp(logs)) * x_mask
+ x = torch.cat([x0, x1], 1)
+ # speaker var to logdet
+ logdet = torch.sum(logs * x_mask, [1, 2]) - torch.sum(
+ speaker_v.expand(-1, -1, logs.size(-1)) * x_mask, [1, 2])
+ return x, logdet
+ else:
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
+ # x1 denorm before output
+ x1 = (speaker_m + x1 * torch.exp(speaker_v)) * x_mask
+ x = torch.cat([x0, x1], 1)
+ # speaker var to logdet
+ logdet = torch.sum(-logs * x_mask, [1, 2]) + torch.sum(
+ speaker_v.expand(-1, -1, logs.size(-1)) * x_mask, [1, 2])
+ return x, logdet
+
+ def remove_weight_norm(self):
+ self.enc.remove_weight_norm()
diff --git a/vits/modules_grl.py b/vits/modules_grl.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c8510725210f5f31b3677f2e8f30c3b6c215f0f
--- /dev/null
+++ b/vits/modules_grl.py
@@ -0,0 +1,62 @@
+# Adapted from https://github.com/ubisoft/ubisoft-laforge-daft-exprt Apache License Version 2.0
+# Unsupervised Domain Adaptation by Backpropagation
+
+import torch
+import torch.nn as nn
+
+from torch.autograd import Function
+from torch.nn.utils import weight_norm
+
+
+class GradientReversalFunction(Function):
+ @staticmethod
+ def forward(ctx, x, lambda_):
+ ctx.lambda_ = lambda_
+ return x.clone()
+
+ @staticmethod
+ def backward(ctx, grads):
+ lambda_ = ctx.lambda_
+ lambda_ = grads.new_tensor(lambda_)
+ dx = -lambda_ * grads
+ return dx, None
+
+
+class GradientReversal(torch.nn.Module):
+ ''' Gradient Reversal Layer
+ Y. Ganin, V. Lempitsky,
+ "Unsupervised Domain Adaptation by Backpropagation",
+ in ICML, 2015.
+ Forward pass is the identity function
+ In the backward pass, upstream gradients are multiplied by -lambda (i.e. gradient are reversed)
+ '''
+
+ def __init__(self, lambda_reversal=1):
+ super(GradientReversal, self).__init__()
+ self.lambda_ = lambda_reversal
+
+ def forward(self, x):
+ return GradientReversalFunction.apply(x, self.lambda_)
+
+
+class SpeakerClassifier(nn.Module):
+
+ def __init__(self, embed_dim, spk_dim):
+ super(SpeakerClassifier, self).__init__()
+ self.classifier = nn.Sequential(
+ GradientReversal(lambda_reversal=1),
+ weight_norm(nn.Conv1d(embed_dim, embed_dim, kernel_size=5, padding=2)),
+ nn.ReLU(),
+ weight_norm(nn.Conv1d(embed_dim, embed_dim, kernel_size=5, padding=2)),
+ nn.ReLU(),
+ weight_norm(nn.Conv1d(embed_dim, spk_dim, kernel_size=5, padding=2))
+ )
+
+ def forward(self, x):
+ ''' Forward function of Speaker Classifier:
+ x = (B, embed_dim, len)
+ '''
+ # pass through classifier
+ outputs = self.classifier(x) # (B, nb_speakers)
+ outputs = torch.mean(outputs, dim=-1)
+ return outputs
diff --git a/vits/spectrogram.py b/vits/spectrogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..67b54b1757f977f840ba97e0ad28b241fceeecd7
--- /dev/null
+++ b/vits/spectrogram.py
@@ -0,0 +1,140 @@
+import torch
+import torch.utils.data
+
+from librosa.filters import mel as librosa_mel_fn
+
+MAX_WAV_VALUE = 32768.0
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ """
+ PARAMS
+ ------
+ C: compression factor
+ """
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+ """
+ PARAMS
+ ------
+ C: compression factor used to compress
+ """
+ return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+ output = dynamic_range_compression_torch(magnitudes)
+ return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+ output = dynamic_range_decompression_torch(magnitudes)
+ return output
+
+
+mel_basis = {}
+hann_window = {}
+
+
+def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
+ if torch.min(y) < -1.0:
+ print("min value is ", torch.min(y))
+ if torch.max(y) > 1.0:
+ print("max value is ", torch.max(y))
+
+ global hann_window
+ dtype_device = str(y.dtype) + "_" + str(y.device)
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
+ if wnsize_dtype_device not in hann_window:
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
+ dtype=y.dtype, device=y.device
+ )
+
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1),
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
+ mode="reflect",
+ )
+ y = y.squeeze(1)
+
+ spec = torch.stft(
+ y,
+ n_fft,
+ hop_length=hop_size,
+ win_length=win_size,
+ window=hann_window[wnsize_dtype_device],
+ center=center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=False,
+ )
+
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+ return spec
+
+
+def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
+ global mel_basis
+ dtype_device = str(spec.dtype) + "_" + str(spec.device)
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
+ if fmax_dtype_device not in mel_basis:
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
+ dtype=spec.dtype, device=spec.device
+ )
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
+ spec = spectral_normalize_torch(spec)
+ return spec
+
+
+def mel_spectrogram_torch(
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
+):
+ if torch.min(y) < -1.0:
+ print("min value is ", torch.min(y))
+ if torch.max(y) > 1.0:
+ print("max value is ", torch.max(y))
+
+ global mel_basis, hann_window
+ dtype_device = str(y.dtype) + "_" + str(y.device)
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
+ if fmax_dtype_device not in mel_basis:
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
+ dtype=y.dtype, device=y.device
+ )
+ if wnsize_dtype_device not in hann_window:
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
+ dtype=y.dtype, device=y.device
+ )
+
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1),
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
+ mode="reflect",
+ )
+ y = y.squeeze(1)
+
+ spec = torch.stft(
+ y,
+ n_fft,
+ hop_length=hop_size,
+ win_length=win_size,
+ window=hann_window[wnsize_dtype_device],
+ center=center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=False,
+ )
+
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
+ spec = spectral_normalize_torch(spec)
+
+ return spec
diff --git a/vits/utils.py b/vits/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2ae3a16ecd0112c41422a0696a09740a3f2f6a3
--- /dev/null
+++ b/vits/utils.py
@@ -0,0 +1,33 @@
+import torch
+import numpy as np
+from scipy.io.wavfile import read
+
+MATPLOTLIB_FLAG = False
+
+
+def load_wav_to_torch(full_path):
+ sampling_rate, data = read(full_path)
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
+
+
+f0_bin = 256
+f0_max = 1100.0
+f0_min = 50.0
+f0_mel_min = 1127 * np.log(1 + f0_min / 700)
+f0_mel_max = 1127 * np.log(1 + f0_max / 700)
+
+
+def f0_to_coarse(f0):
+ is_torch = isinstance(f0, torch.Tensor)
+ f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * \
+ np.log(1 + f0 / 700)
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * \
+ (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
+
+ f0_mel[f0_mel <= 1] = 1
+ f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
+ f0_coarse = (
+ f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int)
+ assert f0_coarse.max() <= 255 and f0_coarse.min(
+ ) >= 1, (f0_coarse.max(), f0_coarse.min())
+ return f0_coarse
diff --git a/vits_decoder/LICENSE.txt b/vits_decoder/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e9663595cc28938f88d6299acd3ba791542e4c0c
--- /dev/null
+++ b/vits_decoder/LICENSE.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 NVIDIA CORPORATION.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/vits_decoder/__init__.py b/vits_decoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..986a0cfe522626f45f6c2d4dede44374c86bbe71
--- /dev/null
+++ b/vits_decoder/__init__.py
@@ -0,0 +1 @@
+from .alias.act import SnakeAlias
\ No newline at end of file
diff --git a/vits_decoder/alias/LICENSE-alias.txt b/vits_decoder/alias/LICENSE-alias.txt
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/vits_decoder/alias/LICENSE-alias.txt
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/vits_decoder/alias/LICENSE-snake.txt b/vits_decoder/alias/LICENSE-snake.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9c28182ace9ed5b2d9c8ee4b0e003d1f6f10c757
--- /dev/null
+++ b/vits_decoder/alias/LICENSE-snake.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 Edward Dixon
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/vits_decoder/alias/__init__.py b/vits_decoder/alias/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2318b63198250856809c0cb46210a4147b829bc
--- /dev/null
+++ b/vits_decoder/alias/__init__.py
@@ -0,0 +1,6 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+from .filter import *
+from .resample import *
+from .act import *
\ No newline at end of file
diff --git a/vits_decoder/alias/act.py b/vits_decoder/alias/act.py
new file mode 100644
index 0000000000000000000000000000000000000000..308344fb6ccbc39317c584a3ee1fb2f29084678e
--- /dev/null
+++ b/vits_decoder/alias/act.py
@@ -0,0 +1,129 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch import sin, pow
+from torch.nn import Parameter
+from .resample import UpSample1d, DownSample1d
+
+
+class Activation1d(nn.Module):
+ def __init__(self,
+ activation,
+ up_ratio: int = 2,
+ down_ratio: int = 2,
+ up_kernel_size: int = 12,
+ down_kernel_size: int = 12):
+ super().__init__()
+ self.up_ratio = up_ratio
+ self.down_ratio = down_ratio
+ self.act = activation
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
+
+ # x: [B,C,T]
+ def forward(self, x):
+ x = self.upsample(x)
+ x = self.act(x)
+ x = self.downsample(x)
+
+ return x
+
+
+class SnakeBeta(nn.Module):
+ '''
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
+ Shape:
+ - Input: (B, C, T)
+ - Output: (B, C, T), same shape as the input
+ Parameters:
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ References:
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+ https://arxiv.org/abs/2006.08195
+ Examples:
+ >>> a1 = snakebeta(256)
+ >>> x = torch.randn(256)
+ >>> x = a1(x)
+ '''
+
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
+ '''
+ Initialization.
+ INPUT:
+ - in_features: shape of the input
+ - alpha - trainable parameter that controls frequency
+ - beta - trainable parameter that controls magnitude
+ alpha is initialized to 1 by default, higher values = higher-frequency.
+ beta is initialized to 1 by default, higher values = higher-magnitude.
+ alpha will be trained along with the rest of your model.
+ '''
+ super(SnakeBeta, self).__init__()
+ self.in_features = in_features
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
+ self.beta = Parameter(torch.ones(in_features) * alpha)
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ '''
+ Forward pass of the function.
+ Applies the function to the input elementwise.
+ SnakeBeta = x + 1/b * sin^2 (xa)
+ '''
+ alpha = self.alpha.unsqueeze(
+ 0).unsqueeze(-1) # line up with x to [B, C, T]
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ beta = torch.exp(beta)
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
+ return x
+
+
+class Mish(nn.Module):
+ """
+ Mish activation function is proposed in "Mish: A Self
+ Regularized Non-Monotonic Neural Activation Function"
+ paper, https://arxiv.org/abs/1908.08681.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x * torch.tanh(F.softplus(x))
+
+
+class SnakeAlias(nn.Module):
+ def __init__(self,
+ channels,
+ up_ratio: int = 2,
+ down_ratio: int = 2,
+ up_kernel_size: int = 12,
+ down_kernel_size: int = 12):
+ super().__init__()
+ self.up_ratio = up_ratio
+ self.down_ratio = down_ratio
+ self.act = SnakeBeta(channels, alpha_logscale=True)
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
+
+ # x: [B,C,T]
+ def forward(self, x):
+ x = self.upsample(x)
+ x = self.act(x)
+ x = self.downsample(x)
+
+ return x
\ No newline at end of file
diff --git a/vits_decoder/alias/filter.py b/vits_decoder/alias/filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ad6ea87c1f10ddd94c544037791d7a4634d5ae1
--- /dev/null
+++ b/vits_decoder/alias/filter.py
@@ -0,0 +1,95 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+
+if 'sinc' in dir(torch):
+ sinc = torch.sinc
+else:
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
+ # https://adefossez.github.io/julius/julius/core.html
+ # LICENSE is in incl_licenses directory.
+ def sinc(x: torch.Tensor):
+ """
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
+ """
+ return torch.where(x == 0,
+ torch.tensor(1., device=x.device, dtype=x.dtype),
+ torch.sin(math.pi * x) / math.pi / x)
+
+
+# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
+# https://adefossez.github.io/julius/julius/lowpass.html
+# LICENSE is in incl_licenses directory.
+def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
+ even = (kernel_size % 2 == 0)
+ half_size = kernel_size // 2
+
+ #For kaiser window
+ delta_f = 4 * half_width
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
+ if A > 50.:
+ beta = 0.1102 * (A - 8.7)
+ elif A >= 21.:
+ beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
+ else:
+ beta = 0.
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
+
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
+ if even:
+ time = (torch.arange(-half_size, half_size) + 0.5)
+ else:
+ time = torch.arange(kernel_size) - half_size
+ if cutoff == 0:
+ filter_ = torch.zeros_like(time)
+ else:
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
+ # of the constant component in the input signal.
+ filter_ /= filter_.sum()
+ filter = filter_.view(1, 1, kernel_size)
+
+ return filter
+
+
+class LowPassFilter1d(nn.Module):
+ def __init__(self,
+ cutoff=0.5,
+ half_width=0.6,
+ stride: int = 1,
+ padding: bool = True,
+ padding_mode: str = 'replicate',
+ kernel_size: int = 12):
+ # kernel_size should be even number for stylegan3 setup,
+ # in this implementation, odd number is also possible.
+ super().__init__()
+ if cutoff < -0.:
+ raise ValueError("Minimum cutoff must be larger than zero.")
+ if cutoff > 0.5:
+ raise ValueError("A cutoff above 0.5 does not make sense.")
+ self.kernel_size = kernel_size
+ self.even = (kernel_size % 2 == 0)
+ self.pad_left = kernel_size // 2 - int(self.even)
+ self.pad_right = kernel_size // 2
+ self.stride = stride
+ self.padding = padding
+ self.padding_mode = padding_mode
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
+ self.register_buffer("filter", filter)
+
+ #input [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+
+ if self.padding:
+ x = F.pad(x, (self.pad_left, self.pad_right),
+ mode=self.padding_mode)
+ out = F.conv1d(x, self.filter.expand(C, -1, -1),
+ stride=self.stride, groups=C)
+
+ return out
\ No newline at end of file
diff --git a/vits_decoder/alias/resample.py b/vits_decoder/alias/resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..750e6c3402cc5ac939c4b9d075246562e0e1d1a7
--- /dev/null
+++ b/vits_decoder/alias/resample.py
@@ -0,0 +1,49 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+# LICENSE is in incl_licenses directory.
+
+import torch.nn as nn
+from torch.nn import functional as F
+from .filter import LowPassFilter1d
+from .filter import kaiser_sinc_filter1d
+
+
+class UpSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ self.stride = ratio
+ self.pad = self.kernel_size // ratio - 1
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
+ half_width=0.6 / ratio,
+ kernel_size=self.kernel_size)
+ self.register_buffer("filter", filter)
+
+ # x: [B, C, T]
+ def forward(self, x):
+ _, C, _ = x.shape
+
+ x = F.pad(x, (self.pad, self.pad), mode='replicate')
+ x = self.ratio * F.conv_transpose1d(
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
+ x = x[..., self.pad_left:-self.pad_right]
+
+ return x
+
+
+class DownSample1d(nn.Module):
+ def __init__(self, ratio=2, kernel_size=None):
+ super().__init__()
+ self.ratio = ratio
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+ self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
+ half_width=0.6 / ratio,
+ stride=ratio,
+ kernel_size=self.kernel_size)
+
+ def forward(self, x):
+ xx = self.lowpass(x)
+
+ return xx
\ No newline at end of file
diff --git a/vits_decoder/bigv.py b/vits_decoder/bigv.py
new file mode 100644
index 0000000000000000000000000000000000000000..029362c34b2c850cc2d59eea4410f77380d84bbe
--- /dev/null
+++ b/vits_decoder/bigv.py
@@ -0,0 +1,64 @@
+import torch
+import torch.nn as nn
+
+from torch.nn import Conv1d
+from torch.nn.utils import weight_norm, remove_weight_norm
+from .alias.act import SnakeAlias
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size*dilation - dilation)/2)
+
+
+class AMPBlock(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(AMPBlock, self).__init__()
+ self.convs1 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2])))
+ ])
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1)))
+ ])
+ self.convs2.apply(init_weights)
+
+ # total number of conv layers
+ self.num_layers = len(self.convs1) + len(self.convs2)
+
+ # periodic nonlinearity with snakebeta function and anti-aliasing
+ self.activations = nn.ModuleList([
+ SnakeAlias(channels) for _ in range(self.num_layers)
+ ])
+
+ def forward(self, x):
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
+ xt = a1(x)
+ xt = c1(xt)
+ xt = a2(xt)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
\ No newline at end of file
diff --git a/vits_decoder/discriminator.py b/vits_decoder/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..764c0ca806b707e4f36ca2abb64ce79971358dd9
--- /dev/null
+++ b/vits_decoder/discriminator.py
@@ -0,0 +1,39 @@
+import torch
+import torch.nn as nn
+
+from omegaconf import OmegaConf
+from .msd import ScaleDiscriminator
+from .mpd import MultiPeriodDiscriminator
+from .mrd import MultiResolutionDiscriminator
+
+
+class Discriminator(nn.Module):
+ def __init__(self, hp):
+ super(Discriminator, self).__init__()
+ self.MRD = MultiResolutionDiscriminator(hp)
+ self.MPD = MultiPeriodDiscriminator(hp)
+ self.MSD = ScaleDiscriminator()
+
+ def forward(self, x):
+ r = self.MRD(x)
+ p = self.MPD(x)
+ s = self.MSD(x)
+ return r + p + s
+
+
+if __name__ == '__main__':
+ hp = OmegaConf.load('../config/base.yaml')
+ model = Discriminator(hp)
+
+ x = torch.randn(3, 1, 16384)
+ print(x.shape)
+
+ output = model(x)
+ for features, score in output:
+ for feat in features:
+ print(feat.shape)
+ print(score.shape)
+
+ pytorch_total_params = sum(p.numel()
+ for p in model.parameters() if p.requires_grad)
+ print(pytorch_total_params)
diff --git a/vits_decoder/generator.py b/vits_decoder/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..787302bd496ee0545d9699b1cff1835a243cd62b
--- /dev/null
+++ b/vits_decoder/generator.py
@@ -0,0 +1,200 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from torch.nn import Conv1d
+from torch.nn import ConvTranspose1d
+from torch.nn.utils import weight_norm
+from torch.nn.utils import remove_weight_norm
+
+from .nsf import SourceModuleHnNSF
+from .bigv import init_weights, AMPBlock, SnakeAlias
+
+
+class SpeakerAdapter(nn.Module):
+
+ def __init__(self,
+ speaker_dim,
+ adapter_dim,
+ epsilon=1e-5
+ ):
+ super(SpeakerAdapter, self).__init__()
+ self.speaker_dim = speaker_dim
+ self.adapter_dim = adapter_dim
+ self.epsilon = epsilon
+ self.W_scale = nn.Linear(self.speaker_dim, self.adapter_dim)
+ self.W_bias = nn.Linear(self.speaker_dim, self.adapter_dim)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ torch.nn.init.constant_(self.W_scale.weight, 0.0)
+ torch.nn.init.constant_(self.W_scale.bias, 1.0)
+ torch.nn.init.constant_(self.W_bias.weight, 0.0)
+ torch.nn.init.constant_(self.W_bias.bias, 0.0)
+
+ def forward(self, x, speaker_embedding):
+ x = x.transpose(1, -1)
+ mean = x.mean(dim=-1, keepdim=True)
+ var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
+ std = (var + self.epsilon).sqrt()
+ y = (x - mean) / std
+ scale = self.W_scale(speaker_embedding)
+ bias = self.W_bias(speaker_embedding)
+ y *= scale.unsqueeze(1)
+ y += bias.unsqueeze(1)
+ y = y.transpose(1, -1)
+ return y
+
+
+class Generator(torch.nn.Module):
+ # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
+ def __init__(self, hp):
+ super(Generator, self).__init__()
+ self.hp = hp
+ self.num_kernels = len(hp.gen.resblock_kernel_sizes)
+ self.num_upsamples = len(hp.gen.upsample_rates)
+ # speaker adaper, 256 should change by what speaker encoder you use
+ self.adapter = SpeakerAdapter(hp.vits.spk_dim, hp.gen.upsample_input)
+ # pre conv
+ self.conv_pre = Conv1d(hp.gen.upsample_input,
+ hp.gen.upsample_initial_channel, 7, 1, padding=3)
+ # nsf
+ self.f0_upsamp = torch.nn.Upsample(
+ scale_factor=np.prod(hp.gen.upsample_rates))
+ self.m_source = SourceModuleHnNSF(sampling_rate=hp.data.sampling_rate)
+ self.noise_convs = nn.ModuleList()
+ # transposed conv-based upsamplers. does not apply anti-aliasing
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(hp.gen.upsample_rates, hp.gen.upsample_kernel_sizes)):
+ # print(f'ups: {i} {k}, {u}, {(k - u) // 2}')
+ # base
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ hp.gen.upsample_initial_channel // (2 ** i),
+ hp.gen.upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2)
+ )
+ )
+ # nsf
+ if i + 1 < len(hp.gen.upsample_rates):
+ stride_f0 = np.prod(hp.gen.upsample_rates[i + 1:])
+ stride_f0 = int(stride_f0)
+ self.noise_convs.append(
+ Conv1d(
+ 1,
+ hp.gen.upsample_initial_channel // (2 ** (i + 1)),
+ kernel_size=stride_f0 * 2,
+ stride=stride_f0,
+ padding=stride_f0 // 2,
+ )
+ )
+ else:
+ self.noise_convs.append(
+ Conv1d(1, hp.gen.upsample_initial_channel //
+ (2 ** (i + 1)), kernel_size=1)
+ )
+
+ # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = hp.gen.upsample_initial_channel // (2 ** (i + 1))
+ for k, d in zip(hp.gen.resblock_kernel_sizes, hp.gen.resblock_dilation_sizes):
+ self.resblocks.append(AMPBlock(ch, k, d))
+
+ # post conv
+ self.activation_post = SnakeAlias(ch)
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
+ # weight initialization
+ self.ups.apply(init_weights)
+
+ def forward(self, spk, x, f0):
+ # Perturbation
+ x = x + torch.randn_like(x)
+ # adapter
+ x = self.adapter(x, spk)
+ x = self.conv_pre(x)
+ x = x * torch.tanh(F.softplus(x))
+ # nsf
+ f0 = f0[:, None]
+ f0 = self.f0_upsamp(f0).transpose(1, 2)
+ har_source = self.m_source(f0)
+ har_source = har_source.transpose(1, 2)
+
+ for i in range(self.num_upsamples):
+ # upsampling
+ x = self.ups[i](x)
+ # nsf
+ x_source = self.noise_convs[i](har_source)
+ x = x + x_source
+ # AMP blocks
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+
+ # post conv
+ x = self.activation_post(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+
+ def eval(self, inference=False):
+ super(Generator, self).eval()
+ # don't remove weight norm while validation in training loop
+ if inference:
+ self.remove_weight_norm()
+
+ def pitch2source(self, f0):
+ f0 = f0[:, None]
+ f0 = self.f0_upsamp(f0).transpose(1, 2) # [1,len,1]
+ har_source = self.m_source(f0)
+ har_source = har_source.transpose(1, 2) # [1,1,len]
+ return har_source
+
+ def source2wav(self, audio):
+ MAX_WAV_VALUE = 32768.0
+ audio = audio.squeeze()
+ audio = MAX_WAV_VALUE * audio
+ audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
+ audio = audio.short()
+ return audio.cpu().detach().numpy()
+
+ def inference(self, spk, x, har_source):
+ # adapter
+ x = self.adapter(x, spk)
+ x = self.conv_pre(x)
+ x = x * torch.tanh(F.softplus(x))
+
+ for i in range(self.num_upsamples):
+ # upsampling
+ x = self.ups[i](x)
+ # nsf
+ x_source = self.noise_convs[i](har_source)
+ x = x + x_source
+ # AMP blocks
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+
+ # post conv
+ x = self.activation_post(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+ return x
diff --git a/vits_decoder/med.py b/vits_decoder/med.py
new file mode 100644
index 0000000000000000000000000000000000000000..77554d3c07b98328c0cc5c9b0b8301c22568f55c
--- /dev/null
+++ b/vits_decoder/med.py
@@ -0,0 +1,65 @@
+import torch
+import torchaudio
+import typing as T
+
+
+class MelspecDiscriminator(torch.nn.Module):
+ """mel spectrogram (frequency domain) discriminator"""
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.SAMPLE_RATE = 48000
+ # mel filterbank transform
+ self._melspec = torchaudio.transforms.MelSpectrogram(
+ sample_rate=self.SAMPLE_RATE,
+ n_fft=2048,
+ win_length=int(0.025 * self.SAMPLE_RATE),
+ hop_length=int(0.010 * self.SAMPLE_RATE),
+ n_mels=128,
+ power=1,
+ )
+
+ # time-frequency 2D convolutions
+ kernel_sizes = [(7, 7), (4, 4), (4, 4), (4, 4)]
+ strides = [(1, 2), (1, 2), (1, 2), (1, 2)]
+ self._convs = torch.nn.ModuleList(
+ [
+ torch.nn.Sequential(
+ torch.nn.Conv2d(
+ in_channels=1 if i == 0 else 32,
+ out_channels=64,
+ kernel_size=k,
+ stride=s,
+ padding=(1, 2),
+ bias=False,
+ ),
+ torch.nn.BatchNorm2d(num_features=64),
+ torch.nn.GLU(dim=1),
+ )
+ for i, (k, s) in enumerate(zip(kernel_sizes, strides))
+ ]
+ )
+
+ # output adversarial projection
+ self._postnet = torch.nn.Conv2d(
+ in_channels=32,
+ out_channels=1,
+ kernel_size=(15, 3),
+ stride=(1, 2),
+ )
+
+ def forward(self, x: torch.Tensor) -> T.Tuple[torch.Tensor, T.List[torch.Tensor]]:
+ # apply the log-scale mel spectrogram transform
+ x = torch.log(self._melspec(x) + 1e-5)
+
+ # compute hidden layers and feature maps
+ f = []
+ for c in self._convs:
+ x = c(x)
+ f.append(x)
+
+ # apply the output projection and global average pooling
+ x = self._postnet(x)
+ x = x.mean(dim=[-2, -1])
+
+ return [(f, x)]
diff --git a/vits_decoder/mpd.py b/vits_decoder/mpd.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dc63e859dd2920f9d02b285ebc4dae8cf318d6a
--- /dev/null
+++ b/vits_decoder/mpd.py
@@ -0,0 +1,61 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils import weight_norm, spectral_norm
+
+class DiscriminatorP(nn.Module):
+ def __init__(self, hp, period):
+ super(DiscriminatorP, self).__init__()
+
+ self.LRELU_SLOPE = hp.mpd.lReLU_slope
+ self.period = period
+
+ kernel_size = hp.mpd.kernel_size
+ stride = hp.mpd.stride
+ norm_f = weight_norm if hp.mpd.use_spectral_norm == False else spectral_norm
+
+ self.convs = nn.ModuleList([
+ norm_f(nn.Conv2d(1, 64, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
+ norm_f(nn.Conv2d(64, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
+ norm_f(nn.Conv2d(128, 256, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
+ norm_f(nn.Conv2d(256, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
+ norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), 1, padding=(kernel_size // 2, 0))),
+ ])
+ self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+ def forward(self, x):
+ fmap = []
+
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t = t + n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, self.LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return fmap, x
+
+
+class MultiPeriodDiscriminator(nn.Module):
+ def __init__(self, hp):
+ super(MultiPeriodDiscriminator, self).__init__()
+
+ self.discriminators = nn.ModuleList(
+ [DiscriminatorP(hp, period) for period in hp.mpd.periods]
+ )
+
+ def forward(self, x):
+ ret = list()
+ for disc in self.discriminators:
+ ret.append(disc(x))
+
+ return ret # [(feat, score), (feat, score), (feat, score), (feat, score), (feat, score)]
diff --git a/vits_decoder/mrd.py b/vits_decoder/mrd.py
new file mode 100644
index 0000000000000000000000000000000000000000..da6db1a416366603d2e65b400d66c44262e2baef
--- /dev/null
+++ b/vits_decoder/mrd.py
@@ -0,0 +1,62 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils import weight_norm, spectral_norm
+
+class DiscriminatorR(torch.nn.Module):
+ def __init__(self, hp, resolution):
+ super(DiscriminatorR, self).__init__()
+
+ self.resolution = resolution
+ self.LRELU_SLOPE = hp.mpd.lReLU_slope
+
+ norm_f = weight_norm if hp.mrd.use_spectral_norm == False else spectral_norm
+
+ self.convs = nn.ModuleList([
+ norm_f(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
+ norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
+ norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
+ norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
+ norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
+ ])
+ self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
+
+ def forward(self, x):
+ fmap = []
+
+ x = self.spectrogram(x)
+ x = x.unsqueeze(1)
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, self.LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return fmap, x
+
+ def spectrogram(self, x):
+ n_fft, hop_length, win_length = self.resolution
+ x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
+ x = x.squeeze(1)
+ x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=False) #[B, F, TT, 2]
+ mag = torch.norm(x, p=2, dim =-1) #[B, F, TT]
+
+ return mag
+
+
+class MultiResolutionDiscriminator(torch.nn.Module):
+ def __init__(self, hp):
+ super(MultiResolutionDiscriminator, self).__init__()
+ self.resolutions = eval(hp.mrd.resolutions)
+ self.discriminators = nn.ModuleList(
+ [DiscriminatorR(hp, resolution) for resolution in self.resolutions]
+ )
+
+ def forward(self, x):
+ ret = list()
+ for disc in self.discriminators:
+ ret.append(disc(x))
+
+ return ret # [(feat, score), (feat, score), (feat, score)]
diff --git a/vits_decoder/msd.py b/vits_decoder/msd.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e254fa3f1b53368332751a3e7235e93297c44c3
--- /dev/null
+++ b/vits_decoder/msd.py
@@ -0,0 +1,29 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils import weight_norm
+
+
+class ScaleDiscriminator(torch.nn.Module):
+ def __init__(self):
+ super(ScaleDiscriminator, self).__init__()
+ self.convs = nn.ModuleList([
+ weight_norm(nn.Conv1d(1, 16, 15, 1, padding=7)),
+ weight_norm(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)),
+ weight_norm(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)),
+ weight_norm(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
+ weight_norm(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
+ weight_norm(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
+ ])
+ self.conv_post = weight_norm(nn.Conv1d(1024, 1, 3, 1, padding=1))
+
+ def forward(self, x):
+ fmap = []
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, 0.1)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+ return [(fmap, x)]
diff --git a/vits_decoder/nsf.py b/vits_decoder/nsf.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e9e6c7e344eb616a7ca427da1a02a2c2093c942
--- /dev/null
+++ b/vits_decoder/nsf.py
@@ -0,0 +1,394 @@
+import torch
+import numpy as np
+import sys
+import torch.nn.functional as torch_nn_func
+
+
+class PulseGen(torch.nn.Module):
+ """Definition of Pulse train generator
+
+ There are many ways to implement pulse generator.
+ Here, PulseGen is based on SinGen. For a perfect
+ """
+
+ def __init__(self, samp_rate, pulse_amp=0.1, noise_std=0.003, voiced_threshold=0):
+ super(PulseGen, self).__init__()
+ self.pulse_amp = pulse_amp
+ self.sampling_rate = samp_rate
+ self.voiced_threshold = voiced_threshold
+ self.noise_std = noise_std
+ self.l_sinegen = SineGen(
+ self.sampling_rate,
+ harmonic_num=0,
+ sine_amp=self.pulse_amp,
+ noise_std=0,
+ voiced_threshold=self.voiced_threshold,
+ flag_for_pulse=True,
+ )
+
+ def forward(self, f0):
+ """Pulse train generator
+ pulse_train, uv = forward(f0)
+ input F0: tensor(batchsize=1, length, dim=1)
+ f0 for unvoiced steps should be 0
+ output pulse_train: tensor(batchsize=1, length, dim)
+ output uv: tensor(batchsize=1, length, 1)
+
+ Note: self.l_sine doesn't make sure that the initial phase of
+ a voiced segment is np.pi, the first pulse in a voiced segment
+ may not be at the first time step within a voiced segment
+ """
+ with torch.no_grad():
+ sine_wav, uv, noise = self.l_sinegen(f0)
+
+ # sine without additive noise
+ pure_sine = sine_wav - noise
+
+ # step t corresponds to a pulse if
+ # sine[t] > sine[t+1] & sine[t] > sine[t-1]
+ # & sine[t-1], sine[t+1], and sine[t] are voiced
+ # or
+ # sine[t] is voiced, sine[t-1] is unvoiced
+ # we use torch.roll to simulate sine[t+1] and sine[t-1]
+ sine_1 = torch.roll(pure_sine, shifts=1, dims=1)
+ uv_1 = torch.roll(uv, shifts=1, dims=1)
+ uv_1[:, 0, :] = 0
+ sine_2 = torch.roll(pure_sine, shifts=-1, dims=1)
+ uv_2 = torch.roll(uv, shifts=-1, dims=1)
+ uv_2[:, -1, :] = 0
+
+ loc = (pure_sine > sine_1) * (pure_sine > sine_2) \
+ * (uv_1 > 0) * (uv_2 > 0) * (uv > 0) \
+ + (uv_1 < 1) * (uv > 0)
+
+ # pulse train without noise
+ pulse_train = pure_sine * loc
+
+ # additive noise to pulse train
+ # note that noise from sinegen is zero in voiced regions
+ pulse_noise = torch.randn_like(pure_sine) * self.noise_std
+
+ # with additive noise on pulse, and unvoiced regions
+ pulse_train += pulse_noise * loc + pulse_noise * (1 - uv)
+ return pulse_train, sine_wav, uv, pulse_noise
+
+
+class SignalsConv1d(torch.nn.Module):
+ """Filtering input signal with time invariant filter
+ Note: FIRFilter conducted filtering given fixed FIR weight
+ SignalsConv1d convolves two signals
+ Note: this is based on torch.nn.functional.conv1d
+
+ """
+
+ def __init__(self):
+ super(SignalsConv1d, self).__init__()
+
+ def forward(self, signal, system_ir):
+ """output = forward(signal, system_ir)
+
+ signal: (batchsize, length1, dim)
+ system_ir: (length2, dim)
+
+ output: (batchsize, length1, dim)
+ """
+ if signal.shape[-1] != system_ir.shape[-1]:
+ print("Error: SignalsConv1d expects shape:")
+ print("signal (batchsize, length1, dim)")
+ print("system_id (batchsize, length2, dim)")
+ print("But received signal: {:s}".format(str(signal.shape)))
+ print(" system_ir: {:s}".format(str(system_ir.shape)))
+ sys.exit(1)
+ padding_length = system_ir.shape[0] - 1
+ groups = signal.shape[-1]
+
+ # pad signal on the left
+ signal_pad = torch_nn_func.pad(signal.permute(0, 2, 1), (padding_length, 0))
+ # prepare system impulse response as (dim, 1, length2)
+ # also flip the impulse response
+ ir = torch.flip(system_ir.unsqueeze(1).permute(2, 1, 0), dims=[2])
+ # convolute
+ output = torch_nn_func.conv1d(signal_pad, ir, groups=groups)
+ return output.permute(0, 2, 1)
+
+
+class CyclicNoiseGen_v1(torch.nn.Module):
+ """CyclicnoiseGen_v1
+ Cyclic noise with a single parameter of beta.
+ Pytorch v1 implementation assumes f_t is also fixed
+ """
+
+ def __init__(self, samp_rate, noise_std=0.003, voiced_threshold=0):
+ super(CyclicNoiseGen_v1, self).__init__()
+ self.samp_rate = samp_rate
+ self.noise_std = noise_std
+ self.voiced_threshold = voiced_threshold
+
+ self.l_pulse = PulseGen(
+ samp_rate,
+ pulse_amp=1.0,
+ noise_std=noise_std,
+ voiced_threshold=voiced_threshold,
+ )
+ self.l_conv = SignalsConv1d()
+
+ def noise_decay(self, beta, f0mean):
+ """decayed_noise = noise_decay(beta, f0mean)
+ decayed_noise = n[t]exp(-t * f_mean / beta / samp_rate)
+
+ beta: (dim=1) or (batchsize=1, 1, dim=1)
+ f0mean (batchsize=1, 1, dim=1)
+
+ decayed_noise (batchsize=1, length, dim=1)
+ """
+ with torch.no_grad():
+ # exp(-1.0 n / T) < 0.01 => n > -log(0.01)*T = 4.60*T
+ # truncate the noise when decayed by -40 dB
+ length = 4.6 * self.samp_rate / f0mean
+ length = length.int()
+ time_idx = torch.arange(0, length, device=beta.device)
+ time_idx = time_idx.unsqueeze(0).unsqueeze(2)
+ time_idx = time_idx.repeat(beta.shape[0], 1, beta.shape[2])
+
+ noise = torch.randn(time_idx.shape, device=beta.device)
+
+ # due to Pytorch implementation, use f0_mean as the f0 factor
+ decay = torch.exp(-time_idx * f0mean / beta / self.samp_rate)
+ return noise * self.noise_std * decay
+
+ def forward(self, f0s, beta):
+ """Producde cyclic-noise"""
+ # pulse train
+ pulse_train, sine_wav, uv, noise = self.l_pulse(f0s)
+ pure_pulse = pulse_train - noise
+
+ # decayed_noise (length, dim=1)
+ if (uv < 1).all():
+ # all unvoiced
+ cyc_noise = torch.zeros_like(sine_wav)
+ else:
+ f0mean = f0s[uv > 0].mean()
+
+ decayed_noise = self.noise_decay(beta, f0mean)[0, :, :]
+ # convolute
+ cyc_noise = self.l_conv(pure_pulse, decayed_noise)
+
+ # add noise in invoiced segments
+ cyc_noise = cyc_noise + noise * (1.0 - uv)
+ return cyc_noise, pulse_train, sine_wav, uv, noise
+
+
+class SineGen(torch.nn.Module):
+ """Definition of sine generator
+ SineGen(samp_rate, harmonic_num = 0,
+ sine_amp = 0.1, noise_std = 0.003,
+ voiced_threshold = 0,
+ flag_for_pulse=False)
+
+ samp_rate: sampling rate in Hz
+ harmonic_num: number of harmonic overtones (default 0)
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
+ noise_std: std of Gaussian noise (default 0.003)
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
+
+ Note: when flag_for_pulse is True, the first time step of a voiced
+ segment is always sin(np.pi) or cos(0)
+ """
+
+ def __init__(
+ self,
+ samp_rate,
+ harmonic_num=0,
+ sine_amp=0.1,
+ noise_std=0.003,
+ voiced_threshold=0,
+ flag_for_pulse=False,
+ ):
+ super(SineGen, self).__init__()
+ self.sine_amp = sine_amp
+ self.noise_std = noise_std
+ self.harmonic_num = harmonic_num
+ self.dim = self.harmonic_num + 1
+ self.sampling_rate = samp_rate
+ self.voiced_threshold = voiced_threshold
+ self.flag_for_pulse = flag_for_pulse
+
+ def _f02uv(self, f0):
+ # generate uv signal
+ uv = torch.ones_like(f0)
+ uv = uv * (f0 > self.voiced_threshold)
+ return uv
+
+ def _f02sine(self, f0_values):
+ """f0_values: (batchsize, length, dim)
+ where dim indicates fundamental tone and overtones
+ """
+ # convert to F0 in rad. The interger part n can be ignored
+ # because 2 * np.pi * n doesn't affect phase
+ rad_values = (f0_values / self.sampling_rate) % 1
+
+ # initial phase noise (no noise for fundamental component)
+ rand_ini = torch.rand(
+ f0_values.shape[0], f0_values.shape[2], device=f0_values.device
+ )
+ rand_ini[:, 0] = 0
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
+
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
+ if not self.flag_for_pulse:
+ # for normal case
+
+ # To prevent torch.cumsum numerical overflow,
+ # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
+ # Buffer tmp_over_one_idx indicates the time step to add -1.
+ # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
+ tmp_over_one = torch.cumsum(rad_values, 1) % 1
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
+ cumsum_shift = torch.zeros_like(rad_values)
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
+
+ sines = torch.sin(
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
+ )
+ else:
+ # If necessary, make sure that the first time step of every
+ # voiced segments is sin(pi) or cos(0)
+ # This is used for pulse-train generation
+
+ # identify the last time step in unvoiced segments
+ uv = self._f02uv(f0_values)
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
+ uv_1[:, -1, :] = 1
+ u_loc = (uv < 1) * (uv_1 > 0)
+
+ # get the instantanouse phase
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
+ # different batch needs to be processed differently
+ for idx in range(f0_values.shape[0]):
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
+ # stores the accumulation of i.phase within
+ # each voiced segments
+ tmp_cumsum[idx, :, :] = 0
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
+
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
+ # within the previous voiced segment.
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
+
+ # get the sines
+ sines = torch.cos(i_phase * 2 * np.pi)
+ return sines
+
+ def forward(self, f0):
+ """sine_tensor, uv = forward(f0)
+ input F0: tensor(batchsize=1, length, dim=1)
+ f0 for unvoiced steps should be 0
+ output sine_tensor: tensor(batchsize=1, length, dim)
+ output uv: tensor(batchsize=1, length, 1)
+ """
+ with torch.no_grad():
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
+ # fundamental component
+ f0_buf[:, :, 0] = f0[:, :, 0]
+ for idx in np.arange(self.harmonic_num):
+ # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
+
+ # generate sine waveforms
+ sine_waves = self._f02sine(f0_buf) * self.sine_amp
+
+ # generate uv signal
+ # uv = torch.ones(f0.shape)
+ # uv = uv * (f0 > self.voiced_threshold)
+ uv = self._f02uv(f0)
+
+ # noise: for unvoiced should be similar to sine_amp
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
+ # . for voiced regions is self.noise_std
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
+ noise = noise_amp * torch.randn_like(sine_waves)
+
+ # first: set the unvoiced part to 0 by uv
+ # then: additive noise
+ sine_waves = sine_waves * uv + noise
+ return sine_waves
+
+
+class SourceModuleCycNoise_v1(torch.nn.Module):
+ """SourceModuleCycNoise_v1
+ SourceModule(sampling_rate, noise_std=0.003, voiced_threshod=0)
+ sampling_rate: sampling_rate in Hz
+
+ noise_std: std of Gaussian noise (default: 0.003)
+ voiced_threshold: threshold to set U/V given F0 (default: 0)
+
+ cyc, noise, uv = SourceModuleCycNoise_v1(F0_upsampled, beta)
+ F0_upsampled (batchsize, length, 1)
+ beta (1)
+ cyc (batchsize, length, 1)
+ noise (batchsize, length, 1)
+ uv (batchsize, length, 1)
+ """
+
+ def __init__(self, sampling_rate, noise_std=0.003, voiced_threshod=0):
+ super(SourceModuleCycNoise_v1, self).__init__()
+ self.sampling_rate = sampling_rate
+ self.noise_std = noise_std
+ self.l_cyc_gen = CyclicNoiseGen_v1(sampling_rate, noise_std, voiced_threshod)
+
+ def forward(self, f0_upsamped, beta):
+ """
+ cyc, noise, uv = SourceModuleCycNoise_v1(F0, beta)
+ F0_upsampled (batchsize, length, 1)
+ beta (1)
+ cyc (batchsize, length, 1)
+ noise (batchsize, length, 1)
+ uv (batchsize, length, 1)
+ """
+ # source for harmonic branch
+ cyc, pulse, sine, uv, add_noi = self.l_cyc_gen(f0_upsamped, beta)
+
+ # source for noise branch, in the same shape as uv
+ noise = torch.randn_like(uv) * self.noise_std / 3
+ return cyc, noise, uv
+
+
+class SourceModuleHnNSF(torch.nn.Module):
+ def __init__(
+ self,
+ sampling_rate=32000,
+ sine_amp=0.1,
+ add_noise_std=0.003,
+ voiced_threshod=0,
+ ):
+ super(SourceModuleHnNSF, self).__init__()
+ harmonic_num = 10
+ self.sine_amp = sine_amp
+ self.noise_std = add_noise_std
+
+ # to produce sine waveforms
+ self.l_sin_gen = SineGen(
+ sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
+ )
+
+ # to merge source harmonics into a single excitation
+ self.l_tanh = torch.nn.Tanh()
+ self.register_buffer('merge_w', torch.FloatTensor([[
+ 0.2942, -0.2243, 0.0033, -0.0056, -0.0020, -0.0046,
+ 0.0221, -0.0083, -0.0241, -0.0036, -0.0581]]))
+ self.register_buffer('merge_b', torch.FloatTensor([0.0008]))
+
+ def forward(self, x):
+ """
+ Sine_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ """
+ # source for harmonic branch
+ sine_wavs = self.l_sin_gen(x)
+ sine_wavs = torch_nn_func.linear(
+ sine_wavs, self.merge_w) + self.merge_b
+ sine_merge = self.l_tanh(sine_wavs)
+ return sine_merge
diff --git a/vits_extend/__init__.py b/vits_extend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vits_extend/dataloader.py b/vits_extend/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f26fe0e15f719b6594110799f3863e720377150
--- /dev/null
+++ b/vits_extend/dataloader.py
@@ -0,0 +1,38 @@
+from torch.utils.data import DataLoader
+from vits.data_utils import DistributedBucketSampler
+from vits.data_utils import TextAudioSpeakerCollate
+from vits.data_utils import TextAudioSpeakerSet
+
+
+def create_dataloader_train(hps, n_gpus, rank):
+ collate_fn = TextAudioSpeakerCollate()
+ train_dataset = TextAudioSpeakerSet(hps.data.training_files, hps.data)
+ train_sampler = DistributedBucketSampler(
+ train_dataset,
+ hps.train.batch_size,
+ [150, 300, 450],
+ num_replicas=n_gpus,
+ rank=rank,
+ shuffle=True)
+ train_loader = DataLoader(
+ train_dataset,
+ num_workers=4,
+ shuffle=False,
+ pin_memory=True,
+ collate_fn=collate_fn,
+ batch_sampler=train_sampler)
+ return train_loader
+
+
+def create_dataloader_eval(hps):
+ collate_fn = TextAudioSpeakerCollate()
+ eval_dataset = TextAudioSpeakerSet(hps.data.validation_files, hps.data)
+ eval_loader = DataLoader(
+ eval_dataset,
+ num_workers=2,
+ shuffle=False,
+ batch_size=hps.train.batch_size,
+ pin_memory=True,
+ drop_last=False,
+ collate_fn=collate_fn)
+ return eval_loader
diff --git a/vits_extend/plotting.py b/vits_extend/plotting.py
new file mode 100644
index 0000000000000000000000000000000000000000..89ff909af85f1ab8788a8047abe8434844a8e16c
--- /dev/null
+++ b/vits_extend/plotting.py
@@ -0,0 +1,49 @@
+import logging
+mpl_logger = logging.getLogger('matplotlib') # must before import matplotlib
+mpl_logger.setLevel(logging.WARNING)
+import matplotlib
+matplotlib.use("Agg")
+
+import numpy as np
+import matplotlib.pylab as plt
+
+
+def save_figure_to_numpy(fig):
+ # save it to a numpy array.
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ data = np.transpose(data, (2, 0, 1))
+ return data
+
+
+def plot_waveform_to_numpy(waveform):
+ fig, ax = plt.subplots(figsize=(12, 4))
+ ax.plot()
+ ax.plot(range(len(waveform)), waveform,
+ linewidth=0.1, alpha=0.7, color='blue')
+
+ plt.xlabel("Samples")
+ plt.ylabel("Amplitude")
+ plt.ylim(-1, 1)
+ plt.tight_layout()
+
+ fig.canvas.draw()
+ data = save_figure_to_numpy(fig)
+ plt.close()
+
+ return data
+
+
+def plot_spectrogram_to_numpy(spectrogram):
+ fig, ax = plt.subplots(figsize=(12, 4))
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
+ interpolation='none')
+ plt.colorbar(im, ax=ax)
+ plt.xlabel("Frames")
+ plt.ylabel("Channels")
+ plt.tight_layout()
+
+ fig.canvas.draw()
+ data = save_figure_to_numpy(fig)
+ plt.close()
+ return data
diff --git a/vits_extend/stft.py b/vits_extend/stft.py
new file mode 100644
index 0000000000000000000000000000000000000000..9510305ffa19528c80380f1e30bb71e38e9fbcf8
--- /dev/null
+++ b/vits_extend/stft.py
@@ -0,0 +1,104 @@
+# MIT License
+#
+# Copyright (c) 2020 Jungil Kong
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import math
+import os
+import random
+import torch
+import torch.utils.data
+import numpy as np
+from librosa.util import normalize
+from scipy.io.wavfile import read
+from librosa.filters import mel as librosa_mel_fn
+
+
+class TacotronSTFT(torch.nn.Module):
+ def __init__(self, filter_length=512, hop_length=160, win_length=512,
+ n_mel_channels=80, sampling_rate=16000, mel_fmin=0.0,
+ mel_fmax=None, center=False, device='cpu'):
+ super(TacotronSTFT, self).__init__()
+ self.n_mel_channels = n_mel_channels
+ self.sampling_rate = sampling_rate
+ self.n_fft = filter_length
+ self.hop_size = hop_length
+ self.win_size = win_length
+ self.fmin = mel_fmin
+ self.fmax = mel_fmax
+ self.center = center
+
+ mel = librosa_mel_fn(
+ sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax)
+
+ mel_basis = torch.from_numpy(mel).float().to(device)
+ hann_window = torch.hann_window(win_length).to(device)
+
+ self.register_buffer('mel_basis', mel_basis)
+ self.register_buffer('hann_window', hann_window)
+
+ def linear_spectrogram(self, y):
+ assert (torch.min(y.data) >= -1)
+ assert (torch.max(y.data) <= 1)
+
+ y = torch.nn.functional.pad(y.unsqueeze(1),
+ (int((self.n_fft - self.hop_size) / 2), int((self.n_fft - self.hop_size) / 2)),
+ mode='reflect')
+ y = y.squeeze(1)
+ spec = torch.stft(y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window,
+ center=self.center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
+ spec = torch.norm(spec, p=2, dim=-1)
+
+ return spec
+
+ def mel_spectrogram(self, y):
+ """Computes mel-spectrograms from a batch of waves
+ PARAMS
+ ------
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
+
+ RETURNS
+ -------
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
+ """
+ assert(torch.min(y.data) >= -1)
+ assert(torch.max(y.data) <= 1)
+
+ y = torch.nn.functional.pad(y.unsqueeze(1),
+ (int((self.n_fft - self.hop_size) / 2), int((self.n_fft - self.hop_size) / 2)),
+ mode='reflect')
+ y = y.squeeze(1)
+
+ spec = torch.stft(y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window,
+ center=self.center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
+
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
+
+ spec = torch.matmul(self.mel_basis, spec)
+ spec = self.spectral_normalize_torch(spec)
+
+ return spec
+
+ def spectral_normalize_torch(self, magnitudes):
+ output = self.dynamic_range_compression_torch(magnitudes)
+ return output
+
+ def dynamic_range_compression_torch(self, x, C=1, clip_val=1e-5):
+ return torch.log(torch.clamp(x, min=clip_val) * C)
diff --git a/vits_extend/stft_loss.py b/vits_extend/stft_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed672b0000b993067668413f2dc6562ae8febdeb
--- /dev/null
+++ b/vits_extend/stft_loss.py
@@ -0,0 +1,133 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Tomoki Hayashi
+# MIT License (https://opensource.org/licenses/MIT)
+
+"""STFT-based Loss modules."""
+
+import torch
+import torch.nn.functional as F
+
+
+def stft(x, fft_size, hop_size, win_length, window):
+ """Perform STFT and convert to magnitude spectrogram.
+ Args:
+ x (Tensor): Input signal tensor (B, T).
+ fft_size (int): FFT size.
+ hop_size (int): Hop size.
+ win_length (int): Window length.
+ window (str): Window function type.
+ Returns:
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
+ """
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False)
+ real = x_stft[..., 0]
+ imag = x_stft[..., 1]
+
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
+ return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
+
+
+class SpectralConvergengeLoss(torch.nn.Module):
+ """Spectral convergence loss module."""
+
+ def __init__(self):
+ """Initilize spectral convergence loss module."""
+ super(SpectralConvergengeLoss, self).__init__()
+
+ def forward(self, x_mag, y_mag):
+ """Calculate forward propagation.
+ Args:
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+ Returns:
+ Tensor: Spectral convergence loss value.
+ """
+ return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
+
+
+class LogSTFTMagnitudeLoss(torch.nn.Module):
+ """Log STFT magnitude loss module."""
+
+ def __init__(self):
+ """Initilize los STFT magnitude loss module."""
+ super(LogSTFTMagnitudeLoss, self).__init__()
+
+ def forward(self, x_mag, y_mag):
+ """Calculate forward propagation.
+ Args:
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+ Returns:
+ Tensor: Log STFT magnitude loss value.
+ """
+ return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
+
+
+class STFTLoss(torch.nn.Module):
+ """STFT loss module."""
+
+ def __init__(self, device, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
+ """Initialize STFT loss module."""
+ super(STFTLoss, self).__init__()
+ self.fft_size = fft_size
+ self.shift_size = shift_size
+ self.win_length = win_length
+ self.window = getattr(torch, window)(win_length).to(device)
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
+
+ def forward(self, x, y):
+ """Calculate forward propagation.
+ Args:
+ x (Tensor): Predicted signal (B, T).
+ y (Tensor): Groundtruth signal (B, T).
+ Returns:
+ Tensor: Spectral convergence loss value.
+ Tensor: Log STFT magnitude loss value.
+ """
+ x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
+ y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
+
+ return sc_loss, mag_loss
+
+
+class MultiResolutionSTFTLoss(torch.nn.Module):
+ """Multi resolution STFT loss module."""
+
+ def __init__(self,
+ device,
+ resolutions,
+ window="hann_window"):
+ """Initialize Multi resolution STFT loss module.
+ Args:
+ resolutions (list): List of (FFT size, hop size, window length).
+ window (str): Window function type.
+ """
+ super(MultiResolutionSTFTLoss, self).__init__()
+ self.stft_losses = torch.nn.ModuleList()
+ for fs, ss, wl in resolutions:
+ self.stft_losses += [STFTLoss(device, fs, ss, wl, window)]
+
+ def forward(self, x, y):
+ """Calculate forward propagation.
+ Args:
+ x (Tensor): Predicted signal (B, T).
+ y (Tensor): Groundtruth signal (B, T).
+ Returns:
+ Tensor: Multi resolution spectral convergence loss value.
+ Tensor: Multi resolution log STFT magnitude loss value.
+ """
+ sc_loss = 0.0
+ mag_loss = 0.0
+ for f in self.stft_losses:
+ sc_l, mag_l = f(x, y)
+ sc_loss += sc_l
+ mag_loss += mag_l
+
+ sc_loss /= len(self.stft_losses)
+ mag_loss /= len(self.stft_losses)
+
+ return sc_loss, mag_loss
diff --git a/vits_extend/train.py b/vits_extend/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a93c294476e70262fa3e399c74c60698157b13c
--- /dev/null
+++ b/vits_extend/train.py
@@ -0,0 +1,312 @@
+import os
+import time
+import logging
+import math
+import tqdm
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributed import init_process_group
+from torch.nn.parallel import DistributedDataParallel
+
+from vits_extend.dataloader import create_dataloader_train
+from vits_extend.dataloader import create_dataloader_eval
+from vits_extend.writer import MyWriter
+from vits_extend.stft import TacotronSTFT
+from vits_extend.stft_loss import MultiResolutionSTFTLoss
+from vits_extend.validation import validate
+from vits_decoder.discriminator import Discriminator
+from vits.models import SynthesizerTrn
+from vits import commons
+from vits.losses import kl_loss
+from vits.commons import clip_grad_value_
+
+
+def load_part(model, saved_state_dict):
+ if hasattr(model, 'module'):
+ state_dict = model.module.state_dict()
+ else:
+ state_dict = model.state_dict()
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if k.startswith('TODO'):
+ new_state_dict[k] = v
+ else:
+ new_state_dict[k] = saved_state_dict[k]
+ if hasattr(model, 'module'):
+ model.module.load_state_dict(new_state_dict)
+ else:
+ model.load_state_dict(new_state_dict)
+ return model
+
+
+def load_model(model, saved_state_dict):
+ if hasattr(model, 'module'):
+ state_dict = model.module.state_dict()
+ else:
+ state_dict = model.state_dict()
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ try:
+ new_state_dict[k] = saved_state_dict[k]
+ except:
+ print("%s is not in the checkpoint" % k)
+ new_state_dict[k] = v
+ if hasattr(model, 'module'):
+ model.module.load_state_dict(new_state_dict)
+ else:
+ model.load_state_dict(new_state_dict)
+ return model
+
+
+def train(rank, args, chkpt_path, hp, hp_str):
+
+ if args.num_gpus > 1:
+ init_process_group(backend=hp.dist_config.dist_backend, init_method=hp.dist_config.dist_url,
+ world_size=hp.dist_config.world_size * args.num_gpus, rank=rank)
+
+ torch.cuda.manual_seed(hp.train.seed)
+ device = torch.device('cuda:{:d}'.format(rank))
+
+ model_g = SynthesizerTrn(
+ hp.data.filter_length // 2 + 1,
+ hp.data.segment_size // hp.data.hop_length,
+ hp).to(device)
+ model_d = Discriminator(hp).to(device)
+
+ optim_g = torch.optim.AdamW(model_g.parameters(),
+ lr=hp.train.learning_rate, betas=hp.train.betas, eps=hp.train.eps)
+ optim_d = torch.optim.AdamW(model_d.parameters(),
+ lr=(hp.train.learning_rate / hp.train.accum_step), betas=hp.train.betas, eps=hp.train.eps)
+
+ init_epoch = 1
+ step = 0
+
+ stft = TacotronSTFT(filter_length=hp.data.filter_length,
+ hop_length=hp.data.hop_length,
+ win_length=hp.data.win_length,
+ n_mel_channels=hp.data.mel_channels,
+ sampling_rate=hp.data.sampling_rate,
+ mel_fmin=hp.data.mel_fmin,
+ mel_fmax=hp.data.mel_fmax,
+ center=False,
+ device=device)
+ # define logger, writer, valloader, stft at rank_zero
+ if rank == 0:
+ pth_dir = os.path.join(hp.log.pth_dir, args.name)
+ log_dir = os.path.join(hp.log.log_dir, args.name)
+ os.makedirs(pth_dir, exist_ok=True)
+ os.makedirs(log_dir, exist_ok=True)
+
+ logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(levelname)s - %(message)s',
+ handlers=[
+ logging.FileHandler(os.path.join(log_dir, '%s-%d.log' % (args.name, time.time()))),
+ logging.StreamHandler()
+ ]
+ )
+ logger = logging.getLogger()
+ writer = MyWriter(hp, log_dir)
+ valloader = create_dataloader_eval(hp)
+
+ if os.path.isfile(hp.train.pretrain):
+ if rank == 0:
+ logger.info("Start from 32k pretrain model: %s" % hp.train.pretrain)
+ checkpoint = torch.load(hp.train.pretrain, map_location='cpu')
+ load_model(model_g, checkpoint['model_g'])
+ load_model(model_d, checkpoint['model_d'])
+
+ if chkpt_path is not None:
+ if rank == 0:
+ logger.info("Resuming from checkpoint: %s" % chkpt_path)
+ checkpoint = torch.load(chkpt_path, map_location='cpu')
+ load_model(model_g, checkpoint['model_g'])
+ load_model(model_d, checkpoint['model_d'])
+ optim_g.load_state_dict(checkpoint['optim_g'])
+ optim_d.load_state_dict(checkpoint['optim_d'])
+ init_epoch = checkpoint['epoch']
+ step = checkpoint['step']
+
+ if rank == 0:
+ if hp_str != checkpoint['hp_str']:
+ logger.warning("New hparams is different from checkpoint. Will use new.")
+ else:
+ if rank == 0:
+ logger.info("Starting new training run.")
+
+ if args.num_gpus > 1:
+ model_g = DistributedDataParallel(model_g, device_ids=[rank])
+ model_d = DistributedDataParallel(model_d, device_ids=[rank])
+
+ # this accelerates training when the size of minibatch is always consistent.
+ # if not consistent, it'll horribly slow down.
+ torch.backends.cudnn.benchmark = True
+
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hp.train.lr_decay, last_epoch=init_epoch-2)
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hp.train.lr_decay, last_epoch=init_epoch-2)
+
+ stft_criterion = MultiResolutionSTFTLoss(device, eval(hp.mrd.resolutions))
+ spkc_criterion = nn.CosineEmbeddingLoss()
+
+ trainloader = create_dataloader_train(hp, args.num_gpus, rank)
+
+ for epoch in range(init_epoch, hp.train.epochs):
+
+ trainloader.batch_sampler.set_epoch(epoch)
+
+ if rank == 0 and epoch % hp.log.eval_interval == 0:
+ with torch.no_grad():
+ validate(hp, args, model_g, model_d, valloader, stft, writer, step, device)
+
+ if rank == 0:
+ loader = tqdm.tqdm(trainloader, desc='Loading train data')
+ else:
+ loader = trainloader
+
+ model_g.train()
+ model_d.train()
+
+ for ppg, ppg_l, vec, pit, spk, spec, spec_l, audio, audio_l in loader:
+
+ ppg = ppg.to(device)
+ vec = vec.to(device)
+ pit = pit.to(device)
+ spk = spk.to(device)
+ spec = spec.to(device)
+ audio = audio.to(device)
+ ppg_l = ppg_l.to(device)
+ spec_l = spec_l.to(device)
+ audio_l = audio_l.to(device)
+
+ # generator
+ fake_audio, ids_slice, z_mask, \
+ (z_f, z_r, z_p, m_p, logs_p, z_q, m_q, logs_q, logdet_f, logdet_r), spk_preds = model_g(
+ ppg, vec, pit, spec, spk, ppg_l, spec_l)
+
+ audio = commons.slice_segments(
+ audio, ids_slice * hp.data.hop_length, hp.data.segment_size) # slice
+ # Spk Loss
+ spk_loss = spkc_criterion(spk, spk_preds, torch.Tensor(spk_preds.size(0))
+ .to(device).fill_(1.0))
+ # Mel Loss
+ mel_fake = stft.mel_spectrogram(fake_audio.squeeze(1))
+ mel_real = stft.mel_spectrogram(audio.squeeze(1))
+ mel_loss = F.l1_loss(mel_fake, mel_real) * hp.train.c_mel
+
+ # Multi-Resolution STFT Loss
+ sc_loss, mag_loss = stft_criterion(fake_audio.squeeze(1), audio.squeeze(1))
+ stft_loss = (sc_loss + mag_loss) * hp.train.c_stft
+
+ # Generator Loss
+ disc_fake = model_d(fake_audio)
+ score_loss = 0.0
+ for (_, score_fake) in disc_fake:
+ score_loss += torch.mean(torch.pow(score_fake - 1.0, 2))
+ score_loss = score_loss / len(disc_fake)
+
+ # Feature Loss
+ disc_real = model_d(audio)
+ feat_loss = 0.0
+ for (feat_fake, _), (feat_real, _) in zip(disc_fake, disc_real):
+ for fake, real in zip(feat_fake, feat_real):
+ feat_loss += torch.mean(torch.abs(fake - real))
+ feat_loss = feat_loss / len(disc_fake)
+ feat_loss = feat_loss * 2
+
+ # Kl Loss
+ loss_kl_f = kl_loss(z_f, logs_q, m_p, logs_p, logdet_f, z_mask) * hp.train.c_kl
+ loss_kl_r = kl_loss(z_r, logs_p, m_q, logs_q, logdet_r, z_mask) * hp.train.c_kl
+
+ # Loss
+ loss_g = score_loss + feat_loss + mel_loss + stft_loss + loss_kl_f + loss_kl_r * 0.5 + spk_loss * 2
+ loss_g.backward()
+
+ if ((step + 1) % hp.train.accum_step == 0) or (step + 1 == len(loader)):
+ # accumulate gradients for accum steps
+ for param in model_g.parameters():
+ param.grad /= hp.train.accum_step
+ clip_grad_value_(model_g.parameters(), None)
+ # update model
+ optim_g.step()
+ optim_g.zero_grad()
+
+ # discriminator
+ optim_d.zero_grad()
+ disc_fake = model_d(fake_audio.detach())
+ disc_real = model_d(audio)
+
+ loss_d = 0.0
+ for (_, score_fake), (_, score_real) in zip(disc_fake, disc_real):
+ loss_d += torch.mean(torch.pow(score_real - 1.0, 2))
+ loss_d += torch.mean(torch.pow(score_fake, 2))
+ loss_d = loss_d / len(disc_fake)
+
+ loss_d.backward()
+ clip_grad_value_(model_d.parameters(), None)
+ optim_d.step()
+
+ step += 1
+ # logging
+ loss_g = loss_g.item()
+ loss_d = loss_d.item()
+ loss_s = stft_loss.item()
+ loss_m = mel_loss.item()
+ loss_k = loss_kl_f.item()
+ loss_r = loss_kl_r.item()
+ loss_i = spk_loss.item()
+
+ if rank == 0 and step % hp.log.info_interval == 0:
+ writer.log_training(
+ loss_g, loss_d, loss_m, loss_s, loss_k, loss_r, score_loss.item(), step)
+ logger.info("epoch %d | g %.04f m %.04f s %.04f d %.04f k %.04f r %.04f i %.04f | step %d" % (
+ epoch, loss_g, loss_m, loss_s, loss_d, loss_k, loss_r, loss_i, step))
+
+ if rank == 0 and epoch % hp.log.save_interval == 0:
+ save_path = os.path.join(pth_dir, '%s_%04d.pt'
+ % (args.name, epoch))
+ torch.save({
+ 'model_g': (model_g.module if args.num_gpus > 1 else model_g).state_dict(),
+ 'model_d': (model_d.module if args.num_gpus > 1 else model_d).state_dict(),
+ 'optim_g': optim_g.state_dict(),
+ 'optim_d': optim_d.state_dict(),
+ 'step': step,
+ 'epoch': epoch,
+ 'hp_str': hp_str,
+ }, save_path)
+ logger.info("Saved checkpoint to: %s" % save_path)
+
+ if rank == 0:
+ def clean_checkpoints(path_to_models=f'{pth_dir}', n_ckpts_to_keep=hp.log.keep_ckpts, sort_by_time=True):
+ """Freeing up space by deleting saved ckpts
+ Arguments:
+ path_to_models -- Path to the model directory
+ n_ckpts_to_keep -- Number of ckpts to keep, excluding sovits5.0_0.pth
+ If n_ckpts_to_keep == 0, do not delete any ckpts
+ sort_by_time -- True -> chronologically delete ckpts
+ False -> lexicographically delete ckpts
+ """
+ assert isinstance(n_ckpts_to_keep, int) and n_ckpts_to_keep >= 0
+ ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
+ name_key = (lambda _f: int(re.compile(f'{args.name}_(\d+)\.pt').match(_f).group(1)))
+ time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
+ sort_key = time_key if sort_by_time else name_key
+ x_sorted = lambda _x: sorted(
+ [f for f in ckpts_files if f.startswith(_x) and not f.endswith('sovits5.0_0.pth')], key=sort_key)
+ if n_ckpts_to_keep == 0:
+ to_del = []
+ else:
+ to_del = [os.path.join(path_to_models, fn) for fn in x_sorted(f'{args.name}')[:-n_ckpts_to_keep]]
+ del_info = lambda fn: logger.info(f"Free up space by deleting ckpt {fn}")
+ del_routine = lambda x: [os.remove(x), del_info(x)]
+ rs = [del_routine(fn) for fn in to_del]
+
+ clean_checkpoints()
+
+ os.makedirs(f'{pth_dir}', exist_ok=True)
+ keep_ckpts = getattr(hp.log, 'keep_ckpts', 0)
+ if keep_ckpts > 0:
+ clean_checkpoints(path_to_models=f'{pth_dir}', n_ckpts_to_keep=hp.log.keep_ckpts, sort_by_time=True)
+
+ scheduler_g.step()
+ scheduler_d.step()
diff --git a/vits_extend/validation.py b/vits_extend/validation.py
new file mode 100644
index 0000000000000000000000000000000000000000..acf93a1bb428b25386e8365bac19d7cfe22759d7
--- /dev/null
+++ b/vits_extend/validation.py
@@ -0,0 +1,48 @@
+import tqdm
+import torch
+import torch.nn.functional as F
+
+
+def validate(hp, args, generator, discriminator, valloader, stft, writer, step, device):
+ generator.eval()
+ discriminator.eval()
+ torch.backends.cudnn.benchmark = False
+
+ loader = tqdm.tqdm(valloader, desc='Validation loop')
+ mel_loss = 0.0
+ for idx, (ppg, ppg_l, vec, pit, spk, spec, spec_l, audio, audio_l) in enumerate(loader):
+ ppg = ppg.to(device)
+ vec = vec.to(device)
+ pit = pit.to(device)
+ spk = spk.to(device)
+ ppg_l = ppg_l.to(device)
+ audio = audio.to(device)
+
+ if hasattr(generator, 'module'):
+ fake_audio = generator.module.infer(ppg, vec, pit, spk, ppg_l)[
+ :, :, :audio.size(2)]
+ else:
+ fake_audio = generator.infer(ppg, vec, pit, spk, ppg_l)[
+ :, :, :audio.size(2)]
+
+ mel_fake = stft.mel_spectrogram(fake_audio.squeeze(1))
+ mel_real = stft.mel_spectrogram(audio.squeeze(1))
+
+ mel_loss += F.l1_loss(mel_fake, mel_real).item()
+
+ if idx < hp.log.num_audio:
+ spec_fake = stft.linear_spectrogram(fake_audio.squeeze(1))
+ spec_real = stft.linear_spectrogram(audio.squeeze(1))
+
+ audio = audio[0][0].cpu().detach().numpy()
+ fake_audio = fake_audio[0][0].cpu().detach().numpy()
+ spec_fake = spec_fake[0].cpu().detach().numpy()
+ spec_real = spec_real[0].cpu().detach().numpy()
+ writer.log_fig_audio(
+ audio, fake_audio, spec_fake, spec_real, idx, step)
+
+ mel_loss = mel_loss / len(valloader.dataset)
+
+ writer.log_validation(mel_loss, generator, discriminator, step)
+
+ torch.backends.cudnn.benchmark = True
diff --git a/vits_extend/writer.py b/vits_extend/writer.py
new file mode 100644
index 0000000000000000000000000000000000000000..386682bfc4467ee027efdca6d2bdbbe50d574895
--- /dev/null
+++ b/vits_extend/writer.py
@@ -0,0 +1,39 @@
+from torch.utils.tensorboard import SummaryWriter
+import numpy as np
+import librosa
+
+from .plotting import plot_waveform_to_numpy, plot_spectrogram_to_numpy
+
+class MyWriter(SummaryWriter):
+ def __init__(self, hp, logdir):
+ super(MyWriter, self).__init__(logdir)
+ self.sample_rate = hp.data.sampling_rate
+
+ def log_training(self, g_loss, d_loss, mel_loss, stft_loss, k_loss, r_loss, score_loss, step):
+ self.add_scalar('train/g_loss', g_loss, step)
+ self.add_scalar('train/d_loss', d_loss, step)
+
+ self.add_scalar('train/score_loss', score_loss, step)
+ self.add_scalar('train/stft_loss', stft_loss, step)
+ self.add_scalar('train/mel_loss', mel_loss, step)
+ self.add_scalar('train/kl_f_loss', k_loss, step)
+ self.add_scalar('train/kl_r_loss', r_loss, step)
+
+ def log_validation(self, mel_loss, generator, discriminator, step):
+ self.add_scalar('validation/mel_loss', mel_loss, step)
+
+ def log_fig_audio(self, real, fake, spec_fake, spec_real, idx, step):
+ if idx == 0:
+ spec_fake = librosa.amplitude_to_db(spec_fake, ref=np.max,top_db=80.)
+ spec_real = librosa.amplitude_to_db(spec_real, ref=np.max,top_db=80.)
+ self.add_image(f'spec_fake/{step}', plot_spectrogram_to_numpy(spec_fake), step)
+ self.add_image(f'wave_fake/{step}', plot_waveform_to_numpy(fake), step)
+ self.add_image(f'spec_real/{step}', plot_spectrogram_to_numpy(spec_real), step)
+ self.add_image(f'wave_real/{step}', plot_waveform_to_numpy(real), step)
+
+ self.add_audio(f'fake/{step}', fake, step, self.sample_rate)
+ self.add_audio(f'real/{step}', real, step, self.sample_rate)
+
+ def log_histogram(self, model, step):
+ for tag, value in model.named_parameters():
+ self.add_histogram(tag.replace('.', '/'), value.cpu().detach().numpy(), step)
diff --git a/vits_pretrain/.DS_Store b/vits_pretrain/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6
Binary files /dev/null and b/vits_pretrain/.DS_Store differ
diff --git a/vits_pretrain/README.md b/vits_pretrain/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2be30a36a0e1e4718afca2cd53a9c69e5edc7df1
--- /dev/null
+++ b/vits_pretrain/README.md
@@ -0,0 +1,3 @@
+Path for:
+
+ sovits5.0_bigvgan_mix_v2.pth
diff --git a/whisper/LICENSE b/whisper/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..d25552598bb9c5400612159ed4bab92ce12a5ce5
--- /dev/null
+++ b/whisper/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 OpenAI
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/whisper/README.md b/whisper/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9ea3a38e58aa56be82a79e31461849083917babb
--- /dev/null
+++ b/whisper/README.md
@@ -0,0 +1,147 @@
+# Whisper
+
+[[Blog]](https://openai.com/blog/whisper)
+[[Paper]](https://arxiv.org/abs/2212.04356)
+[[Model card]](https://github.com/openai/whisper/blob/main/model-card.md)
+[[Colab example]](https://colab.research.google.com/github/openai/whisper/blob/master/notebooks/LibriSpeech.ipynb)
+
+Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification.
+
+
+## Approach
+
+
+
+A Transformer sequence-to-sequence model is trained on various speech processing tasks, including multilingual speech recognition, speech translation, spoken language identification, and voice activity detection. These tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing a single model to replace many stages of a traditional speech-processing pipeline. The multitask training format uses a set of special tokens that serve as task specifiers or classification targets.
+
+
+## Setup
+
+We used Python 3.9.9 and [PyTorch](https://pytorch.org/) 1.10.1 to train and test our models, but the codebase is expected to be compatible with Python 3.8-3.10 and recent PyTorch versions. The codebase also depends on a few Python packages, most notably [HuggingFace Transformers](https://huggingface.co/docs/transformers/index) for their fast tokenizer implementation and [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) for reading audio files. You can download and install (or update to) the latest release of Whisper with the following command:
+
+ pip install -U openai-whisper
+
+Alternatively, the following command will pull and install the latest commit from this repository, along with its Python dependencies:
+
+ pip install git+https://github.com/openai/whisper.git
+
+To update the package to the latest version of this repository, please run:
+
+ pip install --upgrade --no-deps --force-reinstall git+https://github.com/openai/whisper.git
+
+It also requires the command-line tool [`ffmpeg`](https://ffmpeg.org/) to be installed on your system, which is available from most package managers:
+
+```bash
+# on Ubuntu or Debian
+sudo apt update && sudo apt install ffmpeg
+
+# on Arch Linux
+sudo pacman -S ffmpeg
+
+# on MacOS using Homebrew (https://brew.sh/)
+brew install ffmpeg
+
+# on Windows using Chocolatey (https://chocolatey.org/)
+choco install ffmpeg
+
+# on Windows using Scoop (https://scoop.sh/)
+scoop install ffmpeg
+```
+
+You may need [`rust`](http://rust-lang.org) installed as well, in case [tokenizers](https://pypi.org/project/tokenizers/) does not provide a pre-built wheel for your platform. If you see installation errors during the `pip install` command above, please follow the [Getting started page](https://www.rust-lang.org/learn/get-started) to install Rust development environment. Additionally, you may need to configure the `PATH` environment variable, e.g. `export PATH="$HOME/.cargo/bin:$PATH"`. If the installation fails with `No module named 'setuptools_rust'`, you need to install `setuptools_rust`, e.g. by running:
+
+```bash
+pip install setuptools-rust
+```
+
+
+## Available models and languages
+
+There are five model sizes, four with English-only versions, offering speed and accuracy tradeoffs. Below are the names of the available models and their approximate memory requirements and relative speed.
+
+
+| Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
+|:------:|:----------:|:------------------:|:------------------:|:-------------:|:--------------:|
+| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB | ~32x |
+| base | 74 M | `base.en` | `base` | ~1 GB | ~16x |
+| small | 244 M | `small.en` | `small` | ~2 GB | ~6x |
+| medium | 769 M | `medium.en` | `medium` | ~5 GB | ~2x |
+| large | 1550 M | N/A | `large` | ~10 GB | 1x |
+
+The `.en` models for English-only applications tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
+
+Whisper's performance varies widely depending on the language. The figure below shows a WER (Word Error Rate) breakdown by languages of the Fleurs dataset using the `large-v2` model. More WER and BLEU scores corresponding to the other models and datasets can be found in Appendix D in [the paper](https://arxiv.org/abs/2212.04356). The smaller, the better.
+
+
+
+
+
+## Command-line usage
+
+The following command will transcribe speech in audio files, using the `medium` model:
+
+ whisper audio.flac audio.mp3 audio.wav --model medium
+
+The default setting (which selects the `small` model) works well for transcribing English. To transcribe an audio file containing non-English speech, you can specify the language using the `--language` option:
+
+ whisper japanese.wav --language Japanese
+
+Adding `--task translate` will translate the speech into English:
+
+ whisper japanese.wav --language Japanese --task translate
+
+Run the following to view all available options:
+
+ whisper --help
+
+See [tokenizer.py](https://github.com/openai/whisper/blob/main/whisper/tokenizer.py) for the list of all available languages.
+
+
+## Python usage
+
+Transcription can also be performed within Python:
+
+```python
+import whisper
+
+model = whisper.load_model("base")
+result = model.transcribe("audio.mp3")
+print(result["text"])
+```
+
+Internally, the `transcribe()` method reads the entire file and processes the audio with a sliding 30-second window, performing autoregressive sequence-to-sequence predictions on each window.
+
+Below is an example usage of `whisper.detect_language()` and `whisper.decode()` which provide lower-level access to the model.
+
+```python
+import whisper
+
+model = whisper.load_model("base")
+
+# load audio and pad/trim it to fit 30 seconds
+audio = whisper.load_audio("audio.mp3")
+audio = whisper.pad_or_trim(audio)
+
+# make log-Mel spectrogram and move to the same device as the model
+mel = whisper.log_mel_spectrogram(audio).to(model.device)
+
+# detect the spoken language
+_, probs = model.detect_language(mel)
+print(f"Detected language: {max(probs, key=probs.get)}")
+
+# decode the audio
+options = whisper.DecodingOptions()
+result = whisper.decode(model, mel, options)
+
+# print the recognized text
+print(result.text)
+```
+
+## More examples
+
+Please use the [🙌 Show and tell](https://github.com/openai/whisper/discussions/categories/show-and-tell) category in Discussions for sharing more example usages of Whisper and third-party extensions such as web demos, integrations with other tools, ports for different platforms, etc.
+
+
+## License
+
+Whisper's code and model weights are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details.
\ No newline at end of file
diff --git a/whisper/__init__.py b/whisper/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/whisper/audio.py b/whisper/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dfe105adda10dfe78179edb5e39cc6d3bde39f9
--- /dev/null
+++ b/whisper/audio.py
@@ -0,0 +1,100 @@
+import os
+from functools import lru_cache
+from typing import Union
+
+import librosa
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from .utils import exact_div
+
+from librosa.filters import mel as librosa_mel_fn
+
+# hard-coded audio hyperparameters
+SAMPLE_RATE = 16000
+N_FFT = 400
+N_MELS = 80
+HOP_LENGTH = 160
+CHUNK_LENGTH = 30
+N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
+N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
+
+
+def load_audio(file: str, sr: int = SAMPLE_RATE):
+ x, sr = librosa.load(file, sr=sr)
+ return x
+
+
+def pad_or_trim(array, length_max: int = N_SAMPLES, length_min: int = N_SAMPLES // 2, *, axis: int = -1):
+ """
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
+ """
+ if torch.is_tensor(array):
+ if array.shape[axis] > length_max:
+ array = array.index_select(dim=axis, index=torch.arange(length_max, device=array.device))
+
+ if array.shape[axis] < length_min:
+ pad_widths = [(0, 0)] * array.ndim
+ pad_widths[axis] = (0, length_min - array.shape[axis])
+ array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
+ else:
+ if array.shape[axis] > length_max:
+ array = array.take(indices=range(length_max), axis=axis)
+
+ if array.shape[axis] < length_min:
+ pad_widths = [(0, 0)] * array.ndim
+ pad_widths[axis] = (0, length_min - array.shape[axis])
+ array = np.pad(array, pad_widths)
+
+ return array
+
+
+@lru_cache(maxsize=None)
+def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
+ """
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
+ Allows decoupling librosa dependency; saved using:
+
+ np.savez_compressed(
+ "mel_filters.npz",
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
+ )
+ """
+ assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
+ return torch.from_numpy(librosa_mel_fn(sr=SAMPLE_RATE,n_fft=N_FFT,n_mels=n_mels)).to(device)
+
+
+def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
+ """
+ Compute the log-Mel spectrogram of
+
+ Parameters
+ ----------
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
+
+ n_mels: int
+ The number of Mel-frequency filters, only 80 is supported
+
+ Returns
+ -------
+ torch.Tensor, shape = (80, n_frames)
+ A Tensor that contains the Mel spectrogram
+ """
+ if not torch.is_tensor(audio):
+ if isinstance(audio, str):
+ audio = load_audio(audio)
+ audio = torch.from_numpy(audio)
+
+ window = torch.hann_window(N_FFT).to(audio.device)
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
+ magnitudes = stft[..., :-1].abs() ** 2
+
+ filters = mel_filters(audio.device, n_mels)
+ mel_spec = filters @ magnitudes
+
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
+ log_spec = (log_spec + 4.0) / 4.0
+ return log_spec
diff --git a/whisper/decoding.py b/whisper/decoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..603546d4c9ff67514d2567576935b974fe373bef
--- /dev/null
+++ b/whisper/decoding.py
@@ -0,0 +1,712 @@
+from dataclasses import dataclass, field
+from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+from torch.distributions import Categorical
+
+from .audio import CHUNK_LENGTH
+from .tokenizer import Tokenizer, get_tokenizer
+from .utils import compression_ratio
+
+if TYPE_CHECKING:
+ from .model import Whisper
+
+
+@torch.no_grad()
+def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
+ """
+ Detect the spoken language in the audio, and return them as list of strings, along with the ids
+ of the most probable language tokens and the probability distribution over all language tokens.
+ This is performed outside the main decode loop in order to not interfere with kv-caching.
+
+ Returns
+ -------
+ language_tokens : Tensor, shape = (n_audio,)
+ ids of the most probable language tokens, which appears after the startoftranscript token.
+ language_probs : List[Dict[str, float]], length = n_audio
+ list of dictionaries containing the probability distribution over all languages.
+ """
+ if tokenizer is None:
+ tokenizer = get_tokenizer(model.is_multilingual)
+ if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
+ raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
+
+ single = mel.ndim == 2
+ if single:
+ mel = mel.unsqueeze(0)
+
+ # skip encoder forward pass if already-encoded audio features were given
+ if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
+ mel = model.encoder(mel)
+
+ # forward pass using a single token, startoftranscript
+ n_audio = mel.shape[0]
+ x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
+ logits = model.logits(x, mel)[:, 0]
+
+ # collect detected languages; suppress all non-language tokens
+ mask = torch.ones(logits.shape[-1], dtype=torch.bool)
+ mask[list(tokenizer.all_language_tokens)] = False
+ logits[:, mask] = -np.inf
+ language_tokens = logits.argmax(dim=-1)
+ language_token_probs = logits.softmax(dim=-1).cpu()
+ language_probs = [
+ {
+ c: language_token_probs[i, j].item()
+ for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
+ }
+ for i in range(n_audio)
+ ]
+
+ if single:
+ language_tokens = language_tokens[0]
+ language_probs = language_probs[0]
+
+ return language_tokens, language_probs
+
+
+@dataclass(frozen=True)
+class DecodingOptions:
+ task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
+ language: Optional[str] = None # language that the audio is in; uses detected language if None
+
+ # sampling-related options
+ temperature: float = 0.0
+ sample_len: Optional[int] = None # maximum number of tokens to sample
+ best_of: Optional[int] = None # number of independent samples to collect, when t > 0
+ beam_size: Optional[int] = None # number of beams in beam search, when t == 0
+ patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
+
+ # options for ranking generations (either beams or best-of-N samples)
+ length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
+
+ # prompt, prefix, and token suppression
+ prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
+ prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
+ suppress_blank: bool = True # this will suppress blank outputs
+
+ # list of tokens ids (or comma-separated token ids) to suppress
+ # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
+ suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
+
+ # timestamp sampling options
+ without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
+ max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
+
+ # implementation details
+ fp16: bool = True # use fp16 for most of the calculation
+
+
+@dataclass(frozen=True)
+class DecodingResult:
+ audio_features: Tensor
+ language: str
+ language_probs: Optional[Dict[str, float]] = None
+ tokens: List[int] = field(default_factory=list)
+ text: str = ""
+ avg_logprob: float = np.nan
+ no_speech_prob: float = np.nan
+ temperature: float = np.nan
+ compression_ratio: float = np.nan
+
+
+class Inference:
+ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
+ """Perform a forward pass on the decoder and return per-token logits"""
+ raise NotImplementedError
+
+ def rearrange_kv_cache(self, source_indices) -> None:
+ """Update the key-value cache according to the updated beams"""
+ raise NotImplementedError
+
+ def cleanup_caching(self) -> None:
+ """Clean up any resources or hooks after decoding is finished"""
+ pass
+
+
+class PyTorchInference(Inference):
+ def __init__(self, model: "Whisper", initial_token_length: int):
+ self.model: "Whisper" = model
+ self.initial_token_length = initial_token_length
+ self.kv_cache = {}
+ self.hooks = []
+
+ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
+ if not self.kv_cache:
+ self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
+
+ if tokens.shape[-1] > self.initial_token_length:
+ # only need to use the last token except in the first forward pass
+ tokens = tokens[:, -1:]
+
+ return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
+
+ def cleanup_caching(self):
+ for hook in self.hooks:
+ hook.remove()
+
+ self.kv_cache = {}
+ self.hooks = []
+
+ def rearrange_kv_cache(self, source_indices):
+ for module, tensor in self.kv_cache.items():
+ # update the key/value cache to contain the selected sequences
+ self.kv_cache[module] = tensor[source_indices].detach()
+
+
+class SequenceRanker:
+ def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
+ """
+ Given a list of groups of samples and their cumulative log probabilities,
+ return the indices of the samples in each group to select as the final result
+ """
+ raise NotImplementedError
+
+
+class MaximumLikelihoodRanker(SequenceRanker):
+ """
+ Select the sample with the highest log probabilities, penalized using either
+ a simple length normalization or Google NMT paper's length penalty
+ """
+
+ def __init__(self, length_penalty: Optional[float]):
+ self.length_penalty = length_penalty
+
+ def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
+ def scores(logprobs, lengths):
+ result = []
+ for logprob, length in zip(logprobs, lengths):
+ if self.length_penalty is None:
+ penalty = length
+ else:
+ # from the Google NMT paper
+ penalty = ((5 + length) / 6) ** self.length_penalty
+ result.append(logprob / penalty)
+ return result
+
+ # get the sequence with the highest score
+ lengths = [[len(t) for t in s] for s in tokens]
+ return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
+
+
+class TokenDecoder:
+ def reset(self):
+ """Initialize any stateful variables for decoding a new sequence"""
+
+ def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
+ """Specify how to select the next token, based on the current trace and logits
+
+ Parameters
+ ----------
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
+ all tokens in the context so far, including the prefix and sot_sequence tokens
+
+ logits : Tensor, shape = (n_batch, vocab_size)
+ per-token logits of the probability distribution at the current step
+
+ sum_logprobs : Tensor, shape = (n_batch)
+ cumulative log probabilities for each sequence
+
+ Returns
+ -------
+ tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
+ the tokens, appended with the selected next token
+
+ completed : bool
+ True if all sequences has reached the end of text
+
+ """
+ raise NotImplementedError
+
+ def finalize(
+ self, tokens: Tensor, sum_logprobs: Tensor
+ ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
+ """Finalize search and return the final candidate sequences
+
+ Parameters
+ ----------
+ tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
+ all tokens in the context so far, including the prefix and sot_sequence
+
+ sum_logprobs : Tensor, shape = (n_audio, n_group)
+ cumulative log probabilities for each sequence
+
+ Returns
+ -------
+ tokens : Sequence[Sequence[Tensor]], length = n_audio
+ sequence of Tensors containing candidate token sequences, for each audio input
+
+ sum_logprobs : List[List[float]], length = n_audio
+ sequence of cumulative log probabilities corresponding to the above
+
+ """
+ raise NotImplementedError
+
+
+class GreedyDecoder(TokenDecoder):
+ def __init__(self, temperature: float, eot: int):
+ self.temperature = temperature
+ self.eot = eot
+
+ def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
+ temperature = self.temperature
+ if temperature == 0:
+ next_tokens = logits.argmax(dim=-1)
+ else:
+ next_tokens = Categorical(logits=logits / temperature).sample()
+
+ logprobs = F.log_softmax(logits.float(), dim=-1)
+ current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
+ sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
+
+ next_tokens[tokens[:, -1] == self.eot] = self.eot
+ tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
+
+ completed = (tokens[:, -1] == self.eot).all()
+ return tokens, completed
+
+ def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
+ # make sure each sequence has at least one EOT token at the end
+ tokens = F.pad(tokens, (0, 1), value=self.eot)
+ return tokens, sum_logprobs.tolist()
+
+
+class BeamSearchDecoder(TokenDecoder):
+ def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
+ self.beam_size = beam_size
+ self.eot = eot
+ self.inference = inference
+ self.patience = patience or 1.0
+ self.max_candidates: int = round(beam_size * self.patience)
+ self.finished_sequences = None
+
+ assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
+
+ def reset(self):
+ self.finished_sequences = None
+
+ def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
+ if tokens.shape[0] % self.beam_size != 0:
+ raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
+
+ n_audio = tokens.shape[0] // self.beam_size
+ if self.finished_sequences is None: # for the first update
+ self.finished_sequences = [{} for _ in range(n_audio)]
+
+ logprobs = F.log_softmax(logits.float(), dim=-1)
+ next_tokens, source_indices, finished_sequences = [], [], []
+ for i in range(n_audio):
+ scores, sources, finished = {}, {}, {}
+
+ # STEP 1: calculate the cumulative log probabilities for possible candidates
+ for j in range(self.beam_size):
+ idx = i * self.beam_size + j
+ prefix = tokens[idx].tolist()
+ for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
+ new_logprob = (sum_logprobs[idx] + logprob).item()
+ sequence = tuple(prefix + [token.item()])
+ scores[sequence] = new_logprob
+ sources[sequence] = idx
+
+ # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
+ saved = 0
+ for sequence in sorted(scores, key=scores.get, reverse=True):
+ if sequence[-1] == self.eot:
+ finished[sequence] = scores[sequence]
+ else:
+ sum_logprobs[len(next_tokens)] = scores[sequence]
+ next_tokens.append(sequence)
+ source_indices.append(sources[sequence])
+
+ saved += 1
+ if saved == self.beam_size:
+ break
+
+ finished_sequences.append(finished)
+
+ tokens = torch.tensor(next_tokens, device=tokens.device)
+ self.inference.rearrange_kv_cache(source_indices)
+
+ # add newly finished sequences to self.finished_sequences
+ assert len(self.finished_sequences) == len(finished_sequences)
+ for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
+ for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
+ if len(previously_finished) >= self.max_candidates:
+ break # the candidate list is full
+ previously_finished[seq] = newly_finished[seq]
+
+ # mark as completed if all audio has enough number of samples
+ completed = all(
+ len(sequences) >= self.max_candidates for sequences in self.finished_sequences
+ )
+ return tokens, completed
+
+ def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
+ # collect all finished sequences, including patience, and add unfinished ones if not enough
+ sum_logprobs = sum_logprobs.cpu()
+ for i, sequences in enumerate(self.finished_sequences):
+ if len(sequences) < self.beam_size: # when not enough sequences are finished
+ for j in list(np.argsort(sum_logprobs[i]))[::-1]:
+ sequence = preceding_tokens[i, j].tolist() + [self.eot]
+ sequences[tuple(sequence)] = sum_logprobs[i][j].item()
+ if len(sequences) >= self.beam_size:
+ break
+
+ tokens: List[List[Tensor]] = [
+ [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
+ ]
+ sum_logprobs: List[List[float]] = [
+ list(sequences.values()) for sequences in self.finished_sequences
+ ]
+ return tokens, sum_logprobs
+
+
+class LogitFilter:
+ def apply(self, logits: Tensor, tokens: Tensor) -> None:
+ """Apply any filtering or masking to logits in-place
+
+ Parameters
+ ----------
+ logits : Tensor, shape = (n_batch, vocab_size)
+ per-token logits of the probability distribution at the current step
+
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
+ all tokens in the context so far, including the prefix and sot_sequence tokens
+
+ """
+ raise NotImplementedError
+
+
+class SuppressBlank(LogitFilter):
+ def __init__(self, tokenizer: Tokenizer, sample_begin: int):
+ self.tokenizer = tokenizer
+ self.sample_begin = sample_begin
+
+ def apply(self, logits: Tensor, tokens: Tensor):
+ if tokens.shape[1] == self.sample_begin:
+ logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
+
+
+class SuppressTokens(LogitFilter):
+ def __init__(self, suppress_tokens: Sequence[int]):
+ self.suppress_tokens = list(suppress_tokens)
+
+ def apply(self, logits: Tensor, tokens: Tensor):
+ logits[:, self.suppress_tokens] = -np.inf
+
+
+class ApplyTimestampRules(LogitFilter):
+ def __init__(
+ self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
+ ):
+ self.tokenizer = tokenizer
+ self.sample_begin = sample_begin
+ self.max_initial_timestamp_index = max_initial_timestamp_index
+
+ def apply(self, logits: Tensor, tokens: Tensor):
+ # suppress <|notimestamps|> which is handled by without_timestamps
+ if self.tokenizer.no_timestamps is not None:
+ logits[:, self.tokenizer.no_timestamps] = -np.inf
+
+ # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
+ for k in range(tokens.shape[0]):
+ seq = [t for t in tokens[k, self.sample_begin :].tolist()]
+ last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
+ penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
+
+ if last_was_timestamp:
+ if penultimate_was_timestamp: # has to be non-timestamp
+ logits[k, self.tokenizer.timestamp_begin :] = -np.inf
+ else: # cannot be normal text tokens
+ logits[k, : self.tokenizer.eot] = -np.inf
+
+ if tokens.shape[1] == self.sample_begin:
+ # suppress generating non-timestamp tokens at the beginning
+ logits[:, : self.tokenizer.timestamp_begin] = -np.inf
+
+ # apply the `max_initial_timestamp` option
+ if self.max_initial_timestamp_index is not None:
+ last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
+ logits[:, last_allowed + 1 :] = -np.inf
+
+ # if sum of probability over timestamps is above any other token, sample timestamp
+ logprobs = F.log_softmax(logits.float(), dim=-1)
+ for k in range(tokens.shape[0]):
+ timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
+ max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
+ if timestamp_logprob > max_text_token_logprob:
+ logits[k, : self.tokenizer.timestamp_begin] = -np.inf
+
+
+class DecodingTask:
+ inference: Inference
+ sequence_ranker: SequenceRanker
+ decoder: TokenDecoder
+ logit_filters: List[LogitFilter]
+
+ def __init__(self, model: "Whisper", options: DecodingOptions):
+ self.model = model
+
+ language = options.language or "en"
+ tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
+ self.tokenizer: Tokenizer = tokenizer
+ self.options: DecodingOptions = self._verify_options(options)
+
+ self.n_group: int = options.beam_size or options.best_of or 1
+ self.n_ctx: int = model.dims.n_text_ctx
+ self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
+
+ self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
+ if self.options.without_timestamps:
+ self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
+
+ self.initial_tokens: Tuple[int] = self._get_initial_tokens()
+ self.sample_begin: int = len(self.initial_tokens)
+ self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
+
+ # inference: implements the forward pass through the decoder, including kv caching
+ self.inference = PyTorchInference(model, len(self.initial_tokens))
+
+ # sequence ranker: implements how to rank a group of sampled sequences
+ self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
+
+ # decoder: implements how to select the next tokens, given the autoregressive distribution
+ if options.beam_size is not None:
+ self.decoder = BeamSearchDecoder(
+ options.beam_size, tokenizer.eot, self.inference, options.patience
+ )
+ else:
+ self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
+
+ # logit filters: applies various rules to suppress or penalize certain tokens
+ self.logit_filters = []
+ if self.options.suppress_blank:
+ self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
+ if self.options.suppress_tokens:
+ self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
+ if not options.without_timestamps:
+ precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
+ max_initial_timestamp_index = None
+ if options.max_initial_timestamp:
+ max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
+ self.logit_filters.append(
+ ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
+ )
+
+ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
+ if options.beam_size is not None and options.best_of is not None:
+ raise ValueError("beam_size and best_of can't be given together")
+ if options.temperature == 0:
+ if options.best_of is not None:
+ raise ValueError("best_of with greedy sampling (T=0) is not compatible")
+ if options.patience is not None and options.beam_size is None:
+ raise ValueError("patience requires beam_size to be given")
+ if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
+ raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
+
+ return options
+
+ def _get_initial_tokens(self) -> Tuple[int]:
+ tokens = list(self.sot_sequence)
+ prefix = self.options.prefix
+ prompt = self.options.prompt
+
+ if prefix:
+ prefix_tokens = (
+ self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
+ )
+ if self.sample_len is not None:
+ max_prefix_len = self.n_ctx // 2 - self.sample_len
+ prefix_tokens = prefix_tokens[-max_prefix_len:]
+ tokens = tokens + prefix_tokens
+
+ if prompt:
+ prompt_tokens = (
+ self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
+ )
+ tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
+
+ return tuple(tokens)
+
+ def _get_suppress_tokens(self) -> Tuple[int]:
+ suppress_tokens = self.options.suppress_tokens
+
+ if isinstance(suppress_tokens, str):
+ suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
+
+ if -1 in suppress_tokens:
+ suppress_tokens = [t for t in suppress_tokens if t >= 0]
+ suppress_tokens.extend(self.tokenizer.non_speech_tokens)
+ elif suppress_tokens is None or len(suppress_tokens) == 0:
+ suppress_tokens = [] # interpret empty string as an empty list
+ else:
+ assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
+
+ suppress_tokens.extend(
+ [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
+ )
+ if self.tokenizer.no_speech is not None:
+ # no-speech probability is collected separately
+ suppress_tokens.append(self.tokenizer.no_speech)
+
+ return tuple(sorted(set(suppress_tokens)))
+
+ def _get_audio_features(self, mel: Tensor):
+ if self.options.fp16:
+ mel = mel.half()
+
+ if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
+ # encoded audio features are given; skip audio encoding
+ print("encoded audio features are given; skip audio encoding")
+ audio_features = mel
+ else:
+ print(mel.shape)
+ print("===============================")
+ audio_features = self.model.encoder(mel)
+
+ if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
+ return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
+
+ return audio_features
+
+ def _detect_language(self, audio_features: Tensor, tokens: Tensor):
+ languages = [self.options.language] * audio_features.shape[0]
+ lang_probs = None
+
+ if self.options.language is None or self.options.task == "lang_id":
+ lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
+ languages = [max(probs, key=probs.get) for probs in lang_probs]
+ if self.options.language is None:
+ tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
+
+ return languages, lang_probs
+
+ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
+ assert audio_features.shape[0] == tokens.shape[0]
+ n_batch = tokens.shape[0]
+ sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
+ no_speech_probs = [np.nan] * n_batch
+
+ try:
+ for i in range(self.sample_len):
+ logits = self.inference.logits(tokens, audio_features)
+
+ if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
+ probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
+ no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
+
+ # now we need to consider the logits at the last token only
+ logits = logits[:, -1]
+
+ # apply the logit filters, e.g. for suppressing or applying penalty to
+ for logit_filter in self.logit_filters:
+ logit_filter.apply(logits, tokens)
+
+ # expand the tokens tensor with the selected next tokens
+ tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
+
+ if completed or tokens.shape[-1] > self.n_ctx:
+ break
+ finally:
+ self.inference.cleanup_caching()
+
+ return tokens, sum_logprobs, no_speech_probs
+
+ @torch.no_grad()
+ def run(self, mel: Tensor) -> List[DecodingResult]:
+ self.decoder.reset()
+ tokenizer: Tokenizer = self.tokenizer
+ n_audio: int = mel.shape[0]
+
+ audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
+ tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
+
+ # detect language if requested, overwriting the language token
+ languages, language_probs = self._detect_language(audio_features, tokens)
+ if self.options.task == "lang_id":
+ return [
+ DecodingResult(audio_features=features, language=language, language_probs=probs)
+ for features, language, probs in zip(audio_features, languages, language_probs)
+ ]
+
+ # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
+ audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
+ tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
+
+ # call the main sampling loop
+ tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
+
+ # reshape the tensors to have (n_audio, n_group) as the first two dimensions
+ audio_features = audio_features[:: self.n_group]
+ no_speech_probs = no_speech_probs[:: self.n_group]
+ assert audio_features.shape[0] == len(no_speech_probs) == n_audio
+
+ tokens = tokens.reshape(n_audio, self.n_group, -1)
+ sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
+
+ # get the final candidates for each group, and slice between the first sampled token and EOT
+ tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
+ tokens: List[List[Tensor]] = [
+ [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
+ ]
+
+ # select the top-ranked sample in each group
+ selected = self.sequence_ranker.rank(tokens, sum_logprobs)
+ tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
+ texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
+
+ sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
+ avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
+
+ fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
+ if len(set(map(len, fields))) != 1:
+ raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
+
+ return [
+ DecodingResult(
+ audio_features=features,
+ language=language,
+ tokens=tokens,
+ text=text,
+ avg_logprob=avg_logprob,
+ no_speech_prob=no_speech_prob,
+ temperature=self.options.temperature,
+ compression_ratio=compression_ratio(text),
+ )
+ for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
+ ]
+
+
+@torch.no_grad()
+def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
+ """
+ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
+
+ Parameters
+ ----------
+ model: Whisper
+ the Whisper model instance
+
+ mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
+ A tensor containing the Mel spectrogram(s)
+
+ options: DecodingOptions
+ A dataclass that contains all necessary options for decoding 30-second segments
+
+ Returns
+ -------
+ result: Union[DecodingResult, List[DecodingResult]]
+ The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
+ """
+ single = mel.ndim == 2
+ if single:
+ mel = mel.unsqueeze(0)
+ result = DecodingTask(model, options).run(mel)
+
+ if single:
+ result = result[0]
+
+ return result
diff --git a/whisper/inference.py b/whisper/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..16174c1499d49eccce6a212524309138bfe730b8
--- /dev/null
+++ b/whisper/inference.py
@@ -0,0 +1,78 @@
+import sys,os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import numpy as np
+import argparse
+import torch
+
+from whisper.model import Whisper, ModelDimensions
+from whisper.audio import load_audio, pad_or_trim, log_mel_spectrogram
+
+
+def load_model(path, device) -> Whisper:
+ checkpoint = torch.load(path, map_location="cpu")
+ dims = ModelDimensions(**checkpoint["dims"])
+ # print(dims)
+ model = Whisper(dims)
+ del model.decoder
+ cut = len(model.encoder.blocks) // 4
+ cut = -1 * cut
+ del model.encoder.blocks[cut:]
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
+ model.eval()
+ if not (device == "cpu"):
+ model.half()
+ model.to(device)
+ # torch.save({
+ # 'dims': checkpoint["dims"],
+ # 'model_state_dict': model.state_dict(),
+ # }, "large-v2.pt")
+ return model
+
+
+def pred_ppg(whisper: Whisper, wavPath, ppgPath, device):
+ audio = load_audio(wavPath)
+ audln = audio.shape[0]
+ ppg_a = []
+ idx_s = 0
+ while (idx_s + 15 * 16000 < audln):
+ short = audio[idx_s:idx_s + 15 * 16000]
+ idx_s = idx_s + 15 * 16000
+ ppgln = 15 * 16000 // 320
+ # short = pad_or_trim(short)
+ mel = log_mel_spectrogram(short).to(device)
+ if not (device == "cpu"):
+ mel = mel.half()
+ with torch.no_grad():
+ mel = mel + torch.randn_like(mel) * 0.1
+ ppg = whisper.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy()
+ ppg = ppg[:ppgln,] # [length, dim=1024]
+ ppg_a.extend(ppg)
+ if (idx_s < audln):
+ short = audio[idx_s:audln]
+ ppgln = (audln - idx_s) // 320
+ # short = pad_or_trim(short)
+ mel = log_mel_spectrogram(short).to(device)
+ if not (device == "cpu"):
+ mel = mel.half()
+ with torch.no_grad():
+ mel = mel + torch.randn_like(mel) * 0.1
+ ppg = whisper.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy()
+ ppg = ppg[:ppgln,] # [length, dim=1024]
+ ppg_a.extend(ppg)
+ np.save(ppgPath, ppg_a, allow_pickle=False)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True)
+ parser.add_argument("-p", "--ppg", help="ppg", dest="ppg", required=True)
+ args = parser.parse_args()
+ print(args.wav)
+ print(args.ppg)
+
+ wavPath = args.wav
+ ppgPath = args.ppg
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ whisper = load_model(os.path.join("whisper_pretrain", "large-v2.pt"), device)
+ pred_ppg(whisper, wavPath, ppgPath, device)
diff --git a/whisper/model.py b/whisper/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..78d6d135ee8df21e4deb0fffa5d15e5631f960c5
--- /dev/null
+++ b/whisper/model.py
@@ -0,0 +1,270 @@
+from dataclasses import dataclass
+from typing import Dict
+from typing import Iterable, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+from torch import nn
+
+from .decoding import detect_language as detect_language_function, decode as decode_function
+
+
+@dataclass
+class ModelDimensions:
+ n_mels: int
+ n_audio_ctx: int
+ n_audio_state: int
+ n_audio_head: int
+ n_audio_layer: int
+ n_vocab: int
+ n_text_ctx: int
+ n_text_state: int
+ n_text_head: int
+ n_text_layer: int
+
+
+class LayerNorm(nn.LayerNorm):
+ def forward(self, x: Tensor) -> Tensor:
+ # return super().forward(x.float()).type(x.dtype) sovits5.0
+ return super().forward(x).type(x.dtype)
+
+
+class Linear(nn.Linear):
+ def forward(self, x: Tensor) -> Tensor:
+ return F.linear(
+ x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
+ )
+
+
+class Conv1d(nn.Conv1d):
+ def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
+ return super()._conv_forward(
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
+ )
+
+
+def sinusoids(length, channels, max_timescale=10000):
+ """Returns sinusoids for positional embedding"""
+ assert channels % 2 == 0
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, n_state: int, n_head: int):
+ super().__init__()
+ self.n_head = n_head
+ self.query = Linear(n_state, n_state)
+ self.key = Linear(n_state, n_state, bias=False)
+ self.value = Linear(n_state, n_state)
+ self.out = Linear(n_state, n_state)
+
+ def forward(
+ self,
+ x: Tensor,
+ xa: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ kv_cache: Optional[dict] = None,
+ ):
+ q = self.query(x)
+
+ if kv_cache is None or xa is None or self.key not in kv_cache:
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
+ k = self.key(x if xa is None else xa)
+ v = self.value(x if xa is None else xa)
+ else:
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
+ k = kv_cache[self.key]
+ v = kv_cache[self.value]
+
+ wv, qk = self.qkv_attention(q, k, v, mask)
+ return self.out(wv), qk
+
+ def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
+ n_batch, n_ctx, n_state = q.shape
+ scale = (n_state // self.n_head) ** -0.25
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
+
+ qk = q @ k
+ if mask is not None:
+ qk = qk + mask[:n_ctx, :n_ctx]
+ qk = qk.float()
+
+ w = F.softmax(qk, dim=-1).to(q.dtype)
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
+ super().__init__()
+
+ self.attn = MultiHeadAttention(n_state, n_head)
+ self.attn_ln = LayerNorm(n_state)
+
+ self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
+
+ n_mlp = n_state * 4
+ self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
+ self.mlp_ln = LayerNorm(n_state)
+
+ def forward(
+ self,
+ x: Tensor,
+ xa: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ kv_cache: Optional[dict] = None,
+ ):
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
+ if self.cross_attn:
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
+ x = x + self.mlp(self.mlp_ln(x))
+ return x
+
+
+class AudioEncoder(nn.Module):
+ def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
+ super().__init__()
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
+
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
+ )
+ self.ln_post = LayerNorm(n_state)
+
+ def forward(self, x: Tensor):
+ """
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
+ the mel spectrogram of the audio
+ """
+ x = F.gelu(self.conv1(x))
+ x = F.gelu(self.conv2(x))
+ x = x.permute(0, 2, 1)
+
+ len_x = x.shape[1]
+ len_e = self.positional_embedding.shape[0]
+ assert len_x <= len_e, "incorrect audio shape"
+ pos_e = self.positional_embedding[:len_x, :]
+ x = (x + pos_e).to(x.dtype)
+
+ for block in self.blocks:
+ x = block(x)
+
+ x = self.ln_post(x)
+ return x
+
+
+class TextDecoder(nn.Module):
+ def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
+ super().__init__()
+
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
+ self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
+
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
+ [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
+ )
+ self.ln = LayerNorm(n_state)
+
+ mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
+ self.register_buffer("mask", mask, persistent=False)
+
+ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
+ """
+ x : torch.LongTensor, shape = (batch_size, <= n_ctx)
+ the text tokens
+ xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
+ the encoded audio features to be attended on
+ """
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+ x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
+ x = x.to(xa.dtype)
+
+ for block in self.blocks:
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
+
+ x = self.ln(x)
+ logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
+
+ return logits
+
+
+class Whisper(nn.Module):
+ def __init__(self, dims: ModelDimensions):
+ super().__init__()
+ self.dims = dims
+ self.encoder = AudioEncoder(
+ self.dims.n_mels,
+ self.dims.n_audio_ctx,
+ self.dims.n_audio_state,
+ self.dims.n_audio_head,
+ self.dims.n_audio_layer,
+ )
+ self.decoder = TextDecoder(
+ self.dims.n_vocab,
+ self.dims.n_text_ctx,
+ self.dims.n_text_state,
+ self.dims.n_text_head,
+ self.dims.n_text_layer,
+ )
+
+ def embed_audio(self, mel: torch.Tensor):
+ return self.encoder(mel)
+
+ def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
+ return self.decoder(tokens, audio_features)
+
+ def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
+ return self.decoder(tokens, self.encoder(mel))
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def is_multilingual(self):
+ return self.dims.n_vocab == 51865
+
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
+ """
+ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
+ tensors calculated for the previous positions. This method returns a dictionary that stores
+ all caches, and the necessary hooks for the key and value projection modules that save the
+ intermediate tensors to be reused during later calculations.
+
+ Returns
+ -------
+ cache : Dict[nn.Module, torch.Tensor]
+ A dictionary object mapping the key/value projection modules to its cache
+ hooks : List[RemovableHandle]
+ List of PyTorch RemovableHandle objects to stop the hooks to be called
+ """
+ cache = {**cache} if cache is not None else {}
+ hooks = []
+
+ def save_to_cache(module, _, output):
+ if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
+ cache[module] = output # save as-is, for the first token or cross attention
+ else:
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
+ return cache[module]
+
+ def install_hooks(layer: nn.Module):
+ if isinstance(layer, MultiHeadAttention):
+ hooks.append(layer.key.register_forward_hook(save_to_cache))
+ hooks.append(layer.value.register_forward_hook(save_to_cache))
+
+ self.decoder.apply(install_hooks)
+ return cache, hooks
+
+ detect_language = detect_language_function
+ decode = decode_function
diff --git a/whisper/tokenizer.py b/whisper/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a27cb359ee891590d3f793624f9f8ec768a26cc3
--- /dev/null
+++ b/whisper/tokenizer.py
@@ -0,0 +1,331 @@
+import os
+from dataclasses import dataclass
+from functools import lru_cache
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import GPT2TokenizerFast
+
+LANGUAGES = {
+ "en": "english",
+ "zh": "chinese",
+ "de": "german",
+ "es": "spanish",
+ "ru": "russian",
+ "ko": "korean",
+ "fr": "french",
+ "ja": "japanese",
+ "pt": "portuguese",
+ "tr": "turkish",
+ "pl": "polish",
+ "ca": "catalan",
+ "nl": "dutch",
+ "ar": "arabic",
+ "sv": "swedish",
+ "it": "italian",
+ "id": "indonesian",
+ "hi": "hindi",
+ "fi": "finnish",
+ "vi": "vietnamese",
+ "he": "hebrew",
+ "uk": "ukrainian",
+ "el": "greek",
+ "ms": "malay",
+ "cs": "czech",
+ "ro": "romanian",
+ "da": "danish",
+ "hu": "hungarian",
+ "ta": "tamil",
+ "no": "norwegian",
+ "th": "thai",
+ "ur": "urdu",
+ "hr": "croatian",
+ "bg": "bulgarian",
+ "lt": "lithuanian",
+ "la": "latin",
+ "mi": "maori",
+ "ml": "malayalam",
+ "cy": "welsh",
+ "sk": "slovak",
+ "te": "telugu",
+ "fa": "persian",
+ "lv": "latvian",
+ "bn": "bengali",
+ "sr": "serbian",
+ "az": "azerbaijani",
+ "sl": "slovenian",
+ "kn": "kannada",
+ "et": "estonian",
+ "mk": "macedonian",
+ "br": "breton",
+ "eu": "basque",
+ "is": "icelandic",
+ "hy": "armenian",
+ "ne": "nepali",
+ "mn": "mongolian",
+ "bs": "bosnian",
+ "kk": "kazakh",
+ "sq": "albanian",
+ "sw": "swahili",
+ "gl": "galician",
+ "mr": "marathi",
+ "pa": "punjabi",
+ "si": "sinhala",
+ "km": "khmer",
+ "sn": "shona",
+ "yo": "yoruba",
+ "so": "somali",
+ "af": "afrikaans",
+ "oc": "occitan",
+ "ka": "georgian",
+ "be": "belarusian",
+ "tg": "tajik",
+ "sd": "sindhi",
+ "gu": "gujarati",
+ "am": "amharic",
+ "yi": "yiddish",
+ "lo": "lao",
+ "uz": "uzbek",
+ "fo": "faroese",
+ "ht": "haitian creole",
+ "ps": "pashto",
+ "tk": "turkmen",
+ "nn": "nynorsk",
+ "mt": "maltese",
+ "sa": "sanskrit",
+ "lb": "luxembourgish",
+ "my": "myanmar",
+ "bo": "tibetan",
+ "tl": "tagalog",
+ "mg": "malagasy",
+ "as": "assamese",
+ "tt": "tatar",
+ "haw": "hawaiian",
+ "ln": "lingala",
+ "ha": "hausa",
+ "ba": "bashkir",
+ "jw": "javanese",
+ "su": "sundanese",
+}
+
+# language code lookup by name, with a few language aliases
+TO_LANGUAGE_CODE = {
+ **{language: code for code, language in LANGUAGES.items()},
+ "burmese": "my",
+ "valencian": "ca",
+ "flemish": "nl",
+ "haitian": "ht",
+ "letzeburgesch": "lb",
+ "pushto": "ps",
+ "panjabi": "pa",
+ "moldavian": "ro",
+ "moldovan": "ro",
+ "sinhalese": "si",
+ "castilian": "es",
+}
+
+
+@dataclass(frozen=True)
+class Tokenizer:
+ """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
+
+ tokenizer: "GPT2TokenizerFast"
+ language: Optional[str]
+ sot_sequence: Tuple[int]
+
+ def encode(self, text, **kwargs):
+ return self.tokenizer.encode(text, **kwargs)
+
+ def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
+ return self.tokenizer.decode(token_ids, **kwargs)
+
+ def decode_with_timestamps(self, tokens) -> str:
+ """
+ Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
+ This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
+ """
+ outputs = [[]]
+ for token in tokens:
+ if token >= self.timestamp_begin:
+ timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
+ outputs.append(timestamp)
+ outputs.append([])
+ else:
+ outputs[-1].append(token)
+ outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
+ return "".join(outputs)
+
+ @property
+ @lru_cache()
+ def eot(self) -> int:
+ return self.tokenizer.eos_token_id
+
+ @property
+ @lru_cache()
+ def sot(self) -> int:
+ return self._get_single_token_id("<|startoftranscript|>")
+
+ @property
+ @lru_cache()
+ def sot_lm(self) -> int:
+ return self._get_single_token_id("<|startoflm|>")
+
+ @property
+ @lru_cache()
+ def sot_prev(self) -> int:
+ return self._get_single_token_id("<|startofprev|>")
+
+ @property
+ @lru_cache()
+ def no_speech(self) -> int:
+ return self._get_single_token_id("<|nospeech|>")
+
+ @property
+ @lru_cache()
+ def no_timestamps(self) -> int:
+ return self._get_single_token_id("<|notimestamps|>")
+
+ @property
+ @lru_cache()
+ def timestamp_begin(self) -> int:
+ return self.tokenizer.all_special_ids[-1] + 1
+
+ @property
+ @lru_cache()
+ def language_token(self) -> int:
+ """Returns the token id corresponding to the value of the `language` field"""
+ if self.language is None:
+ raise ValueError(f"This tokenizer does not have language token configured")
+
+ additional_tokens = dict(
+ zip(
+ self.tokenizer.additional_special_tokens,
+ self.tokenizer.additional_special_tokens_ids,
+ )
+ )
+ candidate = f"<|{self.language}|>"
+ if candidate in additional_tokens:
+ return additional_tokens[candidate]
+
+ raise KeyError(f"Language {self.language} not found in tokenizer.")
+
+ @property
+ @lru_cache()
+ def all_language_tokens(self) -> Tuple[int]:
+ result = []
+ for token, token_id in zip(
+ self.tokenizer.additional_special_tokens,
+ self.tokenizer.additional_special_tokens_ids,
+ ):
+ if token.strip("<|>") in LANGUAGES:
+ result.append(token_id)
+ return tuple(result)
+
+ @property
+ @lru_cache()
+ def all_language_codes(self) -> Tuple[str]:
+ return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
+
+ @property
+ @lru_cache()
+ def sot_sequence_including_notimestamps(self) -> Tuple[int]:
+ return tuple(list(self.sot_sequence) + [self.no_timestamps])
+
+ @property
+ @lru_cache()
+ def non_speech_tokens(self) -> Tuple[int]:
+ """
+ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
+ annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
+
+ - ♪♪♪
+ - ( SPEAKING FOREIGN LANGUAGE )
+ - [DAVID] Hey there,
+
+ keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
+ """
+ symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
+ symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
+
+ # symbols that may be a single token or multiple tokens depending on the tokenizer.
+ # In case they're multiple tokens, suppress the first token, which is safe because:
+ # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
+ # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
+ miscellaneous = set("♩♪♫♬♭♮♯")
+ assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
+
+ # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
+ result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
+ for symbol in symbols + list(miscellaneous):
+ for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
+ if len(tokens) == 1 or symbol in miscellaneous:
+ result.add(tokens[0])
+
+ return tuple(sorted(result))
+
+ def _get_single_token_id(self, text) -> int:
+ tokens = self.tokenizer.encode(text)
+ assert len(tokens) == 1, f"{text} is not encoded as a single token"
+ return tokens[0]
+
+
+@lru_cache(maxsize=None)
+def build_tokenizer(name: str = "gpt2"):
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+ path = os.path.join(os.path.dirname(__file__), "assets", name)
+ tokenizer = GPT2TokenizerFast.from_pretrained(path)
+
+ specials = [
+ "<|startoftranscript|>",
+ *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
+ "<|translate|>",
+ "<|transcribe|>",
+ "<|startoflm|>",
+ "<|startofprev|>",
+ "<|nospeech|>",
+ "<|notimestamps|>",
+ ]
+
+ tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
+ return tokenizer
+
+
+@lru_cache(maxsize=None)
+def get_tokenizer(
+ multilingual: bool,
+ *,
+ task: Optional[str] = None, # Literal["transcribe", "translate", None]
+ language: Optional[str] = None,
+) -> Tokenizer:
+ if language is not None:
+ language = language.lower()
+ if language not in LANGUAGES:
+ if language in TO_LANGUAGE_CODE:
+ language = TO_LANGUAGE_CODE[language]
+ else:
+ raise ValueError(f"Unsupported language: {language}")
+
+ if multilingual:
+ tokenizer_name = "multilingual"
+ task = task or "transcribe"
+ language = language or "en"
+ else:
+ tokenizer_name = "gpt2"
+ task = None
+ language = None
+
+ tokenizer = build_tokenizer(name=tokenizer_name)
+ all_special_ids: List[int] = tokenizer.all_special_ids
+ sot: int = all_special_ids[1]
+ translate: int = all_special_ids[-6]
+ transcribe: int = all_special_ids[-5]
+
+ langs = tuple(LANGUAGES.keys())
+ sot_sequence = [sot]
+ if language is not None:
+ sot_sequence.append(sot + 1 + langs.index(language))
+ if task is not None:
+ sot_sequence.append(transcribe if task == "transcribe" else translate)
+
+ return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
diff --git a/whisper/utils.py b/whisper/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dacc173c40bcd6e999d728862e29a968000b12e
--- /dev/null
+++ b/whisper/utils.py
@@ -0,0 +1,163 @@
+import json
+import os
+import sys
+import zlib
+from typing import Callable, TextIO
+
+system_encoding = sys.getdefaultencoding()
+
+if system_encoding != "utf-8":
+ def make_safe(string):
+ # replaces any character not representable using the system default encoding with an '?',
+ # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
+ return string.encode(system_encoding, errors="replace").decode(system_encoding)
+else:
+ def make_safe(string):
+ # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
+ return string
+
+
+def exact_div(x, y):
+ assert x % y == 0
+ return x // y
+
+
+def str2bool(string):
+ str2val = {"True": True, "False": False}
+ if string in str2val:
+ return str2val[string]
+ else:
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
+
+
+def optional_int(string):
+ return None if string == "None" else int(string)
+
+
+def optional_float(string):
+ return None if string == "None" else float(string)
+
+
+def compression_ratio(text) -> float:
+ text_bytes = text.encode("utf-8")
+ return len(text_bytes) / len(zlib.compress(text_bytes))
+
+
+def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
+ assert seconds >= 0, "non-negative timestamp expected"
+ milliseconds = round(seconds * 1000.0)
+
+ hours = milliseconds // 3_600_000
+ milliseconds -= hours * 3_600_000
+
+ minutes = milliseconds // 60_000
+ milliseconds -= minutes * 60_000
+
+ seconds = milliseconds // 1_000
+ milliseconds -= seconds * 1_000
+
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
+
+
+class ResultWriter:
+ extension: str
+
+ def __init__(self, output_dir: str):
+ self.output_dir = output_dir
+
+ def __call__(self, result: dict, audio_path: str):
+ audio_basename = os.path.basename(audio_path)
+ output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
+
+ with open(output_path, "w", encoding="utf-8") as f:
+ self.write_result(result, file=f)
+
+ def write_result(self, result: dict, file: TextIO):
+ raise NotImplementedError
+
+
+class WriteTXT(ResultWriter):
+ extension: str = "txt"
+
+ def write_result(self, result: dict, file: TextIO):
+ for segment in result["segments"]:
+ print(segment['text'].strip(), file=file, flush=True)
+
+
+class WriteVTT(ResultWriter):
+ extension: str = "vtt"
+
+ def write_result(self, result: dict, file: TextIO):
+ print("WEBVTT\n", file=file)
+ for segment in result["segments"]:
+ print(
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
+ f"{segment['text'].strip().replace('-->', '->')}\n",
+ file=file,
+ flush=True,
+ )
+
+
+class WriteSRT(ResultWriter):
+ extension: str = "srt"
+
+ def write_result(self, result: dict, file: TextIO):
+ for i, segment in enumerate(result["segments"], start=1):
+ # write srt lines
+ print(
+ f"{i}\n"
+ f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
+ f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
+ f"{segment['text'].strip().replace('-->', '->')}\n",
+ file=file,
+ flush=True,
+ )
+
+
+class WriteTSV(ResultWriter):
+ """
+ Write a transcript to a file in TSV (tab-separated values) format containing lines like:
+ \t\t
+
+ Using integer milliseconds as start and end times means there's no chance of interference from
+ an environment setting a language encoding that causes the decimal in a floating point number
+ to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
+ """
+ extension: str = "tsv"
+
+ def write_result(self, result: dict, file: TextIO):
+ print("start", "end", "text", sep="\t", file=file)
+ for segment in result["segments"]:
+ print(round(1000 * segment['start']), file=file, end="\t")
+ print(round(1000 * segment['end']), file=file, end="\t")
+ print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
+
+
+class WriteJSON(ResultWriter):
+ extension: str = "json"
+
+ def write_result(self, result: dict, file: TextIO):
+ json.dump(result, file)
+
+
+def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
+ writers = {
+ "txt": WriteTXT,
+ "vtt": WriteVTT,
+ "srt": WriteSRT,
+ "tsv": WriteTSV,
+ "json": WriteJSON,
+ }
+
+ if output_format == "all":
+ all_writers = [writer(output_dir) for writer in writers.values()]
+
+ def write_all(result: dict, file: TextIO):
+ for writer in all_writers:
+ writer(result, file)
+
+ return write_all
+
+ return writers[output_format](output_dir)
+
diff --git a/whisper_pretrain/.DS_Store b/whisper_pretrain/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6
Binary files /dev/null and b/whisper_pretrain/.DS_Store differ
diff --git a/whisper_pretrain/README.md b/whisper_pretrain/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f615cae0154333d0d3c778d83fc5263d30990b54
--- /dev/null
+++ b/whisper_pretrain/README.md
@@ -0,0 +1,3 @@
+Path for:
+
+ large-v2.pt