File size: 1,984 Bytes
8e58f0a
c7143e3
 
 
98cc349
8e58f0a
 
c7143e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c130bcf
98cc349
 
c7143e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e58f0a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import torch
import torchaudio
import yaml
from models.gense_wavlm import N2S, S2S


class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

def get_firstchannel_read(path, target_sr=16000):
    wav, sr = torchaudio.load(path)
    if wav.shape[0] > 1:
        wav = wav[0].unsqueeze(0)
    if sr != target_sr:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
        wav = resampler(wav)
    return wav.unsqueeze(0)

def inference(noisy_path):
    noisy_wav = get_firstchannel_read(noisy_path).to(device)
    noisy_s, clean_s = n2s_model.generate(noisy_wav)
    enhanced_wav = s2s_model.generate(noisy_wav, noisy_s, clean_s)
    out_path = 'enhanced2.wav'
    torchaudio.save(out_path, enhanced_wav, sample_rate=16000)
    return out_path


from huggingface_hub import hf_hub_download
config_path = hf_hub_download(repo_id="yaoxunji/gense", filename="gense.yaml")
n2s_ckpt_path = hf_hub_download(repo_id="yaoxunji/gense", filename="n2s_wavlm.ckpt")
s2s_ckpt_path = hf_hub_download(repo_id="yaoxunji/gense", filename="s2s_wavlm.ckpt")



device = 'cuda' if torch.cuda.is_available() else 'cpu'
with open(config_path, "r") as f:
    config = yaml.safe_load(f)
config = AttrDict(config)

n2s_model = N2S(config)
n2s_model.load_state_dict(torch.load(n2s_ckpt_path)["state_dict"])
n2s_model = n2s_model.eval()
n2s_model = n2s_model.to(device)

s2s_model = S2S(config)
s2s_model.load_state_dict(torch.load(s2s_ckpt_path)["state_dict"])
s2s_model = s2s_model.eval()
s2s_model = s2s_model.to(device)

# 

demo = gr.Interface(
    fn=inference,
    inputs=[
        gr.Audio(label="Upload Noisy Wav", type="filepath"),
    ],
    outputs=gr.Audio(label="Enhanced Audio"),
    title="GenSE Demo",
    description="""
    [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/pdf/2502.02942))
    """,
)

demo.launch()