File size: 5,197 Bytes
d385f3b
c4c4663
d385f3b
 
c4c4663
d661acb
d385f3b
d232776
d385f3b
 
c4c4663
0d0ce41
d385f3b
c4c4663
 
 
d385f3b
c4c4663
 
 
 
 
 
 
d232776
c4c4663
d385f3b
 
 
 
c4c4663
 
 
 
 
 
d385f3b
c4c4663
 
 
d385f3b
c4c4663
 
 
d385f3b
 
 
c4c4663
d232776
c4c4663
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d661acb
 
c4c4663
 
 
 
 
 
 
 
 
 
 
 
d232776
d9599f2
c4c4663
5ce738f
 
 
 
 
 
 
 
 
 
 
c4c4663
 
 
 
 
 
 
 
d232776
 
 
 
 
c4c4663
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d232776
 
 
 
 
c4c4663
 
 
 
 
 
 
 
 
 
d385f3b
 
12bc96d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import uuid
import torch
import torchaudio
import torchaudio.transforms as T
import soundfile as sf
import gradio as gr
import spaces
import look2hear.models

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load models
dnr_model = look2hear.models.TIGERDNR.from_pretrained("JusperLee/TIGER-DnR", cache_dir="cache")
dnr_model.to(device).eval()

sep_model = look2hear.models.TIGER.from_pretrained("JusperLee/TIGER-speech", cache_dir="cache")
sep_model.to(device).eval()

TARGET_SR = 16000
MAX_SPEAKERS = 4

# --- DnR Function ---
@spaces.GPU()
def separate_dnr(audio_file):
    audio, sr = torchaudio.load(audio_file)
    audio = audio.to(device)

    with torch.no_grad():
        dialog, effect, music = dnr_model(audio[None])

    # Unique output folder
    session_id = uuid.uuid4().hex[:8]
    output_dir = os.path.join("output_dnr", session_id)
    os.makedirs(output_dir, exist_ok=True)

    dialog_path = os.path.join(output_dir, "dialog.wav")
    effect_path = os.path.join(output_dir, "effect.wav")
    music_path = os.path.join(output_dir, "music.wav")

    torchaudio.save(dialog_path, dialog.cpu(), sr)
    torchaudio.save(effect_path, effect.cpu(), sr)
    torchaudio.save(music_path, music.cpu(), sr)

    return dialog_path, effect_path, music_path

# --- Speaker Separation Function ---
@spaces.GPU()
def separate_speakers(audio_path):
    waveform, original_sr = torchaudio.load(audio_path)
    if original_sr != TARGET_SR:
        waveform = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR)(waveform)

    if waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)
    audio_input = waveform.unsqueeze(0).to(device)

    with torch.no_grad():
        ests_speech = sep_model(audio_input)

    ests_speech = ests_speech.squeeze(0)

    # Unique output folder
    session_id = uuid.uuid4().hex[:8]
    output_dir = os.path.join("output_sep", session_id)
    os.makedirs(output_dir, exist_ok=True)

    output_files = []
    for i in range(ests_speech.shape[0]):
        path = os.path.join(output_dir, f"speaker_{i+1}.wav")
        audio_np = ests_speech[i].cpu().numpy()
        sf.write(path, audio_np.T, TARGET_SR)  # Transpose only if shape is [T, C], usually not needed
        output_files.append(path)

    updates = []
    for i in range(MAX_SPEAKERS):
        if i < len(output_files):
            updates.append(gr.update(value=output_files[i], visible=True, label=f"Speaker {i+1}"))
        else:
            updates.append(gr.update(value=None, visible=False))
    return updates

# --- Gradio App ---
with gr.Blocks() as demo:
    gr.Markdown("# TIGER: Time-frequency Interleaved Gain Extraction and Reconstruction for Efficient Speech Separation")
    gr.Markdown("TIGER is a lightweight model for speech separation which effectively extracts key acoustic features through frequency band-split, multi-scale and full-frequency-frame modeling.")

    gr.HTML("""
            <div style="display:flex;column-gap:4px;">
                <a href="https://cslikai.cn/TIGER/">
                    <img src='https://img.shields.io/badge/Project-Page-green'>
                </a> 
    			
                <a href="https://huggingface.co/spaces/fffiloni/TIGER-audio-extraction?duplicate=true">
    				<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
    			</a>	
            </div>
        """)
    with gr.Tabs():
        # --- Tab 1: DnR ---
        with gr.Tab("Dialog/Effects/Music Separation (DnR)"):
            gr.Markdown("### Separate Dialog, Effects, and Music from Mixed Audio")

            dnr_input = gr.Audio(type="filepath", label="Upload Audio File")
            dnr_button = gr.Button("Separate Audio")

            gr.Examples(
                examples = ["./test/test_mixture_466.wav"],
                inputs = dnr_input
            )

            dnr_output_dialog = gr.Audio(label="Dialog", type="filepath")
            dnr_output_effect = gr.Audio(label="Effects", type="filepath")
            dnr_output_music = gr.Audio(label="Music", type="filepath")

            dnr_button.click(
                fn=separate_dnr,
                inputs=dnr_input,
                outputs=[dnr_output_dialog, dnr_output_effect, dnr_output_music]
            )

        # --- Tab 2: Speaker Separation ---
        with gr.Tab("Speaker Separation"):
            gr.Markdown("### Separate Individual Speakers from Mixed Speech")

            sep_input = gr.Audio(type="filepath", label="Upload Speech Audio")
            sep_button = gr.Button("Separate Speakers")

            gr.Examples(
                examples = ["./test/mix.wav"],
                inputs = sep_input
            )

            gr.Markdown("#### Separated Speakers")
            sep_outputs = []
            for i in range(MAX_SPEAKERS):
                sep_outputs.append(gr.Audio(label=f"Speaker {i+1}", visible=(i == 0), interactive=False))

            sep_button.click(
                fn=separate_speakers,
                inputs=sep_input,
                outputs=sep_outputs
            )

if __name__ == "__main__":
    demo.launch(share = True)