Spaces:
Running
Running
### | |
# Author: Kai Li | |
# Date: 2024-01-22 01:16:22 | |
# Email: lk21@mails.tsinghua.edu.cn | |
# LastEditTime: 2024-01-24 00:05:10 | |
### | |
import torch | |
import torchaudio | |
import json | |
from typing import Any, Dict, List, Optional, Tuple | |
import os | |
from omegaconf import OmegaConf | |
import argparse | |
import pytorch_lightning as pl | |
import hydra | |
from pytorch_lightning.strategies.ddp import DDPStrategy | |
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer | |
# from pytorch_lightning.loggers import Logger | |
from omegaconf import DictConfig | |
import look2hear.system | |
import look2hear.datas | |
import look2hear.losses | |
import look2hear.models | |
from look2hear.metrics import MetricsTracker | |
from look2hear.utils import RankedLogger, instantiate, print_only | |
import warnings | |
warnings.filterwarnings("ignore") | |
class Owndata(): | |
def __init__(self, root): | |
self.root = root | |
self.data_lists = [] | |
for root, dirs, files in os.walk(self.root): | |
for file in files: | |
if file == "codec_wav.wav": | |
self.data_lists.append(os.path.join(root, file)) | |
def __len__(self): | |
return len(self.data_lists) | |
def __getitem__(self, idx): | |
ori = self.data_lists[idx].replace("codec_wav.wav", "ori_wav.wav") | |
ori_audio = torchaudio.load(ori)[0] | |
codec_audio = torchaudio.load(self.data_lists[idx])[0] | |
return ori_audio, codec_audio, ori | |
def test(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: | |
data_val = Owndata("./codec-test") | |
cfg.model.pop("_target_", None) | |
model = look2hear.models.GullFullband.from_pretrain(os.path.join(cfg.exp.dir, cfg.exp.name, "best_model.pth"), **cfg.model).cuda() | |
model.cuda() | |
os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name, "results/"), exist_ok=True) | |
metrics = MetricsTracker(save_file=os.path.join(cfg.exp.dir, cfg.exp.name, "results/")+"metrics.csv") | |
length = len(data_val) | |
for idx in range(len(data_val)): | |
ori_wav, codec_wav, key = data_val[idx] | |
ori_wav = ori_wav.cuda() | |
codec_wav = codec_wav.unsqueeze(0).cuda() | |
with torch.no_grad(): | |
ests = model(codec_wav) | |
torchaudio.save(key.replace("ori_wav.wav", "ests_wav.wav"), ests.squeeze(0).cpu(), 44100) | |
metrics(ori_wav, ests, key) | |
if idx % 10 == 0: | |
dicts = metrics.update() | |
print(f"Processed {idx}/{length} samples: SDR: {dicts['sdr']}, SI-SNR: {dicts['si-snr']}, VISQOL: {dicts['visqol']}") | |
metrics.final() | |
print("Finished testing") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--conf_dir", | |
default="Exps/Apollo/config.yaml", | |
help="Full path to save best validation model", | |
) | |
args = parser.parse_args() | |
cfg = OmegaConf.load(args.conf_dir) | |
os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name), exist_ok=True) | |
OmegaConf.save(cfg, os.path.join(cfg.exp.dir, cfg.exp.name, "config.yaml")) | |
test(cfg) | |