File size: 3,855 Bytes
d72b2c3
0230db1
bc7f42e
b399825
54adc39
 
bc7f42e
 
 
 
 
d912185
4eabff6
cf02fb0
b399825
bc7f42e
 
 
 
 
 
 
54adc39
bc08da5
4eabff6
54adc39
4eabff6
cf02fb0
54adc39
 
0230db1
4eabff6
0230db1
 
4eabff6
0230db1
bc7f42e
0230db1
 
 
 
 
 
 
 
 
bc7f42e
 
 
0230db1
 
 
4eabff6
0230db1
4eabff6
0230db1
 
 
 
 
 
bc7f42e
 
0230db1
bc7f42e
0230db1
4eabff6
54adc39
0230db1
bc7f42e
cf02fb0
bc7f42e
 
 
4eabff6
 
bc7f42e
 
 
4eabff6
 
b399825
bc7f42e
 
17a68db
bc7f42e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from torch import nn
from omegaconf import OmegaConf
import numpy as np
from huggingface_hub import hf_hub_download
import os
from audiocraft.encodec import EncodecModel
from audiocraft.lm import LMModel
from audiocraft.seanet import SEANetDecoder
from audiocraft.vq import ResidualVectorQuantizer


N_REPEAT = 2  # num (virtual batch_size) clones of audio sounds

def _shift(x):
    #print(x.shape, 'BATCH Independent SHIFT\n AudioGen')
    for i, _slice in enumerate(x):
        n = x.shape[2]
        offset = np.random.randint(.24 * n, max(1, .74 * n))  # high should be above >= 0 TBD
        print(offset)
        x[i, :, :] = torch.roll(_slice, offset, dims=1)  # _slice 2D
    return x

class AudioGen(torch.nn.Module):

    # https://huggingface.co/facebook/audiogen-medium

    def __init__(self):

        super().__init__()
        _file_1 = hf_hub_download(
            repo_id='facebook/audiogen-medium',
            filename="compression_state_dict.bin",
            cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
            library_name="audiocraft",
            library_version= '1.3.0a1')  # Found at __init__.py #audiocraft.__version__)
        pkg = torch.load(_file_1, map_location='cpu')# kwargs = OmegaConf.create(pkg['xp.cfg'])
        decoder = SEANetDecoder()
        quantizer = ResidualVectorQuantizer()
        self.compression_model = EncodecModel(decoder=decoder,
                                              quantizer=quantizer,
                                              frame_rate=50,
                                              renormalize=False,
                                              sample_rate=16000,
                                              channels=1,
                                              causal=False)  #.to(cfg.device)
        self.compression_model.load_state_dict(pkg['best_state'], strict=False)
        self.compression_model.eval()  # ckpt has also unused encoder weights
        #  T5 &
        #  LM
        _file_2 = hf_hub_download(
            repo_id='facebook/audiogen-medium',
            filename="state_dict.bin",
            cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
            library_name="audiocraft",
            library_version= '1.3.0a1')  # Found at __init__.py #audiocraft.__version__)
        pkg = torch.load(_file_2, map_location='cpu')
        cfg = OmegaConf.create(pkg['xp.cfg'])  # CFG inside torch bin
        _best = pkg['best_state']
        _best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float)
        _best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float)
        self.lm = LMModel()
        self.lm.load_state_dict(pkg['best_state'], strict=True)
        self.lm.eval()


    @torch.no_grad()
    def generate(self,
                 prompt='dogs mewo',
                 duration=2.24,  # seconds of audio
                 ):
        torch.manual_seed(42)  # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
        self.lm.n_draw = int(duration / 12) + 1  # different beam every 7 seconds of audio
        
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            gen_tokens = self.lm.generate(
                text_condition=[prompt] * N_REPEAT  + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
                max_tokens=int(duration / (N_REPEAT * self.lm.n_draw) * self.compression_model.frame_rate)
                ) # [bs, 4, 74 * self.lm.n_draw]
        x = self.compression_model.decode(gen_tokens, None)   #[bs, 1, 11840]


        for _ in range(7):  # perhaps shift is too random as already lm.n_draw has randomness
               x = _shift(x)

        return x.reshape(-1) #x / (x.abs().max() + 1e-7)