Spaces:
Running
Running
File size: 5,024 Bytes
30ea38f c0197a8 30ea38f 02a7e23 30ea38f fc049c4 30ea38f fc049c4 30ea38f fc049c4 30ea38f fc049c4 30ea38f 02a7e23 30ea38f 02a7e23 30ea38f 02a7e23 30ea38f 02a7e23 30ea38f fc049c4 30ea38f 02a7e23 30ea38f 02a7e23 30ea38f 02a7e23 30ea38f 02a7e23 30ea38f 02a7e23 30ea38f fc049c4 30ea38f fc049c4 02a7e23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import torch
import librosa
import look2hear.models
import soundfile as sf
import argparse
import numpy as np
import yaml
from ml_collections import ConfigDict
import json
import time
import warnings
warnings.filterwarnings("ignore")
def get_config(config_path):
with open(config_path) as f:
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
return config
def load_audio(file_path):
audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
return torch.from_numpy(audio), samplerate
def save_audio(file_path, audio, samplerate=44100):
sf.write(file_path, audio.T, samplerate, subtype="PCM_16")
def process_chunk(chunk):
chunk = chunk.unsqueeze(0).cuda()
with torch.no_grad():
return model(chunk).squeeze(0).squeeze(0).cpu()
def _getWindowingArray(window_size, fade_size):
fadein = torch.linspace(1, 1, fade_size)
fadeout = torch.linspace(0, 0, fade_size)
window = torch.ones(window_size)
window[-fade_size:] *= fadeout
window[:fade_size] *= fadein
return window
def dBgain(audio, volume_gain_dB):
gain = 10 ** (volume_gain_dB / 20)
gained_audio = audio * gain
return gained_audio
def main(input_wav, output_wav, ckpt_path):
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
global model
feature_dim = config['model']['feature_dim']
sr = config['model']['sr']
win = config['model']['win']
layer = config['model']['layer']
model = look2hear.models.BaseModel.from_pretrain(ckpt_path, sr=sr, win=win, feature_dim=feature_dim, layer=layer).cuda()
test_data, samplerate = load_audio(input_wav)
C = chunk_size * samplerate # chunk_size seconds to samples
N = overlap
step = C // N
fade_size = 3 * 44100 # 3 seconds
print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
border = C - step
if len(test_data.shape) == 1:
test_data = test_data.unsqueeze(0)
if test_data.shape[1] > 2 * border and border > 0:
test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect')
windowingArray = _getWindowingArray(C, fade_size)
result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
i = 0
total_samples = test_data.shape[1]
start_time = time.time()
while i < total_samples:
part = test_data[:, i:i + C]
length = part.shape[-1]
if length < C:
if length > C // 2 + 1:
part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
else:
part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
out = process_chunk(part)
window = windowingArray
if i == 0:
window[:fade_size] = 1
elif i + C >= test_data.shape[1]:
window[-fade_size:] = 1
result[..., i:i+length] += out[..., :length] * window[..., :length]
counter[..., i:i+length] += window[..., :length]
i += step
percentage = (i / total_samples) * 100
processed_samples = min(i, total_samples)
elapsed_time = time.time() - start_time
progress_data = {
"percentage": percentage,
"processed_samples": processed_samples,
"total_samples": total_samples,
"elapsed_time": elapsed_time
}
print(json.dumps(progress_data), flush=True)
final_output = result / counter
final_output = final_output.squeeze(0).numpy()
np.nan_to_num(final_output, copy=False, nan=0.0)
if test_data.shape[1] > 2 * border and border > 0:
final_output = final_output[..., border:-border]
save_audio(output_wav, final_output, samplerate)
print(f'Success! Output file saved as {output_wav}')
model.cpu()
del model
torch.cuda.empty_cache()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Audio Inference Script")
parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file")
parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file")
parser.add_argument("--ckpt", type=str, required=True, help="Path to model checkpoint file")
parser.add_argument("--config", type=str, required=True, help="Path to model config file")
parser.add_argument("--chunk_size", type=int, default=10, help="Chunk size in seconds")
parser.add_argument("--overlap", type=int, default=2, help="Overlap")
args = parser.parse_args()
ckpt_path = args.ckpt
chunk_size = args.chunk_size
overlap = args.overlap
config = get_config(args.config)
print(config['model'])
print(f'ckpt_path = {ckpt_path}')
print(f'chunk_size = {chunk_size}, overlap = {overlap}')
main(args.in_wav, args.out_wav, ckpt_path) |