|
import torch |
|
from torch import nn |
|
from omegaconf import OmegaConf |
|
import numpy as np |
|
from huggingface_hub import hf_hub_download |
|
import os |
|
from audiocraft.encodec import EncodecModel |
|
from audiocraft.lm import LMModel |
|
from audiocraft.seanet import SEANetDecoder |
|
from audiocraft.vq import ResidualVectorQuantizer |
|
|
|
|
|
N_REPEAT = 2 |
|
|
|
def _shift(x): |
|
|
|
for i, _slice in enumerate(x): |
|
n = x.shape[2] |
|
offset = np.random.randint(.24 * n, max(1, .74 * n)) |
|
print(offset) |
|
x[i, :, :] = torch.roll(_slice, offset, dims=1) |
|
return x |
|
|
|
class AudioGen(torch.nn.Module): |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
_file_1 = hf_hub_download( |
|
repo_id='facebook/audiogen-medium', |
|
filename="compression_state_dict.bin", |
|
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None), |
|
library_name="audiocraft", |
|
library_version= '1.3.0a1') |
|
pkg = torch.load(_file_1, map_location='cpu') |
|
decoder = SEANetDecoder() |
|
quantizer = ResidualVectorQuantizer() |
|
self.compression_model = EncodecModel(decoder=decoder, |
|
quantizer=quantizer, |
|
frame_rate=50, |
|
renormalize=False, |
|
sample_rate=16000, |
|
channels=1, |
|
causal=False) |
|
self.compression_model.load_state_dict(pkg['best_state'], strict=False) |
|
self.compression_model.eval() |
|
|
|
|
|
_file_2 = hf_hub_download( |
|
repo_id='facebook/audiogen-medium', |
|
filename="state_dict.bin", |
|
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None), |
|
library_name="audiocraft", |
|
library_version= '1.3.0a1') |
|
pkg = torch.load(_file_2, map_location='cpu') |
|
cfg = OmegaConf.create(pkg['xp.cfg']) |
|
_best = pkg['best_state'] |
|
_best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight') |
|
_best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias') |
|
self.lm = LMModel() |
|
self.lm.load_state_dict(pkg['best_state'], strict=True) |
|
self.lm.eval() |
|
|
|
|
|
@torch.no_grad() |
|
def generate(self, |
|
prompt='dogs mewo', |
|
duration=2.24, |
|
): |
|
torch.manual_seed(42) |
|
self.lm.n_draw = int(duration / 12) + 1 |
|
|
|
with torch.autocast(device_type='cuda', dtype=torch.float16): |
|
gen_tokens = self.lm.generate( |
|
text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT, |
|
max_tokens=int(duration / (N_REPEAT * self.lm.n_draw) * self.compression_model.frame_rate) |
|
) |
|
x = self.compression_model.decode(gen_tokens, None) |
|
|
|
|
|
for _ in range(7): |
|
x = _shift(x) |
|
|
|
return x.reshape(-1) |
|
|