File size: 11,091 Bytes
a180d8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0033bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a180d8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
# SPDX-FileCopyrightText: 2024 Idiap Research Institute
# SPDX-FileContributor: Karl El Hajal
#
# SPDX-License-Identifier: MIT

import os
import zipfile
import gradio as gr
import spaces
from huggingface_hub import snapshot_download

from knn_tts.synthesizer import Synthesizer
from knn_tts.utils import get_vocoder_checkpoint_path

# Check if target_feats directory exists, if not, unzip target_feats.zip
if not os.path.exists("target_feats"):
    if os.path.exists("target_feats.zip"):
        with zipfile.ZipFile("target_feats.zip", "r") as zip_ref:
            zip_ref.extractall(".")
    else:
        raise FileNotFoundError("target_feats.zip not found.")

SAMPLE_RATE = 16000

CHECKPOINTS_DIR = "./checkpoints"

tts_checkpoints_dir = snapshot_download(repo_id="idiap/kNN-TTS", local_dir=CHECKPOINTS_DIR)
vocoder_checkpoint_path = get_vocoder_checkpoint_path(CHECKPOINTS_DIR)

tts_checkpoint_name = "best_model_646135.pth"
synthesizer = Synthesizer(tts_checkpoints_dir, tts_checkpoint_name, vocoder_checkpoint_path, model_name="glowtts")

target_speakers = {
    "Libri 7127":{
        "feats_path": "target_feats/LibriSpeech-test-clean/7127/wavlm",
    },
    "Libri 7729":{
        "feats_path": "target_feats/LibriSpeech-test-clean/7729/wavlm",
    },
    "Libri 6829":{
        "feats_path": "target_feats/LibriSpeech-test-clean/6829/wavlm",
    },
    "Libri 8555":{
        "feats_path": "target_feats/LibriSpeech-test-clean/8555/wavlm",
    },
    "Thorsten Neutral": {
        "feats_path": "target_feats/Thorsten/neutral/wavlm/",
    },
    "Thorsten Whisper": {
        "feats_path": "target_feats/Thorsten/whisper/wavlm/",
    },
    "ESD 0018 Neutral":{
        "feats_path": "target_feats/ESD/0018/neutral/wavlm/",
    },
    "ESD 0018 Surprised":{
        "feats_path": "target_feats/ESD/0018/surprised/wavlm/",
    },
}

@spaces.GPU
def run(text_input, target_speaker, lambda_rate, topk, weighted_average):
    feats_path = target_speakers[target_speaker]["feats_path"]
    wav = synthesizer(text_input, feats_path, interpolation_rate=lambda_rate, knnvc_topk=topk, weighted_average=weighted_average, max_target_num_files=500)
    wav = (SAMPLE_RATE, wav.squeeze().cpu().numpy())
    return wav


def get_title(text, size=1):
    return f"""
    <center>

    <h{size}> {text} </h{size}>

    </center>
    """

def create_gradio_interface():
    with gr.Blocks(
        theme=gr.themes.Default(
            text_size="lg",
        ),
        title="kNN-TTS"
    ) as iface:
        
        gr.HTML(get_title("kNN-TTS: kNN Retrieval for Simple and Effective Zero-Shot Multi-speaker Text-to-Speech", size=1))

        with gr.Tabs():
            with gr.TabItem("Generate Speech"):
                with gr.Row():
                    # Left column - inputs
                    with gr.Column():
                        gr.Markdown("## Input")
                        text_box = gr.Textbox(
                            lines=3, 
                            placeholder="Enter the text to convert to speech...",
                            label="Text",
                            elem_id="text-input"
                        )
                        
                        target_speaker_dropdown = gr.Dropdown(
                            choices=list(target_speakers.keys()),
                            value="Libri 7127",
                            label="Target Voice",
                            elem_id="target-voice"
                        )
                        
                        rate_slider = gr.Slider(
                            minimum=0.0,
                            maximum=2.0,
                            value=1.0,
                            step=0.01,
                            label="Voice Morphing (λ)",
                            info="Higher values give more weight to target voice characteristics"
                        )
                            
                        with gr.Accordion("Advanced Settings", open=False):
                            k_slider = gr.Slider(
                                minimum=1,
                                maximum=50,
                                value=4,
                                step=1,
                                label="Top-k Retrieval",
                                info="k closest neighbors to retrieve"
                            )
                            weighted_toggle = gr.Checkbox(
                                label="Use Weighted Averaging",
                                value=False,
                                info="Weight neighbors by similarity distance"
                            )
                        
                        submit_button = gr.Button("Generate Audio", variant="primary", size="lg")
                    
                    # Right column - outputs
                    with gr.Column():
                        gr.Markdown("## Generated Audio")
                        with gr.Group():
                            audio_output = gr.Audio(
                                type="numpy",
                                label="Output Speech",
                                elem_id="audio-output"
                            )
                            with gr.Row():
                                clear_btn = gr.ClearButton([text_box, target_speaker_dropdown, rate_slider, audio_output], variant="secondary", size="lg")

                # Example section
                with gr.Row():
                    gr.Examples(
                        examples=[
                            ["I think foosball is a combination of football and shish kebabs.", "Thorsten Whisper", 1.0, 8, True],
                            ["I think foosball is a combination of football and shish kebabs.", "Thorsten Neutral", 1.0, 4, False],
                            ["If you're traveling in the north country fair.", "Libri 7127", 1.0, 4, False],
                            ["Like a vision she dances across the porch as the radio plays.", "Libri 7729", 1.0, 8, True],
                            ["There weren't another other way to be.", "Libri 6829", 1.0, 4, False],
                        ],
                        inputs=[text_box, target_speaker_dropdown, rate_slider, k_slider, weighted_toggle],
                        outputs=audio_output,
                        fn=run,
                        cache_examples=True
                    )
            
            # Additional tabs
            with gr.TabItem("Model Details"):
                with gr.Row():
                    with gr.Column():
                        gr.Markdown("""
                        ## kNN-TTS Technical Details
                        
                        kNN-TTS uses self-supervised learning (SSL) features and kNN retrieval to achieve robust zero-shot multi-speaker TTS.
                        
                        ### Key Components
                        
                        1. **Feature Extraction**: We extract discrete representations from target speaker speech using a pre-trained SSL encoder. We use the 6th layer of WavLM Large.
                        2. **Text-to-SSL**: We train a lightweight TTS model to predict the same representations from Text. For simplicity, we train on a single speaker dataset.
                        3. **Retrieval Mechanism**: We use kNN to find for each unit in the generated features its closest matches in the target voice unit database
                        4. **Voice Morphing**: By linearly interpolating the source and selected target speaker features, we can morph the two voices. The interpolation parameter λ controls the balance between source and target characteristics
                        5. **Vocoder**: We use a pre-trained vocoder to convert the converted features to waveform.
                       
                        ### Performance
                        
                        Our simple and efficient model achieves comparable results to sota models while being trained on 100 to 1000× less transcribed data.
                        This framework is therefore particularly well-suited for low-resource domains.

                        For more details, please refer to our paper (https://arxiv.org/abs/2408.10771).
                        """)
                    with gr.Column():
                        gr.Image("assets/diagram.png", label="Model Architecture", scale=0.3, show_label=False, show_download_button=False, show_fullscreen_button=False)
            
            with gr.TabItem("About"):
                gr.Markdown("""
                ## About the Project
                
                This demo showcases kNN-TTS, a lightweight zero-shot text-to-speech synthesis model.
                
                ### Authors
                
                - Karl El Hajal
                - Ajinkya Kulkarni
                - Enno Hermann
                - Mathew Magimai.-Doss
                
                ### Citation
                
                If you use kNN-TTS in your research, please cite our paper:
                
                ```
                @inproceedings{hajal-etal-2025-knn,
                    title = "k{NN} Retrieval for Simple and Effective Zero-Shot Multi-speaker Text-to-Speech",
                    author = "Hajal, Karl El  and
                      Kulkarni, Ajinkya  and
                      Hermann, Enno  and
                      Magimai Doss, Mathew",
                    editor = "Chiruzzo, Luis  and
                      Ritter, Alan  and
                      Wang, Lu",
                    booktitle = "Proceedings of the 2025 Conference of the Nations of the Americas Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 2: Short Papers)",
                    month = apr,
                    year = "2025",
                    address = "Albuquerque, New Mexico",
                    publisher = "Association for Computational Linguistics",
                    url = "https://aclanthology.org/2025.naacl-short.65/",
                    pages = "778--786",
                    ISBN = "979-8-89176-190-2"
                }
                ```
                
                ### Acknowledgments
                
                The target voices featured in this demo were sourced from the following datasets:
                
                - [Thorsten Dataset](https://www.thorsten-voice.de/)
                - [LibriSpeech Dataset](https://www.openslr.org/12)
                - [Emotional Speech Dataset (ESD)](https://hltsingapore.github.io/ESD/)
                
                ### License
                
                This project is licensed under the MIT License.
                """)
        
        # Event handlers
        submit_button.click(
            fn=run,
            inputs=[text_box, target_speaker_dropdown, rate_slider, k_slider, weighted_toggle],
            outputs=[audio_output]
        )
                
    return iface

demo = create_gradio_interface()
demo.launch(share=True, debug=False)