TIGER-audio-extraction / gradio_app.py
svjack's picture
Update gradio_app.py
12bc96d verified
import os
import uuid
import torch
import torchaudio
import torchaudio.transforms as T
import soundfile as sf
import gradio as gr
import spaces
import look2hear.models
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load models
dnr_model = look2hear.models.TIGERDNR.from_pretrained("JusperLee/TIGER-DnR", cache_dir="cache")
dnr_model.to(device).eval()
sep_model = look2hear.models.TIGER.from_pretrained("JusperLee/TIGER-speech", cache_dir="cache")
sep_model.to(device).eval()
TARGET_SR = 16000
MAX_SPEAKERS = 4
# --- DnR Function ---
@spaces.GPU()
def separate_dnr(audio_file):
audio, sr = torchaudio.load(audio_file)
audio = audio.to(device)
with torch.no_grad():
dialog, effect, music = dnr_model(audio[None])
# Unique output folder
session_id = uuid.uuid4().hex[:8]
output_dir = os.path.join("output_dnr", session_id)
os.makedirs(output_dir, exist_ok=True)
dialog_path = os.path.join(output_dir, "dialog.wav")
effect_path = os.path.join(output_dir, "effect.wav")
music_path = os.path.join(output_dir, "music.wav")
torchaudio.save(dialog_path, dialog.cpu(), sr)
torchaudio.save(effect_path, effect.cpu(), sr)
torchaudio.save(music_path, music.cpu(), sr)
return dialog_path, effect_path, music_path
# --- Speaker Separation Function ---
@spaces.GPU()
def separate_speakers(audio_path):
waveform, original_sr = torchaudio.load(audio_path)
if original_sr != TARGET_SR:
waveform = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR)(waveform)
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
audio_input = waveform.unsqueeze(0).to(device)
with torch.no_grad():
ests_speech = sep_model(audio_input)
ests_speech = ests_speech.squeeze(0)
# Unique output folder
session_id = uuid.uuid4().hex[:8]
output_dir = os.path.join("output_sep", session_id)
os.makedirs(output_dir, exist_ok=True)
output_files = []
for i in range(ests_speech.shape[0]):
path = os.path.join(output_dir, f"speaker_{i+1}.wav")
audio_np = ests_speech[i].cpu().numpy()
sf.write(path, audio_np.T, TARGET_SR) # Transpose only if shape is [T, C], usually not needed
output_files.append(path)
updates = []
for i in range(MAX_SPEAKERS):
if i < len(output_files):
updates.append(gr.update(value=output_files[i], visible=True, label=f"Speaker {i+1}"))
else:
updates.append(gr.update(value=None, visible=False))
return updates
# --- Gradio App ---
with gr.Blocks() as demo:
gr.Markdown("# TIGER: Time-frequency Interleaved Gain Extraction and Reconstruction for Efficient Speech Separation")
gr.Markdown("TIGER is a lightweight model for speech separation which effectively extracts key acoustic features through frequency band-split, multi-scale and full-frequency-frame modeling.")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href="https://cslikai.cn/TIGER/">
<img src='https://img.shields.io/badge/Project-Page-green'>
</a>
<a href="https://huggingface.co/spaces/fffiloni/TIGER-audio-extraction?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
</a>
</div>
""")
with gr.Tabs():
# --- Tab 1: DnR ---
with gr.Tab("Dialog/Effects/Music Separation (DnR)"):
gr.Markdown("### Separate Dialog, Effects, and Music from Mixed Audio")
dnr_input = gr.Audio(type="filepath", label="Upload Audio File")
dnr_button = gr.Button("Separate Audio")
gr.Examples(
examples = ["./test/test_mixture_466.wav"],
inputs = dnr_input
)
dnr_output_dialog = gr.Audio(label="Dialog", type="filepath")
dnr_output_effect = gr.Audio(label="Effects", type="filepath")
dnr_output_music = gr.Audio(label="Music", type="filepath")
dnr_button.click(
fn=separate_dnr,
inputs=dnr_input,
outputs=[dnr_output_dialog, dnr_output_effect, dnr_output_music]
)
# --- Tab 2: Speaker Separation ---
with gr.Tab("Speaker Separation"):
gr.Markdown("### Separate Individual Speakers from Mixed Speech")
sep_input = gr.Audio(type="filepath", label="Upload Speech Audio")
sep_button = gr.Button("Separate Speakers")
gr.Examples(
examples = ["./test/mix.wav"],
inputs = sep_input
)
gr.Markdown("#### Separated Speakers")
sep_outputs = []
for i in range(MAX_SPEAKERS):
sep_outputs.append(gr.Audio(label=f"Speaker {i+1}", visible=(i == 0), interactive=False))
sep_button.click(
fn=separate_speakers,
inputs=sep_input,
outputs=sep_outputs
)
if __name__ == "__main__":
demo.launch(share = True)