Dionyssos's picture
determinis
bc7f42e
import torch
from torch import nn
from omegaconf import OmegaConf
import numpy as np
from huggingface_hub import hf_hub_download
import os
from audiocraft.encodec import EncodecModel
from audiocraft.lm import LMModel
from audiocraft.seanet import SEANetDecoder
from audiocraft.vq import ResidualVectorQuantizer
N_REPEAT = 2 # num (virtual batch_size) clones of audio sounds
def _shift(x):
#print(x.shape, 'BATCH Independent SHIFT\n AudioGen')
for i, _slice in enumerate(x):
n = x.shape[2]
offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
print(offset)
x[i, :, :] = torch.roll(_slice, offset, dims=1) # _slice 2D
return x
class AudioGen(torch.nn.Module):
# https://huggingface.co/facebook/audiogen-medium
def __init__(self):
super().__init__()
_file_1 = hf_hub_download(
repo_id='facebook/audiogen-medium',
filename="compression_state_dict.bin",
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
library_name="audiocraft",
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
pkg = torch.load(_file_1, map_location='cpu')# kwargs = OmegaConf.create(pkg['xp.cfg'])
decoder = SEANetDecoder()
quantizer = ResidualVectorQuantizer()
self.compression_model = EncodecModel(decoder=decoder,
quantizer=quantizer,
frame_rate=50,
renormalize=False,
sample_rate=16000,
channels=1,
causal=False) #.to(cfg.device)
self.compression_model.load_state_dict(pkg['best_state'], strict=False)
self.compression_model.eval() # ckpt has also unused encoder weights
# T5 &
# LM
_file_2 = hf_hub_download(
repo_id='facebook/audiogen-medium',
filename="state_dict.bin",
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
library_name="audiocraft",
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
pkg = torch.load(_file_2, map_location='cpu')
cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
_best = pkg['best_state']
_best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float)
_best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float)
self.lm = LMModel()
self.lm.load_state_dict(pkg['best_state'], strict=True)
self.lm.eval()
@torch.no_grad()
def generate(self,
prompt='dogs mewo',
duration=2.24, # seconds of audio
):
torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
self.lm.n_draw = int(duration / 12) + 1 # different beam every 7 seconds of audio
with torch.autocast(device_type='cuda', dtype=torch.float16):
gen_tokens = self.lm.generate(
text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
max_tokens=int(duration / (N_REPEAT * self.lm.n_draw) * self.compression_model.frame_rate)
) # [bs, 4, 74 * self.lm.n_draw]
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
for _ in range(7): # perhaps shift is too random as already lm.n_draw has randomness
x = _shift(x)
return x.reshape(-1) #x / (x.abs().max() + 1e-7)