Spaces:
Running
Running
| import os | |
| import json | |
| def test_reconstuct(): | |
| import yaml | |
| from diffusers import AutoencoderKL | |
| from transformers import SpeechT5HifiGan | |
| from audioldm2.utilities.data.dataset import AudioDataset | |
| from utils import load_clip, load_clap, load_t5 | |
| model_path = '/maindata/data/shared/multimodal/public/dataset_music/audioldm2' | |
| config = yaml.load( | |
| open( | |
| 'config/16k_64.yaml', | |
| 'r' | |
| ), | |
| Loader=yaml.FullLoader, | |
| ) | |
| print(config) | |
| t5 = load_t5('cuda', max_length=256) | |
| clap = load_clap('cuda', max_length=256) | |
| dataset = AudioDataset( | |
| config=config, split="train", waveform_only=False, dataset_json_path='mini_dataset.json', | |
| tokenizer=clap.tokenizer, | |
| uncond_pro=0.1, | |
| text_ctx_len=77, | |
| tokenizer_t5=t5.tokenizer, | |
| text_ctx_len_t5=256, | |
| uncond_pro_t5=0.1, | |
| ) | |
| print(dataset[0]['log_mel_spec'].unsqueeze(0).unsqueeze(0).size()) | |
| vae = AutoencoderKL.from_pretrained(os.path.join(model_path, 'vae')) | |
| vocoder = SpeechT5HifiGan.from_pretrained(os.path.join(model_path, 'vocoder')) | |
| latents = vae.encode(dataset[0]['log_mel_spec'].unsqueeze(0).unsqueeze(0)).latent_dist.sample().mul_(vae.config.scaling_factor) | |
| print('laten size:', latents.size()) | |
| latents = 1 / vae.config.scaling_factor * latents | |
| mel_spectrogram = vae.decode(latents).sample | |
| print(mel_spectrogram.size()) | |
| if mel_spectrogram.dim() == 4: | |
| mel_spectrogram = mel_spectrogram.squeeze(1) | |
| waveform = vocoder(mel_spectrogram) | |
| waveform = waveform[0].cpu().float().detach().numpy() | |
| print(waveform.shape) | |
| # import soundfile as sf | |
| # sf.write('reconstruct.wav', waveform, samplerate=16000) | |
| from scipy.io import wavfile | |
| # wavfile.write('reconstruct.wav', 16000, waveform) | |
| def mini_dataset(num=32): | |
| data = [] | |
| for i in range(num): | |
| data.append( | |
| { | |
| 'wav': 'case.mp3', | |
| 'label': 'a beautiful music', | |
| } | |
| ) | |
| with open('mini_dataset.json', 'w') as f: | |
| json.dump(data, f, indent=4) | |
| def fma_dataset(): | |
| import pandas as pd | |
| annotation_prex = "/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/annotation" | |
| annotation_list = ['test-00000-of-00001.parquet', 'train-00000-of-00001.parquet', 'valid-00000-of-00001.parquet'] | |
| dataset_prex = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/fma_large' | |
| data = [] | |
| for annotation_file in annotation_list: | |
| annotation_file = os.path.join(annotation_prex, annotation_file) | |
| df=pd.read_parquet(annotation_file) | |
| print(df.shape) | |
| for id, row in df.iterrows(): | |
| #print(id, row['pseudo_caption'], row['path']) | |
| tmp_path = os.path.join(dataset_prex, row['path'] + '.mp3') | |
| # print(tmp_path) | |
| if os.path.exists(tmp_path): | |
| data.append( | |
| { | |
| 'wav': tmp_path, | |
| 'label': row['pseudo_caption'], | |
| } | |
| ) | |
| # break | |
| print(len(data)) | |
| with open('fma_dataset.json', 'w') as f: | |
| json.dump(data, f, indent=4) | |
| def audioset_dataset(): | |
| import pandas as pd | |
| dataset_prex = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/audioset' | |
| annotation_path = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/audioset/balanced_train-00000-of-00001.parquet' | |
| df=pd.read_parquet(annotation_path) | |
| print(df.shape) | |
| data = [] | |
| for id, row in df.iterrows(): | |
| #print(id, row['pseudo_caption'], row['path']) | |
| try: | |
| tmp_path = os.path.join(dataset_prex, row['path'] + '.flac') | |
| except: | |
| print(row['path']) | |
| if os.path.exists(tmp_path): | |
| # print(tmp_path) | |
| data.append( | |
| { | |
| 'wav': tmp_path, | |
| 'label': row['pseudo_caption'], | |
| } | |
| ) | |
| print(len(data)) | |
| with open('audioset_dataset.json', 'w') as f: | |
| json.dump(data, f, indent=4) | |
| def combine_dataset(): | |
| data_list = ['fma_dataset.json', 'audioset_dataset.json'] | |
| data = [] | |
| for data_file in data_list: | |
| with open(data_file, 'r') as f: | |
| data += json.load(f) | |
| print(len(data)) | |
| with open('combine_dataset.json', 'w') as f: | |
| json.dump(data, f, indent=4) | |
| def test_music_format(): | |
| import torchaudio | |
| filename = '2.flac' | |
| waveform, sr = torchaudio.load(filename,) | |
| print(waveform, sr ) | |
| def test_flops(): | |
| version = 'giant' | |
| import torch | |
| from constants import build_model | |
| from thop import profile | |
| model = build_model(version).cuda() | |
| img_ids = torch.randn((1, 1024, 3)).cuda() | |
| txt = torch.randn((1, 256, 4096)).cuda() | |
| txt_ids = torch.randn((1, 256, 3)).cuda() | |
| y = torch.randn((1, 768)).cuda() | |
| x = torch.randn((1, 1024, 32)).cuda() | |
| t = torch.tensor([1] * 1).cuda() | |
| flops, _ = profile(model, inputs=(x, img_ids, txt, txt_ids, t, y,)) | |
| print('FLOPs = ' + str(flops * 2/1000**3) + 'G') | |
| # test_music_format() | |
| # test_reconstuct() | |
| # mini_dataset() | |
| # fma_dataset() | |
| # audioset_dataset() | |
| # combine_dataset() | |
| test_flops() |