import torch from torch import nn from omegaconf import OmegaConf import numpy as np from huggingface_hub import hf_hub_download import os from audiocraft.encodec import EncodecModel from audiocraft.lm import LMModel from audiocraft.seanet import SEANetDecoder from audiocraft.vq import ResidualVectorQuantizer N_REPEAT = 2 # num (virtual batch_size) clones of audio sounds def _shift(x): #print(x.shape, 'BATCH Independent SHIFT\n AudioGen') for i, _slice in enumerate(x): n = x.shape[2] offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD print(offset) x[i, :, :] = torch.roll(_slice, offset, dims=1) # _slice 2D return x class AudioGen(torch.nn.Module): # https://huggingface.co/facebook/audiogen-medium def __init__(self): super().__init__() _file_1 = hf_hub_download( repo_id='facebook/audiogen-medium', filename="compression_state_dict.bin", cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None), library_name="audiocraft", library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__) pkg = torch.load(_file_1, map_location='cpu')# kwargs = OmegaConf.create(pkg['xp.cfg']) decoder = SEANetDecoder() quantizer = ResidualVectorQuantizer() self.compression_model = EncodecModel(decoder=decoder, quantizer=quantizer, frame_rate=50, renormalize=False, sample_rate=16000, channels=1, causal=False) #.to(cfg.device) self.compression_model.load_state_dict(pkg['best_state'], strict=False) self.compression_model.eval() # ckpt has also unused encoder weights # T5 & # LM _file_2 = hf_hub_download( repo_id='facebook/audiogen-medium', filename="state_dict.bin", cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None), library_name="audiocraft", library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__) pkg = torch.load(_file_2, map_location='cpu') cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin _best = pkg['best_state'] _best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float) _best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float) self.lm = LMModel() self.lm.load_state_dict(pkg['best_state'], strict=True) self.lm.eval() @torch.no_grad() def generate(self, prompt='dogs mewo', duration=2.24, # seconds of audio ): torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858 self.lm.n_draw = int(duration / 12) + 1 # different beam every 7 seconds of audio with torch.autocast(device_type='cuda', dtype=torch.float16): gen_tokens = self.lm.generate( text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', ''] max_tokens=int(duration / (N_REPEAT * self.lm.n_draw) * self.compression_model.frame_rate) ) # [bs, 4, 74 * self.lm.n_draw] x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840] for _ in range(7): # perhaps shift is too random as already lm.n_draw has randomness x = _shift(x) return x.reshape(-1) #x / (x.abs().max() + 1e-7)