File size: 5,024 Bytes
30ea38f
 
 
c0197a8
30ea38f
 
 
 
 
02a7e23
 
30ea38f
 
 
 
 
 
 
 
fc049c4
 
30ea38f
 
 
 
 
 
fc049c4
 
30ea38f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc049c4
 
 
 
 
 
 
 
 
30ea38f
fc049c4
30ea38f
 
 
 
02a7e23
30ea38f
 
 
 
 
 
 
02a7e23
30ea38f
 
 
 
 
 
 
 
02a7e23
 
30ea38f
02a7e23
30ea38f
 
 
 
 
 
 
 
fc049c4
30ea38f
 
02a7e23
30ea38f
02a7e23
30ea38f
 
 
 
 
 
02a7e23
 
 
 
 
 
 
 
 
 
30ea38f
 
 
 
 
02a7e23
30ea38f
 
 
 
 
 
 
 
 
 
 
 
 
02a7e23
 
 
 
30ea38f
 
fc049c4
 
 
30ea38f
 
fc049c4
 
 
02a7e23
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import torch
import librosa
import look2hear.models
import soundfile as sf
import argparse
import numpy as np
import yaml
from ml_collections import ConfigDict
import json
import time
import warnings
warnings.filterwarnings("ignore")

def get_config(config_path):
    with open(config_path) as f:
        config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
        return config

def load_audio(file_path):
    audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
    print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
    return torch.from_numpy(audio), samplerate

def save_audio(file_path, audio, samplerate=44100):
    sf.write(file_path, audio.T, samplerate, subtype="PCM_16")

def process_chunk(chunk):
    chunk = chunk.unsqueeze(0).cuda()
    with torch.no_grad():
        return model(chunk).squeeze(0).squeeze(0).cpu()

def _getWindowingArray(window_size, fade_size):
    fadein = torch.linspace(1, 1, fade_size)
    fadeout = torch.linspace(0, 0, fade_size)
    window = torch.ones(window_size)
    window[-fade_size:] *= fadeout
    window[:fade_size] *= fadein
    return window

def dBgain(audio, volume_gain_dB):
    gain = 10 ** (volume_gain_dB / 20)
    gained_audio = audio * gain 
    return gained_audio

def main(input_wav, output_wav, ckpt_path):
    os.environ['CUDA_VISIBLE_DEVICES'] = "0"

    global model
    feature_dim = config['model']['feature_dim']
    sr = config['model']['sr']
    win = config['model']['win']
    layer = config['model']['layer']
    model = look2hear.models.BaseModel.from_pretrain(ckpt_path, sr=sr, win=win, feature_dim=feature_dim, layer=layer).cuda()

    test_data, samplerate = load_audio(input_wav)
    
    C = chunk_size * samplerate  # chunk_size seconds to samples
    N = overlap
    step = C // N
    fade_size = 3 * 44100  # 3 seconds
    print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
    
    border = C - step
    
    if len(test_data.shape) == 1:
        test_data = test_data.unsqueeze(0) 

    if test_data.shape[1] > 2 * border and border > 0:
        test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect')

    windowingArray = _getWindowingArray(C, fade_size)

    result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
    counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)

    i = 0
    total_samples = test_data.shape[1]
    start_time = time.time()

    while i < total_samples:
        part = test_data[:, i:i + C]
        length = part.shape[-1]
        if length < C:
            if length > C // 2 + 1:
                part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
            else:
                part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)

        out = process_chunk(part)

        window = windowingArray
        if i == 0:
            window[:fade_size] = 1
        elif i + C >= test_data.shape[1]:
            window[-fade_size:] = 1

        result[..., i:i+length] += out[..., :length] * window[..., :length]
        counter[..., i:i+length] += window[..., :length]

        i += step
        percentage = (i / total_samples) * 100
        processed_samples = min(i, total_samples)
        elapsed_time = time.time() - start_time
        progress_data = {
            "percentage": percentage,
            "processed_samples": processed_samples,
            "total_samples": total_samples,
            "elapsed_time": elapsed_time
        }
        print(json.dumps(progress_data), flush=True)

    final_output = result / counter
    final_output = final_output.squeeze(0).numpy()
    np.nan_to_num(final_output, copy=False, nan=0.0)

    if test_data.shape[1] > 2 * border and border > 0:
        final_output = final_output[..., border:-border]

    save_audio(output_wav, final_output, samplerate)
    print(f'Success! Output file saved as {output_wav}')

    model.cpu()
    del model
    torch.cuda.empty_cache()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Audio Inference Script")
    parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file")
    parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file")
    parser.add_argument("--ckpt", type=str, required=True, help="Path to model checkpoint file")
    parser.add_argument("--config", type=str, required=True, help="Path to model config file")
    parser.add_argument("--chunk_size", type=int, default=10, help="Chunk size in seconds")
    parser.add_argument("--overlap", type=int, default=2, help="Overlap")
    args = parser.parse_args()
    
    ckpt_path = args.ckpt
    chunk_size = args.chunk_size
    overlap = args.overlap
    config = get_config(args.config)
    print(config['model'])
    print(f'ckpt_path = {ckpt_path}')
    print(f'chunk_size = {chunk_size}, overlap = {overlap}')
    
    main(args.in_wav, args.out_wav, ckpt_path)