mrfakename commited on
Commit
4300fed
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +36 -0
  2. .gitignore +11 -0
  3. LICENSE +19 -0
  4. README.md +162 -0
  5. app.py +23 -0
  6. logo.png +0 -0
  7. melo/__init__.py +0 -0
  8. melo/api.py +113 -0
  9. melo/attentions.py +459 -0
  10. melo/commons.py +160 -0
  11. melo/download_utils.py +47 -0
  12. melo/mel_processing.py +174 -0
  13. melo/models.py +1038 -0
  14. melo/modules.py +598 -0
  15. melo/split_utils.py +131 -0
  16. melo/text/__init__.py +35 -0
  17. melo/text/chinese.py +199 -0
  18. melo/text/chinese_bert.py +107 -0
  19. melo/text/chinese_mix.py +253 -0
  20. melo/text/cleaner.py +36 -0
  21. melo/text/cleaner_multiling.py +110 -0
  22. melo/text/cmudict.rep +0 -0
  23. melo/text/cmudict_cache.pickle +3 -0
  24. melo/text/english.py +284 -0
  25. melo/text/english_bert.py +39 -0
  26. melo/text/english_utils/__init__.py +0 -0
  27. melo/text/english_utils/abbreviations.py +35 -0
  28. melo/text/english_utils/number_norm.py +97 -0
  29. melo/text/english_utils/time_norm.py +47 -0
  30. melo/text/es_phonemizer/__init__.py +0 -0
  31. melo/text/es_phonemizer/base.py +140 -0
  32. melo/text/es_phonemizer/cleaner.py +109 -0
  33. melo/text/es_phonemizer/es_symbols.json +79 -0
  34. melo/text/es_phonemizer/es_symbols.txt +1 -0
  35. melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  36. melo/text/es_phonemizer/es_to_ipa.py +12 -0
  37. melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  38. melo/text/es_phonemizer/punctuation.py +174 -0
  39. melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  40. melo/text/es_phonemizer/test.ipynb +124 -0
  41. melo/text/fr_phonemizer/__init__.py +0 -0
  42. melo/text/fr_phonemizer/base.py +140 -0
  43. melo/text/fr_phonemizer/cleaner.py +122 -0
  44. melo/text/fr_phonemizer/en_symbols.json +78 -0
  45. melo/text/fr_phonemizer/fr_symbols.json +89 -0
  46. melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  47. melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  48. melo/text/fr_phonemizer/french_symbols.txt +1 -0
  49. melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  50. melo/text/fr_phonemizer/punctuation.py +172 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ melo/text/fr_phonemizer/example_ipa.txt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .ipynb_checkpoints/
3
+ basetts_outputs_use_bert/
4
+ basetts_outputs/
5
+ multilingual_ckpts
6
+ basetts_outputs_package/
7
+ build/
8
+ *.egg-info/
9
+ .DS_Store
10
+ *.zip
11
+ *.wav
LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2024 MyShell.ai
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in all
11
+ copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <div>&nbsp;</div>
3
+ <img src="logo.png" width="200"/>
4
+ </div>
5
+
6
+ ## Introduction
7
+ MeloTTS is a **high-quality multi-lingual** text-to-speech library by [MyShell.ai](https://myshell.ai). Supported languages include:
8
+
9
+ | Language | Example |
10
+ | --- | --- |
11
+ | English | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/en/EN-Default/speed_1.0/sent_000.wav) |
12
+ | English (American) | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/en/EN-US/speed_1.0/sent_000.wav) |
13
+ | English (British) | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/en/EN-BR/speed_1.0/sent_000.wav) |
14
+ | English (Indian) | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/en/EN_INDIA/speed_1.0/sent_000.wav) |
15
+ | English (Australian) | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/en/EN-AU/speed_1.0/sent_000.wav) |
16
+ | Spanish | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/es/ES/speed_1.0/sent_000.wav) |
17
+ | French | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/fr/FR/speed_1.0/sent_000.wav) |
18
+ | Chinese (mix EN) | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/zh/ZH/speed_1.0/sent_008.wav) |
19
+ | Japanese | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/jp/JP/speed_1.0/sent_000.wav) |
20
+ | Korean | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/kr/KR/speed_1.0/sent_000.wav) |
21
+
22
+ Some other features include:
23
+ - The Chinese speaker supports `mixed Chinese and English`.
24
+ - Fast enough for `CPU real-time inference`.
25
+
26
+ ## Install on Linux
27
+ ```bash
28
+ git clone git@github.com:myshell-ai/MeloTTS.git
29
+ cd MeloTTS
30
+ pip install -e .
31
+ python -m unidic download
32
+ ```
33
+ We welcome the open-source community to make this repo `Mac` and `Windows` compatible. If you find this repo useful, please consider contributing to the repo.
34
+
35
+ ## Usage
36
+
37
+ ### English with Multi Accents
38
+ ```python
39
+ from melo.api import TTS
40
+
41
+ # Speed is adjustable
42
+ speed = 1.0
43
+
44
+ # CPU is sufficient for real-time inference.
45
+ # You can also change to cuda:0
46
+ device = 'cpu'
47
+
48
+ # English
49
+ text = "Did you ever hear a folk tale about a giant turtle?"
50
+ model = TTS(language='EN', device=device)
51
+ speaker_ids = model.hps.data.spk2id
52
+
53
+ # Default accent
54
+ output_path = 'en-default.wav'
55
+ model.tts_to_file(text, speaker_ids['EN-Default'], output_path, speed=speed)
56
+
57
+ # American accent
58
+ output_path = 'en-us.wav'
59
+ model.tts_to_file(text, speaker_ids['EN-US'], output_path, speed=speed)
60
+
61
+ # British accent
62
+ output_path = 'en-br.wav'
63
+ model.tts_to_file(text, speaker_ids['EN-BR'], output_path, speed=speed)
64
+
65
+ # Indian accent
66
+ output_path = 'en-india.wav'
67
+ model.tts_to_file(text, speaker_ids['EN_INDIA'], output_path, speed=speed)
68
+
69
+ # Australian accent
70
+ output_path = 'en-au.wav'
71
+ model.tts_to_file(text, speaker_ids['EN-AU'], output_path, speed=speed)
72
+
73
+ ```
74
+
75
+ ### Spanish
76
+ ```python
77
+ from melo.api import TTS
78
+
79
+ # Speed is adjustable
80
+ speed = 1.0
81
+
82
+ # CPU is sufficient for real-time inference.
83
+ # You can also change to cuda:0
84
+ device = 'cpu'
85
+
86
+ text = "El resplandor del sol acaricia las olas, pintando el cielo con una paleta deslumbrante."
87
+ model = TTS(language='ES', device=device)
88
+ speaker_ids = model.hps.data.spk2id
89
+
90
+ output_path = 'es.wav'
91
+ model.tts_to_file(text, speaker_ids['ES'], output_path, speed=speed)
92
+ ```
93
+
94
+ ### French
95
+ ```python
96
+ from melo.api import TTS
97
+
98
+ # Speed is adjustable
99
+ speed = 1.0
100
+ device = 'cpu' # or cuda:0
101
+
102
+ text = "La lueur dorée du soleil caresse les vagues, peignant le ciel d'une palette éblouissante."
103
+ model = TTS(language='FR', device=device)
104
+ speaker_ids = model.hps.data.spk2id
105
+
106
+ output_path = 'fr.wav'
107
+ model.tts_to_file(text, speaker_ids['FR'], output_path, speed=speed)
108
+ ```
109
+
110
+ ### Chinese
111
+ ```python
112
+ from melo.api import TTS
113
+
114
+ # Speed is adjustable
115
+ speed = 1.0
116
+ device = 'cpu' # or cuda:0
117
+
118
+ text = "我最近在学习machine learning,希望能够在未来的artificial intelligence领域有所建树。"
119
+ model = TTS(language='ZH', device=device)
120
+ speaker_ids = model.hps.data.spk2id
121
+
122
+ output_path = 'zh.wav'
123
+ model.tts_to_file(text, speaker_ids['ZH'], output_path, speed=speed)
124
+ ```
125
+
126
+ ### Japanese
127
+ ```python
128
+ from melo.api import TTS
129
+
130
+ # Speed is adjustable
131
+ speed = 1.0
132
+ device = 'cpu' # or cuda:0
133
+
134
+ text = "彼は毎朝ジョギングをして体を健康に保っています。"
135
+ model = TTS(language='JP', device=device)
136
+ speaker_ids = model.hps.data.spk2id
137
+
138
+ output_path = 'jp.wav'
139
+ model.tts_to_file(text, speaker_ids['JP'], output_path, speed=speed)
140
+ ```
141
+
142
+ ### Korean
143
+ ```python
144
+ from melo.api import TTS
145
+
146
+ # Speed is adjustable
147
+ speed = 1.0
148
+ device = 'cpu' # or cuda:0
149
+
150
+ text = "안녕하세요! 오늘은 날씨가 정말 좋네요."
151
+ model = TTS(language='KR', device=device)
152
+ speaker_ids = model.hps.data.spk2id
153
+
154
+ output_path = 'kr.wav'
155
+ model.tts_to_file(text, speaker_ids['KR'], output_path, speed=speed)
156
+ ```
157
+
158
+ ## License
159
+ This library is under MIT License. Free for both commercial and non-commercial use.
160
+
161
+ ## Acknowledgement
162
+ This implementation is based on several excellent projects, [TTS](https://github.com/coqui-ai/TTS), [VITS](https://github.com/jaywalnut310/vits), [VITS2](https://github.com/daniilrobnikov/vits2) and [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2). We appreciate their awesome work!
app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os, torch, io
3
+ os.system('python -m unidic download')
4
+ from melo.api import TTS
5
+ speed = 1.0
6
+ import tempfile
7
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
+ model = TTS(language='EN', device=device)
9
+ speaker_ids = model.hps.data.spk2id
10
+ def synthesize(speaker, text, speed=1.0):
11
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
12
+ model.tts_to_file(text, speaker_ids[speaker], f.name, speed=speed)
13
+ return f.name
14
+ with gr.Blocks() as demo:
15
+ gr.Markdown('# MeloTTS\n\nAn unofficial demo of [MeloTTS](https://github.com/myshell-ai/MeloTTS) from MyShell AI. MeloTTS is a permissively licensed (MIT) SOTA multi-speaker TTS model.\n\nI am not affiliated with MyShell AI in any way.\n\nThis demo currently only supports English, but the model itself supports other languages.')
16
+ with gr.Group():
17
+ speaker = gr.Dropdown(speaker_ids.keys(), interactive=True, value='EN-Default', label='Speaker')
18
+ speed = gr.Slider(label='Speed', minimum=0.1, maximum=3.0, value=1.0, interactive=True)
19
+ text = gr.Textbox(label="Text to speak", value='The field of text to speech has seen rapid development recently')
20
+ btn = gr.Button('Synthesize', variant='primary')
21
+ aud = gr.Audio(interactive=False)
22
+ btn.click(synthesize, inputs=[speaker, text, speed], outputs=[aud])
23
+ demo.queue(api_open=False).launch(show_api=False)
logo.png ADDED
melo/__init__.py ADDED
File without changes
melo/api.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import torch
5
+ import librosa
6
+ import soundfile
7
+ import torchaudio
8
+ import numpy as np
9
+ import torch.nn as nn
10
+
11
+ from . import utils
12
+ from . import commons
13
+ from .models import SynthesizerTrn
14
+ from .split_utils import split_sentence
15
+ from .mel_processing import spectrogram_torch, spectrogram_torch_conv
16
+ from .download_utils import load_or_download_config, load_or_download_model
17
+
18
+ class TTS(nn.Module):
19
+ def __init__(self,
20
+ language,
21
+ device='cuda:0'):
22
+ super().__init__()
23
+ if 'cuda' in device:
24
+ assert torch.cuda.is_available()
25
+
26
+ # config_path =
27
+ hps = load_or_download_config(language)
28
+
29
+ num_languages = hps.num_languages
30
+ num_tones = hps.num_tones
31
+ symbols = hps.symbols
32
+
33
+ model = SynthesizerTrn(
34
+ len(symbols),
35
+ hps.data.filter_length // 2 + 1,
36
+ hps.train.segment_size // hps.data.hop_length,
37
+ n_speakers=hps.data.n_speakers,
38
+ num_tones=num_tones,
39
+ num_languages=num_languages,
40
+ **hps.model,
41
+ ).to(device)
42
+
43
+ model.eval()
44
+ self.model = model
45
+ self.symbol_to_id = {s: i for i, s in enumerate(symbols)}
46
+ self.hps = hps
47
+ self.device = device
48
+
49
+ # load state_dict
50
+ checkpoint_dict = load_or_download_model(language, device)
51
+ self.model.load_state_dict(checkpoint_dict['model'], strict=True)
52
+
53
+ language = language.split('_')[0]
54
+ self.language = 'ZH_MIX_EN' if language == 'ZH' else language # we support a ZH_MIX_EN model
55
+
56
+ @staticmethod
57
+ def audio_numpy_concat(segment_data_list, sr, speed=1.):
58
+ audio_segments = []
59
+ for segment_data in segment_data_list:
60
+ audio_segments += segment_data.reshape(-1).tolist()
61
+ audio_segments += [0] * int((sr * 0.05) / speed)
62
+ audio_segments = np.array(audio_segments).astype(np.float32)
63
+ return audio_segments
64
+
65
+ @staticmethod
66
+ def split_sentences_into_pieces(text, language):
67
+ texts = split_sentence(text, language_str=language)
68
+ print(" > Text splitted to sentences.")
69
+ print('\n'.join(texts))
70
+ print(" > ===========================")
71
+ return texts
72
+
73
+ def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0):
74
+ language = self.language
75
+ texts = self.split_sentences_into_pieces(text, language)
76
+ audio_list = []
77
+ for t in texts:
78
+ if language in ['EN', 'ZH_MIX_EN']:
79
+ t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
80
+ device = self.device
81
+ bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, device, self.symbol_to_id)
82
+ with torch.no_grad():
83
+ x_tst = phones.to(device).unsqueeze(0)
84
+ tones = tones.to(device).unsqueeze(0)
85
+ lang_ids = lang_ids.to(device).unsqueeze(0)
86
+ bert = bert.to(device).unsqueeze(0)
87
+ ja_bert = ja_bert.to(device).unsqueeze(0)
88
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
89
+ del phones
90
+ speakers = torch.LongTensor([speaker_id]).to(device)
91
+ audio = self.model.infer(
92
+ x_tst,
93
+ x_tst_lengths,
94
+ speakers,
95
+ tones,
96
+ lang_ids,
97
+ bert,
98
+ ja_bert,
99
+ sdp_ratio=sdp_ratio,
100
+ noise_scale=noise_scale,
101
+ noise_scale_w=noise_scale_w,
102
+ length_scale=1. / speed,
103
+ )[0][0, 0].data.cpu().float().numpy()
104
+ del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
105
+ #
106
+ audio_list.append(audio)
107
+ torch.cuda.empty_cache()
108
+ audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
109
+
110
+ if output_path is None:
111
+ return audio
112
+ else:
113
+ soundfile.write(output_path, audio, self.hps.data.sampling_rate)
melo/attentions.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from . import commons
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class LayerNorm(nn.Module):
13
+ def __init__(self, channels, eps=1e-5):
14
+ super().__init__()
15
+ self.channels = channels
16
+ self.eps = eps
17
+
18
+ self.gamma = nn.Parameter(torch.ones(channels))
19
+ self.beta = nn.Parameter(torch.zeros(channels))
20
+
21
+ def forward(self, x):
22
+ x = x.transpose(1, -1)
23
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24
+ return x.transpose(1, -1)
25
+
26
+
27
+ @torch.jit.script
28
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
29
+ n_channels_int = n_channels[0]
30
+ in_act = input_a + input_b
31
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
32
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33
+ acts = t_act * s_act
34
+ return acts
35
+
36
+
37
+ class Encoder(nn.Module):
38
+ def __init__(
39
+ self,
40
+ hidden_channels,
41
+ filter_channels,
42
+ n_heads,
43
+ n_layers,
44
+ kernel_size=1,
45
+ p_dropout=0.0,
46
+ window_size=4,
47
+ isflow=True,
48
+ **kwargs
49
+ ):
50
+ super().__init__()
51
+ self.hidden_channels = hidden_channels
52
+ self.filter_channels = filter_channels
53
+ self.n_heads = n_heads
54
+ self.n_layers = n_layers
55
+ self.kernel_size = kernel_size
56
+ self.p_dropout = p_dropout
57
+ self.window_size = window_size
58
+
59
+ self.cond_layer_idx = self.n_layers
60
+ if "gin_channels" in kwargs:
61
+ self.gin_channels = kwargs["gin_channels"]
62
+ if self.gin_channels != 0:
63
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
64
+ self.cond_layer_idx = (
65
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
66
+ )
67
+ assert (
68
+ self.cond_layer_idx < self.n_layers
69
+ ), "cond_layer_idx should be less than n_layers"
70
+ self.drop = nn.Dropout(p_dropout)
71
+ self.attn_layers = nn.ModuleList()
72
+ self.norm_layers_1 = nn.ModuleList()
73
+ self.ffn_layers = nn.ModuleList()
74
+ self.norm_layers_2 = nn.ModuleList()
75
+
76
+ for i in range(self.n_layers):
77
+ self.attn_layers.append(
78
+ MultiHeadAttention(
79
+ hidden_channels,
80
+ hidden_channels,
81
+ n_heads,
82
+ p_dropout=p_dropout,
83
+ window_size=window_size,
84
+ )
85
+ )
86
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
87
+ self.ffn_layers.append(
88
+ FFN(
89
+ hidden_channels,
90
+ hidden_channels,
91
+ filter_channels,
92
+ kernel_size,
93
+ p_dropout=p_dropout,
94
+ )
95
+ )
96
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
97
+
98
+ def forward(self, x, x_mask, g=None):
99
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
100
+ x = x * x_mask
101
+ for i in range(self.n_layers):
102
+ if i == self.cond_layer_idx and g is not None:
103
+ g = self.spk_emb_linear(g.transpose(1, 2))
104
+ g = g.transpose(1, 2)
105
+ x = x + g
106
+ x = x * x_mask
107
+ y = self.attn_layers[i](x, x, attn_mask)
108
+ y = self.drop(y)
109
+ x = self.norm_layers_1[i](x + y)
110
+
111
+ y = self.ffn_layers[i](x, x_mask)
112
+ y = self.drop(y)
113
+ x = self.norm_layers_2[i](x + y)
114
+ x = x * x_mask
115
+ return x
116
+
117
+
118
+ class Decoder(nn.Module):
119
+ def __init__(
120
+ self,
121
+ hidden_channels,
122
+ filter_channels,
123
+ n_heads,
124
+ n_layers,
125
+ kernel_size=1,
126
+ p_dropout=0.0,
127
+ proximal_bias=False,
128
+ proximal_init=True,
129
+ **kwargs
130
+ ):
131
+ super().__init__()
132
+ self.hidden_channels = hidden_channels
133
+ self.filter_channels = filter_channels
134
+ self.n_heads = n_heads
135
+ self.n_layers = n_layers
136
+ self.kernel_size = kernel_size
137
+ self.p_dropout = p_dropout
138
+ self.proximal_bias = proximal_bias
139
+ self.proximal_init = proximal_init
140
+
141
+ self.drop = nn.Dropout(p_dropout)
142
+ self.self_attn_layers = nn.ModuleList()
143
+ self.norm_layers_0 = nn.ModuleList()
144
+ self.encdec_attn_layers = nn.ModuleList()
145
+ self.norm_layers_1 = nn.ModuleList()
146
+ self.ffn_layers = nn.ModuleList()
147
+ self.norm_layers_2 = nn.ModuleList()
148
+ for i in range(self.n_layers):
149
+ self.self_attn_layers.append(
150
+ MultiHeadAttention(
151
+ hidden_channels,
152
+ hidden_channels,
153
+ n_heads,
154
+ p_dropout=p_dropout,
155
+ proximal_bias=proximal_bias,
156
+ proximal_init=proximal_init,
157
+ )
158
+ )
159
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
160
+ self.encdec_attn_layers.append(
161
+ MultiHeadAttention(
162
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
163
+ )
164
+ )
165
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
166
+ self.ffn_layers.append(
167
+ FFN(
168
+ hidden_channels,
169
+ hidden_channels,
170
+ filter_channels,
171
+ kernel_size,
172
+ p_dropout=p_dropout,
173
+ causal=True,
174
+ )
175
+ )
176
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
177
+
178
+ def forward(self, x, x_mask, h, h_mask):
179
+ """
180
+ x: decoder input
181
+ h: encoder output
182
+ """
183
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
184
+ device=x.device, dtype=x.dtype
185
+ )
186
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
187
+ x = x * x_mask
188
+ for i in range(self.n_layers):
189
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
190
+ y = self.drop(y)
191
+ x = self.norm_layers_0[i](x + y)
192
+
193
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
194
+ y = self.drop(y)
195
+ x = self.norm_layers_1[i](x + y)
196
+
197
+ y = self.ffn_layers[i](x, x_mask)
198
+ y = self.drop(y)
199
+ x = self.norm_layers_2[i](x + y)
200
+ x = x * x_mask
201
+ return x
202
+
203
+
204
+ class MultiHeadAttention(nn.Module):
205
+ def __init__(
206
+ self,
207
+ channels,
208
+ out_channels,
209
+ n_heads,
210
+ p_dropout=0.0,
211
+ window_size=None,
212
+ heads_share=True,
213
+ block_length=None,
214
+ proximal_bias=False,
215
+ proximal_init=False,
216
+ ):
217
+ super().__init__()
218
+ assert channels % n_heads == 0
219
+
220
+ self.channels = channels
221
+ self.out_channels = out_channels
222
+ self.n_heads = n_heads
223
+ self.p_dropout = p_dropout
224
+ self.window_size = window_size
225
+ self.heads_share = heads_share
226
+ self.block_length = block_length
227
+ self.proximal_bias = proximal_bias
228
+ self.proximal_init = proximal_init
229
+ self.attn = None
230
+
231
+ self.k_channels = channels // n_heads
232
+ self.conv_q = nn.Conv1d(channels, channels, 1)
233
+ self.conv_k = nn.Conv1d(channels, channels, 1)
234
+ self.conv_v = nn.Conv1d(channels, channels, 1)
235
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
236
+ self.drop = nn.Dropout(p_dropout)
237
+
238
+ if window_size is not None:
239
+ n_heads_rel = 1 if heads_share else n_heads
240
+ rel_stddev = self.k_channels**-0.5
241
+ self.emb_rel_k = nn.Parameter(
242
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
243
+ * rel_stddev
244
+ )
245
+ self.emb_rel_v = nn.Parameter(
246
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
247
+ * rel_stddev
248
+ )
249
+
250
+ nn.init.xavier_uniform_(self.conv_q.weight)
251
+ nn.init.xavier_uniform_(self.conv_k.weight)
252
+ nn.init.xavier_uniform_(self.conv_v.weight)
253
+ if proximal_init:
254
+ with torch.no_grad():
255
+ self.conv_k.weight.copy_(self.conv_q.weight)
256
+ self.conv_k.bias.copy_(self.conv_q.bias)
257
+
258
+ def forward(self, x, c, attn_mask=None):
259
+ q = self.conv_q(x)
260
+ k = self.conv_k(c)
261
+ v = self.conv_v(c)
262
+
263
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
264
+
265
+ x = self.conv_o(x)
266
+ return x
267
+
268
+ def attention(self, query, key, value, mask=None):
269
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
270
+ b, d, t_s, t_t = (*key.size(), query.size(2))
271
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
272
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
273
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
274
+
275
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
276
+ if self.window_size is not None:
277
+ assert (
278
+ t_s == t_t
279
+ ), "Relative attention is only available for self-attention."
280
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
281
+ rel_logits = self._matmul_with_relative_keys(
282
+ query / math.sqrt(self.k_channels), key_relative_embeddings
283
+ )
284
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
285
+ scores = scores + scores_local
286
+ if self.proximal_bias:
287
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
288
+ scores = scores + self._attention_bias_proximal(t_s).to(
289
+ device=scores.device, dtype=scores.dtype
290
+ )
291
+ if mask is not None:
292
+ scores = scores.masked_fill(mask == 0, -1e4)
293
+ if self.block_length is not None:
294
+ assert (
295
+ t_s == t_t
296
+ ), "Local attention is only available for self-attention."
297
+ block_mask = (
298
+ torch.ones_like(scores)
299
+ .triu(-self.block_length)
300
+ .tril(self.block_length)
301
+ )
302
+ scores = scores.masked_fill(block_mask == 0, -1e4)
303
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
304
+ p_attn = self.drop(p_attn)
305
+ output = torch.matmul(p_attn, value)
306
+ if self.window_size is not None:
307
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
308
+ value_relative_embeddings = self._get_relative_embeddings(
309
+ self.emb_rel_v, t_s
310
+ )
311
+ output = output + self._matmul_with_relative_values(
312
+ relative_weights, value_relative_embeddings
313
+ )
314
+ output = (
315
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
316
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
317
+ return output, p_attn
318
+
319
+ def _matmul_with_relative_values(self, x, y):
320
+ """
321
+ x: [b, h, l, m]
322
+ y: [h or 1, m, d]
323
+ ret: [b, h, l, d]
324
+ """
325
+ ret = torch.matmul(x, y.unsqueeze(0))
326
+ return ret
327
+
328
+ def _matmul_with_relative_keys(self, x, y):
329
+ """
330
+ x: [b, h, l, d]
331
+ y: [h or 1, m, d]
332
+ ret: [b, h, l, m]
333
+ """
334
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
335
+ return ret
336
+
337
+ def _get_relative_embeddings(self, relative_embeddings, length):
338
+ 2 * self.window_size + 1
339
+ # Pad first before slice to avoid using cond ops.
340
+ pad_length = max(length - (self.window_size + 1), 0)
341
+ slice_start_position = max((self.window_size + 1) - length, 0)
342
+ slice_end_position = slice_start_position + 2 * length - 1
343
+ if pad_length > 0:
344
+ padded_relative_embeddings = F.pad(
345
+ relative_embeddings,
346
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
347
+ )
348
+ else:
349
+ padded_relative_embeddings = relative_embeddings
350
+ used_relative_embeddings = padded_relative_embeddings[
351
+ :, slice_start_position:slice_end_position
352
+ ]
353
+ return used_relative_embeddings
354
+
355
+ def _relative_position_to_absolute_position(self, x):
356
+ """
357
+ x: [b, h, l, 2*l-1]
358
+ ret: [b, h, l, l]
359
+ """
360
+ batch, heads, length, _ = x.size()
361
+ # Concat columns of pad to shift from relative to absolute indexing.
362
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
363
+
364
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
365
+ x_flat = x.view([batch, heads, length * 2 * length])
366
+ x_flat = F.pad(
367
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
368
+ )
369
+
370
+ # Reshape and slice out the padded elements.
371
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
372
+ :, :, :length, length - 1 :
373
+ ]
374
+ return x_final
375
+
376
+ def _absolute_position_to_relative_position(self, x):
377
+ """
378
+ x: [b, h, l, l]
379
+ ret: [b, h, l, 2*l-1]
380
+ """
381
+ batch, heads, length, _ = x.size()
382
+ # pad along column
383
+ x = F.pad(
384
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
385
+ )
386
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
387
+ # add 0's in the beginning that will skew the elements after reshape
388
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
389
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
390
+ return x_final
391
+
392
+ def _attention_bias_proximal(self, length):
393
+ """Bias for self-attention to encourage attention to close positions.
394
+ Args:
395
+ length: an integer scalar.
396
+ Returns:
397
+ a Tensor with shape [1, 1, length, length]
398
+ """
399
+ r = torch.arange(length, dtype=torch.float32)
400
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
401
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
402
+
403
+
404
+ class FFN(nn.Module):
405
+ def __init__(
406
+ self,
407
+ in_channels,
408
+ out_channels,
409
+ filter_channels,
410
+ kernel_size,
411
+ p_dropout=0.0,
412
+ activation=None,
413
+ causal=False,
414
+ ):
415
+ super().__init__()
416
+ self.in_channels = in_channels
417
+ self.out_channels = out_channels
418
+ self.filter_channels = filter_channels
419
+ self.kernel_size = kernel_size
420
+ self.p_dropout = p_dropout
421
+ self.activation = activation
422
+ self.causal = causal
423
+
424
+ if causal:
425
+ self.padding = self._causal_padding
426
+ else:
427
+ self.padding = self._same_padding
428
+
429
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
430
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
431
+ self.drop = nn.Dropout(p_dropout)
432
+
433
+ def forward(self, x, x_mask):
434
+ x = self.conv_1(self.padding(x * x_mask))
435
+ if self.activation == "gelu":
436
+ x = x * torch.sigmoid(1.702 * x)
437
+ else:
438
+ x = torch.relu(x)
439
+ x = self.drop(x)
440
+ x = self.conv_2(self.padding(x * x_mask))
441
+ return x * x_mask
442
+
443
+ def _causal_padding(self, x):
444
+ if self.kernel_size == 1:
445
+ return x
446
+ pad_l = self.kernel_size - 1
447
+ pad_r = 0
448
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
449
+ x = F.pad(x, commons.convert_pad_shape(padding))
450
+ return x
451
+
452
+ def _same_padding(self, x):
453
+ if self.kernel_size == 1:
454
+ return x
455
+ pad_l = (self.kernel_size - 1) // 2
456
+ pad_r = self.kernel_size // 2
457
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
458
+ x = F.pad(x, commons.convert_pad_shape(padding))
459
+ return x
melo/commons.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ def get_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ def convert_pad_shape(pad_shape):
17
+ layer = pad_shape[::-1]
18
+ pad_shape = [item for sublist in layer for item in sublist]
19
+ return pad_shape
20
+
21
+
22
+ def intersperse(lst, item):
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
+
27
+
28
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
29
+ """KL(P||Q)"""
30
+ kl = (logs_q - logs_p) - 0.5
31
+ kl += (
32
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
33
+ )
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ return ret
55
+
56
+
57
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
+ b, d, t = x.size()
59
+ if x_lengths is None:
60
+ x_lengths = t
61
+ ids_str_max = x_lengths - segment_size + 1
62
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
68
+ position = torch.arange(length, dtype=torch.float)
69
+ num_timescales = channels // 2
70
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
71
+ num_timescales - 1
72
+ )
73
+ inv_timescales = min_timescale * torch.exp(
74
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
75
+ )
76
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
79
+ signal = signal.view(1, channels, length)
80
+ return signal
81
+
82
+
83
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84
+ b, channels, length = x.size()
85
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86
+ return x + signal.to(dtype=x.dtype, device=x.device)
87
+
88
+
89
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90
+ b, channels, length = x.size()
91
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93
+
94
+
95
+ def subsequent_mask(length):
96
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97
+ return mask
98
+
99
+
100
+ @torch.jit.script
101
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102
+ n_channels_int = n_channels[0]
103
+ in_act = input_a + input_b
104
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
105
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106
+ acts = t_act * s_act
107
+ return acts
108
+
109
+
110
+ def convert_pad_shape(pad_shape):
111
+ layer = pad_shape[::-1]
112
+ pad_shape = [item for sublist in layer for item in sublist]
113
+ return pad_shape
114
+
115
+
116
+ def shift_1d(x):
117
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118
+ return x
119
+
120
+
121
+ def sequence_mask(length, max_length=None):
122
+ if max_length is None:
123
+ max_length = length.max()
124
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125
+ return x.unsqueeze(0) < length.unsqueeze(1)
126
+
127
+
128
+ def generate_path(duration, mask):
129
+ """
130
+ duration: [b, 1, t_x]
131
+ mask: [b, 1, t_y, t_x]
132
+ """
133
+
134
+ b, _, t_y, t_x = mask.shape
135
+ cum_duration = torch.cumsum(duration, -1)
136
+
137
+ cum_duration_flat = cum_duration.view(b * t_x)
138
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
139
+ path = path.view(b, t_x, t_y)
140
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
141
+ path = path.unsqueeze(1).transpose(2, 3) * mask
142
+ return path
143
+
144
+
145
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
146
+ if isinstance(parameters, torch.Tensor):
147
+ parameters = [parameters]
148
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
149
+ norm_type = float(norm_type)
150
+ if clip_value is not None:
151
+ clip_value = float(clip_value)
152
+
153
+ total_norm = 0
154
+ for p in parameters:
155
+ param_norm = p.grad.data.norm(norm_type)
156
+ total_norm += param_norm.item() ** norm_type
157
+ if clip_value is not None:
158
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
159
+ total_norm = total_norm ** (1.0 / norm_type)
160
+ return total_norm
melo/download_utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from . import utils
4
+
5
+ DOWNLOAD_CKPT_URLS = {
6
+ 'EN': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/EN/checkpoint.pth',
7
+ 'EN_V2': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/EN_V2/checkpoint.pth',
8
+ 'FR': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/FR/checkpoint.pth',
9
+ 'JP': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/JP/checkpoint.pth',
10
+ 'ES': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/ES/checkpoint.pth',
11
+ 'ZH': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/ZH/checkpoint.pth',
12
+ 'KR': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/KR/checkpoint.pth',
13
+ }
14
+
15
+ DOWNLOAD_CONFIG_URLS = {
16
+ 'EN': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/EN/config.json',
17
+ 'EN_V2': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/EN_V2/config.json',
18
+ 'FR': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/FR/config.json',
19
+ 'JP': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/JP/config.json',
20
+ 'ES': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/ES/config.json',
21
+ 'ZH': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/ZH/config.json',
22
+ 'KR': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/KR/config.json',
23
+ }
24
+
25
+ def load_or_download_config(locale):
26
+ language = locale.split('-')[0].upper()
27
+ assert language in DOWNLOAD_CONFIG_URLS
28
+ config_path = os.path.expanduser(f'~/.local/share/openvoice/basespeakers/{language}/config.json')
29
+ try:
30
+ return utils.get_hparams_from_file(config_path)
31
+ except:
32
+ # download
33
+ os.makedirs(os.path.dirname(config_path), exist_ok=True)
34
+ os.system(f'wget {DOWNLOAD_CONFIG_URLS[language]} -O {config_path}')
35
+ return utils.get_hparams_from_file(config_path)
36
+
37
+ def load_or_download_model(locale, device):
38
+ language = locale.split('-')[0].upper()
39
+ assert language in DOWNLOAD_CKPT_URLS
40
+ ckpt_path = os.path.expanduser(f'~/.local/share/openvoice/basespeakers/{language}/checkpoint.pth')
41
+ try:
42
+ return torch.load(ckpt_path, map_location=device)
43
+ except:
44
+ # download
45
+ os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
46
+ os.system(f'wget {DOWNLOAD_CKPT_URLS[language]} -O {ckpt_path}')
47
+ return torch.load(ckpt_path, map_location=device)
melo/mel_processing.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ import librosa
4
+ from librosa.filters import mel as librosa_mel_fn
5
+
6
+ MAX_WAV_VALUE = 32768.0
7
+
8
+
9
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
10
+ """
11
+ PARAMS
12
+ ------
13
+ C: compression factor
14
+ """
15
+ return torch.log(torch.clamp(x, min=clip_val) * C)
16
+
17
+
18
+ def dynamic_range_decompression_torch(x, C=1):
19
+ """
20
+ PARAMS
21
+ ------
22
+ C: compression factor used to compress
23
+ """
24
+ return torch.exp(x) / C
25
+
26
+
27
+ def spectral_normalize_torch(magnitudes):
28
+ output = dynamic_range_compression_torch(magnitudes)
29
+ return output
30
+
31
+
32
+ def spectral_de_normalize_torch(magnitudes):
33
+ output = dynamic_range_decompression_torch(magnitudes)
34
+ return output
35
+
36
+
37
+ mel_basis = {}
38
+ hann_window = {}
39
+
40
+
41
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
42
+ if torch.min(y) < -1.1:
43
+ print("min value is ", torch.min(y))
44
+ if torch.max(y) > 1.1:
45
+ print("max value is ", torch.max(y))
46
+
47
+ global hann_window
48
+ dtype_device = str(y.dtype) + "_" + str(y.device)
49
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
50
+ if wnsize_dtype_device not in hann_window:
51
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
52
+ dtype=y.dtype, device=y.device
53
+ )
54
+
55
+ y = torch.nn.functional.pad(
56
+ y.unsqueeze(1),
57
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
58
+ mode="reflect",
59
+ )
60
+ y = y.squeeze(1)
61
+
62
+ spec = torch.stft(
63
+ y,
64
+ n_fft,
65
+ hop_length=hop_size,
66
+ win_length=win_size,
67
+ window=hann_window[wnsize_dtype_device],
68
+ center=center,
69
+ pad_mode="reflect",
70
+ normalized=False,
71
+ onesided=True,
72
+ return_complex=False,
73
+ )
74
+
75
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
76
+ return spec
77
+
78
+
79
+ def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False):
80
+ global hann_window
81
+ dtype_device = str(y.dtype) + '_' + str(y.device)
82
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
83
+ if wnsize_dtype_device not in hann_window:
84
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
85
+
86
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
87
+
88
+ # ******************** original ************************#
89
+ # y = y.squeeze(1)
90
+ # spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
91
+ # center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
92
+
93
+ # ******************** ConvSTFT ************************#
94
+ freq_cutoff = n_fft // 2 + 1
95
+ fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft)))
96
+ forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1])
97
+ forward_basis = forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float()
98
+
99
+ import torch.nn.functional as F
100
+
101
+ # if center:
102
+ # signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1)
103
+ assert center is False
104
+
105
+ forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride = hop_size)
106
+ spec2 = torch.stack([forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim = -1)
107
+
108
+
109
+ # ******************** Verification ************************#
110
+ spec1 = torch.stft(y.squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
111
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
112
+ assert torch.allclose(spec1, spec2, atol=1e-4)
113
+
114
+ spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6)
115
+ return spec
116
+
117
+
118
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
119
+ global mel_basis
120
+ dtype_device = str(spec.dtype) + "_" + str(spec.device)
121
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
122
+ if fmax_dtype_device not in mel_basis:
123
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
124
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
125
+ dtype=spec.dtype, device=spec.device
126
+ )
127
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
128
+ spec = spectral_normalize_torch(spec)
129
+ return spec
130
+
131
+
132
+ def mel_spectrogram_torch(
133
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
134
+ ):
135
+ global mel_basis, hann_window
136
+ dtype_device = str(y.dtype) + "_" + str(y.device)
137
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
138
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
139
+ if fmax_dtype_device not in mel_basis:
140
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
141
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
142
+ dtype=y.dtype, device=y.device
143
+ )
144
+ if wnsize_dtype_device not in hann_window:
145
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
146
+ dtype=y.dtype, device=y.device
147
+ )
148
+
149
+ y = torch.nn.functional.pad(
150
+ y.unsqueeze(1),
151
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
152
+ mode="reflect",
153
+ )
154
+ y = y.squeeze(1)
155
+
156
+ spec = torch.stft(
157
+ y,
158
+ n_fft,
159
+ hop_length=hop_size,
160
+ win_length=win_size,
161
+ window=hann_window[wnsize_dtype_device],
162
+ center=center,
163
+ pad_mode="reflect",
164
+ normalized=False,
165
+ onesided=True,
166
+ return_complex=False,
167
+ )
168
+
169
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
170
+
171
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
172
+ spec = spectral_normalize_torch(spec)
173
+
174
+ return spec
melo/models.py ADDED
@@ -0,0 +1,1038 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from . import commons
7
+ from . import modules
8
+ from . import attentions
9
+
10
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+
13
+ from .commons import init_weights, get_padding
14
+
15
+
16
+ class DurationDiscriminator(nn.Module): # vits2
17
+ def __init__(
18
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
19
+ ):
20
+ super().__init__()
21
+ self.in_channels = in_channels
22
+ self.filter_channels = filter_channels
23
+ self.kernel_size = kernel_size
24
+ self.p_dropout = p_dropout
25
+ self.gin_channels = gin_channels
26
+
27
+ self.drop = nn.Dropout(p_dropout)
28
+ self.conv_1 = nn.Conv1d(
29
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
30
+ )
31
+ self.norm_1 = modules.LayerNorm(filter_channels)
32
+ self.conv_2 = nn.Conv1d(
33
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
34
+ )
35
+ self.norm_2 = modules.LayerNorm(filter_channels)
36
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
37
+
38
+ self.pre_out_conv_1 = nn.Conv1d(
39
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
40
+ )
41
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
42
+ self.pre_out_conv_2 = nn.Conv1d(
43
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
44
+ )
45
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
46
+
47
+ if gin_channels != 0:
48
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
49
+
50
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
51
+
52
+ def forward_probability(self, x, x_mask, dur, g=None):
53
+ dur = self.dur_proj(dur)
54
+ x = torch.cat([x, dur], dim=1)
55
+ x = self.pre_out_conv_1(x * x_mask)
56
+ x = torch.relu(x)
57
+ x = self.pre_out_norm_1(x)
58
+ x = self.drop(x)
59
+ x = self.pre_out_conv_2(x * x_mask)
60
+ x = torch.relu(x)
61
+ x = self.pre_out_norm_2(x)
62
+ x = self.drop(x)
63
+ x = x * x_mask
64
+ x = x.transpose(1, 2)
65
+ output_prob = self.output_layer(x)
66
+ return output_prob
67
+
68
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
69
+ x = torch.detach(x)
70
+ if g is not None:
71
+ g = torch.detach(g)
72
+ x = x + self.cond(g)
73
+ x = self.conv_1(x * x_mask)
74
+ x = torch.relu(x)
75
+ x = self.norm_1(x)
76
+ x = self.drop(x)
77
+ x = self.conv_2(x * x_mask)
78
+ x = torch.relu(x)
79
+ x = self.norm_2(x)
80
+ x = self.drop(x)
81
+
82
+ output_probs = []
83
+ for dur in [dur_r, dur_hat]:
84
+ output_prob = self.forward_probability(x, x_mask, dur, g)
85
+ output_probs.append(output_prob)
86
+
87
+ return output_probs
88
+
89
+
90
+ class TransformerCouplingBlock(nn.Module):
91
+ def __init__(
92
+ self,
93
+ channels,
94
+ hidden_channels,
95
+ filter_channels,
96
+ n_heads,
97
+ n_layers,
98
+ kernel_size,
99
+ p_dropout,
100
+ n_flows=4,
101
+ gin_channels=0,
102
+ share_parameter=False,
103
+ ):
104
+ super().__init__()
105
+ self.channels = channels
106
+ self.hidden_channels = hidden_channels
107
+ self.kernel_size = kernel_size
108
+ self.n_layers = n_layers
109
+ self.n_flows = n_flows
110
+ self.gin_channels = gin_channels
111
+
112
+ self.flows = nn.ModuleList()
113
+
114
+ self.wn = (
115
+ attentions.FFT(
116
+ hidden_channels,
117
+ filter_channels,
118
+ n_heads,
119
+ n_layers,
120
+ kernel_size,
121
+ p_dropout,
122
+ isflow=True,
123
+ gin_channels=self.gin_channels,
124
+ )
125
+ if share_parameter
126
+ else None
127
+ )
128
+
129
+ for i in range(n_flows):
130
+ self.flows.append(
131
+ modules.TransformerCouplingLayer(
132
+ channels,
133
+ hidden_channels,
134
+ kernel_size,
135
+ n_layers,
136
+ n_heads,
137
+ p_dropout,
138
+ filter_channels,
139
+ mean_only=True,
140
+ wn_sharing_parameter=self.wn,
141
+ gin_channels=self.gin_channels,
142
+ )
143
+ )
144
+ self.flows.append(modules.Flip())
145
+
146
+ def forward(self, x, x_mask, g=None, reverse=False):
147
+ if not reverse:
148
+ for flow in self.flows:
149
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
150
+ else:
151
+ for flow in reversed(self.flows):
152
+ x = flow(x, x_mask, g=g, reverse=reverse)
153
+ return x
154
+
155
+
156
+ class StochasticDurationPredictor(nn.Module):
157
+ def __init__(
158
+ self,
159
+ in_channels,
160
+ filter_channels,
161
+ kernel_size,
162
+ p_dropout,
163
+ n_flows=4,
164
+ gin_channels=0,
165
+ ):
166
+ super().__init__()
167
+ filter_channels = in_channels # it needs to be removed from future version.
168
+ self.in_channels = in_channels
169
+ self.filter_channels = filter_channels
170
+ self.kernel_size = kernel_size
171
+ self.p_dropout = p_dropout
172
+ self.n_flows = n_flows
173
+ self.gin_channels = gin_channels
174
+
175
+ self.log_flow = modules.Log()
176
+ self.flows = nn.ModuleList()
177
+ self.flows.append(modules.ElementwiseAffine(2))
178
+ for i in range(n_flows):
179
+ self.flows.append(
180
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
181
+ )
182
+ self.flows.append(modules.Flip())
183
+
184
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
185
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
186
+ self.post_convs = modules.DDSConv(
187
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
188
+ )
189
+ self.post_flows = nn.ModuleList()
190
+ self.post_flows.append(modules.ElementwiseAffine(2))
191
+ for i in range(4):
192
+ self.post_flows.append(
193
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
194
+ )
195
+ self.post_flows.append(modules.Flip())
196
+
197
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
198
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
199
+ self.convs = modules.DDSConv(
200
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
201
+ )
202
+ if gin_channels != 0:
203
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
204
+
205
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
206
+ x = torch.detach(x)
207
+ x = self.pre(x)
208
+ if g is not None:
209
+ g = torch.detach(g)
210
+ x = x + self.cond(g)
211
+ x = self.convs(x, x_mask)
212
+ x = self.proj(x) * x_mask
213
+
214
+ if not reverse:
215
+ flows = self.flows
216
+ assert w is not None
217
+
218
+ logdet_tot_q = 0
219
+ h_w = self.post_pre(w)
220
+ h_w = self.post_convs(h_w, x_mask)
221
+ h_w = self.post_proj(h_w) * x_mask
222
+ e_q = (
223
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
224
+ * x_mask
225
+ )
226
+ z_q = e_q
227
+ for flow in self.post_flows:
228
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
229
+ logdet_tot_q += logdet_q
230
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
231
+ u = torch.sigmoid(z_u) * x_mask
232
+ z0 = (w - u) * x_mask
233
+ logdet_tot_q += torch.sum(
234
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
235
+ )
236
+ logq = (
237
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
238
+ - logdet_tot_q
239
+ )
240
+
241
+ logdet_tot = 0
242
+ z0, logdet = self.log_flow(z0, x_mask)
243
+ logdet_tot += logdet
244
+ z = torch.cat([z0, z1], 1)
245
+ for flow in flows:
246
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
247
+ logdet_tot = logdet_tot + logdet
248
+ nll = (
249
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
250
+ - logdet_tot
251
+ )
252
+ return nll + logq # [b]
253
+ else:
254
+ flows = list(reversed(self.flows))
255
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
256
+ z = (
257
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
258
+ * noise_scale
259
+ )
260
+ for flow in flows:
261
+ z = flow(z, x_mask, g=x, reverse=reverse)
262
+ z0, z1 = torch.split(z, [1, 1], 1)
263
+ logw = z0
264
+ return logw
265
+
266
+
267
+ class DurationPredictor(nn.Module):
268
+ def __init__(
269
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
270
+ ):
271
+ super().__init__()
272
+
273
+ self.in_channels = in_channels
274
+ self.filter_channels = filter_channels
275
+ self.kernel_size = kernel_size
276
+ self.p_dropout = p_dropout
277
+ self.gin_channels = gin_channels
278
+
279
+ self.drop = nn.Dropout(p_dropout)
280
+ self.conv_1 = nn.Conv1d(
281
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
282
+ )
283
+ self.norm_1 = modules.LayerNorm(filter_channels)
284
+ self.conv_2 = nn.Conv1d(
285
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
286
+ )
287
+ self.norm_2 = modules.LayerNorm(filter_channels)
288
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
289
+
290
+ if gin_channels != 0:
291
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
292
+
293
+ def forward(self, x, x_mask, g=None):
294
+ x = torch.detach(x)
295
+ if g is not None:
296
+ g = torch.detach(g)
297
+ x = x + self.cond(g)
298
+ x = self.conv_1(x * x_mask)
299
+ x = torch.relu(x)
300
+ x = self.norm_1(x)
301
+ x = self.drop(x)
302
+ x = self.conv_2(x * x_mask)
303
+ x = torch.relu(x)
304
+ x = self.norm_2(x)
305
+ x = self.drop(x)
306
+ x = self.proj(x * x_mask)
307
+ return x * x_mask
308
+
309
+
310
+ class TextEncoder(nn.Module):
311
+ def __init__(
312
+ self,
313
+ n_vocab,
314
+ out_channels,
315
+ hidden_channels,
316
+ filter_channels,
317
+ n_heads,
318
+ n_layers,
319
+ kernel_size,
320
+ p_dropout,
321
+ gin_channels=0,
322
+ num_languages=None,
323
+ num_tones=None,
324
+ ):
325
+ super().__init__()
326
+ if num_languages is None:
327
+ from text import num_languages
328
+ if num_tones is None:
329
+ from text import num_tones
330
+ self.n_vocab = n_vocab
331
+ self.out_channels = out_channels
332
+ self.hidden_channels = hidden_channels
333
+ self.filter_channels = filter_channels
334
+ self.n_heads = n_heads
335
+ self.n_layers = n_layers
336
+ self.kernel_size = kernel_size
337
+ self.p_dropout = p_dropout
338
+ self.gin_channels = gin_channels
339
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
340
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
341
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
342
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
343
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
344
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
345
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
346
+ self.ja_bert_proj = nn.Conv1d(768, hidden_channels, 1)
347
+
348
+ self.encoder = attentions.Encoder(
349
+ hidden_channels,
350
+ filter_channels,
351
+ n_heads,
352
+ n_layers,
353
+ kernel_size,
354
+ p_dropout,
355
+ gin_channels=self.gin_channels,
356
+ )
357
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
358
+
359
+ def forward(self, x, x_lengths, tone, language, bert, ja_bert, g=None):
360
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
361
+ ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
362
+ x = (
363
+ self.emb(x)
364
+ + self.tone_emb(tone)
365
+ + self.language_emb(language)
366
+ + bert_emb
367
+ + ja_bert_emb
368
+ ) * math.sqrt(
369
+ self.hidden_channels
370
+ ) # [b, t, h]
371
+ x = torch.transpose(x, 1, -1) # [b, h, t]
372
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
373
+ x.dtype
374
+ )
375
+
376
+ x = self.encoder(x * x_mask, x_mask, g=g)
377
+ stats = self.proj(x) * x_mask
378
+
379
+ m, logs = torch.split(stats, self.out_channels, dim=1)
380
+ return x, m, logs, x_mask
381
+
382
+
383
+ class ResidualCouplingBlock(nn.Module):
384
+ def __init__(
385
+ self,
386
+ channels,
387
+ hidden_channels,
388
+ kernel_size,
389
+ dilation_rate,
390
+ n_layers,
391
+ n_flows=4,
392
+ gin_channels=0,
393
+ ):
394
+ super().__init__()
395
+ self.channels = channels
396
+ self.hidden_channels = hidden_channels
397
+ self.kernel_size = kernel_size
398
+ self.dilation_rate = dilation_rate
399
+ self.n_layers = n_layers
400
+ self.n_flows = n_flows
401
+ self.gin_channels = gin_channels
402
+
403
+ self.flows = nn.ModuleList()
404
+ for i in range(n_flows):
405
+ self.flows.append(
406
+ modules.ResidualCouplingLayer(
407
+ channels,
408
+ hidden_channels,
409
+ kernel_size,
410
+ dilation_rate,
411
+ n_layers,
412
+ gin_channels=gin_channels,
413
+ mean_only=True,
414
+ )
415
+ )
416
+ self.flows.append(modules.Flip())
417
+
418
+ def forward(self, x, x_mask, g=None, reverse=False):
419
+ if not reverse:
420
+ for flow in self.flows:
421
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
422
+ else:
423
+ for flow in reversed(self.flows):
424
+ x = flow(x, x_mask, g=g, reverse=reverse)
425
+ return x
426
+
427
+
428
+ class PosteriorEncoder(nn.Module):
429
+ def __init__(
430
+ self,
431
+ in_channels,
432
+ out_channels,
433
+ hidden_channels,
434
+ kernel_size,
435
+ dilation_rate,
436
+ n_layers,
437
+ gin_channels=0,
438
+ ):
439
+ super().__init__()
440
+ self.in_channels = in_channels
441
+ self.out_channels = out_channels
442
+ self.hidden_channels = hidden_channels
443
+ self.kernel_size = kernel_size
444
+ self.dilation_rate = dilation_rate
445
+ self.n_layers = n_layers
446
+ self.gin_channels = gin_channels
447
+
448
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
449
+ self.enc = modules.WN(
450
+ hidden_channels,
451
+ kernel_size,
452
+ dilation_rate,
453
+ n_layers,
454
+ gin_channels=gin_channels,
455
+ )
456
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
457
+
458
+ def forward(self, x, x_lengths, g=None, tau=1.0):
459
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
460
+ x.dtype
461
+ )
462
+ x = self.pre(x) * x_mask
463
+ x = self.enc(x, x_mask, g=g)
464
+ stats = self.proj(x) * x_mask
465
+ m, logs = torch.split(stats, self.out_channels, dim=1)
466
+ z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask
467
+ return z, m, logs, x_mask
468
+
469
+
470
+ class Generator(torch.nn.Module):
471
+ def __init__(
472
+ self,
473
+ initial_channel,
474
+ resblock,
475
+ resblock_kernel_sizes,
476
+ resblock_dilation_sizes,
477
+ upsample_rates,
478
+ upsample_initial_channel,
479
+ upsample_kernel_sizes,
480
+ gin_channels=0,
481
+ ):
482
+ super(Generator, self).__init__()
483
+ self.num_kernels = len(resblock_kernel_sizes)
484
+ self.num_upsamples = len(upsample_rates)
485
+ self.conv_pre = Conv1d(
486
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
487
+ )
488
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
489
+
490
+ self.ups = nn.ModuleList()
491
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
492
+ self.ups.append(
493
+ weight_norm(
494
+ ConvTranspose1d(
495
+ upsample_initial_channel // (2**i),
496
+ upsample_initial_channel // (2 ** (i + 1)),
497
+ k,
498
+ u,
499
+ padding=(k - u) // 2,
500
+ )
501
+ )
502
+ )
503
+
504
+ self.resblocks = nn.ModuleList()
505
+ for i in range(len(self.ups)):
506
+ ch = upsample_initial_channel // (2 ** (i + 1))
507
+ for j, (k, d) in enumerate(
508
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
509
+ ):
510
+ self.resblocks.append(resblock(ch, k, d))
511
+
512
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
513
+ self.ups.apply(init_weights)
514
+
515
+ if gin_channels != 0:
516
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
517
+
518
+ def forward(self, x, g=None):
519
+ x = self.conv_pre(x)
520
+ if g is not None:
521
+ x = x + self.cond(g)
522
+
523
+ for i in range(self.num_upsamples):
524
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
525
+ x = self.ups[i](x)
526
+ xs = None
527
+ for j in range(self.num_kernels):
528
+ if xs is None:
529
+ xs = self.resblocks[i * self.num_kernels + j](x)
530
+ else:
531
+ xs += self.resblocks[i * self.num_kernels + j](x)
532
+ x = xs / self.num_kernels
533
+ x = F.leaky_relu(x)
534
+ x = self.conv_post(x)
535
+ x = torch.tanh(x)
536
+
537
+ return x
538
+
539
+ def remove_weight_norm(self):
540
+ print("Removing weight norm...")
541
+ for layer in self.ups:
542
+ remove_weight_norm(layer)
543
+ for layer in self.resblocks:
544
+ layer.remove_weight_norm()
545
+
546
+
547
+ class DiscriminatorP(torch.nn.Module):
548
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
549
+ super(DiscriminatorP, self).__init__()
550
+ self.period = period
551
+ self.use_spectral_norm = use_spectral_norm
552
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
553
+ self.convs = nn.ModuleList(
554
+ [
555
+ norm_f(
556
+ Conv2d(
557
+ 1,
558
+ 32,
559
+ (kernel_size, 1),
560
+ (stride, 1),
561
+ padding=(get_padding(kernel_size, 1), 0),
562
+ )
563
+ ),
564
+ norm_f(
565
+ Conv2d(
566
+ 32,
567
+ 128,
568
+ (kernel_size, 1),
569
+ (stride, 1),
570
+ padding=(get_padding(kernel_size, 1), 0),
571
+ )
572
+ ),
573
+ norm_f(
574
+ Conv2d(
575
+ 128,
576
+ 512,
577
+ (kernel_size, 1),
578
+ (stride, 1),
579
+ padding=(get_padding(kernel_size, 1), 0),
580
+ )
581
+ ),
582
+ norm_f(
583
+ Conv2d(
584
+ 512,
585
+ 1024,
586
+ (kernel_size, 1),
587
+ (stride, 1),
588
+ padding=(get_padding(kernel_size, 1), 0),
589
+ )
590
+ ),
591
+ norm_f(
592
+ Conv2d(
593
+ 1024,
594
+ 1024,
595
+ (kernel_size, 1),
596
+ 1,
597
+ padding=(get_padding(kernel_size, 1), 0),
598
+ )
599
+ ),
600
+ ]
601
+ )
602
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
603
+
604
+ def forward(self, x):
605
+ fmap = []
606
+
607
+ # 1d to 2d
608
+ b, c, t = x.shape
609
+ if t % self.period != 0: # pad first
610
+ n_pad = self.period - (t % self.period)
611
+ x = F.pad(x, (0, n_pad), "reflect")
612
+ t = t + n_pad
613
+ x = x.view(b, c, t // self.period, self.period)
614
+
615
+ for layer in self.convs:
616
+ x = layer(x)
617
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
618
+ fmap.append(x)
619
+ x = self.conv_post(x)
620
+ fmap.append(x)
621
+ x = torch.flatten(x, 1, -1)
622
+
623
+ return x, fmap
624
+
625
+
626
+ class DiscriminatorS(torch.nn.Module):
627
+ def __init__(self, use_spectral_norm=False):
628
+ super(DiscriminatorS, self).__init__()
629
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
630
+ self.convs = nn.ModuleList(
631
+ [
632
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
633
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
634
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
635
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
636
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
637
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
638
+ ]
639
+ )
640
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
641
+
642
+ def forward(self, x):
643
+ fmap = []
644
+
645
+ for layer in self.convs:
646
+ x = layer(x)
647
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
648
+ fmap.append(x)
649
+ x = self.conv_post(x)
650
+ fmap.append(x)
651
+ x = torch.flatten(x, 1, -1)
652
+
653
+ return x, fmap
654
+
655
+
656
+ class MultiPeriodDiscriminator(torch.nn.Module):
657
+ def __init__(self, use_spectral_norm=False):
658
+ super(MultiPeriodDiscriminator, self).__init__()
659
+ periods = [2, 3, 5, 7, 11]
660
+
661
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
662
+ discs = discs + [
663
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
664
+ ]
665
+ self.discriminators = nn.ModuleList(discs)
666
+
667
+ def forward(self, y, y_hat):
668
+ y_d_rs = []
669
+ y_d_gs = []
670
+ fmap_rs = []
671
+ fmap_gs = []
672
+ for i, d in enumerate(self.discriminators):
673
+ y_d_r, fmap_r = d(y)
674
+ y_d_g, fmap_g = d(y_hat)
675
+ y_d_rs.append(y_d_r)
676
+ y_d_gs.append(y_d_g)
677
+ fmap_rs.append(fmap_r)
678
+ fmap_gs.append(fmap_g)
679
+
680
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
681
+
682
+
683
+ class ReferenceEncoder(nn.Module):
684
+ """
685
+ inputs --- [N, Ty/r, n_mels*r] mels
686
+ outputs --- [N, ref_enc_gru_size]
687
+ """
688
+
689
+ def __init__(self, spec_channels, gin_channels=0, layernorm=False):
690
+ super().__init__()
691
+ self.spec_channels = spec_channels
692
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
693
+ K = len(ref_enc_filters)
694
+ filters = [1] + ref_enc_filters
695
+ convs = [
696
+ weight_norm(
697
+ nn.Conv2d(
698
+ in_channels=filters[i],
699
+ out_channels=filters[i + 1],
700
+ kernel_size=(3, 3),
701
+ stride=(2, 2),
702
+ padding=(1, 1),
703
+ )
704
+ )
705
+ for i in range(K)
706
+ ]
707
+ self.convs = nn.ModuleList(convs)
708
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
709
+
710
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
711
+ self.gru = nn.GRU(
712
+ input_size=ref_enc_filters[-1] * out_channels,
713
+ hidden_size=256 // 2,
714
+ batch_first=True,
715
+ )
716
+ self.proj = nn.Linear(128, gin_channels)
717
+ if layernorm:
718
+ self.layernorm = nn.LayerNorm(self.spec_channels)
719
+ print('[Ref Enc]: using layer norm')
720
+ else:
721
+ self.layernorm = None
722
+
723
+ def forward(self, inputs, mask=None):
724
+ N = inputs.size(0)
725
+
726
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
727
+ if self.layernorm is not None:
728
+ out = self.layernorm(out)
729
+
730
+ for conv in self.convs:
731
+ out = conv(out)
732
+ # out = wn(out)
733
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
734
+
735
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
736
+ T = out.size(1)
737
+ N = out.size(0)
738
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
739
+
740
+ self.gru.flatten_parameters()
741
+ memory, out = self.gru(out) # out --- [1, N, 128]
742
+
743
+ return self.proj(out.squeeze(0))
744
+
745
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
746
+ for i in range(n_convs):
747
+ L = (L - kernel_size + 2 * pad) // stride + 1
748
+ return L
749
+
750
+
751
+ class SynthesizerTrn(nn.Module):
752
+ """
753
+ Synthesizer for Training
754
+ """
755
+
756
+ def __init__(
757
+ self,
758
+ n_vocab,
759
+ spec_channels,
760
+ segment_size,
761
+ inter_channels,
762
+ hidden_channels,
763
+ filter_channels,
764
+ n_heads,
765
+ n_layers,
766
+ kernel_size,
767
+ p_dropout,
768
+ resblock,
769
+ resblock_kernel_sizes,
770
+ resblock_dilation_sizes,
771
+ upsample_rates,
772
+ upsample_initial_channel,
773
+ upsample_kernel_sizes,
774
+ n_speakers=256,
775
+ gin_channels=256,
776
+ use_sdp=True,
777
+ n_flow_layer=4,
778
+ n_layers_trans_flow=6,
779
+ flow_share_parameter=False,
780
+ use_transformer_flow=True,
781
+ use_vc=False,
782
+ num_languages=None,
783
+ num_tones=None,
784
+ norm_refenc=False,
785
+ use_se=False,
786
+ **kwargs
787
+ ):
788
+ super().__init__()
789
+ self.n_vocab = n_vocab
790
+ self.spec_channels = spec_channels
791
+ self.inter_channels = inter_channels
792
+ self.hidden_channels = hidden_channels
793
+ self.filter_channels = filter_channels
794
+ self.n_heads = n_heads
795
+ self.n_layers = n_layers
796
+ self.kernel_size = kernel_size
797
+ self.p_dropout = p_dropout
798
+ self.resblock = resblock
799
+ self.resblock_kernel_sizes = resblock_kernel_sizes
800
+ self.resblock_dilation_sizes = resblock_dilation_sizes
801
+ self.upsample_rates = upsample_rates
802
+ self.upsample_initial_channel = upsample_initial_channel
803
+ self.upsample_kernel_sizes = upsample_kernel_sizes
804
+ self.segment_size = segment_size
805
+ self.n_speakers = n_speakers
806
+ self.gin_channels = gin_channels
807
+ self.n_layers_trans_flow = n_layers_trans_flow
808
+ self.use_spk_conditioned_encoder = kwargs.get(
809
+ "use_spk_conditioned_encoder", True
810
+ )
811
+ self.use_sdp = use_sdp
812
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
813
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
814
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
815
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
816
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
817
+ self.enc_gin_channels = gin_channels
818
+ else:
819
+ self.enc_gin_channels = 0
820
+ self.enc_p = TextEncoder(
821
+ n_vocab,
822
+ inter_channels,
823
+ hidden_channels,
824
+ filter_channels,
825
+ n_heads,
826
+ n_layers,
827
+ kernel_size,
828
+ p_dropout,
829
+ gin_channels=self.enc_gin_channels,
830
+ num_languages=num_languages,
831
+ num_tones=num_tones,
832
+ )
833
+ self.dec = Generator(
834
+ inter_channels,
835
+ resblock,
836
+ resblock_kernel_sizes,
837
+ resblock_dilation_sizes,
838
+ upsample_rates,
839
+ upsample_initial_channel,
840
+ upsample_kernel_sizes,
841
+ gin_channels=gin_channels,
842
+ )
843
+ self.enc_q = PosteriorEncoder(
844
+ spec_channels,
845
+ inter_channels,
846
+ hidden_channels,
847
+ 5,
848
+ 1,
849
+ 16,
850
+ gin_channels=gin_channels,
851
+ )
852
+ if use_transformer_flow:
853
+ self.flow = TransformerCouplingBlock(
854
+ inter_channels,
855
+ hidden_channels,
856
+ filter_channels,
857
+ n_heads,
858
+ n_layers_trans_flow,
859
+ 5,
860
+ p_dropout,
861
+ n_flow_layer,
862
+ gin_channels=gin_channels,
863
+ share_parameter=flow_share_parameter,
864
+ )
865
+ else:
866
+ self.flow = ResidualCouplingBlock(
867
+ inter_channels,
868
+ hidden_channels,
869
+ 5,
870
+ 1,
871
+ n_flow_layer,
872
+ gin_channels=gin_channels,
873
+ )
874
+ self.sdp = StochasticDurationPredictor(
875
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
876
+ )
877
+ self.dp = DurationPredictor(
878
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
879
+ )
880
+
881
+ if n_speakers > 1:
882
+ if use_se:
883
+ emb_dim = 512
884
+ self.emb_g = nn.Linear(emb_dim, gin_channels)
885
+ else:
886
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
887
+ else:
888
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels, layernorm=norm_refenc)
889
+ self.use_vc = use_vc
890
+ self.use_se = use_se
891
+
892
+ def forward(self, x, x_lengths, y, y_lengths, sid, tone, language, bert, ja_bert):
893
+ if self.n_speakers > 0:
894
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
895
+ else:
896
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
897
+ if self.use_vc:
898
+ g_p = None
899
+ else:
900
+ g_p = g
901
+ x, m_p, logs_p, x_mask = self.enc_p(
902
+ x, x_lengths, tone, language, bert, ja_bert, g=g_p
903
+ )
904
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
905
+ z_p = self.flow(z, y_mask, g=g)
906
+
907
+ with torch.no_grad():
908
+ # negative cross-entropy
909
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
910
+ neg_cent1 = torch.sum(
911
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
912
+ ) # [b, 1, t_s]
913
+ neg_cent2 = torch.matmul(
914
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
915
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
916
+ neg_cent3 = torch.matmul(
917
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
918
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
919
+ neg_cent4 = torch.sum(
920
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
921
+ ) # [b, 1, t_s]
922
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
923
+ if self.use_noise_scaled_mas:
924
+ epsilon = (
925
+ torch.std(neg_cent)
926
+ * torch.randn_like(neg_cent)
927
+ * self.current_mas_noise_scale
928
+ )
929
+ neg_cent = neg_cent + epsilon
930
+
931
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
932
+ attn = (
933
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
934
+ .unsqueeze(1)
935
+ .detach()
936
+ )
937
+
938
+ w = attn.sum(2)
939
+
940
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
941
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
942
+
943
+ logw_ = torch.log(w + 1e-6) * x_mask
944
+ logw = self.dp(x, x_mask, g=g)
945
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
946
+ x_mask
947
+ ) # for averaging
948
+
949
+ l_length = l_length_dp + l_length_sdp
950
+
951
+ # expand prior
952
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
953
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
954
+
955
+ z_slice, ids_slice = commons.rand_slice_segments(
956
+ z, y_lengths, self.segment_size
957
+ )
958
+ o = self.dec(z_slice, g=g)
959
+ return (
960
+ o,
961
+ l_length,
962
+ attn,
963
+ ids_slice,
964
+ x_mask,
965
+ y_mask,
966
+ (z, z_p, m_p, logs_p, m_q, logs_q),
967
+ (x, logw, logw_),
968
+ )
969
+
970
+ def infer(
971
+ self,
972
+ x,
973
+ x_lengths,
974
+ sid,
975
+ tone,
976
+ language,
977
+ bert,
978
+ ja_bert,
979
+ noise_scale=0.667,
980
+ length_scale=1,
981
+ noise_scale_w=0.8,
982
+ max_len=None,
983
+ sdp_ratio=0,
984
+ y=None,
985
+ g=None,
986
+ ):
987
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
988
+ # g = self.gst(y)
989
+ if g is None:
990
+ if self.n_speakers > 0:
991
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
992
+ else:
993
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
994
+ if self.use_vc:
995
+ g_p = None
996
+ else:
997
+ g_p = g
998
+ x, m_p, logs_p, x_mask = self.enc_p(
999
+ x, x_lengths, tone, language, bert, ja_bert, g=g_p
1000
+ )
1001
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1002
+ sdp_ratio
1003
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
1004
+ w = torch.exp(logw) * x_mask * length_scale
1005
+
1006
+ w_ceil = torch.ceil(w)
1007
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1008
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1009
+ x_mask.dtype
1010
+ )
1011
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1012
+ attn = commons.generate_path(w_ceil, attn_mask)
1013
+
1014
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1015
+ 1, 2
1016
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1017
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1018
+ 1, 2
1019
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1020
+
1021
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1022
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1023
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
1024
+ # print('max/min of o:', o.max(), o.min())
1025
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
1026
+
1027
+ def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
1028
+ if self.use_se:
1029
+ sid_src = self.emb_g(sid_src).unsqueeze(-1)
1030
+ sid_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
1031
+
1032
+ g_src = sid_src
1033
+ g_tgt = sid_tgt
1034
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src, tau=tau)
1035
+ z_p = self.flow(z, y_mask, g=g_src)
1036
+ z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
1037
+ o_hat = self.dec(z_hat * y_mask, g=g_tgt)
1038
+ return o_hat, y_mask, (z, z_p, z_hat)
melo/modules.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from torch.nn import Conv1d
7
+ from torch.nn.utils import weight_norm, remove_weight_norm
8
+
9
+ from . import commons
10
+ from .commons import init_weights, get_padding
11
+ from .transforms import piecewise_rational_quadratic_transform
12
+ from .attentions import Encoder
13
+
14
+ LRELU_SLOPE = 0.1
15
+
16
+
17
+ class LayerNorm(nn.Module):
18
+ def __init__(self, channels, eps=1e-5):
19
+ super().__init__()
20
+ self.channels = channels
21
+ self.eps = eps
22
+
23
+ self.gamma = nn.Parameter(torch.ones(channels))
24
+ self.beta = nn.Parameter(torch.zeros(channels))
25
+
26
+ def forward(self, x):
27
+ x = x.transpose(1, -1)
28
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
29
+ return x.transpose(1, -1)
30
+
31
+
32
+ class ConvReluNorm(nn.Module):
33
+ def __init__(
34
+ self,
35
+ in_channels,
36
+ hidden_channels,
37
+ out_channels,
38
+ kernel_size,
39
+ n_layers,
40
+ p_dropout,
41
+ ):
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+ self.hidden_channels = hidden_channels
45
+ self.out_channels = out_channels
46
+ self.kernel_size = kernel_size
47
+ self.n_layers = n_layers
48
+ self.p_dropout = p_dropout
49
+ assert n_layers > 1, "Number of layers should be larger than 0."
50
+
51
+ self.conv_layers = nn.ModuleList()
52
+ self.norm_layers = nn.ModuleList()
53
+ self.conv_layers.append(
54
+ nn.Conv1d(
55
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
56
+ )
57
+ )
58
+ self.norm_layers.append(LayerNorm(hidden_channels))
59
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
60
+ for _ in range(n_layers - 1):
61
+ self.conv_layers.append(
62
+ nn.Conv1d(
63
+ hidden_channels,
64
+ hidden_channels,
65
+ kernel_size,
66
+ padding=kernel_size // 2,
67
+ )
68
+ )
69
+ self.norm_layers.append(LayerNorm(hidden_channels))
70
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
71
+ self.proj.weight.data.zero_()
72
+ self.proj.bias.data.zero_()
73
+
74
+ def forward(self, x, x_mask):
75
+ x_org = x
76
+ for i in range(self.n_layers):
77
+ x = self.conv_layers[i](x * x_mask)
78
+ x = self.norm_layers[i](x)
79
+ x = self.relu_drop(x)
80
+ x = x_org + self.proj(x)
81
+ return x * x_mask
82
+
83
+
84
+ class DDSConv(nn.Module):
85
+ """
86
+ Dialted and Depth-Separable Convolution
87
+ """
88
+
89
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
90
+ super().__init__()
91
+ self.channels = channels
92
+ self.kernel_size = kernel_size
93
+ self.n_layers = n_layers
94
+ self.p_dropout = p_dropout
95
+
96
+ self.drop = nn.Dropout(p_dropout)
97
+ self.convs_sep = nn.ModuleList()
98
+ self.convs_1x1 = nn.ModuleList()
99
+ self.norms_1 = nn.ModuleList()
100
+ self.norms_2 = nn.ModuleList()
101
+ for i in range(n_layers):
102
+ dilation = kernel_size**i
103
+ padding = (kernel_size * dilation - dilation) // 2
104
+ self.convs_sep.append(
105
+ nn.Conv1d(
106
+ channels,
107
+ channels,
108
+ kernel_size,
109
+ groups=channels,
110
+ dilation=dilation,
111
+ padding=padding,
112
+ )
113
+ )
114
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
115
+ self.norms_1.append(LayerNorm(channels))
116
+ self.norms_2.append(LayerNorm(channels))
117
+
118
+ def forward(self, x, x_mask, g=None):
119
+ if g is not None:
120
+ x = x + g
121
+ for i in range(self.n_layers):
122
+ y = self.convs_sep[i](x * x_mask)
123
+ y = self.norms_1[i](y)
124
+ y = F.gelu(y)
125
+ y = self.convs_1x1[i](y)
126
+ y = self.norms_2[i](y)
127
+ y = F.gelu(y)
128
+ y = self.drop(y)
129
+ x = x + y
130
+ return x * x_mask
131
+
132
+
133
+ class WN(torch.nn.Module):
134
+ def __init__(
135
+ self,
136
+ hidden_channels,
137
+ kernel_size,
138
+ dilation_rate,
139
+ n_layers,
140
+ gin_channels=0,
141
+ p_dropout=0,
142
+ ):
143
+ super(WN, self).__init__()
144
+ assert kernel_size % 2 == 1
145
+ self.hidden_channels = hidden_channels
146
+ self.kernel_size = (kernel_size,)
147
+ self.dilation_rate = dilation_rate
148
+ self.n_layers = n_layers
149
+ self.gin_channels = gin_channels
150
+ self.p_dropout = p_dropout
151
+
152
+ self.in_layers = torch.nn.ModuleList()
153
+ self.res_skip_layers = torch.nn.ModuleList()
154
+ self.drop = nn.Dropout(p_dropout)
155
+
156
+ if gin_channels != 0:
157
+ cond_layer = torch.nn.Conv1d(
158
+ gin_channels, 2 * hidden_channels * n_layers, 1
159
+ )
160
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
161
+
162
+ for i in range(n_layers):
163
+ dilation = dilation_rate**i
164
+ padding = int((kernel_size * dilation - dilation) / 2)
165
+ in_layer = torch.nn.Conv1d(
166
+ hidden_channels,
167
+ 2 * hidden_channels,
168
+ kernel_size,
169
+ dilation=dilation,
170
+ padding=padding,
171
+ )
172
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
173
+ self.in_layers.append(in_layer)
174
+
175
+ # last one is not necessary
176
+ if i < n_layers - 1:
177
+ res_skip_channels = 2 * hidden_channels
178
+ else:
179
+ res_skip_channels = hidden_channels
180
+
181
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
182
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
183
+ self.res_skip_layers.append(res_skip_layer)
184
+
185
+ def forward(self, x, x_mask, g=None, **kwargs):
186
+ output = torch.zeros_like(x)
187
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
188
+
189
+ if g is not None:
190
+ g = self.cond_layer(g)
191
+
192
+ for i in range(self.n_layers):
193
+ x_in = self.in_layers[i](x)
194
+ if g is not None:
195
+ cond_offset = i * 2 * self.hidden_channels
196
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
197
+ else:
198
+ g_l = torch.zeros_like(x_in)
199
+
200
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
201
+ acts = self.drop(acts)
202
+
203
+ res_skip_acts = self.res_skip_layers[i](acts)
204
+ if i < self.n_layers - 1:
205
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
206
+ x = (x + res_acts) * x_mask
207
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
208
+ else:
209
+ output = output + res_skip_acts
210
+ return output * x_mask
211
+
212
+ def remove_weight_norm(self):
213
+ if self.gin_channels != 0:
214
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
215
+ for l in self.in_layers:
216
+ torch.nn.utils.remove_weight_norm(l)
217
+ for l in self.res_skip_layers:
218
+ torch.nn.utils.remove_weight_norm(l)
219
+
220
+
221
+ class ResBlock1(torch.nn.Module):
222
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
223
+ super(ResBlock1, self).__init__()
224
+ self.convs1 = nn.ModuleList(
225
+ [
226
+ weight_norm(
227
+ Conv1d(
228
+ channels,
229
+ channels,
230
+ kernel_size,
231
+ 1,
232
+ dilation=dilation[0],
233
+ padding=get_padding(kernel_size, dilation[0]),
234
+ )
235
+ ),
236
+ weight_norm(
237
+ Conv1d(
238
+ channels,
239
+ channels,
240
+ kernel_size,
241
+ 1,
242
+ dilation=dilation[1],
243
+ padding=get_padding(kernel_size, dilation[1]),
244
+ )
245
+ ),
246
+ weight_norm(
247
+ Conv1d(
248
+ channels,
249
+ channels,
250
+ kernel_size,
251
+ 1,
252
+ dilation=dilation[2],
253
+ padding=get_padding(kernel_size, dilation[2]),
254
+ )
255
+ ),
256
+ ]
257
+ )
258
+ self.convs1.apply(init_weights)
259
+
260
+ self.convs2 = nn.ModuleList(
261
+ [
262
+ weight_norm(
263
+ Conv1d(
264
+ channels,
265
+ channels,
266
+ kernel_size,
267
+ 1,
268
+ dilation=1,
269
+ padding=get_padding(kernel_size, 1),
270
+ )
271
+ ),
272
+ weight_norm(
273
+ Conv1d(
274
+ channels,
275
+ channels,
276
+ kernel_size,
277
+ 1,
278
+ dilation=1,
279
+ padding=get_padding(kernel_size, 1),
280
+ )
281
+ ),
282
+ weight_norm(
283
+ Conv1d(
284
+ channels,
285
+ channels,
286
+ kernel_size,
287
+ 1,
288
+ dilation=1,
289
+ padding=get_padding(kernel_size, 1),
290
+ )
291
+ ),
292
+ ]
293
+ )
294
+ self.convs2.apply(init_weights)
295
+
296
+ def forward(self, x, x_mask=None):
297
+ for c1, c2 in zip(self.convs1, self.convs2):
298
+ xt = F.leaky_relu(x, LRELU_SLOPE)
299
+ if x_mask is not None:
300
+ xt = xt * x_mask
301
+ xt = c1(xt)
302
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
303
+ if x_mask is not None:
304
+ xt = xt * x_mask
305
+ xt = c2(xt)
306
+ x = xt + x
307
+ if x_mask is not None:
308
+ x = x * x_mask
309
+ return x
310
+
311
+ def remove_weight_norm(self):
312
+ for l in self.convs1:
313
+ remove_weight_norm(l)
314
+ for l in self.convs2:
315
+ remove_weight_norm(l)
316
+
317
+
318
+ class ResBlock2(torch.nn.Module):
319
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
320
+ super(ResBlock2, self).__init__()
321
+ self.convs = nn.ModuleList(
322
+ [
323
+ weight_norm(
324
+ Conv1d(
325
+ channels,
326
+ channels,
327
+ kernel_size,
328
+ 1,
329
+ dilation=dilation[0],
330
+ padding=get_padding(kernel_size, dilation[0]),
331
+ )
332
+ ),
333
+ weight_norm(
334
+ Conv1d(
335
+ channels,
336
+ channels,
337
+ kernel_size,
338
+ 1,
339
+ dilation=dilation[1],
340
+ padding=get_padding(kernel_size, dilation[1]),
341
+ )
342
+ ),
343
+ ]
344
+ )
345
+ self.convs.apply(init_weights)
346
+
347
+ def forward(self, x, x_mask=None):
348
+ for c in self.convs:
349
+ xt = F.leaky_relu(x, LRELU_SLOPE)
350
+ if x_mask is not None:
351
+ xt = xt * x_mask
352
+ xt = c(xt)
353
+ x = xt + x
354
+ if x_mask is not None:
355
+ x = x * x_mask
356
+ return x
357
+
358
+ def remove_weight_norm(self):
359
+ for l in self.convs:
360
+ remove_weight_norm(l)
361
+
362
+
363
+ class Log(nn.Module):
364
+ def forward(self, x, x_mask, reverse=False, **kwargs):
365
+ if not reverse:
366
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
367
+ logdet = torch.sum(-y, [1, 2])
368
+ return y, logdet
369
+ else:
370
+ x = torch.exp(x) * x_mask
371
+ return x
372
+
373
+
374
+ class Flip(nn.Module):
375
+ def forward(self, x, *args, reverse=False, **kwargs):
376
+ x = torch.flip(x, [1])
377
+ if not reverse:
378
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
379
+ return x, logdet
380
+ else:
381
+ return x
382
+
383
+
384
+ class ElementwiseAffine(nn.Module):
385
+ def __init__(self, channels):
386
+ super().__init__()
387
+ self.channels = channels
388
+ self.m = nn.Parameter(torch.zeros(channels, 1))
389
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
390
+
391
+ def forward(self, x, x_mask, reverse=False, **kwargs):
392
+ if not reverse:
393
+ y = self.m + torch.exp(self.logs) * x
394
+ y = y * x_mask
395
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
396
+ return y, logdet
397
+ else:
398
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
399
+ return x
400
+
401
+
402
+ class ResidualCouplingLayer(nn.Module):
403
+ def __init__(
404
+ self,
405
+ channels,
406
+ hidden_channels,
407
+ kernel_size,
408
+ dilation_rate,
409
+ n_layers,
410
+ p_dropout=0,
411
+ gin_channels=0,
412
+ mean_only=False,
413
+ ):
414
+ assert channels % 2 == 0, "channels should be divisible by 2"
415
+ super().__init__()
416
+ self.channels = channels
417
+ self.hidden_channels = hidden_channels
418
+ self.kernel_size = kernel_size
419
+ self.dilation_rate = dilation_rate
420
+ self.n_layers = n_layers
421
+ self.half_channels = channels // 2
422
+ self.mean_only = mean_only
423
+
424
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
425
+ self.enc = WN(
426
+ hidden_channels,
427
+ kernel_size,
428
+ dilation_rate,
429
+ n_layers,
430
+ p_dropout=p_dropout,
431
+ gin_channels=gin_channels,
432
+ )
433
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
434
+ self.post.weight.data.zero_()
435
+ self.post.bias.data.zero_()
436
+
437
+ def forward(self, x, x_mask, g=None, reverse=False):
438
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
439
+ h = self.pre(x0) * x_mask
440
+ h = self.enc(h, x_mask, g=g)
441
+ stats = self.post(h) * x_mask
442
+ if not self.mean_only:
443
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
444
+ else:
445
+ m = stats
446
+ logs = torch.zeros_like(m)
447
+
448
+ if not reverse:
449
+ x1 = m + x1 * torch.exp(logs) * x_mask
450
+ x = torch.cat([x0, x1], 1)
451
+ logdet = torch.sum(logs, [1, 2])
452
+ return x, logdet
453
+ else:
454
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
455
+ x = torch.cat([x0, x1], 1)
456
+ return x
457
+
458
+
459
+ class ConvFlow(nn.Module):
460
+ def __init__(
461
+ self,
462
+ in_channels,
463
+ filter_channels,
464
+ kernel_size,
465
+ n_layers,
466
+ num_bins=10,
467
+ tail_bound=5.0,
468
+ ):
469
+ super().__init__()
470
+ self.in_channels = in_channels
471
+ self.filter_channels = filter_channels
472
+ self.kernel_size = kernel_size
473
+ self.n_layers = n_layers
474
+ self.num_bins = num_bins
475
+ self.tail_bound = tail_bound
476
+ self.half_channels = in_channels // 2
477
+
478
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
479
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
480
+ self.proj = nn.Conv1d(
481
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
482
+ )
483
+ self.proj.weight.data.zero_()
484
+ self.proj.bias.data.zero_()
485
+
486
+ def forward(self, x, x_mask, g=None, reverse=False):
487
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
488
+ h = self.pre(x0)
489
+ h = self.convs(h, x_mask, g=g)
490
+ h = self.proj(h) * x_mask
491
+
492
+ b, c, t = x0.shape
493
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
494
+
495
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
496
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
497
+ self.filter_channels
498
+ )
499
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
500
+
501
+ x1, logabsdet = piecewise_rational_quadratic_transform(
502
+ x1,
503
+ unnormalized_widths,
504
+ unnormalized_heights,
505
+ unnormalized_derivatives,
506
+ inverse=reverse,
507
+ tails="linear",
508
+ tail_bound=self.tail_bound,
509
+ )
510
+
511
+ x = torch.cat([x0, x1], 1) * x_mask
512
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
513
+ if not reverse:
514
+ return x, logdet
515
+ else:
516
+ return x
517
+
518
+
519
+ class TransformerCouplingLayer(nn.Module):
520
+ def __init__(
521
+ self,
522
+ channels,
523
+ hidden_channels,
524
+ kernel_size,
525
+ n_layers,
526
+ n_heads,
527
+ p_dropout=0,
528
+ filter_channels=0,
529
+ mean_only=False,
530
+ wn_sharing_parameter=None,
531
+ gin_channels=0,
532
+ ):
533
+ assert n_layers == 3, n_layers
534
+ assert channels % 2 == 0, "channels should be divisible by 2"
535
+ super().__init__()
536
+ self.channels = channels
537
+ self.hidden_channels = hidden_channels
538
+ self.kernel_size = kernel_size
539
+ self.n_layers = n_layers
540
+ self.half_channels = channels // 2
541
+ self.mean_only = mean_only
542
+
543
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
544
+ self.enc = (
545
+ Encoder(
546
+ hidden_channels,
547
+ filter_channels,
548
+ n_heads,
549
+ n_layers,
550
+ kernel_size,
551
+ p_dropout,
552
+ isflow=True,
553
+ gin_channels=gin_channels,
554
+ )
555
+ if wn_sharing_parameter is None
556
+ else wn_sharing_parameter
557
+ )
558
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
559
+ self.post.weight.data.zero_()
560
+ self.post.bias.data.zero_()
561
+
562
+ def forward(self, x, x_mask, g=None, reverse=False):
563
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
564
+ h = self.pre(x0) * x_mask
565
+ h = self.enc(h, x_mask, g=g)
566
+ stats = self.post(h) * x_mask
567
+ if not self.mean_only:
568
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
569
+ else:
570
+ m = stats
571
+ logs = torch.zeros_like(m)
572
+
573
+ if not reverse:
574
+ x1 = m + x1 * torch.exp(logs) * x_mask
575
+ x = torch.cat([x0, x1], 1)
576
+ logdet = torch.sum(logs, [1, 2])
577
+ return x, logdet
578
+ else:
579
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
580
+ x = torch.cat([x0, x1], 1)
581
+ return x
582
+
583
+ x1, logabsdet = piecewise_rational_quadratic_transform(
584
+ x1,
585
+ unnormalized_widths,
586
+ unnormalized_heights,
587
+ unnormalized_derivatives,
588
+ inverse=reverse,
589
+ tails="linear",
590
+ tail_bound=self.tail_bound,
591
+ )
592
+
593
+ x = torch.cat([x0, x1], 1) * x_mask
594
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
595
+ if not reverse:
596
+ return x, logdet
597
+ else:
598
+ return x
melo/split_utils.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import glob
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import torchaudio
7
+ from txtsplit import txtsplit
8
+ def split_sentence(text, min_len=10, language_str='EN'):
9
+ if language_str in ['EN', 'FR', 'ES', 'SP', 'DE', 'RU']:
10
+ sentences = split_sentences_latin(text, min_len=min_len)
11
+ else:
12
+ sentences = split_sentences_zh(text, min_len=min_len)
13
+ return sentences
14
+
15
+ def split_sentences_latin(text, min_len=10):
16
+ text = re.sub('[。!?;]', '.', text)
17
+ text = re.sub('[,]', ',', text)
18
+ text = re.sub('[“”]', '"', text)
19
+ text = re.sub('[‘’]', "'", text)
20
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
21
+ return txtsplit(text, 512, 512)
22
+ # 将文本中的换行符、空格和制表符替换为空格
23
+ # text = re.sub('[\n\t ]+', ' ', text)
24
+ # # 在标点符号后添加一个空格
25
+ # text = re.sub('([,.!?;])', r'\1 $#!', text)
26
+ # # 分隔句子并去除前后空格
27
+ # sentences = [s.strip() for s in text.split('$#!')]
28
+ # if len(sentences[-1]) == 0: del sentences[-1]
29
+
30
+ # new_sentences = []
31
+ # new_sent = []
32
+ # count_len = 0
33
+ # for ind, sent in enumerate(sentences):
34
+ # # print(sent)
35
+ # new_sent.append(sent)
36
+ # count_len += len(sent.split(" "))
37
+ # if count_len > min_len or ind == len(sentences) - 1:
38
+ # count_len = 0
39
+ # new_sentences.append(' '.join(new_sent))
40
+ # new_sent = []
41
+ # return merge_short_sentences_en(new_sentences)
42
+
43
+ def split_sentences_zh(text, min_len=10):
44
+ text = re.sub('[。!?;]', '.', text)
45
+ text = re.sub('[,]', ',', text)
46
+ # 将文本中的换行符、空格和制表符替换为空格
47
+ text = re.sub('[\n\t ]+', ' ', text)
48
+ # 在标点符号后添加一个空格
49
+ text = re.sub('([,.!?;])', r'\1 $#!', text)
50
+ # 分隔句子并去除前后空格
51
+ # sentences = [s.strip() for s in re.split('(。|!|?|;)', text)]
52
+ sentences = [s.strip() for s in text.split('$#!')]
53
+ if len(sentences[-1]) == 0: del sentences[-1]
54
+
55
+ new_sentences = []
56
+ new_sent = []
57
+ count_len = 0
58
+ for ind, sent in enumerate(sentences):
59
+ new_sent.append(sent)
60
+ count_len += len(sent)
61
+ if count_len > min_len or ind == len(sentences) - 1:
62
+ count_len = 0
63
+ new_sentences.append(' '.join(new_sent))
64
+ new_sent = []
65
+ return merge_short_sentences_zh(new_sentences)
66
+
67
+ def merge_short_sentences_en(sens):
68
+ """Avoid short sentences by merging them with the following sentence.
69
+
70
+ Args:
71
+ List[str]: list of input sentences.
72
+
73
+ Returns:
74
+ List[str]: list of output sentences.
75
+ """
76
+ sens_out = []
77
+ for s in sens:
78
+ # If the previous sentense is too short, merge them with
79
+ # the current sentence.
80
+ if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
81
+ sens_out[-1] = sens_out[-1] + " " + s
82
+ else:
83
+ sens_out.append(s)
84
+ try:
85
+ if len(sens_out[-1].split(" ")) <= 2:
86
+ sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
87
+ sens_out.pop(-1)
88
+ except:
89
+ pass
90
+ return sens_out
91
+
92
+ def merge_short_sentences_zh(sens):
93
+ # return sens
94
+ """Avoid short sentences by merging them with the following sentence.
95
+
96
+ Args:
97
+ List[str]: list of input sentences.
98
+
99
+ Returns:
100
+ List[str]: list of output sentences.
101
+ """
102
+ sens_out = []
103
+ for s in sens:
104
+ # If the previous sentense is too short, merge them with
105
+ # the current sentence.
106
+ if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
107
+ sens_out[-1] = sens_out[-1] + " " + s
108
+ else:
109
+ sens_out.append(s)
110
+ try:
111
+ if len(sens_out[-1]) <= 2:
112
+ sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
113
+ sens_out.pop(-1)
114
+ except:
115
+ pass
116
+ return sens_out
117
+
118
+
119
+ if __name__ == '__main__':
120
+ zh_text = "好的,我来给你讲一个故事吧。从前有一个小姑娘,她叫做小红。小红非常喜欢在森林里玩耍,她经常会和她的小伙伴们一起去探险。有一天,小红和她的小伙伴们走到了森林深处,突然遇到了一只凶猛的野兽。小红的小伙伴们都吓得不敢动弹,但是小红并没有被吓倒,她勇敢地走向野兽,用她的智慧和勇气成功地制服了野兽,保护了她的小伙伴们。从那以后,小红变得更加勇敢和自信,成为了她小伙伴们心中的英雄。"
121
+ en_text = "I didn’t know what to do. I said please kill her because it would be better than being kidnapped,” Ben, whose surname CNN is not using for security concerns, said on Wednesday. “It’s a nightmare. I said ‘please kill her, don’t take her there.’"
122
+ sp_text = "¡Claro! ¿En qué tema te gustaría que te hable en español? Puedo proporcionarte información o conversar contigo sobre una amplia variedad de temas, desde cultura y comida hasta viajes y tecnología. ¿Tienes alguna preferencia en particular?"
123
+ fr_text = "Bien sûr ! En quelle matière voudriez-vous que je vous parle en français ? Je peux vous fournir des informations ou discuter avec vous sur une grande variété de sujets, que ce soit la culture, la nourriture, les voyages ou la technologie. Avez-vous une préférence particulière ?"
124
+ de_text = 'Es war das Wichtigste was wir sichern wollten da es keine Möglichkeit gab eine 20 Megatonnen- H- Bombe ab zu werfen von einem 30, C124.'
125
+ ru_text = 'Но он был во многом, как-бы, всё равно что сын плантатора, так как являлся сыном человека, у которого было в собственности много чего.'
126
+ print(split_sentence(zh_text, language_str='ZH'))
127
+ print(split_sentence(en_text, language_str='EN'))
128
+ print(split_sentence(sp_text, language_str='SP'))
129
+ print(split_sentence(fr_text, language_str='FR'))
130
+ print(split_sentence(de_text, language_str='DE'))
131
+ print(split_sentence(ru_text, language_str='RU'))
melo/text/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .symbols import *
2
+
3
+
4
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
5
+
6
+
7
+ def cleaned_text_to_sequence(cleaned_text, tones, language, symbol_to_id=None):
8
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
9
+ Args:
10
+ text: string to convert to a sequence
11
+ Returns:
12
+ List of integers corresponding to the symbols in the text
13
+ """
14
+ symbol_to_id_map = symbol_to_id if symbol_to_id else _symbol_to_id
15
+ phones = [symbol_to_id_map[symbol] for symbol in cleaned_text]
16
+ tone_start = language_tone_start_map[language]
17
+ tones = [i + tone_start for i in tones]
18
+ lang_id = language_id_map[language]
19
+ lang_ids = [lang_id for i in phones]
20
+ return phones, tones, lang_ids
21
+
22
+
23
+ def get_bert(norm_text, word2ph, language, device):
24
+ from .chinese_bert import get_bert_feature as zh_bert
25
+ from .english_bert import get_bert_feature as en_bert
26
+ from .japanese_bert import get_bert_feature as jp_bert
27
+ from .chinese_mix import get_bert_feature as zh_mix_en_bert
28
+ from .spanish_bert import get_bert_feature as sp_bert
29
+ from .french_bert import get_bert_feature as fr_bert
30
+ from .korean import get_bert_feature as kr_bert
31
+
32
+ lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert, 'ZH_MIX_EN': zh_mix_en_bert,
33
+ 'FR': fr_bert, 'SP': sp_bert, 'ES': sp_bert, "KR": kr_bert}
34
+ bert = lang_bert_func_map[language](norm_text, word2ph, device)
35
+ return bert
melo/text/chinese.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import cn2an
5
+ from pypinyin import lazy_pinyin, Style
6
+
7
+ from .symbols import punctuation
8
+ from .tone_sandhi import ToneSandhi
9
+
10
+ current_file_path = os.path.dirname(__file__)
11
+ pinyin_to_symbol_map = {
12
+ line.split("\t")[0]: line.strip().split("\t")[1]
13
+ for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
14
+ }
15
+
16
+ import jieba.posseg as psg
17
+
18
+
19
+ rep_map = {
20
+ ":": ",",
21
+ ";": ",",
22
+ ",": ",",
23
+ "。": ".",
24
+ "!": "!",
25
+ "?": "?",
26
+ "\n": ".",
27
+ "·": ",",
28
+ "、": ",",
29
+ "...": "…",
30
+ "$": ".",
31
+ "“": "'",
32
+ "”": "'",
33
+ "‘": "'",
34
+ "’": "'",
35
+ "(": "'",
36
+ ")": "'",
37
+ "(": "'",
38
+ ")": "'",
39
+ "《": "'",
40
+ "》": "'",
41
+ "【": "'",
42
+ "】": "'",
43
+ "[": "'",
44
+ "]": "'",
45
+ "—": "-",
46
+ "~": "-",
47
+ "~": "-",
48
+ "「": "'",
49
+ "」": "'",
50
+ }
51
+
52
+ tone_modifier = ToneSandhi()
53
+
54
+
55
+ def replace_punctuation(text):
56
+ text = text.replace("嗯", "恩").replace("呣", "母")
57
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
58
+
59
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
60
+
61
+ replaced_text = re.sub(
62
+ r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
63
+ )
64
+
65
+ return replaced_text
66
+
67
+
68
+ def g2p(text):
69
+ pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
70
+ sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
71
+ phones, tones, word2ph = _g2p(sentences)
72
+ assert sum(word2ph) == len(phones)
73
+ assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch.
74
+ phones = ["_"] + phones + ["_"]
75
+ tones = [0] + tones + [0]
76
+ word2ph = [1] + word2ph + [1]
77
+ return phones, tones, word2ph
78
+
79
+
80
+ def _get_initials_finals(word):
81
+ initials = []
82
+ finals = []
83
+ orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
84
+ orig_finals = lazy_pinyin(
85
+ word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
86
+ )
87
+ for c, v in zip(orig_initials, orig_finals):
88
+ initials.append(c)
89
+ finals.append(v)
90
+ return initials, finals
91
+
92
+
93
+ def _g2p(segments):
94
+ phones_list = []
95
+ tones_list = []
96
+ word2ph = []
97
+ for seg in segments:
98
+ # Replace all English words in the sentence
99
+ seg = re.sub("[a-zA-Z]+", "", seg)
100
+ seg_cut = psg.lcut(seg)
101
+ initials = []
102
+ finals = []
103
+ seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
104
+ for word, pos in seg_cut:
105
+ if pos == "eng":
106
+ import pdb; pdb.set_trace()
107
+ continue
108
+ sub_initials, sub_finals = _get_initials_finals(word)
109
+ sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
110
+ initials.append(sub_initials)
111
+ finals.append(sub_finals)
112
+
113
+ # assert len(sub_initials) == len(sub_finals) == len(word)
114
+ initials = sum(initials, [])
115
+ finals = sum(finals, [])
116
+ #
117
+ for c, v in zip(initials, finals):
118
+ raw_pinyin = c + v
119
+ # NOTE: post process for pypinyin outputs
120
+ # we discriminate i, ii and iii
121
+ if c == v:
122
+ assert c in punctuation
123
+ phone = [c]
124
+ tone = "0"
125
+ word2ph.append(1)
126
+ else:
127
+ v_without_tone = v[:-1]
128
+ tone = v[-1]
129
+
130
+ pinyin = c + v_without_tone
131
+ assert tone in "12345"
132
+
133
+ if c:
134
+ # 多音节
135
+ v_rep_map = {
136
+ "uei": "ui",
137
+ "iou": "iu",
138
+ "uen": "un",
139
+ }
140
+ if v_without_tone in v_rep_map.keys():
141
+ pinyin = c + v_rep_map[v_without_tone]
142
+ else:
143
+ # 单音节
144
+ pinyin_rep_map = {
145
+ "ing": "ying",
146
+ "i": "yi",
147
+ "in": "yin",
148
+ "u": "wu",
149
+ }
150
+ if pinyin in pinyin_rep_map.keys():
151
+ pinyin = pinyin_rep_map[pinyin]
152
+ else:
153
+ single_rep_map = {
154
+ "v": "yu",
155
+ "e": "e",
156
+ "i": "y",
157
+ "u": "w",
158
+ }
159
+ if pinyin[0] in single_rep_map.keys():
160
+ pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
161
+
162
+ assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
163
+ phone = pinyin_to_symbol_map[pinyin].split(" ")
164
+ word2ph.append(len(phone))
165
+
166
+ phones_list += phone
167
+ tones_list += [int(tone)] * len(phone)
168
+ return phones_list, tones_list, word2ph
169
+
170
+
171
+ def text_normalize(text):
172
+ numbers = re.findall(r"\d+(?:\.?\d+)?", text)
173
+ for number in numbers:
174
+ text = text.replace(number, cn2an.an2cn(number), 1)
175
+ text = replace_punctuation(text)
176
+ return text
177
+
178
+
179
+ def get_bert_feature(text, word2ph, device=None):
180
+ from text import chinese_bert
181
+
182
+ return chinese_bert.get_bert_feature(text, word2ph, device=device)
183
+
184
+
185
+ if __name__ == "__main__":
186
+ from text.chinese_bert import get_bert_feature
187
+
188
+ text = "啊!chemistry 但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
189
+ text = text_normalize(text)
190
+ print(text)
191
+ phones, tones, word2ph = g2p(text)
192
+ bert = get_bert_feature(text, word2ph)
193
+
194
+ print(phones, tones, word2ph, bert.shape)
195
+
196
+
197
+ # # 示例用法
198
+ # text = "这是一个示例文本:,你好!这是一个测试...."
199
+ # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
melo/text/chinese_bert.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
4
+
5
+
6
+ # model_id = 'hfl/chinese-roberta-wwm-ext-large'
7
+ local_path = "./bert/chinese-roberta-wwm-ext-large"
8
+
9
+
10
+ tokenizers = {}
11
+ models = {}
12
+
13
+ def get_bert_feature(text, word2ph, device=None, model_id='hfl/chinese-roberta-wwm-ext-large'):
14
+ if model_id not in models:
15
+ models[model_id] = AutoModelForMaskedLM.from_pretrained(
16
+ model_id
17
+ ).to(device)
18
+ tokenizers[model_id] = AutoTokenizer.from_pretrained(model_id)
19
+ model = models[model_id]
20
+ tokenizer = tokenizers[model_id]
21
+
22
+ if (
23
+ sys.platform == "darwin"
24
+ and torch.backends.mps.is_available()
25
+ and device == "cpu"
26
+ ):
27
+ device = "mps"
28
+ if not device:
29
+ device = "cuda"
30
+
31
+ with torch.no_grad():
32
+ inputs = tokenizer(text, return_tensors="pt")
33
+ for i in inputs:
34
+ inputs[i] = inputs[i].to(device)
35
+ res = model(**inputs, output_hidden_states=True)
36
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
37
+ # import pdb; pdb.set_trace()
38
+ # assert len(word2ph) == len(text) + 2
39
+ word2phone = word2ph
40
+ phone_level_feature = []
41
+ for i in range(len(word2phone)):
42
+ repeat_feature = res[i].repeat(word2phone[i], 1)
43
+ phone_level_feature.append(repeat_feature)
44
+
45
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
46
+ return phone_level_feature.T
47
+
48
+
49
+ if __name__ == "__main__":
50
+ import torch
51
+
52
+ word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征
53
+ word2phone = [
54
+ 1,
55
+ 2,
56
+ 1,
57
+ 2,
58
+ 2,
59
+ 1,
60
+ 2,
61
+ 2,
62
+ 1,
63
+ 2,
64
+ 2,
65
+ 1,
66
+ 2,
67
+ 2,
68
+ 2,
69
+ 2,
70
+ 2,
71
+ 1,
72
+ 1,
73
+ 2,
74
+ 2,
75
+ 1,
76
+ 2,
77
+ 2,
78
+ 2,
79
+ 2,
80
+ 1,
81
+ 2,
82
+ 2,
83
+ 2,
84
+ 2,
85
+ 2,
86
+ 1,
87
+ 2,
88
+ 2,
89
+ 2,
90
+ 2,
91
+ 1,
92
+ ]
93
+
94
+ # 计算总帧数
95
+ total_frames = sum(word2phone)
96
+ print(word_level_feature.shape)
97
+ print(word2phone)
98
+ phone_level_feature = []
99
+ for i in range(len(word2phone)):
100
+ print(word_level_feature[i].shape)
101
+
102
+ # 对每个词重复word2phone[i]次
103
+ repeat_feature = word_level_feature[i].repeat(word2phone[i], 1)
104
+ phone_level_feature.append(repeat_feature)
105
+
106
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
107
+ print(phone_level_feature.shape) # torch.Size([36, 1024])
melo/text/chinese_mix.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import cn2an
5
+ from pypinyin import lazy_pinyin, Style
6
+
7
+ # from text.symbols import punctuation
8
+ from .symbols import language_tone_start_map
9
+ from .tone_sandhi import ToneSandhi
10
+ from .english import g2p as g2p_en
11
+ from transformers import AutoTokenizer
12
+
13
+ punctuation = ["!", "?", "…", ",", ".", "'", "-"]
14
+ current_file_path = os.path.dirname(__file__)
15
+ pinyin_to_symbol_map = {
16
+ line.split("\t")[0]: line.strip().split("\t")[1]
17
+ for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
18
+ }
19
+
20
+ import jieba.posseg as psg
21
+
22
+
23
+ rep_map = {
24
+ ":": ",",
25
+ ";": ",",
26
+ ",": ",",
27
+ "。": ".",
28
+ "!": "!",
29
+ "?": "?",
30
+ "\n": ".",
31
+ "·": ",",
32
+ "、": ",",
33
+ "...": "…",
34
+ "$": ".",
35
+ "“": "'",
36
+ "”": "'",
37
+ "‘": "'",
38
+ "’": "'",
39
+ "(": "'",
40
+ ")": "'",
41
+ "(": "'",
42
+ ")": "'",
43
+ "《": "'",
44
+ "》": "'",
45
+ "【": "'",
46
+ "】": "'",
47
+ "[": "'",
48
+ "]": "'",
49
+ "—": "-",
50
+ "~": "-",
51
+ "~": "-",
52
+ "「": "'",
53
+ "」": "'",
54
+ }
55
+
56
+ tone_modifier = ToneSandhi()
57
+
58
+
59
+ def replace_punctuation(text):
60
+ text = text.replace("嗯", "恩").replace("呣", "母")
61
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
62
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
63
+ replaced_text = re.sub(r"[^\u4e00-\u9fa5_a-zA-Z\s" + "".join(punctuation) + r"]+", "", replaced_text)
64
+ replaced_text = re.sub(r"[\s]+", " ", replaced_text)
65
+
66
+ return replaced_text
67
+
68
+
69
+ def g2p(text, impl='v2'):
70
+ pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
71
+ sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
72
+ if impl == 'v1':
73
+ _func = _g2p
74
+ elif impl == 'v2':
75
+ _func = _g2p_v2
76
+ else:
77
+ raise NotImplementedError()
78
+ phones, tones, word2ph = _func(sentences)
79
+ assert sum(word2ph) == len(phones)
80
+ # assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch.
81
+ phones = ["_"] + phones + ["_"]
82
+ tones = [0] + tones + [0]
83
+ word2ph = [1] + word2ph + [1]
84
+ return phones, tones, word2ph
85
+
86
+
87
+ def _get_initials_finals(word):
88
+ initials = []
89
+ finals = []
90
+ orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
91
+ orig_finals = lazy_pinyin(
92
+ word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
93
+ )
94
+ for c, v in zip(orig_initials, orig_finals):
95
+ initials.append(c)
96
+ finals.append(v)
97
+ return initials, finals
98
+
99
+ model_id = 'bert-base-multilingual-uncased'
100
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
101
+ def _g2p(segments):
102
+ phones_list = []
103
+ tones_list = []
104
+ word2ph = []
105
+ for seg in segments:
106
+ # Replace all English words in the sentence
107
+ # seg = re.sub("[a-zA-Z]+", "", seg)
108
+ seg_cut = psg.lcut(seg)
109
+ initials = []
110
+ finals = []
111
+ seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
112
+ for word, pos in seg_cut:
113
+ if pos == "eng":
114
+ initials.append(['EN_WORD'])
115
+ finals.append([word])
116
+ else:
117
+ sub_initials, sub_finals = _get_initials_finals(word)
118
+ sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
119
+ initials.append(sub_initials)
120
+ finals.append(sub_finals)
121
+
122
+ # assert len(sub_initials) == len(sub_finals) == len(word)
123
+ initials = sum(initials, [])
124
+ finals = sum(finals, [])
125
+ #
126
+ for c, v in zip(initials, finals):
127
+ if c == 'EN_WORD':
128
+ tokenized_en = tokenizer.tokenize(v)
129
+ phones_en, tones_en, word2ph_en = g2p_en(text=None, pad_start_end=False, tokenized=tokenized_en)
130
+ # apply offset to tones_en
131
+ tones_en = [t + language_tone_start_map['EN'] for t in tones_en]
132
+ phones_list += phones_en
133
+ tones_list += tones_en
134
+ word2ph += word2ph_en
135
+ else:
136
+ raw_pinyin = c + v
137
+ # NOTE: post process for pypinyin outputs
138
+ # we discriminate i, ii and iii
139
+ if c == v:
140
+ assert c in punctuation
141
+ phone = [c]
142
+ tone = "0"
143
+ word2ph.append(1)
144
+ else:
145
+ v_without_tone = v[:-1]
146
+ tone = v[-1]
147
+
148
+ pinyin = c + v_without_tone
149
+ assert tone in "12345"
150
+
151
+ if c:
152
+ # 多音节
153
+ v_rep_map = {
154
+ "uei": "ui",
155
+ "iou": "iu",
156
+ "uen": "un",
157
+ }
158
+ if v_without_tone in v_rep_map.keys():
159
+ pinyin = c + v_rep_map[v_without_tone]
160
+ else:
161
+ # 单音节
162
+ pinyin_rep_map = {
163
+ "ing": "ying",
164
+ "i": "yi",
165
+ "in": "yin",
166
+ "u": "wu",
167
+ }
168
+ if pinyin in pinyin_rep_map.keys():
169
+ pinyin = pinyin_rep_map[pinyin]
170
+ else:
171
+ single_rep_map = {
172
+ "v": "yu",
173
+ "e": "e",
174
+ "i": "y",
175
+ "u": "w",
176
+ }
177
+ if pinyin[0] in single_rep_map.keys():
178
+ pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
179
+
180
+ assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
181
+ phone = pinyin_to_symbol_map[pinyin].split(" ")
182
+ word2ph.append(len(phone))
183
+
184
+ phones_list += phone
185
+ tones_list += [int(tone)] * len(phone)
186
+ return phones_list, tones_list, word2ph
187
+
188
+
189
+ def text_normalize(text):
190
+ numbers = re.findall(r"\d+(?:\.?\d+)?", text)
191
+ for number in numbers:
192
+ text = text.replace(number, cn2an.an2cn(number), 1)
193
+ text = replace_punctuation(text)
194
+ return text
195
+
196
+
197
+ def get_bert_feature(text, word2ph, device):
198
+ from . import chinese_bert
199
+ return chinese_bert.get_bert_feature(text, word2ph, model_id='bert-base-multilingual-uncased', device=device)
200
+
201
+ from .chinese import _g2p as _chinese_g2p
202
+ def _g2p_v2(segments):
203
+ spliter = '#$&^!@'
204
+
205
+ phones_list = []
206
+ tones_list = []
207
+ word2ph = []
208
+
209
+ for text in segments:
210
+ assert spliter not in text
211
+ # replace all english words
212
+ text = re.sub('([a-zA-Z\s]+)', lambda x: f'{spliter}{x.group(1)}{spliter}', text)
213
+ texts = text.split(spliter)
214
+ texts = [t for t in texts if len(t) > 0]
215
+
216
+
217
+ for text in texts:
218
+ if re.match('[a-zA-Z\s]+', text):
219
+ # english
220
+ tokenized_en = tokenizer.tokenize(text)
221
+ phones_en, tones_en, word2ph_en = g2p_en(text=None, pad_start_end=False, tokenized=tokenized_en)
222
+ # apply offset to tones_en
223
+ tones_en = [t + language_tone_start_map['EN'] for t in tones_en]
224
+ phones_list += phones_en
225
+ tones_list += tones_en
226
+ word2ph += word2ph_en
227
+ else:
228
+ phones_zh, tones_zh, word2ph_zh = _chinese_g2p([text])
229
+ phones_list += phones_zh
230
+ tones_list += tones_zh
231
+ word2ph += word2ph_zh
232
+ return phones_list, tones_list, word2ph
233
+
234
+
235
+
236
+ if __name__ == "__main__":
237
+ # from text.chinese_bert import get_bert_feature
238
+
239
+ text = "NFT啊!chemistry 但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
240
+ text = '我最近在学习machine learning,希望能够在未来的artificial intelligence领域有所建树。'
241
+ text = '今天下午,我们准备去shopping mall购物,然后晚上去看一场movie。'
242
+ text = '我们现在 also 能够 help 很多公司 use some machine learning 的 algorithms 啊!'
243
+ text = text_normalize(text)
244
+ print(text)
245
+ phones, tones, word2ph = g2p(text, impl='v2')
246
+ bert = get_bert_feature(text, word2ph, device='cuda:0')
247
+ print(phones)
248
+ import pdb; pdb.set_trace()
249
+
250
+
251
+ # # 示例用法
252
+ # text = "这是一个示例文本:,你好!这是一个测试...."
253
+ # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
melo/text/cleaner.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import chinese, japanese, english, chinese_mix, korean, french, spanish
2
+ from . import cleaned_text_to_sequence
3
+ import copy
4
+
5
+ language_module_map = {"ZH": chinese, "JP": japanese, "EN": english, 'ZH_MIX_EN': chinese_mix, 'KR': korean,
6
+ 'FR': french, 'SP': spanish, 'ES': spanish}
7
+
8
+
9
+ def clean_text(text, language):
10
+ language_module = language_module_map[language]
11
+ norm_text = language_module.text_normalize(text)
12
+ phones, tones, word2ph = language_module.g2p(norm_text)
13
+ return norm_text, phones, tones, word2ph
14
+
15
+
16
+ def clean_text_bert(text, language, device=None):
17
+ language_module = language_module_map[language]
18
+ norm_text = language_module.text_normalize(text)
19
+ phones, tones, word2ph = language_module.g2p(norm_text)
20
+
21
+ word2ph_bak = copy.deepcopy(word2ph)
22
+ for i in range(len(word2ph)):
23
+ word2ph[i] = word2ph[i] * 2
24
+ word2ph[0] += 1
25
+ bert = language_module.get_bert_feature(norm_text, word2ph, device=device)
26
+
27
+ return norm_text, phones, tones, word2ph_bak, bert
28
+
29
+
30
+ def text_to_sequence(text, language):
31
+ norm_text, phones, tones, word2ph = clean_text(text, language)
32
+ return cleaned_text_to_sequence(phones, tones, language)
33
+
34
+
35
+ if __name__ == "__main__":
36
+ pass
melo/text/cleaner_multiling.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Set of default text cleaners"""
2
+ # TODO: pick the cleaner for languages dynamically
3
+
4
+ import re
5
+
6
+ # Regular expression matching whitespace:
7
+ _whitespace_re = re.compile(r"\s+")
8
+
9
+ rep_map = {
10
+ ":": ",",
11
+ ";": ",",
12
+ ",": ",",
13
+ "。": ".",
14
+ "!": "!",
15
+ "?": "?",
16
+ "\n": ".",
17
+ "·": ",",
18
+ "、": ",",
19
+ "...": ".",
20
+ "…": ".",
21
+ "$": ".",
22
+ "“": "'",
23
+ "”": "'",
24
+ "‘": "'",
25
+ "’": "'",
26
+ "(": "'",
27
+ ")": "'",
28
+ "(": "'",
29
+ ")": "'",
30
+ "《": "'",
31
+ "》": "'",
32
+ "【": "'",
33
+ "】": "'",
34
+ "[": "'",
35
+ "]": "'",
36
+ "—": "",
37
+ "~": "-",
38
+ "~": "-",
39
+ "「": "'",
40
+ "」": "'",
41
+ }
42
+
43
+ def replace_punctuation(text):
44
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
45
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
46
+ return replaced_text
47
+
48
+ def lowercase(text):
49
+ return text.lower()
50
+
51
+
52
+ def collapse_whitespace(text):
53
+ return re.sub(_whitespace_re, " ", text).strip()
54
+
55
+ def remove_punctuation_at_begin(text):
56
+ return re.sub(r'^[,.!?]+', '', text)
57
+
58
+ def remove_aux_symbols(text):
59
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»\']+", "", text)
60
+ return text
61
+
62
+
63
+ def replace_symbols(text, lang="en"):
64
+ """Replace symbols based on the lenguage tag.
65
+
66
+ Args:
67
+ text:
68
+ Input text.
69
+ lang:
70
+ Lenguage identifier. ex: "en", "fr", "pt", "ca".
71
+
72
+ Returns:
73
+ The modified text
74
+ example:
75
+ input args:
76
+ text: "si l'avi cau, diguem-ho"
77
+ lang: "ca"
78
+ Output:
79
+ text: "si lavi cau, diguemho"
80
+ """
81
+ text = text.replace(";", ",")
82
+ text = text.replace("-", " ") if lang != "ca" else text.replace("-", "")
83
+ text = text.replace(":", ",")
84
+ if lang == "en":
85
+ text = text.replace("&", " and ")
86
+ elif lang == "fr":
87
+ text = text.replace("&", " et ")
88
+ elif lang == "pt":
89
+ text = text.replace("&", " e ")
90
+ elif lang == "ca":
91
+ text = text.replace("&", " i ")
92
+ text = text.replace("'", "")
93
+ elif lang== "es":
94
+ text=text.replace("&","y")
95
+ text = text.replace("'", "")
96
+ return text
97
+
98
+ def unicleaners(text, cased=False, lang='en'):
99
+ """Basic pipeline for Portuguese text. There is no need to expand abbreviation and
100
+ numbers, phonemizer already does that"""
101
+ if not cased:
102
+ text = lowercase(text)
103
+ text = replace_punctuation(text)
104
+ text = replace_symbols(text, lang=lang)
105
+ text = remove_aux_symbols(text)
106
+ text = remove_punctuation_at_begin(text)
107
+ text = collapse_whitespace(text)
108
+ text = re.sub(r'([^\.,!\?\-…])$', r'\1.', text)
109
+ return text
110
+
melo/text/cmudict.rep ADDED
The diff for this file is too large to render. See raw diff
 
melo/text/cmudict_cache.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9b21b20325471934ba92f2e4a5976989e7d920caa32e7a286eacb027d197949
3
+ size 6212655
melo/text/english.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import os
3
+ import re
4
+ from g2p_en import G2p
5
+
6
+ from . import symbols
7
+
8
+ from .english_utils.abbreviations import expand_abbreviations
9
+ from .english_utils.time_norm import expand_time_english
10
+ from .english_utils.number_norm import normalize_numbers
11
+ from .japanese import distribute_phone
12
+
13
+ from transformers import AutoTokenizer
14
+
15
+ current_file_path = os.path.dirname(__file__)
16
+ CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
17
+ CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle")
18
+ _g2p = G2p()
19
+
20
+ arpa = {
21
+ "AH0",
22
+ "S",
23
+ "AH1",
24
+ "EY2",
25
+ "AE2",
26
+ "EH0",
27
+ "OW2",
28
+ "UH0",
29
+ "NG",
30
+ "B",
31
+ "G",
32
+ "AY0",
33
+ "M",
34
+ "AA0",
35
+ "F",
36
+ "AO0",
37
+ "ER2",
38
+ "UH1",
39
+ "IY1",
40
+ "AH2",
41
+ "DH",
42
+ "IY0",
43
+ "EY1",
44
+ "IH0",
45
+ "K",
46
+ "N",
47
+ "W",
48
+ "IY2",
49
+ "T",
50
+ "AA1",
51
+ "ER1",
52
+ "EH2",
53
+ "OY0",
54
+ "UH2",
55
+ "UW1",
56
+ "Z",
57
+ "AW2",
58
+ "AW1",
59
+ "V",
60
+ "UW2",
61
+ "AA2",
62
+ "ER",
63
+ "AW0",
64
+ "UW0",
65
+ "R",
66
+ "OW1",
67
+ "EH1",
68
+ "ZH",
69
+ "AE0",
70
+ "IH2",
71
+ "IH",
72
+ "Y",
73
+ "JH",
74
+ "P",
75
+ "AY1",
76
+ "EY0",
77
+ "OY2",
78
+ "TH",
79
+ "HH",
80
+ "D",
81
+ "ER0",
82
+ "CH",
83
+ "AO1",
84
+ "AE1",
85
+ "AO2",
86
+ "OY1",
87
+ "AY2",
88
+ "IH1",
89
+ "OW0",
90
+ "L",
91
+ "SH",
92
+ }
93
+
94
+
95
+ def post_replace_ph(ph):
96
+ rep_map = {
97
+ ":": ",",
98
+ ";": ",",
99
+ ",": ",",
100
+ "。": ".",
101
+ "!": "!",
102
+ "?": "?",
103
+ "\n": ".",
104
+ "·": ",",
105
+ "、": ",",
106
+ "...": "…",
107
+ "v": "V",
108
+ }
109
+ if ph in rep_map.keys():
110
+ ph = rep_map[ph]
111
+ if ph in symbols:
112
+ return ph
113
+ if ph not in symbols:
114
+ ph = "UNK"
115
+ return ph
116
+
117
+
118
+ def read_dict():
119
+ g2p_dict = {}
120
+ start_line = 49
121
+ with open(CMU_DICT_PATH) as f:
122
+ line = f.readline()
123
+ line_index = 1
124
+ while line:
125
+ if line_index >= start_line:
126
+ line = line.strip()
127
+ word_split = line.split(" ")
128
+ word = word_split[0]
129
+
130
+ syllable_split = word_split[1].split(" - ")
131
+ g2p_dict[word] = []
132
+ for syllable in syllable_split:
133
+ phone_split = syllable.split(" ")
134
+ g2p_dict[word].append(phone_split)
135
+
136
+ line_index = line_index + 1
137
+ line = f.readline()
138
+
139
+ return g2p_dict
140
+
141
+
142
+ def cache_dict(g2p_dict, file_path):
143
+ with open(file_path, "wb") as pickle_file:
144
+ pickle.dump(g2p_dict, pickle_file)
145
+
146
+
147
+ def get_dict():
148
+ if os.path.exists(CACHE_PATH):
149
+ with open(CACHE_PATH, "rb") as pickle_file:
150
+ g2p_dict = pickle.load(pickle_file)
151
+ else:
152
+ g2p_dict = read_dict()
153
+ cache_dict(g2p_dict, CACHE_PATH)
154
+
155
+ return g2p_dict
156
+
157
+
158
+ eng_dict = get_dict()
159
+
160
+
161
+ def refine_ph(phn):
162
+ tone = 0
163
+ if re.search(r"\d$", phn):
164
+ tone = int(phn[-1]) + 1
165
+ phn = phn[:-1]
166
+ return phn.lower(), tone
167
+
168
+
169
+ def refine_syllables(syllables):
170
+ tones = []
171
+ phonemes = []
172
+ for phn_list in syllables:
173
+ for i in range(len(phn_list)):
174
+ phn = phn_list[i]
175
+ phn, tone = refine_ph(phn)
176
+ phonemes.append(phn)
177
+ tones.append(tone)
178
+ return phonemes, tones
179
+
180
+
181
+ def text_normalize(text):
182
+ text = text.lower()
183
+ text = expand_time_english(text)
184
+ text = normalize_numbers(text)
185
+ text = expand_abbreviations(text)
186
+ return text
187
+
188
+ model_id = 'bert-base-uncased'
189
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
190
+ def g2p_old(text):
191
+ tokenized = tokenizer.tokenize(text)
192
+ # import pdb; pdb.set_trace()
193
+ phones = []
194
+ tones = []
195
+ words = re.split(r"([,;.\-\?\!\s+])", text)
196
+ for w in words:
197
+ if w.upper() in eng_dict:
198
+ phns, tns = refine_syllables(eng_dict[w.upper()])
199
+ phones += phns
200
+ tones += tns
201
+ else:
202
+ phone_list = list(filter(lambda p: p != " ", _g2p(w)))
203
+ for ph in phone_list:
204
+ if ph in arpa:
205
+ ph, tn = refine_ph(ph)
206
+ phones.append(ph)
207
+ tones.append(tn)
208
+ else:
209
+ phones.append(ph)
210
+ tones.append(0)
211
+ # todo: implement word2ph
212
+ word2ph = [1 for i in phones]
213
+
214
+ phones = [post_replace_ph(i) for i in phones]
215
+ return phones, tones, word2ph
216
+
217
+ def g2p(text, pad_start_end=True, tokenized=None):
218
+ if tokenized is None:
219
+ tokenized = tokenizer.tokenize(text)
220
+ # import pdb; pdb.set_trace()
221
+ phs = []
222
+ ph_groups = []
223
+ for t in tokenized:
224
+ if not t.startswith("#"):
225
+ ph_groups.append([t])
226
+ else:
227
+ ph_groups[-1].append(t.replace("#", ""))
228
+
229
+ phones = []
230
+ tones = []
231
+ word2ph = []
232
+ for group in ph_groups:
233
+ w = "".join(group)
234
+ phone_len = 0
235
+ word_len = len(group)
236
+ if w.upper() in eng_dict:
237
+ phns, tns = refine_syllables(eng_dict[w.upper()])
238
+ phones += phns
239
+ tones += tns
240
+ phone_len += len(phns)
241
+ else:
242
+ phone_list = list(filter(lambda p: p != " ", _g2p(w)))
243
+ for ph in phone_list:
244
+ if ph in arpa:
245
+ ph, tn = refine_ph(ph)
246
+ phones.append(ph)
247
+ tones.append(tn)
248
+ else:
249
+ phones.append(ph)
250
+ tones.append(0)
251
+ phone_len += 1
252
+ aaa = distribute_phone(phone_len, word_len)
253
+ word2ph += aaa
254
+ phones = [post_replace_ph(i) for i in phones]
255
+
256
+ if pad_start_end:
257
+ phones = ["_"] + phones + ["_"]
258
+ tones = [0] + tones + [0]
259
+ word2ph = [1] + word2ph + [1]
260
+ return phones, tones, word2ph
261
+
262
+ def get_bert_feature(text, word2ph, device=None):
263
+ from text import english_bert
264
+
265
+ return english_bert.get_bert_feature(text, word2ph, device=device)
266
+
267
+ if __name__ == "__main__":
268
+ # print(get_dict())
269
+ # print(eng_word_to_phoneme("hello"))
270
+ from text.english_bert import get_bert_feature
271
+ text = "In this paper, we propose 1 DSPGAN, a N-F-T GAN-based universal vocoder."
272
+ text = text_normalize(text)
273
+ phones, tones, word2ph = g2p(text)
274
+ import pdb; pdb.set_trace()
275
+ bert = get_bert_feature(text, word2ph)
276
+
277
+ print(phones, tones, word2ph, bert.shape)
278
+
279
+ # all_phones = set()
280
+ # for k, syllables in eng_dict.items():
281
+ # for group in syllables:
282
+ # for ph in group:
283
+ # all_phones.add(ph)
284
+ # print(all_phones)
melo/text/english_bert.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
3
+ import sys
4
+
5
+ model_id = 'bert-base-uncased'
6
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
7
+ model = None
8
+
9
+ def get_bert_feature(text, word2ph, device=None):
10
+ global model
11
+ if (
12
+ sys.platform == "darwin"
13
+ and torch.backends.mps.is_available()
14
+ and device == "cpu"
15
+ ):
16
+ device = "mps"
17
+ if not device:
18
+ device = "cuda"
19
+ if model is None:
20
+ model = AutoModelForMaskedLM.from_pretrained(model_id).to(
21
+ device
22
+ )
23
+ with torch.no_grad():
24
+ inputs = tokenizer(text, return_tensors="pt")
25
+ for i in inputs:
26
+ inputs[i] = inputs[i].to(device)
27
+ res = model(**inputs, output_hidden_states=True)
28
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
29
+
30
+ assert inputs["input_ids"].shape[-1] == len(word2ph)
31
+ word2phone = word2ph
32
+ phone_level_feature = []
33
+ for i in range(len(word2phone)):
34
+ repeat_feature = res[i].repeat(word2phone[i], 1)
35
+ phone_level_feature.append(repeat_feature)
36
+
37
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
38
+
39
+ return phone_level_feature.T
melo/text/english_utils/__init__.py ADDED
File without changes
melo/text/english_utils/abbreviations.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ # List of (regular expression, replacement) pairs for abbreviations in english:
4
+ abbreviations_en = [
5
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
6
+ for x in [
7
+ ("mrs", "misess"),
8
+ ("mr", "mister"),
9
+ ("dr", "doctor"),
10
+ ("st", "saint"),
11
+ ("co", "company"),
12
+ ("jr", "junior"),
13
+ ("maj", "major"),
14
+ ("gen", "general"),
15
+ ("drs", "doctors"),
16
+ ("rev", "reverend"),
17
+ ("lt", "lieutenant"),
18
+ ("hon", "honorable"),
19
+ ("sgt", "sergeant"),
20
+ ("capt", "captain"),
21
+ ("esq", "esquire"),
22
+ ("ltd", "limited"),
23
+ ("col", "colonel"),
24
+ ("ft", "fort"),
25
+ ]
26
+ ]
27
+
28
+ def expand_abbreviations(text, lang="en"):
29
+ if lang == "en":
30
+ _abbreviations = abbreviations_en
31
+ else:
32
+ raise NotImplementedError()
33
+ for regex, replacement in _abbreviations:
34
+ text = re.sub(regex, replacement, text)
35
+ return text
melo/text/english_utils/number_norm.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ import re
4
+ from typing import Dict
5
+
6
+ import inflect
7
+
8
+ _inflect = inflect.engine()
9
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
10
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
11
+ _currency_re = re.compile(r"(£|\$|¥)([0-9\,\.]*[0-9]+)")
12
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
13
+ _number_re = re.compile(r"-?[0-9]+")
14
+
15
+
16
+ def _remove_commas(m):
17
+ return m.group(1).replace(",", "")
18
+
19
+
20
+ def _expand_decimal_point(m):
21
+ return m.group(1).replace(".", " point ")
22
+
23
+
24
+ def __expand_currency(value: str, inflection: Dict[float, str]) -> str:
25
+ parts = value.replace(",", "").split(".")
26
+ if len(parts) > 2:
27
+ return f"{value} {inflection[2]}" # Unexpected format
28
+ text = []
29
+ integer = int(parts[0]) if parts[0] else 0
30
+ if integer > 0:
31
+ integer_unit = inflection.get(integer, inflection[2])
32
+ text.append(f"{integer} {integer_unit}")
33
+ fraction = int(parts[1]) if len(parts) > 1 and parts[1] else 0
34
+ if fraction > 0:
35
+ fraction_unit = inflection.get(fraction / 100, inflection[0.02])
36
+ text.append(f"{fraction} {fraction_unit}")
37
+ if len(text) == 0:
38
+ return f"zero {inflection[2]}"
39
+ return " ".join(text)
40
+
41
+
42
+ def _expand_currency(m: "re.Match") -> str:
43
+ currencies = {
44
+ "$": {
45
+ 0.01: "cent",
46
+ 0.02: "cents",
47
+ 1: "dollar",
48
+ 2: "dollars",
49
+ },
50
+ "€": {
51
+ 0.01: "cent",
52
+ 0.02: "cents",
53
+ 1: "euro",
54
+ 2: "euros",
55
+ },
56
+ "£": {
57
+ 0.01: "penny",
58
+ 0.02: "pence",
59
+ 1: "pound sterling",
60
+ 2: "pounds sterling",
61
+ },
62
+ "¥": {
63
+ # TODO rin
64
+ 0.02: "sen",
65
+ 2: "yen",
66
+ },
67
+ }
68
+ unit = m.group(1)
69
+ currency = currencies[unit]
70
+ value = m.group(2)
71
+ return __expand_currency(value, currency)
72
+
73
+
74
+ def _expand_ordinal(m):
75
+ return _inflect.number_to_words(m.group(0))
76
+
77
+
78
+ def _expand_number(m):
79
+ num = int(m.group(0))
80
+ if 1000 < num < 3000:
81
+ if num == 2000:
82
+ return "two thousand"
83
+ if 2000 < num < 2010:
84
+ return "two thousand " + _inflect.number_to_words(num % 100)
85
+ if num % 100 == 0:
86
+ return _inflect.number_to_words(num // 100) + " hundred"
87
+ return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
88
+ return _inflect.number_to_words(num, andword="")
89
+
90
+
91
+ def normalize_numbers(text):
92
+ text = re.sub(_comma_number_re, _remove_commas, text)
93
+ text = re.sub(_currency_re, _expand_currency, text)
94
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
95
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
96
+ text = re.sub(_number_re, _expand_number, text)
97
+ return text
melo/text/english_utils/time_norm.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import inflect
4
+
5
+ _inflect = inflect.engine()
6
+
7
+ _time_re = re.compile(
8
+ r"""\b
9
+ ((0?[0-9])|(1[0-1])|(1[2-9])|(2[0-3])) # hours
10
+ :
11
+ ([0-5][0-9]) # minutes
12
+ \s*(a\\.m\\.|am|pm|p\\.m\\.|a\\.m|p\\.m)? # am/pm
13
+ \b""",
14
+ re.IGNORECASE | re.X,
15
+ )
16
+
17
+
18
+ def _expand_num(n: int) -> str:
19
+ return _inflect.number_to_words(n)
20
+
21
+
22
+ def _expand_time_english(match: "re.Match") -> str:
23
+ hour = int(match.group(1))
24
+ past_noon = hour >= 12
25
+ time = []
26
+ if hour > 12:
27
+ hour -= 12
28
+ elif hour == 0:
29
+ hour = 12
30
+ past_noon = True
31
+ time.append(_expand_num(hour))
32
+
33
+ minute = int(match.group(6))
34
+ if minute > 0:
35
+ if minute < 10:
36
+ time.append("oh")
37
+ time.append(_expand_num(minute))
38
+ am_pm = match.group(7)
39
+ if am_pm is None:
40
+ time.append("p m" if past_noon else "a m")
41
+ else:
42
+ time.extend(list(am_pm.replace(".", "")))
43
+ return " ".join(time)
44
+
45
+
46
+ def expand_time_english(text: str) -> str:
47
+ return re.sub(_time_re, _expand_time_english, text)
melo/text/es_phonemizer/__init__.py ADDED
File without changes
melo/text/es_phonemizer/base.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import List, Tuple
3
+
4
+ from .punctuation import Punctuation
5
+
6
+
7
+ class BasePhonemizer(abc.ABC):
8
+ """Base phonemizer class
9
+
10
+ Phonemization follows the following steps:
11
+ 1. Preprocessing:
12
+ - remove empty lines
13
+ - remove punctuation
14
+ - keep track of punctuation marks
15
+
16
+ 2. Phonemization:
17
+ - convert text to phonemes
18
+
19
+ 3. Postprocessing:
20
+ - join phonemes
21
+ - restore punctuation marks
22
+
23
+ Args:
24
+ language (str):
25
+ Language used by the phonemizer.
26
+
27
+ punctuations (List[str]):
28
+ List of punctuation marks to be preserved.
29
+
30
+ keep_puncs (bool):
31
+ Whether to preserve punctuation marks or not.
32
+ """
33
+
34
+ def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False):
35
+ # ensure the backend is installed on the system
36
+ if not self.is_available():
37
+ raise RuntimeError("{} not installed on your system".format(self.name())) # pragma: nocover
38
+
39
+ # ensure the backend support the requested language
40
+ self._language = self._init_language(language)
41
+
42
+ # setup punctuation processing
43
+ self._keep_puncs = keep_puncs
44
+ self._punctuator = Punctuation(punctuations)
45
+
46
+ def _init_language(self, language):
47
+ """Language initialization
48
+
49
+ This method may be overloaded in child classes (see Segments backend)
50
+
51
+ """
52
+ if not self.is_supported_language(language):
53
+ raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend")
54
+ return language
55
+
56
+ @property
57
+ def language(self):
58
+ """The language code configured to be used for phonemization"""
59
+ return self._language
60
+
61
+ @staticmethod
62
+ @abc.abstractmethod
63
+ def name():
64
+ """The name of the backend"""
65
+ ...
66
+
67
+ @classmethod
68
+ @abc.abstractmethod
69
+ def is_available(cls):
70
+ """Returns True if the backend is installed, False otherwise"""
71
+ ...
72
+
73
+ @classmethod
74
+ @abc.abstractmethod
75
+ def version(cls):
76
+ """Return the backend version as a tuple (major, minor, patch)"""
77
+ ...
78
+
79
+ @staticmethod
80
+ @abc.abstractmethod
81
+ def supported_languages():
82
+ """Return a dict of language codes -> name supported by the backend"""
83
+ ...
84
+
85
+ def is_supported_language(self, language):
86
+ """Returns True if `language` is supported by the backend"""
87
+ return language in self.supported_languages()
88
+
89
+ @abc.abstractmethod
90
+ def _phonemize(self, text, separator):
91
+ """The main phonemization method"""
92
+
93
+ def _phonemize_preprocess(self, text) -> Tuple[List[str], List]:
94
+ """Preprocess the text before phonemization
95
+
96
+ 1. remove spaces
97
+ 2. remove punctuation
98
+
99
+ Override this if you need a different behaviour
100
+ """
101
+ text = text.strip()
102
+ if self._keep_puncs:
103
+ # a tuple (text, punctuation marks)
104
+ return self._punctuator.strip_to_restore(text)
105
+ return [self._punctuator.strip(text)], []
106
+
107
+ def _phonemize_postprocess(self, phonemized, punctuations) -> str:
108
+ """Postprocess the raw phonemized output
109
+
110
+ Override this if you need a different behaviour
111
+ """
112
+ if self._keep_puncs:
113
+ return self._punctuator.restore(phonemized, punctuations)[0]
114
+ return phonemized[0]
115
+
116
+ def phonemize(self, text: str, separator="|", language: str = None) -> str: # pylint: disable=unused-argument
117
+ """Returns the `text` phonemized for the given language
118
+
119
+ Args:
120
+ text (str):
121
+ Text to be phonemized.
122
+
123
+ separator (str):
124
+ string separator used between phonemes. Default to '_'.
125
+
126
+ Returns:
127
+ (str): Phonemized text
128
+ """
129
+ text, punctuations = self._phonemize_preprocess(text)
130
+ phonemized = []
131
+ for t in text:
132
+ p = self._phonemize(t, separator)
133
+ phonemized.append(p)
134
+ phonemized = self._phonemize_postprocess(phonemized, punctuations)
135
+ return phonemized
136
+
137
+ def print_logs(self, level: int = 0):
138
+ indent = "\t" * level
139
+ print(f"{indent}| > phoneme language: {self.language}")
140
+ print(f"{indent}| > phoneme backend: {self.name()}")
melo/text/es_phonemizer/cleaner.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Set of default text cleaners"""
2
+ # TODO: pick the cleaner for languages dynamically
3
+
4
+ import re
5
+
6
+ # Regular expression matching whitespace:
7
+ _whitespace_re = re.compile(r"\s+")
8
+
9
+ rep_map = {
10
+ ":": ",",
11
+ ";": ",",
12
+ ",": ",",
13
+ "。": ".",
14
+ "!": "!",
15
+ "?": "?",
16
+ "\n": ".",
17
+ "·": ",",
18
+ "、": ",",
19
+ "...": ".",
20
+ "…": ".",
21
+ "$": ".",
22
+ "“": "'",
23
+ "”": "'",
24
+ "‘": "'",
25
+ "’": "'",
26
+ "(": "'",
27
+ ")": "'",
28
+ "(": "'",
29
+ ")": "'",
30
+ "《": "'",
31
+ "》": "'",
32
+ "【": "'",
33
+ "】": "'",
34
+ "[": "'",
35
+ "]": "'",
36
+ "—": "",
37
+ "~": "-",
38
+ "~": "-",
39
+ "「": "'",
40
+ "」": "'",
41
+ }
42
+
43
+ def replace_punctuation(text):
44
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
45
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
46
+ return replaced_text
47
+
48
+ def lowercase(text):
49
+ return text.lower()
50
+
51
+
52
+ def collapse_whitespace(text):
53
+ return re.sub(_whitespace_re, " ", text).strip()
54
+
55
+ def remove_punctuation_at_begin(text):
56
+ return re.sub(r'^[,.!?]+', '', text)
57
+
58
+ def remove_aux_symbols(text):
59
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»\']+", "", text)
60
+ return text
61
+
62
+
63
+ def replace_symbols(text, lang="en"):
64
+ """Replace symbols based on the lenguage tag.
65
+
66
+ Args:
67
+ text:
68
+ Input text.
69
+ lang:
70
+ Lenguage identifier. ex: "en", "fr", "pt", "ca".
71
+
72
+ Returns:
73
+ The modified text
74
+ example:
75
+ input args:
76
+ text: "si l'avi cau, diguem-ho"
77
+ lang: "ca"
78
+ Output:
79
+ text: "si lavi cau, diguemho"
80
+ """
81
+ text = text.replace(";", ",")
82
+ text = text.replace("-", " ") if lang != "ca" else text.replace("-", "")
83
+ text = text.replace(":", ",")
84
+ if lang == "en":
85
+ text = text.replace("&", " and ")
86
+ elif lang == "fr":
87
+ text = text.replace("&", " et ")
88
+ elif lang == "pt":
89
+ text = text.replace("&", " e ")
90
+ elif lang == "ca":
91
+ text = text.replace("&", " i ")
92
+ text = text.replace("'", "")
93
+ elif lang== "es":
94
+ text=text.replace("&","y")
95
+ text = text.replace("'", "")
96
+ return text
97
+
98
+ def spanish_cleaners(text):
99
+ """Basic pipeline for Portuguese text. There is no need to expand abbreviation and
100
+ numbers, phonemizer already does that"""
101
+ text = lowercase(text)
102
+ text = replace_symbols(text, lang="es")
103
+ text = replace_punctuation(text)
104
+ text = remove_aux_symbols(text)
105
+ text = remove_punctuation_at_begin(text)
106
+ text = collapse_whitespace(text)
107
+ text = re.sub(r'([^\.,!\?\-…])$', r'\1.', text)
108
+ return text
109
+
melo/text/es_phonemizer/es_symbols.json ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "symbols": [
3
+ "_",
4
+ ",",
5
+ ".",
6
+ "!",
7
+ "?",
8
+ "-",
9
+ "~",
10
+ "\u2026",
11
+ "N",
12
+ "Q",
13
+ "a",
14
+ "b",
15
+ "d",
16
+ "e",
17
+ "f",
18
+ "g",
19
+ "h",
20
+ "i",
21
+ "j",
22
+ "k",
23
+ "l",
24
+ "m",
25
+ "n",
26
+ "o",
27
+ "p",
28
+ "s",
29
+ "t",
30
+ "u",
31
+ "v",
32
+ "w",
33
+ "x",
34
+ "y",
35
+ "z",
36
+ "\u0251",
37
+ "\u00e6",
38
+ "\u0283",
39
+ "\u0291",
40
+ "\u00e7",
41
+ "\u026f",
42
+ "\u026a",
43
+ "\u0254",
44
+ "\u025b",
45
+ "\u0279",
46
+ "\u00f0",
47
+ "\u0259",
48
+ "\u026b",
49
+ "\u0265",
50
+ "\u0278",
51
+ "\u028a",
52
+ "\u027e",
53
+ "\u0292",
54
+ "\u03b8",
55
+ "\u03b2",
56
+ "\u014b",
57
+ "\u0266",
58
+ "\u207c",
59
+ "\u02b0",
60
+ "`",
61
+ "^",
62
+ "#",
63
+ "*",
64
+ "=",
65
+ "\u02c8",
66
+ "\u02cc",
67
+ "\u2192",
68
+ "\u2193",
69
+ "\u2191",
70
+ " ",
71
+ "\u0263",
72
+ "\u0261",
73
+ "r",
74
+ "\u0272",
75
+ "\u029d",
76
+ "\u028e",
77
+ "\u02d0"
78
+ ]
79
+ }
melo/text/es_phonemizer/es_symbols.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ _,.!?-~…NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ɡrɲʝɣʎː—¿¡
melo/text/es_phonemizer/es_symbols_v2.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "symbols": [
3
+ "_",
4
+ ",",
5
+ ".",
6
+ "!",
7
+ "?",
8
+ "-",
9
+ "~",
10
+ "\u2026",
11
+ "N",
12
+ "Q",
13
+ "a",
14
+ "b",
15
+ "d",
16
+ "e",
17
+ "f",
18
+ "g",
19
+ "h",
20
+ "i",
21
+ "j",
22
+ "k",
23
+ "l",
24
+ "m",
25
+ "n",
26
+ "o",
27
+ "p",
28
+ "s",
29
+ "t",
30
+ "u",
31
+ "v",
32
+ "w",
33
+ "x",
34
+ "y",
35
+ "z",
36
+ "\u0251",
37
+ "\u00e6",
38
+ "\u0283",
39
+ "\u0291",
40
+ "\u00e7",
41
+ "\u026f",
42
+ "\u026a",
43
+ "\u0254",
44
+ "\u025b",
45
+ "\u0279",
46
+ "\u00f0",
47
+ "\u0259",
48
+ "\u026b",
49
+ "\u0265",
50
+ "\u0278",
51
+ "\u028a",
52
+ "\u027e",
53
+ "\u0292",
54
+ "\u03b8",
55
+ "\u03b2",
56
+ "\u014b",
57
+ "\u0266",
58
+ "\u207c",
59
+ "\u02b0",
60
+ "`",
61
+ "^",
62
+ "#",
63
+ "*",
64
+ "=",
65
+ "\u02c8",
66
+ "\u02cc",
67
+ "\u2192",
68
+ "\u2193",
69
+ "\u2191",
70
+ " ",
71
+ "\u0261",
72
+ "r",
73
+ "\u0272",
74
+ "\u029d",
75
+ "\u0263",
76
+ "\u028e",
77
+ "\u02d0",
78
+
79
+ "\u2014",
80
+ "\u00bf",
81
+ "\u00a1"
82
+ ]
83
+ }
melo/text/es_phonemizer/es_to_ipa.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .cleaner import spanish_cleaners
2
+ from .gruut_wrapper import Gruut
3
+
4
+ def es2ipa(text):
5
+ e = Gruut(language="es-es", keep_puncs=True, keep_stress=True, use_espeak_phonemes=True)
6
+ # text = spanish_cleaners(text)
7
+ phonemes = e.phonemize(text, separator="")
8
+ return phonemes
9
+
10
+
11
+ if __name__ == '__main__':
12
+ print(es2ipa('¿Y a quién echaría de menos, en el mundo si no fuese a vos?'))
melo/text/es_phonemizer/gruut_wrapper.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from typing import List
3
+
4
+ import gruut
5
+ from gruut_ipa import IPA # pip install gruut_ipa
6
+
7
+ from .base import BasePhonemizer
8
+ from .punctuation import Punctuation
9
+
10
+ # Table for str.translate to fix gruut/TTS phoneme mismatch
11
+ GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ")
12
+
13
+
14
+ class Gruut(BasePhonemizer):
15
+ """Gruut wrapper for G2P
16
+
17
+ Args:
18
+ language (str):
19
+ Valid language code for the used backend.
20
+
21
+ punctuations (str):
22
+ Characters to be treated as punctuation. Defaults to `Punctuation.default_puncs()`.
23
+
24
+ keep_puncs (bool):
25
+ If true, keep the punctuations after phonemization. Defaults to True.
26
+
27
+ use_espeak_phonemes (bool):
28
+ If true, use espeak lexicons instead of default Gruut lexicons. Defaults to False.
29
+
30
+ keep_stress (bool):
31
+ If true, keep the stress characters after phonemization. Defaults to False.
32
+
33
+ Example:
34
+
35
+ >>> from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
36
+ >>> phonemizer = Gruut('en-us')
37
+ >>> phonemizer.phonemize("Be a voice, not an! echo?", separator="|")
38
+ 'b|i| ə| v|ɔ|ɪ|s, n|ɑ|t| ə|n! ɛ|k|o|ʊ?'
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ language: str,
44
+ punctuations=Punctuation.default_puncs(),
45
+ keep_puncs=True,
46
+ use_espeak_phonemes=False,
47
+ keep_stress=False,
48
+ ):
49
+ super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs)
50
+ self.use_espeak_phonemes = use_espeak_phonemes
51
+ self.keep_stress = keep_stress
52
+
53
+ @staticmethod
54
+ def name():
55
+ return "gruut"
56
+
57
+ def phonemize_gruut(self, text: str, separator: str = "|", tie=False) -> str: # pylint: disable=unused-argument
58
+ """Convert input text to phonemes.
59
+
60
+ Gruut phonemizes the given `str` by seperating each phoneme character with `separator`, even for characters
61
+ that constitude a single sound.
62
+
63
+ It doesn't affect 🐸TTS since it individually converts each character to token IDs.
64
+
65
+ Examples::
66
+ "hello how are you today?" -> `h|ɛ|l|o|ʊ| h|a|ʊ| ɑ|ɹ| j|u| t|ə|d|e|ɪ`
67
+
68
+ Args:
69
+ text (str):
70
+ Text to be converted to phonemes.
71
+
72
+ tie (bool, optional) : When True use a '͡' character between
73
+ consecutive characters of a single phoneme. Else separate phoneme
74
+ with '_'. This option requires espeak>=1.49. Default to False.
75
+ """
76
+ ph_list = []
77
+ for sentence in gruut.sentences(text, lang=self.language, espeak=self.use_espeak_phonemes):
78
+ for word in sentence:
79
+ if word.is_break:
80
+ # Use actual character for break phoneme (e.g., comma)
81
+ if ph_list:
82
+ # Join with previous word
83
+ ph_list[-1].append(word.text)
84
+ else:
85
+ # First word is punctuation
86
+ ph_list.append([word.text])
87
+ elif word.phonemes:
88
+ # Add phonemes for word
89
+ word_phonemes = []
90
+
91
+ for word_phoneme in word.phonemes:
92
+ if not self.keep_stress:
93
+ # Remove primary/secondary stress
94
+ word_phoneme = IPA.without_stress(word_phoneme)
95
+
96
+ word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE)
97
+
98
+ if word_phoneme:
99
+ # Flatten phonemes
100
+ word_phonemes.extend(word_phoneme)
101
+
102
+ if word_phonemes:
103
+ ph_list.append(word_phonemes)
104
+
105
+ ph_words = [separator.join(word_phonemes) for word_phonemes in ph_list]
106
+ ph = f"{separator} ".join(ph_words)
107
+ return ph
108
+
109
+ def _phonemize(self, text, separator):
110
+ return self.phonemize_gruut(text, separator, tie=False)
111
+
112
+ def is_supported_language(self, language):
113
+ """Returns True if `language` is supported by the backend"""
114
+ return gruut.is_language_supported(language)
115
+
116
+ @staticmethod
117
+ def supported_languages() -> List:
118
+ """Get a dictionary of supported languages.
119
+
120
+ Returns:
121
+ List: List of language codes.
122
+ """
123
+ return list(gruut.get_supported_languages())
124
+
125
+ def version(self):
126
+ """Get the version of the used backend.
127
+
128
+ Returns:
129
+ str: Version of the used backend.
130
+ """
131
+ return gruut.__version__
132
+
133
+ @classmethod
134
+ def is_available(cls):
135
+ """Return true if ESpeak is available else false"""
136
+ return importlib.util.find_spec("gruut") is not None
137
+
138
+
139
+ if __name__ == "__main__":
140
+ from es_to_ipa import es2ipa
141
+ import json
142
+
143
+ e = Gruut(language="es-es", keep_puncs=True, keep_stress=True, use_espeak_phonemes=True)
144
+ symbols = [
145
+ "_",
146
+ ",",
147
+ ".",
148
+ "!",
149
+ "?",
150
+ "-",
151
+ "~",
152
+ "\u2026",
153
+ "N",
154
+ "Q",
155
+ "a",
156
+ "b",
157
+ "d",
158
+ "e",
159
+ "f",
160
+ "g",
161
+ "h",
162
+ "i",
163
+ "j",
164
+ "k",
165
+ "l",
166
+ "m",
167
+ "n",
168
+ "o",
169
+ "p",
170
+ "s",
171
+ "t",
172
+ "u",
173
+ "v",
174
+ "w",
175
+ "x",
176
+ "y",
177
+ "z",
178
+ "\u0251",
179
+ "\u00e6",
180
+ "\u0283",
181
+ "\u0291",
182
+ "\u00e7",
183
+ "\u026f",
184
+ "\u026a",
185
+ "\u0254",
186
+ "\u025b",
187
+ "\u0279",
188
+ "\u00f0",
189
+ "\u0259",
190
+ "\u026b",
191
+ "\u0265",
192
+ "\u0278",
193
+ "\u028a",
194
+ "\u027e",
195
+ "\u0292",
196
+ "\u03b8",
197
+ "\u03b2",
198
+ "\u014b",
199
+ "\u0266",
200
+ "\u207c",
201
+ "\u02b0",
202
+ "`",
203
+ "^",
204
+ "#",
205
+ "*",
206
+ "=",
207
+ "\u02c8",
208
+ "\u02cc",
209
+ "\u2192",
210
+ "\u2193",
211
+ "\u2191",
212
+ " ",
213
+ ]
214
+ with open('./text/es_phonemizer/spanish_text.txt', 'r') as f:
215
+ lines = f.readlines()
216
+
217
+
218
+ used_sym = []
219
+ not_existed_sym = []
220
+ phonemes = []
221
+
222
+ for line in lines[:400]:
223
+ text = line.split('|')[-1].strip()
224
+ ipa = es2ipa(text)
225
+ phonemes.append(ipa + '\n')
226
+ for s in ipa:
227
+ if s not in symbols:
228
+ if s not in not_existed_sym:
229
+ print(f'not_existed char: {s}')
230
+ not_existed_sym.append(s)
231
+ else:
232
+ if s not in used_sym:
233
+ # print(f'used char: {s}')
234
+ used_sym.append(s)
235
+
236
+ print(used_sym)
237
+ print(not_existed_sym)
238
+
239
+
240
+ with open('./text/es_phonemizer/es_symbols.txt', 'w') as g:
241
+ g.writelines(symbols + not_existed_sym)
242
+
243
+ with open('./text/es_phonemizer/example_ipa.txt', 'w') as g:
244
+ g.writelines(phonemes)
245
+
246
+ data = {'symbols': symbols + not_existed_sym}
247
+ with open('./text/es_phonemizer/es_symbols_v2.json', 'w') as f:
248
+ json.dump(data, f, indent=4)
249
+
250
+
251
+
252
+
253
+
melo/text/es_phonemizer/punctuation.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import re
3
+ from enum import Enum
4
+
5
+ import six
6
+
7
+ _DEF_PUNCS = ';:,.!?¡¿—…"«»“”'
8
+
9
+ _PUNC_IDX = collections.namedtuple("_punc_index", ["punc", "position"])
10
+
11
+
12
+ class PuncPosition(Enum):
13
+ """Enum for the punctuations positions"""
14
+
15
+ BEGIN = 0
16
+ END = 1
17
+ MIDDLE = 2
18
+ ALONE = 3
19
+
20
+
21
+ class Punctuation:
22
+ """Handle punctuations in text.
23
+
24
+ Just strip punctuations from text or strip and restore them later.
25
+
26
+ Args:
27
+ puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`.
28
+
29
+ Example:
30
+ >>> punc = Punctuation()
31
+ >>> punc.strip("This is. example !")
32
+ 'This is example'
33
+
34
+ >>> text_striped, punc_map = punc.strip_to_restore("This is. example !")
35
+ >>> ' '.join(text_striped)
36
+ 'This is example'
37
+
38
+ >>> text_restored = punc.restore(text_striped, punc_map)
39
+ >>> text_restored[0]
40
+ 'This is. example !'
41
+ """
42
+
43
+ def __init__(self, puncs: str = _DEF_PUNCS):
44
+ self.puncs = puncs
45
+
46
+ @staticmethod
47
+ def default_puncs():
48
+ """Return default set of punctuations."""
49
+ return _DEF_PUNCS
50
+
51
+ @property
52
+ def puncs(self):
53
+ return self._puncs
54
+
55
+ @puncs.setter
56
+ def puncs(self, value):
57
+ if not isinstance(value, six.string_types):
58
+ raise ValueError("[!] Punctuations must be of type str.")
59
+ self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder
60
+ self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+")
61
+
62
+ def strip(self, text):
63
+ """Remove all the punctuations by replacing with `space`.
64
+
65
+ Args:
66
+ text (str): The text to be processed.
67
+
68
+ Example::
69
+
70
+ "This is. example !" -> "This is example "
71
+ """
72
+ return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip()
73
+
74
+ def strip_to_restore(self, text):
75
+ """Remove punctuations from text to restore them later.
76
+
77
+ Args:
78
+ text (str): The text to be processed.
79
+
80
+ Examples ::
81
+
82
+ "This is. example !" -> [["This is", "example"], [".", "!"]]
83
+
84
+ """
85
+ text, puncs = self._strip_to_restore(text)
86
+ return text, puncs
87
+
88
+ def _strip_to_restore(self, text):
89
+ """Auxiliary method for Punctuation.preserve()"""
90
+ matches = list(re.finditer(self.puncs_regular_exp, text))
91
+ if not matches:
92
+ return [text], []
93
+ # the text is only punctuations
94
+ if len(matches) == 1 and matches[0].group() == text:
95
+ return [], [_PUNC_IDX(text, PuncPosition.ALONE)]
96
+ # build a punctuation map to be used later to restore punctuations
97
+ puncs = []
98
+ for match in matches:
99
+ position = PuncPosition.MIDDLE
100
+ if match == matches[0] and text.startswith(match.group()):
101
+ position = PuncPosition.BEGIN
102
+ elif match == matches[-1] and text.endswith(match.group()):
103
+ position = PuncPosition.END
104
+ puncs.append(_PUNC_IDX(match.group(), position))
105
+ # convert str text to a List[str], each item is separated by a punctuation
106
+ splitted_text = []
107
+ for idx, punc in enumerate(puncs):
108
+ split = text.split(punc.punc)
109
+ prefix, suffix = split[0], punc.punc.join(split[1:])
110
+ splitted_text.append(prefix)
111
+ # if the text does not end with a punctuation, add it to the last item
112
+ if idx == len(puncs) - 1 and len(suffix) > 0:
113
+ splitted_text.append(suffix)
114
+ text = suffix
115
+ while splitted_text[0] == '':
116
+ splitted_text = splitted_text[1:]
117
+ return splitted_text, puncs
118
+
119
+ @classmethod
120
+ def restore(cls, text, puncs):
121
+ """Restore punctuation in a text.
122
+
123
+ Args:
124
+ text (str): The text to be processed.
125
+ puncs (List[str]): The list of punctuations map to be used for restoring.
126
+
127
+ Examples ::
128
+
129
+ ['This is', 'example'], ['.', '!'] -> "This is. example!"
130
+
131
+ """
132
+ return cls._restore(text, puncs, 0)
133
+
134
+ @classmethod
135
+ def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statements
136
+ """Auxiliary method for Punctuation.restore()"""
137
+ if not puncs:
138
+ return text
139
+
140
+ # nothing have been phonemized, returns the puncs alone
141
+ if not text:
142
+ return ["".join(m.punc for m in puncs)]
143
+
144
+ current = puncs[0]
145
+
146
+ if current.position == PuncPosition.BEGIN:
147
+ return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num)
148
+
149
+ if current.position == PuncPosition.END:
150
+ return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1)
151
+
152
+ if current.position == PuncPosition.ALONE:
153
+ return [current.mark] + cls._restore(text, puncs[1:], num + 1)
154
+
155
+ # POSITION == MIDDLE
156
+ if len(text) == 1: # pragma: nocover
157
+ # a corner case where the final part of an intermediate
158
+ # mark (I) has not been phonemized
159
+ return cls._restore([text[0] + current.punc], puncs[1:], num)
160
+
161
+ return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num)
162
+
163
+
164
+ # if __name__ == "__main__":
165
+ # punc = Punctuation()
166
+ # text = "This is. This is, example!"
167
+
168
+ # print(punc.strip(text))
169
+
170
+ # split_text, puncs = punc.strip_to_restore(text)
171
+ # print(split_text, " ---- ", puncs)
172
+
173
+ # restored_text = punc.restore(split_text, puncs)
174
+ # print(restored_text)
melo/text/es_phonemizer/spanish_symbols.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ dˌaβˈiðkopeɾfjl unθsbmtʃwɛxɪŋʊɣɡrɲʝʎː
melo/text/es_phonemizer/test.ipynb ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "ename": "ImportError",
10
+ "evalue": "attempted relative import with no known parent package",
11
+ "output_type": "error",
12
+ "traceback": [
13
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
14
+ "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
15
+ "\u001b[1;32m/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb Cell 1\u001b[0m line \u001b[0;36m5\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mos\u001b[39;00m\u001b[39m,\u001b[39m \u001b[39msys\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a>\u001b[0m sys\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mappend(\u001b[39m'\u001b[39m\u001b[39m/home/xumin/workspace/MyShell-VC-Training/text/es_phonemizer/\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mes_to_ipa\u001b[39;00m \u001b[39mimport\u001b[39;00m es2ipa\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=8'>9</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39msplit_sentences_en\u001b[39m(text, min_len\u001b[39m=\u001b[39m\u001b[39m10\u001b[39m):\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=9'>10</a>\u001b[0m \u001b[39m# 将文本中的换行符、空格和制表符替换为空格\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=10'>11</a>\u001b[0m text \u001b[39m=\u001b[39m re\u001b[39m.\u001b[39msub(\u001b[39m'\u001b[39m\u001b[39m[\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m ]+\u001b[39m\u001b[39m'\u001b[39m, \u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39m'\u001b[39m, text)\n",
16
+ "File \u001b[0;32m/data/workspace/Bert-VITS2/text/es_phonemizer/es_to_ipa.py:1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39m.\u001b[39;00m\u001b[39mcleaner\u001b[39;00m \u001b[39mimport\u001b[39;00m spanish_cleaners\n\u001b[1;32m 2\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39m.\u001b[39;00m\u001b[39mgruut_wrapper\u001b[39;00m \u001b[39mimport\u001b[39;00m Gruut\n\u001b[1;32m 4\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mes2ipa\u001b[39m(text):\n",
17
+ "\u001b[0;31mImportError\u001b[0m: attempted relative import with no known parent package"
18
+ ]
19
+ }
20
+ ],
21
+ "source": [
22
+ "import re\n",
23
+ "import os\n",
24
+ "import os, sys\n",
25
+ "sys.path.append('/home/xumin/workspace/MyShell-VC-Training/text/es_phonemizer/')\n",
26
+ "from es_to_ipa import es2ipa\n",
27
+ "\n",
28
+ "\n",
29
+ "\n",
30
+ "def split_sentences_en(text, min_len=10):\n",
31
+ " # 将文本中的换行符、空格和制表符替换为空格\n",
32
+ " text = re.sub('[\\n\\t ]+', ' ', text)\n",
33
+ " # 在标点符号后添加一个空格\n",
34
+ " text = re.sub('([¿—¡])', r'\\1 $#!', text)\n",
35
+ " # 分隔句子并去除前后空格\n",
36
+ " \n",
37
+ " sentences = [s.strip() for s in text.split(' $#!')]\n",
38
+ " if len(sentences[-1]) == 0: del sentences[-1]\n",
39
+ "\n",
40
+ " new_sentences = []\n",
41
+ " new_sent = []\n",
42
+ " for ind, sent in enumerate(sentences):\n",
43
+ " if sent in ['¿', '—', '¡']:\n",
44
+ " new_sent.append(sent)\n",
45
+ " else:\n",
46
+ " new_sent.append(es2ipa(sent))\n",
47
+ " \n",
48
+ " \n",
49
+ " new_sentences = ''.join(new_sent)\n",
50
+ "\n",
51
+ " return new_sentences"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 3,
57
+ "metadata": {},
58
+ "outputs": [
59
+ {
60
+ "data": {
61
+ "text/plain": [
62
+ "'—¿aβˈeis estˈaðo kasˈaða alɣˈuna bˈeθ?'"
63
+ ]
64
+ },
65
+ "execution_count": 3,
66
+ "metadata": {},
67
+ "output_type": "execute_result"
68
+ }
69
+ ],
70
+ "source": [
71
+ "split_sentences_en('—¿Habéis estado casada alguna vez?')"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": 4,
77
+ "metadata": {},
78
+ "outputs": [
79
+ {
80
+ "data": {
81
+ "text/plain": [
82
+ "'aβˈeis estˈaðo kasˈaða alɣˈuna bˈeθ?'"
83
+ ]
84
+ },
85
+ "execution_count": 4,
86
+ "metadata": {},
87
+ "output_type": "execute_result"
88
+ }
89
+ ],
90
+ "source": [
91
+ "es2ipa('—¿Habéis estado casada alguna vez?')"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": []
100
+ }
101
+ ],
102
+ "metadata": {
103
+ "kernelspec": {
104
+ "display_name": "base",
105
+ "language": "python",
106
+ "name": "python3"
107
+ },
108
+ "language_info": {
109
+ "codemirror_mode": {
110
+ "name": "ipython",
111
+ "version": 3
112
+ },
113
+ "file_extension": ".py",
114
+ "mimetype": "text/x-python",
115
+ "name": "python",
116
+ "nbconvert_exporter": "python",
117
+ "pygments_lexer": "ipython3",
118
+ "version": "3.8.18"
119
+ },
120
+ "orig_nbformat": 4
121
+ },
122
+ "nbformat": 4,
123
+ "nbformat_minor": 2
124
+ }
melo/text/fr_phonemizer/__init__.py ADDED
File without changes
melo/text/fr_phonemizer/base.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import List, Tuple
3
+
4
+ from .punctuation import Punctuation
5
+
6
+
7
+ class BasePhonemizer(abc.ABC):
8
+ """Base phonemizer class
9
+
10
+ Phonemization follows the following steps:
11
+ 1. Preprocessing:
12
+ - remove empty lines
13
+ - remove punctuation
14
+ - keep track of punctuation marks
15
+
16
+ 2. Phonemization:
17
+ - convert text to phonemes
18
+
19
+ 3. Postprocessing:
20
+ - join phonemes
21
+ - restore punctuation marks
22
+
23
+ Args:
24
+ language (str):
25
+ Language used by the phonemizer.
26
+
27
+ punctuations (List[str]):
28
+ List of punctuation marks to be preserved.
29
+
30
+ keep_puncs (bool):
31
+ Whether to preserve punctuation marks or not.
32
+ """
33
+
34
+ def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False):
35
+ # ensure the backend is installed on the system
36
+ if not self.is_available():
37
+ raise RuntimeError("{} not installed on your system".format(self.name())) # pragma: nocover
38
+
39
+ # ensure the backend support the requested language
40
+ self._language = self._init_language(language)
41
+
42
+ # setup punctuation processing
43
+ self._keep_puncs = keep_puncs
44
+ self._punctuator = Punctuation(punctuations)
45
+
46
+ def _init_language(self, language):
47
+ """Language initialization
48
+
49
+ This method may be overloaded in child classes (see Segments backend)
50
+
51
+ """
52
+ if not self.is_supported_language(language):
53
+ raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend")
54
+ return language
55
+
56
+ @property
57
+ def language(self):
58
+ """The language code configured to be used for phonemization"""
59
+ return self._language
60
+
61
+ @staticmethod
62
+ @abc.abstractmethod
63
+ def name():
64
+ """The name of the backend"""
65
+ ...
66
+
67
+ @classmethod
68
+ @abc.abstractmethod
69
+ def is_available(cls):
70
+ """Returns True if the backend is installed, False otherwise"""
71
+ ...
72
+
73
+ @classmethod
74
+ @abc.abstractmethod
75
+ def version(cls):
76
+ """Return the backend version as a tuple (major, minor, patch)"""
77
+ ...
78
+
79
+ @staticmethod
80
+ @abc.abstractmethod
81
+ def supported_languages():
82
+ """Return a dict of language codes -> name supported by the backend"""
83
+ ...
84
+
85
+ def is_supported_language(self, language):
86
+ """Returns True if `language` is supported by the backend"""
87
+ return language in self.supported_languages()
88
+
89
+ @abc.abstractmethod
90
+ def _phonemize(self, text, separator):
91
+ """The main phonemization method"""
92
+
93
+ def _phonemize_preprocess(self, text) -> Tuple[List[str], List]:
94
+ """Preprocess the text before phonemization
95
+
96
+ 1. remove spaces
97
+ 2. remove punctuation
98
+
99
+ Override this if you need a different behaviour
100
+ """
101
+ text = text.strip()
102
+ if self._keep_puncs:
103
+ # a tuple (text, punctuation marks)
104
+ return self._punctuator.strip_to_restore(text)
105
+ return [self._punctuator.strip(text)], []
106
+
107
+ def _phonemize_postprocess(self, phonemized, punctuations) -> str:
108
+ """Postprocess the raw phonemized output
109
+
110
+ Override this if you need a different behaviour
111
+ """
112
+ if self._keep_puncs:
113
+ return self._punctuator.restore(phonemized, punctuations)[0]
114
+ return phonemized[0]
115
+
116
+ def phonemize(self, text: str, separator="|", language: str = None) -> str: # pylint: disable=unused-argument
117
+ """Returns the `text` phonemized for the given language
118
+
119
+ Args:
120
+ text (str):
121
+ Text to be phonemized.
122
+
123
+ separator (str):
124
+ string separator used between phonemes. Default to '_'.
125
+
126
+ Returns:
127
+ (str): Phonemized text
128
+ """
129
+ text, punctuations = self._phonemize_preprocess(text)
130
+ phonemized = []
131
+ for t in text:
132
+ p = self._phonemize(t, separator)
133
+ phonemized.append(p)
134
+ phonemized = self._phonemize_postprocess(phonemized, punctuations)
135
+ return phonemized
136
+
137
+ def print_logs(self, level: int = 0):
138
+ indent = "\t" * level
139
+ print(f"{indent}| > phoneme language: {self.language}")
140
+ print(f"{indent}| > phoneme backend: {self.name()}")
melo/text/fr_phonemizer/cleaner.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Set of default text cleaners"""
2
+ # TODO: pick the cleaner for languages dynamically
3
+
4
+ import re
5
+ from .french_abbreviations import abbreviations_fr
6
+
7
+ # Regular expression matching whitespace:
8
+ _whitespace_re = re.compile(r"\s+")
9
+
10
+
11
+ rep_map = {
12
+ ":": ",",
13
+ ";": ",",
14
+ ",": ",",
15
+ "。": ".",
16
+ "!": "!",
17
+ "?": "?",
18
+ "\n": ".",
19
+ "·": ",",
20
+ "、": ",",
21
+ "...": ".",
22
+ "…": ".",
23
+ "$": ".",
24
+ "“": "",
25
+ "”": "",
26
+ "‘": "",
27
+ "’": "",
28
+ "(": "",
29
+ ")": "",
30
+ "(": "",
31
+ ")": "",
32
+ "《": "",
33
+ "》": "",
34
+ "【": "",
35
+ "】": "",
36
+ "[": "",
37
+ "]": "",
38
+ "—": "",
39
+ "~": "-",
40
+ "~": "-",
41
+ "「": "",
42
+ "」": "",
43
+ "¿" : "",
44
+ "¡" : ""
45
+ }
46
+
47
+
48
+ def replace_punctuation(text):
49
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
50
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
51
+ return replaced_text
52
+
53
+ def expand_abbreviations(text, lang="fr"):
54
+ if lang == "fr":
55
+ _abbreviations = abbreviations_fr
56
+ for regex, replacement in _abbreviations:
57
+ text = re.sub(regex, replacement, text)
58
+ return text
59
+
60
+
61
+ def lowercase(text):
62
+ return text.lower()
63
+
64
+
65
+ def collapse_whitespace(text):
66
+ return re.sub(_whitespace_re, " ", text).strip()
67
+
68
+ def remove_punctuation_at_begin(text):
69
+ return re.sub(r'^[,.!?]+', '', text)
70
+
71
+ def remove_aux_symbols(text):
72
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
73
+ return text
74
+
75
+
76
+ def replace_symbols(text, lang="en"):
77
+ """Replace symbols based on the lenguage tag.
78
+
79
+ Args:
80
+ text:
81
+ Input text.
82
+ lang:
83
+ Lenguage identifier. ex: "en", "fr", "pt", "ca".
84
+
85
+ Returns:
86
+ The modified text
87
+ example:
88
+ input args:
89
+ text: "si l'avi cau, diguem-ho"
90
+ lang: "ca"
91
+ Output:
92
+ text: "si lavi cau, diguemho"
93
+ """
94
+ text = text.replace(";", ",")
95
+ text = text.replace("-", " ") if lang != "ca" else text.replace("-", "")
96
+ text = text.replace(":", ",")
97
+ if lang == "en":
98
+ text = text.replace("&", " and ")
99
+ elif lang == "fr":
100
+ text = text.replace("&", " et ")
101
+ elif lang == "pt":
102
+ text = text.replace("&", " e ")
103
+ elif lang == "ca":
104
+ text = text.replace("&", " i ")
105
+ text = text.replace("'", "")
106
+ elif lang== "es":
107
+ text=text.replace("&","y")
108
+ text = text.replace("'", "")
109
+ return text
110
+
111
+ def french_cleaners(text):
112
+ """Pipeline for French text. There is no need to expand numbers, phonemizer already does that"""
113
+ text = expand_abbreviations(text, lang="fr")
114
+ # text = lowercase(text) # as we use the cased bert
115
+ text = replace_punctuation(text)
116
+ text = replace_symbols(text, lang="fr")
117
+ text = remove_aux_symbols(text)
118
+ text = remove_punctuation_at_begin(text)
119
+ text = collapse_whitespace(text)
120
+ text = re.sub(r'([^\.,!\?\-…])$', r'\1.', text)
121
+ return text
122
+
melo/text/fr_phonemizer/en_symbols.json ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"symbols": [
2
+ "_",
3
+ ",",
4
+ ".",
5
+ "!",
6
+ "?",
7
+ "-",
8
+ "~",
9
+ "\u2026",
10
+ "N",
11
+ "Q",
12
+ "a",
13
+ "b",
14
+ "d",
15
+ "e",
16
+ "f",
17
+ "g",
18
+ "h",
19
+ "i",
20
+ "j",
21
+ "k",
22
+ "l",
23
+ "m",
24
+ "n",
25
+ "o",
26
+ "p",
27
+ "s",
28
+ "t",
29
+ "u",
30
+ "v",
31
+ "w",
32
+ "x",
33
+ "y",
34
+ "z",
35
+ "\u0251",
36
+ "\u00e6",
37
+ "\u0283",
38
+ "\u0291",
39
+ "\u00e7",
40
+ "\u026f",
41
+ "\u026a",
42
+ "\u0254",
43
+ "\u025b",
44
+ "\u0279",
45
+ "\u00f0",
46
+ "\u0259",
47
+ "\u026b",
48
+ "\u0265",
49
+ "\u0278",
50
+ "\u028a",
51
+ "\u027e",
52
+ "\u0292",
53
+ "\u03b8",
54
+ "\u03b2",
55
+ "\u014b",
56
+ "\u0266",
57
+ "\u207c",
58
+ "\u02b0",
59
+ "`",
60
+ "^",
61
+ "#",
62
+ "*",
63
+ "=",
64
+ "\u02c8",
65
+ "\u02cc",
66
+ "\u2192",
67
+ "\u2193",
68
+ "\u2191",
69
+ " ",
70
+ "ɣ",
71
+ "ɡ",
72
+ "r",
73
+ "ɲ",
74
+ "ʝ",
75
+ "ʎ",
76
+ "ː"
77
+ ]
78
+ }
melo/text/fr_phonemizer/fr_symbols.json ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "symbols": [
3
+ "_",
4
+ ",",
5
+ ".",
6
+ "!",
7
+ "?",
8
+ "-",
9
+ "~",
10
+ "\u2026",
11
+ "N",
12
+ "Q",
13
+ "a",
14
+ "b",
15
+ "d",
16
+ "e",
17
+ "f",
18
+ "g",
19
+ "h",
20
+ "i",
21
+ "j",
22
+ "k",
23
+ "l",
24
+ "m",
25
+ "n",
26
+ "o",
27
+ "p",
28
+ "s",
29
+ "t",
30
+ "u",
31
+ "v",
32
+ "w",
33
+ "x",
34
+ "y",
35
+ "z",
36
+ "\u0251",
37
+ "\u00e6",
38
+ "\u0283",
39
+ "\u0291",
40
+ "\u00e7",
41
+ "\u026f",
42
+ "\u026a",
43
+ "\u0254",
44
+ "\u025b",
45
+ "\u0279",
46
+ "\u00f0",
47
+ "\u0259",
48
+ "\u026b",
49
+ "\u0265",
50
+ "\u0278",
51
+ "\u028a",
52
+ "\u027e",
53
+ "\u0292",
54
+ "\u03b8",
55
+ "\u03b2",
56
+ "\u014b",
57
+ "\u0266",
58
+ "\u207c",
59
+ "\u02b0",
60
+ "`",
61
+ "^",
62
+ "#",
63
+ "*",
64
+ "=",
65
+ "\u02c8",
66
+ "\u02cc",
67
+ "\u2192",
68
+ "\u2193",
69
+ "\u2191",
70
+ " ",
71
+ "\u0263",
72
+ "\u0261",
73
+ "r",
74
+ "\u0272",
75
+ "\u029d",
76
+ "\u028e",
77
+ "\u02d0",
78
+
79
+ "\u0303",
80
+ "\u0153",
81
+ "\u00f8",
82
+ "\u0281",
83
+ "\u0252",
84
+ "\u028c",
85
+ "\u2014",
86
+ "\u025c",
87
+ "\u0250"
88
+ ]
89
+ }
melo/text/fr_phonemizer/fr_to_ipa.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .cleaner import french_cleaners
2
+ from .gruut_wrapper import Gruut
3
+
4
+
5
+ def remove_consecutive_t(input_str):
6
+ result = []
7
+ count = 0
8
+
9
+ for char in input_str:
10
+ if char == 't':
11
+ count += 1
12
+ else:
13
+ if count < 3:
14
+ result.extend(['t'] * count)
15
+ count = 0
16
+ result.append(char)
17
+
18
+ if count < 3:
19
+ result.extend(['t'] * count)
20
+
21
+ return ''.join(result)
22
+
23
+ def fr2ipa(text):
24
+ e = Gruut(language="fr-fr", keep_puncs=True, keep_stress=True, use_espeak_phonemes=True)
25
+ # text = french_cleaners(text)
26
+ phonemes = e.phonemize(text, separator="")
27
+ # print(phonemes)
28
+ phonemes = remove_consecutive_t(phonemes)
29
+ # print(phonemes)
30
+ return phonemes
melo/text/fr_phonemizer/french_abbreviations.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ # List of (regular expression, replacement) pairs for abbreviations in french:
4
+ abbreviations_fr = [
5
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
6
+ for x in [
7
+ ("M", "monsieur"),
8
+ ("Mlle", "mademoiselle"),
9
+ ("Mlles", "mesdemoiselles"),
10
+ ("Mme", "Madame"),
11
+ ("Mmes", "Mesdames"),
12
+ ("N.B", "nota bene"),
13
+ ("M", "monsieur"),
14
+ ("p.c.q", "parce que"),
15
+ ("Pr", "professeur"),
16
+ ("qqch", "quelque chose"),
17
+ ("rdv", "rendez-vous"),
18
+ ("max", "maximum"),
19
+ ("min", "minimum"),
20
+ ("no", "numéro"),
21
+ ("adr", "adresse"),
22
+ ("dr", "docteur"),
23
+ ("st", "saint"),
24
+ ("co", "companie"),
25
+ ("jr", "junior"),
26
+ ("sgt", "sergent"),
27
+ ("capt", "capitain"),
28
+ ("col", "colonel"),
29
+ ("av", "avenue"),
30
+ ("av. J.-C", "avant Jésus-Christ"),
31
+ ("apr. J.-C", "après Jésus-Christ"),
32
+ ("art", "article"),
33
+ ("boul", "boulevard"),
34
+ ("c.-à-d", "c’est-à-dire"),
35
+ ("etc", "et cetera"),
36
+ ("ex", "exemple"),
37
+ ("excl", "exclusivement"),
38
+ ("boul", "boulevard"),
39
+ ]
40
+ ] + [
41
+ (re.compile("\\b%s" % x[0]), x[1])
42
+ for x in [
43
+ ("Mlle", "mademoiselle"),
44
+ ("Mlles", "mesdemoiselles"),
45
+ ("Mme", "Madame"),
46
+ ("Mmes", "Mesdames"),
47
+ ]
48
+ ]
melo/text/fr_phonemizer/french_symbols.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ _,.!?-~…NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ɣɡrɲʝʎː̃œøʁɒʌ—ɜɐ
melo/text/fr_phonemizer/gruut_wrapper.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from typing import List
3
+
4
+ import gruut
5
+ from gruut_ipa import IPA # pip install gruut_ipa
6
+
7
+ from .base import BasePhonemizer
8
+ from .punctuation import Punctuation
9
+
10
+ # Table for str.translate to fix gruut/TTS phoneme mismatch
11
+ GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ")
12
+
13
+
14
+ class Gruut(BasePhonemizer):
15
+ """Gruut wrapper for G2P
16
+
17
+ Args:
18
+ language (str):
19
+ Valid language code for the used backend.
20
+
21
+ punctuations (str):
22
+ Characters to be treated as punctuation. Defaults to `Punctuation.default_puncs()`.
23
+
24
+ keep_puncs (bool):
25
+ If true, keep the punctuations after phonemization. Defaults to True.
26
+
27
+ use_espeak_phonemes (bool):
28
+ If true, use espeak lexicons instead of default Gruut lexicons. Defaults to False.
29
+
30
+ keep_stress (bool):
31
+ If true, keep the stress characters after phonemization. Defaults to False.
32
+
33
+ Example:
34
+
35
+ >>> from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
36
+ >>> phonemizer = Gruut('en-us')
37
+ >>> phonemizer.phonemize("Be a voice, not an! echo?", separator="|")
38
+ 'b|i| ə| v|ɔ|ɪ|s, n|ɑ|t| ə|n! ɛ|k|o|ʊ?'
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ language: str,
44
+ punctuations=Punctuation.default_puncs(),
45
+ keep_puncs=True,
46
+ use_espeak_phonemes=False,
47
+ keep_stress=False,
48
+ ):
49
+ super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs)
50
+ self.use_espeak_phonemes = use_espeak_phonemes
51
+ self.keep_stress = keep_stress
52
+
53
+ @staticmethod
54
+ def name():
55
+ return "gruut"
56
+
57
+ def phonemize_gruut(self, text: str, separator: str = "|", tie=False) -> str: # pylint: disable=unused-argument
58
+ """Convert input text to phonemes.
59
+
60
+ Gruut phonemizes the given `str` by seperating each phoneme character with `separator`, even for characters
61
+ that constitude a single sound.
62
+
63
+ It doesn't affect 🐸TTS since it individually converts each character to token IDs.
64
+
65
+ Examples::
66
+ "hello how are you today?" -> `h|ɛ|l|o|ʊ| h|a|ʊ| ɑ|ɹ| j|u| t|ə|d|e|ɪ`
67
+
68
+ Args:
69
+ text (str):
70
+ Text to be converted to phonemes.
71
+
72
+ tie (bool, optional) : When True use a '͡' character between
73
+ consecutive characters of a single phoneme. Else separate phoneme
74
+ with '_'. This option requires espeak>=1.49. Default to False.
75
+ """
76
+ ph_list = []
77
+ for sentence in gruut.sentences(text, lang=self.language, espeak=self.use_espeak_phonemes):
78
+ for word in sentence:
79
+ if word.is_break:
80
+ # Use actual character for break phoneme (e.g., comma)
81
+ if ph_list:
82
+ # Join with previous word
83
+ ph_list[-1].append(word.text)
84
+ else:
85
+ # First word is punctuation
86
+ ph_list.append([word.text])
87
+ elif word.phonemes:
88
+ # Add phonemes for word
89
+ word_phonemes = []
90
+
91
+ for word_phoneme in word.phonemes:
92
+ if not self.keep_stress:
93
+ # Remove primary/secondary stress
94
+ word_phoneme = IPA.without_stress(word_phoneme)
95
+
96
+ word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE)
97
+
98
+ if word_phoneme:
99
+ # Flatten phonemes
100
+ word_phonemes.extend(word_phoneme)
101
+
102
+ if word_phonemes:
103
+ ph_list.append(word_phonemes)
104
+
105
+ ph_words = [separator.join(word_phonemes) for word_phonemes in ph_list]
106
+ ph = f"{separator} ".join(ph_words)
107
+ return ph
108
+
109
+ def _phonemize(self, text, separator):
110
+ return self.phonemize_gruut(text, separator, tie=False)
111
+
112
+ def is_supported_language(self, language):
113
+ """Returns True if `language` is supported by the backend"""
114
+ return gruut.is_language_supported(language)
115
+
116
+ @staticmethod
117
+ def supported_languages() -> List:
118
+ """Get a dictionary of supported languages.
119
+
120
+ Returns:
121
+ List: List of language codes.
122
+ """
123
+ return list(gruut.get_supported_languages())
124
+
125
+ def version(self):
126
+ """Get the version of the used backend.
127
+
128
+ Returns:
129
+ str: Version of the used backend.
130
+ """
131
+ return gruut.__version__
132
+
133
+ @classmethod
134
+ def is_available(cls):
135
+ """Return true if ESpeak is available else false"""
136
+ return importlib.util.find_spec("gruut") is not None
137
+
138
+
139
+ if __name__ == "__main__":
140
+ from cleaner import french_cleaners
141
+ import json
142
+
143
+ e = Gruut(language="fr-fr", keep_puncs=True, keep_stress=True, use_espeak_phonemes=True)
144
+ symbols = [ # en + sp
145
+ "_",
146
+ ",",
147
+ ".",
148
+ "!",
149
+ "?",
150
+ "-",
151
+ "~",
152
+ "\u2026",
153
+ "N",
154
+ "Q",
155
+ "a",
156
+ "b",
157
+ "d",
158
+ "e",
159
+ "f",
160
+ "g",
161
+ "h",
162
+ "i",
163
+ "j",
164
+ "k",
165
+ "l",
166
+ "m",
167
+ "n",
168
+ "o",
169
+ "p",
170
+ "s",
171
+ "t",
172
+ "u",
173
+ "v",
174
+ "w",
175
+ "x",
176
+ "y",
177
+ "z",
178
+ "\u0251",
179
+ "\u00e6",
180
+ "\u0283",
181
+ "\u0291",
182
+ "\u00e7",
183
+ "\u026f",
184
+ "\u026a",
185
+ "\u0254",
186
+ "\u025b",
187
+ "\u0279",
188
+ "\u00f0",
189
+ "\u0259",
190
+ "\u026b",
191
+ "\u0265",
192
+ "\u0278",
193
+ "\u028a",
194
+ "\u027e",
195
+ "\u0292",
196
+ "\u03b8",
197
+ "\u03b2",
198
+ "\u014b",
199
+ "\u0266",
200
+ "\u207c",
201
+ "\u02b0",
202
+ "`",
203
+ "^",
204
+ "#",
205
+ "*",
206
+ "=",
207
+ "\u02c8",
208
+ "\u02cc",
209
+ "\u2192",
210
+ "\u2193",
211
+ "\u2191",
212
+ " ",
213
+ "ɣ",
214
+ "ɡ",
215
+ "r",
216
+ "ɲ",
217
+ "ʝ",
218
+ "ʎ",
219
+ "ː"
220
+ ]
221
+ with open('/home/xumin/workspace/VITS-Training-Multiling/230715_fr/metadata.txt', 'r') as f:
222
+ lines = f.readlines()
223
+
224
+
225
+ used_sym = []
226
+ not_existed_sym = []
227
+ phonemes = []
228
+
229
+ for line in lines:
230
+ text = line.split('|')[-1].strip()
231
+ text = french_cleaners(text)
232
+ ipa = e.phonemize(text, separator="")
233
+ phonemes.append(ipa)
234
+ for s in ipa:
235
+ if s not in symbols:
236
+ if s not in not_existed_sym:
237
+ print(f'not_existed char: {s}')
238
+ not_existed_sym.append(s)
239
+ else:
240
+ if s not in used_sym:
241
+ # print(f'used char: {s}')
242
+ used_sym.append(s)
243
+
244
+ print(used_sym)
245
+ print(not_existed_sym)
246
+
247
+
248
+ with open('./text/fr_phonemizer/french_symbols.txt', 'w') as g:
249
+ g.writelines(symbols + not_existed_sym)
250
+
251
+ with open('./text/fr_phonemizer/example_ipa.txt', 'w') as g:
252
+ g.writelines(phonemes)
253
+
254
+ data = {'symbols': symbols + not_existed_sym}
255
+
256
+ with open('./text/fr_phonemizer/fr_symbols.json', 'w') as f:
257
+ json.dump(data, f, indent=4)
258
+
melo/text/fr_phonemizer/punctuation.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import re
3
+ from enum import Enum
4
+
5
+ import six
6
+
7
+ _DEF_PUNCS = ';:,.!?¡¿—…"«»“”'
8
+
9
+ _PUNC_IDX = collections.namedtuple("_punc_index", ["punc", "position"])
10
+
11
+
12
+ class PuncPosition(Enum):
13
+ """Enum for the punctuations positions"""
14
+
15
+ BEGIN = 0
16
+ END = 1
17
+ MIDDLE = 2
18
+ ALONE = 3
19
+
20
+
21
+ class Punctuation:
22
+ """Handle punctuations in text.
23
+
24
+ Just strip punctuations from text or strip and restore them later.
25
+
26
+ Args:
27
+ puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`.
28
+
29
+ Example:
30
+ >>> punc = Punctuation()
31
+ >>> punc.strip("This is. example !")
32
+ 'This is example'
33
+
34
+ >>> text_striped, punc_map = punc.strip_to_restore("This is. example !")
35
+ >>> ' '.join(text_striped)
36
+ 'This is example'
37
+
38
+ >>> text_restored = punc.restore(text_striped, punc_map)
39
+ >>> text_restored[0]
40
+ 'This is. example !'
41
+ """
42
+
43
+ def __init__(self, puncs: str = _DEF_PUNCS):
44
+ self.puncs = puncs
45
+
46
+ @staticmethod
47
+ def default_puncs():
48
+ """Return default set of punctuations."""
49
+ return _DEF_PUNCS
50
+
51
+ @property
52
+ def puncs(self):
53
+ return self._puncs
54
+
55
+ @puncs.setter
56
+ def puncs(self, value):
57
+ if not isinstance(value, six.string_types):
58
+ raise ValueError("[!] Punctuations must be of type str.")
59
+ self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder
60
+ self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+")
61
+
62
+ def strip(self, text):
63
+ """Remove all the punctuations by replacing with `space`.
64
+
65
+ Args:
66
+ text (str): The text to be processed.
67
+
68
+ Example::
69
+
70
+ "This is. example !" -> "This is example "
71
+ """
72
+ return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip()
73
+
74
+ def strip_to_restore(self, text):
75
+ """Remove punctuations from text to restore them later.
76
+
77
+ Args:
78
+ text (str): The text to be processed.
79
+
80
+ Examples ::
81
+
82
+ "This is. example !" -> [["This is", "example"], [".", "!"]]
83
+
84
+ """
85
+ text, puncs = self._strip_to_restore(text)
86
+ return text, puncs
87
+
88
+ def _strip_to_restore(self, text):
89
+ """Auxiliary method for Punctuation.preserve()"""
90
+ matches = list(re.finditer(self.puncs_regular_exp, text))
91
+ if not matches:
92
+ return [text], []
93
+ # the text is only punctuations
94
+ if len(matches) == 1 and matches[0].group() == text:
95
+ return [], [_PUNC_IDX(text, PuncPosition.ALONE)]
96
+ # build a punctuation map to be used later to restore punctuations
97
+ puncs = []
98
+ for match in matches:
99
+ position = PuncPosition.MIDDLE
100
+ if match == matches[0] and text.startswith(match.group()):
101
+ position = PuncPosition.BEGIN
102
+ elif match == matches[-1] and text.endswith(match.group()):
103
+ position = PuncPosition.END
104
+ puncs.append(_PUNC_IDX(match.group(), position))
105
+ # convert str text to a List[str], each item is separated by a punctuation
106
+ splitted_text = []
107
+ for idx, punc in enumerate(puncs):
108
+ split = text.split(punc.punc)
109
+ prefix, suffix = split[0], punc.punc.join(split[1:])
110
+ splitted_text.append(prefix)
111
+ # if the text does not end with a punctuation, add it to the last item
112
+ if idx == len(puncs) - 1 and len(suffix) > 0:
113
+ splitted_text.append(suffix)
114
+ text = suffix
115
+ return splitted_text, puncs
116
+
117
+ @classmethod
118
+ def restore(cls, text, puncs):
119
+ """Restore punctuation in a text.
120
+
121
+ Args:
122
+ text (str): The text to be processed.
123
+ puncs (List[str]): The list of punctuations map to be used for restoring.
124
+
125
+ Examples ::
126
+
127
+ ['This is', 'example'], ['.', '!'] -> "This is. example!"
128
+
129
+ """
130
+ return cls._restore(text, puncs, 0)
131
+
132
+ @classmethod
133
+ def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statements
134
+ """Auxiliary method for Punctuation.restore()"""
135
+ if not puncs:
136
+ return text
137
+
138
+ # nothing have been phonemized, returns the puncs alone
139
+ if not text:
140
+ return ["".join(m.punc for m in puncs)]
141
+
142
+ current = puncs[0]
143
+
144
+ if current.position == PuncPosition.BEGIN:
145
+ return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num)
146
+
147
+ if current.position == PuncPosition.END:
148
+ return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1)
149
+
150
+ if current.position == PuncPosition.ALONE:
151
+ return [current.mark] + cls._restore(text, puncs[1:], num + 1)
152
+
153
+ # POSITION == MIDDLE
154
+ if len(text) == 1: # pragma: nocover
155
+ # a corner case where the final part of an intermediate
156
+ # mark (I) has not been phonemized
157
+ return cls._restore([text[0] + current.punc], puncs[1:], num)
158
+
159
+ return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num)
160
+
161
+
162
+ # if __name__ == "__main__":
163
+ # punc = Punctuation()
164
+ # text = "This is. This is, example!"
165
+
166
+ # print(punc.strip(text))
167
+
168
+ # split_text, puncs = punc.strip_to_restore(text)
169
+ # print(split_text, " ---- ", puncs)
170
+
171
+ # restored_text = punc.restore(split_text, puncs)
172
+ # print(restored_text)