diff --git a/.dockerignore b/.dockerignore index da10b929bbf1260f01287b7abec2b980575de832..bf1b3a0378eee07e7e2d9a8480820a14bd5093da 100644 --- a/.dockerignore +++ b/.dockerignore @@ -6,12 +6,13 @@ !/style_bert_vits2/ !/bert/deberta-v2-large-japanese-char-wwm/ -!/common/ !/configs/ !/dict_data/default.csv !/model_assets/ +!/static/ !/config.py !/default_config.yml +!/initialize.py !/requirements.txt !/server_editor.py diff --git a/.github/workflows/update_space.yml b/.github/workflows/update_space.yml new file mode 100644 index 0000000000000000000000000000000000000000..67dbc84e4e59320a7c98b94460eb976e5cd2984f --- /dev/null +++ b/.github/workflows/update_space.yml @@ -0,0 +1,28 @@ +name: Run Python script + +on: + push: + branches: + - main + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + + - name: Install Gradio + run: python -m pip install gradio + + - name: Log in to Hugging Face + run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")' + + - name: Deploy to Spaces + run: gradio deploy diff --git a/.gitignore b/.gitignore index aa114e3fc92d0c690fce848c53bf7c60b0bf189b..6ca8d26c695d59390d770fdb1c948858f6f87907 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,8 @@ dist/ /bert/*/*.safetensors /bert/*/*.msgpack +/configs/paths.yml + /pretrained/*.safetensors /pretrained/*.pth @@ -37,3 +39,5 @@ safetensors.ipynb # pyopenjtalk's dictionary *.dic + +playground.ipynb diff --git a/.gradio/certificate.pem b/.gradio/certificate.pem new file mode 100644 index 0000000000000000000000000000000000000000..b85c8037f6b60976b2546fdbae88312c5246d9a3 --- /dev/null +++ b/.gradio/certificate.pem @@ -0,0 +1,31 @@ +-----BEGIN CERTIFICATE----- +MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw +TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh +cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 +WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu +ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY +MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc +h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ +0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U +A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW +T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH +B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC +B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv +KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn +OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn +jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw +qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI +rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV +HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq +hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL +ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ +3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK +NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 +ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur +TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC +jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc +oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq +4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA +mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d +emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= +-----END CERTIFICATE----- diff --git a/App.bat b/App.bat index 2f2b6534bdb03d42eac717940cc35f431207db3c..99498ab03f75c5b0a4f4ca4bb212a537a87b1a59 100644 --- a/App.bat +++ b/App.bat @@ -1,11 +1,11 @@ -chcp 65001 > NUL -@echo off - -pushd %~dp0 -echo Running app.py... -venv\Scripts\python app.py - -if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) - -popd +chcp 65001 > NUL +@echo off + +pushd %~dp0 +echo Running app.py... +venv\Scripts\python app.py + +if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) + +popd pause \ No newline at end of file diff --git a/Dataset.bat b/Dataset.bat new file mode 100644 index 0000000000000000000000000000000000000000..65a938d625d041342a3759762f9cec6d93bf69b7 --- /dev/null +++ b/Dataset.bat @@ -0,0 +1,11 @@ +chcp 65001 > NUL +@echo off + +pushd %~dp0 +echo Running gradio_tabs/dataset.py... +venv\Scripts\python -m gradio_tabs.dataset + +if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) + +popd +pause \ No newline at end of file diff --git a/Dockerfile.deploy b/Dockerfile.deploy index 7c6aaf6bc83b19dbb9674ce65d0a96cfe243b01b..48c22d0b8d4f0b2cd53cc467d0103274a6944125 100644 --- a/Dockerfile.deploy +++ b/Dockerfile.deploy @@ -20,4 +20,4 @@ COPY --chown=user . $HOME/app RUN pip install --no-cache-dir -r $HOME/app/requirements.txt # 必要に応じて制限を変更してください -CMD ["python", "app.py", "--share" "--line_length", "50", "--line_count", "3"] +CMD ["python", "server_editor.py", "--line_length", "50", "--line_count", "3", "--skip_static_files"] diff --git a/Editor.bat b/Editor.bat index 8405cd5e9273f30e54c2836ce08e488fdfde9452..43cb2a88506edc84f1d936b5bd5b6bd2f6361fc7 100644 --- a/Editor.bat +++ b/Editor.bat @@ -1,11 +1,11 @@ -chcp 65001 > NUL -@echo off - -pushd %~dp0 -echo Running server_editor.py --inbrowser -venv\Scripts\python server_editor.py --inbrowser - -if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) - -popd +chcp 65001 > NUL +@echo off + +pushd %~dp0 +echo Running server_editor.py --inbrowser +venv\Scripts\python server_editor.py --inbrowser + +if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) + +popd pause \ No newline at end of file diff --git a/Inference.bat b/Inference.bat new file mode 100644 index 0000000000000000000000000000000000000000..9f7ea2e720c1f6dd7f8d27fc40a6af0db047dd97 --- /dev/null +++ b/Inference.bat @@ -0,0 +1,11 @@ +chcp 65001 > NUL +@echo off + +pushd %~dp0 +echo Running gradio_tabs/inference.py... +venv\Scripts\python -m gradio_tabs.inference + +if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) + +popd +pause \ No newline at end of file diff --git a/Initialize.bat b/Initialize.bat new file mode 100644 index 0000000000000000000000000000000000000000..b2a5aaf0404e297d286ad52a844d0cd82781328c --- /dev/null +++ b/Initialize.bat @@ -0,0 +1,11 @@ +chcp 65001 > NUL +@echo off + +pushd %~dp0 +echo Running initialize.py... +venv\Scripts\python initialize.py + +if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) + +popd +pause \ No newline at end of file diff --git a/Merge.bat b/Merge.bat new file mode 100644 index 0000000000000000000000000000000000000000..6490d18095f5c325ba30bb9150d6e92c339b8c38 --- /dev/null +++ b/Merge.bat @@ -0,0 +1,11 @@ +chcp 65001 > NUL +@echo off + +pushd %~dp0 +echo Running gradio_tabs/merge.py... +venv\Scripts\python -m gradio_tabs.merge + +if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) + +popd +pause \ No newline at end of file diff --git a/README.md b/README.md index bd39e64f29b435c4913285c2d22b50cde48e6370..3525fd7edb6feddff040190532b82f87e6b04d0c 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,12 @@ title: Style-Bert-VITS2 app_file: app.py sdk: gradio -sdk_version: 4.23.0 +sdk_version: 5.16.0 --- # Style-Bert-VITS2 +**利用の際は必ず[お願いとデフォルトモデルの利用規約](/docs/TERMS_OF_USE.md)をお読みください。** + Bert-VITS2 with more controllable voice styles. https://github.com/litagin02/Style-Bert-VITS2/assets/139731664/e853f9a2-db4a-4202-a1dd-56ded3c562a0 @@ -13,13 +15,16 @@ https://github.com/litagin02/Style-Bert-VITS2/assets/139731664/e853f9a2-db4a-420 You can install via `pip install style-bert-vits2` (inference only), see [library.ipynb](/library.ipynb) for example usage. - **解説チュートリアル動画** [YouTube](https://youtu.be/aTUSzgDl1iY) [ニコニコ動画](https://www.nicovideo.jp/watch/sm43391524) -- [English README](docs/README_en.md) - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/litagin02/Style-Bert-VITS2/blob/master/colab.ipynb) +- [**よくある質問** (FAQ)](/docs/FAQ.md) - [🤗 オンラインデモはこちらから](https://huggingface.co/spaces/litagin/Style-Bert-VITS2-Editor-Demo) - [Zennの解説記事](https://zenn.dev/litagin/articles/034819a5256ff4) - [**リリースページ**](https://github.com/litagin02/Style-Bert-VITS2/releases/)、[更新履歴](/docs/CHANGELOG.md) - + - 2024-09-09: Ver 2.6.1: Google colabでうまく学習できない等のバグ修正のみ + - 2024-06-16: Ver 2.6.0 (モデルの差分マージ・加重マージ・ヌルモデルマージの追加、使い道については[この記事](https://zenn.dev/litagin/articles/1297b1dc7bdc79)参照) + - 2024-06-14: Ver 2.5.1 (利用規約をお願いへ変更したのみ) + - 2024-06-02: Ver 2.5.0 (**[利用規約](/docs/TERMS_OF_USE.md)の追加**、フォルダ分けからのスタイル生成、小春音アミ・あみたろモデルの追加、インストールの高速化等) - 2024-03-16: ver 2.4.1 (**batファイルによるインストール方法の変更**) - 2024-03-15: ver 2.4.0 (大規模リファクタリングや種々の改良、ライブラリ化) - 2024-02-26: ver 2.3 (辞書機能とエディター機能) @@ -38,13 +43,15 @@ This repository is based on [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2 - 入力されたテキストの内容をもとに感情豊かな音声を生成する[Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)のv2.1とJapanese-Extraを元に、感情や発話スタイルを強弱込みで自由に制御できるようにしたものです。 - GitやPythonがない人でも(Windowsユーザーなら)簡単にインストールでき、学習もできます (多くを[EasyBertVits2](https://github.com/Zuntan03/EasyBertVits2/)からお借りしました)。またGoogle Colabでの学習もサポートしています: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/litagin02/Style-Bert-VITS2/blob/master/colab.ipynb) - 音声合成のみに使う場合は、グラボがなくてもCPUで動作します。 +- 音声合成のみに使う場合、Pythonライブラリとして`pip install style-bert-vits2`でインストールできます。例は[library.ipynb](/library.ipynb)を参照してください。 - 他との連携に使えるAPIサーバーも同梱しています ([@darai0512](https://github.com/darai0512) 様によるPRです、ありがとうございます)。 - 元々「楽しそうな文章は楽しそうに、悲しそうな文章は悲しそうに」読むのがBert-VITS2の強みですので、スタイル指定がデフォルトでも感情豊かな音声を生成することができます。 ## 使い方 -CLIでの使い方は[こちら](/docs/CLI.md)を参照してください。 +- CLIでの使い方は[こちら](/docs/CLI.md)を参照してください。 +- [よくある質問](/docs/FAQ.md)も参照してください。 ### 動作環境 @@ -58,7 +65,7 @@ Pythonライブラリとしてのpipでのインストールや使用例は[libr Windowsを前提としています。 -1. [このzipファイル](https://github.com/litagin02/Style-Bert-VITS2/releases/download/2.4.1/sbv2.zip)を**パスに日本語や空白が含まれない場所に**ダウンロードして展開します。 +1. [このzipファイル](https://github.com/litagin02/Style-Bert-VITS2/releases/download/2.6.0/sbv2.zip)を**パスに日本語や空白が含まれない場所に**ダウンロードして展開します。 - グラボがある方は、`Install-Style-Bert-VITS2.bat`をダブルクリックします。 - グラボがない方は、`Install-Style-Bert-VITS2-CPU.bat`をダブルクリックします。CPU版では学習はできませんが、音声合成とマージは可能です。 2. 待つと自動で必要な環境がインストールされます。 @@ -70,13 +77,17 @@ Windowsを前提としています。 #### GitやPython使える人 +Pythonの仮想環境・パッケージ管理ツールである[uv](https://github.com/astral-sh/uv)がpipより高速なので、それを使ってインストールすることをお勧めします。 +(使いたくない場合は通常のpipでも大丈夫です。) + ```bash +powershell -c "irm https://astral.sh/uv/install.ps1 | iex" git clone https://github.com/litagin02/Style-Bert-VITS2.git cd Style-Bert-VITS2 -python -m venv venv +uv venv venv venv\Scripts\activate -pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 -pip install -r requirements.txt +uv pip install "torch<2.4" "torchaudio<2.4" --index-url https://download.pytorch.org/whl/cu118 +uv pip install -r requirements.txt python initialize.py # 必要なモデルとデフォルトTTSモデルをダウンロード ``` 最後を忘れずに。 @@ -88,7 +99,7 @@ python initialize.py # 必要なモデルとデフォルトTTSモデルをダ エディター部分は[別リポジトリ](https://github.com/litagin02/Style-Bert-VITS2-Editor)に分かれています。 -バージョン2.2以前での音声合成WebUIは、`App.bat`をダブルクリックか、`python app.py`するとWebUIが起動します。 +バージョン2.2以前での音声合成WebUIは、`App.bat`をダブルクリックか、`python app.py`するとWebUIが起動します。または`Inference.bat`でも音声合成単独タブが開きます。 音声合成に必要なモデルファイルたちの構造は以下の通りです(手動で配置する必要はありません)。 ``` @@ -119,23 +130,19 @@ model_assets #### データセット作り -- `App.bat`をダブルクリックか`python app.py`したところの「データセット作成」タブから、音声ファイルを適切な長さにスライスし、その後に文字の書き起こしを自動で行えます。 +- `App.bat`をダブルクリックか`python app.py`したところの「データセット作成」タブから、音声ファイルを適切な長さにスライスし、その後に文字の書き起こしを自動で行えます。または`Dataset.bat`をダブルクリックでもその単独タブが開きます。 - 指示に従った後、下の「学習」タブでそのまま学習を行うことができます。 -注意: データセットの手動修正やノイズ除去等、細かい修正を行いたい場合は[Aivis](https://github.com/tsukumijima/Aivis)や、そのデータセット部分のWindows対応版 [Aivis Dataset](https://github.com/litagin02/Aivis-Dataset) を使うといいかもしれません。ですがファイル数が多い場合などは、このツールで簡易的に切り出してデータセットを作るだけでも十分という気もしています。 - -データセットがどのようなものがいいかは各自試行錯誤中してください。 - #### 学習WebUI -- `App.bat`をダブルクリックか`python app.py`して開くWebUIの「学習」タブから指示に従ってください。 +- `App.bat`をダブルクリックか`python app.py`して開くWebUIの「学習」タブから指示に従ってください。または`Train.bat`をダブルクリックでもその単独タブが開きます。 ### スタイルの生成 -- デフォルトスタイル「Neutral」以外のスタイルを使いたい人向けです。 -- `App.bat`をダブルクリックか`python app.py`して開くWebUIの「スタイル作成」タブから、音声ファイルを使ってスタイルを生成できます。 +- デフォルトでは、デフォルトスタイル「Neutral」の他、学習フォルダのフォルダ分けに応じたスタイルが生成されます。 +- それ以外の方法で手動でスタイルを作成したい人向けです。 +- `App.bat`をダブルクリックか`python app.py`して開くWebUIの「スタイル作成」タブから、音声ファイルを使ってスタイルを生成できます。または`StyleVectors.bat`をダブルクリックでもその単独タブが開きます。 - 学習とは独立しているので、学習中でもできるし、学習が終わっても何度もやりなおせます(前処理は終わらせている必要があります)。 -- スタイルについての仕様の詳細は[clustering.ipynb](clustering.ipynb)を参照してください。 ### API Server @@ -151,8 +158,8 @@ API仕様は起動後に`/docs`にて確認ください。 ### マージ -2つのモデルを、「声質」「声の高さ」「感情表現」「テンポ」の4点で混ぜ合わせて、新しいモデルを作ることが出来ます。 -`App.bat`をダブルクリックか`python app.py`して開くWebUIの「マージ」タブから、2つのモデルを選択してマージすることができます。 +2つのモデルを、「声質」「声の高さ」「感情表現」「テンポ」の4点で混ぜ合わせて、新しいモデルを作ったり、また「あるモデルに、別の2つのモデルの差分を足す」等の操作ができます。 +`App.bat`をダブルクリックか`python app.py`して開くWebUIの「マージ」タブから、2つのモデルを選択してマージすることができます。または`Merge.bat`をダブルクリックでもその単独タブが開きます。 ### 自然性評価 diff --git a/Server.bat b/Server.bat index cec9b1de8f642feaf2f8e35460e3c022a853bee4..816cc45473999f6fc5b79d351b344a51f1b25aa1 100644 --- a/Server.bat +++ b/Server.bat @@ -1,11 +1,11 @@ -chcp 65001 > NUL -@echo off - -pushd %~dp0 -echo Running server_fastapi.py -venv\Scripts\python server_fastapi.py - -if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) - -popd +chcp 65001 > NUL +@echo off + +pushd %~dp0 +echo Running server_fastapi.py +venv\Scripts\python server_fastapi.py + +if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) + +popd pause \ No newline at end of file diff --git a/StyleVectors.bat b/StyleVectors.bat new file mode 100644 index 0000000000000000000000000000000000000000..d6dd3d5af405d351ac8e44954d82f7ce6558300c --- /dev/null +++ b/StyleVectors.bat @@ -0,0 +1,11 @@ +chcp 65001 > NUL +@echo off + +pushd %~dp0 +echo Running gradio_tabs/style_vectors.py... +venv\Scripts\python -m gradio_tabs.style_vectors + +if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) + +popd +pause \ No newline at end of file diff --git a/Train.bat b/Train.bat new file mode 100644 index 0000000000000000000000000000000000000000..e0e72d688f28bdabe6705fc953af5fa1e3425283 --- /dev/null +++ b/Train.bat @@ -0,0 +1,11 @@ +chcp 65001 > NUL +@echo off + +pushd %~dp0 +echo Running gradio_tabs/train.py... +venv\Scripts\python -m gradio_tabs.train + +if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) + +popd +pause \ No newline at end of file diff --git a/app.py b/app.py index f556a2377835c659e4b6e69f56f11177aceca007..987739b6a9f57c546ba6492b1c5abdd5d73cc408 100644 --- a/app.py +++ b/app.py @@ -3,8 +3,8 @@ from pathlib import Path import gradio as gr import torch -import yaml +from config import get_path_config from gradio_tabs.dataset import create_dataset_app from gradio_tabs.inference import create_inference_app from gradio_tabs.merge import create_merge_app @@ -22,11 +22,6 @@ pyopenjtalk_worker.initialize_worker() # dict_data/ 以下の辞書データを pyopenjtalk に適用 update_dict() -# Get path settings -with Path("configs/paths.yml").open("r", encoding="utf-8") as f: - path_config: dict[str, str] = yaml.safe_load(f.read()) - # dataset_root = path_config["dataset_root"] - assets_root = path_config["assets_root"] parser = argparse.ArgumentParser() parser.add_argument("--device", type=str, default="cuda") @@ -34,13 +29,18 @@ parser.add_argument("--host", type=str, default="127.0.0.1") parser.add_argument("--port", type=int, default=None) parser.add_argument("--no_autolaunch", action="store_true") parser.add_argument("--share", action="store_true") +# parser.add_argument("--skip_default_models", action="store_true") args = parser.parse_args() device = args.device if device == "cuda" and not torch.cuda.is_available(): device = "cpu" -model_holder = TTSModelHolder(Path(assets_root), device) +# if not args.skip_default_models: +# download_default_models() + +path_config = get_path_config() +model_holder = TTSModelHolder(Path(path_config.assets_root), device) with gr.Blocks(theme=GRADIO_THEME) as app: gr.Markdown(f"# Style-Bert-VITS2 WebUI (version {VERSION})") @@ -56,7 +56,6 @@ with gr.Blocks(theme=GRADIO_THEME) as app: with gr.Tab("マージ"): create_merge_app(model_holder=model_holder) - app.launch( server_name=args.host, server_port=args.port, diff --git a/bert_gen.py b/bert_gen.py index 1dcab9eb2ccb23c04dd59c5afb905d0700826cba..c4995cb16954c6a0f63f0914a178e54ad53cb7f5 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -5,21 +5,18 @@ import torch import torch.multiprocessing as mp from tqdm import tqdm -from config import config +from config import get_config from style_bert_vits2.constants import Languages from style_bert_vits2.logging import logger from style_bert_vits2.models import commons from style_bert_vits2.models.hyper_parameters import HyperParameters -from style_bert_vits2.nlp import ( - bert_models, - cleaned_text_to_sequence, - extract_bert_feature, -) +from style_bert_vits2.nlp import cleaned_text_to_sequence, extract_bert_feature from style_bert_vits2.nlp.japanese import pyopenjtalk_worker from style_bert_vits2.nlp.japanese.user_dict import update_dict from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT +config = get_config() # このプロセスからはワーカーを起動して辞書を使いたいので、ここで初期化 pyopenjtalk_worker.initialize_worker() @@ -61,7 +58,7 @@ def process_line(x: tuple[str, bool]): bert = torch.load(bert_path) assert bert.shape[-1] == len(phone) except Exception: - bert = extract_bert_feature(text, word2ph, language_str, device) + bert = extract_bert_feature(text, word2ph, Languages(language_str), device) assert bert.shape[-1] == len(phone) torch.save(bert, bert_path) @@ -77,10 +74,10 @@ if __name__ == "__main__": config_path = args.config hps = HyperParameters.load_from_json(config_path) lines: list[str] = [] - with open(hps.data.training_files, "r", encoding="utf-8") as f: + with open(hps.data.training_files, encoding="utf-8") as f: lines.extend(f.readlines()) - with open(hps.data.validation_files, "r", encoding="utf-8") as f: + with open(hps.data.validation_files, encoding="utf-8") as f: lines.extend(f.readlines()) add_blank = [hps.data.add_blank] * len(lines) diff --git a/colab.ipynb b/colab.ipynb index 4f21124c71f23d6fffdd0fd0ea1fe25f18f1cfa7..c400e65b16b6e233576c45ba2ccb41730392492b 100644 --- a/colab.ipynb +++ b/colab.ipynb @@ -1,384 +1,455 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Style-Bert-VITS2 (ver 2.4.1) のGoogle Colabでの学習\n", - "\n", - "Google Colab上でStyle-Bert-VITS2の学習を行うことができます。\n", - "\n", - "このnotebookでは、通常使用ではあなたのGoogle Driveにフォルダ`Style-Bert-VITS2`を作り、その内部での作業を行います。他のフォルダには触れません。\n", - "Google Driveを使わない場合は、初期設定のところで適切なパスを指定してください。\n", - "\n", - "## 流れ\n", - "\n", - "### 学習を最初からやりたいとき\n", - "上から順に実行していけばいいです。音声合成に必要なファイルはGoogle Driveの`Style-Bert-VITS2/model_assets/`に保存されます。また、途中経過も`Style-Bert-VITS2/Data/`に保存されるので、学習を中断したり、途中から再開することもできます。\n", - "\n", - "### 学習を途中から再開したいとき\n", - "0と1を行い、3の前処理は飛ばして、4から始めてください。スタイル分け5は、学習が終わったら必要なら行ってください。\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 0. 環境構築\n", - "\n", - "Style-Bert-VITS2の環境をcolab上に構築します。グラボモードが有効になっていることを確認し、以下のセルを順に実行してください。\n", - "\n", - "**最近のcolabのアップデートにより、エラーダイアログ「WARNING: The following packages were previously imported in this runtime: [pydevd_plugins]」が出るが、「キャンセル」を選択して続行してください。**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# このセルを実行して環境構築してください。\n", - "# エラーダイアログ「WARNING: The following packages were previously imported in this runtime: [pydevd_plugins]」が出るが「キャンセル」を選択して続行してください。\n", - "\n", - "!git clone https://github.com/litagin02/Style-Bert-VITS2.git\n", - "%cd Style-Bert-VITS2/\n", - "!pip install -r requirements.txt\n", - "!python initialize.py --skip_jvnv" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Google driveを使う方はこちらを実行してください。\n", - "\n", - "from google.colab import drive\n", - "drive.mount(\"/content/drive\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. 初期設定\n", - "\n", - "学習とその結果を保存するディレクトリ名を指定します。\n", - "Google driveの場合はそのまま実行、カスタマイズしたい方は変更して実行してください。" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# 学習に必要なファイルや途中経過が保存されるディレクトリ\n", - "dataset_root = \"/content/drive/MyDrive/Style-Bert-VITS2/Data\"\n", - "\n", - "# 学習結果(音声合成に必要なファイルたち)が保存されるディレクトリ\n", - "assets_root = \"/content/drive/MyDrive/Style-Bert-VITS2/model_assets\"\n", - "\n", - "import yaml\n", - "\n", - "\n", - "with open(\"configs/paths.yml\", \"w\", encoding=\"utf-8\") as f:\n", - " yaml.dump({\"dataset_root\": dataset_root, \"assets_root\": assets_root}, f)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. 学習に使うデータ準備\n", - "\n", - "すでに音声ファイル(1ファイル2-12秒程度)とその書き起こしデータがある場合は2.2を、ない場合は2.1を実行してください。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.1 音声ファイルからのデータセットの作成(ある人はスキップ可)\n", - "\n", - "音声ファイル(1ファイル2-12秒程度)とその書き起こしのデータセットを持っていない方は、(日本語の)音声ファイルのみから以下の手順でデータセットを作成することができます。Google drive上の`Style-Bert-VITS2/inputs/`フォルダに音声ファイル(wavファイル形式、1ファイルでも複数ファイルでも可)を置いて、下を実行すると、データセットが作られ、自動的に正しい場所へ配置されます。" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# 元となる音声ファイル(wav形式)を入れるディレクトリ\n", - "input_dir = \"/content/drive/MyDrive/Style-Bert-VITS2/inputs\"\n", - "# モデル名(話者名)を入力\n", - "model_name = \"your_model_name\"\n", - "\n", - "# こういうふうに書き起こして欲しいという例文(句読点の入れ方・笑い方や固有名詞等)\n", - "initial_prompt = \"こんにちは。元気、ですかー?ふふっ、私は……ちゃんと元気だよ!\"\n", - "\n", - "!python slice.py -i {input_dir} --model_name {model_name}\n", - "!python transcribe.py --model_name {model_name} --initial_prompt {initial_prompt} --use_hf_whisper" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "成功したらそのまま3へ進んでください" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2.2 音声ファイルと書き起こしデータがすでにある場合\n", - "\n", - "指示に従って適切にデータセットを配置してください。\n", - "\n", - "次のセルを実行して、学習データをいれるフォルダ(1で設定した`dataset_root`)を作成します。" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "esCNJl704h52" - }, - "outputs": [], - "source": [ - "import os\n", - "\n", - "os.makedirs(dataset_root, exist_ok=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "次に、学習に必要なデータを、Google driveに作成された`Style-Bert-VITS2/Data`フォルダに配置します。\n", - "\n", - "まず音声データ(wavファイルで1ファイルが2-12秒程度の、長すぎず短すぎない発話のものをいくつか)と、書き起こしテキストを用意してください。wavファイル名やモデルの名前は空白を含まない半角で、wavファイルの拡張子は小文字`.wav`である必要があります。\n", - "\n", - "書き起こしテキストは、次の形式で記述してください。\n", - "```\n", - "****.wav|{話者名}|{言語ID、ZHかJPかEN}|{書き起こしテキスト}\n", - "```\n", - "\n", - "例:\n", - "```\n", - "wav_number1.wav|hanako|JP|こんにちは、聞こえて、いますか?\n", - "wav_next.wav|taro|JP|はい、聞こえています……。\n", - "english_teacher.wav|Mary|EN|How are you? I'm fine, thank you, and you?\n", - "...\n", - "```\n", - "日本語話者の単一話者データセットで構いません。\n", - "\n", - "### データセットの配置\n", - "\n", - "次にモデルの名前を適当に決めてください(空白を含まない半角英数字がよいです)。\n", - "そして、書き起こしファイルを`esd.list`という名前で保存し、またwavファイルも`raw`というフォルダを作成し、あなたのGoogle Driveの中の(上で自動的に作られるはずの)`Data`フォルダのなかに、次のように配置します。\n", - "```\n", - "├── Data\n", - "│ ├── {モデルの名前}\n", - "│ │ ├── esd.list\n", - "│ │ ├── raw\n", - "│ │ │ ├── ****.wav\n", - "│ │ │ ├── ****.wav\n", - "│ │ │ ├── ...\n", - "```" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5r85-W20ECcr" - }, - "source": [ - "## 3. 学習の前処理\n", - "\n", - "次に学習の前処理を行います。必要なパラメータをここで指定します。次のセルに設定等を入力して実行してください。「~~かどうか」は`True`もしくは`False`を指定してください。" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "CXR7kjuF5GlE" - }, - "outputs": [], - "source": [ - "# 上でつけたフォルダの名前`Data/{model_name}/`\n", - "model_name = \"your_model_name\"\n", - "\n", - "# JP-Extra (日本語特化版)を使うかどうか。日本語の能力が向上する代わりに英語と中国語は使えなくなります。\n", - "use_jp_extra = True\n", - "\n", - "# 学習のバッチサイズ。VRAMのはみ出具合に応じて調整してください。\n", - "batch_size = 4\n", - "\n", - "# 学習のエポック数(データセットを合計何周するか)。\n", - "# 100で多すぎるほどかもしれませんが、もっと多くやると質が上がるのかもしれません。\n", - "epochs = 100\n", - "\n", - "# 保存頻度。何ステップごとにモデルを保存するか。分からなければデフォルトのままで。\n", - "save_every_steps = 1000\n", - "\n", - "# 音声ファイルの音量を正規化するかどうか\n", - "normalize = False\n", - "\n", - "# 音声ファイルの開始・終了にある無音区間を削除するかどうか\n", - "trim = False\n", - "\n", - "# 読みのエラーが出た場合にどうするか。\n", - "# \"raise\"ならテキスト前処理が終わったら中断、\"skip\"なら読めない行は学習に使わない、\"use\"なら無理やり使う\n", - "yomi_error = \"skip\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "上のセルが実行されたら、次のセルを実行して学習の前処理を行います。" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "xMVaOIPLabV5", - "outputId": "15fac868-9132-45d9-9f5f-365b6aeb67b0" - }, - "outputs": [], - "source": [ - "from gradio_tabs.train import preprocess_all\n", - "\n", - "preprocess_all(\n", - " model_name=model_name,\n", - " batch_size=batch_size,\n", - " epochs=epochs,\n", - " save_every_steps=save_every_steps,\n", - " num_processes=2,\n", - " normalize=normalize,\n", - " trim=trim,\n", - " freeze_EN_bert=False,\n", - " freeze_JP_bert=False,\n", - " freeze_ZH_bert=False,\n", - " freeze_style=False,\n", - " freeze_decoder=False, # ここをTrueにするともしかしたら違う結果になるかもしれません。\n", - " use_jp_extra=use_jp_extra,\n", - " val_per_lang=0,\n", - " log_interval=200,\n", - " yomi_error=yomi_error\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. 学習\n", - "\n", - "前処理が正常に終わったら、学習を行います。次のセルを実行すると学習が始まります。\n", - "\n", - "学習の結果は、上で指定した`save_every_steps`の間隔で、Google Driveの中の`Style-Bert-VITS2/Data/{モデルの名前}/model_assets/`フォルダに保存されます。\n", - "\n", - "このフォルダをダウンロードし、ローカルのStyle-Bert-VITS2の`model_assets`フォルダに上書きすれば、学習結果を使うことができます。" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "F7aJhsgLAWvO" + }, + "source": [ + "# Style-Bert-VITS2 (ver 2.6.1) のGoogle Colabでの学習\n", + "\n", + "Google Colab上でStyle-Bert-VITS2の学習を行うことができます。\n", + "\n", + "このnotebookでは、通常使用ではあなたのGoogle Driveにフォルダ`Style-Bert-VITS2`を作り、その内部での作業を行います。他のフォルダには触れません。\n", + "Google Driveを使わない場合は、初期設定のところで適切なパスを指定してください。\n", + "\n", + "## 流れ\n", + "\n", + "### 学習を最初からやりたいとき\n", + "上から順に実行していけばいいです。音声合成に必要なファイルはGoogle Driveの`Style-Bert-VITS2/model_assets/`に保存されます。また、途中経過も`Style-Bert-VITS2/Data/`に保存されるので、学習を中断したり、途中から再開することもできます。\n", + "\n", + "### 学習を途中から再開したいとき\n", + "0と1を行い、3の前処理は飛ばして、4から始めてください。スタイル分け5は、学習が終わったら必要なら行ってください。\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "L-gAIubBAWvQ" + }, + "source": [ + "## 0. 環境構築\n", + "\n", + "Style-Bert-VITS2の環境をcolab上に構築します。ランタイムがT4等のGPUバックエンドになっていることを確認し、実行してください。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "laieKrbEb6Ij", - "outputId": "72238c88-f294-4ed9-84f6-84c1c17999ca" - }, - "outputs": [], - "source": [ - "# 上でつけたモデル名を入力。学習を途中からする場合はきちんとモデルが保存されているフォルダ名を入力。\n", - "model_name = \"your_model_name\"\n", - "\n", - "\n", - "import yaml\n", - "from gradio_tabs.train import get_path\n", - "\n", - "dataset_path, _, _, _, config_path = get_path(model_name)\n", - "\n", - "with open(\"default_config.yml\", \"r\", encoding=\"utf-8\") as f:\n", - " yml_data = yaml.safe_load(f)\n", - "yml_data[\"model_name\"] = model_name\n", - "with open(\"config.yml\", \"w\", encoding=\"utf-8\") as f:\n", - " yaml.dump(yml_data, f, allow_unicode=True)" - ] + "id": "0GNj8JyDAlm2", + "outputId": "d8be4a1a-e52d-46f8-8675-3f1a24bc9a51" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"PATH\"] += \":/root/.cargo/bin\"\n", + "\n", + "!curl -LsSf https://astral.sh/uv/install.sh | sh\n", + "!git clone https://github.com/litagin02/Style-Bert-VITS2.git\n", + "%cd Style-Bert-VITS2/\n", + "!uv pip install --system -r requirements-colab.txt\n", + "!python initialize.py --skip_default_models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# 日本語特化版を「使う」場合\n", - "!python train_ms_jp_extra.py --config {config_path} --model {dataset_path} --assets_root {assets_root}" - ] + "id": "o5z1nzkvAWvR", + "outputId": "cd87f053-18e0-4dbb-f904-d5230d1fa7ef" + }, + "outputs": [], + "source": [ + "# Google driveを使う方はこちらを実行してください。\n", + "\n", + "from google.colab import drive\n", + "\n", + "drive.mount(\"/content/drive\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WU9apXzcAWvR" + }, + "source": [ + "## 1. 初期設定\n", + "\n", + "学習とその結果を保存するディレクトリ名を指定します。\n", + "Google driveの場合はそのまま実行、カスタマイズしたい方は変更して実行してください。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gO3OwZV1AWvR" + }, + "outputs": [], + "source": [ + "# 学習に必要なファイルや途中経過が保存されるディレクトリ\n", + "dataset_root = \"/content/drive/MyDrive/Style-Bert-VITS2/Data\"\n", + "\n", + "# 学習結果(音声合成に必要なファイルたち)が保存されるディレクトリ\n", + "assets_root = \"/content/drive/MyDrive/Style-Bert-VITS2/model_assets\"\n", + "\n", + "import yaml\n", + "\n", + "\n", + "with open(\"configs/paths.yml\", \"w\", encoding=\"utf-8\") as f:\n", + " yaml.dump({\"dataset_root\": dataset_root, \"assets_root\": assets_root}, f)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dA_yLeezAWvS" + }, + "source": [ + "## 2. 学習に使うデータ準備\n", + "\n", + "すでに音声ファイル(1ファイル2-12秒程度)とその書き起こしデータがある場合は2.2を、ない場合は2.1を実行してください。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8s9gOnTCAWvS" + }, + "source": [ + "### 2.1 音声ファイルからのデータセットの作成(ある人はスキップ可)\n", + "\n", + "音声ファイル(1ファイル2-12秒程度)とその書き起こしのデータセットを持っていない方は、(日本語の)音声ファイルのみから以下の手順でデータセットを作成することができます。Google drive上の`Style-Bert-VITS2/inputs/`フォルダに音声ファイル(wavやmp3等の通常の音声ファイル形式、1ファイルでも複数ファイルでも可)を置いて、下を実行すると、データセットが作られ、自動的に正しい場所へ配置されます。\n", + "\n", + "**2024-06-02のVer 2.5以降**、`inputs/`フォルダにサブフォルダを2個以上作ってそこへ音声ファイルをスタイルに応じて振り分けて置くと、学習の際にサブディレクトリに応じたスタイルが自動的に作成されます。デフォルトスタイルのみでよい場合や手動でスタイルを後で作成する場合は`inputs/`直下へ入れれば大丈夫です。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# 日本語特化版を「使わない」場合\n", - "!python train_ms.py --config {config_path} --model {dataset_path} --assets_root {assets_root}" - ] + "id": "_fXCTPuiAWvS", + "outputId": "47abd55b-efe5-48e2-f6fa-8e2016efe0ec" + }, + "outputs": [], + "source": [ + "# 元となる音声ファイル(wav形式)を入れるディレクトリ\n", + "input_dir = \"/content/drive/MyDrive/Style-Bert-VITS2/inputs\"\n", + "# モデル名(話者名)を入力\n", + "model_name = \"your_model_name\"\n", + "\n", + "# こういうふうに書き起こして欲しいという例文(句読点の入れ方・笑い方や固有名詞等)\n", + "initial_prompt = \"こんにちは。元気、ですかー?ふふっ、私は……ちゃんと元気だよ!\"\n", + "\n", + "!python slice.py -i {input_dir} --model_name {model_name}\n", + "!python transcribe.py --model_name {model_name} --initial_prompt {initial_prompt} --use_hf_whisper" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "j7vEWewoAWvS" + }, + "source": [ + "成功したらそのまま3へ進んでください" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z3AC-3zpAWvS" + }, + "source": [ + "### 2.2 音声ファイルと書き起こしデータがすでにある場合\n", + "\n", + "指示に従って適切にデータセットを配置してください。\n", + "\n", + "次のセルを実行して、学習データをいれるフォルダ(1で設定した`dataset_root`)を作成します。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "esCNJl704h52" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.makedirs(dataset_root, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aaDgJCjCAWvT" + }, + "source": [ + "まず音声データと、書き起こしテキストを用意してください。\n", + "\n", + "それを次のように配置します。\n", + "```\n", + "├── Data/\n", + "│ ├── {モデルの名前}\n", + "│ │ ├── esd.list\n", + "│ │ ├── raw/\n", + "│ │ │ ├── foo.wav\n", + "│ │ │ ├── bar.mp3\n", + "│ │ │ ├── style1/\n", + "│ │ │ │ ├── baz.wav\n", + "│ │ │ │ ├── qux.wav\n", + "│ │ │ ├── style2/\n", + "│ │ │ │ ├── corge.wav\n", + "│ │ │ │ ├── grault.wav\n", + "...\n", + "```\n", + "\n", + "### 配置の仕方\n", + "- 上のように配置すると、`style1/`と`style2/`フォルダの内部(直下以外も含む)に入っている音声ファイルたちから、自動的にデフォルトスタイルに加えて`style1`と`style2`というスタイルが作成されます\n", + "- 特にスタイルを作る必要がない場合や、スタイル分類機能等でスタイルを作る場合は、`raw/`フォルダ直下に全てを配置してください。このように`raw/`のサブディレクトリの個数が0または1の場合は、スタイルはデフォルトスタイルのみが作成されます。\n", + "- 音声ファイルのフォーマットはwav形式以外にもmp3等の多くの音声ファイルに対応しています\n", + "\n", + "### 書き起こしファイル`esd.list`\n", + "\n", + "`Data/{モデルの名前}/esd.list` ファイルには、以下のフォーマットで各音声ファイルの情報を記述してください。\n", + "\n", + "\n", + "```\n", + "path/to/audio.wav(wavファイル以外でもこう書く)|{話者名}|{言語ID、ZHかJPかEN}|{書き起こしテキスト}\n", + "```\n", + "\n", + "- ここで、最初の`path/to/audio.wav`は、`raw/`からの相対パスです。つまり、`raw/foo.wav`の場合は`foo.wav`、`raw/style1/bar.wav`の場合は`style1/bar.wav`となります。\n", + "- 拡張子がwavでない場合でも、`esd.list`には`wav`と書いてください、つまり、`raw/bar.mp3`の場合でも`bar.wav`と書いてください。\n", + "\n", + "\n", + "例:\n", + "```\n", + "foo.wav|hanako|JP|こんにちは、元気ですか?\n", + "bar.wav|taro|JP|はい、聞こえています……。何か用ですか?\n", + "style1/baz.wav|hanako|JP|今日はいい天気ですね。\n", + "style1/qux.wav|taro|JP|はい、そうですね。\n", + "...\n", + "english_teacher.wav|Mary|EN|How are you? I'm fine, thank you, and you?\n", + "...\n", + "```\n", + "もちろん日本語話者の単一話者データセットでも構いません。" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5r85-W20ECcr" + }, + "source": [ + "## 3. 学習の前処理\n", + "\n", + "次に学習の前処理を行います。必要なパラメータをここで指定します。次のセルに設定等を入力して実行してください。「~~かどうか」は`True`もしくは`False`を指定してください。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CXR7kjuF5GlE" + }, + "outputs": [], + "source": [ + "# 上でつけたフォルダの名前`Data/{model_name}/`\n", + "model_name = \"your_model_name\"\n", + "\n", + "# JP-Extra (日本語特化版)を使うかどうか。日本語の能力が向上する代わりに英語と中国語は使えなくなります。\n", + "use_jp_extra = True\n", + "\n", + "# 学習のバッチサイズ。VRAMのはみ出具合に応じて調整してください。\n", + "batch_size = 4\n", + "\n", + "# 学習のエポック数(データセットを合計何周するか)。\n", + "# 100で多すぎるほどかもしれませんが、もっと多くやると質が上がるのかもしれません。\n", + "epochs = 100\n", + "\n", + "# 保存頻度。何ステップごとにモデルを保存するか。分からなければデフォルトのままで。\n", + "save_every_steps = 1000\n", + "\n", + "# 音声ファイルの音量を正規化するかどうか\n", + "normalize = False\n", + "\n", + "# 音声ファイルの開始・終了にある無音区間を削除するかどうか\n", + "trim = False\n", + "\n", + "# 読みのエラーが出た場合にどうするか。\n", + "# \"raise\"ならテキスト前処理が終わったら中断、\"skip\"なら読めない行は学習に使わない、\"use\"なら無理やり使う\n", + "yomi_error = \"skip\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BFZdLTtpAWvT" + }, + "source": [ + "上のセルが実行されたら、次のセルを実行して学習の前処理を行います。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "c7g0hrdeP1Tl", - "outputId": "94f9a6f6-027f-4554-ce0c-60ac56251c22" - }, - "outputs": [], - "source": [ - "# 学習結果を試す・マージ・スタイル分けはこちらから\n", - "!python app.py --share" - ] - } - ], - "metadata": { - "accelerator": "GPU", + "id": "xMVaOIPLabV5", + "outputId": "36b1c2b2-6df0-4d00-d86a-519a0fc0af63" + }, + "outputs": [], + "source": [ + "from gradio_tabs.train import preprocess_all\n", + "from style_bert_vits2.nlp.japanese import pyopenjtalk_worker\n", + "\n", + "\n", + "pyopenjtalk_worker.initialize_worker()\n", + "\n", + "preprocess_all(\n", + " model_name=model_name,\n", + " batch_size=batch_size,\n", + " epochs=epochs,\n", + " save_every_steps=save_every_steps,\n", + " num_processes=2,\n", + " normalize=normalize,\n", + " trim=trim,\n", + " freeze_EN_bert=False,\n", + " freeze_JP_bert=False,\n", + " freeze_ZH_bert=False,\n", + " freeze_style=False,\n", + " freeze_decoder=False,\n", + " use_jp_extra=use_jp_extra,\n", + " val_per_lang=0,\n", + " log_interval=200,\n", + " yomi_error=yomi_error,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sVhwI5C-AWvT" + }, + "source": [ + "## 4. 学習\n", + "\n", + "前処理が正常に終わったら、学習を行います。次のセルを実行すると学習が始まります。\n", + "\n", + "学習の結果は、上で指定した`save_every_steps`の間隔で、Google Driveの中の`Style-Bert-VITS2/Data/{モデルの名前}/model_assets/`フォルダに保存されます。\n", + "\n", + "このフォルダをダウンロードし、ローカルのStyle-Bert-VITS2の`model_assets`フォルダに上書きすれば、学習結果を使うことができます。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "laieKrbEb6Ij" + }, + "outputs": [], + "source": [ + "# 上でつけたモデル名を入力。学習を途中からする場合はきちんとモデルが保存されているフォルダ名を入力。\n", + "model_name = \"your_model_name\"\n", + "\n", + "\n", + "import yaml\n", + "from gradio_tabs.train import get_path\n", + "\n", + "paths = get_path(model_name)\n", + "dataset_path = str(paths.dataset_path)\n", + "config_path = str(paths.config_path)\n", + "\n", + "with open(\"default_config.yml\", \"r\", encoding=\"utf-8\") as f:\n", + " yml_data = yaml.safe_load(f)\n", + "yml_data[\"model_name\"] = model_name\n", + "with open(\"config.yml\", \"w\", encoding=\"utf-8\") as f:\n", + " yaml.dump(yml_data, f, allow_unicode=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { "colab": { - "gpuType": "T4", - "provenance": [] + "background_save": true, + "base_uri": "https://localhost:8080/" }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" + "id": "JqGeHNabAWvT", + "outputId": "c51b422c-728b-420b-fa92-b787fa058adf" + }, + "outputs": [], + "source": [ + "# 日本語特化版を「使う」場合\n", + "!python train_ms_jp_extra.py --config {config_path} --model {dataset_path} --assets_root {assets_root}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rVbjh-WPAWvU" + }, + "outputs": [], + "source": [ + "# 日本語特化版を「使わない」場合\n", + "!python train_ms.py --config {config_path} --model {dataset_path} --assets_root {assets_root}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } + "id": "c7g0hrdeP1Tl", + "outputId": "4bb9d21e-50df-4ba5-a547-daa78a4b63dc" + }, + "outputs": [], + "source": [ + "# 学習結果を試す・マージ・スタイル分けはこちらから\n", + "!python app.py --share" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" }, - "nbformat": 4, - "nbformat_minor": 0 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/config.py b/config.py index 2229e615694cc2cfa543e10f4d1a9de0e5b09358..4ce32f8d3fe47335d40ea60ab02f8255446a807a 100644 --- a/config.py +++ b/config.py @@ -2,9 +2,9 @@ @Desc: 全局配置文件读取 """ -import os import shutil -from typing import Dict, List +from pathlib import Path +from typing import Any import torch import yaml @@ -12,6 +12,12 @@ import yaml from style_bert_vits2.logging import logger +class PathConfig: + def __init__(self, dataset_root: str, assets_root: str): + self.dataset_root = Path(dataset_root) + self.assets_root = Path(assets_root) + + # If not cuda available, set possible devices to cpu cuda_available = torch.cuda.is_available() @@ -20,17 +26,17 @@ class Resample_config: """重采样配置""" def __init__(self, in_dir: str, out_dir: str, sampling_rate: int = 44100): - self.sampling_rate: int = sampling_rate # 目标采样率 - self.in_dir: str = in_dir # 待处理音频目录路径 - self.out_dir: str = out_dir # 重采样输出路径 + self.sampling_rate = sampling_rate # 目标采样率 + self.in_dir = Path(in_dir) # 待处理音频目录路径 + self.out_dir = Path(out_dir) # 重采样输出路径 @classmethod - def from_dict(cls, dataset_path: str, data: Dict[str, any]): + def from_dict(cls, dataset_path: Path, data: dict[str, Any]): """从字典中生成实例""" # 不检查路径是否有效,此逻辑在resample.py中处理 - data["in_dir"] = os.path.join(dataset_path, data["in_dir"]) - data["out_dir"] = os.path.join(dataset_path, data["out_dir"]) + data["in_dir"] = dataset_path / data["in_dir"] + data["out_dir"] = dataset_path / data["out_dir"] return cls(**data) @@ -49,39 +55,32 @@ class Preprocess_text_config: max_val_total: int = 10000, clean: bool = True, ): - self.transcription_path: str = ( - transcription_path # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。 - ) - self.cleaned_path: str = ( - cleaned_path # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成 - ) - self.train_path: str = ( - train_path # 训练集路径,可以不填。不填则将在原始文本目录生成 - ) - self.val_path: str = ( - val_path # 验证集路径,可以不填。不填则将在原始文本目录生成 - ) - self.config_path: str = config_path # 配置文件路径 - self.val_per_lang: int = val_per_lang # 每个speaker的验证集条数 - self.max_val_total: int = ( - max_val_total # 验证集最大条数,多于的会被截断并放到训练集中 - ) - self.clean: bool = clean # 是否进行数据清洗 + self.transcription_path = Path(transcription_path) + self.train_path = Path(train_path) + if cleaned_path == "" or cleaned_path is None: + self.cleaned_path = self.transcription_path.with_name( + self.transcription_path.name + ".cleaned" + ) + else: + self.cleaned_path = Path(cleaned_path) + self.val_path = Path(val_path) + self.config_path = Path(config_path) + self.val_per_lang = val_per_lang + self.max_val_total = max_val_total + self.clean = clean @classmethod - def from_dict(cls, dataset_path: str, data: Dict[str, any]): + def from_dict(cls, dataset_path: Path, data: dict[str, Any]): """从字典中生成实例""" - data["transcription_path"] = os.path.join( - dataset_path, data["transcription_path"] - ) + data["transcription_path"] = dataset_path / data["transcription_path"] if data["cleaned_path"] == "" or data["cleaned_path"] is None: - data["cleaned_path"] = None + data["cleaned_path"] = "" else: - data["cleaned_path"] = os.path.join(dataset_path, data["cleaned_path"]) - data["train_path"] = os.path.join(dataset_path, data["train_path"]) - data["val_path"] = os.path.join(dataset_path, data["val_path"]) - data["config_path"] = os.path.join(dataset_path, data["config_path"]) + data["cleaned_path"] = dataset_path / data["cleaned_path"] + data["train_path"] = dataset_path / data["train_path"] + data["val_path"] = dataset_path / data["val_path"] + data["config_path"] = dataset_path / data["config_path"] return cls(**data) @@ -96,7 +95,7 @@ class Bert_gen_config: device: str = "cuda", use_multi_device: bool = False, ): - self.config_path = config_path + self.config_path = Path(config_path) self.num_processes = num_processes if not cuda_available: device = "cpu" @@ -104,8 +103,8 @@ class Bert_gen_config: self.use_multi_device = use_multi_device @classmethod - def from_dict(cls, dataset_path: str, data: Dict[str, any]): - data["config_path"] = os.path.join(dataset_path, data["config_path"]) + def from_dict(cls, dataset_path: Path, data: dict[str, Any]): + data["config_path"] = dataset_path / data["config_path"] return cls(**data) @@ -119,15 +118,15 @@ class Style_gen_config: num_processes: int = 4, device: str = "cuda", ): - self.config_path = config_path + self.config_path = Path(config_path) self.num_processes = num_processes if not cuda_available: device = "cpu" self.device = device @classmethod - def from_dict(cls, dataset_path: str, data: Dict[str, any]): - data["config_path"] = os.path.join(dataset_path, data["config_path"]) + def from_dict(cls, dataset_path: Path, data: dict[str, Any]): + data["config_path"] = dataset_path / data["config_path"] return cls(**data) @@ -138,7 +137,7 @@ class Train_ms_config: def __init__( self, config_path: str, - env: Dict[str, any], + env: dict[str, Any], # base: Dict[str, any], model_dir: str, num_workers: int, @@ -147,16 +146,18 @@ class Train_ms_config: ): self.env = env # 需要加载的环境变量 # self.base = base # 底模配置 - self.model_dir = model_dir # 训练模型存储目录,该路径为相对于dataset_path的路径,而非项目根目录 - self.config_path = config_path # 配置文件路径 + self.model_dir = Path( + model_dir + ) # 训练模型存储目录,该路径为相对于dataset_path的路径,而非项目根目录 + self.config_path = Path(config_path) # 配置文件路径 self.num_workers = num_workers # worker数量 self.spec_cache = spec_cache # 是否启用spec缓存 self.keep_ckpts = keep_ckpts # ckpt数量 @classmethod - def from_dict(cls, dataset_path: str, data: Dict[str, any]): + def from_dict(cls, dataset_path: Path, data: dict[str, Any]): # data["model"] = os.path.join(dataset_path, data["model"]) - data["config_path"] = os.path.join(dataset_path, data["config_path"]) + data["config_path"] = dataset_path / data["config_path"] return cls(**data) @@ -176,20 +177,18 @@ class Webui_config: ): if not cuda_available: device = "cpu" - self.device: str = device - self.model: str = model # 端口号 - self.config_path: str = config_path # 是否公开部署,对外网开放 - self.port: int = port # 是否开启debug模式 - self.share: bool = share # 模型路径 - self.debug: bool = debug # 配置文件路径 - self.language_identification_library: str = ( - language_identification_library # 语种识别库 - ) + self.device = device + self.model = Path(model) + self.config_path = Path(config_path) + self.port: int = port + self.share: bool = share + self.debug: bool = debug + self.language_identification_library: str = language_identification_library @classmethod - def from_dict(cls, dataset_path: str, data: Dict[str, any]): - data["config_path"] = os.path.join(dataset_path, data["config_path"]) - data["model"] = os.path.join(dataset_path, data["model"]) + def from_dict(cls, dataset_path: Path, data: dict[str, Any]): + data["config_path"] = dataset_path / data["config_path"] + data["model"] = dataset_path / data["model"] return cls(**data) @@ -200,7 +199,7 @@ class Server_config: device: str = "cuda", limit: int = 100, language: str = "JP", - origins: List[str] = None, + origins: list[str] = ["*"], ): self.port: int = port if not cuda_available: @@ -208,10 +207,10 @@ class Server_config: self.device: str = device self.language: str = language self.limit: int = limit - self.origins: List[str] = origins + self.origins: list[str] = origins @classmethod - def from_dict(cls, data: Dict[str, any]): + def from_dict(cls, data: dict[str, Any]): return cls(**data) @@ -223,32 +222,33 @@ class Translate_config: self.secret_key = secret_key @classmethod - def from_dict(cls, data: Dict[str, any]): + def from_dict(cls, data: dict[str, Any]): return cls(**data) class Config: - def __init__(self, config_path: str, path_config: dict[str, str]): - if not os.path.isfile(config_path) and os.path.isfile("default_config.yml"): + def __init__(self, config_path: str, path_config: PathConfig): + if not Path(config_path).exists(): shutil.copy(src="default_config.yml", dst=config_path) logger.info( f"A configuration file {config_path} has been generated based on the default configuration file default_config.yml." ) logger.info( - "If you have no special needs, please do not modify default_config.yml." + "Please do not modify default_config.yml. Instead, modify config.yml." ) # sys.exit(0) - with open(config_path, "r", encoding="utf-8") as file: - yaml_config: Dict[str, any] = yaml.safe_load(file.read()) + with open(config_path, encoding="utf-8") as file: + yaml_config: dict[str, Any] = yaml.safe_load(file.read()) model_name: str = yaml_config["model_name"] self.model_name: str = model_name if "dataset_path" in yaml_config: - dataset_path = yaml_config["dataset_path"] + dataset_path = Path(yaml_config["dataset_path"]) else: - dataset_path = os.path.join(path_config["dataset_root"], model_name) - self.dataset_path: str = dataset_path - self.assets_root: str = path_config["assets_root"] - self.out_dir = os.path.join(self.assets_root, model_name) + dataset_path = path_config.dataset_root / model_name + self.dataset_path = dataset_path + self.dataset_root = path_config.dataset_root + self.assets_root = path_config.assets_root + self.out_dir = self.assets_root / model_name self.resample_config: Resample_config = Resample_config.from_dict( dataset_path, yaml_config["resample"] ) @@ -277,16 +277,31 @@ class Config: # ) -with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: - path_config: dict[str, str] = yaml.safe_load(f.read()) - # Should contain the following keys: - # - dataset_root: the root directory of the dataset, default to "Data" - # - assets_root: the root directory of the assets, default to "model_assets" +# Load and initialize the configuration + + +def get_path_config() -> PathConfig: + path_config_path = Path("configs/paths.yml") + if not path_config_path.exists(): + shutil.copy(src="configs/default_paths.yml", dst=path_config_path) + logger.info( + f"A configuration file {path_config_path} has been generated based on the default configuration file default_paths.yml." + ) + logger.info( + "Please do not modify configs/default_paths.yml. Instead, modify configs/paths.yml." + ) + with open(path_config_path, encoding="utf-8") as file: + path_config_dict: dict[str, str] = yaml.safe_load(file.read()) + return PathConfig(**path_config_dict) + +def get_config() -> Config: + path_config = get_path_config() + try: + config = Config("config.yml", path_config) + except (TypeError, KeyError): + logger.warning("Old config.yml found. Replace it with default_config.yml.") + shutil.copy(src="default_config.yml", dst="config.yml") + config = Config("config.yml", path_config) -try: - config = Config("config.yml", path_config) -except (TypeError, KeyError): - logger.warning("Old config.yml found. Replace it with default_config.yml.") - shutil.copy(src="default_config.yml", dst="config.yml") - config = Config("config.yml", path_config) + return config diff --git a/configs/config.json b/configs/config.json index 2f3988b4f1799c8f4254151763617ac7e7620513..d1edf16ebb815207117a95ce22f7f6b1798bbf51 100644 --- a/configs/config.json +++ b/configs/config.json @@ -69,5 +69,5 @@ "use_spectral_norm": false, "gin_channels": 256 }, - "version": "2.4.1" + "version": "2.6.1" } diff --git a/configs/config_jp_extra.json b/configs/config_jp_extra.json index d7548094b2248e4e3fb2017b522093eeb1ce768a..fa93293c2c394d46be4869f7e332bdf89e2b6279 100644 --- a/configs/config_jp_extra.json +++ b/configs/config_jp_extra.json @@ -76,5 +76,5 @@ "initial_channel": 64 } }, - "version": "2.4.1-JP-Extra" + "version": "2.6.1-JP-Extra" } diff --git a/configs/default_paths.yml b/configs/default_paths.yml new file mode 100644 index 0000000000000000000000000000000000000000..d743748f7f52d5c41f1487b53872143c2f010222 --- /dev/null +++ b/configs/default_paths.yml @@ -0,0 +1,8 @@ +# Root directory of the training dataset. +# The training dataset of {model_name} should be placed in {dataset_root}/{model_name}. +dataset_root: Data + +# Root directory of the model assets (for inference). +# In training, the model assets will be saved to {assets_root}/{model_name}, +# and in inference, we load all the models from {assets_root}. +assets_root: model_assets diff --git a/data_utils.py b/data_utils.py index 73d4303c80494f15dc8287809aafb3df2f323a4a..96eab38668ee97c15e5d635bb5eb07c688fa6b69 100644 --- a/data_utils.py +++ b/data_utils.py @@ -7,7 +7,7 @@ import torch import torch.utils.data from tqdm import tqdm -from config import config +from config import get_config from mel_processing import mel_spectrogram_torch, spectrogram_torch from style_bert_vits2.logging import logger from style_bert_vits2.models import commons @@ -16,6 +16,7 @@ from style_bert_vits2.models.utils import load_filepaths_and_text, load_wav_to_t from style_bert_vits2.nlp import cleaned_text_to_sequence +config = get_config() """Multi speaker version""" @@ -70,16 +71,16 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): self.audiopaths_sid_text, file=sys.stdout ): audiopath = f"{_id}" - if self.min_text_len <= len(phones) and len(phones) <= self.max_text_len: - phones = phones.split(" ") - tone = [int(i) for i in tone.split(" ")] - word2ph = [int(i) for i in word2ph.split(" ")] - audiopaths_sid_text_new.append( - [audiopath, spk, language, text, phones, tone, word2ph] - ) - lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length)) - else: - skipped += 1 + # if self.min_text_len <= len(phones) and len(phones) <= self.max_text_len: + phones = phones.split(" ") + tone = [int(i) for i in tone.split(" ")] + word2ph = [int(i) for i in word2ph.split(" ")] + audiopaths_sid_text_new.append( + [audiopath, spk, language, text, phones, tone, word2ph] + ) + lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length)) + # else: + # skipped += 1 logger.info( "skipped: " + str(skipped) @@ -120,9 +121,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): audio, sampling_rate = load_wav_to_torch(filename) if sampling_rate != self.sampling_rate: raise ValueError( - "{} {} SR doesn't match target {} SR".format( - filename, sampling_rate, self.sampling_rate - ) + f"{filename} {sampling_rate} SR doesn't match target {self.sampling_rate} SR" ) audio_norm = audio / self.max_wav_value audio_norm = audio_norm.unsqueeze(0) diff --git a/default_style.py b/default_style.py index 49881eb1caa6a0cb0b898ac1c2277ab904b8ecf7..b17b6c03b97b34f7f173b5a1acfae847800ce17f 100644 --- a/default_style.py +++ b/default_style.py @@ -1,5 +1,4 @@ import json -import os from pathlib import Path from typing import Union @@ -9,26 +8,91 @@ from style_bert_vits2.constants import DEFAULT_STYLE from style_bert_vits2.logging import logger -def set_style_config(json_path: Path, output_path: Path): - with open(json_path, "r", encoding="utf-8") as f: +def save_neutral_vector( + wav_dir: Union[Path, str], + output_dir: Union[Path, str], + config_path: Union[Path, str], + config_output_path: Union[Path, str], +): + wav_dir = Path(wav_dir) + output_dir = Path(output_dir) + embs = [] + for file in wav_dir.rglob("*.npy"): + xvec = np.load(file) + embs.append(np.expand_dims(xvec, axis=0)) + + x = np.concatenate(embs, axis=0) # (N, 256) + mean = np.mean(x, axis=0) # (256,) + only_mean = np.stack([mean]) # (1, 256) + np.save(output_dir / "style_vectors.npy", only_mean) + logger.info(f"Saved mean style vector to {output_dir}") + + with open(config_path, encoding="utf-8") as f: json_dict = json.load(f) json_dict["data"]["num_styles"] = 1 json_dict["data"]["style2id"] = {DEFAULT_STYLE: 0} - with open(output_path, "w", encoding="utf-8") as f: + with open(config_output_path, "w", encoding="utf-8") as f: json.dump(json_dict, f, indent=2, ensure_ascii=False) - logger.info(f"Save style config (only {DEFAULT_STYLE}) to {output_path}") + logger.info(f"Saved style config to {config_output_path}") -def save_neutral_vector(wav_dir: Union[Path, str], output_path: Union[Path, str]): +def save_styles_by_dirs( + wav_dir: Union[Path, str], + output_dir: Union[Path, str], + config_path: Union[Path, str], + config_output_path: Union[Path, str], +): wav_dir = Path(wav_dir) - output_path = Path(output_path) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + config_path = Path(config_path) + config_output_path = Path(config_output_path) + + subdirs = [d for d in wav_dir.iterdir() if d.is_dir()] + subdirs.sort() + if len(subdirs) in (0, 1): + logger.info( + f"At least 2 subdirectories are required for generating style vectors with respect to them, found {len(subdirs)}." + ) + logger.info("Generating only neutral style vector instead.") + save_neutral_vector(wav_dir, output_dir, config_path, config_output_path) + return + + # First get mean of all for Neutral embs = [] for file in wav_dir.rglob("*.npy"): xvec = np.load(file) embs.append(np.expand_dims(xvec, axis=0)) - x = np.concatenate(embs, axis=0) # (N, 256) mean = np.mean(x, axis=0) # (256,) - only_mean = np.stack([mean]) # (1, 256) - np.save(output_path, only_mean) - logger.info(f"Saved mean style vector to {output_path}") + style_vectors = [mean] + + names = [DEFAULT_STYLE] + for style_dir in subdirs: + npy_files = list(style_dir.rglob("*.npy")) + if not npy_files: + continue + embs = [] + for file in npy_files: + xvec = np.load(file) + embs.append(np.expand_dims(xvec, axis=0)) + + x = np.concatenate(embs, axis=0) # (N, 256) + mean = np.mean(x, axis=0) # (256,) + style_vectors.append(mean) + names.append(style_dir.name) + + # Stack them to make (num_styles, 256) + style_vectors_npy = np.stack(style_vectors, axis=0) + np.save(output_dir / "style_vectors.npy", style_vectors_npy) + logger.info(f"Saved style vectors to {output_dir / 'style_vectors.npy'}") + + # Save style2id config to json + style2id = {name: i for i, name in enumerate(names)} + with open(config_path, encoding="utf-8") as f: + json_dict = json.load(f) + json_dict["data"]["num_styles"] = len(names) + json_dict["data"]["style2id"] = style2id + with open(config_output_path, "w", encoding="utf-8") as f: + json.dump(json_dict, f, indent=2, ensure_ascii=False) + logger.info(f"Saved style config to {config_output_path}") diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index fe1ed527a775c5f84a4cd2782f2bb53c2f313771..e19f74dd756df3b0446fbc2fd5aac9c3e25d1133 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,5 +1,76 @@ # Changelog +## v2.6.1 (2024-09-09) + +- Google colabで、torchのバージョン由来でエラーが発生する不具合の修正(たぶん) +- WebUIからのスタイル作成での、サブフォルダによるスタイル分けでエラーが発生していた点の修正 + +## v2.6.0 (2024-06-16) + +### 新機能 +モデルのマージ時に、今までの `new = (1 - weight) * A + weight * B` の他に、次を追加 + +- `new = A + weight * (B - C)`: 差分マージ +- `new = a * A + b * B + c * C`: 加重和マージ +- `new = A + weight * B`: ヌルモデルのマージ + +差分マージは、例えばBを「Cと同じ話者だけど囁いているモデル」とすると、`B - C`が囁きベクトル的なものだと思えるので、それをAに足すことで、Aの話者が囁いているような音声を生成できるようになります。 + +また、加重和で`new = A - B`を作って、それをヌルモデルマージで別のモデルに足せば、実質差分マージを実現できます。また謎に`new = -A`や`new = 41 * A`等のモデルも作ることができます。 + +これらのマージの活用法については各自いろいろ考えて実験してみて、面白い使い方があればぜひ共有してください。 + +囁きについて実験的に作ったヌルモデルを[こちら](https://huggingface.co/litagin/sbv2_null_models)に置いています。これをヌルモデルマージで使うことで、任意のモデルを囁きモデルにある程度は変換できます。 + +### 改善 + +- スタイルベクトルのマージ部分のUIの改善 +- WebUIの`App.bat`の起動が少し重いので、それぞれの機能を分割した`Dataset.bat`, `Inference.bat`, `Merge.bat`, `StyleVectors.bat`, `Train.bat`を追加 (今までの`App.bat`もこれまで通り使えます) + +## v2.5.1 (2024-06-14) + +ライセンスとのコンフリクトから、[利用規約](/docs/TERMS_OF_USE.md)を[開発陣からのお願いとデフォルトモデルの利用規約](/docs/TERMS_OF_USE.md)に変更しました。 + +## v2.5.0 (2024-06-02) + +このバージョンから[利用規約](/docs/TERMS_OF_USE.md)が追加されました。ご利用の際は必ずお読みください。 + +### 新機能等 + +- デフォルトモデルに [あみたろの声素材工房](https://amitaro.net/) のあみたろ様が公開しているコーパスとライブ配信音声を利用して学習した[**小春音アミ**](https://huggingface.co/litagin/sbv2_koharune_ami)と[**あみたろ**](https://huggingface.co/litagin/sbv2_amitaro)モデルを追加(あみたろ様には事前に連絡して許諾を得ています) + - アプデの場合は`Initialize.bat`をダブルクリックすればモデルをダウンロードできます(手動でダウンロードして`model_assets`フォルダに入れることも可能) +- 学習時に音声データをスタイルごとにフォルダ分けしておくことで、そのフォルダごとのスタイルを学習時に自動的に作成するように + - `inputs`からスライスして使う場合は`inputs`直下に作りたいスタイルだけサブフォルダを作りそこに音声ファイルを配置 + - `Data/モデル名/raw`から使う場合も`raw`直下に同様に配置 + - サブフォルダの個数が0または1の場合は、今まで通りのNeutralスタイルのみが作成されます +- batファイルでのインストールの大幅な高速化(Pythonのライブラリインストールに[uv](https://github.com/astral-sh/uv)を使用) +- 学習時に「カスタムバッチサンプラーを無効化」オプションを追加。これにより、長い音声ファイルも学習に使われるようになりますが、使用VRAMがかなり増えたり学習が不安定になる可能性があります。 +- [よくある質問](/docs/FAQ.md)を追加 +- 英語の音声合成の速度向上([gordon0414](https://github.com/gordon0414)さんによる[PR](https://github.com/litagin02/Style-Bert-VITS2/pull/124)です、ありがとうございます!) +- エディターの各種機能改善(多くが[kamexy](https://github.com/kamexy)様による[エディターリポジトリ](https://github.com/litagin02/Style-Bert-VITS2-Editor)へのプルリク群です、ありがとうございます!) + - 選択した行の下に新規の行を作成できるように + - Mac使用時に日本語変換のエンターで音声合成が走るバグの修正 + - ペースト時に改行を含まない場合は通常のペーストの振る舞いになるように修正 + + +### その他の改善 + +- 上のスタイル自動作成機能を既存モデルでも使えるような機能追加。具体的には、スタイル作成タブにて、フォルダ分けされた音声ファイルのディレクトリを任意に指定し、そのフォルダ分けを使って既存のモデルのスタイルの作成が可能に +- 音声書き起こしに[kotoba-whisper](https://huggingface.co/kotoba-tech/kotoba-whisper-v1.1)を追加 +- 音声書き起こし時にHugging FaceのWhisperモデルを使う際に、書き起こしを順次保存するように改善 +- 音声書き起こしのデフォルトをfaster-whiperからHugging FaceのWhisperモデルへ変更 +- (**ライブラリとしてのみ**)依存関係の軽量化、音声合成時に読み上げテキストの読みを表す音素列を指定する機能を追加 + 様々な改善 ([tsukumijimaさん](https://github.com/tsukumijima)による[プルリク](https://github.com/litagin02/Style-Bert-VITS2/pull/118)です、ありがとうございます!) + +### 内部変更 + +- これまでpath管理に`configs/paths.yml`を使っていたが、`configs/default_paths.yml`にリネームし、`configs/paths.yml`はgitの管理対象外に変更 + +### バグ修正 + +- Gradioのアップデートにより、モデル選択時やスタイルのDBSCAN作成時等に`TypeError: Type is not JSON serializable: WindowsPath`のようなエラーが出る問題を修正 +- TensorboardをWebUIから立ち上げた際にエラーが出る問題の修正 ([#129](https://github.com/litagin02/Style-Bert-VITS2/issues/129)) + + ## v2.4.1 (2024-03-16) **batファイルでのインストール・アップデート方法の変更**(それ以外の変更はありません) diff --git a/docs/CLI.md b/docs/CLI.md index 97d296aa0a0e8d2773f5dea266a378068cab37bb..726c42d168ae962d060ac72f73fda2e846bd350b 100644 --- a/docs/CLI.md +++ b/docs/CLI.md @@ -7,17 +7,17 @@ git clone https://github.com/litagin02/Style-Bert-VITS2.git cd Style-Bert-VITS2 python -m venv venv venv\Scripts\activate -pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 +pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118 pip install -r requirements.txt ``` Then download the necessary models and the default TTS model, and set the global paths. ```bash -python initialize.py [--skip_jvnv] [--dataset_root ] [--assets_root ] +python initialize.py [--skip_default_models] [--dataset_root ] [--assets_root ] ``` Optional: -- `--skip_jvnv`: Skip downloading the default JVNV voice models (use this if you only have to train your own models). +- `--skip_default_models`: Skip downloading the default voice models (use this if you only have to train your own models). - `--dataset_root`: Default: `Data`. Root directory of the training dataset. The training dataset of `{model_name}` should be placed in `{dataset_root}/{model_name}`. - `--assets_root`: Default: `model_assets`. Root directory of the model assets (for inference). In training, the model assets will be saved to `{assets_root}/{model_name}`, and in inference, we load all the models from `{assets_root}`. @@ -26,7 +26,7 @@ Optional: ### 1.1. Slice audio files -The following audio formats are supported: ".wav", ".flac", ".mp3", ".ogg", ".opus". +The following audio formats are supported: ".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a". ```bash python slice.py --model_name [-i ] [-m ] [-M ] [--time_suffix] ``` @@ -101,4 +101,4 @@ python train_ms_jp_extra.py [--repo_id /] [--skip_default_s Optional: - `--repo_id`: Hugging Face repository ID to upload the trained model to. You should have logged in using `huggingface-cli login` before running this command. -- `--skip_default_style`: Skip making the default style vector. Use this if you want to resume training (since the default style vector is already made). +- `--skip_default_style`: Skip making the default style vector. Use this if you want to resume training (since the default style vector has been already made). diff --git a/docs/FAQ.md b/docs/FAQ.md new file mode 100644 index 0000000000000000000000000000000000000000..35e926369b381d0720a1b14ce576b918b4c54630 --- /dev/null +++ b/docs/FAQ.md @@ -0,0 +1,57 @@ +# よくある質問 + +## 書き起こしにエラーが出たりして失敗する + +ffmpegが入っていないことが問題のようです。 +ググるか、おそらくWindowsなら +``` +winget install ffmpeg +``` +によりffmpegをインストールできます。その後で試してみてください。 + +## Google Colabでの学習がエラーが何か出て動かない + +Google Colabのノートブックは以前のバージョンのノートブックのコピーを使っていませんか? +Colabノートブックは最新のバージョンに合ったノートブックで動かすことを前提としています。ノートブック記載のバージョンを確認して、[最新のcolabノートブック](http://colab.research.google.com/github/litagin02/Style-Bert-VITS2/blob/master/colab.ipynb)(を必要ならコピーして)から使うようにしてください。 + +## `ModuleNotFoundError: No module named '_socket'`と出る + +フォルダ名をインストールした時から変えていませんか?フォルダ名を変えるとパスが変わってしまい、インストール時に指定したパスと異なるためにエラーが出ます。フォルダ名を元に戻してください。 + +## 学習に時間がかかりすぎる + +デフォルトの100エポックは音声データ量によっては過剰な場合があります。デフォルトでは1000ステップごとにモデルが保存されるはずなので、途中で学習を中断してみて途中のもので試してみてもいいでしょう。 + +またバッチサイズが大き過ぎてメモリがVRAMから溢れると非常に遅くなることがあります。VRAM使用量がギリギリだったり物理メモリに溢れている場合はバッチサイズを小さくしてみてください。 + +## どのくらいの音声データが必要なの? + +分かりません。試行錯誤してください。 + +参考として、数分程度でも学習はできるらしく、またRVCでよく言われているのは多くても45分くらいで十分説があります。ただ多ければ多いほど精度が上がる可能性もありますが、分かりません。 + + +## どのくらいのステップ・エポックがいいの? + +分かりません。試行錯誤してください。`python speech_mos.py -m <モデル名>`によって自然性の一つの評価ができるので、それが少し参考になります(ただあくまで一つの指標です)。 + +参考として、最初の2k-3kで声音はかなり似始めて、5k-10k-15kステップほどで感情含めてよい感じになりやすく、そこからどんどん回して20kなり30kなり50kなり100kなりでどんどん微妙に変わっていきます。が、微妙に変わるので、どこがいいとかは分かりません。 + +## APIサーバーで長い文章が合成できない + +デフォルトで`server_fastapi.py`の入力文字上限は100文字に設定されています。 +`config.yml`の`server.limit`の100を好きな数字に変更してください。 +上限をなくしたい方は`server.limit`を-1に設定してください。 + +## 学習を中断・再開するには + +- 学習を中断するには、学習の進捗が表示されている画面(bat使用ならコマンドプロンプト)を好きなタイミングで閉じてください。 +- 学習を再開するには、WebUIでモデル名を再開したいモデルと同じ名前に設定して、前処理等はせずに一番下の「学習を開始する」ボタンを押してください(「スタイルファイルの生成をスキップする」にチェックを入れるのをおすすめします)。 + +## 途中でバッチサイズやエポック数を変更したい + +`Data/{モデル名}/config.json`を手動で変更してから、学習を再開してください。 + +## その他 + +ググったり調べたりChatGPTに聞くか、それでも分からない場合・または手順通りやってもエラーが出る等明らかに不具合やバグと思われる場合は、GitHubの[Issue](https://github.com/litagin02/Style-Bert-VITS2/issues)に投稿してください。 diff --git a/docs/TERMS_OF_USE.md b/docs/TERMS_OF_USE.md new file mode 100644 index 0000000000000000000000000000000000000000..585d6bd0f7632d49ef8545681e8cb83e2e3709bf --- /dev/null +++ b/docs/TERMS_OF_USE.md @@ -0,0 +1,54 @@ +# 開発陣からのお願いとデフォルトモデルの利用規約 + +- 2024-06-14: ライセンスとの整合性から「利用規約」を「お願い」に変更 +- 2024-06-01: 初版 + +Style-Bert-VITS2を用いる際は、以下のお願いを守っていただけると幸いです。ただしモデルの利用規約以前の箇所はあくまで「お願い」であり、何の強制力はなく、Style-Bert-VITS2の利用規約ではありません。よって[リポジトリのライセンス](https://github.com/litagin02/Style-Bert-VITS2#license)とは矛盾せず、リポジトリの利用にあたっては常にリポジトリのライセンスのみが拘束力を持ちます。 + +## やってほしくないこと + +以下の目的での利用はStyle-Bert-VITS2を使ってほしくありません。 + +- 法律に違反する目的 +- 政治的な目的(本家Bert-VITS2で禁止されています) +- 他者を傷つける目的 +- なりすまし・ディープフェイク作成目的 + +## 守ってほしいこと + +- Style-Bert-VITS2を利用する際は、使用するモデルの利用規約・ライセンス必ず確認し、存在する場合はそれに従ってほしいです。 +- またソースコードを利用する際は、[リポジトリのライセンス](https://github.com/litagin02/Style-Bert-VITS2#license)に従ってほしいです。 + +# モデルの利用規約・ライセンス + +以下はデフォルトで付随しているモデルの利用規約・ライセンスです。このリポジトリ自体にはモデルは付随していないので、[リポジトリのライセンス](https://github.com/litagin02/Style-Bert-VITS2#license)とは関係がありません(なのでリポジトリライセンスとの矛盾は発生しません)。 + +## JVNVコーパス (jvnv-F1-jp, jvnv-F2-jp, jvnv-M1-jp, jvnv-M2-jp) + +- [JVNVコーパス](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvnv_corpus) のライセンスは[CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/deed.ja)ですので、これを継承します。 + +## 小春音アミ (koharune-ami) / あみたろ (amitaro) + +[あみたろの声素材工房様の規約](https://amitaro.net/voice/voice_rule/) と [あみたろのライブ配信音声・利用規約](https://amitaro.net/voice/livevoice/#index_id6) を全て守らなければなりません。特に、以下の事項を遵守してください(規約を守れば商用非商用問わず利用できます)。 + +### 禁止事項 + +- 年齢制限のある作品・用途への使用 +- 新興宗教・政治・マルチ購などに深く関係する作品・用途 +- 特定の団体や個人や国家を誹謗中傷する作品・用途 +- 生成された音声を、あみたろ本人の声として扱うこと +- 生成された音声を、あみたろ以外の人の声として扱うこと + +### クレジット表記 + +生成音声を公開する際は(媒体は問わない)、必ず分かりやすい場所に `あみたろの声素材工房 (https://amitaro.net/)` の声を元にした音声モデルを使用していることが分かるようなクレジット表記を記載してください。 + +クレジット表記例: +- `Style-BertVITS2モデル: 小春音アミ、あみたろの声素材工房 (https://amitaro.net/)` +- `Style-BertVITS2モデル: あみたろ、あみたろの声素材工房 (https://amitaro.net/)` + +### モデルマージ + +モデルマージに関しては、[あみたろの声素材工房のよくある質問への回答](https://amitaro.net/voice/faq/#index_id17)を遵守してください: +- 本モデルを別モデルとマージできるのは、その別モデル作成の際に学習に使われた声の権利者が許諾している場合に限る +- あみたろの声の特徴が残っている場合(マージの割合が25%以上の場合)は、その利用は[あみたろの声素材工房様の規約](https://amitaro.net/voice/voice_rule/)の範囲内に限定され、そのモデルに関してもこの規約が適応される \ No newline at end of file diff --git a/gen_yaml.py b/gen_yaml.py index ac27103ea4cb53876827e1832b5ab05df2ebd862..18e3115fde9e209fcb1ea41b69f2ae562abcd106 100644 --- a/gen_yaml.py +++ b/gen_yaml.py @@ -22,7 +22,7 @@ args = parser.parse_args() def gen_yaml(model_name, dataset_path): if not os.path.exists("config.yml"): shutil.copy(src="default_config.yml", dst="config.yml") - with open("config.yml", "r", encoding="utf-8") as f: + with open("config.yml", encoding="utf-8") as f: yml_data = yaml.safe_load(f) yml_data["model_name"] = model_name yml_data["dataset_path"] = dataset_path diff --git a/gradio_tabs/dataset.py b/gradio_tabs/dataset.py index 21b1063b56b5a8c6870109e5a01d78d5eb5c0add..540459527c4cba7b0c6bee49ba39bf10fd02463a 100644 --- a/gradio_tabs/dataset.py +++ b/gradio_tabs/dataset.py @@ -1,5 +1,6 @@ import gradio as gr +from style_bert_vits2.constants import GRADIO_THEME from style_bert_vits2.logging import logger from style_bert_vits2.utils.subprocess import run_script_with_log @@ -43,10 +44,10 @@ def do_transcribe( compute_type, language, initial_prompt, - device, use_hf_whisper, batch_size, num_beams, + hf_repo_id, ): if model_name == "": return "Error: モデル名を入力してください。" @@ -59,8 +60,6 @@ def do_transcribe( whisper_model, "--compute_type", compute_type, - "--device", - device, "--language", language, "--initial_prompt", @@ -71,9 +70,12 @@ def do_transcribe( if use_hf_whisper: cmd.append("--use_hf_whisper") cmd.extend(["--batch_size", str(batch_size)]) - success, message = run_script_with_log(cmd) + if hf_repo_id != "openai/whisper": + cmd.extend(["--hf_repo_id", hf_repo_id]) + success, message = run_script_with_log(cmd, ignore_warning=True) if not success: return f"Error: {message}. エラーメッセージが空の場合、何も問題がない可能性があるので、書き起こしファイルをチェックして問題なければ無視してください。" + return "音声の文字起こしが完了しました。" how_to_md = """ @@ -82,46 +84,51 @@ Style-Bert-VITS2の学習用データセットを作成するためのツール - 与えられた音声からちょうどいい長さの発話区間を切り取りスライス - 音声に対して文字起こし -このうち両方を使ってもよいし、スライスする必要がない場合は後者のみを使ってもよいです。 +このうち両方を使ってもよいし、スライスする必要がない場合は後者のみを使ってもよいです。**コーパス音源などすでに適度な長さの音声ファイルがある場合はスライスは不要**です。 ## 必要なもの -学習したい音声が入ったwavファイルいくつか。 +学習したい音声が入った音声ファイルいくつか(形式はwav以外でもmp3等通常の音声ファイル形式なら可能)。 合計時間がある程度はあったほうがいいかも、10分とかでも大丈夫だったとの報告あり。単一ファイルでも良いし複数ファイルでもよい。 ## スライス使い方 -1. `inputs`フォルダにwavファイルをすべて入れる +1. `inputs`フォルダに音声ファイルをすべて入れる(スタイル分けをしたい場合は、サブフォルダにスタイルごとに音声を分けて入れる) 2. `モデル名`を入力して、設定を必要なら調整して`音声のスライス`ボタンを押す 3. 出来上がった音声ファイルたちは`Data/{モデル名}/raw`に保存される ## 書き起こし使い方 -1. 書き起こしたい音声ファイルのあるフォルダを指定(デフォルトは`Data/{モデル名}/raw`なのでスライス後に行う場合は省略してよい) +1. `Data/{モデル名}/raw`に音声ファイルが入っていることを確認(直下でなくてもよい) 2. 設定を必要なら調整してボタンを押す 3. 書き起こしファイルは`Data/{モデル名}/esd.list`に保存される ## 注意 -- 長すぎる秒数(12-15秒くらいより長い?)のwavファイルは学習に用いられないようです。また短すぎてもあまりよくない可能性もあります。 +- ~~長すぎる秒数(12-15秒くらいより長い?)のwavファイルは学習に用いられないようです。また短すぎてもあまりよくない可能性もあります。~~ この制限はVer 2.5では学習時に「カスタムバッチサンプラーを使わない」を選択すればなくなりました。が、長すぎる音声があるとVRAM消費量が増えたり安定しなかったりするので、適度な長さにスライスすることをおすすめします。 - 書き起こしの結果をどれだけ修正すればいいかはデータセットに依存しそうです。 -- 手動で書き起こしをいろいろ修正したり結果を細かく確認したい場合は、[Aivis Dataset](https://github.com/litagin02/Aivis-Dataset)もおすすめします。書き起こし部分もかなり工夫されています。ですがファイル数が多い場合などは、このツールで簡易的に切り出してデータセットを作るだけでも十分という気もしています。 """ def create_dataset_app() -> gr.Blocks: - with gr.Blocks() as app: + with gr.Blocks(theme=GRADIO_THEME) as app: + gr.Markdown( + "**既に1ファイル2-12秒程度の音声ファイル集とその書き起こしデータがある場合は、このタブは使用せずに学習できます。**" + ) with gr.Accordion("使い方", open=False): gr.Markdown(how_to_md) model_name = gr.Textbox( label="モデル名を入力してください(話者名としても使われます)。" ) with gr.Accordion("音声のスライス"): + gr.Markdown( + "**すでに適度な長さの音声ファイルからなるデータがある場合は、その音声をData/{モデル名}/rawに入れれば、このステップは不要です。**" + ) with gr.Row(): with gr.Column(): input_dir = gr.Textbox( label="元音声の入っているフォルダパス", value="inputs", - info="下記フォルダにwavファイルを入れておいてください", + info="下記フォルダにwavやmp3等のファイルを入れておいてください", ) min_sec = gr.Slider( minimum=0, @@ -167,6 +174,12 @@ def create_dataset_app() -> gr.Blocks: ) use_hf_whisper = gr.Checkbox( label="HuggingFaceのWhisperを使う(速度が速いがVRAMを多く使う)", + value=True, + ) + hf_repo_id = gr.Dropdown( + ["openai/whisper", "kotoba-tech/kotoba-whisper-v1.1"], + label="HuggingFaceのWhisperモデル", + value="openai/whisper", ) compute_type = gr.Dropdown( [ @@ -181,6 +194,7 @@ def create_dataset_app() -> gr.Blocks: ], label="計算精度", value="bfloat16", + visible=False, ) batch_size = gr.Slider( minimum=1, @@ -189,9 +203,7 @@ def create_dataset_app() -> gr.Blocks: step=1, label="バッチサイズ", info="大きくすると速度が速くなるがVRAMを多く使う", - visible=False, ) - device = gr.Radio(["cuda", "cpu"], label="デバイス", value="cuda") language = gr.Dropdown(["ja", "en", "zh"], value="ja", label="言語") initial_prompt = gr.Textbox( label="初期プロンプト", @@ -228,21 +240,26 @@ def create_dataset_app() -> gr.Blocks: compute_type, language, initial_prompt, - device, use_hf_whisper, batch_size, num_beams, + hf_repo_id, ], outputs=[result2], ) use_hf_whisper.change( lambda x: ( gr.update(visible=x), - gr.update(visible=not x), + gr.update(visible=x), gr.update(visible=not x), ), inputs=[use_hf_whisper], - outputs=[batch_size, compute_type, device], + outputs=[hf_repo_id, batch_size, compute_type], ) return app + + +if __name__ == "__main__": + app = create_dataset_app() + app.launch(inbrowser=True) diff --git a/gradio_tabs/inference.py b/gradio_tabs/inference.py index ef71928a892fc9726dce713ea30381efc7a8dbf7..53393be2119892ec5654caec3c93950c84a93a5f 100644 --- a/gradio_tabs/inference.py +++ b/gradio_tabs/inference.py @@ -96,9 +96,65 @@ examples = [ ] initial_md = """ -- Ver 2.3で追加されたエディターのほうが実際に読み上げさせるには使いやすいかもしれません。`Editor.bat`か`python server_editor.py --inbrowser`で起動できます。 +- Ver 2.5で追加されたデフォルトの [`koharune-ami`(小春音アミ)モデル](https://huggingface.co/litagin/sbv2_koharune_ami) と[`amitaro`(あみたろ)モデル](https://huggingface.co/litagin/sbv2_amitaro) は、[あみたろの声素材工房](https://amitaro.net/)で公開されているコーパス音源・ライブ配信音声を利用して事前に許可を得て学習したモデルです。下記の**利用規約を必ず読んで**からご利用ください。 -- 初期からある[jvnvのモデル](https://huggingface.co/litagin/style_bert_vits2_jvnv)は、[JVNVコーパス(言語音声と非言語音声を持つ日本語感情音声コーパス)](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvnv_corpus)で学習されたモデルです。ライセンスは[CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/deed.ja)です。 +- Ver 2.5のアップデート後に上記モデルをダウンロードするには、`Initialize.bat`をダブルクリックするか、手動でダウンロードして`model_assets`ディレクトリに配置してください。 + +- Ver 2.3で追加された**エディター版**のほうが実際に読み上げさせるには使いやすいかもしれません。`Editor.bat`か`python server_editor.py --inbrowser`で起動できます。 +""" + +terms_of_use_md = """ +## お願いとデフォルトモデルのライセンス + +最新のお願い・利用規約は [こちら](https://github.com/litagin02/Style-Bert-VITS2/blob/master/docs/TERMS_OF_USE.md) を参照してください。常に最新のものが適用されます。 + +Style-Bert-VITS2を用いる際は、以下のお願いを守っていただけると幸いです。ただしモデルの利用規約以前の箇所はあくまで「お願い」であり、何の強制力はなく、Style-Bert-VITS2の利用規約ではありません。よって[リポジトリのライセンス](https://github.com/litagin02/Style-Bert-VITS2#license)とは矛盾せず、リポジトリの利用にあたっては常にリポジトリのライセンスのみが拘束力を持ちます。 + +### やってほしくないこと + +以下の目的での利用はStyle-Bert-VITS2を使ってほしくありません。 + +- 法律に違反する目的 +- 政治的な目的(本家Bert-VITS2で禁止されています) +- 他者を傷つける目的 +- なりすまし・ディープフェイク作成目的 + +### 守ってほしいこと + +- Style-Bert-VITS2を利用する際は、使用するモデルの利用規約・ライセンス必ず確認し、存在する場合はそれに従ってほしいです。 +- またソースコードを利用する際は、[リポジトリのライセンス](https://github.com/litagin02/Style-Bert-VITS2#license)に従ってほしいです。 + +以下はデフォルトで付随しているモデルのライセンスです。 + +### JVNVコーパス (jvnv-F1-jp, jvnv-F2-jp, jvnv-M1-jp, jvnv-M2-jp) + +- [JVNVコーパス](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvnv_corpus) のライセンスは[CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/deed.ja)ですので、これを継承します。 + +### 小春音アミ (koharune-ami) / あみたろ (amitaro) + +[あみたろの声素材工房様の規約](https://amitaro.net/voice/voice_rule/) と [あみたろのライブ配信音声・利用規約](https://amitaro.net/voice/livevoice/#index_id6) を全て守らなければなりません。特に、以下の事項を遵守してください(規約を守れば商用非商用問わず利用できます): + +#### 禁止事項 + +- 年齢制限のある作品・用途への使用 +- 新興宗教・政治・マルチ購などに深く関係する作品・用途 +- 特定の団体や個人や国家を誹謗中傷する作品・用途 +- 生成された音声を、あみたろ本人の声として扱うこと +- 生成された音声を、あみたろ以外の人の声として扱うこと + +#### クレジット表記 + +生成音声を公開する際は(媒体は問わない)、必ず分かりやすい場所に `あみたろの声素材工房 (https://amitaro.net/)` の声を元にした音声モデルを使用していることが分かるようなクレジット表記を記載してください。 + +クレジット表記例: +- `Style-BertVITS2モデル: 小春音アミ、あみたろの声素材工房 (https://amitaro.net/)` +- `Style-BertVITS2モデル: あみたろ、あみたろの声素材工房 (https://amitaro.net/)` + +#### モデルマージ + +モデルマージに関しては、[あみたろの声素材工房のよくある質問への回答](https://amitaro.net/voice/faq/#index_id17)を遵守してください: +- 本モデルを別モデルとマージできるのは、その別モデル作成の際に学習に使われた声の権利者が許諾している場合に限る +- あみたろの声の特徴が残っている場合(マージの割合が25%以上の場合)は、その利用は[あみたろの声素材工房様の規約](https://amitaro.net/voice/voice_rule/)の範囲内に限定され、そのモデルに関してもこの規約が適応される """ how_to_md = """ @@ -260,10 +316,13 @@ def create_inference_app(model_holder: TTSModelHolder) -> gr.Blocks: ) return app initial_id = 0 - initial_pth_files = model_holder.model_files_dict[model_names[initial_id]] + initial_pth_files = [ + str(f) for f in model_holder.model_files_dict[model_names[initial_id]] + ] with gr.Blocks(theme=GRADIO_THEME) as app: gr.Markdown(initial_md) + gr.Markdown(terms_of_use_md) with gr.Accordion(label="使い方", open=False): gr.Markdown(how_to_md) with gr.Row(): @@ -392,10 +451,10 @@ def create_inference_app(model_holder: TTSModelHolder) -> gr.Blocks: ) style_weight = gr.Slider( minimum=0, - maximum=50, + maximum=20, value=DEFAULT_STYLE_WEIGHT, step=0.1, - label="スタイルの強さ", + label="スタイルの強さ(声が崩壊したら小さくしてください)", ) ref_audio_path = gr.Audio( label="参照音声", type="filepath", visible=False @@ -464,3 +523,15 @@ def create_inference_app(model_holder: TTSModelHolder) -> gr.Blocks: ) return app + + +if __name__ == "__main__": + from config import get_path_config + import torch + + path_config = get_path_config() + assets_root = path_config.assets_root + device = "cuda" if torch.cuda.is_available() else "cpu" + model_holder = TTSModelHolder(assets_root, device) + app = create_inference_app(model_holder) + app.launch(inbrowser=True) diff --git a/gradio_tabs/merge.py b/gradio_tabs/merge.py index 661f233690b1eecb31637ed4ca96aef6593ff8d6..b4c9204e3e59d3050edd9b95246949ab89792190 100644 --- a/gradio_tabs/merge.py +++ b/gradio_tabs/merge.py @@ -1,14 +1,14 @@ import json -import os from pathlib import Path +from typing import Any, Union import gradio as gr import numpy as np import torch -import yaml from safetensors import safe_open from safetensors.torch import save_file +from config import get_path_config from style_bert_vits2.constants import DEFAULT_STYLE, GRADIO_THEME from style_bert_vits2.logging import logger from style_bert_vits2.tts_model import TTSModel, TTSModelHolder @@ -20,45 +20,72 @@ speech_style_keys = ["enc_p"] tempo_keys = ["sdp", "dp"] device = "cuda" if torch.cuda.is_available() else "cpu" +path_config = get_path_config() +assets_root = path_config.assets_root -# Get path settings -with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: - path_config: dict[str, str] = yaml.safe_load(f.read()) - # dataset_root = path_config["dataset_root"] - assets_root = path_config["assets_root"] +def load_safetensors(model_path: Union[str, Path]) -> dict[str, torch.Tensor]: + result: dict[str, torch.Tensor] = {} + with safe_open(model_path, framework="pt", device="cpu") as f: + for k in f.keys(): + result[k] = f.get_tensor(k) + return result + + +def load_config(model_name: str) -> dict[str, Any]: + with open(assets_root / model_name / "config.json", encoding="utf-8") as f: + config = json.load(f) + return config + + +def save_config(config: dict[str, Any], model_name: str): + with open(assets_root / model_name / "config.json", "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, ensure_ascii=False) + + +def load_recipe(model_name: str) -> dict[str, Any]: + receipe_path = assets_root / model_name / "recipe.json" + if receipe_path.exists(): + with open(receipe_path, encoding="utf-8") as f: + recipe = json.load(f) + else: + recipe = {} + return recipe + + +def save_recipe(recipe: dict[str, Any], model_name: str): + with open(assets_root / model_name / "recipe.json", "w", encoding="utf-8") as f: + json.dump(recipe, f, indent=2, ensure_ascii=False) + + +def load_style_vectors(model_name: str) -> np.ndarray: + return np.load(assets_root / model_name / "style_vectors.npy") -def merge_style(model_name_a, model_name_b, weight, output_name, style_triple_list): + +def save_style_vectors(style_vectors: np.ndarray, model_name: str): + np.save(assets_root / model_name / "style_vectors.npy", style_vectors) + + +def merge_style_usual( + model_name_a: str, + model_name_b: str, + weight: float, + output_name: str, + style_tuple_list: list[tuple[str, ...]], +) -> list[str]: """ + new = (1 - weight) * A + weight * B style_triple_list: list[(model_aでのスタイル名, model_bでのスタイル名, 出力するスタイル名)] """ - # 新スタイル名リストにNeutralが含まれているか確認し、Neutralを先頭に持ってくる - if any(triple[2] == DEFAULT_STYLE for triple in style_triple_list): - # 存在する場合、リストをソート - sorted_list = sorted(style_triple_list, key=lambda x: x[2] != DEFAULT_STYLE) - else: - # 存在しない場合、エラーを発生 - raise ValueError(f"No element with {DEFAULT_STYLE} output style name found.") - - style_vectors_a = np.load( - os.path.join(assets_root, model_name_a, "style_vectors.npy") - ) # (style_num_a, 256) - style_vectors_b = np.load( - os.path.join(assets_root, model_name_b, "style_vectors.npy") - ) # (style_num_b, 256) - with open( - os.path.join(assets_root, model_name_a, "config.json"), "r", encoding="utf-8" - ) as f: - config_a = json.load(f) - with open( - os.path.join(assets_root, model_name_b, "config.json"), "r", encoding="utf-8" - ) as f: - config_b = json.load(f) + style_vectors_a = load_style_vectors(model_name_a) + style_vectors_b = load_style_vectors(model_name_b) + config_a = load_config(model_name_a) + config_b = load_config(model_name_b) style2id_a = config_a["data"]["style2id"] style2id_b = config_b["data"]["style2id"] new_style_vecs = [] new_style2id = {} - for style_a, style_b, style_out in sorted_list: + for style_a, style_b, style_out in style_tuple_list: if style_a not in style2id_a: logger.error(f"{style_a} is not in {model_name_a}.") raise ValueError(f"{style_a} は {model_name_a} にありません。") @@ -72,38 +99,191 @@ def merge_style(model_name_a, model_name_b, weight, output_name, style_triple_li new_style_vecs.append(new_style) new_style2id[style_out] = len(new_style_vecs) - 1 new_style_vecs = np.array(new_style_vecs) + save_style_vectors(new_style_vecs, output_name) + + new_config = config_a.copy() + new_config["data"]["num_styles"] = len(new_style2id) + new_config["data"]["style2id"] = new_style2id + new_config["model_name"] = output_name + save_config(new_config, output_name) + + receipe = load_recipe(output_name) + receipe["style_tuple_list"] = style_tuple_list + save_recipe(receipe, output_name) + + return list(new_style2id.keys()) + + +def merge_style_add_diff( + model_name_a: str, + model_name_b: str, + model_name_c: str, + weight: float, + output_name: str, + style_tuple_list: list[tuple[str, ...]], +) -> list[str]: + """ + new = A + weight * (B - C) + style_tuple_list: list[(model_aでのスタイル名, model_bでのスタイル名, model_cでのスタイル名, 出力するスタイル名)] + """ + style_vectors_a = load_style_vectors(model_name_a) + style_vectors_b = load_style_vectors(model_name_b) + style_vectors_c = load_style_vectors(model_name_c) + config_a = load_config(model_name_a) + config_b = load_config(model_name_b) + config_c = load_config(model_name_c) + style2id_a = config_a["data"]["style2id"] + style2id_b = config_b["data"]["style2id"] + style2id_c = config_c["data"]["style2id"] + new_style_vecs = [] + new_style2id = {} + for style_a, style_b, style_c, style_out in style_tuple_list: + if style_a not in style2id_a: + logger.error(f"{style_a} is not in {model_name_a}.") + raise ValueError(f"{style_a} は {model_name_a} にありません。") + if style_b not in style2id_b: + logger.error(f"{style_b} is not in {model_name_b}.") + raise ValueError(f"{style_b} は {model_name_b} にありません。") + if style_c not in style2id_c: + logger.error(f"{style_c} is not in {model_name_c}.") + raise ValueError(f"{style_c} は {model_name_c} にありません。") + new_style = style_vectors_a[style2id_a[style_a]] + weight * ( + style_vectors_b[style2id_b[style_b]] - style_vectors_c[style2id_c[style_c]] + ) + new_style_vecs.append(new_style) + new_style2id[style_out] = len(new_style_vecs) - 1 + new_style_vecs = np.array(new_style_vecs) - output_style_path = os.path.join(assets_root, output_name, "style_vectors.npy") - np.save(output_style_path, new_style_vecs) + save_style_vectors(new_style_vecs, output_name) new_config = config_a.copy() new_config["data"]["num_styles"] = len(new_style2id) new_config["data"]["style2id"] = new_style2id new_config["model_name"] = output_name - with open( - os.path.join(assets_root, output_name, "config.json"), "w", encoding="utf-8" - ) as f: - json.dump(new_config, f, indent=2, ensure_ascii=False) + save_config(new_config, output_name) + + receipe = load_recipe(output_name) + receipe["style_tuple_list"] = style_tuple_list + save_recipe(receipe, output_name) + + return list(new_style2id.keys()) - # recipe.jsonを読み込んで、style_triple_listを追記 - info_path = os.path.join(assets_root, output_name, "recipe.json") - if os.path.exists(info_path): - with open(info_path, "r", encoding="utf-8") as f: - info = json.load(f) - else: - info = {} - info["style_triple_list"] = style_triple_list - with open(info_path, "w", encoding="utf-8") as f: - json.dump(info, f, indent=2, ensure_ascii=False) - return output_style_path, list(new_style2id.keys()) +def merge_style_weighted_sum( + model_name_a: str, + model_name_b: str, + model_name_c: str, + model_a_coeff: float, + model_b_coeff: float, + model_c_coeff: float, + output_name: str, + style_tuple_list: list[tuple[str, ...]], +) -> list[str]: + """ + new = A * model_a_coeff + B * model_b_coeff + C * model_c_coeff + style_tuple_list: list[(model_aでのスタイル名, model_bでのスタイル名, model_cでのスタイル名, 出力するスタイル名)] + """ + style_vectors_a = load_style_vectors(model_name_a) + style_vectors_b = load_style_vectors(model_name_b) + style_vectors_c = load_style_vectors(model_name_c) + config_a = load_config(model_name_a) + config_b = load_config(model_name_b) + config_c = load_config(model_name_c) + style2id_a = config_a["data"]["style2id"] + style2id_b = config_b["data"]["style2id"] + style2id_c = config_c["data"]["style2id"] + new_style_vecs = [] + new_style2id = {} + for style_a, style_b, style_c, style_out in style_tuple_list: + if style_a not in style2id_a: + logger.error(f"{style_a} is not in {model_name_a}.") + raise ValueError(f"{style_a} は {model_name_a} にありません。") + if style_b not in style2id_b: + logger.error(f"{style_b} is not in {model_name_b}.") + raise ValueError(f"{style_b} は {model_name_b} にありません。") + if style_c not in style2id_c: + logger.error(f"{style_c} is not in {model_name_c}.") + raise ValueError(f"{style_c} は {model_name_c} にありません。") + new_style = ( + style_vectors_a[style2id_a[style_a]] * model_a_coeff + + style_vectors_b[style2id_b[style_b]] * model_b_coeff + + style_vectors_c[style2id_c[style_c]] * model_c_coeff + ) + new_style_vecs.append(new_style) + new_style2id[style_out] = len(new_style_vecs) - 1 + new_style_vecs = np.array(new_style_vecs) + save_style_vectors(new_style_vecs, output_name) -def lerp_tensors(t, v0, v1): + new_config = config_a.copy() + new_config["data"]["num_styles"] = len(new_style2id) + new_config["data"]["style2id"] = new_style2id + new_config["model_name"] = output_name + save_config(new_config, output_name) + + receipe = load_recipe(output_name) + receipe["style_tuple_list"] = style_tuple_list + save_recipe(receipe, output_name) + + return list(new_style2id.keys()) + + +def merge_style_add_null( + model_name_a: str, + model_name_b: str, + weight: float, + output_name: str, + style_tuple_list: list[tuple[str, ...]], +) -> list[str]: + """ + new = A + weight * B + style_tuple_list: list[(model_aでのスタイル名, model_bでのスタイル名, 出力するスタイル名)] + """ + style_vectors_a = load_style_vectors(model_name_a) + style_vectors_b = load_style_vectors(model_name_b) + config_a = load_config(model_name_a) + config_b = load_config(model_name_b) + style2id_a = config_a["data"]["style2id"] + style2id_b = config_b["data"]["style2id"] + new_style_vecs = [] + new_style2id = {} + for style_a, style_b, style_out in style_tuple_list: + if style_a not in style2id_a: + logger.error(f"{style_a} is not in {model_name_a}.") + raise ValueError(f"{style_a} は {model_name_a} にありません。") + if style_b not in style2id_b: + logger.error(f"{style_b} is not in {model_name_b}.") + raise ValueError(f"{style_b} は {model_name_b} にありません。") + new_style = ( + style_vectors_a[style2id_a[style_a]] + + weight * style_vectors_b[style2id_b[style_b]] + ) + new_style_vecs.append(new_style) + new_style2id[style_out] = len(new_style_vecs) - 1 + new_style_vecs = np.array(new_style_vecs) + + save_style_vectors(new_style_vecs, output_name) + + new_config = config_a.copy() + new_config["data"]["num_styles"] = len(new_style2id) + new_config["data"]["style2id"] = new_style2id + new_config["model_name"] = output_name + save_config(new_config, output_name) + + receipe = load_recipe(output_name) + receipe["style_tuple_list"] = style_tuple_list + save_recipe(receipe, output_name) + + return list(new_style2id.keys()) + + +def lerp_tensors(t: float, v0: torch.Tensor, v1: torch.Tensor): return v0 * (1 - t) + v1 * t -def slerp_tensors(t, v0, v1, dot_thres=0.998): +def slerp_tensors( + t: float, v0: torch.Tensor, v1: torch.Tensor, dot_thres: float = 0.998 +): device = v0.device v0c = v0.cpu().numpy() v1c = v1.cpu().numpy() @@ -122,31 +302,25 @@ def slerp_tensors(t, v0, v1, dot_thres=0.998): ).to(device) -def merge_models( - model_path_a, - model_path_b, - voice_weight, - voice_pitch_weight, - speech_style_weight, - tempo_weight, - output_name, - use_slerp_instead_of_lerp, +def merge_models_usual( + model_path_a: str, + model_path_b: str, + voice_weight: float, + voice_pitch_weight: float, + speech_style_weight: float, + tempo_weight: float, + output_name: str, + use_slerp_instead_of_lerp: bool, ): - """model Aを起点に、model Bの各要素を重み付けしてマージする。 - safetensors形式を前提とする。""" - model_a_weight = {} - with safe_open(model_path_a, framework="pt", device="cpu") as f: - for k in f.keys(): - model_a_weight[k] = f.get_tensor(k) - - model_b_weight = {} - with safe_open(model_path_b, framework="pt", device="cpu") as f: - for k in f.keys(): - model_b_weight[k] = f.get_tensor(k) + """ + new = (1 - weight) * A + weight * B + """ + model_a_weight = load_safetensors(model_path_a) + model_b_weight = load_safetensors(model_path_b) merged_model_weight = model_a_weight.copy() - for key in model_a_weight.keys(): + for key in model_a_weight: if any([key.startswith(prefix) for prefix in voice_keys]): weight = voice_weight elif any([key.startswith(prefix) for prefix in voice_pitch_keys]): @@ -161,13 +335,239 @@ def merge_models( slerp_tensors if use_slerp_instead_of_lerp else lerp_tensors )(weight, model_a_weight[key], model_b_weight[key]) - merged_model_path = os.path.join( - assets_root, output_name, f"{output_name}.safetensors" + merged_model_path = assets_root / output_name / f"{output_name}.safetensors" + merged_model_path.parent.mkdir(parents=True, exist_ok=True) + save_file(merged_model_weight, merged_model_path) + + receipe = { + "method": "usual", + "model_a": model_path_a, + "model_b": model_path_b, + "voice_weight": voice_weight, + "voice_pitch_weight": voice_pitch_weight, + "speech_style_weight": speech_style_weight, + "tempo_weight": tempo_weight, + "use_slerp_instead_of_lerp": use_slerp_instead_of_lerp, + } + save_recipe(receipe, output_name) + + # Merge default Neutral style vectors and save + model_name_a = Path(model_path_a).parent.name + model_name_b = Path(model_path_b).parent.name + style_vectors_a = load_style_vectors(model_name_a) + style_vectors_b = load_style_vectors(model_name_b) + + new_config = load_config(model_name_a) + new_config["model_name"] = output_name + new_config["data"]["num_styles"] = 1 + new_config["data"]["style2id"] = {DEFAULT_STYLE: 0} + save_config(new_config, output_name) + + neutral_vector_a = style_vectors_a[0] + neutral_vector_b = style_vectors_b[0] + weight = speech_style_weight + new_neutral_vector = (1 - weight) * neutral_vector_a + weight * neutral_vector_b + new_style_vectors = np.array([new_neutral_vector]) + save_style_vectors(new_style_vectors, output_name) + return merged_model_path + + +def merge_models_add_diff( + model_path_a: str, + model_path_b: str, + model_path_c: str, + voice_weight: float, + voice_pitch_weight: float, + speech_style_weight: float, + tempo_weight: float, + output_name: str, +): + """ + new = A + weight * (B - C) + """ + model_a_weight = load_safetensors(model_path_a) + model_b_weight = load_safetensors(model_path_b) + model_c_weight = load_safetensors(model_path_c) + + merged_model_weight = model_a_weight.copy() + + for key in model_a_weight: + if any([key.startswith(prefix) for prefix in voice_keys]): + weight = voice_weight + elif any([key.startswith(prefix) for prefix in voice_pitch_keys]): + weight = voice_pitch_weight + elif any([key.startswith(prefix) for prefix in speech_style_keys]): + weight = speech_style_weight + elif any([key.startswith(prefix) for prefix in tempo_keys]): + weight = tempo_weight + else: + continue + merged_model_weight[key] = model_a_weight[key] + weight * ( + model_b_weight[key] - model_c_weight[key] + ) + + merged_model_path = assets_root / output_name / f"{output_name}.safetensors" + merged_model_path.parent.mkdir(parents=True, exist_ok=True) + save_file(merged_model_weight, merged_model_path) + + info = { + "method": "add_diff", + "model_a": model_path_a, + "model_b": model_path_b, + "model_c": model_path_c, + "voice_weight": voice_weight, + "voice_pitch_weight": voice_pitch_weight, + "speech_style_weight": speech_style_weight, + "tempo_weight": tempo_weight, + } + with open(assets_root / output_name / "recipe.json", "w", encoding="utf-8") as f: + json.dump(info, f, indent=2, ensure_ascii=False) + + # Default style merge only using Neutral style + model_name_a = Path(model_path_a).parent.name + model_name_b = Path(model_path_b).parent.name + model_name_c = Path(model_path_c).parent.name + + style_vectors_a = np.load( + assets_root / model_name_a / "style_vectors.npy" + ) # (style_num_a, 256) + style_vectors_b = np.load( + assets_root / model_name_b / "style_vectors.npy" + ) # (style_num_b, 256) + style_vectors_c = np.load( + assets_root / model_name_c / "style_vectors.npy" + ) # (style_num_c, 256) + with open(assets_root / model_name_a / "config.json", encoding="utf-8") as f: + new_config = json.load(f) + + new_config["model_name"] = output_name + new_config["data"]["num_styles"] = 1 + new_config["data"]["style2id"] = {DEFAULT_STYLE: 0} + with open(assets_root / output_name / "config.json", "w", encoding="utf-8") as f: + json.dump(new_config, f, indent=2, ensure_ascii=False) + + neutral_vector_a = style_vectors_a[0] + neutral_vector_b = style_vectors_b[0] + neutral_vector_c = style_vectors_c[0] + weight = speech_style_weight + new_neutral_vector = neutral_vector_a + weight * ( + neutral_vector_b - neutral_vector_c + ) + new_style_vectors = np.array([new_neutral_vector]) + new_style_path = assets_root / output_name / "style_vectors.npy" + np.save(new_style_path, new_style_vectors) + return merged_model_path + + +def merge_models_weighted_sum( + model_path_a: str, + model_path_b: str, + model_path_c: str, + model_a_coeff: float, + model_b_coeff: float, + model_c_coeff: float, + output_name: str, +): + model_a_weight = load_safetensors(model_path_a) + model_b_weight = load_safetensors(model_path_b) + model_c_weight = load_safetensors(model_path_c) + + merged_model_weight = model_a_weight.copy() + + for key in model_a_weight: + merged_model_weight[key] = ( + model_a_coeff * model_a_weight[key] + + model_b_coeff * model_b_weight[key] + + model_c_coeff * model_c_weight[key] + ) + + merged_model_path = assets_root / output_name / f"{output_name}.safetensors" + merged_model_path.parent.mkdir(parents=True, exist_ok=True) + save_file(merged_model_weight, merged_model_path) + + info = { + "method": "weighted_sum", + "model_a": model_path_a, + "model_b": model_path_b, + "model_c": model_path_c, + "model_a_coeff": model_a_coeff, + "model_b_coeff": model_b_coeff, + "model_c_coeff": model_c_coeff, + } + with open(assets_root / output_name / "recipe.json", "w", encoding="utf-8") as f: + json.dump(info, f, indent=2, ensure_ascii=False) + + # Default style merge only using Neutral style + model_name_a = Path(model_path_a).parent.name + model_name_b = Path(model_path_b).parent.name + model_name_c = Path(model_path_c).parent.name + + style_vectors_a = np.load( + assets_root / model_name_a / "style_vectors.npy" + ) # (style_num_a, 256) + style_vectors_b = np.load( + assets_root / model_name_b / "style_vectors.npy" + ) # (style_num_b, 256) + style_vectors_c = np.load( + assets_root / model_name_c / "style_vectors.npy" + ) # (style_num_c, 256) + + with open(assets_root / model_name_a / "config.json", encoding="utf-8") as f: + new_config = json.load(f) + + new_config["model_name"] = output_name + new_config["data"]["num_styles"] = 1 + new_config["data"]["style2id"] = {DEFAULT_STYLE: 0} + with open(assets_root / output_name / "config.json", "w", encoding="utf-8") as f: + json.dump(new_config, f, indent=2, ensure_ascii=False) + + neutral_vector_a = style_vectors_a[0] + neutral_vector_b = style_vectors_b[0] + neutral_vector_c = style_vectors_c[0] + new_neutral_vector = ( + model_a_coeff * neutral_vector_a + + model_b_coeff * neutral_vector_b + + model_c_coeff * neutral_vector_c ) - os.makedirs(os.path.dirname(merged_model_path), exist_ok=True) + new_style_vectors = np.array([new_neutral_vector]) + new_style_path = assets_root / output_name / "style_vectors.npy" + np.save(new_style_path, new_style_vectors) + return merged_model_path + + +def merge_models_add_null( + model_path_a: str, + model_path_b: str, + voice_weight: float, + voice_pitch_weight: float, + speech_style_weight: float, + tempo_weight: float, + output_name: str, +): + model_a_weight = load_safetensors(model_path_a) + model_b_weight = load_safetensors(model_path_b) + + merged_model_weight = model_a_weight.copy() + + for key in model_a_weight: + if any([key.startswith(prefix) for prefix in voice_keys]): + weight = voice_weight + elif any([key.startswith(prefix) for prefix in voice_pitch_keys]): + weight = voice_pitch_weight + elif any([key.startswith(prefix) for prefix in speech_style_keys]): + weight = speech_style_weight + elif any([key.startswith(prefix) for prefix in tempo_keys]): + weight = tempo_weight + else: + continue + merged_model_weight[key] = model_a_weight[key] + weight * model_b_weight[key] + + merged_model_path = assets_root / output_name / f"{output_name}.safetensors" + merged_model_path.parent.mkdir(parents=True, exist_ok=True) save_file(merged_model_weight, merged_model_path) info = { + "method": "add_null", "model_a": model_path_a, "model_b": model_path_b, "voice_weight": voice_weight, @@ -175,98 +575,253 @@ def merge_models( "speech_style_weight": speech_style_weight, "tempo_weight": tempo_weight, } - with open( - os.path.join(assets_root, output_name, "recipe.json"), "w", encoding="utf-8" - ) as f: + with open(assets_root / output_name / "recipe.json", "w", encoding="utf-8") as f: json.dump(info, f, indent=2, ensure_ascii=False) + + # Default style merge only using Neutral style + model_name_a = Path(model_path_a).parent.name + model_name_b = Path(model_path_b).parent.name + + style_vectors_a = np.load( + assets_root / model_name_a / "style_vectors.npy" + ) # (style_num_a, 256) + style_vectors_b = np.load( + assets_root / model_name_b / "style_vectors.npy" + ) # (style_num_b, 256) + with open(assets_root / model_name_a / "config.json", encoding="utf-8") as f: + new_config = json.load(f) + + new_config["model_name"] = output_name + new_config["data"]["num_styles"] = 1 + new_config["data"]["style2id"] = {DEFAULT_STYLE: 0} + with open(assets_root / output_name / "config.json", "w", encoding="utf-8") as f: + json.dump(new_config, f, indent=2, ensure_ascii=False) + + neutral_vector_a = style_vectors_a[0] + neutral_vector_b = style_vectors_b[0] + weight = speech_style_weight + new_neutral_vector = neutral_vector_a + weight * neutral_vector_b + new_style_vectors = np.array([new_neutral_vector]) + new_style_path = assets_root / output_name / "style_vectors.npy" + np.save(new_style_path, new_style_vectors) return merged_model_path def merge_models_gr( - model_name_a, - model_path_a, - model_name_b, - model_path_b, - output_name, - voice_weight, - voice_pitch_weight, - speech_style_weight, - tempo_weight, - use_slerp_instead_of_lerp, + model_path_a: str, + model_path_b: str, + model_path_c: str, + model_a_coeff: float, + model_b_coeff: float, + model_c_coeff: float, + method: str, + output_name: str, + voice_weight: float, + voice_pitch_weight: float, + speech_style_weight: float, + tempo_weight: float, + use_slerp_instead_of_lerp: bool, ): if output_name == "": return "Error: 新しいモデル名を入力してください。" - merged_model_path = merge_models( - model_path_a, - model_path_b, - voice_weight, - voice_pitch_weight, - speech_style_weight, - tempo_weight, + assert method in [ + "usual", + "add_diff", + "weighted_sum", + "add_null", + ], f"Invalid method: {method}" + model_a_name = Path(model_path_a).parent.name + model_b_name = Path(model_path_b).parent.name + model_c_name = Path(model_path_c).parent.name + if method == "usual": + if output_name in [model_a_name, model_b_name]: + return "Error: マージ元のモデル名と同じ名前は使用できません。", None + merged_model_path = merge_models_usual( + model_path_a, + model_path_b, + voice_weight, + voice_pitch_weight, + speech_style_weight, + tempo_weight, + output_name, + use_slerp_instead_of_lerp, + ) + elif method == "add_diff": + if output_name in [model_a_name, model_b_name, model_c_name]: + return "Error: マージ元のモデル名と同じ名前は使用できません。", None + merged_model_path = merge_models_add_diff( + model_path_a, + model_path_b, + model_path_c, + voice_weight, + voice_pitch_weight, + speech_style_weight, + tempo_weight, + output_name, + ) + elif method == "weighted_sum": + if output_name in [model_a_name, model_b_name, model_c_name]: + return "Error: マージ元のモデル名と同じ名前は使用できません。", None + merged_model_path = merge_models_weighted_sum( + model_path_a, + model_path_b, + model_path_c, + model_a_coeff, + model_b_coeff, + model_c_coeff, + output_name, + ) + else: # add_null + if output_name in [model_a_name, model_b_name]: + return "Error: マージ元のモデル名と同じ名前は使用できません。", None + merged_model_path = merge_models_add_null( + model_path_a, + model_path_b, + voice_weight, + voice_pitch_weight, + speech_style_weight, + tempo_weight, + output_name, + ) + return f"Success: モデルを{merged_model_path}に保存しました。", gr.Dropdown( + choices=[DEFAULT_STYLE], value=DEFAULT_STYLE + ) + + +def merge_style_usual_gr( + model_name_a: str, + model_name_b: str, + weight: float, + output_name: str, + style_tuple_list: list[tuple[str, ...]], +): + if output_name == "": + return "Error: 新しいモデル名を入力してください。", None + new_styles = merge_style_usual( + model_name_a, + model_name_b, + weight, + output_name, + style_tuple_list, + ) + return f"Success: {output_name}のスタイルを保存しました。", gr.Dropdown( + choices=new_styles, value=new_styles[0] + ) + + +def merge_style_add_diff_gr( + model_name_a: str, + model_name_b: str, + model_name_c: str, + weight: float, + output_name: str, + style_tuple_list: list[tuple[str, ...]], +): + if output_name == "": + return "Error: 新しいモデル名を入力してください。", None + new_styles = merge_style_add_diff( + model_name_a, + model_name_b, + model_name_c, + weight, output_name, - use_slerp_instead_of_lerp, + style_tuple_list, + ) + return f"Success: {output_name}のスタイルを保存しました。", gr.Dropdown( + choices=new_styles, value=new_styles[0] ) - return f"Success: モデルを{merged_model_path}に保存しました。" -def merge_style_gr( - model_name_a, - model_name_b, - weight, - output_name, - style_triple_list_str: str, +def merge_style_weighted_sum_gr( + model_name_a: str, + model_name_b: str, + model_name_c: str, + model_a_coeff: float, + model_b_coeff: float, + model_c_coeff: float, + output_name: str, + style_tuple_list: list[tuple[str, ...]], ): if output_name == "": return "Error: 新しいモデル名を入力してください。", None - style_triple_list = [] - for line in style_triple_list_str.split("\n"): - if not line: - continue - style_triple = line.split(",") - if len(style_triple) != 3: - logger.error(f"Invalid style triple: {line}") - return ( - f"Error: スタイルを3つのカンマ区切りで入力してください:\n{line}", - None, - ) - style_a, style_b, style_out = style_triple - style_a = style_a.strip() - style_b = style_b.strip() - style_out = style_out.strip() - style_triple_list.append((style_a, style_b, style_out)) - try: - new_style_path, new_styles = merge_style( - model_name_a, model_name_b, weight, output_name, style_triple_list - ) - except ValueError as e: - return f"Error: {e}" - return f"Success: スタイルを{new_style_path}に保存しました。", gr.Dropdown( + new_styles = merge_style_weighted_sum( + model_name_a, + model_name_b, + model_name_c, + model_a_coeff, + model_b_coeff, + model_c_coeff, + output_name, + style_tuple_list, + ) + return f"Success: {output_name}のスタイルを保存しました。", gr.Dropdown( choices=new_styles, value=new_styles[0] ) -def simple_tts(model_name, text, style=DEFAULT_STYLE, style_weight=1.0): - model_path = os.path.join(assets_root, model_name, f"{model_name}.safetensors") - config_path = os.path.join(assets_root, model_name, "config.json") - style_vec_path = os.path.join(assets_root, model_name, "style_vectors.npy") +def merge_style_add_null_gr( + model_name_a: str, + model_name_b: str, + weight: float, + output_name: str, + style_tuple_list: list[tuple[str, ...]], +): + if output_name == "": + return "Error: 新しいモデル名を入力してください。", None + new_styles = merge_style_add_null( + model_name_a, + model_name_b, + weight, + output_name, + style_tuple_list, + ) + return f"Success: {output_name}のスタイルを保存しました。", gr.Dropdown( + choices=new_styles, value=new_styles[0] + ) + - model = TTSModel(Path(model_path), Path(config_path), Path(style_vec_path), device) - return model.infer(text, style=style, style_weight=style_weight) +def simple_tts( + model_name: str, text: str, style: str = DEFAULT_STYLE, style_weight: float = 1.0 +): + if model_name == "": + return "Error: モデル名を入力してください。", None + model_path = assets_root / model_name / f"{model_name}.safetensors" + config_path = assets_root / model_name / "config.json" + style_vec_path = assets_root / model_name / "style_vectors.npy" + model = TTSModel(model_path, config_path, style_vec_path, device) -def update_two_model_names_dropdown(model_holder: TTSModelHolder): + return ( + "Success: 音声を生成しました。", + model.infer(text, style=style, style_weight=style_weight), + ) + + +def update_three_model_names_dropdown(model_holder: TTSModelHolder): new_names, new_files, _ = model_holder.update_model_names_for_gradio() - return new_names, new_files, new_names, new_files + return new_names, new_files, new_names, new_files, new_names, new_files + + +def get_styles(model_name: str): + config_path = assets_root / model_name / "config.json" + with open(config_path, encoding="utf-8") as f: + config = json.load(f) + styles = list(config["data"]["style2id"].keys()) + return styles -def load_styles_gr(model_name_a, model_name_b): - config_path_a = os.path.join(assets_root, model_name_a, "config.json") - with open(config_path_a, "r", encoding="utf-8") as f: +def get_triple_styles(model_name_a: str, model_name_b: str, model_name_c: str): + return get_styles(model_name_a), get_styles(model_name_b), get_styles(model_name_c) + + +def load_styles_gr(model_name_a: str, model_name_b: str): + config_path_a = assets_root / model_name_a / "config.json" + with open(config_path_a, encoding="utf-8") as f: config_a = json.load(f) styles_a = list(config_a["data"]["style2id"].keys()) - config_path_b = os.path.join(assets_root, model_name_b, "config.json") - with open(config_path_b, "r", encoding="utf-8") as f: + config_path_b = assets_root / model_name_b / "config.json" + with open(config_path_b, encoding="utf-8") as f: config_b = json.load(f) styles_b = list(config_b["data"]["style2id"].keys()) @@ -288,11 +843,30 @@ def load_styles_gr(model_name_a, model_name_b): initial_md = """ ## 使い方 -1. マージしたい2つのモデルを選択してください(`model_assets`フォルダの中から選ばれます)。 -2. マージ後のモデルの名前を入力してください。 -3. マージ後のモデルの声質・話し方・話す速さを調整してください。 -4. 「モデルファイルのマージ」ボタンを押してください(safetensorsファイルがマージされる)。 -5. スタイルベクトルファイルも生成する必要があるので、指示に従ってマージ方法を入力後、「スタイルのマージ」ボタンを押してください。 +### マージ方法の選択 + +マージの方法には4つの方法があります。 +- 通常のマージ `new = (1 - weight) * A + weight * B`: AとBのモデルを指定して、要素ごとに比率を指定して混ぜる + - 単純にAとBの二人の話し方や声音を混ぜたいとき +- 差分マージ `new = A + weight * (B - C)`: AとBとCのモデルを指定して、「Bの要素からCの要素を引いたもの」をAに足す + - 例えば、Bが「Cと同じ人だけど囁いているモデル」とすると、`B - C`は「囁きを表すベクトル」だと思えるので、それをAに足すことで、Aの声のままで囁き声を出すモデルができたりする + - 他にも活用例はいろいろありそう +- 重み付き和 `new = a * A + b * B + c * C`: AとBとCのモデルを指定して、各モデルの係数を指定して混ぜる + - 例えば`new = A - B` としておくと、結果としてできたモデルを別のモデルと「ヌルモデルの加算」で使うことで、差分マージが実現できる + - 他にも何らかの活用法があるかもしれない +- ヌルモデルの加算 `new = A + weight * B`: AとBのモデルを指定して、Bのモデルに要素ごとに比率をかけたものをAに足す + - Bのモデルは重み付き和などで `C - D` などとして作っている場合を想定している + - 他にも何らかの活用法があるかもしれない + + +### マージの手順 + +1. マージ元のモデルたちを選択(`model_assets`フォルダの中から選ばれます) +2. マージ後のモデルの名前を入力 +3. 指示に従って重みや係数を入力 +4. 「モデルファイルのマージ」ボタンを押す (safetensorsファイルがマージされる) +5. 結果を簡易音声合成で確認 +6. 必要に応じてスタイルベクトルのマージを行う 以上でマージは完了で、`model_assets/マージ後のモデル名`にマージ後のモデルが保存され、音声合成のときに使えます。 @@ -301,28 +875,131 @@ initial_md = """ 一番下にマージしたモデルによる簡易的な音声合成機能もつけています。 ## 注意 -1.x系と2.x-JP-Extraのモデルマージは失敗するようです。 + +- 1.x系と2.x-JP-Extraのモデルマージは失敗するようです。 +- 話者数が違うモデル同士はおそらくマージできません。 """ style_merge_md = f""" -## スタイルベクトルのマージ +## 3. スタイルベクトルのマージ + +1. マージ後のモデルにいくつスタイルを追加したいかを「作りたいスタイル数」で指定 +2. マージ前のモデルのスタイルを「各モデルのスタイルを取得」ボタンで取得 +3. どのスタイルたちから新しいスタイルを作るかを下の欄で入力 +4. 「スタイルのマージ」をクリック + +### スタイルベクトルの混ぜられ方 -1行に「モデルAのスタイル名, モデルBのスタイル名, 左の2つを混ぜて出力するスタイル名」 -という形式で入力してください。例えば、 +- 構造上の相性の関係で、スタイルベクトルを混ぜる重みは、加重和以外の場合は、上の「話し方」と同じ比率で混ぜられます。例えば「話し方」が0のときはモデルAのみしか使われません。 +- 加重和の場合は、AとBとCの係数によって混ぜられます。 +""" + +usual_md = """ +`weight` を下の各スライダーで定める数値とすると、各要素ごとに、 ``` -{DEFAULT_STYLE}, {DEFAULT_STYLE}, {DEFAULT_STYLE} -Happy, Surprise, HappySurprise +new_model = (1 - weight) * A + weight * B ``` -と入力すると、マージ後のスタイルベクトルは、 -- `{DEFAULT_STYLE}`: モデルAの`{DEFAULT_STYLE}`とモデルBの`{DEFAULT_STYLE}`を混ぜたもの -- `HappySurprise`: モデルAの`Happy`とモデルBの`Surprise`を混ぜたもの -の2つになります。 - -### 注意 -- 必ず「{DEFAULT_STYLE}」という名前のスタイルを作ってください。これは、マージ後のモデルの平均スタイルになります。 -- 構造上の相性の関係で、スタイルベクトルを混ぜる重みは、上の「話し方」と同じ比率で混ぜられます。例えば「話し方」が0のときはモデルAのみしか使われません。 +としてマージされます。 + +つまり、`weight = 0` のときはモデルA、`weight = 1` のときはモデルBになります。 """ +add_diff_md = """ +`weight` を下の各スライダーで定める数値とすると、各要素ごとに、 +``` +new_model = A + weight * (B - C) +``` +としてマージされます。 + +通常のマージと違い、**重みを1にしてもAの要素はそのまま保たれます**。 +""" + +weighted_sum_md = """ +モデルの係数をそれぞれ `a`, `b`, `c` とすると、 **全要素に対して**、 +``` +new_model = a * A + b * B + c * C +``` +としてマージされます。 + +## TIPS + +- A, B, C が全て通常モデルで、通常モデルを作りたい場合は、`a + b + c = 1`となるようにするのがよいと思います。 +- `a + b + c = 0` とすると(たとえば `A - B`)、話者性を持たないヌルモデルを作ることができ、「ヌルモデルとの和」で結果を使うことが出来ます(差分マージの材料などに) +- 他にも、`a = 0.5, b = c = 0`などでモデルAを謎に小さくしたり大きくしたり負にしたりできるので、実験に使ってください。 +""" + +add_null_md = """ +「ヌルモデル」を、いくつかのモデルの加重和であってその係数の和が0であるようなものとします(例えば `C - D` など)。 + +そうして作ったヌルモデルBと通常モデルAに対して、`weight` を下の各スライダーで定める数値とすると、各要素ごとに、 +``` +new_model = A + weight * B +``` +としてマージされます。 + +通常のマージと違い、**重みを1にしてもAの要素はそのまま保たれます**。 + +実際にはヌルモデルでないBに対しても使えますが、その場合はおそらく音声が正常に生成されないモデルができる気がします。が、もしかしたら何かに使えるかもしれません。 + +囁きについて実験的に作ったヌルモデルを[こちら](https://huggingface.co/litagin/sbv2_null_models)に置いています。これを `B` に使うことで、任意のモデルを囁きモデルにある程度は変換できます。 +""" + +tts_md = f""" +## 2. 結果のテスト + +マージ後のモデルで音声合成を行います。ただし、デフォルトではスタイルは`{DEFAULT_STYLE}`しか使えないので、他のスタイルを使いたい場合は、下の「スタイルベクトルのマージ」を行ってください。 +""" + + +def method_change(x: str): + assert x in [ + "usual", + "add_diff", + "weighted_sum", + "add_null", + ], f"Invalid method: {x}" + # model_desc, c_col, model_a_coeff, model_b_coeff, model_c_coeff, weight_row, use_slerp_instead_of_lerp + if x == "usual": + return ( + gr.Markdown(usual_md), + gr.Column(visible=False), + gr.Number(visible=False), + gr.Number(visible=False), + gr.Number(visible=False), + gr.Row(visible=True), + gr.Checkbox(visible=True), + ) + elif x == "add_diff": + return ( + gr.Markdown(add_diff_md), + gr.Column(visible=True), + gr.Number(visible=False), + gr.Number(visible=False), + gr.Number(visible=False), + gr.Row(visible=True), + gr.Checkbox(visible=False), + ) + elif x == "add_null": + return ( + gr.Markdown(add_null_md), + gr.Column(visible=False), + gr.Number(visible=False), + gr.Number(visible=False), + gr.Number(visible=False), + gr.Row(visible=True), + gr.Checkbox(visible=False), + ) + else: # weighted_sum + return ( + gr.Markdown(weighted_sum_md), + gr.Column(visible=True), + gr.Number(visible=True), + gr.Number(visible=True), + gr.Number(visible=True), + gr.Row(visible=False), + gr.Checkbox(visible=False), + ) + def create_merge_app(model_holder: TTSModelHolder) -> gr.Blocks: model_names = model_holder.model_names @@ -336,14 +1013,26 @@ def create_merge_app(model_holder: TTSModelHolder) -> gr.Blocks: ) return app initial_id = 0 - initial_model_files = model_holder.model_files_dict[model_names[initial_id]] + initial_model_files = [ + str(f) for f in model_holder.model_files_dict[model_names[initial_id]] + ] with gr.Blocks(theme=GRADIO_THEME) as app: gr.Markdown( - "2つのStyle-Bert-VITS2モデルから、声質・話し方・話す速さを取り替えたり混ぜたりできます。" + "複数のStyle-Bert-VITS2モデルから、声質・話し方・話す速さを取り替えたり混ぜたり引いたりして新しいモデルを作成できます。" ) with gr.Accordion(label="使い方", open=False): gr.Markdown(initial_md) + method = gr.Radio( + label="マージ方法", + choices=[ + ("通常マージ", "usual"), + ("差分マージ", "add_diff"), + ("加重和", "weighted_sum"), + ("ヌルモデルマージ", "add_null"), + ], + value="usual", + ) with gr.Row(): with gr.Column(scale=3): model_name_a = gr.Dropdown( @@ -356,6 +1045,12 @@ def create_merge_app(model_holder: TTSModelHolder) -> gr.Blocks: choices=initial_model_files, value=initial_model_files[0], ) + model_a_coeff = gr.Number( + label="モデルAの係数", + value=1.0, + step=0.1, + visible=False, + ) with gr.Column(scale=3): model_name_b = gr.Dropdown( label="モデルB", @@ -367,10 +1062,34 @@ def create_merge_app(model_holder: TTSModelHolder) -> gr.Blocks: choices=initial_model_files, value=initial_model_files[0], ) + model_b_coeff = gr.Number( + label="モデルBの係数", + value=-1.0, + step=0.1, + visible=False, + ) + with gr.Column(scale=3, visible=False) as c_col: + model_name_c = gr.Dropdown( + label="モデルC", + choices=model_names, + value=model_names[initial_id], + ) + model_path_c = gr.Dropdown( + label="モデルファイル", + choices=initial_model_files, + value=initial_model_files[0], + ) + model_c_coeff = gr.Number( + label="モデルCの係数", + value=0.0, + step=0.1, + visible=False, + ) refresh_button = gr.Button("更新", scale=1, visible=True) + method_desc = gr.Markdown(usual_md) with gr.Column(variant="panel"): new_name = gr.Textbox(label="新しいモデル名", placeholder="new_model") - with gr.Row(): + with gr.Row() as weight_row: voice_slider = gr.Slider( label="声質", value=0, @@ -402,45 +1121,337 @@ def create_merge_app(model_holder: TTSModelHolder) -> gr.Blocks: use_slerp_instead_of_lerp = gr.Checkbox( label="線形補完のかわりに球面線形補完を使う", value=False, + visible=True, ) - with gr.Column(variant="panel"): - gr.Markdown("## モデルファイル(safetensors)のマージ") + with gr.Column(variant="panel"): + gr.Markdown("## 1. モデルファイル (safetensors) のマージ") + with gr.Row(): model_merge_button = gr.Button( "モデルファイルのマージ", variant="primary" ) info_model_merge = gr.Textbox(label="情報") - with gr.Column(variant="panel"): - gr.Markdown(style_merge_md) - with gr.Row(): - load_style_button = gr.Button("スタイル一覧をロード", scale=1) - styles_a = gr.Textbox(label="モデルAのスタイル一覧") - styles_b = gr.Textbox(label="モデルBのスタイル一覧") - style_triple_list = gr.TextArea( - label="スタイルのマージリスト", - placeholder=f"{DEFAULT_STYLE}, {DEFAULT_STYLE},{DEFAULT_STYLE}\nAngry, Angry, Angry", - value=f"{DEFAULT_STYLE}, {DEFAULT_STYLE}, {DEFAULT_STYLE}", - ) - style_merge_button = gr.Button("スタイルのマージ", variant="primary") - info_style_merge = gr.Textbox(label="情報") + with gr.Column(variant="panel"): + gr.Markdown(tts_md) + text_input = gr.TextArea( + label="テキスト", value="これはテストです。聞こえていますか?" + ) + with gr.Row(): + with gr.Column(): + style = gr.Dropdown( + label="スタイル", + choices=[DEFAULT_STYLE], + value=DEFAULT_STYLE, + ) + emotion_weight = gr.Slider( + minimum=0, + maximum=50, + value=1, + step=0.1, + label="スタイルの強さ", + ) + tts_button = gr.Button("音声合成", variant="primary") + tts_info = gr.Textbox(label="情報") + audio_output = gr.Audio(label="結果") + with gr.Column(variant="panel"): + gr.Markdown(style_merge_md) + style_a_list = gr.State([DEFAULT_STYLE]) + style_b_list = gr.State([DEFAULT_STYLE]) + style_c_list = gr.State([DEFAULT_STYLE]) + gr.Markdown("Hello world!") + with gr.Row(): + style_count = gr.Number(label="作るスタイルの数", value=1, step=1) - text_input = gr.TextArea( - label="テキスト", value="これはテストです。聞こえていますか?" - ) - style = gr.Dropdown( - label="スタイル", - choices=["スタイルをマージしてください"], - value="スタイルをマージしてください", - ) - emotion_weight = gr.Slider( - minimum=0, - maximum=50, - value=1, - step=0.1, - label="スタイルの強さ", - ) - tts_button = gr.Button("音声合成", variant="primary") - audio_output = gr.Audio(label="結果") + get_style_btn = gr.Button("各モデルのスタイルを取得", variant="primary") + get_style_btn.click( + get_triple_styles, + inputs=[model_name_a, model_name_b, model_name_c], + outputs=[style_a_list, style_b_list, style_c_list], + ) + + def join_names(*args): + if all(arg == DEFAULT_STYLE for arg in args): + return DEFAULT_STYLE + return "_".join(args) + + @gr.render( + inputs=[ + style_count, + style_a_list, + style_b_list, + style_c_list, + method, + ] + ) + def render_style( + style_count, style_a_list, style_b_list, style_c_list, method + ): + a_components = [] + b_components = [] + c_components = [] + out_components = [] + if method in ["usual", "add_null"]: + for i in range(style_count): + with gr.Row(): + style_a = gr.Dropdown( + label="モデルAのスタイル名", + key=f"style_a_{i}", + choices=style_a_list, + value=DEFAULT_STYLE, + interactive=i != 0, + ) + style_b = gr.Dropdown( + label="モデルBのスタイル名", + key=f"style_b_{i}", + choices=style_b_list, + value=DEFAULT_STYLE, + interactive=i != 0, + ) + style_out = gr.Textbox( + label="出力スタイル名", + key=f"style_out_{i}", + value=DEFAULT_STYLE, + interactive=i != 0, + ) + style_a.change( + join_names, + inputs=[style_a, style_b], + outputs=[style_out], + ) + style_b.change( + join_names, + inputs=[style_a, style_b], + outputs=[style_out], + ) + a_components.append(style_a) + b_components.append(style_b) + out_components.append(style_out) + if method == "usual": + + def _merge_usual(data): + style_tuple_list = [ + (data[a], data[b], data[out]) + for a, b, out in zip( + a_components, b_components, out_components + ) + ] + return merge_style_usual_gr( + data[model_name_a], + data[model_name_b], + data[speech_style_slider], + data[new_name], + style_tuple_list, + ) + + style_merge_btn.click( + _merge_usual, + inputs=set( + a_components + + b_components + + out_components + + [ + model_name_a, + model_name_b, + speech_style_slider, + new_name, + ] + ), + outputs=[info_style_merge, style], + ) + else: # add_null + + def _merge_add_null(data): + print("Method is add_null") + style_tuple_list = [ + (data[a], data[b], data[out]) + for a, b, out in zip( + a_components, b_components, out_components + ) + ] + return merge_style_add_null_gr( + data[model_name_a], + data[model_name_b], + data[speech_style_slider], + data[new_name], + style_tuple_list, + ) + + style_merge_btn.click( + _merge_add_null, + inputs=set( + a_components + + b_components + + out_components + + [ + model_name_a, + model_name_b, + speech_style_slider, + new_name, + ] + ), + outputs=[info_style_merge, style], + ) + + elif method in ["add_diff", "weighted_sum"]: + for i in range(style_count): + with gr.Row(): + style_a = gr.Dropdown( + label="モデルAのスタイル名", + key=f"style_a_{i}", + choices=style_a_list, + value=DEFAULT_STYLE, + interactive=i != 0, + ) + style_b = gr.Dropdown( + label="モデルBのスタイル名", + key=f"style_b_{i}", + choices=style_b_list, + value=DEFAULT_STYLE, + interactive=i != 0, + ) + style_c = gr.Dropdown( + label="モデルCのスタイル名", + key=f"style_c_{i}", + choices=style_c_list, + value=DEFAULT_STYLE, + interactive=i != 0, + ) + style_out = gr.Textbox( + label="出力スタイル名", + key=f"style_out_{i}", + value=DEFAULT_STYLE, + interactive=i != 0, + ) + style_a.change( + join_names, + inputs=[style_a, style_b, style_c], + outputs=[style_out], + ) + style_b.change( + join_names, + inputs=[style_a, style_b, style_c], + outputs=[style_out], + ) + style_c.change( + join_names, + inputs=[style_a, style_b, style_c], + outputs=[style_out], + ) + a_components.append(style_a) + b_components.append(style_b) + c_components.append(style_c) + out_components.append(style_out) + if method == "add_diff": + + def _merge_add_diff(data): + style_tuple_list = [ + (data[a], data[b], data[c], data[out]) + for a, b, c, out in zip( + a_components, + b_components, + c_components, + out_components, + ) + ] + return merge_style_add_diff_gr( + data[model_name_a], + data[model_name_b], + data[model_name_c], + data[speech_style_slider], + data[new_name], + style_tuple_list, + ) + + style_merge_btn.click( + _merge_add_diff, + inputs=set( + a_components + + b_components + + c_components + + out_components + + [ + model_name_a, + model_name_b, + model_name_c, + speech_style_slider, + new_name, + ] + ), + outputs=[info_style_merge, style], + ) + else: # weighted_sum + + def _merge_weighted_sum(data): + style_tuple_list = [ + (data[a], data[b], data[c], data[out]) + for a, b, c, out in zip( + a_components, + b_components, + c_components, + out_components, + ) + ] + return merge_style_weighted_sum_gr( + data[model_name_a], + data[model_name_b], + data[model_name_c], + data[model_a_coeff], + data[model_b_coeff], + data[model_c_coeff], + data[new_name], + style_tuple_list, + ) + + style_merge_btn.click( + _merge_weighted_sum, + inputs=set( + a_components + + b_components + + c_components + + out_components + + [ + model_name_a, + model_name_b, + model_name_c, + model_a_coeff, + model_b_coeff, + model_c_coeff, + new_name, + ] + ), + outputs=[info_style_merge, style], + ) + + with gr.Row(): + add_btn = gr.Button("スタイルを増やす") + del_btn = gr.Button("スタイルを減らす") + add_btn.click( + lambda x: x + 1, + inputs=[style_count], + outputs=[style_count], + ) + del_btn.click( + lambda x: x - 1 if x > 1 else 1, + inputs=[style_count], + outputs=[style_count], + ) + style_merge_btn = gr.Button("スタイルのマージ", variant="primary") + + info_style_merge = gr.Textbox(label="情報") + + method.change( + method_change, + inputs=[method], + outputs=[ + method_desc, + c_col, + model_a_coeff, + model_b_coeff, + model_c_coeff, + weight_row, + use_slerp_instead_of_lerp, + ], + ) model_name_a.change( model_holder.update_model_files_for_gradio, inputs=[model_name_a], @@ -451,25 +1462,34 @@ def create_merge_app(model_holder: TTSModelHolder) -> gr.Blocks: inputs=[model_name_b], outputs=[model_path_b], ) - - refresh_button.click( - lambda: update_two_model_names_dropdown(model_holder), - outputs=[model_name_a, model_path_a, model_name_b, model_path_b], + model_name_c.change( + model_holder.update_model_files_for_gradio, + inputs=[model_name_c], + outputs=[model_path_c], ) - load_style_button.click( - load_styles_gr, - inputs=[model_name_a, model_name_b], - outputs=[styles_a, styles_b, style_triple_list], + refresh_button.click( + lambda: update_three_model_names_dropdown(model_holder), + outputs=[ + model_name_a, + model_path_a, + model_name_b, + model_path_b, + model_name_c, + model_path_c, + ], ) model_merge_button.click( merge_models_gr, inputs=[ - model_name_a, model_path_a, - model_name_b, model_path_b, + model_path_c, + model_a_coeff, + model_b_coeff, + model_c_coeff, + method, new_name, voice_slider, voice_pitch_slider, @@ -477,25 +1497,35 @@ def create_merge_app(model_holder: TTSModelHolder) -> gr.Blocks: tempo_slider, use_slerp_instead_of_lerp, ], - outputs=[info_model_merge], + outputs=[info_model_merge, style], ) - style_merge_button.click( - merge_style_gr, - inputs=[ - model_name_a, - model_name_b, - speech_style_slider, - new_name, - style_triple_list, - ], - outputs=[info_style_merge, style], - ) + # style_merge_button.click( + # merge_style_gr, + # inputs=[ + # model_name_a, + # model_name_b, + # model_name_c, + # method, + # speech_style_slider, + # new_name, + # style_triple_list, + # ], + # outputs=[info_style_merge, style], + # ) tts_button.click( simple_tts, inputs=[new_name, text_input, style, emotion_weight], - outputs=[audio_output], + outputs=[tts_info, audio_output], ) return app + + +if __name__ == "__main__": + model_holder = TTSModelHolder( + assets_root, device="cuda" if torch.cuda.is_available() else "cpu" + ) + app = create_merge_app(model_holder) + app.launch(inbrowser=True) diff --git a/gradio_tabs/style_vectors.py b/gradio_tabs/style_vectors.py index 9234cff626a73f1c8550ea53484fd31bcdd58857..c9014cd5e580db8a9325d9f1ba6134e788efe5b7 100644 --- a/gradio_tabs/style_vectors.py +++ b/gradio_tabs/style_vectors.py @@ -1,26 +1,29 @@ +""" +TODO: +importが重いので、WebUI全般が重くなっている。どうにかしたい。 +""" + import json -import os import shutil from pathlib import Path import gradio as gr import matplotlib.pyplot as plt import numpy as np -import yaml from scipy.spatial.distance import pdist, squareform from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans from sklearn.manifold import TSNE from umap import UMAP -from config import config +from config import get_path_config +from default_style import save_styles_by_dirs from style_bert_vits2.constants import DEFAULT_STYLE, GRADIO_THEME from style_bert_vits2.logging import logger -# Get path settings -with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: - path_config: dict[str, str] = yaml.safe_load(f.read()) - dataset_root = Path(path_config["dataset_root"]) - # assets_root = path_config["assets_root"] + +path_config = get_path_config() +dataset_root = path_config.dataset_root +assets_root = path_config.assets_root MAX_CLUSTER_NUM = 10 MAX_AUDIO_NUM = 10 @@ -38,11 +41,7 @@ centroids = [] def load(model_name: str, reduction_method: str): global wav_files, x, x_reduced, mean - # wavs_dir = os.path.join(dataset_root, model_name, "wavs") wavs_dir = dataset_root / model_name / "wavs" - # style_vector_files = [ - # os.path.join(wavs_dir, f) for f in os.listdir(wavs_dir) if f.endswith(".npy") - # ] style_vector_files = [f for f in wavs_dir.rglob("*.npy") if f.is_file()] # foo.wav.npy -> foo.wav wav_files = [f.with_suffix("") for f in style_vector_files] @@ -142,7 +141,7 @@ def do_dbscan_gradio(eps=2.5, min_samples=15): ) plt.legend() - n_clusters = max(y_pred) + 1 + n_clusters = int(max(y_pred) + 1) if n_clusters > MAX_CLUSTER_NUM: # raise ValueError(f"The number of clusters is too large: {n_clusters}") @@ -169,7 +168,7 @@ def representative_wav_files_gradio(cluster_id, num_files=1): closest_indices = representative_wav_files(cluster_id, num_files) actual_num_files = len(closest_indices) # ファイル数が少ないときのため return [ - gr.Audio(wav_files[i], visible=True, label=wav_files[i]) + gr.Audio(wav_files[i], visible=True, label=str(wav_files[i])) for i in closest_indices ] + [gr.update(visible=False)] * (MAX_AUDIO_NUM - actual_num_files) @@ -195,21 +194,21 @@ def do_clustering_gradio(n_clusters=4, method="KMeans"): ] * MAX_AUDIO_NUM -def save_style_vectors_from_clustering(model_name, style_names_str: str): +def save_style_vectors_from_clustering(model_name: str, style_names_str: str): """centerとcentroidsを保存する""" - result_dir = os.path.join(config.assets_root, model_name) - os.makedirs(result_dir, exist_ok=True) + result_dir = assets_root / model_name + result_dir.mkdir(parents=True, exist_ok=True) style_vectors = np.stack([mean] + centroids) - style_vector_path = os.path.join(result_dir, "style_vectors.npy") - if os.path.exists(style_vector_path): + style_vector_path = result_dir / "style_vectors.npy" + if style_vector_path.exists(): logger.info(f"Backup {style_vector_path} to {style_vector_path}.bak") shutil.copy(style_vector_path, f"{style_vector_path}.bak") np.save(style_vector_path, style_vectors) logger.success(f"Saved style vectors to {style_vector_path}") # config.jsonの更新 - config_path = os.path.join(result_dir, "config.json") - if not os.path.exists(config_path): + config_path = result_dir / "config.json" + if not config_path.exists(): return f"{config_path}が存在しません。" style_names = [name.strip() for name in style_names_str.split(",")] style_name_list = [DEFAULT_STYLE] + style_names @@ -220,7 +219,7 @@ def save_style_vectors_from_clustering(model_name, style_names_str: str): logger.info(f"Backup {config_path} to {config_path}.bak") shutil.copy(config_path, f"{config_path}.bak") - with open(config_path, "r", encoding="utf-8") as f: + with open(config_path, encoding="utf-8") as f: json_dict = json.load(f) json_dict["data"]["num_styles"] = len(style_name_list) style_dict = {name: i for i, name in enumerate(style_name_list)} @@ -232,7 +231,7 @@ def save_style_vectors_from_clustering(model_name, style_names_str: str): def save_style_vectors_from_files( - model_name, audio_files_str: str, style_names_str: str + model_name: str, audio_files_str: str, style_names_str: str ): """音声ファイルからスタイルベクトルを作成して保存する""" global mean @@ -240,8 +239,8 @@ def save_style_vectors_from_files( return "Error: スタイルベクトルを読み込んでください。" mean = np.mean(x, axis=0) - result_dir = os.path.join(config.assets_root, model_name) - os.makedirs(result_dir, exist_ok=True) + result_dir = assets_root / model_name + result_dir.mkdir(parents=True, exist_ok=True) audio_files = [name.strip() for name in audio_files_str.split(",")] style_names = [name.strip() for name in style_names_str.split(",")] if len(audio_files) != len(style_names): @@ -251,28 +250,28 @@ def save_style_vectors_from_files( return "スタイル名が重複しています。" style_vectors = [mean] - wavs_dir = os.path.join(dataset_root, model_name, "wavs") + wavs_dir = dataset_root / model_name / "wavs" for audio_file in audio_files: - path = os.path.join(wavs_dir, audio_file) - if not os.path.exists(path): + path = wavs_dir / audio_file + if not path.exists(): return f"{path}が存在しません。" style_vectors.append(np.load(f"{path}.npy")) style_vectors = np.stack(style_vectors) assert len(style_name_list) == len(style_vectors) - style_vector_path = os.path.join(result_dir, "style_vectors.npy") - if os.path.exists(style_vector_path): + style_vector_path = result_dir / "style_vectors.npy" + if style_vector_path.exists(): logger.info(f"Backup {style_vector_path} to {style_vector_path}.bak") shutil.copy(style_vector_path, f"{style_vector_path}.bak") np.save(style_vector_path, style_vectors) # config.jsonの更新 - config_path = os.path.join(result_dir, "config.json") - if not os.path.exists(config_path): + config_path = result_dir / "config.json" + if not config_path.exists(): return f"{config_path}が存在しません。" logger.info(f"Backup {config_path} to {config_path}.bak") shutil.copy(config_path, f"{config_path}.bak") - with open(config_path, "r", encoding="utf-8") as f: + with open(config_path, encoding="utf-8") as f: json_dict = json.load(f) json_dict["data"]["num_styles"] = len(style_name_list) style_dict = {name: i for i, name in enumerate(style_name_list)} @@ -283,20 +282,107 @@ def save_style_vectors_from_files( return f"成功!\n{style_vector_path}に保存し{config_path}を更新しました。" -how_to_md = f""" -Style-Bert-VITS2でこまかくスタイルを指定して音声合成するには、モデルごとにスタイルベクトルのファイル`style_vectors.npy`を手動で作成する必要があります。 +def save_style_vectors_by_dirs(model_name: str, audio_dir_str: str): + if model_name == "": + return "モデル名を入力してください。" + if audio_dir_str == "": + return "音声ファイルが入っているディレクトリを入力してください。" + + from concurrent.futures import ThreadPoolExecutor + from multiprocessing import cpu_count + + from tqdm import tqdm + + from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT + from style_gen import save_style_vector + + # First generate style vectors for each audio file + + audio_dir = Path(audio_dir_str) + audio_suffixes = [".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a"] + audio_files = [f for f in audio_dir.rglob("*") if f.suffix in audio_suffixes] + + def process(file: Path): + # f: `test.wav` -> search `test.wav.npy` + if (file.with_name(file.name + ".npy")).exists(): + return file, None + try: + save_style_vector(str(file)) + except Exception as e: + return file, e + return file, None + + with ThreadPoolExecutor(max_workers=cpu_count() // 2) as executor: + _ = list( + tqdm( + executor.map( + process, + audio_files, + ), + total=len(audio_files), + file=SAFE_STDOUT, + desc="Generating style vectors", + ) + ) -ただし、学習の過程で自動的に平均スタイル「{DEFAULT_STYLE}」のみは作成されるので、それをそのまま使うこともできます(その場合はこのWebUIは使いません)。 + result_dir = assets_root / model_name + config_path = result_dir / "config.json" + if not config_path.exists(): + return f"{config_path}が存在しません。" + logger.info(f"Backup {config_path} to {config_path}.bak") + shutil.copy(config_path, f"{config_path}.bak") -このプロセスは学習とは全く関係がないので、何回でも独立して繰り返して試せます。また学習中にもたぶん軽いので動くはずです。 + style_vector_path = result_dir / "style_vectors.npy" + if style_vector_path.exists(): + logger.info(f"Backup {style_vector_path} to {style_vector_path}.bak") + shutil.copy(style_vector_path, f"{style_vector_path}.bak") + save_styles_by_dirs( + wav_dir=audio_dir, + output_dir=result_dir, + config_path=config_path, + config_output_path=config_path, + ) + return f"成功!\n{result_dir}にスタイルベクトルを保存しました。" + + +how_to_md = f""" +Style-Bert-VITS2でこまかくスタイルを指定して音声合成するには、モデルごとにスタイルベクトルのファイル`style_vectors.npy`を作成する必要があります。 + +ただし、学習の過程では自動的に、平均スタイル「{DEFAULT_STYLE}」と、(**Ver 2.5.0以降からは**)音声をサブフォルダに分けていた場合はそのサブフォルダごとのスタイルが保存されています。 ## 方法 +- 方法0: 音声を作りたいスタイルごとのサブフォルダに分け、そのフォルダごとにスタイルベクトルを作成 - 方法1: 音声ファイルを自動でスタイル別に分け、その各スタイルの平均を取って保存 - 方法2: スタイルを代表する音声ファイルを手動で選んで、その音声のスタイルベクトルを保存 - 方法3: 自分でもっと頑張ってこだわって作る(JVNVコーパスなど、もともとスタイルラベル等が利用可能な場合はこれがよいかも) """ +method0 = """ +音声をスタイルごとにサブフォルダを作り、その中に音声ファイルを入れてください。 + +**注意** + +- Ver 2.5.0以降では、`inputs/`フォルダや`raw/`フォルダにサブディレクトリに分けて音声ファイルを入れるだけで、スタイルベクトルが自動で作成されるので、この手順は不要です。 +- それ未満のバージョンで学習したモデルに新しくスタイルベクトルをつけたい場合や、学習に使ったのとは別の音声でスタイルベクトルを作成したい場合に使います。 +- 学習との整合性のため、もし**現在学習中や、今後学習する予定がある場合は**、音声ファイルは、`Data/{モデル名}/wavs`フォルダではなく**新しい別のディレクトリに保存してください**。 + +例: + +```bash +audio_dir +├── style1 +│ ├── audio1.wav +│ ├── audio2.wav +│ └── ... +├── style2 +│ ├── audio1.wav +│ ├── audio2.wav +│ └── ... +└── ... +``` +""" + method1 = f""" 学習の時に取り出したスタイルベクトルを読み込んで、可視化を見ながらスタイルを分けていきます。 @@ -332,138 +418,168 @@ def create_style_vectors_app(): with gr.Blocks(theme=GRADIO_THEME) as app: with gr.Accordion("使い方", open=False): gr.Markdown(how_to_md) - with gr.Row(): - model_name = gr.Textbox(placeholder="your_model_name", label="モデル名") - reduction_method = gr.Radio( - choices=["UMAP", "t-SNE"], - label="次元削減方法", - info="v 1.3以前はt-SNEでしたがUMAPのほうがよい可能性もあります。", - value="UMAP", + model_name = gr.Textbox(placeholder="your_model_name", label="モデル名") + with gr.Tab("方法0: サブフォルダごとにスタイルベクトルを作成"): + gr.Markdown(method0) + audio_dir = gr.Textbox( + placeholder="path/to/audio_dir", + label="音声が入っているフォルダ", + info="音声ファイルをスタイルごとにサブフォルダに分けて保存してください。", ) - load_button = gr.Button("スタイルベクトルを読み込む", variant="primary") - output = gr.Plot(label="音声スタイルの可視化") - load_button.click(load, inputs=[model_name, reduction_method], outputs=[output]) - with gr.Tab("方法1: スタイル分けを自動で行う"): - with gr.Tab("スタイル分け1"): - n_clusters = gr.Slider( - minimum=2, - maximum=10, - step=1, - value=4, - label="作るスタイルの数(平均スタイルを除く)", - info="上の図を見ながらスタイルの数を試行錯誤してください。", - ) - c_method = gr.Radio( - choices=[ - "Agglomerative after reduction", - "KMeans after reduction", - "Agglomerative", - "KMeans", - ], - label="アルゴリズム", - info="分類する(クラスタリング)アルゴリズムを選択します。いろいろ試してみてください。", - value="Agglomerative after reduction", - ) - c_button = gr.Button("スタイル分けを実行") - with gr.Tab("スタイル分け2: DBSCAN"): - gr.Markdown(dbscan_md) - eps = gr.Slider( - minimum=0.1, - maximum=10, - step=0.01, - value=0.3, - label="eps", - ) - min_samples = gr.Slider( - minimum=1, - maximum=50, - step=1, - value=15, - label="min_samples", - ) - with gr.Row(): - dbscan_button = gr.Button("スタイル分けを実行") - num_styles_result = gr.Textbox(label="スタイル数") - gr.Markdown("スタイル分けの結果") - gr.Markdown( - "注意: もともと256次元なものをを2次元に落としているので、正確なベクトルの位置関係ではありません。" + method0_btn = gr.Button("スタイルベクトルを作成", variant="primary") + method0_info = gr.Textbox(label="結果") + method0_btn.click( + save_style_vectors_by_dirs, + inputs=[model_name, audio_dir], + outputs=[method0_info], ) + with gr.Tab("その他の方法"): with gr.Row(): - gr_plot = gr.Plot() - with gr.Column(): - with gr.Row(): - cluster_index = gr.Slider( - minimum=1, - maximum=MAX_CLUSTER_NUM, - step=1, - value=1, - label="スタイル番号", - info="選択したスタイルの代表音声を表示します。", - ) - num_files = gr.Slider( - minimum=1, - maximum=MAX_AUDIO_NUM, - step=1, - value=5, - label="代表音声の数をいくつ表示するか", - ) - get_audios_button = gr.Button("代表音声を取得") - with gr.Row(): - audio_list = [] - for i in range(MAX_AUDIO_NUM): - audio_list.append(gr.Audio(visible=False, show_label=True)) - c_button.click( - do_clustering_gradio, - inputs=[n_clusters, c_method], - outputs=[gr_plot, cluster_index] + audio_list, - ) - dbscan_button.click( - do_dbscan_gradio, - inputs=[eps, min_samples], - outputs=[gr_plot, cluster_index, num_styles_result] + audio_list, + reduction_method = gr.Radio( + choices=["UMAP", "t-SNE"], + label="次元削減方法", + info="v 1.3以前はt-SNEでしたがUMAPのほうがよい可能性もあります。", + value="UMAP", ) - get_audios_button.click( - representative_wav_files_gradio, - inputs=[cluster_index, num_files], - outputs=audio_list, - ) - gr.Markdown("結果が良さそうなら、これを保存します。") - style_names = gr.Textbox( - "Angry, Sad, Happy", - label="スタイルの名前", - info=f"スタイルの名前を`,`で区切って入力してください(日本語可)。例: `Angry, Sad, Happy`や`怒り, 悲しみ, 喜び`など。平均音声は{DEFAULT_STYLE}として自動的に保存されます。", - ) - with gr.Row(): - save_button1 = gr.Button("スタイルベクトルを保存", variant="primary") - info2 = gr.Textbox(label="保存結果") - - save_button1.click( - save_style_vectors_from_clustering, - inputs=[model_name, style_names], - outputs=[info2], - ) - with gr.Tab("方法2: 手動でスタイルを選ぶ"): - gr.Markdown( - "下のテキスト欄に、各スタイルの代表音声のファイル名を`,`区切りで、その横に対応するスタイル名を`,`区切りで入力してください。" - ) - gr.Markdown("例: `angry.wav, sad.wav, happy.wav`と`Angry, Sad, Happy`") - gr.Markdown( - f"注意: {DEFAULT_STYLE}スタイルは自動的に保存されます、手動では{DEFAULT_STYLE}という名前のスタイルは指定しないでください。" + load_button = gr.Button("スタイルベクトルを読み込む", variant="primary") + output = gr.Plot(label="音声スタイルの可視化") + load_button.click( + load, inputs=[model_name, reduction_method], outputs=[output] ) - with gr.Row(): - audio_files_text = gr.Textbox( - label="音声ファイル名", placeholder="angry.wav, sad.wav, happy.wav" + with gr.Tab("方法1: スタイル分けを自動で行う"): + with gr.Tab("スタイル分け1"): + n_clusters = gr.Slider( + minimum=2, + maximum=10, + step=1, + value=4, + label="作るスタイルの数(平均スタイルを除く)", + info="上の図を見ながらスタイルの数を試行錯誤してください。", + ) + c_method = gr.Radio( + choices=[ + "Agglomerative after reduction", + "KMeans after reduction", + "Agglomerative", + "KMeans", + ], + label="アルゴリズム", + info="分類する(クラスタリング)アルゴリズムを選択します。いろいろ試してみてください。", + value="Agglomerative after reduction", + ) + c_button = gr.Button("スタイル分けを実行") + with gr.Tab("スタイル分け2: DBSCAN"): + gr.Markdown(dbscan_md) + eps = gr.Slider( + minimum=0.1, + maximum=10, + step=0.01, + value=0.3, + label="eps", + ) + min_samples = gr.Slider( + minimum=1, + maximum=50, + step=1, + value=15, + label="min_samples", + ) + with gr.Row(): + dbscan_button = gr.Button("スタイル分けを実行") + num_styles_result = gr.Textbox(label="スタイル数") + gr.Markdown("スタイル分けの結果") + gr.Markdown( + "注意: もともと256次元なものをを2次元に落としているので、正確なベクトルの位置関係ではありません。" ) - style_names_text = gr.Textbox( - label="スタイル名", placeholder="Angry, Sad, Happy" + with gr.Row(): + gr_plot = gr.Plot() + with gr.Column(): + with gr.Row(): + cluster_index = gr.Slider( + minimum=1, + maximum=MAX_CLUSTER_NUM, + step=1, + value=1, + label="スタイル番号", + info="選択したスタイルの代表音声を表示します。", + ) + num_files = gr.Slider( + minimum=1, + maximum=MAX_AUDIO_NUM, + step=1, + value=5, + label="代表音声の数をいくつ表示するか", + ) + get_audios_button = gr.Button("代表音声を取得") + with gr.Row(): + audio_list = [] + for i in range(MAX_AUDIO_NUM): + audio_list.append( + gr.Audio(visible=False, show_label=True) + ) + c_button.click( + do_clustering_gradio, + inputs=[n_clusters, c_method], + outputs=[gr_plot, cluster_index] + audio_list, + ) + dbscan_button.click( + do_dbscan_gradio, + inputs=[eps, min_samples], + outputs=[gr_plot, cluster_index, num_styles_result] + + audio_list, + ) + get_audios_button.click( + representative_wav_files_gradio, + inputs=[cluster_index, num_files], + outputs=audio_list, + ) + gr.Markdown("結果が良さそうなら、これを保存します。") + style_names = gr.Textbox( + "Angry, Sad, Happy", + label="スタイルの名前", + info=f"スタイルの名前を`,`で区切って入力してください(日本語可)。例: `Angry, Sad, Happy`や`怒り, 悲しみ, 喜び`など。平均音声は{DEFAULT_STYLE}として自動的に保存されます。", ) - with gr.Row(): - save_button2 = gr.Button("スタイルベクトルを保存", variant="primary") - info2 = gr.Textbox(label="保存結果") - save_button2.click( - save_style_vectors_from_files, - inputs=[model_name, audio_files_text, style_names_text], + with gr.Row(): + save_button1 = gr.Button( + "スタイルベクトルを保存", variant="primary" + ) + info2 = gr.Textbox(label="保存結果") + + save_button1.click( + save_style_vectors_from_clustering, + inputs=[model_name, style_names], outputs=[info2], ) + with gr.Tab("方法2: 手動でスタイルを選ぶ"): + gr.Markdown( + "下のテキスト欄に、各スタイルの代表音声のファイル名を`,`区切りで、その横に対応するスタイル名を`,`区切りで入力してください。" + ) + gr.Markdown("例: `angry.wav, sad.wav, happy.wav`と`Angry, Sad, Happy`") + gr.Markdown( + f"注意: {DEFAULT_STYLE}スタイルは自動的に保存されます、手動では{DEFAULT_STYLE}という名前のスタイルは指定しないでください。" + ) + with gr.Row(): + audio_files_text = gr.Textbox( + label="音声ファイル名", + placeholder="angry.wav, sad.wav, happy.wav", + ) + style_names_text = gr.Textbox( + label="スタイル名", placeholder="Angry, Sad, Happy" + ) + with gr.Row(): + save_button2 = gr.Button( + "スタイルベクトルを保存", variant="primary" + ) + info2 = gr.Textbox(label="保存結果") + save_button2.click( + save_style_vectors_from_files, + inputs=[model_name, audio_files_text, style_names_text], + outputs=[info2], + ) return app + + +if __name__ == "__main__": + app = create_style_vectors_app() + app.launch(inbrowser=True) diff --git a/gradio_tabs/train.py b/gradio_tabs/train.py index 03efdde579f1869bb3a25bdd8b42e4456f818af2..fbf8522c495d41739098cdf9bff870200a936123 100644 --- a/gradio_tabs/train.py +++ b/gradio_tabs/train.py @@ -1,11 +1,11 @@ import json -import os import shutil import socket import subprocess import sys import time import webbrowser +from dataclasses import dataclass from datetime import datetime from multiprocessing import cpu_count from pathlib import Path @@ -13,6 +13,8 @@ from pathlib import Path import gradio as gr import yaml +from config import get_path_config +from style_bert_vits2.constants import GRADIO_THEME from style_bert_vits2.logging import logger from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT from style_bert_vits2.utils.subprocess import run_script_with_log, second_elem_of @@ -21,20 +23,27 @@ from style_bert_vits2.utils.subprocess import run_script_with_log, second_elem_o logger_handler = None tensorboard_executed = False -# Get path settings -with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: - path_config: dict[str, str] = yaml.safe_load(f.read()) - dataset_root = Path(path_config["dataset_root"]) +path_config = get_path_config() +dataset_root = path_config.dataset_root -def get_path(model_name: str) -> tuple[Path, Path, Path, Path, Path]: +@dataclass +class PathsForPreprocess: + dataset_path: Path + esd_path: Path + train_path: Path + val_path: Path + config_path: Path + + +def get_path(model_name: str) -> PathsForPreprocess: assert model_name != "", "モデル名は空にできません" dataset_path = dataset_root / model_name - lbl_path = dataset_path / "esd.list" + esd_path = dataset_path / "esd.list" train_path = dataset_path / "train.list" val_path = dataset_path / "val.list" config_path = dataset_path / "config.json" - return dataset_path, lbl_path, train_path, val_path, config_path + return PathsForPreprocess(dataset_path, esd_path, train_path, val_path, config_path) def initialize( @@ -51,14 +60,14 @@ def initialize( log_interval: int, ): global logger_handler - dataset_path, _, train_path, val_path, config_path = get_path(model_name) + paths = get_path(model_name) # 前処理のログをファイルに保存する timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") file_name = f"preprocess_{timestamp}.log" if logger_handler is not None: logger.remove(logger_handler) - logger_handler = logger.add(os.path.join(dataset_path, file_name)) + logger_handler = logger.add(paths.dataset_path / file_name) logger.info( f"Step 1: start initialization...\nmodel_name: {model_name}, batch_size: {batch_size}, epochs: {epochs}, save_every_steps: {save_every_steps}, freeze_ZH_bert: {freeze_ZH_bert}, freeze_JP_bert: {freeze_JP_bert}, freeze_EN_bert: {freeze_EN_bert}, freeze_style: {freeze_style}, freeze_decoder: {freeze_decoder}, use_jp_extra: {use_jp_extra}" @@ -68,11 +77,11 @@ def initialize( "configs/config.json" if not use_jp_extra else "configs/config_jp_extra.json" ) - with open(default_config_path, "r", encoding="utf-8") as f: + with open(default_config_path, encoding="utf-8") as f: config = json.load(f) config["model_name"] = model_name - config["data"]["training_files"] = str(train_path) - config["data"]["validation_files"] = str(val_path) + config["data"]["training_files"] = str(paths.train_path) + config["data"]["validation_files"] = str(paths.val_path) config["train"]["batch_size"] = batch_size config["train"]["epochs"] = epochs config["train"]["eval_interval"] = save_every_steps @@ -89,14 +98,14 @@ def initialize( # 今はデフォルトであるが、以前は非JP-Extra版になくバグの原因になるので念のため config["data"]["use_jp_extra"] = use_jp_extra - model_path = dataset_path / "models" + model_path = paths.dataset_path / "models" if model_path.exists(): logger.warning( f"Step 1: {model_path} already exists, so copy it to backup to {model_path}_backup" ) shutil.copytree( src=model_path, - dst=dataset_path / "models_backup", + dst=paths.dataset_path / "models_backup", dirs_exist_ok=True, ) shutil.rmtree(model_path) @@ -110,14 +119,14 @@ def initialize( logger.error(f"Step 1: {pretrained_dir} folder not found.") return False, f"Step 1, Error: {pretrained_dir}フォルダが見つかりません。" - with open(config_path, "w", encoding="utf-8") as f: + with open(paths.config_path, "w", encoding="utf-8") as f: json.dump(config, f, indent=2, ensure_ascii=False) if not Path("config.yml").exists(): shutil.copy(src="default_config.yml", dst="config.yml") - with open("config.yml", "r", encoding="utf-8") as f: + with open("config.yml", encoding="utf-8") as f: yml_data = yaml.safe_load(f) yml_data["model_name"] = model_name - yml_data["dataset_path"] = str(dataset_path) + yml_data["dataset_path"] = str(paths.dataset_path) with open("config.yml", "w", encoding="utf-8") as f: yaml.dump(yml_data, f, allow_unicode=True) logger.success("Step 1: initialization finished.") @@ -126,7 +135,7 @@ def initialize( def resample(model_name: str, normalize: bool, trim: bool, num_processes: int): logger.info("Step 2: start resampling...") - dataset_path, _, _, _, _ = get_path(model_name) + dataset_path = get_path(model_name).dataset_path input_dir = dataset_path / "raw" output_dir = dataset_path / "wavs" cmd = [ @@ -159,21 +168,24 @@ def preprocess_text( model_name: str, use_jp_extra: bool, val_per_lang: int, yomi_error: str ): logger.info("Step 3: start preprocessing text...") - _, lbl_path, train_path, val_path, config_path = get_path(model_name) - if not lbl_path.exists(): - logger.error(f"Step 3: {lbl_path} not found.") - return False, f"Step 3, Error: 書き起こしファイル {lbl_path} が見つかりません。" + paths = get_path(model_name) + if not paths.esd_path.exists(): + logger.error(f"Step 3: {paths.esd_path} not found.") + return ( + False, + f"Step 3, Error: 書き起こしファイル {paths.esd_path} が見つかりません。", + ) cmd = [ "preprocess_text.py", "--config-path", - str(config_path), + str(paths.config_path), "--transcription-path", - str(lbl_path), + str(paths.esd_path), "--train-path", - str(train_path), + str(paths.train_path), "--val-path", - str(val_path), + str(paths.val_path), "--val-per-lang", str(val_per_lang), "--yomi_error", @@ -201,7 +213,7 @@ def preprocess_text( def bert_gen(model_name: str): logger.info("Step 4: start bert_gen...") - _, _, _, _, config_path = get_path(model_name) + config_path = get_path(model_name).config_path success, message = run_script_with_log( ["bert_gen.py", "--config", str(config_path)] ) @@ -220,7 +232,7 @@ def bert_gen(model_name: str): def style_gen(model_name: str, num_processes: int): logger.info("Step 5: start style_gen...") - _, _, _, _, config_path = get_path(model_name) + config_path = get_path(model_name).config_path success, message = run_script_with_log( [ "style_gen.py", @@ -318,22 +330,31 @@ def train( skip_style: bool = False, use_jp_extra: bool = True, speedup: bool = False, + not_use_custom_batch_sampler: bool = False, ): - dataset_path, _, _, _, config_path = get_path(model_name) + paths = get_path(model_name) # 学習再開の場合を考えて念のためconfig.ymlの名前等を更新 - with open("config.yml", "r", encoding="utf-8") as f: + with open("config.yml", encoding="utf-8") as f: yml_data = yaml.safe_load(f) yml_data["model_name"] = model_name - yml_data["dataset_path"] = str(dataset_path) + yml_data["dataset_path"] = str(paths.dataset_path) with open("config.yml", "w", encoding="utf-8") as f: yaml.dump(yml_data, f, allow_unicode=True) train_py = "train_ms.py" if not use_jp_extra else "train_ms_jp_extra.py" - cmd = [train_py, "--config", str(config_path), "--model", str(dataset_path)] + cmd = [ + train_py, + "--config", + str(paths.config_path), + "--model", + str(paths.dataset_path), + ] if skip_style: cmd.append("--skip_default_style") if speedup: cmd.append("--speedup") + if not_use_custom_batch_sampler: + cmd.append("--not_use_custom_batch_sampler") success, message = run_script_with_log(cmd, ignore_warning=True) if not success: logger.error("Train failed.") @@ -385,6 +406,15 @@ def run_tensorboard(model_name: str): yield gr.Button("Tensorboardを開く") +change_log_md = """ +**Ver 2.5以降の変更点** + +- `raw/`フォルダの中で音声をサブディレクトリに分けて配置することで、自動的にスタイルが作成されるようになりました。詳細は下の「使い方/データの前準備」を参照してください。 +- これまでは1ファイルあたり14秒程度を超えた音声ファイルは学習には用いられていませんでしたが、Ver 2.5以降では「カスタムバッチサンプラーを無効化」にチェックを入れることでその制限が無しに学習できるようになりました(デフォルトはオフ)。ただし: + - 音声ファイルが長い場合の学習効率は悪いかもしれず、挙動も確認していません + - チェックを入れると要求VRAMがかなり増えるようので、学習に失敗したりVRAM不足になる場合は、バッチサイズを小さくするか、チェックを外してください +""" + how_to_md = """ ## 使い方 @@ -396,9 +426,6 @@ how_to_md = """ - 途中から学習を再開する場合は、モデル名を入力してから「学習を開始する」を押せばよいです。 -注意: 標準スタイル以外のスタイルを音声合成で使うには、スタイルベクトルファイル`style_vectors.npy`を作る必要があります。これは、`Style.bat`を実行してそこで作成してください。 -動作は軽いはずなので、学習中でも実行でき、何度でも繰り返して試せます。 - ## JP-Extra版について 元とするモデル構造として [Bert-VITS2 Japanese-Extra](https://github.com/fishaudio/Bert-VITS2/releases/tag/JP-Exta) を使うことができます。 @@ -406,40 +433,60 @@ how_to_md = """ """ prepare_md = """ -まず音声データ(wavファイルで1ファイルが2-12秒程度の、長すぎず短すぎない発話のものをいくつか)と、書き起こしテキストを用意してください。 +まず音声データと、書き起こしテキストを用意してください。 それを次のように配置します。 ``` -├── Data +├── Data/ │ ├── {モデルの名前} │ │ ├── esd.list -│ │ ├── raw -│ │ │ ├── ****.wav -│ │ │ ├── ****.wav -│ │ │ ├── ... +│ │ ├── raw/ +│ │ │ ├── foo.wav +│ │ │ ├── bar.mp3 +│ │ │ ├── style1/ +│ │ │ │ ├── baz.wav +│ │ │ │ ├── qux.wav +│ │ │ ├── style2/ +│ │ │ │ ├── corge.wav +│ │ │ │ ├── grault.wav +... ``` -wavファイル名やモデルの名前は空白を含まない半角で、wavファイルの拡張子は小文字`.wav`である必要があります。 -`raw` フォルダにはすべてのwavファイルを入れ、`esd.list` ファイルには、以下のフォーマットで各wavファイルの情報を記述してください。 +### 配置の仕方 +- 上のように配置すると、`style1/`と`style2/`フォルダの内部(直下以外も含む)に入っている音声ファイルたちから、自動的にデフォルトスタイルに加えて`style1`と`style2`というスタイルが作成されます +- 特にスタイルを作る必要がない場合や、スタイル分類機能等でスタイルを作る場合は、`raw/`フォルダ直下に全てを配置してください。このように`raw/`のサブディレクトリの個数が0または1の場合は、スタイルはデフォルトスタイルのみが作成されます。 +- 音声ファイルのフォーマットはwav形式以外にもmp3等の多くの音声ファイルに対応しています + +### 書き起こしファイル`esd.list` + +`Data/{モデルの名前}/esd.list` ファイルには、以下のフォーマットで各音声ファイルの情報を記述してください。 + + ``` -****.wav|{話者名}|{言語ID、ZHかJPかEN}|{書き起こしテキスト} +path/to/audio.wav(wavファイル以外でもこう書く)|{話者名}|{言語ID、ZHかJPかEN}|{書き起こしテキスト} ``` +- ここで、最初の`path/to/audio.wav`は、`raw/`からの相対パスです。つまり、`raw/foo.wav`の場合は`foo.wav`、`raw/style1/bar.wav`の場合は`style1/bar.wav`となります。 +- 拡張子がwavでない場合でも、`esd.list`には`wav`と書いてください、つまり、`raw/bar.mp3`の場合でも`bar.wav`と書いてください。 + + 例: ``` -wav_number1.wav|hanako|JP|こんにちは、聞こえて、いますか? -wav_next.wav|taro|JP|はい、聞こえています……。 +foo.wav|hanako|JP|こんにちは、元気ですか? +bar.wav|taro|JP|はい、聞こえています……。何か用ですか? +style1/baz.wav|hanako|JP|今日はいい天気ですね。 +style1/qux.wav|taro|JP|はい、そうですね。 +... english_teacher.wav|Mary|EN|How are you? I'm fine, thank you, and you? ... ``` -日本語話者の単一話者データセットでも構いません。 - -- 音声ファイルはrawフォルダの直下でなくてもサブフォルダに入れても構いません。その場合は、`esd.list`の最初には`raw`からの相対パスを記述してください。 +もちろん日本語話者の単一話者データセットでも構いません。 """ def create_train_app(): - with gr.Blocks().queue() as app: + with gr.Blocks(theme=GRADIO_THEME).queue() as app: + gr.Markdown(change_log_md) with gr.Accordion("使い方", open=False): gr.Markdown(how_to_md) with gr.Accordion(label="データの前準備", open=False): @@ -491,7 +538,7 @@ def create_train_app(): ("読めないファイルは使わず続行", "skip"), ("読めないファイルも無理やり読んで学習に使う", "use"), ], - value="raise", + value="skip", ) with gr.Accordion("詳細設定", open=False): num_processes = gr.Slider( @@ -677,6 +724,11 @@ def create_train_app(): label="JP-Extra版を使う", value=True, ) + not_use_custom_batch_sampler = gr.Checkbox( + label="カスタムバッチサンプラーを無効化", + info="VRAMに余裕がある場合にチェックすると、長い音声ファイルも学習に使われるようになります", + value=False, + ) speedup = gr.Checkbox( label="ログ等をスキップして学習を高速化する", value=False, @@ -764,7 +816,13 @@ def create_train_app(): # Train train_btn.click( second_elem_of(train), - inputs=[model_name, skip_style, use_jp_extra_train, speedup], + inputs=[ + model_name, + skip_style, + use_jp_extra_train, + speedup, + not_use_custom_batch_sampler, + ], outputs=[info_train], ) tensorboard_btn.click( @@ -783,3 +841,8 @@ def create_train_app(): ) return app + + +if __name__ == "__main__": + app = create_train_app() + app.launch(inbrowser=True) diff --git a/initialize.py b/initialize.py index bfe59e4b30e8034b4afa455c4fac4af5b056a11b..267465e38d35f1da4ab4838ea7626aeef175fa1c 100644 --- a/initialize.py +++ b/initialize.py @@ -1,5 +1,6 @@ import argparse import json +import shutil from pathlib import Path import yaml @@ -9,7 +10,7 @@ from style_bert_vits2.logging import logger def download_bert_models(): - with open("bert/bert_models.json", "r", encoding="utf-8") as fp: + with open("bert/bert_models.json", encoding="utf-8") as fp: models = json.load(fp) for k, v in models.items(): local_path = Path("bert").joinpath(k) @@ -49,7 +50,7 @@ def download_jp_extra_pretrained_models(): ) -def download_jvnv_models(): +def download_default_models(): files = [ "jvnv-F1-jp/config.json", "jvnv-F1-jp/jvnv-F1-jp_e160_s14000.safetensors", @@ -71,13 +72,33 @@ def download_jvnv_models(): "litagin/style_bert_vits2_jvnv", file, local_dir="model_assets", - local_dir_use_symlinks=False, ) + additional_files = { + "litagin/sbv2_koharune_ami": [ + "koharune-ami/config.json", + "koharune-ami/style_vectors.npy", + "koharune-ami/koharune-ami.safetensors", + ], + "litagin/sbv2_amitaro": [ + "amitaro/config.json", + "amitaro/style_vectors.npy", + "amitaro/amitaro.safetensors", + ], + } + for repo_id, files in additional_files.items(): + for file in files: + if not Path(f"model_assets/{file}").exists(): + logger.info(f"Downloading {file}") + hf_hub_download( + repo_id, + file, + local_dir="model_assets", + ) def main(): parser = argparse.ArgumentParser() - parser.add_argument("--skip_jvnv", action="store_true") + parser.add_argument("--skip_default_models", action="store_true") parser.add_argument("--only_infer", action="store_true") parser.add_argument( "--dataset_root", @@ -95,19 +116,24 @@ def main(): download_bert_models() - if not args.skip_jvnv: - download_jvnv_models() + if not args.skip_default_models: + download_default_models() if not args.only_infer: download_slm_model() download_pretrained_models() download_jp_extra_pretrained_models() + # If configs/paths.yml not exists, create it + default_paths_yml = Path("configs/default_paths.yml") + paths_yml = Path("configs/paths.yml") + if not paths_yml.exists(): + shutil.copy(default_paths_yml, paths_yml) + if args.dataset_root is None and args.assets_root is None: return # Change default paths if necessary - paths_yml = Path("configs/paths.yml") - with open(paths_yml, "r", encoding="utf-8") as f: + with open(paths_yml, encoding="utf-8") as f: yml_data = yaml.safe_load(f) if args.assets_root is not None: yml_data["assets_root"] = args.assets_root diff --git a/library.ipynb b/library.ipynb index 753059a79abe27facfb074c306e5ec88534dc767..8f158eaa706dae3a5aa5ecda7725a701186da2f4 100644 --- a/library.ipynb +++ b/library.ipynb @@ -1,138 +1,135 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Style-Bert-VITS2ライブラリの使用例\n", - "\n", - "`pip install style-bert-vits2`を使った、jupyter notebookでの使用例です。Google colab等でも動きます。" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# PyTorch環境の構築(ない場合)\n", - "# 参照: https://pytorch.org/get-started/locally/\n", - "\n", - "!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "LLrngKcQEAyP" - }, - "outputs": [], - "source": [ - "# style-bert-vits2のインストール\n", - "\n", - "!pip install style-bert-vits2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "9xRtfUg5EZkx" - }, - "outputs": [], - "source": [ - "# BERTモデルをロード(ローカルに手動でダウンロードする必要はありません)\n", - "\n", - "from style_bert_vits2.nlp import bert_models\n", - "from style_bert_vits2.constants import Languages\n", - "\n", - "\n", - "bert_models.load_model(Languages.JP, \"ku-nlp/deberta-v2-large-japanese-char-wwm\")\n", - "bert_models.load_tokenizer(Languages.JP, \"ku-nlp/deberta-v2-large-japanese-char-wwm\")\n", - "# bert_models.load_model(Languages.EN, \"microsoft/deberta-v3-large\")\n", - "# bert_models.load_tokenizer(Languages.EN, \"microsoft/deberta-v3-large\")\n", - "# bert_models.load_model(Languages.ZH, \"hfl/chinese-roberta-wwm-ext-large\")\n", - "# bert_models.load_tokenizer(Languages.ZH, \"hfl/chinese-roberta-wwm-ext-large\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "q2V9d3HyFAr_" - }, - "outputs": [], - "source": [ - "# Hugging Faceから試しにデフォルトモデルをダウンロードしてみて、それを音声合成に使ってみる\n", - "# model_assetsディレクトリにダウンロードされます\n", - "\n", - "from pathlib import Path\n", - "from huggingface_hub import hf_hub_download\n", - "\n", - "\n", - "model_file = \"jvnv-F1-jp/jvnv-F1-jp_e160_s14000.safetensors\"\n", - "config_file = \"jvnv-F1-jp/config.json\"\n", - "style_file = \"jvnv-F1-jp/style_vectors.npy\"\n", - "\n", - "for file in [model_file, config_file, style_file]:\n", - " print(file)\n", - " hf_hub_download(\n", - " \"litagin/style_bert_vits2_jvnv\",\n", - " file,\n", - " local_dir=\"model_assets\"\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "hJa31MEUFhe4" - }, - "outputs": [], - "source": [ - "# 上でダウンロードしたモデルファイルを指定して音声合成のテスト\n", - "\n", - "from style_bert_vits2.tts_model import TTSModel\n", - "\n", - "assets_root = Path(\"model_assets\")\n", - "\n", - "model = TTSModel(\n", - " model_path=assets_root / model_file,\n", - " config_path=assets_root / config_file,\n", - " style_vec_path=assets_root / style_file,\n", - " device=\"cpu\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Gal0tqrtGXZx" - }, - "outputs": [], - "source": [ - "from IPython.display import Audio, display\n", - "\n", - "sr, audio = model.infer(text=\"こんにちは\")\n", - "display(Audio(audio, rate=sr))" - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Style-Bert-VITS2ライブラリの使用例\n", + "\n", + "`pip install style-bert-vits2`を使った、jupyter notebookでの使用例です。Google colab等でも動きます。" + ] }, - "nbformat": 4, - "nbformat_minor": 0 + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# PyTorch環境の構築(ない場合)\n", + "# 参照: https://pytorch.org/get-started/locally/\n", + "\n", + "!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LLrngKcQEAyP" + }, + "outputs": [], + "source": [ + "# style-bert-vits2のインストール\n", + "\n", + "!pip install style-bert-vits2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9xRtfUg5EZkx" + }, + "outputs": [], + "source": [ + "# BERTモデルをロード(ローカルに手動でダウンロードする必要はありません)\n", + "\n", + "from style_bert_vits2.nlp import bert_models\n", + "from style_bert_vits2.constants import Languages\n", + "\n", + "\n", + "bert_models.load_model(Languages.JP, \"ku-nlp/deberta-v2-large-japanese-char-wwm\")\n", + "bert_models.load_tokenizer(Languages.JP, \"ku-nlp/deberta-v2-large-japanese-char-wwm\")\n", + "# bert_models.load_model(Languages.EN, \"microsoft/deberta-v3-large\")\n", + "# bert_models.load_tokenizer(Languages.EN, \"microsoft/deberta-v3-large\")\n", + "# bert_models.load_model(Languages.ZH, \"hfl/chinese-roberta-wwm-ext-large\")\n", + "# bert_models.load_tokenizer(Languages.ZH, \"hfl/chinese-roberta-wwm-ext-large\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "q2V9d3HyFAr_" + }, + "outputs": [], + "source": [ + "# Hugging Faceから試しにデフォルトモデルをダウンロードしてみて、それを音声合成に使ってみる\n", + "# model_assetsディレクトリにダウンロードされます\n", + "\n", + "from pathlib import Path\n", + "from huggingface_hub import hf_hub_download\n", + "\n", + "\n", + "model_file = \"jvnv-F1-jp/jvnv-F1-jp_e160_s14000.safetensors\"\n", + "config_file = \"jvnv-F1-jp/config.json\"\n", + "style_file = \"jvnv-F1-jp/style_vectors.npy\"\n", + "\n", + "for file in [model_file, config_file, style_file]:\n", + " print(file)\n", + " hf_hub_download(\"litagin/style_bert_vits2_jvnv\", file, local_dir=\"model_assets\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hJa31MEUFhe4" + }, + "outputs": [], + "source": [ + "# 上でダウンロードしたモデルファイルを指定して音声合成のテスト\n", + "\n", + "from style_bert_vits2.tts_model import TTSModel\n", + "\n", + "assets_root = Path(\"model_assets\")\n", + "\n", + "model = TTSModel(\n", + " model_path=assets_root / model_file,\n", + " config_path=assets_root / config_file,\n", + " style_vec_path=assets_root / style_file,\n", + " device=\"cpu\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Gal0tqrtGXZx" + }, + "outputs": [], + "source": [ + "from IPython.display import Audio, display\n", + "\n", + "sr, audio = model.infer(text=\"こんにちは\")\n", + "display(Audio(audio, rate=sr))" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/preprocess_text.py b/preprocess_text.py index a65a1d6bee219504d50f8f7f71fe9f0ffa6a9b50..4dd7e33e0236e732bb134920265fdcc12a7ee2b2 100644 --- a/preprocess_text.py +++ b/preprocess_text.py @@ -2,12 +2,12 @@ import argparse import json from collections import defaultdict from pathlib import Path -from random import shuffle +from random import sample, shuffle from typing import Optional from tqdm import tqdm -from config import Preprocess_text_config, config +from config import get_config from style_bert_vits2.logging import logger from style_bert_vits2.nlp import clean_text from style_bert_vits2.nlp.japanese import pyopenjtalk_worker @@ -22,7 +22,7 @@ pyopenjtalk_worker.initialize_worker() update_dict() -preprocess_text_config: Preprocess_text_config = config.preprocess_text_config +preprocess_text_config = get_config().preprocess_text_config # Count lines for tqdm @@ -145,7 +145,7 @@ def preprocess( spk_utt_map[spk].append(line) # 新しい話者が出てきたら話者IDを割り当て、current_sidを1増やす - if spk not in spk_id_map.keys(): + if spk not in spk_id_map: spk_id_map[spk] = current_sid current_sid += 1 if count_same > 0 or count_not_found > 0: @@ -156,16 +156,26 @@ def preprocess( train_list: list[str] = [] val_list: list[str] = [] - # 各話者ごとにシャッフルして、val_per_lang個をval_listに、残りをtrain_listに追加 + # 各話者ごとに発話リストを処理 for spk, utts in spk_utt_map.items(): - shuffle(utts) - val_list += utts[:val_per_lang] - train_list += utts[val_per_lang:] - - shuffle(val_list) + if val_per_lang == 0: + train_list.extend(utts) + continue + # ランダムにval_per_lang個のインデックスを選択 + val_indices = set(sample(range(len(utts)), val_per_lang)) + # 元の順序を保ちながらリストを分割 + for index, utt in enumerate(utts): + if index in val_indices: + val_list.append(utt) + else: + train_list.append(utt) + + # バリデーションリストのサイズ調整 if len(val_list) > max_val_total: - train_list += val_list[max_val_total:] + extra_val = val_list[max_val_total:] val_list = val_list[:max_val_total] + # 余剰のバリデーション発話をトレーニングリストに追加(元の順序を保持) + train_list.extend(extra_val) with train_path.open("w", encoding="utf-8") as f: for line in train_list: diff --git a/pyproject.toml b/pyproject.toml index 2d26680f1ef7311472139042d2e8160539156626..45fbbcb7e60ff266245cbbecf2bcb93f0971ff55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "hatchling.build" [project] name = "style-bert-vits2" dynamic = ["version"] -description = 'Style-Bert-VITS2: Bert-VITS2 with more controllable voice styles.' +description = "Style-Bert-VITS2: Bert-VITS2 with more controllable voice styles." readme = "README.md" requires-python = ">=3.9" license = "AGPL-3.0" @@ -22,25 +22,21 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", ] dependencies = [ - 'cmudict', - 'cn2an', - 'g2p_en', - 'gradio', - 'jieba', - 'librosa==0.9.2', - 'loguru', - 'num2words', - 'numba', - 'numpy', - 'pyannote.audio>=3.1.0', - 'pydantic>=2.0', - 'pyopenjtalk-dict', - 'pypinyin', - 'pyworld-prebuilt', - 'safetensors', - 'scipy', - 'torch>=2.1', - 'transformers', + "cmudict", + "cn2an", + "g2p_en", + "jieba", + "loguru", + "num2words", + "numba", + "numpy", + "pydantic>=2.0", + "pyopenjtalk-dict", + "pypinyin", + "pyworld-prebuilt", + "safetensors", + "torch>=2.1", + "transformers", ] [project.urls] @@ -63,42 +59,26 @@ only-include = [ "pyproject.toml", "README.md", ] -exclude = [ - ".git", - ".gitignore", - ".gitattributes", -] +exclude = [".git", ".gitignore", ".gitattributes"] [tool.hatch.build.targets.wheel] packages = ["style_bert_vits2"] [tool.hatch.envs.test] -dependencies = [ - "coverage[toml]>=6.5", - "pytest", -] +dependencies = ["coverage[toml]>=6.5", "pytest"] [tool.hatch.envs.test.scripts] # Usage: `hatch run test:test` test = "pytest {args:tests}" # Usage: `hatch run test:coverage` test-cov = "coverage run -m pytest {args:tests}" # Usage: `hatch run test:cov-report` -cov-report = [ - "- coverage combine", - "coverage report", -] +cov-report = ["- coverage combine", "coverage report"] # Usage: `hatch run test:cov` -cov = [ - "test-cov", - "cov-report", -] +cov = ["test-cov", "cov-report"] [tool.hatch.envs.style] detached = true -dependencies = [ - "black", - "isort", -] +dependencies = ["black[jupyter]", "isort"] [tool.hatch.envs.style.scripts] check = [ "black --check --diff .", @@ -117,17 +97,17 @@ python = ["3.9", "3.10", "3.11"] source_pkgs = ["style_bert_vits2", "tests"] branch = true parallel = true -omit = [ - "style_bert_vits2/constants.py", -] +omit = ["style_bert_vits2/constants.py"] [tool.coverage.paths] style_bert_vits2 = ["style_bert_vits2", "*/style-bert-vits2/style_bert_vits2"] tests = ["tests", "*/style-bert-vits2/tests"] [tool.coverage.report] -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] + +[tool.ruff] +extend-select = ["I"] + +[tool.ruff.lint.isort] +lines-after-imports = 2 \ No newline at end of file diff --git a/requirements-colab.txt b/requirements-colab.txt new file mode 100644 index 0000000000000000000000000000000000000000..b34b256318a935584923ea68254bb49ee38571f4 --- /dev/null +++ b/requirements-colab.txt @@ -0,0 +1,20 @@ +cmudict +cn2an +g2p_en +gradio>=4.32 +jieba +librosa==0.9.2 +loguru +num2words +numpy<2 +onnxruntime +pyannote.audio>=3.1.0 +pyloudnorm +pyopenjtalk-dict +pypinyin +pyworld-prebuilt +torch +torchaudio +torchvision +transformers +umap-learn diff --git a/requirements-infer.txt b/requirements-infer.txt new file mode 100644 index 0000000000000000000000000000000000000000..336b64ff5fcaa219bbfbf24dc84d51ac81870df4 --- /dev/null +++ b/requirements-infer.txt @@ -0,0 +1,24 @@ +cmudict +cn2an +# faster-whisper==0.10.1 +g2p_en +GPUtil +gradio +jieba +# librosa==0.9.2 +loguru +num2words +numpy<2 +# protobuf==4.25 +psutil +# punctuators +pyannote.audio>=3.1.0 +# pyloudnorm +pyopenjtalk-dict +pypinyin +pyworld-prebuilt +# stable_ts +# tensorboard +torch<2.4 +transformers +umap-learn diff --git a/requirements.txt b/requirements.txt index 114e97202038bab593e434c45fb3248938dc0b14..1d54159fdfb778493bfa8a510a2fc3ad80666fe2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,28 +3,23 @@ cn2an faster-whisper==0.10.1 g2p_en GPUtil -gradio==4.23.0 +gradio>=4.32 jieba -langid librosa==0.9.2 loguru -matplotlib num2words -numba -numpy +numpy<2 +protobuf==4.25 psutil +punctuators pyannote.audio>=3.1.0 -pydantic>=2.0 pyloudnorm -# pyopenjtalk-prebuilt # Should be manually uninstalled pyopenjtalk-dict pypinyin pyworld-prebuilt -PyYAML -requests -safetensors -scipy +stable_ts tensorboard -torch>=2.1 +torch<2.4 +torchaudio<2.4 transformers umap-learn diff --git a/resample.py b/resample.py index 55f0f7f94ec14a79eeb6eefc897490d55d09a975..285105c46485c18171252b0ce13fd3506579eb01 100644 --- a/resample.py +++ b/resample.py @@ -10,7 +10,7 @@ import soundfile from numpy.typing import NDArray from tqdm import tqdm -from config import config +from config import get_config from style_bert_vits2.logging import logger from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT @@ -62,6 +62,7 @@ def resample( if trim: wav, _ = librosa.effects.trim(wav, top_db=30) relative_path = file.relative_to(input_dir) + # ここで拡張子が.wav以外でも.wavに置き換えられる output_path = output_dir / relative_path.with_suffix(".wav") output_path.parent.mkdir(parents=True, exist_ok=True) soundfile.write(output_path, wav, sr) @@ -70,6 +71,7 @@ def resample( if __name__ == "__main__": + config = get_config() parser = argparse.ArgumentParser() parser.add_argument( "--sr", diff --git a/scripts/Install-Style-Bert-VITS2-CPU.bat b/scripts/Install-Style-Bert-VITS2-CPU.bat index b62655ac574f933a9c2270dc51ba5ded02f0b6a4..8353a05f4478979a6046c0f632da3644ea36bd8c 100644 --- a/scripts/Install-Style-Bert-VITS2-CPU.bat +++ b/scripts/Install-Style-Bert-VITS2-CPU.bat @@ -1,123 +1,134 @@ -chcp 65001 > NUL -@echo off - -@REM エラーコードを遅延評価するために設定 -setlocal enabledelayedexpansion - -@REM PowerShellのコマンド -set PS_CMD=PowerShell -Version 5.1 -ExecutionPolicy Bypass - -@REM PortableGitのURLと保存先 -set DL_URL=https://github.com/git-for-windows/git/releases/download/v2.44.0.windows.1/PortableGit-2.44.0-64-bit.7z.exe -set DL_DST=%~dp0lib\PortableGit-2.44.0-64-bit.7z.exe - -@REM Style-Bert-VITS2のリポジトリURL -set REPO_URL=https://github.com/litagin02/Style-Bert-VITS2 - -@REM カレントディレクトリをbatファイルのディレクトリに変更 -pushd %~dp0 - -@REM lib フォルダがなければ作成 -if not exist lib\ ( mkdir lib ) - -echo -------------------------------------------------- -echo PS_CMD: %PS_CMD% -echo DL_URL: %DL_URL% -echo DL_DST: %DL_DST% -echo REPO_URL: %REPO_URL% -echo -------------------------------------------------- -echo. -echo -------------------------------------------------- -echo Checking Git Installation... -echo -------------------------------------------------- -echo Executing: git --version -git --version -if !errorlevel! neq 0 ( - echo -------------------------------------------------- - echo Git is not installed, so download and use PortableGit. - echo Downloading PortableGit... - echo -------------------------------------------------- - echo Executing: curl -L %DL_URL% -o "%DL_DST%" - curl -L %DL_URL% -o "%DL_DST%" - if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) - - echo -------------------------------------------------- - echo Extracting PortableGit... - echo -------------------------------------------------- - echo Executing: "%DL_DST%" -y - "%DL_DST%" -y - if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) - - echo -------------------------------------------------- - echo Removing %DL_DST%... - echo -------------------------------------------------- - echo Executing: del "%DL_DST%" - del "%DL_DST%" - if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) - - @REM Gitコマンドのパスを設定 - echo -------------------------------------------------- - echo Setting up PATH... - echo -------------------------------------------------- - echo Executing: set "PATH=%~dp0lib\PortableGit\bin;%PATH%" - set "PATH=%~dp0lib\PortableGit\bin;%PATH%" - if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) - - echo -------------------------------------------------- - echo Checking Git Installation... - echo -------------------------------------------------- - echo Executing: git --version - git --version - if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) -) - -echo -------------------------------------------------- -echo Cloning repository... -echo -------------------------------------------------- -echo Executing: git clone %REPO_URL% -git clone %REPO_URL% -if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) - -@REM Pythonのセットアップ、仮想環境が有効化されて戻って来る -echo -------------------------------------------------- -echo Setting up Python environment... -echo -------------------------------------------------- -echo Executing: call Setup-Python.bat ".\lib\python" ".\Style-Bert-VITS2\venv" -call Setup-Python.bat ".\lib\python" ".\Style-Bert-VITS2\venv" -if !errorlevel! neq 0 ( popd & exit /b !errorlevel! ) - -@REM Style-Bert-VITS2フォルダに移動 -pushd Style-Bert-VITS2 - -echo -------------------------------------------------- -echo Activating the virtual environment... -echo -------------------------------------------------- -echo Executing: call ".\venv\Scripts\activate.bat" -call ".\venv\Scripts\activate.bat" -if !errorlevel! neq 0 ( popd & exit /b !errorlevel! ) - -echo -------------------------------------------------- -echo Installing dependencies... -echo -------------------------------------------------- -echo Executing: pip install -r requirements.txt -pip install -r requirements.txt -if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) - -echo ---------------------------------------- -echo Environment setup is complete. Start downloading the model. -echo ---------------------------------------- -echo Executing: python initialize.py -python initialize.py --only_infer - -echo ---------------------------------------- -echo Model download is complete. Start Style-Bert-VITS2 Editor. -echo ---------------------------------------- -echo Executing: python server_editor.py --inbrowser -python server_editor.py --inbrowser -pause - -popd - -popd - -endlocal +chcp 65001 > NUL +@echo off + +@REM エラーコードを遅延評価するために設定 +setlocal enabledelayedexpansion + +@REM PowerShellのコマンド +set PS_CMD=PowerShell -Version 5.1 -ExecutionPolicy Bypass + +@REM PortableGitのURLと保存先 +set DL_URL=https://github.com/git-for-windows/git/releases/download/v2.44.0.windows.1/PortableGit-2.44.0-64-bit.7z.exe +set DL_DST=%~dp0lib\PortableGit-2.44.0-64-bit.7z.exe + +@REM Style-Bert-VITS2のリポジトリURL +set REPO_URL=https://github.com/litagin02/Style-Bert-VITS2 + +@REM カレントディレクトリをbatファイルのディレクトリに変更 +pushd %~dp0 + +@REM lib フォルダがなければ作成 +if not exist lib\ ( mkdir lib ) + +echo -------------------------------------------------- +echo PS_CMD: %PS_CMD% +echo DL_URL: %DL_URL% +echo DL_DST: %DL_DST% +echo REPO_URL: %REPO_URL% +echo -------------------------------------------------- +echo. +echo -------------------------------------------------- +echo Checking Git Installation... +echo -------------------------------------------------- +echo Executing: git --version +git --version +if !errorlevel! neq 0 ( + echo -------------------------------------------------- + echo Git is not installed, so download and use PortableGit. + echo Downloading PortableGit... + echo -------------------------------------------------- + echo Executing: curl -L %DL_URL% -o "%DL_DST%" + curl -L %DL_URL% -o "%DL_DST%" + if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + + echo -------------------------------------------------- + echo Extracting PortableGit... + echo -------------------------------------------------- + echo Executing: "%DL_DST%" -y + "%DL_DST%" -y + if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + + echo -------------------------------------------------- + echo Removing %DL_DST%... + echo -------------------------------------------------- + echo Executing: del "%DL_DST%" + del "%DL_DST%" + if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + + @REM Gitコマンドのパスを設定 + echo -------------------------------------------------- + echo Setting up PATH... + echo -------------------------------------------------- + echo Executing: set "PATH=%~dp0lib\PortableGit\bin;%PATH%" + set "PATH=%~dp0lib\PortableGit\bin;%PATH%" + if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + + echo -------------------------------------------------- + echo Checking Git Installation... + echo -------------------------------------------------- + echo Executing: git --version + git --version + if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) +) + +echo -------------------------------------------------- +echo Cloning repository... +echo -------------------------------------------------- +echo Executing: git clone %REPO_URL% +git clone %REPO_URL% +if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + +@REM Pythonのセットアップ、仮想環境が有効化されて戻って来る +echo -------------------------------------------------- +echo Setting up Python environment... +echo -------------------------------------------------- +echo Executing: call Setup-Python.bat ".\lib\python" ".\Style-Bert-VITS2\venv" +call Setup-Python.bat ".\lib\python" ".\Style-Bert-VITS2\venv" +if !errorlevel! neq 0 ( popd & exit /b !errorlevel! ) + +@REM Style-Bert-VITS2フォルダに移動 +pushd Style-Bert-VITS2 + +@REM 後で消す!!!!!!!!!! +@REM git checkout dev +@REM 後で消す!!!!!!!!!! + +echo -------------------------------------------------- +echo Activating the virtual environment... +echo -------------------------------------------------- +echo Executing: call ".\venv\Scripts\activate.bat" +call ".\venv\Scripts\activate.bat" +if !errorlevel! neq 0 ( popd & exit /b !errorlevel! ) + +echo -------------------------------------------------- +echo Installing package manager uv... +echo -------------------------------------------------- +echo Executing: pip install uv +pip install uv +if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + +echo -------------------------------------------------- +echo Installing dependencies... +echo -------------------------------------------------- +echo Executing: uv pip install -r requirements-infer.txt +uv pip install -r requirements-infer.txt +if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + +echo ---------------------------------------- +echo Environment setup is complete. Start downloading the model. +echo ---------------------------------------- +echo Executing: python initialize.py +python initialize.py --only_infer + +echo ---------------------------------------- +echo Model download is complete. Start Style-Bert-VITS2 Editor. +echo ---------------------------------------- +echo Executing: python server_editor.py --inbrowser +python server_editor.py --inbrowser +pause + +popd + +popd + +endlocal diff --git a/scripts/Install-Style-Bert-VITS2.bat b/scripts/Install-Style-Bert-VITS2.bat index 35ce45c0c357c118fde6ef3d19f601dba9b4dbb4..5b6715ecab0c2b66cd5be2409be04913d1f6e95a 100644 --- a/scripts/Install-Style-Bert-VITS2.bat +++ b/scripts/Install-Style-Bert-VITS2.bat @@ -1,130 +1,141 @@ -chcp 65001 > NUL -@echo off - -@REM エラーコードを遅延評価するために設定 -setlocal enabledelayedexpansion - -@REM PowerShellのコマンド -set PS_CMD=PowerShell -Version 5.1 -ExecutionPolicy Bypass - -@REM PortableGitのURLと保存先 -set DL_URL=https://github.com/git-for-windows/git/releases/download/v2.44.0.windows.1/PortableGit-2.44.0-64-bit.7z.exe -set DL_DST=%~dp0lib\PortableGit-2.44.0-64-bit.7z.exe - -@REM Style-Bert-VITS2のリポジトリURL -set REPO_URL=https://github.com/litagin02/Style-Bert-VITS2 - -@REM カレントディレクトリをbatファイルのディレクトリに変更 -pushd %~dp0 - -@REM lib フォルダがなければ作成 -if not exist lib\ ( mkdir lib ) - -echo -------------------------------------------------- -echo PS_CMD: %PS_CMD% -echo DL_URL: %DL_URL% -echo DL_DST: %DL_DST% -echo REPO_URL: %REPO_URL% -echo -------------------------------------------------- -echo. -echo -------------------------------------------------- -echo Checking Git Installation... -echo -------------------------------------------------- -echo Executing: git --version -git --version -if !errorlevel! neq 0 ( - echo -------------------------------------------------- - echo Git is not installed, so download and use PortableGit. - echo Downloading PortableGit... - echo -------------------------------------------------- - echo Executing: curl -L %DL_URL% -o "%DL_DST%" - curl -L %DL_URL% -o "%DL_DST%" - if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) - - echo -------------------------------------------------- - echo Extracting PortableGit... - echo -------------------------------------------------- - echo Executing: "%DL_DST%" -y - "%DL_DST%" -y - if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) - - echo -------------------------------------------------- - echo Removing %DL_DST%... - echo -------------------------------------------------- - echo Executing: del "%DL_DST%" - del "%DL_DST%" - if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) - - @REM Gitコマンドのパスを設定 - echo -------------------------------------------------- - echo Setting up PATH... - echo -------------------------------------------------- - echo Executing: set "PATH=%~dp0lib\PortableGit\bin;%PATH%" - set "PATH=%~dp0lib\PortableGit\bin;%PATH%" - if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) - - echo -------------------------------------------------- - echo Checking Git Installation... - echo -------------------------------------------------- - echo Executing: git --version - git --version - if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) -) - -echo -------------------------------------------------- -echo Cloning repository... -echo -------------------------------------------------- -echo Executing: git clone %REPO_URL% -git clone %REPO_URL% -if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) - -@REM Pythonのセットアップ -echo -------------------------------------------------- -echo Setting up Python environment... -echo -------------------------------------------------- -echo Executing: call Setup-Python.bat ".\lib\python" ".\Style-Bert-VITS2\venv" -call Setup-Python.bat ".\lib\python" ".\Style-Bert-VITS2\venv" -if !errorlevel! neq 0 ( popd & exit /b !errorlevel! ) - -@REM Style-Bert-VITS2フォルダに移動 -pushd Style-Bert-VITS2 - -echo -------------------------------------------------- -echo Activating the virtual environment... -echo -------------------------------------------------- -echo Executing: call ".\venv\Scripts\activate.bat" -call ".\venv\Scripts\activate.bat" -if !errorlevel! neq 0 ( popd & exit /b !errorlevel! ) - -echo -------------------------------------------------- -echo Installing PyTorch... -echo -------------------------------------------------- -echo Executing: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 -pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 -if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) - -echo -------------------------------------------------- -echo Installing other dependencies... -echo -------------------------------------------------- -echo Executing: pip install -r requirements.txt -pip install -r requirements.txt -if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) - -echo ---------------------------------------- -echo Environment setup is complete. Start downloading the model. -echo ---------------------------------------- -echo Executing: python initialize.py -python initialize.py - -echo ---------------------------------------- -echo Model download is complete. Start Style-Bert-VITS2 Editor. -echo ---------------------------------------- -echo Executing: python server_editor.py --inbrowser -python server_editor.py --inbrowser -pause - -popd - -popd - -endlocal +chcp 65001 > NUL +@echo off + +@REM エラーコードを遅延評価するために設定 +setlocal enabledelayedexpansion + +@REM PowerShellのコマンド +set PS_CMD=PowerShell -Version 5.1 -ExecutionPolicy Bypass + +@REM PortableGitのURLと保存先 +set DL_URL=https://github.com/git-for-windows/git/releases/download/v2.44.0.windows.1/PortableGit-2.44.0-64-bit.7z.exe +set DL_DST=%~dp0lib\PortableGit-2.44.0-64-bit.7z.exe + +@REM Style-Bert-VITS2のリポジトリURL +set REPO_URL=https://github.com/litagin02/Style-Bert-VITS2 + +@REM カレントディレクトリをbatファイルのディレクトリに変更 +pushd %~dp0 + +@REM lib フォルダがなければ作成 +if not exist lib\ ( mkdir lib ) + +echo -------------------------------------------------- +echo PS_CMD: %PS_CMD% +echo DL_URL: %DL_URL% +echo DL_DST: %DL_DST% +echo REPO_URL: %REPO_URL% +echo -------------------------------------------------- +echo. +echo -------------------------------------------------- +echo Checking Git Installation... +echo -------------------------------------------------- +echo Executing: git --version +git --version +if !errorlevel! neq 0 ( + echo -------------------------------------------------- + echo Git is not installed, so download and use PortableGit. + echo Downloading PortableGit... + echo -------------------------------------------------- + echo Executing: curl -L %DL_URL% -o "%DL_DST%" + curl -L %DL_URL% -o "%DL_DST%" + if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + + echo -------------------------------------------------- + echo Extracting PortableGit... + echo -------------------------------------------------- + echo Executing: "%DL_DST%" -y + "%DL_DST%" -y + if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + + echo -------------------------------------------------- + echo Removing %DL_DST%... + echo -------------------------------------------------- + echo Executing: del "%DL_DST%" + del "%DL_DST%" + if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + + @REM Gitコマンドのパスを設定 + echo -------------------------------------------------- + echo Setting up PATH... + echo -------------------------------------------------- + echo Executing: set "PATH=%~dp0lib\PortableGit\bin;%PATH%" + set "PATH=%~dp0lib\PortableGit\bin;%PATH%" + if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + + echo -------------------------------------------------- + echo Checking Git Installation... + echo -------------------------------------------------- + echo Executing: git --version + git --version + if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) +) + +echo -------------------------------------------------- +echo Cloning repository... +echo -------------------------------------------------- +echo Executing: git clone %REPO_URL% +git clone %REPO_URL% +if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + +@REM Pythonのセットアップ +echo -------------------------------------------------- +echo Setting up Python environment... +echo -------------------------------------------------- +echo Executing: call Setup-Python.bat ".\lib\python" ".\Style-Bert-VITS2\venv" +call Setup-Python.bat ".\lib\python" ".\Style-Bert-VITS2\venv" +if !errorlevel! neq 0 ( popd & exit /b !errorlevel! ) + +@REM Style-Bert-VITS2フォルダに移動 +pushd Style-Bert-VITS2 + +@REM 後で消す!!!!!!!!!! +@REM git checkout dev +@REM 後で消す!!!!!!!!!! + +echo -------------------------------------------------- +echo Activating the virtual environment... +echo -------------------------------------------------- +echo Executing: call ".\venv\Scripts\activate.bat" +call ".\venv\Scripts\activate.bat" +if !errorlevel! neq 0 ( popd & exit /b !errorlevel! ) + +echo -------------------------------------------------- +echo Installing package manager uv... +echo -------------------------------------------------- +echo Executing: pip install uv +pip install uv +if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + +echo -------------------------------------------------- +echo Installing PyTorch... +echo -------------------------------------------------- +echo Executing: uv pip install "torch<2.4" "torchaudio<2.4" --index-url https://download.pytorch.org/whl/cu118 +uv pip install "torch<2.4" "torchaudio<2.4" --index-url https://download.pytorch.org/whl/cu118 +if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + +echo -------------------------------------------------- +echo Installing other dependencies... +echo -------------------------------------------------- +echo Executing: uv pip install -r requirements.txt +uv pip install -r requirements.txt +if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + +echo ---------------------------------------- +echo Environment setup is complete. Start downloading the model. +echo ---------------------------------------- +echo Executing: python initialize.py +python initialize.py + +echo ---------------------------------------- +echo Model download is complete. Start Style-Bert-VITS2 Editor. +echo ---------------------------------------- +echo Executing: python server_editor.py --inbrowser +python server_editor.py --inbrowser +pause + +popd + +popd + +endlocal diff --git a/scripts/Setup-Python.bat b/scripts/Setup-Python.bat index 27ca69ec3b39316392cbf6a9fc76176a0298a14e..bedabfaf8c44b5605927b7d1a90e7684dd363899 100644 --- a/scripts/Setup-Python.bat +++ b/scripts/Setup-Python.bat @@ -1,115 +1,101 @@ -chcp 65001 > NUL - -@REM https://github.com/Zuntan03/EasyBertVits2 より引用・改変 - -@REM エラーコードを遅延評価するために設定 -setlocal enabledelayedexpansion - -@echo off -set PS_CMD=PowerShell -Version 5.1 -ExecutionPolicy Bypass -set CURL_CMD=C:\Windows\System32\curl.exe - -if not exist %CURL_CMD% ( - echo [ERROR] %CURL_CMD% が見つかりません。 - pause & exit /b 1 -) - -if "%1" neq "" ( - set PYTHON_DIR=%~dp0%~1 -) else ( - set PYTHON_DIR=%~dp0python -) -set PYTHON_CMD=%PYTHON_DIR%\python.exe - -if "%2" neq "" ( - set VENV_DIR=%~dp0%~2 -) else ( - set VENV_DIR=%~dp0venv -) - -echo -------------------------------------------------- -echo PS_CMD: %PS_CMD% -echo CURL_CMD: %CURL_CMD% -echo PYTHON_CMD: %PYTHON_CMD% -echo PYTHON_DIR: %PYTHON_DIR% -echo VENV_DIR: %VENV_DIR% -echo -------------------------------------------------- -echo. - -if not exist "%PYTHON_DIR%"\ ( - echo -------------------------------------------------- - echo Downloading Python... - echo -------------------------------------------------- - echo Executing: %CURL_CMD% -o python.zip https://www.python.org/ftp/python/3.10.11/python-3.10.11-embed-amd64.zip - %CURL_CMD% -o python.zip https://www.python.org/ftp/python/3.10.11/python-3.10.11-embed-amd64.zip - if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) - - echo -------------------------------------------------- - echo Extracting zip... - echo -------------------------------------------------- - echo Executing: %PS_CMD% Expand-Archive -Path python.zip -DestinationPath \"%PYTHON_DIR%\" - %PS_CMD% Expand-Archive -Path python.zip -DestinationPath \"%PYTHON_DIR%\" - if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) - - echo -------------------------------------------------- - echo Removing python.zip... - echo -------------------------------------------------- - echo Executing: del python.zip - del python.zip - if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) - - echo -------------------------------------------------- - echo Enabling 'site' module in the embedded Python environment... - echo -------------------------------------------------- - echo Executing: %PS_CMD% "&{(Get-Content '%PYTHON_DIR%/python310._pth') -creplace '#import site', 'import site' | Set-Content '%PYTHON_DIR%/python310._pth' }" - %PS_CMD% "&{(Get-Content '%PYTHON_DIR%/python310._pth') -creplace '#import site', 'import site' | Set-Content '%PYTHON_DIR%/python310._pth' }" - if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) - - echo -------------------------------------------------- - echo Installing pip and virtualenv... - echo -------------------------------------------------- - echo Executing: %CURL_CMD% -o "%PYTHON_DIR%\get-pip.py" https://bootstrap.pypa.io/get-pip.py - %CURL_CMD% -o "%PYTHON_DIR%\get-pip.py" https://bootstrap.pypa.io/get-pip.py - if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) - - echo -------------------------------------------------- - echo Installing pip... - echo -------------------------------------------------- - echo Executing: "%PYTHON_CMD%" "%PYTHON_DIR%\get-pip.py" --no-warn-script-location - "%PYTHON_CMD%" "%PYTHON_DIR%\get-pip.py" --no-warn-script-location - if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) - - echo -------------------------------------------------- - echo Installing virtualenv... - echo -------------------------------------------------- - echo Executing: "%PYTHON_CMD%" -m pip install virtualenv --no-warn-script-location - "%PYTHON_CMD%" -m pip install virtualenv --no-warn-script-location - if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) -) - -if not exist %VENV_DIR%\ ( - echo -------------------------------------------------- - echo Creating virtual environment... - echo -------------------------------------------------- - echo Executing: "%PYTHON_CMD%" -m virtualenv --copies "%VENV_DIR%" - "%PYTHON_CMD%" -m virtualenv --copies "%VENV_DIR%" - if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) -) - -echo -------------------------------------------------- -echo Activating virtual environment... -echo -------------------------------------------------- -echo Executing: call "%VENV_DIR%\Scripts\activate.bat" -call "%VENV_DIR%\Scripts\activate.bat" -if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) - -echo -------------------------------------------------- -echo Upgrading pip... -echo -------------------------------------------------- -echo Executing: python -m pip install --upgrade pip -python -m pip install --upgrade pip -if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) - -echo -------------------------------------------------- -echo Completed. -echo -------------------------------------------------- +chcp 65001 > NUL + +@REM https://github.com/Zuntan03/EasyBertVits2 より引用・改変 + +@REM エラーコードを遅延評価するために設定 +setlocal enabledelayedexpansion + +@echo off +set PS_CMD=PowerShell -Version 5.1 -ExecutionPolicy Bypass +set CURL_CMD=C:\Windows\System32\curl.exe + +if not exist %CURL_CMD% ( + echo [ERROR] %CURL_CMD% が見つかりません。 + pause & exit /b 1 +) + +if "%1" neq "" ( + set PYTHON_DIR=%~dp0%~1 +) else ( + set PYTHON_DIR=%~dp0python +) +set PYTHON_CMD=%PYTHON_DIR%\python.exe + +if "%2" neq "" ( + set VENV_DIR=%~dp0%~2 +) else ( + set VENV_DIR=%~dp0venv +) + +echo -------------------------------------------------- +echo PS_CMD: %PS_CMD% +echo CURL_CMD: %CURL_CMD% +echo PYTHON_CMD: %PYTHON_CMD% +echo PYTHON_DIR: %PYTHON_DIR% +echo VENV_DIR: %VENV_DIR% +echo -------------------------------------------------- +echo. + +if not exist "%PYTHON_DIR%"\ ( + echo -------------------------------------------------- + echo Downloading Python... + echo -------------------------------------------------- + echo Executing: %CURL_CMD% -o python.zip https://www.python.org/ftp/python/3.10.11/python-3.10.11-embed-amd64.zip + %CURL_CMD% -o python.zip https://www.python.org/ftp/python/3.10.11/python-3.10.11-embed-amd64.zip + if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) + + echo -------------------------------------------------- + echo Extracting zip... + echo -------------------------------------------------- + echo Executing: %PS_CMD% Expand-Archive -Path python.zip -DestinationPath \"%PYTHON_DIR%\" + %PS_CMD% Expand-Archive -Path python.zip -DestinationPath \"%PYTHON_DIR%\" + if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) + + echo -------------------------------------------------- + echo Removing python.zip... + echo -------------------------------------------------- + echo Executing: del python.zip + del python.zip + if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) + + echo -------------------------------------------------- + echo Enabling 'site' module in the embedded Python environment... + echo -------------------------------------------------- + echo Executing: %PS_CMD% "&{(Get-Content '%PYTHON_DIR%/python310._pth') -creplace '#import site', 'import site' | Set-Content '%PYTHON_DIR%/python310._pth' }" + %PS_CMD% "&{(Get-Content '%PYTHON_DIR%/python310._pth') -creplace '#import site', 'import site' | Set-Content '%PYTHON_DIR%/python310._pth' }" + if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) + + echo -------------------------------------------------- + echo Downloading get-pip.py... + echo -------------------------------------------------- + echo Executing: %CURL_CMD% -o "%PYTHON_DIR%\get-pip.py" https://bootstrap.pypa.io/get-pip.py + %CURL_CMD% -o "%PYTHON_DIR%\get-pip.py" https://bootstrap.pypa.io/get-pip.py + if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) + + echo -------------------------------------------------- + echo Installing pip... + echo -------------------------------------------------- + echo Executing: "%PYTHON_CMD%" "%PYTHON_DIR%\get-pip.py" --no-warn-script-location + "%PYTHON_CMD%" "%PYTHON_DIR%\get-pip.py" --no-warn-script-location + if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) + + echo -------------------------------------------------- + echo Installing virtualenv... + echo -------------------------------------------------- + echo Executing: "%PYTHON_CMD%" -m pip install virtualenv --no-warn-script-location + "%PYTHON_CMD%" -m pip install virtualenv --no-warn-script-location + if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) +) + +if not exist %VENV_DIR%\ ( + echo -------------------------------------------------- + echo Creating virtual environment... + echo -------------------------------------------------- + echo Executing: "%PYTHON_CMD%" -m virtualenv --copies "%VENV_DIR%" + "%PYTHON_CMD%" -m virtualenv --copies "%VENV_DIR%" + if !errorlevel! neq 0 ( pause & exit /b !errorlevel! ) +) + +echo -------------------------------------------------- +echo Completed. +echo -------------------------------------------------- diff --git a/scripts/Update-Style-Bert-VITS2.bat b/scripts/Update-Style-Bert-VITS2.bat index 5fd50e4e1a1d6e85c3a108cad7299f9df97c8f0c..2aa79e153923b6e71ed286d1f3f28da9da59ad37 100644 --- a/scripts/Update-Style-Bert-VITS2.bat +++ b/scripts/Update-Style-Bert-VITS2.bat @@ -1,62 +1,69 @@ -chcp 65001 > NUL -@echo off - -@REM エラーコードを遅延評価するために設定 -setlocal enabledelayedexpansion - -pushd %~dp0 - - -pushd Style-Bert-VITS2 - -echo -------------------------------------------------- -echo Checking Git Installation... -echo -------------------------------------------------- -git --version -if !errorlevel! neq 0 ( - echo -------------------------------------------------- - echo Global Git is not installed, so use PortableGit. - echo Setting up PATH... - echo -------------------------------------------------- - echo Executing: set "PATH=%~dp0lib\PortableGit\bin;%PATH%" - set "PATH=%~dp0lib\PortableGit\bin;%PATH%" - - echo -------------------------------------------------- - echo Checking Git Installation... - echo -------------------------------------------------- - echo Executing: git --version - git --version - if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) -) - -echo -------------------------------------------------- -echo Git pull... -echo -------------------------------------------------- -git pull -if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) - -@REM 仮想環境のpip requirements.txtを更新 - -echo -------------------------------------------------- -echo Activating virtual environment... -echo -------------------------------------------------- -echo Executing: call ".\venv\Scripts\activate.bat" -call ".\venv\Scripts\activate.bat" -if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) - -echo -------------------------------------------------- -echo Updating dependencies... -echo -------------------------------------------------- -echo Executing: pip install -U -r requirements.txt -pip install -U -r requirements.txt -if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) - -echo ---------------------------------------- -echo Update completed. -echo ---------------------------------------- - -pause - -popd - -popd +chcp 65001 > NUL +@echo off + +@REM エラーコードを遅延評価するために設定 +setlocal enabledelayedexpansion + +pushd %~dp0 + + +pushd Style-Bert-VITS2 + +echo -------------------------------------------------- +echo Checking Git Installation... +echo -------------------------------------------------- +git --version +if !errorlevel! neq 0 ( + echo -------------------------------------------------- + echo Global Git is not installed, so use PortableGit. + echo Setting up PATH... + echo -------------------------------------------------- + echo Executing: set "PATH=%~dp0lib\PortableGit\bin;%PATH%" + set "PATH=%~dp0lib\PortableGit\bin;%PATH%" + + echo -------------------------------------------------- + echo Checking Git Installation... + echo -------------------------------------------------- + echo Executing: git --version + git --version + if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) +) + +echo -------------------------------------------------- +echo Git pull... +echo -------------------------------------------------- +git pull +if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + +@REM 仮想環境のpip requirements.txtを更新 + +echo -------------------------------------------------- +echo Activating virtual environment... +echo -------------------------------------------------- +echo Executing: call ".\venv\Scripts\activate.bat" +call ".\venv\Scripts\activate.bat" +if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + +echo -------------------------------------------------- +echo Installing uv... +echo -------------------------------------------------- +echo Executing: pip install -U uv +pip install -U uv +if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + +echo -------------------------------------------------- +echo Updating dependencies... +echo -------------------------------------------------- +echo Executing: uv pip install -U -r requirements.txt +uv pip install -U -r requirements.txt +if !errorlevel! neq 0 ( pause & popd & exit /b !errorlevel! ) + +echo ---------------------------------------- +echo Update completed. +echo ---------------------------------------- + +pause + +popd + +popd diff --git a/server_editor.py b/server_editor.py index c6f856e2f6cc04d72ad4f9f77249ee259f014cac..cde2c7dae3ef1e90eb8cce9e3a947c54004d7616 100644 --- a/server_editor.py +++ b/server_editor.py @@ -22,7 +22,6 @@ import numpy as np import requests import torch import uvicorn -import yaml from fastapi import APIRouter, FastAPI, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response @@ -30,6 +29,7 @@ from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from scipy.io import wavfile +from config import get_path_config from style_bert_vits2.constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, DEFAULT_NOISE, @@ -127,7 +127,7 @@ def download_and_extract(url, extract_to: Path): def new_release_available(latest_release): if LAST_DOWNLOAD_FILE.exists(): - with open(LAST_DOWNLOAD_FILE, "r") as file: + with open(LAST_DOWNLOAD_FILE) as file: last_download_str = file.read().strip() # 'Z'を除去して日時オブジェクトに変換 last_download_str = last_download_str.replace("Z", "+00:00") @@ -174,35 +174,32 @@ origins = [ "http://127.0.0.1:8000", ] -# Get path settings -with open(Path("configs/paths.yml"), "r", encoding="utf-8") as f: - path_config: dict[str, str] = yaml.safe_load(f.read()) - # dataset_root = path_config["dataset_root"] - assets_root = path_config["assets_root"] - +path_config = get_path_config() parser = argparse.ArgumentParser() -parser.add_argument("--model_dir", type=str, default="model_assets/") +parser.add_argument("--model_dir", type=str, default=path_config.assets_root) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--port", type=int, default=8000) parser.add_argument("--inbrowser", action="store_true") parser.add_argument("--line_length", type=int, default=None) parser.add_argument("--line_count", type=int, default=None) -parser.add_argument( - "--dir", "-d", type=str, help="Model directory", default=assets_root -) - +# parser.add_argument("--skip_default_models", action="store_true") +parser.add_argument("--skip_static_files", action="store_true") args = parser.parse_args() device = args.device if device == "cuda" and not torch.cuda.is_available(): device = "cpu" model_dir = Path(args.model_dir) port = int(args.port) +# if not args.skip_default_models: +# download_default_models() +skip_static_files = bool(args.skip_static_files) model_holder = TTSModelHolder(model_dir, device) if len(model_holder.model_names) == 0: logger.error(f"Models not found in {model_dir}.") sys.exit(1) + app = FastAPI() @@ -444,7 +441,8 @@ def delete_user_dict_word(uuid: str): app.include_router(router, prefix="/api") if __name__ == "__main__": - download_static_files("litagin02", "Style-Bert-VITS2-Editor", "out.zip") + if not skip_static_files: + download_static_files("litagin02", "Style-Bert-VITS2-Editor", "out.zip") app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static") if args.inbrowser: webbrowser.open(f"http://localhost:{port}") diff --git a/server_fastapi.py b/server_fastapi.py index 3c9d31e6829c1cd9aea0a99708b13e3c7643543a..f1cf97f696aad9096ca6484392e8e83e8094fa06 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -20,7 +20,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, Response from scipy.io import wavfile -from config import config +from config import get_config from style_bert_vits2.constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, DEFAULT_LENGTH, @@ -40,6 +40,7 @@ from style_bert_vits2.nlp.japanese.user_dict import update_dict from style_bert_vits2.tts_model import TTSModel, TTSModelHolder +config = get_config() ln = config.server_config.language @@ -113,6 +114,12 @@ if __name__ == "__main__": load_models(model_holder) limit = config.server_config.limit + if limit < 1: + limit = None + else: + logger.info( + f"The maximum length of the text is {limit}. If you want to change it, modify config.yml. Set limit to -1 to remove the limit." + ) app = FastAPI() allow_origins = config.server_config.origins if allow_origins: @@ -134,6 +141,10 @@ if __name__ == "__main__": request: Request, text: str = Query(..., min_length=1, max_length=limit, description="セリフ"), encoding: str = Query(None, description="textをURLデコードする(ex, `utf-8`)"), + model_name: str = Query( + None, + description="モデル名(model_idより優先)。model_assets内のディレクトリ名を指定", + ), model_id: int = Query( 0, description="モデルID。`GET /models/info`のkeyの値を指定ください" ), @@ -191,6 +202,20 @@ if __name__ == "__main__": ): # /models/refresh があるためQuery(le)で表現不可 raise_validation_error(f"model_id={model_id} not found", "model_id") + if model_name: + # load_models() の 処理内容が i の正当性を担保していることに注意 + model_ids = [i for i, x in enumerate(model_holder.models_info) if x.name == model_name] + if not model_ids: + raise_validation_error( + f"model_name={model_name} not found", "model_name" + ) + # 今の実装ではディレクトリ名が重複することは無いはずだが... + if len(model_ids) > 1: + raise_validation_error( + f"model_name={model_name} is ambiguous", "model_name" + ) + model_id = model_ids[0] + model = loaded_models[model_id] if speaker_name is None: if speaker_id not in model.id2spk.keys(): @@ -230,6 +255,10 @@ if __name__ == "__main__": wavfile.write(wavContent, sr, audio) return Response(content=wavContent.getvalue(), media_type="audio/wav") + @app.post("/g2p") + def g2p(text: str): + return g2kata_tone(normalize_text(text)) + @app.get("/models/info") def get_loaded_models_info(): """ロードされたモデル情報の取得""" @@ -305,6 +334,9 @@ if __name__ == "__main__": logger.info(f"server listen: http://127.0.0.1:{config.server_config.port}") logger.info(f"API docs: http://127.0.0.1:{config.server_config.port}/docs") + logger.info( + f"Input text length limit: {limit}. You can change it in server.limit in config.yml" + ) uvicorn.run( app, port=config.server_config.port, host="0.0.0.0", log_level="warning" ) diff --git a/slice.py b/slice.py index c4b2292fd3e447be029d963ad0d54e476a67d06c..bfcb916cd52afe76c42a8b224d27ebbae09a23e5 100644 --- a/slice.py +++ b/slice.py @@ -7,15 +7,15 @@ from typing import Any, Optional import soundfile as sf import torch -import yaml from tqdm import tqdm +from config import get_path_config from style_bert_vits2.logging import logger from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT def is_audio_file(file: Path) -> bool: - supported_extensions = [".wav", ".flac", ".mp3", ".ogg", ".opus"] + supported_extensions = [".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a"] return file.suffix.lower() in supported_extensions @@ -150,13 +150,12 @@ if __name__ == "__main__": ) args = parser.parse_args() - with open(Path("configs/paths.yml"), "r", encoding="utf-8") as f: - path_config: dict[str, str] = yaml.safe_load(f.read()) - dataset_root = path_config["dataset_root"] + path_config = get_path_config() + dataset_root = path_config.dataset_root model_name = str(args.model_name) input_dir = Path(args.input_dir) - output_dir = Path(dataset_root) / model_name / "raw" + output_dir = dataset_root / model_name / "raw" min_sec: float = args.min_sec max_sec: float = args.max_sec min_silence_dur_ms: int = args.min_silence_dur_ms @@ -198,11 +197,12 @@ if __name__ == "__main__": q.task_done() break try: + rel_path = file.relative_to(input_dir) time_sec, count = split_wav( vad_model=vad_model, utils=utils, audio_file=file, - target_dir=output_dir, + target_dir=output_dir / rel_path.parent, min_sec=min_sec, max_sec=max_sec, min_silence_dur_ms=min_silence_dur_ms, diff --git a/speech_mos.py b/speech_mos.py index 453b7d31300177df6d46ff9c56a2bf1c8ee3407e..c7a6a25a7640aa68d4a30524a461119592c77dce 100644 --- a/speech_mos.py +++ b/speech_mos.py @@ -10,7 +10,7 @@ import pandas as pd import torch from tqdm import tqdm -from config import config +from config import get_path_config from style_bert_vits2.logging import logger from style_bert_vits2.tts_model import TTSModel @@ -35,6 +35,8 @@ test_texts = [ "この分野の最新の研究成果を使うと、より自然で表現豊かな音声の生成が可能である。深層学習の応用により、感情やアクセントを含む声質の微妙な変化も再現することが出来る。", ] +path_config = get_path_config() + predictor = torch.hub.load( "tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True ) @@ -48,17 +50,16 @@ args = parser.parse_args() model_name: str = args.model_name device: str = args.device -model_path = Path(config.assets_root) / model_name - +model_path = path_config.assets_root / model_name # .safetensorsファイルを検索 safetensors_files = model_path.glob("*.safetensors") def get_model(model_file: Path): return TTSModel( - model_path=str(model_file), - config_path=str(model_file.parent / "config.json"), - style_vec_path=str(model_file.parent / "style_vectors.npy"), + model_path=model_file, + config_path=model_file.parent / "config.json", + style_vec_path=model_file.parent / "style_vectors.npy", device=device, ) diff --git a/style_bert_vits2/constants.py b/style_bert_vits2/constants.py index 58b6441a6fe2479cdaeed6861f698042eff79e1f..40bfdbaeb83950e0e5cbe4df8f4e317d8b1000c7 100644 --- a/style_bert_vits2/constants.py +++ b/style_bert_vits2/constants.py @@ -4,7 +4,7 @@ from style_bert_vits2.utils.strenum import StrEnum # Style-Bert-VITS2 のバージョン -VERSION = "2.4.1" +VERSION = "2.6.1" # Style-Bert-VITS2 のベースディレクトリ BASE_DIR = Path(__file__).parent.parent @@ -32,7 +32,7 @@ DEFAULT_USER_DICT_DIR = BASE_DIR / "dict_data" # デフォルトの推論パラメータ DEFAULT_STYLE = "Neutral" -DEFAULT_STYLE_WEIGHT = 5.0 +DEFAULT_STYLE_WEIGHT = 1.0 DEFAULT_SDP_RATIO = 0.2 DEFAULT_NOISE = 0.6 DEFAULT_NOISEW = 0.8 diff --git a/style_bert_vits2/models/hyper_parameters.py b/style_bert_vits2/models/hyper_parameters.py index feb6bfbafb220d01f0d6aa59e0cad36c45795f0c..827ce6dea5b69e3dee9e3f1378bece0289790489 100644 --- a/style_bert_vits2/models/hyper_parameters.py +++ b/style_bert_vits2/models/hyper_parameters.py @@ -125,5 +125,5 @@ class HyperParameters(BaseModel): HyperParameters: ハイパーパラメータ """ - with open(json_path, "r", encoding="utf-8") as f: + with open(json_path, encoding="utf-8") as f: return HyperParameters.model_validate_json(f.read()) diff --git a/style_bert_vits2/models/infer.py b/style_bert_vits2/models/infer.py index a9bcbf658831aa2aa253f23e8feba2454a329185..a9a486975564ac61ef1491f5cbe57f3bca0ed085 100644 --- a/style_bert_vits2/models/infer.py +++ b/style_bert_vits2/models/infer.py @@ -102,6 +102,7 @@ def get_text( device: str, assist_text: Optional[str] = None, assist_text_weight: float = 0.7, + given_phone: Optional[list[str]] = None, given_tone: Optional[list[int]] = None, ): use_jp_extra = hps.version.endswith("JP-Extra") @@ -112,10 +113,44 @@ def get_text( use_jp_extra=use_jp_extra, raise_yomi_error=False, ) - if given_tone is not None: - if len(given_tone) != len(phone): + # phone と tone の両方が与えられた場合はそれを使う + if given_phone is not None and given_tone is not None: + # 指定された phone と指定された tone 両方の長さが一致していなければならない + if len(given_phone) != len(given_tone): + raise InvalidPhoneError( + f"Length of given_phone ({len(given_phone)}) != length of given_tone ({len(given_tone)})" + ) + # 与えられた音素数と pyopenjtalk で生成した読みの音素数が一致しない + if len(given_phone) != sum(word2ph): + # 日本語の場合、len(given_phone) と sum(word2ph) が一致するように word2ph を適切に調整する + # 他の言語は word2ph の調整方法が思いつかないのでエラー + if language_str == Languages.JP: + from style_bert_vits2.nlp.japanese.g2p import adjust_word2ph + + word2ph = adjust_word2ph(word2ph, phone, given_phone) + # 上記処理により word2ph の合計が given_phone の長さと一致するはず + # それでも一致しない場合、大半は読み上げテキストと given_phone が著しく乖離していて調整し切れなかったことを意味する + if len(given_phone) != sum(word2ph): + raise InvalidPhoneError( + f"Length of given_phone ({len(given_phone)}) != sum of word2ph ({sum(word2ph)})" + ) + else: + raise InvalidPhoneError( + f"Length of given_phone ({len(given_phone)}) != sum of word2ph ({sum(word2ph)})" + ) + phone = given_phone + # 生成あるいは指定された phone と指定された tone 両方の長さが一致していなければならない + if len(phone) != len(given_tone): + raise InvalidToneError( + f"Length of phone ({len(phone)}) != length of given_tone ({len(given_tone)})" + ) + tone = given_tone + # tone だけが与えられた場合は clean_text() で生成した phone と合わせて使う + elif given_tone is not None: + # 生成した phone と指定された tone 両方の長さが一致していなければならない + if len(phone) != len(given_tone): raise InvalidToneError( - f"Length of given_tone ({len(given_tone)}) != length of phone ({len(phone)})" + f"Length of phone ({len(phone)}) != length of given_tone ({len(given_tone)})" ) tone = given_tone phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) @@ -179,6 +214,7 @@ def infer( skip_end: bool = False, assist_text: Optional[str] = None, assist_text_weight: float = 0.7, + given_phone: Optional[list[str]] = None, given_tone: Optional[list[int]] = None, ): is_jp_extra = hps.version.endswith("JP-Extra") @@ -189,6 +225,7 @@ def infer( device, assist_text=assist_text, assist_text_weight=assist_text_weight, + given_phone=given_phone, given_tone=given_tone, ) if skip_start: @@ -263,5 +300,9 @@ def infer( return audio +class InvalidPhoneError(ValueError): + pass + + class InvalidToneError(ValueError): pass diff --git a/style_bert_vits2/models/models.py b/style_bert_vits2/models/models.py index 56fb27c62ed2b5d3a2c1b2d91ff77684e6d10543..eaff2fadfc168b97c1a59cdecf34a535ed9cd4b0 100644 --- a/style_bert_vits2/models/models.py +++ b/style_bert_vits2/models/models.py @@ -786,7 +786,7 @@ class ReferenceEncoder(nn.Module): for i in range(K) ] self.convs = nn.ModuleList(convs) - # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501 + # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) self.gru = nn.GRU( diff --git a/style_bert_vits2/models/models_jp_extra.py b/style_bert_vits2/models/models_jp_extra.py index 2850baf2004ff16b218e6272f4cc97feda439416..e8df7d9fa7860956ba8c58249f4cf7aac1de2f35 100644 --- a/style_bert_vits2/models/models_jp_extra.py +++ b/style_bert_vits2/models/models_jp_extra.py @@ -844,7 +844,7 @@ class ReferenceEncoder(nn.Module): for i in range(K) ] self.convs = nn.ModuleList(convs) - # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501 + # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) self.gru = nn.GRU( diff --git a/style_bert_vits2/models/utils/__init__.py b/style_bert_vits2/models/utils/__init__.py index edd51ccb691311627f93c5923c53c8fda200c5d8..b91c74aeead2f410a29f94c86f9b1861f4b32d2f 100644 --- a/style_bert_vits2/models/utils/__init__.py +++ b/style_bert_vits2/models/utils/__init__.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np import torch from numpy.typing import NDArray -from scipy.io.wavfile import read from style_bert_vits2.logging import logger from style_bert_vits2.models.utils import checkpoints # type: ignore @@ -162,6 +161,13 @@ def load_wav_to_torch(full_path: Union[str, Path]) -> tuple[torch.FloatTensor, i tuple[torch.FloatTensor, int]: 音声データのテンソルとサンプリングレート """ + # この関数は学習時以外使われないため、ライブラリとしての style_bert_vits2 が + # 重たい scipy に依存しないように遅延 import する + try: + from scipy.io.wavfile import read + except ImportError: + raise ImportError("scipy is required to load wav file") + sampling_rate, data = read(full_path) return torch.FloatTensor(data.astype(np.float32)), sampling_rate @@ -180,7 +186,7 @@ def load_filepaths_and_text( list[list[str]]: ファイルパスとテキストのリスト """ - with open(filename, "r", encoding="utf-8") as f: + with open(filename, encoding="utf-8") as f: filepaths_and_text = [line.strip().split(split) for line in f] return filepaths_and_text @@ -239,9 +245,7 @@ def check_git_hash(model_dir_path: Union[str, Path]) -> None: source_dir = os.path.dirname(os.path.realpath(__file__)) if not os.path.exists(os.path.join(source_dir, ".git")): logger.warning( - "{} is not a git repository, therefore hash value comparison will be ignored.".format( - source_dir - ) + f"{source_dir} is not a git repository, therefore hash value comparison will be ignored." ) return @@ -249,13 +253,11 @@ def check_git_hash(model_dir_path: Union[str, Path]) -> None: path = os.path.join(model_dir_path, "githash") if os.path.exists(path): - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: saved_hash = f.read() if saved_hash != cur_hash: logger.warning( - "git hash values are different. {}(saved) != {}(current)".format( - saved_hash[:8], cur_hash[:8] - ) + f"git hash values are different. {saved_hash[:8]}(saved) != {cur_hash[:8]}(current)" ) else: with open(path, "w", encoding="utf-8") as f: diff --git a/style_bert_vits2/models/utils/safetensors.py b/style_bert_vits2/models/utils/safetensors.py index 52ab115b29ebd1d6d1edb4e14085410c0e8b5132..4b4ef3fbd7c7528df03252fa2840b573cfcbad06 100644 --- a/style_bert_vits2/models/utils/safetensors.py +++ b/style_bert_vits2/models/utils/safetensors.py @@ -77,7 +77,7 @@ def save_safetensors( keys = [] for k in state_dict: if "enc_q" in k and for_infer: - continue # noqa: E701 + continue keys.append(k) new_dict = ( diff --git a/style_bert_vits2/nlp/chinese/g2p.py b/style_bert_vits2/nlp/chinese/g2p.py index f38e09fa823d6058be91cfe6c062dd52485fe36d..1f0894f9526dec74977aa99896e9a4a47ee0340e 100644 --- a/style_bert_vits2/nlp/chinese/g2p.py +++ b/style_bert_vits2/nlp/chinese/g2p.py @@ -8,7 +8,7 @@ from style_bert_vits2.nlp.chinese.tone_sandhi import ToneSandhi from style_bert_vits2.nlp.symbols import PUNCTUATIONS -with open(Path(__file__).parent / "opencpop-strict.txt", "r", encoding="utf-8") as f: +with open(Path(__file__).parent / "opencpop-strict.txt", encoding="utf-8") as f: __PINYIN_TO_SYMBOL_MAP = { line.split("\t")[0]: line.strip().split("\t")[1] for line in f.readlines() } @@ -73,7 +73,7 @@ def __g2p(segments: list[str]) -> tuple[list[str], list[int], list[int]]: "iou": "iu", "uen": "un", } - if v_without_tone in v_rep_map.keys(): + if v_without_tone in v_rep_map: pinyin = c + v_rep_map[v_without_tone] else: # 单音节 @@ -83,7 +83,7 @@ def __g2p(segments: list[str]) -> tuple[list[str], list[int], list[int]]: "in": "yin", "u": "wu", } - if pinyin in pinyin_rep_map.keys(): + if pinyin in pinyin_rep_map: pinyin = pinyin_rep_map[pinyin] else: single_rep_map = { @@ -92,10 +92,10 @@ def __g2p(segments: list[str]) -> tuple[list[str], list[int], list[int]]: "i": "y", "u": "w", } - if pinyin[0] in single_rep_map.keys(): + if pinyin[0] in single_rep_map: pinyin = single_rep_map[pinyin[0]] + pinyin[1:] - assert pinyin in __PINYIN_TO_SYMBOL_MAP.keys(), ( + assert pinyin in __PINYIN_TO_SYMBOL_MAP, ( pinyin, seg, raw_pinyin, diff --git a/style_bert_vits2/nlp/chinese/normalizer.py b/style_bert_vits2/nlp/chinese/normalizer.py index c56c636f3ef2e7d1a7e7af883fd558525bd994f1..c076408dccc569058de0da499f0bab403cadf0a4 100644 --- a/style_bert_vits2/nlp/chinese/normalizer.py +++ b/style_bert_vits2/nlp/chinese/normalizer.py @@ -5,6 +5,41 @@ import cn2an from style_bert_vits2.nlp.symbols import PUNCTUATIONS +__REPLACE_MAP = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "...": "…", + "$": ".", + "“": "'", + "”": "'", + '"': "'", + "‘": "'", + "’": "'", + "(": "'", + ")": "'", + "(": "'", + ")": "'", + "《": "'", + "》": "'", + "【": "'", + "】": "'", + "[": "'", + "]": "'", + "—": "-", + "~": "-", + "~": "-", + "「": "'", + "」": "'", +} + + def normalize_text(text: str) -> str: numbers = re.findall(r"\d+(?:\.?\d+)?", text) for number in numbers: @@ -15,44 +50,10 @@ def normalize_text(text: str) -> str: def replace_punctuation(text: str) -> str: - REPLACE_MAP = { - ":": ",", - ";": ",", - ",": ",", - "。": ".", - "!": "!", - "?": "?", - "\n": ".", - "·": ",", - "、": ",", - "...": "…", - "$": ".", - "“": "'", - "”": "'", - '"': "'", - "‘": "'", - "’": "'", - "(": "'", - ")": "'", - "(": "'", - ")": "'", - "《": "'", - "》": "'", - "【": "'", - "】": "'", - "[": "'", - "]": "'", - "—": "-", - "~": "-", - "~": "-", - "「": "'", - "」": "'", - } - text = text.replace("嗯", "恩").replace("呣", "母") - pattern = re.compile("|".join(re.escape(p) for p in REPLACE_MAP.keys())) + pattern = re.compile("|".join(re.escape(p) for p in __REPLACE_MAP)) - replaced_text = pattern.sub(lambda x: REPLACE_MAP[x.group()], text) + replaced_text = pattern.sub(lambda x: __REPLACE_MAP[x.group()], text) replaced_text = re.sub( r"[^\u4e00-\u9fa5" + "".join(PUNCTUATIONS) + r"]+", "", replaced_text diff --git a/style_bert_vits2/nlp/chinese/tone_sandhi.py b/style_bert_vits2/nlp/chinese/tone_sandhi.py index 552cb0d366188a92b83a9184b2b541fe5d763a20..4945ab952d724bfc43321ef98a29f7cc3ccb6758 100644 --- a/style_bert_vits2/nlp/chinese/tone_sandhi.py +++ b/style_bert_vits2/nlp/chinese/tone_sandhi.py @@ -471,26 +471,27 @@ class ToneSandhi: ): finals[j] = finals[j][:-1] + "5" ge_idx = word.find("个") - if len(word) >= 1 and word[-1] in "吧呢啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶": - finals[-1] = finals[-1][:-1] + "5" - elif len(word) >= 1 and word[-1] in "的地得": - finals[-1] = finals[-1][:-1] + "5" - # e.g. 走了, 看着, 去过 - # elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}: - # finals[-1] = finals[-1][:-1] + "5" - elif ( - len(word) > 1 - and word[-1] in "们子" - and pos in {"r", "n"} - and word not in self.must_not_neural_tone_words + if ( + len(word) >= 1 + and word[-1] in "吧呢啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶" + or len(word) >= 1 + and word[-1] in "的地得" + or ( + ( + len(word) > 1 + and word[-1] in "们子" + and pos in {"r", "n"} + and word not in self.must_not_neural_tone_words + ) + or len(word) > 1 + and word[-1] in "上下里" + and pos in {"s", "l", "f"} + ) + or len(word) > 1 + and word[-1] in "来去" + and word[-2] in "上下进出回过起开" ): finals[-1] = finals[-1][:-1] + "5" - # e.g. 桌上, 地下, 家里 - elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}: - finals[-1] = finals[-1][:-1] + "5" - # e.g. 上来, 下去 - elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开": - finals[-1] = finals[-1][:-1] + "5" # 个做量词 elif ( ge_idx >= 1 @@ -500,12 +501,11 @@ class ToneSandhi: ) ) or word == "个": finals[ge_idx] = finals[ge_idx][:-1] + "5" - else: - if ( - word in self.must_neural_tone_words - or word[-2:] in self.must_neural_tone_words - ): - finals[-1] = finals[-1][:-1] + "5" + elif ( + word in self.must_neural_tone_words + or word[-2:] in self.must_neural_tone_words + ): + finals[-1] = finals[-1][:-1] + "5" word_list = self._split_word(word) finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]] @@ -549,10 +549,8 @@ class ToneSandhi: if finals[i + 1][-1] == "4": finals[i] = finals[i][:-1] + "2" # "一" before non-tone4 should be yi4, e.g. 一天 - else: - # "一" 后面如果是标点,还读一声 - if word[i + 1] not in self.punc: - finals[i] = finals[i][:-1] + "4" + elif word[i + 1] not in self.punc: + finals[i] = finals[i][:-1] + "4" return finals def _split_word(self, word: str) -> list[str]: diff --git a/style_bert_vits2/nlp/english/cmudict.py b/style_bert_vits2/nlp/english/cmudict.py index 7772e77b853d2716c81cc81c18a3c132730e8263..fd17405f500ab387b6b8a4b40c31687bc281b13b 100644 --- a/style_bert_vits2/nlp/english/cmudict.py +++ b/style_bert_vits2/nlp/english/cmudict.py @@ -20,7 +20,7 @@ def get_dict() -> dict[str, list[list[str]]]: def read_dict() -> dict[str, list[list[str]]]: g2p_dict = {} start_line = 49 - with open(CMU_DICT_PATH, "r", encoding="utf-8") as f: + with open(CMU_DICT_PATH, encoding="utf-8") as f: line = f.readline() line_index = 1 while line: diff --git a/style_bert_vits2/nlp/english/g2p.py b/style_bert_vits2/nlp/english/g2p.py index db2a87f97cf70e0b30302c8589f283d0f788697e..4e3f9b329b40bd6a2aa462cebfe9efb71a3ad661 100644 --- a/style_bert_vits2/nlp/english/g2p.py +++ b/style_bert_vits2/nlp/english/g2p.py @@ -8,96 +8,95 @@ from style_bert_vits2.nlp.english.cmudict import get_dict from style_bert_vits2.nlp.symbols import PUNCTUATIONS, SYMBOLS -def g2p(text: str) -> tuple[list[str], list[int], list[int]]: - - ARPA = { - "AH0", - "S", - "AH1", - "EY2", - "AE2", - "EH0", - "OW2", - "UH0", - "NG", - "B", - "G", - "AY0", - "M", - "AA0", - "F", - "AO0", - "ER2", - "UH1", - "IY1", - "AH2", - "DH", - "IY0", - "EY1", - "IH0", - "K", - "N", - "W", - "IY2", - "T", - "AA1", - "ER1", - "EH2", - "OY0", - "UH2", - "UW1", - "Z", - "AW2", - "AW1", - "V", - "UW2", - "AA2", - "ER", - "AW0", - "UW0", - "R", - "OW1", - "EH1", - "ZH", - "AE0", - "IH2", - "IH", - "Y", - "JH", - "P", - "AY1", - "EY0", - "OY2", - "TH", - "HH", - "D", - "ER0", - "CH", - "AO1", - "AE1", - "AO2", - "OY1", - "AY2", - "IH1", - "OW0", - "L", - "SH", - } +# Initialize global variables once +ARPA = { + "AH0", + "S", + "AH1", + "EY2", + "AE2", + "EH0", + "OW2", + "UH0", + "NG", + "B", + "G", + "AY0", + "M", + "AA0", + "F", + "AO0", + "ER2", + "UH1", + "IY1", + "AH2", + "DH", + "IY0", + "EY1", + "IH0", + "K", + "N", + "W", + "IY2", + "T", + "AA1", + "ER1", + "EH2", + "OY0", + "UH2", + "UW1", + "Z", + "AW2", + "AW1", + "V", + "UW2", + "AA2", + "ER", + "AW0", + "UW0", + "R", + "OW1", + "EH1", + "ZH", + "AE0", + "IH2", + "IH", + "Y", + "JH", + "P", + "AY1", + "EY0", + "OY2", + "TH", + "HH", + "D", + "ER0", + "CH", + "AO1", + "AE1", + "AO2", + "OY1", + "AY2", + "IH1", + "OW0", + "L", + "SH", +} +_g2p = G2p() +eng_dict = get_dict() - _g2p = G2p() +def g2p(text: str) -> tuple[list[str], list[int], list[int]]: phones = [] tones = [] phone_len = [] - # tokens = [tokenizer.tokenize(i) for i in words] words = __text_to_words(text) - eng_dict = get_dict() for word in words: temp_phones, temp_tones = [], [] - if len(word) > 1: - if "'" in word: - word = ["".join(word)] + if len(word) > 1 and "'" in word: + word = ["".join(word)] + for w in word: if w in PUNCTUATIONS: temp_phones.append(w) @@ -107,11 +106,9 @@ def g2p(text: str) -> tuple[list[str], list[int], list[int]]: phns, tns = __refine_syllables(eng_dict[w.upper()]) temp_phones += [__post_replace_ph(i) for i in phns] temp_tones += tns - # w2ph.append(len(phns)) else: - phone_list = list(filter(lambda p: p != " ", _g2p(w))) # type: ignore - phns = [] - tns = [] + phone_list = list(filter(lambda p: p != " ", _g2p(w))) + phns, tns = [], [] for ph in phone_list: if ph in ARPA: ph, tn = __refine_ph(ph) @@ -122,17 +119,15 @@ def g2p(text: str) -> tuple[list[str], list[int], list[int]]: tns.append(0) temp_phones += [__post_replace_ph(i) for i in phns] temp_tones += tns + phones += temp_phones tones += temp_tones phone_len.append(len(temp_phones)) - # phones = [post_replace_ph(i) for i in phones] word2ph = [] for token, pl in zip(words, phone_len): word_len = len(token) - - aaa = __distribute_phone(pl, word_len) - word2ph += aaa + word2ph += __distribute_phone(pl, word_len) phones = ["_"] + phones + ["_"] tones = [0] + tones + [0] @@ -159,13 +154,11 @@ def __post_replace_ph(ph: str) -> str: "・・・": "...", "v": "V", } - if ph in REPLACE_MAP.keys(): + if ph in REPLACE_MAP: ph = REPLACE_MAP[ph] if ph in SYMBOLS: return ph - if ph not in SYMBOLS: - ph = "UNK" - return ph + return "UNK" def __refine_ph(phn: str) -> tuple[str, int]: @@ -182,8 +175,7 @@ def __refine_syllables(syllables: list[list[str]]) -> tuple[list[str], list[int] tones = [] phonemes = [] for phn_list in syllables: - for i in range(len(phn_list)): - phn = phn_list[i] + for phn in phn_list: phn, tone = __refine_ph(phn) phonemes.append(phn) tones.append(tone) @@ -206,24 +198,22 @@ def __text_to_words(text: str) -> list[list[str]]: for idx, t in enumerate(tokens): if t.startswith("▁"): words.append([t[1:]]) - else: - if t in PUNCTUATIONS: - if idx == len(tokens) - 1: - words.append([f"{t}"]) - else: - if ( - not tokens[idx + 1].startswith("▁") - and tokens[idx + 1] not in PUNCTUATIONS - ): - if idx == 0: - words.append([]) - words[-1].append(f"{t}") - else: - words.append([f"{t}"]) - else: + elif t in PUNCTUATIONS: + if idx == len(tokens) - 1: + words.append([f"{t}"]) + elif ( + not tokens[idx + 1].startswith("▁") + and tokens[idx + 1] not in PUNCTUATIONS + ): if idx == 0: words.append([]) words[-1].append(f"{t}") + else: + words.append([f"{t}"]) + else: + if idx == 0: + words.append([]) + words[-1].append(f"{t}") return words diff --git a/style_bert_vits2/nlp/english/normalizer.py b/style_bert_vits2/nlp/english/normalizer.py index f6ddc90c2fccce15c725b53c678f66fb576c3f95..88cec023a6183d333eb238ec9353346291184ebb 100644 --- a/style_bert_vits2/nlp/english/normalizer.py +++ b/style_bert_vits2/nlp/english/normalizer.py @@ -58,7 +58,7 @@ def replace_punctuation(text: str) -> str: "「": "'", "」": "'", } - pattern = re.compile("|".join(re.escape(p) for p in REPLACE_MAP.keys())) + pattern = re.compile("|".join(re.escape(p) for p in REPLACE_MAP)) replaced_text = pattern.sub(lambda x: REPLACE_MAP[x.group()], text) # replaced_text = re.sub( # r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005" diff --git a/style_bert_vits2/nlp/japanese/g2p.py b/style_bert_vits2/nlp/japanese/g2p.py index 1f6b450d9795ae75ae7530d355cfb38f436e35f3..dab3dbf98d45491ffe3d74653ccb3163c63689e6 100644 --- a/style_bert_vits2/nlp/japanese/g2p.py +++ b/style_bert_vits2/nlp/japanese/g2p.py @@ -1,10 +1,11 @@ import re +from typing import TypedDict from style_bert_vits2.constants import Languages from style_bert_vits2.logging import logger from style_bert_vits2.nlp import bert_models from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk -from style_bert_vits2.nlp.japanese.mora_list import MORA_KATA_TO_MORA_PHONEMES +from style_bert_vits2.nlp.japanese.mora_list import MORA_KATA_TO_MORA_PHONEMES, VOWELS from style_bert_vits2.nlp.japanese.normalizer import replace_punctuation from style_bert_vits2.nlp.symbols import PUNCTUATIONS @@ -23,7 +24,7 @@ def g2p( Args: norm_text (str): 正規化されたテキスト use_jp_extra (bool, optional): False の場合、「ん」の音素を「N」ではなく「n」とする。Defaults to True. - raise_yomi_error (bool, optional): False の場合、読めない文字が消えたような扱いとして処理される。Defaults to False. + raise_yomi_error (bool, optional): False の場合、読めない文字が「'」として発音される。Defaults to False. Returns: tuple[list[str], list[int], list[int]]: 音素のリスト、アクセントのリスト、word2ph のリスト @@ -38,8 +39,8 @@ def g2p( # punctuation がすべて消えた、音素とアクセントのタプルのリスト(「ん」は「N」) phone_tone_list_wo_punct = __g2phone_tone_wo_punct(norm_text) - # sep_text: 単語単位の単語のリスト、読めない文字があったら raise_yomi_error なら例外、そうでないなら読めない文字が消えて返ってくる - # sep_kata: 単語単位の単語のカタカナ読みのリスト + # sep_text: 単語単位の単語のリスト + # sep_kata: 単語単位の単語のカタカナ読みのリスト、読めない文字は raise_yomi_error=True なら例外、False なら読めない文字を「'」として返ってくる sep_text, sep_kata = text_to_sep_kata(norm_text, raise_yomi_error=raise_yomi_error) # sep_phonemes: 各単語ごとの音素のリストのリスト @@ -103,7 +104,7 @@ def text_to_sep_kata( Args: norm_text (str): 正規化されたテキスト - raise_yomi_error (bool, optional): False の場合、読めない文字が消えたような扱いとして処理される。Defaults to False. + raise_yomi_error (bool, optional): False の場合、読めない文字が「'」として発音される。Defaults to False. Returns: tuple[list[str], list[str]]: 分割された単語リストと、その読み(カタカナ or 記号1文字)のリスト @@ -137,12 +138,19 @@ def text_to_sep_kata( # word は正規化されているので、`.`, `,`, `!`, `'`, `-`, `--` のいずれか if not set(word).issubset(set(PUNCTUATIONS)): # 記号繰り返しか判定 # ここは pyopenjtalk が読めない文字等のときに起こる + ## 例外を送出する場合 if raise_yomi_error: raise YomiError(f"Cannot read: {word} in:\n{norm_text}") - logger.warning(f"Ignoring unknown: {word} in:\n{norm_text}") - continue - # yomi は元の記号のままに変更 - yomi = word + ## 例外を送出しない場合 + ## 読めない文字は「'」として扱う + logger.warning( + f'Cannot read: {word} in:\n{norm_text}, replaced with "\'"' + ) + # word の文字数分「'」を追加 + yomi = "'" * len(word) + else: + # yomi は元の記号のままに変更 + yomi = word elif yomi == "?": assert word == "?", f"yomi `?` comes from: {word}" yomi = "?" @@ -152,6 +160,217 @@ def text_to_sep_kata( return sep_text, sep_kata +def adjust_word2ph( + word2ph: list[int], + generated_phone: list[str], + given_phone: list[str], +) -> list[int]: + """ + `g2p()` で得られた `word2ph` を、generated_phone と given_phone の差分情報を使っていい感じに調整する。 + generated_phone は正規化された読み上げテキストから生成された読みの情報だが、 + given_phone で 同じ読み上げテキストに異なる読みが与えられた場合、正規化された読み上げテキストの各文字に + 音素が何文字割り当てられるかを示す word2ph の合計値が given_phone の長さ (音素数) と一致しなくなりうる + そこで generated_phone と given_phone の差分を取り変更箇所に対応する word2ph の要素の値だけを増減させ、 + アクセントへの影響を最低限に抑えつつ word2ph の合計値を given_phone の長さ (音素数) に一致させる。 + + Args: + word2ph (list[int]): 単語ごとの音素の数のリスト + generated_phone (list[str]): 生成された音素のリスト + given_phone (list[str]): 与えられた音素のリスト + + Returns: + list[int]: 修正された word2ph のリスト + """ + + # word2ph・generated_phone・given_phone 全ての先頭と末尾にダミー要素が入っているので、処理の都合上それらを削除 + # word2ph は先頭と末尾に 1 が入っている (返す際に再度追加する) + word2ph = word2ph[1:-1] + generated_phone = generated_phone[1:-1] + given_phone = given_phone[1:-1] + + class DiffDetail(TypedDict): + begin_index: int + end_index: int + value: list[str] + + class Diff(TypedDict): + generated: DiffDetail + given: DiffDetail + + def extract_differences( + generated_phone: list[str], given_phone: list[str] + ) -> list[Diff]: + """ + 最長共通部分列を基にして、二つのリストの異なる部分を抽出する。 + """ + + def longest_common_subsequence( + X: list[str], Y: list[str] + ) -> list[tuple[int, int]]: + """ + 二つのリストの最長共通部分列のインデックスのペアを返す。 + """ + m, n = len(X), len(Y) + L = [[0] * (n + 1) for _ in range(m + 1)] + # LCSの長さを構築 + for i in range(1, m + 1): + for j in range(1, n + 1): + if X[i - 1] == Y[j - 1]: + L[i][j] = L[i - 1][j - 1] + 1 + else: + L[i][j] = max(L[i - 1][j], L[i][j - 1]) + # LCSを逆方向にトレースしてインデックスのペアを取得 + index_pairs = [] + i, j = m, n + while i > 0 and j > 0: + if X[i - 1] == Y[j - 1]: + index_pairs.append((i - 1, j - 1)) + i -= 1 + j -= 1 + elif L[i - 1][j] >= L[i][j - 1]: + i -= 1 + else: + j -= 1 + index_pairs.reverse() + return index_pairs + + differences = [] + common_indices = longest_common_subsequence(generated_phone, given_phone) + prev_x, prev_y = -1, -1 + + # 共通部分のインデックスを基にして差分を抽出 + for x, y in common_indices: + diff_X = { + "begin_index": prev_x + 1, + "end_index": x, + "value": generated_phone[prev_x + 1 : x], + } + diff_Y = { + "begin_index": prev_y + 1, + "end_index": y, + "value": given_phone[prev_y + 1 : y], + } + if diff_X or diff_Y: + differences.append({"generated": diff_X, "given": diff_Y}) + prev_x, prev_y = x, y + # 最後の非共通部分を追加 + if prev_x < len(generated_phone) - 1 or prev_y < len(given_phone) - 1: + differences.append( + { + "generated": { + "begin_index": prev_x + 1, + "end_index": len(generated_phone) - 1, + "value": generated_phone[prev_x + 1 : len(generated_phone) - 1], + }, + "given": { + "begin_index": prev_y + 1, + "end_index": len(given_phone) - 1, + "value": given_phone[prev_y + 1 : len(given_phone) - 1], + }, + } + ) + # generated.value と given.value の両方が空の要素を diffrences から削除 + for diff in differences[:]: + if ( + len(diff["generated"]["value"]) == 0 + and len(diff["given"]["value"]) == 0 + ): + differences.remove(diff) + + return differences + + # 二つのリストの差分を抽出 + differences = extract_differences(generated_phone, given_phone) + + # word2ph をもとにして新しく作る word2ph のリスト + ## 長さは word2ph と同じだが、中身は 0 で初期化されている + adjusted_word2ph: list[int] = [0] * len(word2ph) + # 現在処理中の generated_phone のインデックス + current_generated_index = 0 + + # word2ph の要素数 (=正規化された読み上げテキストの文字数) を維持しながら、差分情報を使って word2ph を修正 + ## 音素数が generated_phone と given_phone で異なる場合にこの align_word2ph() が呼び出される + ## word2ph は正規化された読み上げテキストの文字数に対応しているので、要素数はそのまま given_phone で増減した音素数に合わせて各要素の値を増減する + for word2ph_element_index, word2ph_element in enumerate(word2ph): + # ここの word2ph_element は、正規化された読み上げテキストの各文字に割り当てられる音素の数を示す + # 例えば word2ph_element が 2 ならば、その文字には 2 つの音素 (例: "k", "a") が割り当てられる + # 音素の数だけループを回す + for _ in range(word2ph_element): + # difference の中に 処理中の generated_phone から始まる差分があるかどうかを確認 + current_diff: Diff | None = None + for diff in differences: + if diff["generated"]["begin_index"] == current_generated_index: + current_diff = diff + break + # current_diff が None でない場合、generated_phone から始まる差分がある + if current_diff is not None: + # generated から given で変わった音素数の差分を取得 (2増えた場合は +2 だし、2減った場合は -2) + diff_in_phonemes = \ + len(current_diff["given"]["value"]) - len(current_diff["generated"]["value"]) # fmt: skip + # adjusted_word2ph[(読み上げテキストの各文字のインデックス)] に上記差分を反映 + adjusted_word2ph[word2ph_element_index] += diff_in_phonemes + # adjusted_word2ph[(読み上げテキストの各文字のインデックス)] に処理が完了した分の音素として 1 を加える + adjusted_word2ph[word2ph_element_index] += 1 + # 処理中の generated_phone のインデックスを進める + current_generated_index += 1 + + # この時点で given_phone の長さと adjusted_word2ph に記録されている音素数の合計が一致しているはず + assert len(given_phone) == sum(adjusted_word2ph), f"{len(given_phone)} != {sum(adjusted_word2ph)}" # fmt: skip + + # generated_phone から given_phone の間で音素が減った場合 (例: a, sh, i, t, a -> a, s, u) 、 + # adjusted_word2ph の要素の値が 1 未満になることがあるので、1 になるように値を増やす + ## この時、adjusted_word2ph に記録されている音素数の合計を変えないために、 + ## 値を 1 にした分だけ右隣の要素から増やした分の差分を差し引く + for adjusted_word2ph_element_index, adjusted_word2ph_element in enumerate(adjusted_word2ph): # fmt: skip + # もし現在の要素が 1 未満ならば + if adjusted_word2ph_element < 1: + # 値を 1 にするためにどれだけ足せばいいかを計算 + diff = 1 - adjusted_word2ph_element + # adjusted_word2ph[(読み上げテキストの各文字のインデックス)] を 1 にする + # これにより、当該文字に最低ラインとして 1 つの音素が割り当てられる + adjusted_word2ph[adjusted_word2ph_element_index] = 1 + # 次の要素のうち、一番近くてかつ 1 以上の要素から diff を引く + # この時、diff を引いた結果引いた要素が 1 未満になる場合は、その要素の次の要素の中から一番近くてかつ 1 以上の要素から引く + # 上記を繰り返していって、diff が 0 になるまで続ける + for i in range(1, len(adjusted_word2ph)): + if adjusted_word2ph_element_index + i >= len(adjusted_word2ph): + break # adjusted_word2ph の最後に達した場合は諦める + if adjusted_word2ph[adjusted_word2ph_element_index + i] - diff >= 1: + adjusted_word2ph[adjusted_word2ph_element_index + i] -= diff + break + else: + diff -= adjusted_word2ph[adjusted_word2ph_element_index + i] - 1 + adjusted_word2ph[adjusted_word2ph_element_index + i] = 1 + if diff == 0: + break + + # 逆に、generated_phone から given_phone の間で音素が増えた場合 (例: a, s, u -> a, sh, i, t, a) 、 + # 1文字あたり7音素以上も割り当てられてしまう場合があるので、最大6音素にした上で削った分の差分を次の要素に加える + # 次の要素に差分を加えた結果7音素以上になってしまう場合は、その差分をさらに次の要素に加える + for adjusted_word2ph_element_index, adjusted_word2ph_element in enumerate(adjusted_word2ph): # fmt: skip + if adjusted_word2ph_element > 6: + diff = adjusted_word2ph_element - 6 + adjusted_word2ph[adjusted_word2ph_element_index] = 6 + for i in range(1, len(adjusted_word2ph)): + if adjusted_word2ph_element_index + i >= len(adjusted_word2ph): + break # adjusted_word2ph の最後に達した場合は諦める + if adjusted_word2ph[adjusted_word2ph_element_index + i] + diff <= 6: + adjusted_word2ph[adjusted_word2ph_element_index + i] += diff + break + else: + diff -= 6 - adjusted_word2ph[adjusted_word2ph_element_index + i] + adjusted_word2ph[adjusted_word2ph_element_index + i] = 6 + if diff == 0: + break + + # この時点で given_phone の長さと adjusted_word2ph に記録されている音素数の合計が一致していない場合、 + # 正規化された読み上げテキストと given_phone が著しく乖離していることを示す + # このとき、この関数の呼び出し元の get_text() にて InvalidPhoneError が送出される + + # 最初に削除した前後のダミー要素を追加して返す + return [1] + adjusted_word2ph + [1] + + def __g2phone_tone_wo_punct(text: str) -> list[tuple[str, int]]: """ テキストに対して、音素とアクセント(0か1)のペアのリストを返す。 @@ -209,15 +428,23 @@ def __g2phone_tone_wo_punct(text: str) -> list[tuple[str, int]]: return result +__PYOPENJTALK_G2P_PROSODY_A1_PATTERN = re.compile(r"/A:([0-9\-]+)\+") +__PYOPENJTALK_G2P_PROSODY_A2_PATTERN = re.compile(r"\+(\d+)\+") +__PYOPENJTALK_G2P_PROSODY_A3_PATTERN = re.compile(r"\+(\d+)/") +__PYOPENJTALK_G2P_PROSODY_E3_PATTERN = re.compile(r"!(\d+)_") +__PYOPENJTALK_G2P_PROSODY_F1_PATTERN = re.compile(r"/F:(\d+)_") +__PYOPENJTALK_G2P_PROSODY_P3_PATTERN = re.compile(r"\-(.*?)\+") + + def __pyopenjtalk_g2p_prosody( text: str, drop_unvoiced_vowels: bool = True ) -> list[str]: """ - ESPnet の実装から引用、変更点無し。「ん」は「N」なことに注意。 + ESPnet の実装から引用、概ね変更点無し。「ん」は「N」なことに注意。 ref: https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py ------------------------------------------------------------------------------------------ - Extract phoneme + prosoody symbol sequence from input full-context labels. + Extract phoneme + prosody symbol sequence from input full-context labels. The algorithm is based on `Prosodic features control by symbols as input of sequence-to-sequence acoustic modeling for neural TTS`_ with some r9y9's tweaks. @@ -238,8 +465,8 @@ def __pyopenjtalk_g2p_prosody( modeling for neural TTS`: https://doi.org/10.1587/transinf.2020EDP7104 """ - def _numeric_feature_by_regex(regex: str, s: str) -> int: - match = re.search(regex, s) + def _numeric_feature_by_regex(pattern: re.Pattern[str], s: str) -> int: + match = pattern.search(s) if match is None: return -50 return int(match.group(1)) @@ -252,7 +479,7 @@ def __pyopenjtalk_g2p_prosody( lab_curr = labels[n] # current phoneme - p3 = re.search(r"\-(.*?)\+", lab_curr).group(1) # type: ignore + p3 = __PYOPENJTALK_G2P_PROSODY_P3_PATTERN.search(lab_curr).group(1) # type: ignore # deal unvoiced vowels as normal vowels if drop_unvoiced_vowels and p3 in "AEIOU": p3 = p3.lower() @@ -264,7 +491,9 @@ def __pyopenjtalk_g2p_prosody( phones.append("^") elif n == N - 1: # check question form or not - e3 = _numeric_feature_by_regex(r"!(\d+)_", lab_curr) + e3 = _numeric_feature_by_regex( + __PYOPENJTALK_G2P_PROSODY_E3_PATTERN, lab_curr + ) if e3 == 0: phones.append("$") elif e3 == 1: @@ -277,14 +506,16 @@ def __pyopenjtalk_g2p_prosody( phones.append(p3) # accent type and position info (forward or backward) - a1 = _numeric_feature_by_regex(r"/A:([0-9\-]+)\+", lab_curr) - a2 = _numeric_feature_by_regex(r"\+(\d+)\+", lab_curr) - a3 = _numeric_feature_by_regex(r"\+(\d+)/", lab_curr) + a1 = _numeric_feature_by_regex(__PYOPENJTALK_G2P_PROSODY_A1_PATTERN, lab_curr) + a2 = _numeric_feature_by_regex(__PYOPENJTALK_G2P_PROSODY_A2_PATTERN, lab_curr) + a3 = _numeric_feature_by_regex(__PYOPENJTALK_G2P_PROSODY_A3_PATTERN, lab_curr) # number of mora in accent phrase - f1 = _numeric_feature_by_regex(r"/F:(\d+)_", lab_curr) + f1 = _numeric_feature_by_regex(__PYOPENJTALK_G2P_PROSODY_F1_PATTERN, lab_curr) - a2_next = _numeric_feature_by_regex(r"\+(\d+)\+", labels[n + 1]) + a2_next = _numeric_feature_by_regex( + __PYOPENJTALK_G2P_PROSODY_A2_PATTERN, labels[n + 1] + ) # accent phrase border if a3 == 1 and a2_next == 1 and p3 in "aeiouAEIOUNcl": phones.append("#") @@ -341,9 +572,6 @@ def __handle_long(sep_phonemes: list[list[str]]) -> list[list[str]]: list[list[str]]: 長音記号を処理した音素のリストのリスト """ - # 母音の集合 (便宜上「ん」を含める) - VOWELS = {"a", "i", "u", "e", "o", "N"} - for i in range(len(sep_phonemes)): if len(sep_phonemes[i]) == 0: # 空白文字等でリストが空の場合 @@ -369,6 +597,15 @@ def __handle_long(sep_phonemes: list[list[str]]) -> list[list[str]]: return sep_phonemes +__KATAKANA_PATTERN = re.compile(r"[\u30A0-\u30FF]+") +__MORA_PATTERN = re.compile( + "|".join( + map(re.escape, sorted(MORA_KATA_TO_MORA_PHONEMES.keys(), key=len, reverse=True)) + ) +) +__LONG_PATTERN = re.compile(r"(\w)(ー*)") + + def __kata_to_phoneme_list(text: str) -> list[str]: """ 原則カタカナの `text` を受け取り、それをそのままいじらずに音素記号のリストに変換。 @@ -391,23 +628,20 @@ def __kata_to_phoneme_list(text: str) -> list[str]: if set(text).issubset(set(PUNCTUATIONS)): return list(text) # `text` がカタカナ(`ー`含む)のみからなるかどうかをチェック - if re.fullmatch(r"[\u30A0-\u30FF]+", text) is None: + if __KATAKANA_PATTERN.fullmatch(text) is None: raise ValueError(f"Input must be katakana only: {text}") - sorted_keys = sorted(MORA_KATA_TO_MORA_PHONEMES.keys(), key=len, reverse=True) - pattern = "|".join(map(re.escape, sorted_keys)) def mora2phonemes(mora: str) -> str: - cosonant, vowel = MORA_KATA_TO_MORA_PHONEMES[mora] - if cosonant is None: + consonant, vowel = MORA_KATA_TO_MORA_PHONEMES[mora] + if consonant is None: return f" {vowel}" - return f" {cosonant} {vowel}" + return f" {consonant} {vowel}" - spaced_phonemes = re.sub(pattern, lambda m: mora2phonemes(m.group()), text) + spaced_phonemes = __MORA_PATTERN.sub(lambda m: mora2phonemes(m.group()), text) # 長音記号「ー」の処理 - long_pattern = r"(\w)(ー*)" long_replacement = lambda m: m.group(1) + (" " + m.group(1)) * len(m.group(2)) # type: ignore - spaced_phonemes = re.sub(long_pattern, long_replacement, spaced_phonemes) + spaced_phonemes = __LONG_PATTERN.sub(long_replacement, spaced_phonemes) return spaced_phonemes.strip().split(" ") @@ -485,5 +719,3 @@ class YomiError(Exception): 基本的に「学習の前処理のテキスト処理時」には発生させ、そうでない場合は、 ignore_yomi_error=True にしておいて、この例外を発生させないようにする。 """ - - pass diff --git a/style_bert_vits2/nlp/japanese/g2p_utils.py b/style_bert_vits2/nlp/japanese/g2p_utils.py index 511793f34f18783b55d210f97b39d7f23ca30cc8..ce0b049a45075edefea0e1826834dc5abd785b08 100644 --- a/style_bert_vits2/nlp/japanese/g2p_utils.py +++ b/style_bert_vits2/nlp/japanese/g2p_utils.py @@ -1,5 +1,6 @@ from style_bert_vits2.nlp.japanese.g2p import g2p from style_bert_vits2.nlp.japanese.mora_list import ( + CONSONANTS, MORA_KATA_TO_MORA_PHONEMES, MORA_PHONEMES_TO_MORA_KATA, ) @@ -33,15 +34,6 @@ def phone_tone2kata_tone(phone_tone: list[tuple[str, int]]) -> list[tuple[str, i カタカナと音高のリスト。 """ - # 子音の集合 - CONSONANTS = set( - [ - consonant - for consonant, _ in MORA_KATA_TO_MORA_PHONEMES.values() - if consonant is not None - ] - ) - phone_tone = phone_tone[1:] # 最初の("_", 0)を無視 phones = [phone for phone, _ in phone_tone] tones = [tone for _, tone in phone_tone] diff --git a/style_bert_vits2/nlp/japanese/mora_list.py b/style_bert_vits2/nlp/japanese/mora_list.py index a0dab2fc7a6b60a000ff4a3f64d3d8fba6d0b131..db69c3495af87db2acededa4a15c193dfa3439e2 100644 --- a/style_bert_vits2/nlp/japanese/mora_list.py +++ b/style_bert_vits2/nlp/japanese/mora_list.py @@ -234,3 +234,15 @@ MORA_KATA_TO_MORA_PHONEMES: dict[str, tuple[Optional[str], str]] = { kana: (consonant, vowel) for [kana, consonant, vowel] in __MORA_LIST_MINIMUM + __MORA_LIST_ADDITIONAL } + +# 子音の集合 +CONSONANTS = set( + [ + consonant + for consonant, _ in MORA_KATA_TO_MORA_PHONEMES.values() + if consonant is not None + ] +) + +# 母音の集合 (便宜上「ん」を含める) +VOWELS = {"a", "i", "u", "e", "o", "N"} diff --git a/style_bert_vits2/nlp/japanese/normalizer.py b/style_bert_vits2/nlp/japanese/normalizer.py index 5ceb2f8c16e7f06123b6433046f29c08f1d0d7f1..7ecbfb6f0b36ff1620ede243d25112ce0dda6424 100644 --- a/style_bert_vits2/nlp/japanese/normalizer.py +++ b/style_bert_vits2/nlp/japanese/normalizer.py @@ -6,6 +6,81 @@ from num2words import num2words from style_bert_vits2.nlp.symbols import PUNCTUATIONS +# 記号類の正規化マップ +__REPLACE_MAP = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + ".": ".", + "…": "...", + "···": "...", + "・・・": "...", + "·": ",", + "・": ",", + "、": ",", + "$": ".", + "“": "'", + "”": "'", + '"': "'", + "‘": "'", + "’": "'", + "(": "'", + ")": "'", + "(": "'", + ")": "'", + "《": "'", + "》": "'", + "【": "'", + "】": "'", + "[": "'", + "]": "'", + # NFKC 正規化後のハイフン・ダッシュの変種を全て通常半角ハイフン - \u002d に変換 + "\u02d7": "\u002d", # ˗, Modifier Letter Minus Sign + "\u2010": "\u002d", # ‐, Hyphen, + # "\u2011": "\u002d", # ‑, Non-Breaking Hyphen, NFKC により \u2010 に変換される + "\u2012": "\u002d", # ‒, Figure Dash + "\u2013": "\u002d", # –, En Dash + "\u2014": "\u002d", # —, Em Dash + "\u2015": "\u002d", # ―, Horizontal Bar + "\u2043": "\u002d", # ⁃, Hyphen Bullet + "\u2212": "\u002d", # −, Minus Sign + "\u23af": "\u002d", # ⎯, Horizontal Line Extension + "\u23e4": "\u002d", # ⏤, Straightness + "\u2500": "\u002d", # ─, Box Drawings Light Horizontal + "\u2501": "\u002d", # ━, Box Drawings Heavy Horizontal + "\u2e3a": "\u002d", # ⸺, Two-Em Dash + "\u2e3b": "\u002d", # ⸻, Three-Em Dash + # "~": "-", # これは長音記号「ー」として扱うよう変更 + # "~": "-", # これも長音記号「ー」として扱うよう変更 + "「": "'", + "」": "'", +} +# 記号類の正規化パターン +__REPLACE_PATTERN = re.compile("|".join(re.escape(p) for p in __REPLACE_MAP)) +# 句読点等の正規化パターン +__PUNCTUATION_CLEANUP_PATTERN = re.compile( + # ↓ ひらがな、カタカナ、漢字 + r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005" + # ↓ 半角アルファベット(大文字と小文字) + + r"\u0041-\u005A\u0061-\u007A" + # ↓ 全角アルファベット(大文字と小文字) + + r"\uFF21-\uFF3A\uFF41-\uFF5A" + # ↓ ギリシャ文字 + + r"\u0370-\u03FF\u1F00-\u1FFF" + # ↓ "!", "?", "…", ",", ".", "'", "-", 但し`…`はすでに`...`に変換されている + + "".join(PUNCTUATIONS) + r"]+", # fmt: skip +) +# 数字・通貨記号の正規化パターン +__CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"} +__CURRENCY_PATTERN = re.compile(r"([$¥£€])([0-9.]*[0-9])") +__NUMBER_PATTERN = re.compile(r"[0-9]+(\.[0-9]+)?") +__NUMBER_WITH_SEPARATOR_PATTERN = re.compile("[0-9]{1,3}(,[0-9]{3})+") + + def normalize_text(text: str) -> str: """ 日本語のテキストを正規化する。 @@ -62,80 +137,11 @@ def replace_punctuation(text: str) -> str: str: 正規化されたテキスト """ - # 記号類の正規化変換マップ - REPLACE_MAP = { - ":": ",", - ";": ",", - ",": ",", - "。": ".", - "!": "!", - "?": "?", - "\n": ".", - ".": ".", - "…": "...", - "···": "...", - "・・・": "...", - "·": ",", - "・": ",", - "、": ",", - "$": ".", - "“": "'", - "”": "'", - '"': "'", - "‘": "'", - "’": "'", - "(": "'", - ")": "'", - "(": "'", - ")": "'", - "《": "'", - "》": "'", - "【": "'", - "】": "'", - "[": "'", - "]": "'", - # NFKC 正規化後のハイフン・ダッシュの変種を全て通常半角ハイフン - \u002d に変換 - "\u02d7": "\u002d", # ˗, Modifier Letter Minus Sign - "\u2010": "\u002d", # ‐, Hyphen, - # "\u2011": "\u002d", # ‑, Non-Breaking Hyphen, NFKC により \u2010 に変換される - "\u2012": "\u002d", # ‒, Figure Dash - "\u2013": "\u002d", # –, En Dash - "\u2014": "\u002d", # —, Em Dash - "\u2015": "\u002d", # ―, Horizontal Bar - "\u2043": "\u002d", # ⁃, Hyphen Bullet - "\u2212": "\u002d", # −, Minus Sign - "\u23af": "\u002d", # ⎯, Horizontal Line Extension - "\u23e4": "\u002d", # ⏤, Straightness - "\u2500": "\u002d", # ─, Box Drawings Light Horizontal - "\u2501": "\u002d", # ━, Box Drawings Heavy Horizontal - "\u2e3a": "\u002d", # ⸺, Two-Em Dash - "\u2e3b": "\u002d", # ⸻, Three-Em Dash - # "~": "-", # これは長音記号「ー」として扱うよう変更 - # "~": "-", # これも長音記号「ー」として扱うよう変更 - "「": "'", - "」": "'", - } - - pattern = re.compile("|".join(re.escape(p) for p in REPLACE_MAP.keys())) - # 句読点を辞書で置換 - replaced_text = pattern.sub(lambda x: REPLACE_MAP[x.group()], text) - - replaced_text = re.sub( - # ↓ ひらがな、カタカナ、漢字 - r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005" - # ↓ 半角アルファベット(大文字と小文字) - + r"\u0041-\u005A\u0061-\u007A" - # ↓ 全角アルファベット(大文字と小文字) - + r"\uFF21-\uFF3A\uFF41-\uFF5A" - # ↓ ギリシャ文字 - + r"\u0370-\u03FF\u1F00-\u1FFF" - # ↓ "!", "?", "…", ",", ".", "'", "-", 但し`…`はすでに`...`に変換されている - + "".join(PUNCTUATIONS) + r"]+", - # 上述以外の文字を削除 - "", - replaced_text, - ) + replaced_text = __REPLACE_PATTERN.sub(lambda x: __REPLACE_MAP[x.group()], text) + + # 上述以外の文字を削除 + replaced_text = __PUNCTUATION_CLEANUP_PATTERN.sub("", replaced_text) return replaced_text @@ -151,13 +157,8 @@ def __convert_numbers_to_words(text: str) -> str: str: 変換されたテキスト """ - NUMBER_WITH_SEPARATOR_PATTERN = re.compile("[0-9]{1,3}(,[0-9]{3})+") - CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"} - CURRENCY_PATTERN = re.compile(r"([$¥£€])([0-9.]*[0-9])") - NUMBER_PATTERN = re.compile(r"[0-9]+(\.[0-9]+)?") - - res = NUMBER_WITH_SEPARATOR_PATTERN.sub(lambda m: m[0].replace(",", ""), text) - res = CURRENCY_PATTERN.sub(lambda m: m[2] + CURRENCY_MAP.get(m[1], m[1]), res) - res = NUMBER_PATTERN.sub(lambda m: num2words(m[0], lang="ja"), res) + res = __NUMBER_WITH_SEPARATOR_PATTERN.sub(lambda m: m[0].replace(",", ""), text) + res = __CURRENCY_PATTERN.sub(lambda m: m[2] + __CURRENCY_MAP.get(m[1], m[1]), res) + res = __NUMBER_PATTERN.sub(lambda m: num2words(m[0], lang="ja"), res) return res diff --git a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py index 3a146b6716c578a70444484e7b3e97f77c769b10..d86646787a550aa8785b3554342be8483fb1924a 100644 --- a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py +++ b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py @@ -88,7 +88,7 @@ def initialize_worker(port: int = WORKER_PORT) -> None: client = None try: client = WorkerClient(port) - except (socket.timeout, socket.error): + except (OSError, socket.timeout): logger.debug("try starting pyopenjtalk worker server") import os import subprocess @@ -120,7 +120,7 @@ def initialize_worker(port: int = WORKER_PORT) -> None: try: client = WorkerClient(port) break - except socket.error: + except OSError: time.sleep(0.5) count += 1 # 20: max number of retries diff --git a/style_bert_vits2/nlp/japanese/user_dict/word_model.py b/style_bert_vits2/nlp/japanese/user_dict/word_model.py index c85a5b9543280d2d717799bccdf404df02184d62..43420da7103811f000d79d40df969b59d86eca8e 100644 --- a/style_bert_vits2/nlp/japanese/user_dict/word_model.py +++ b/style_bert_vits2/nlp/japanese/user_dict/word_model.py @@ -114,7 +114,7 @@ class PartOfSpeechDetail(BaseModel): part_of_speech_detail_2: str = Field(title="品詞細分類2") part_of_speech_detail_3: str = Field(title="品詞細分類3") # context_idは辞書の左・右文脈IDのこと - # https://github.com/VOICEVOX/open_jtalk/blob/427cfd761b78efb6094bea3c5bb8c968f0d711ab/src/mecab-naist-jdic/_left-id.def # noqa + # https://github.com/VOICEVOX/open_jtalk/blob/427cfd761b78efb6094bea3c5bb8c968f0d711ab/src/mecab-naist-jdic/_left-id.def context_id: int = Field(title="文脈ID") cost_candidates: List[int] = Field(title="コストのパーセンタイル") accent_associative_rules: List[str] = Field(title="アクセント結合規則の一覧") diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index 12803086ff082abc2b4a5e667ace720a2efcdd7a..6df8394aef662cc545142b07c2ee94dc24dd440b 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -1,12 +1,8 @@ -import warnings from pathlib import Path from typing import Any, Optional, Union -import gradio as gr import numpy as np -import pyannote.audio import torch -from gradio.processing_utils import convert_to_16_bit_wav from numpy.typing import NDArray from pydantic import BaseModel @@ -29,10 +25,16 @@ from style_bert_vits2.models.models import SynthesizerTrn from style_bert_vits2.models.models_jp_extra import ( SynthesizerTrn as SynthesizerTrnJPExtra, ) -from style_bert_vits2.nlp import bert_models from style_bert_vits2.voice import adjust_voice +# Gradio の import は重いため、ここでは型チェック時のみ import する +# ライブラリとしての利用を考慮し、TTSModelHolder の _for_gradio() 系メソッド以外では Gradio に依存しないようにする +# _for_gradio() 系メソッドの戻り値の型アノテーションを文字列としているのは、Gradio なしで実行できるようにするため +# if TYPE_CHECKING: +# import gradio as gr + + class TTSModel: """ Style-Bert-Vits2 の音声合成モデルを操作するクラス。 @@ -40,7 +42,11 @@ class TTSModel: """ def __init__( - self, model_path: Path, config_path: Path, style_vec_path: Path, device: str + self, + model_path: Path, + config_path: Union[Path, HyperParameters], + style_vec_path: Union[Path, NDArray[Any]], + device: str, ) -> None: """ Style-Bert-Vits2 の音声合成モデルを初期化する。 @@ -48,18 +54,34 @@ class TTSModel: Args: model_path (Path): モデル (.safetensors) のパス - config_path (Path): ハイパーパラメータ (config.json) のパス - style_vec_path (Path): スタイルベクトル (style_vectors.npy) のパス + config_path (Union[Path, HyperParameters]): ハイパーパラメータ (config.json) のパス (直接 HyperParameters を指定することも可能) + style_vec_path (Union[Path, NDArray[Any]]): スタイルベクトル (style_vectors.npy) のパス (直接 NDArray を指定することも可能) device (str): 音声合成時に利用するデバイス (cpu, cuda, mps など) """ self.model_path: Path = model_path - self.config_path: Path = config_path - self.style_vec_path: Path = style_vec_path self.device: str = device - self.hyper_parameters: HyperParameters = HyperParameters.load_from_json( - self.config_path - ) + + # ハイパーパラメータの Pydantic モデルが直接指定された + if isinstance(config_path, HyperParameters): + self.config_path: Path = Path("") # 互換性のため空の Path を設定 + self.hyper_parameters: HyperParameters = config_path + # ハイパーパラメータのパスが指定された + else: + self.config_path: Path = config_path + self.hyper_parameters: HyperParameters = HyperParameters.load_from_json( + self.config_path + ) + + # スタイルベクトルの NDArray が直接指定された + if isinstance(style_vec_path, np.ndarray): + self.style_vec_path: Path = Path("") # 互換性のため空の Path を設定 + self.__style_vectors: NDArray[Any] = style_vec_path + # スタイルベクトルのパスが指定された + else: + self.style_vec_path: Path = style_vec_path + self.__style_vectors: NDArray[Any] = np.load(self.style_vec_path) + self.spk2id: dict[str, int] = self.hyper_parameters.data.spk2id self.id2spk: dict[int, str] = {v: k for k, v in self.spk2id.items()} @@ -73,12 +95,11 @@ class TTSModel: f"Number of styles ({num_styles}) does not match the number of style2id ({len(self.style2id)})" ) - self.__style_vector_inference: Optional[pyannote.audio.Inference] = None - self.__style_vectors: NDArray[Any] = np.load(self.style_vec_path) if self.__style_vectors.shape[0] != num_styles: raise ValueError( f"The number of styles ({num_styles}) does not match the number of style vectors ({self.__style_vectors.shape[0]})" ) + self.__style_vector_inference: Optional[Any] = None self.__net_g: Union[SynthesizerTrn, SynthesizerTrnJPExtra, None] = None @@ -122,8 +143,18 @@ class TTSModel: NDArray[Any]: スタイルベクトル """ - # スタイルベクトルを取得するための推論モデルを初期化 if self.__style_vector_inference is None: + + # pyannote.audio は scikit-learn などの大量の重量級ライブラリに依存しているため、 + # TTSModel.infer() に reference_audio_path を指定し音声からスタイルベクトルを推論する場合のみ遅延 import する + try: + import pyannote.audio + except ImportError: + raise ImportError( + "pyannote.audio is required to infer style vector from audio" + ) + + # スタイルベクトルを取得するための推論モデルを初期化 self.__style_vector_inference = pyannote.audio.Inference( model=pyannote.audio.Model.from_pretrained( "pyannote/wespeaker-voxceleb-resnet34-LM" @@ -138,6 +169,43 @@ class TTSModel: xvec = mean + (xvec - mean) * weight return xvec + def __convert_to_16_bit_wav(self, data: NDArray[Any]) -> NDArray[Any]: + """ + 音声データを 16-bit int 形式に変換する。 + gradio.processing_utils.convert_to_16_bit_wav() を移植したもの。 + + Args: + data (NDArray[Any]): 音声データ + + Returns: + NDArray[Any]: 16-bit int 形式の音声データ + """ + # Based on: https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.write.html + if data.dtype in [np.float64, np.float32, np.float16]: # type: ignore + data = data / np.abs(data).max() + data = data * 32767 + data = data.astype(np.int16) + elif data.dtype == np.int32: + data = data / 65536 + data = data.astype(np.int16) + elif data.dtype == np.int16: + pass + elif data.dtype == np.uint16: + data = data - 32768 + data = data.astype(np.int16) + elif data.dtype == np.uint8: + data = data * 257 - 32768 + data = data.astype(np.int16) + elif data.dtype == np.int8: + data = data * 256 + data = data.astype(np.int16) + else: + raise ValueError( + "Audio data cannot be converted automatically from " + f"{data.dtype} to 16-bit int format." + ) + return data + def infer( self, text: str, @@ -155,6 +223,7 @@ class TTSModel: use_assist_text: bool = False, style: str = DEFAULT_STYLE, style_weight: float = DEFAULT_STYLE_WEIGHT, + given_phone: Optional[list[str]] = None, given_tone: Optional[list[int]] = None, pitch_scale: float = 1.0, intonation_scale: float = 1.0, @@ -171,13 +240,14 @@ class TTSModel: noise (float, optional): DP に与えられるノイズ. Defaults to DEFAULT_NOISE. noise_w (float, optional): SDP に与えられるノイズ. Defaults to DEFAULT_NOISEW. length (float, optional): 生成音声の長さ(話速)のパラメータ。大きいほど生成音声が長くゆっくり、小さいほど短く早くなる。 Defaults to DEFAULT_LENGTH. - line_split (bool, optional): テキストを改行ごとに分割して生成するかどうか. Defaults to DEFAULT_LINE_SPLIT. + line_split (bool, optional): テキストを改行ごとに分割して生成するかどうか (True の場合 given_phone/given_tone は無視される). Defaults to DEFAULT_LINE_SPLIT. split_interval (float, optional): 改行ごとに分割する場合の無音 (秒). Defaults to DEFAULT_SPLIT_INTERVAL. assist_text (Optional[str], optional): 感情表現の参照元の補助テキスト. Defaults to None. assist_text_weight (float, optional): 感情表現の補助テキストを適用する強さ. Defaults to DEFAULT_ASSIST_TEXT_WEIGHT. use_assist_text (bool, optional): 音声合成時に感情表現の補助テキストを使用するかどうか. Defaults to False. style (str, optional): 音声スタイル (Neutral, Happy など). Defaults to DEFAULT_STYLE. style_weight (float, optional): 音声スタイルを適用する強さ. Defaults to DEFAULT_STYLE_WEIGHT. + given_phone (Optional[list[int]], optional): 読み上げテキストの読みを表す音素列。指定する場合は given_tone も別途指定が必要. Defaults to None. given_tone (Optional[list[int]], optional): アクセントのトーンのリスト. Defaults to None. pitch_scale (float, optional): ピッチの高さ (1.0 から変更すると若干音質が低下する). Defaults to 1.0. intonation_scale (float, optional): 抑揚の平均からの変化幅 (1.0 から変更すると若干音質が低下する). Defaults to 1.0. @@ -222,6 +292,7 @@ class TTSModel: assist_text=assist_text, assist_text_weight=assist_text_weight, style_vec=style_vector, + given_phone=given_phone, given_tone=given_tone, ) else: @@ -258,9 +329,7 @@ class TTSModel: pitch_scale=pitch_scale, intonation_scale=intonation_scale, ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - audio = convert_to_16_bit_wav(audio) + audio = self.__convert_to_16_bit_wav(audio) return (self.hyper_parameters.data.sampling_rate, audio) @@ -377,9 +446,9 @@ class TTSModelHolder: return self.current_model - def get_model_for_gradio( - self, model_name: str, model_path_str: str - ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]: + def get_model_for_gradio(self, model_name: str, model_path_str: str): + import gradio as gr + model_path = Path(model_path_str) if model_name not in self.model_files_dict: raise ValueError(f"Model `{model_name}` is not found") @@ -411,16 +480,22 @@ class TTSModelHolder: gr.Dropdown(choices=speakers, value=speakers[0]), # type: ignore ) - def update_model_files_for_gradio(self, model_name: str) -> gr.Dropdown: - model_files = self.model_files_dict[model_name] + def update_model_files_for_gradio(self, model_name: str): + import gradio as gr + + model_files = [str(f) for f in self.model_files_dict[model_name]] return gr.Dropdown(choices=model_files, value=model_files[0]) # type: ignore def update_model_names_for_gradio( self, - ) -> tuple[gr.Dropdown, gr.Dropdown, gr.Button]: + ): + import gradio as gr + self.refresh() initial_model_name = self.model_names[0] - initial_model_files = self.model_files_dict[initial_model_name] + initial_model_files = [ + str(f) for f in self.model_files_dict[initial_model_name] + ] return ( gr.Dropdown(choices=self.model_names, value=initial_model_name), # type: ignore gr.Dropdown(choices=initial_model_files, value=initial_model_files[0]), # type: ignore diff --git a/style_bert_vits2/utils/subprocess.py b/style_bert_vits2/utils/subprocess.py index f8e8e94547a716267df25eda44f46d518637ade5..8f159a20fc2885853902f2bdc30a0e55aae32d5b 100644 --- a/style_bert_vits2/utils/subprocess.py +++ b/style_bert_vits2/utils/subprocess.py @@ -27,6 +27,7 @@ def run_script_with_log( stderr=subprocess.PIPE, text=True, encoding="utf-8", + check=False, ) if result.returncode != 0: logger.error(f"Error: {' '.join(cmd)}\n{result.stderr}") diff --git a/style_gen.py b/style_gen.py index a21ab2b33e8c87e3e012c31b69e2ed0a7fd5f736..1ced5f0f809705c0bdee7d5ea293f1e9d0caafaa 100644 --- a/style_gen.py +++ b/style_gen.py @@ -8,12 +8,14 @@ from numpy.typing import NDArray from pyannote.audio import Inference, Model from tqdm import tqdm -from config import config +from config import get_config from style_bert_vits2.logging import logger from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT +config = get_config() + model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") inference = Inference(model, window="whole") device = torch.device(config.style_gen_config.device) @@ -23,8 +25,6 @@ inference.to(device) class NaNValueError(ValueError): """カスタム例外クラス。NaN値が見つかった場合に使用されます。""" - pass - # 推論時にインポートするために短いが関数を書く def get_style_vector(wav_path: str) -> NDArray[Any]: @@ -72,7 +72,7 @@ if __name__ == "__main__": device = config.style_gen_config.device training_lines: list[str] = [] - with open(hps.data.training_files, "r", encoding="utf-8") as f: + with open(hps.data.training_files, encoding="utf-8") as f: training_lines.extend(f.readlines()) with ThreadPoolExecutor(max_workers=num_processes) as executor: training_results = list( @@ -93,7 +93,7 @@ if __name__ == "__main__": ) val_lines: list[str] = [] - with open(hps.data.validation_files, "r", encoding="utf-8") as f: + with open(hps.data.validation_files, encoding="utf-8") as f: val_lines.extend(f.readlines()) with ThreadPoolExecutor(max_workers=num_processes) as executor: diff --git a/train_ms.py b/train_ms.py index 067a6a571257b7331d781bec561c29f493aed914..1adee4e7767c234d356ee8ad8deb6b231c42b801 100644 --- a/train_ms.py +++ b/train_ms.py @@ -16,7 +16,7 @@ from tqdm import tqdm # logging.getLogger("numba").setLevel(logging.WARNING) import default_style -from config import config +from config import get_config from data_utils import ( DistributedBucketSampler, TextAudioSpeakerCollate, @@ -48,7 +48,7 @@ torch.backends.cuda.enable_mem_efficient_sdp( ) # Not available if torch version is lower than 2.0 torch.backends.cuda.enable_math_sdp(True) - +config = get_config() global_step = 0 api = HfApi() @@ -97,6 +97,11 @@ def run(): help="Huggingface model repo id to backup the model.", default=None, ) + parser.add_argument( + "--not_use_custom_batch_sampler", + help="Don't use custom batch sampler for training, which was used in the version < 2.5", + action="store_true", + ) args = parser.parse_args() # Set log file @@ -108,7 +113,7 @@ def run(): envs = config.train_ms_config.env for env_name, env_value in envs.items(): if env_name not in os.environ.keys(): - logger.info("Loading configuration from config {}".format(str(env_value))) + logger.info(f"Loading configuration from config {env_value!s}") os.environ[env_name] = str(env_value) logger.info( "Loading environment variables \nMASTER_ADDR: {},\nMASTER_PORT: {},\nWORLD_SIZE: {},\nRANK: {},\nLOCAL_RANK: {}".format( @@ -142,7 +147,7 @@ def run(): if os.path.realpath(args.config) != os.path.realpath( config.train_ms_config.config_path ): - with open(args.config, "r", encoding="utf-8") as f: + with open(args.config, encoding="utf-8") as f: data = f.read() os.makedirs(os.path.dirname(config.train_ms_config.config_path), exist_ok=True) with open(config.train_ms_config.config_path, "w", encoding="utf-8") as f: @@ -192,13 +197,11 @@ def run(): os.makedirs(config.out_dir, exist_ok=True) if not args.skip_default_style: - # Save default style to out_dir - default_style.set_style_config( - args.config, os.path.join(config.out_dir, "config.json") - ) - default_style.save_neutral_vector( + default_style.save_styles_by_dirs( os.path.join(args.model, "wavs"), - os.path.join(config.out_dir, "style_vectors.npy"), + config.out_dir, + config_path=args.config, + config_output_path=os.path.join(config.out_dir, "config.json"), ) torch.manual_seed(hps.train.seed) @@ -214,28 +217,45 @@ def run(): writer = SummaryWriter(log_dir=model_dir) writer_eval = SummaryWriter(log_dir=os.path.join(model_dir, "eval")) train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data) - train_sampler = DistributedBucketSampler( - train_dataset, - hps.train.batch_size, - [32, 300, 400, 500, 600, 700, 800, 900, 1000], - num_replicas=n_gpus, - rank=rank, - shuffle=True, - ) collate_fn = TextAudioSpeakerCollate() - train_loader = DataLoader( - train_dataset, - # メモリ消費量を減らそうとnum_workersを1にしてみる - # num_workers=min(config.train_ms_config.num_workers, os.cpu_count() // 2), - num_workers=1, - shuffle=False, - pin_memory=True, - collate_fn=collate_fn, - batch_sampler=train_sampler, - persistent_workers=True, - # これもメモリ消費量を減らそうとしてコメントアウト - # prefetch_factor=4, - ) # DataLoader config could be adjusted. + if not args.not_use_custom_batch_sampler: + train_sampler = DistributedBucketSampler( + train_dataset, + hps.train.batch_size, + [32, 300, 400, 500, 600, 700, 800, 900, 1000], + num_replicas=n_gpus, + rank=rank, + shuffle=True, + ) + train_loader = DataLoader( + train_dataset, + # メモリ消費量を減らそうとnum_workersを1にしてみる + # num_workers=min(config.train_ms_config.num_workers, os.cpu_count() // 2), + num_workers=1, + shuffle=False, + pin_memory=True, + collate_fn=collate_fn, + batch_sampler=train_sampler, + # batch_size=hps.train.batch_size, + persistent_workers=True, + # これもメモリ消費量を減らそうとしてコメントアウト + # prefetch_factor=6, + ) + else: + train_loader = DataLoader( + train_dataset, + # メモリ消費量を減らそうとnum_workersを1にしてみる + # num_workers=min(config.train_ms_config.num_workers, os.cpu_count() // 2), + num_workers=1, + shuffle=True, + pin_memory=True, + collate_fn=collate_fn, + # batch_sampler=train_sampler, + batch_size=hps.train.batch_size, + persistent_workers=True, + # これもメモリ消費量を減らそうとしてコメントアウト + # prefetch_factor=6, + ) eval_dataset = None eval_loader = None if rank == 0 and not args.speedup: @@ -505,7 +525,7 @@ def run(): optim_g, hps.train.learning_rate, epoch, - os.path.join(model_dir, "G_{}.pth".format(global_step)), + os.path.join(model_dir, f"G_{global_step}.pth"), ) assert optim_d is not None utils.checkpoints.save_checkpoint( @@ -513,7 +533,7 @@ def run(): optim_d, hps.train.learning_rate, epoch, - os.path.join(model_dir, "D_{}.pth".format(global_step)), + os.path.join(model_dir, f"D_{global_step}.pth"), ) if net_dur_disc is not None: assert optim_dur_disc is not None @@ -522,7 +542,7 @@ def run(): optim_dur_disc, hps.train.learning_rate, epoch, - os.path.join(model_dir, "DUR_{}.pth".format(global_step)), + os.path.join(model_dir, f"DUR_{global_step}.pth"), ) utils.safetensors.save_safetensors( net_g, @@ -757,34 +777,32 @@ def train_and_evaluate( "loss/g/kl": loss_kl, } ) + scalar_dict.update({f"loss/g/{i}": v for i, v in enumerate(losses_gen)}) scalar_dict.update( - {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)} + {f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)} ) scalar_dict.update( - {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)} + {f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)} ) - scalar_dict.update( - {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)} - ) - - image_dict = { - "slice/mel_org": utils.plot_spectrogram_to_numpy( - y_mel[0].data.cpu().numpy() - ), - "slice/mel_gen": utils.plot_spectrogram_to_numpy( - y_hat_mel[0].data.cpu().numpy() - ), - "all/mel": utils.plot_spectrogram_to_numpy( - mel[0].data.cpu().numpy() - ), - "all/attn": utils.plot_alignment_to_numpy( - attn[0, 0].data.cpu().numpy() - ), - } + # 以降のログは計算が重い気がするし誰も見てない気がするのでコメントアウト + # image_dict = { + # "slice/mel_org": utils.plot_spectrogram_to_numpy( + # y_mel[0].data.cpu().numpy() + # ), + # "slice/mel_gen": utils.plot_spectrogram_to_numpy( + # y_hat_mel[0].data.cpu().numpy() + # ), + # "all/mel": utils.plot_spectrogram_to_numpy( + # mel[0].data.cpu().numpy() + # ), + # "all/attn": utils.plot_alignment_to_numpy( + # attn[0, 0].data.cpu().numpy() + # ), + # } utils.summarize( writer=writer, global_step=global_step, - images=image_dict, + # images=image_dict, scalars=scalar_dict, ) @@ -801,14 +819,14 @@ def train_and_evaluate( optim_g, hps.train.learning_rate, epoch, - os.path.join(hps.model_dir, "G_{}.pth".format(global_step)), + os.path.join(hps.model_dir, f"G_{global_step}.pth"), ) utils.checkpoints.save_checkpoint( net_d, optim_d, hps.train.learning_rate, epoch, - os.path.join(hps.model_dir, "D_{}.pth".format(global_step)), + os.path.join(hps.model_dir, f"D_{global_step}.pth"), ) if net_dur_disc is not None: utils.checkpoints.save_checkpoint( @@ -816,7 +834,7 @@ def train_and_evaluate( optim_dur_disc, hps.train.learning_rate, epoch, - os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step)), + os.path.join(hps.model_dir, f"DUR_{global_step}.pth"), ) keep_ckpts = config.train_ms_config.keep_ckpts if keep_ckpts > 0: @@ -853,9 +871,7 @@ def train_and_evaluate( global_step += 1 if pbar is not None: pbar.set_description( - "Epoch {}({:.0f}%)/{}".format( - epoch, 100.0 * batch_idx / len(train_loader), hps.train.epochs - ) + f"Epoch {epoch}({100.0 * batch_idx / len(train_loader):.0f}%)/{hps.train.epochs}" ) pbar.update() # 本家ではこれをスピードアップのために消すと書かれていたので、一応消してみる @@ -870,6 +886,7 @@ def evaluate(hps, generator, eval_loader, writer_eval): generator.eval() image_dict = {} audio_dict = {} + print() logger.info("Evaluating ...") with torch.no_grad(): for batch_idx, ( @@ -913,32 +930,39 @@ def evaluate(hps, generator, eval_loader, writer_eval): sdp_ratio=0.0 if not use_sdp else 1.0, ) y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length - - mel = spec_to_mel_torch( - spec, - hps.data.filter_length, - hps.data.n_mel_channels, - hps.data.sampling_rate, - hps.data.mel_fmin, - hps.data.mel_fmax, - ) - y_hat_mel = mel_spectrogram_torch( - y_hat.squeeze(1).float(), - hps.data.filter_length, - hps.data.n_mel_channels, - hps.data.sampling_rate, - hps.data.hop_length, - hps.data.win_length, - hps.data.mel_fmin, - hps.data.mel_fmax, - ) - image_dict.update( - { - f"gen/mel_{batch_idx}": utils.plot_spectrogram_to_numpy( - y_hat_mel[0].cpu().numpy() - ) - } - ) + # 以降のログは計算が重い気がするし誰も見てない気がするのでコメントアウト + # mel = spec_to_mel_torch( + # spec, + # hps.data.filter_length, + # hps.data.n_mel_channels, + # hps.data.sampling_rate, + # hps.data.mel_fmin, + # hps.data.mel_fmax, + # ) + # y_hat_mel = mel_spectrogram_torch( + # y_hat.squeeze(1).float(), + # hps.data.filter_length, + # hps.data.n_mel_channels, + # hps.data.sampling_rate, + # hps.data.hop_length, + # hps.data.win_length, + # hps.data.mel_fmin, + # hps.data.mel_fmax, + # ) + # image_dict.update( + # { + # f"gen/mel_{batch_idx}": utils.plot_spectrogram_to_numpy( + # y_hat_mel[0].cpu().numpy() + # ) + # } + # ) + # image_dict.update( + # { + # f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy( + # mel[0].cpu().numpy() + # ) + # } + # ) audio_dict.update( { f"gen/audio_{batch_idx}_{use_sdp}": y_hat[ @@ -946,13 +970,6 @@ def evaluate(hps, generator, eval_loader, writer_eval): ] } ) - image_dict.update( - { - f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy( - mel[0].cpu().numpy() - ) - } - ) audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]}) utils.summarize( diff --git a/train_ms_jp_extra.py b/train_ms_jp_extra.py index ad3d3bd92e0682577bd43ff57121cfde4f12de97..e5c5bd198b030aad9b2f95336ce2e49472d8f9eb 100644 --- a/train_ms_jp_extra.py +++ b/train_ms_jp_extra.py @@ -16,7 +16,7 @@ from tqdm import tqdm # logging.getLogger("numba").setLevel(logging.WARNING) import default_style -from config import config +from config import get_config from data_utils import ( DistributedBucketSampler, TextAudioSpeakerCollate, @@ -48,6 +48,8 @@ torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp( True ) # Not available if torch version is lower than 2.0 + +config = get_config() global_step = 0 api = HfApi() @@ -96,6 +98,11 @@ def run(): help="Huggingface model repo id to backup the model.", default=None, ) + parser.add_argument( + "--not_use_custom_batch_sampler", + help="Don't use custom batch sampler for training, which was used in the version < 2.5", + action="store_true", + ) args = parser.parse_args() # Set log file @@ -107,7 +114,7 @@ def run(): envs = config.train_ms_config.env for env_name, env_value in envs.items(): if env_name not in os.environ.keys(): - logger.info("Loading configuration from config {}".format(str(env_value))) + logger.info(f"Loading configuration from config {env_value!s}") os.environ[env_name] = str(env_value) logger.info( "Loading environment variables \nMASTER_ADDR: {},\nMASTER_PORT: {},\nWORLD_SIZE: {},\nRANK: {},\nLOCAL_RANK: {}".format( @@ -141,7 +148,7 @@ def run(): if os.path.realpath(args.config) != os.path.realpath( config.train_ms_config.config_path ): - with open(args.config, "r", encoding="utf-8") as f: + with open(args.config, encoding="utf-8") as f: data = f.read() os.makedirs(os.path.dirname(config.train_ms_config.config_path), exist_ok=True) with open(config.train_ms_config.config_path, "w", encoding="utf-8") as f: @@ -191,13 +198,11 @@ def run(): os.makedirs(config.out_dir, exist_ok=True) if not args.skip_default_style: - # Save default style to out_dir - default_style.set_style_config( - args.config, os.path.join(config.out_dir, "config.json") - ) - default_style.save_neutral_vector( + default_style.save_styles_by_dirs( os.path.join(args.model, "wavs"), - os.path.join(config.out_dir, "style_vectors.npy"), + config.out_dir, + config_path=args.config, + config_output_path=os.path.join(config.out_dir, "config.json"), ) torch.manual_seed(hps.train.seed) @@ -213,28 +218,45 @@ def run(): writer = SummaryWriter(log_dir=model_dir) writer_eval = SummaryWriter(log_dir=os.path.join(model_dir, "eval")) train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data) - train_sampler = DistributedBucketSampler( - train_dataset, - hps.train.batch_size, - [32, 300, 400, 500, 600, 700, 800, 900, 1000], - num_replicas=n_gpus, - rank=rank, - shuffle=True, - ) collate_fn = TextAudioSpeakerCollate(use_jp_extra=True) - train_loader = DataLoader( - train_dataset, - # メモリ消費量を減らそうとnum_workersを1にしてみる - # num_workers=min(config.train_ms_config.num_workers, os.cpu_count() // 2), - num_workers=1, - shuffle=False, - pin_memory=True, - collate_fn=collate_fn, - batch_sampler=train_sampler, - persistent_workers=True, - # これもメモリ消費量を減らそうとしてコメントアウト - # prefetch_factor=6, - ) # DataLoader config could be adjusted. + if not args.not_use_custom_batch_sampler: + train_sampler = DistributedBucketSampler( + train_dataset, + hps.train.batch_size, + [32, 300, 400, 500, 600, 700, 800, 900, 1000], + num_replicas=n_gpus, + rank=rank, + shuffle=True, + ) + train_loader = DataLoader( + train_dataset, + # メモリ消費量を減らそうとnum_workersを1にしてみる + # num_workers=min(config.train_ms_config.num_workers, os.cpu_count() // 2), + num_workers=1, + shuffle=False, + pin_memory=True, + collate_fn=collate_fn, + batch_sampler=train_sampler, + # batch_size=hps.train.batch_size, + persistent_workers=True, + # これもメモリ消費量を減らそうとしてコメントアウト + # prefetch_factor=6, + ) + else: + train_loader = DataLoader( + train_dataset, + # メモリ消費量を減らそうとnum_workersを1にしてみる + # num_workers=min(config.train_ms_config.num_workers, os.cpu_count() // 2), + num_workers=1, + shuffle=True, + pin_memory=True, + collate_fn=collate_fn, + # batch_sampler=train_sampler, + batch_size=hps.train.batch_size, + persistent_workers=True, + # これもメモリ消費量を減らそうとしてコメントアウト + # prefetch_factor=6, + ) eval_dataset = None eval_loader = None if rank == 0 and not args.speedup: @@ -577,7 +599,7 @@ def run(): optim_g, hps.train.learning_rate, epoch, - os.path.join(model_dir, "G_{}.pth".format(global_step)), + os.path.join(model_dir, f"G_{global_step}.pth"), ) assert optim_d is not None utils.checkpoints.save_checkpoint( @@ -585,7 +607,7 @@ def run(): optim_d, hps.train.learning_rate, epoch, - os.path.join(model_dir, "D_{}.pth".format(global_step)), + os.path.join(model_dir, f"D_{global_step}.pth"), ) if net_dur_disc is not None: assert optim_dur_disc is not None @@ -594,7 +616,7 @@ def run(): optim_dur_disc, hps.train.learning_rate, epoch, - os.path.join(model_dir, "DUR_{}.pth".format(global_step)), + os.path.join(model_dir, f"DUR_{global_step}.pth"), ) if net_wd is not None: assert optim_wd is not None @@ -603,7 +625,7 @@ def run(): optim_wd, hps.train.learning_rate, epoch, - os.path.join(model_dir, "WD_{}.pth".format(global_step)), + os.path.join(model_dir, f"WD_{global_step}.pth"), ) utils.safetensors.save_safetensors( net_g, @@ -661,7 +683,7 @@ def train_and_evaluate( if writers is not None: writer, writer_eval = writers - train_loader.batch_sampler.set_epoch(epoch) + # train_loader.batch_sampler.set_epoch(epoch) global global_step net_g.train() @@ -867,14 +889,12 @@ def train_and_evaluate( "loss/g/kl": loss_kl, } ) + scalar_dict.update({f"loss/g/{i}": v for i, v in enumerate(losses_gen)}) scalar_dict.update( - {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)} + {f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)} ) scalar_dict.update( - {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)} - ) - scalar_dict.update( - {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)} + {f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)} ) if net_dur_disc is not None: @@ -882,23 +902,20 @@ def train_and_evaluate( scalar_dict.update( { - "loss/dur_disc_g/{}".format(i): v + f"loss/dur_disc_g/{i}": v for i, v in enumerate(losses_dur_disc_g) } ) scalar_dict.update( { - "loss/dur_disc_r/{}".format(i): v + f"loss/dur_disc_r/{i}": v for i, v in enumerate(losses_dur_disc_r) } ) scalar_dict.update({"loss/g/dur_gen": loss_dur_gen}) scalar_dict.update( - { - "loss/g/dur_gen_{}".format(i): v - for i, v in enumerate(losses_dur_gen) - } + {f"loss/g/dur_gen_{i}": v for i, v in enumerate(losses_dur_gen)} ) if net_wd is not None: @@ -910,24 +927,25 @@ def train_and_evaluate( "loss/g/lm_gen": loss_lm_gen, } ) - image_dict = { - "slice/mel_org": utils.plot_spectrogram_to_numpy( - y_mel[0].data.cpu().numpy() - ), - "slice/mel_gen": utils.plot_spectrogram_to_numpy( - y_hat_mel[0].data.cpu().numpy() - ), - "all/mel": utils.plot_spectrogram_to_numpy( - mel[0].data.cpu().numpy() - ), - "all/attn": utils.plot_alignment_to_numpy( - attn[0, 0].data.cpu().numpy() - ), - } + # 以降のログは計算が重い気がするし誰も見てない気がするのでコメントアウト + # image_dict = { + # "slice/mel_org": utils.plot_spectrogram_to_numpy( + # y_mel[0].data.cpu().numpy() + # ), + # "slice/mel_gen": utils.plot_spectrogram_to_numpy( + # y_hat_mel[0].data.cpu().numpy() + # ), + # "all/mel": utils.plot_spectrogram_to_numpy( + # mel[0].data.cpu().numpy() + # ), + # "all/attn": utils.plot_alignment_to_numpy( + # attn[0, 0].data.cpu().numpy() + # ), + # } utils.summarize( writer=writer, global_step=global_step, - images=image_dict, + # images=image_dict, scalars=scalar_dict, ) @@ -943,14 +961,14 @@ def train_and_evaluate( optim_g, hps.train.learning_rate, epoch, - os.path.join(hps.model_dir, "G_{}.pth".format(global_step)), + os.path.join(hps.model_dir, f"G_{global_step}.pth"), ) utils.checkpoints.save_checkpoint( net_d, optim_d, hps.train.learning_rate, epoch, - os.path.join(hps.model_dir, "D_{}.pth".format(global_step)), + os.path.join(hps.model_dir, f"D_{global_step}.pth"), ) if net_dur_disc is not None: utils.checkpoints.save_checkpoint( @@ -958,7 +976,7 @@ def train_and_evaluate( optim_dur_disc, hps.train.learning_rate, epoch, - os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step)), + os.path.join(hps.model_dir, f"DUR_{global_step}.pth"), ) if net_wd is not None: utils.checkpoints.save_checkpoint( @@ -966,7 +984,7 @@ def train_and_evaluate( optim_wd, hps.train.learning_rate, epoch, - os.path.join(hps.model_dir, "WD_{}.pth".format(global_step)), + os.path.join(hps.model_dir, f"WD_{global_step}.pth"), ) keep_ckpts = config.train_ms_config.keep_ckpts if keep_ckpts > 0: @@ -1004,9 +1022,7 @@ def train_and_evaluate( global_step += 1 if pbar is not None: pbar.set_description( - "Epoch {}({:.0f}%)/{}".format( - epoch, 100.0 * batch_idx / len(train_loader), hps.train.epochs - ) + f"Epoch {epoch}({100.0 * batch_idx / len(train_loader):.0f}%)/{hps.train.epochs}" ) pbar.update() @@ -1020,6 +1036,7 @@ def evaluate(hps, generator, eval_loader, writer_eval): generator.eval() image_dict = {} audio_dict = {} + print() logger.info("Evaluating ...") with torch.no_grad(): for batch_idx, ( @@ -1057,32 +1074,39 @@ def evaluate(hps, generator, eval_loader, writer_eval): sdp_ratio=0.0 if not use_sdp else 1.0, ) y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length - - mel = spec_to_mel_torch( - spec, - hps.data.filter_length, - hps.data.n_mel_channels, - hps.data.sampling_rate, - hps.data.mel_fmin, - hps.data.mel_fmax, - ) - y_hat_mel = mel_spectrogram_torch( - y_hat.squeeze(1).float(), - hps.data.filter_length, - hps.data.n_mel_channels, - hps.data.sampling_rate, - hps.data.hop_length, - hps.data.win_length, - hps.data.mel_fmin, - hps.data.mel_fmax, - ) - image_dict.update( - { - f"gen/mel_{batch_idx}": utils.plot_spectrogram_to_numpy( - y_hat_mel[0].cpu().numpy() - ) - } - ) + # 以降のログは計算が重い気がするし誰も見てない気がするのでコメントアウト + # mel = spec_to_mel_torch( + # spec, + # hps.data.filter_length, + # hps.data.n_mel_channels, + # hps.data.sampling_rate, + # hps.data.mel_fmin, + # hps.data.mel_fmax, + # ) + # y_hat_mel = mel_spectrogram_torch( + # y_hat.squeeze(1).float(), + # hps.data.filter_length, + # hps.data.n_mel_channels, + # hps.data.sampling_rate, + # hps.data.hop_length, + # hps.data.win_length, + # hps.data.mel_fmin, + # hps.data.mel_fmax, + # ) + # image_dict.update( + # { + # f"gen/mel_{batch_idx}": utils.plot_spectrogram_to_numpy( + # y_hat_mel[0].cpu().numpy() + # ) + # } + # ) + # image_dict.update( + # { + # f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy( + # mel[0].cpu().numpy() + # ) + # } + # ) audio_dict.update( { f"gen/audio_{batch_idx}_{use_sdp}": y_hat[ @@ -1090,13 +1114,6 @@ def evaluate(hps, generator, eval_loader, writer_eval): ] } ) - image_dict.update( - { - f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy( - mel[0].cpu().numpy() - ) - } - ) audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]}) utils.summarize( diff --git a/transcribe.py b/transcribe.py index 898fdc98e26f6c54eaa1d56980b71d6b57a04e51..27521261360dd25c9b99a5e4c42c6d0a23a81b11 100644 --- a/transcribe.py +++ b/transcribe.py @@ -1,13 +1,12 @@ import argparse -import os import sys from pathlib import Path from typing import Any, Optional -import yaml from torch.utils.data import Dataset from tqdm import tqdm +from config import get_path_config from style_bert_vits2.constants import Languages from style_bert_vits2.logging import logger from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT @@ -49,6 +48,7 @@ class StrListDataset(Dataset[str]): def transcribe_files_with_hf_whisper( audio_files: list[Path], model_id: str, + output_file: Path, initial_prompt: Optional[str] = None, language: str = "ja", batch_size: int = 16, @@ -69,13 +69,6 @@ def transcribe_files_with_hf_whisper( } logger.info(f"generate_kwargs: {generate_kwargs}") - if initial_prompt is not None: - prompt_ids: torch.Tensor = processor.get_prompt_ids( - initial_prompt, return_tensors="pt" - ) - prompt_ids = prompt_ids.to(device) - generate_kwargs["prompt_ids"] = prompt_ids - pipe = pipeline( model=model_id, max_new_tokens=128, @@ -83,17 +76,33 @@ def transcribe_files_with_hf_whisper( batch_size=batch_size, torch_dtype=torch.float16, device="cuda", - generate_kwargs=generate_kwargs, + trust_remote_code=True, + # generate_kwargs=generate_kwargs, ) + if initial_prompt is not None: + prompt_ids: torch.Tensor = pipe.tokenizer.get_prompt_ids( + initial_prompt, return_tensors="pt" + ).to(device) + generate_kwargs["prompt_ids"] = prompt_ids + dataset = StrListDataset([str(f) for f in audio_files]) results: list[str] = [] - for whisper_result in pipe(dataset): + for whisper_result, file in zip( + pipe(dataset, generate_kwargs=generate_kwargs), audio_files + ): text: str = whisper_result["text"] # なぜかテキストの最初に" {initial_prompt}"が入るので、文字の最初からこれを削除する # cf. https://github.com/huggingface/transformers/issues/27594 if text.startswith(f" {initial_prompt}"): text = text[len(f" {initial_prompt}") :] + # with open(output_file, "w", encoding="utf-8") as f: + # for wav_file, text in zip(wav_files, results): + # wav_rel_path = wav_file.relative_to(input_dir) + # f.write(f"{wav_rel_path}|{model_name}|{language_id}|{text}\n") + with open(output_file, "a", encoding="utf-8") as f: + wav_rel_path = file.relative_to(input_dir) + f.write(f"{wav_rel_path}|{model_name}|{language_id}|{text}\n") results.append(text) if pbar is not None: pbar.update(1) @@ -119,14 +128,14 @@ if __name__ == "__main__": parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--compute_type", type=str, default="bfloat16") parser.add_argument("--use_hf_whisper", action="store_true") + parser.add_argument("--hf_repo_id", type=str, default="") parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--num_beams", type=int, default=1) parser.add_argument("--no_repeat_ngram_size", type=int, default=10) args = parser.parse_args() - with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: - path_config: dict[str, str] = yaml.safe_load(f.read()) - dataset_root = Path(path_config["dataset_root"]) + path_config = get_path_config() + dataset_root = path_config.dataset_root model_name = str(args.model_name) @@ -144,7 +153,7 @@ if __name__ == "__main__": output_file.parent.mkdir(parents=True, exist_ok=True) wav_files = [f for f in input_dir.rglob("*.wav") if f.is_file()] - wav_files = sorted(wav_files, key=lambda x: x.name) + wav_files = sorted(wav_files, key=lambda x: str(x)) if output_file.exists(): logger.warning(f"{output_file} exists, backing up to {output_file}.bak") @@ -187,7 +196,10 @@ if __name__ == "__main__": with open(output_file, "a", encoding="utf-8") as f: f.write(f"{wav_rel_path}|{model_name}|{language_id}|{text}\n") else: - model_id = f"openai/whisper-{args.model}" + if args.hf_repo_id == "": + model_id = f"openai/whisper-{args.model}" + else: + model_id = args.hf_repo_id logger.info(f"Loading HF Whisper model ({model_id})") pbar = tqdm(total=len(wav_files), file=SAFE_STDOUT) results = transcribe_files_with_hf_whisper( @@ -200,10 +212,11 @@ if __name__ == "__main__": no_repeat_ngram_size=no_repeat_ngram_size, device=device, pbar=pbar, + output_file=output_file, ) - with open(output_file, "w", encoding="utf-8") as f: - for wav_file, text in zip(wav_files, results): - wav_rel_path = wav_file.relative_to(input_dir) - f.write(f"{wav_rel_path}|{model_name}|{language_id}|{text}\n") + # with open(output_file, "w", encoding="utf-8") as f: + # for wav_file, text in zip(wav_files, results): + # wav_rel_path = wav_file.relative_to(input_dir) + # f.write(f"{wav_rel_path}|{model_name}|{language_id}|{text}\n") sys.exit(0) diff --git a/vad_filter.py b/vad_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..1259051f58d8b87008b270bae6b4176724287807 --- /dev/null +++ b/vad_filter.py @@ -0,0 +1,92 @@ +import argparse +import os +import shutil +import sys +from pathlib import Path + +import pandas as pd +import torch +from tqdm import tqdm + +from style_bert_vits2.logging import logger + + +vad_model, utils = torch.hub.load( + repo_or_dir="litagin02/silero-vad", + model="silero_vad", + onnx=True, + trust_repo=True, +) + +(get_speech_timestamps, _, read_audio, *_) = utils + + +def get_speech_ratio(audio_file): + sampling_rate = 16000 + + wav = read_audio(audio_file, sampling_rate=sampling_rate) + speech_timestamps = get_speech_timestamps( + wav, vad_model, sampling_rate=sampling_rate + ) + + speech_dur_ms = 0 + + for ts in speech_timestamps: + start_ms = ts["start"] / 16 + end_ms = ts["end"] / 16 + speech_dur_ms += end_ms - start_ms + + total_dur_ms = len(wav) / sampling_rate * 1000 + return speech_dur_ms / total_dur_ms + + +def process(file: Path): + speech_ratio = get_speech_ratio(file) + return file, speech_ratio + + +def main(): + parser = argparse.ArgumentParser(description="Calculate speech ratio.") + parser.add_argument( + "-i", "--input", help="Directory containing audio files", required=True + ) + args = parser.parse_args() + + if os.path.exists(os.path.join(args.input, "low_speech_ratio")): + logger.info("Low speech ratio directory already exists, skipping...") + exit(0) + + data_dir = Path(args.input) + wav_files = list(data_dir.glob("*.wav")) + wav_files.sort() + + if len(wav_files) < 100: + logger.warning("Too few files, skipping...") + exit(0) + + logger.info(f"Start VAD filtering for {data_dir}...") + + results = [] + + for wav_file in tqdm(wav_files, file=sys.stdout): + speech_ratio = get_speech_ratio(wav_file) + results.append((wav_file, speech_ratio)) + + results_df = pd.DataFrame(results, columns=["file", "speech_ratio"]) + results_df.to_csv(os.path.join(data_dir, "speech_ratio.csv"), index=False) + + logger.info(f"Speech ratio stats:\n{results_df['speech_ratio'].describe()}") + threshold = 0.5 + + low_speech_ratio_dir = os.path.join(data_dir, "low_speech_ratio") + os.makedirs(low_speech_ratio_dir, exist_ok=True) + + low_speech_files = results_df[results_df["speech_ratio"] < threshold]["file"] + logger.info(f"Moving {len(low_speech_files)} files to {low_speech_ratio_dir}...") + for low_speech_file in low_speech_files: + shutil.move(low_speech_file, low_speech_ratio_dir) + logger.success("VAD filtering completed.") + + +if __name__ == "__main__": + main()