import os import torch import librosa import look2hear.models import soundfile as sf import argparse import numpy as np import yaml from ml_collections import ConfigDict import json import time import warnings warnings.filterwarnings("ignore") def get_config(config_path): with open(config_path) as f: config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) return config def load_audio(file_path): audio, samplerate = librosa.load(file_path, mono=False, sr=44100) print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}') return torch.from_numpy(audio), samplerate def save_audio(file_path, audio, samplerate=44100): sf.write(file_path, audio.T, samplerate, subtype="PCM_16") def process_chunk(chunk): chunk = chunk.unsqueeze(0).cuda() with torch.no_grad(): return model(chunk).squeeze(0).squeeze(0).cpu() def _getWindowingArray(window_size, fade_size): fadein = torch.linspace(1, 1, fade_size) fadeout = torch.linspace(0, 0, fade_size) window = torch.ones(window_size) window[-fade_size:] *= fadeout window[:fade_size] *= fadein return window def dBgain(audio, volume_gain_dB): gain = 10 ** (volume_gain_dB / 20) gained_audio = audio * gain return gained_audio def main(input_wav, output_wav, ckpt_path): os.environ['CUDA_VISIBLE_DEVICES'] = "0" global model feature_dim = config['model']['feature_dim'] sr = config['model']['sr'] win = config['model']['win'] layer = config['model']['layer'] model = look2hear.models.BaseModel.from_pretrain(ckpt_path, sr=sr, win=win, feature_dim=feature_dim, layer=layer).cuda() test_data, samplerate = load_audio(input_wav) C = chunk_size * samplerate # chunk_size seconds to samples N = overlap step = C // N fade_size = 3 * 44100 # 3 seconds print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}") border = C - step if len(test_data.shape) == 1: test_data = test_data.unsqueeze(0) if test_data.shape[1] > 2 * border and border > 0: test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect') windowingArray = _getWindowingArray(C, fade_size) result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32) counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32) i = 0 total_samples = test_data.shape[1] start_time = time.time() while i < total_samples: part = test_data[:, i:i + C] length = part.shape[-1] if length < C: if length > C // 2 + 1: part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect') else: part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) out = process_chunk(part) window = windowingArray if i == 0: window[:fade_size] = 1 elif i + C >= test_data.shape[1]: window[-fade_size:] = 1 result[..., i:i+length] += out[..., :length] * window[..., :length] counter[..., i:i+length] += window[..., :length] i += step percentage = (i / total_samples) * 100 processed_samples = min(i, total_samples) elapsed_time = time.time() - start_time progress_data = { "percentage": percentage, "processed_samples": processed_samples, "total_samples": total_samples, "elapsed_time": elapsed_time } print(json.dumps(progress_data), flush=True) final_output = result / counter final_output = final_output.squeeze(0).numpy() np.nan_to_num(final_output, copy=False, nan=0.0) if test_data.shape[1] > 2 * border and border > 0: final_output = final_output[..., border:-border] save_audio(output_wav, final_output, samplerate) print(f'Success! Output file saved as {output_wav}') model.cpu() del model torch.cuda.empty_cache() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Audio Inference Script") parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file") parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file") parser.add_argument("--ckpt", type=str, required=True, help="Path to model checkpoint file") parser.add_argument("--config", type=str, required=True, help="Path to model config file") parser.add_argument("--chunk_size", type=int, default=10, help="Chunk size in seconds") parser.add_argument("--overlap", type=int, default=2, help="Overlap") args = parser.parse_args() ckpt_path = args.ckpt chunk_size = args.chunk_size overlap = args.overlap config = get_config(args.config) print(config['model']) print(f'ckpt_path = {ckpt_path}') print(f'chunk_size = {chunk_size}, overlap = {overlap}') main(args.in_wav, args.out_wav, ckpt_path)