File size: 3,855 Bytes
d72b2c3 0230db1 bc7f42e b399825 54adc39 bc7f42e d912185 4eabff6 cf02fb0 b399825 bc7f42e 54adc39 bc08da5 4eabff6 54adc39 4eabff6 cf02fb0 54adc39 0230db1 4eabff6 0230db1 4eabff6 0230db1 bc7f42e 0230db1 bc7f42e 0230db1 4eabff6 0230db1 4eabff6 0230db1 bc7f42e 0230db1 bc7f42e 0230db1 4eabff6 54adc39 0230db1 bc7f42e cf02fb0 bc7f42e 4eabff6 bc7f42e 4eabff6 b399825 bc7f42e 17a68db bc7f42e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import torch
from torch import nn
from omegaconf import OmegaConf
import numpy as np
from huggingface_hub import hf_hub_download
import os
from audiocraft.encodec import EncodecModel
from audiocraft.lm import LMModel
from audiocraft.seanet import SEANetDecoder
from audiocraft.vq import ResidualVectorQuantizer
N_REPEAT = 2 # num (virtual batch_size) clones of audio sounds
def _shift(x):
#print(x.shape, 'BATCH Independent SHIFT\n AudioGen')
for i, _slice in enumerate(x):
n = x.shape[2]
offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
print(offset)
x[i, :, :] = torch.roll(_slice, offset, dims=1) # _slice 2D
return x
class AudioGen(torch.nn.Module):
# https://huggingface.co/facebook/audiogen-medium
def __init__(self):
super().__init__()
_file_1 = hf_hub_download(
repo_id='facebook/audiogen-medium',
filename="compression_state_dict.bin",
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
library_name="audiocraft",
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
pkg = torch.load(_file_1, map_location='cpu')# kwargs = OmegaConf.create(pkg['xp.cfg'])
decoder = SEANetDecoder()
quantizer = ResidualVectorQuantizer()
self.compression_model = EncodecModel(decoder=decoder,
quantizer=quantizer,
frame_rate=50,
renormalize=False,
sample_rate=16000,
channels=1,
causal=False) #.to(cfg.device)
self.compression_model.load_state_dict(pkg['best_state'], strict=False)
self.compression_model.eval() # ckpt has also unused encoder weights
# T5 &
# LM
_file_2 = hf_hub_download(
repo_id='facebook/audiogen-medium',
filename="state_dict.bin",
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
library_name="audiocraft",
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
pkg = torch.load(_file_2, map_location='cpu')
cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
_best = pkg['best_state']
_best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float)
_best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float)
self.lm = LMModel()
self.lm.load_state_dict(pkg['best_state'], strict=True)
self.lm.eval()
@torch.no_grad()
def generate(self,
prompt='dogs mewo',
duration=2.24, # seconds of audio
):
torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
self.lm.n_draw = int(duration / 12) + 1 # different beam every 7 seconds of audio
with torch.autocast(device_type='cuda', dtype=torch.float16):
gen_tokens = self.lm.generate(
text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
max_tokens=int(duration / (N_REPEAT * self.lm.n_draw) * self.compression_model.frame_rate)
) # [bs, 4, 74 * self.lm.n_draw]
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
for _ in range(7): # perhaps shift is too random as already lm.n_draw has randomness
x = _shift(x)
return x.reshape(-1) #x / (x.abs().max() + 1e-7)
|