determinis
Browse files- audiocraft/builders.py +30 -45
- audiocraft/lm.py +49 -59
- audiocraft/transformer.py +69 -77
audiocraft/builders.py
CHANGED
@@ -1,26 +1,25 @@
|
|
1 |
-
import omegaconf
|
2 |
-
import torchaudio
|
3 |
import torch
|
4 |
from torch import nn
|
|
|
5 |
import numpy as np
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
import os
|
8 |
-
from
|
9 |
-
from .
|
10 |
-
from .
|
11 |
-
from .
|
12 |
-
|
13 |
|
14 |
-
# torch.backends.cudnn.deterministic = True
|
15 |
N_REPEAT = 2 # num (virtual batch_size) clones of audio sounds
|
16 |
|
17 |
def _shift(x):
|
18 |
-
n
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
24 |
|
25 |
class AudioGen(torch.nn.Module):
|
26 |
|
@@ -29,18 +28,13 @@ class AudioGen(torch.nn.Module):
|
|
29 |
def __init__(self):
|
30 |
|
31 |
super().__init__()
|
32 |
-
# self.autocast = TorchAutocast(
|
33 |
-
# enabled=True, device_type='cuda', dtype=torch.float16)
|
34 |
-
# Vocoder
|
35 |
_file_1 = hf_hub_download(
|
36 |
repo_id='facebook/audiogen-medium',
|
37 |
filename="compression_state_dict.bin",
|
38 |
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
|
39 |
library_name="audiocraft",
|
40 |
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
|
41 |
-
pkg = torch.load(_file_1, map_location='cpu')
|
42 |
-
# kwargs = OmegaConf.create(pkg['xp.cfg'])
|
43 |
-
# kwargs.device = 'cpu'
|
44 |
decoder = SEANetDecoder()
|
45 |
quantizer = ResidualVectorQuantizer()
|
46 |
self.compression_model = EncodecModel(decoder=decoder,
|
@@ -50,12 +44,10 @@ class AudioGen(torch.nn.Module):
|
|
50 |
sample_rate=16000,
|
51 |
channels=1,
|
52 |
causal=False) #.to(cfg.device)
|
53 |
-
|
54 |
-
self.compression_model.
|
55 |
-
|
56 |
-
# # T5 &
|
57 |
# LM
|
58 |
-
|
59 |
_file_2 = hf_hub_download(
|
60 |
repo_id='facebook/audiogen-medium',
|
61 |
filename="state_dict.bin",
|
@@ -65,37 +57,30 @@ class AudioGen(torch.nn.Module):
|
|
65 |
pkg = torch.load(_file_2, map_location='cpu')
|
66 |
cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
|
67 |
_best = pkg['best_state']
|
68 |
-
# _best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
|
69 |
-
# _best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
|
70 |
_best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float)
|
71 |
_best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float)
|
72 |
-
self.lm = LMModel()
|
73 |
-
self.lm.load_state_dict(pkg['best_state'],
|
74 |
-
strict=True)
|
75 |
-
#
|
76 |
self.lm.eval()
|
77 |
-
|
78 |
|
79 |
@torch.no_grad()
|
80 |
def generate(self,
|
81 |
prompt='dogs mewo',
|
82 |
-
duration=2.24,
|
83 |
):
|
|
|
|
|
|
|
84 |
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
85 |
gen_tokens = self.lm.generate(
|
86 |
-
text_condition=[prompt]
|
87 |
-
max_tokens=int(duration / (N_REPEAT * self.lm.n_draw) * self.compression_model.frame_rate)
|
88 |
-
|
89 |
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
|
90 |
|
91 |
-
x = x[:, 0, :] # last samples have splash sounds DISCARD 25000 last samples
|
92 |
-
|
93 |
-
# AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
|
94 |
-
|
95 |
-
x = self.resample_fn(x) # [N_REPEAT, duration]
|
96 |
|
97 |
-
|
|
|
98 |
|
99 |
-
#
|
100 |
-
# x = _shift(x)
|
101 |
-
return x #x / (x.abs().max() + 1e-7)
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
+
from omegaconf import OmegaConf
|
4 |
import numpy as np
|
5 |
from huggingface_hub import hf_hub_download
|
6 |
import os
|
7 |
+
from audiocraft.encodec import EncodecModel
|
8 |
+
from audiocraft.lm import LMModel
|
9 |
+
from audiocraft.seanet import SEANetDecoder
|
10 |
+
from audiocraft.vq import ResidualVectorQuantizer
|
11 |
+
|
12 |
|
|
|
13 |
N_REPEAT = 2 # num (virtual batch_size) clones of audio sounds
|
14 |
|
15 |
def _shift(x):
|
16 |
+
#print(x.shape, 'BATCH Independent SHIFT\n AudioGen')
|
17 |
+
for i, _slice in enumerate(x):
|
18 |
+
n = x.shape[2]
|
19 |
+
offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
|
20 |
+
print(offset)
|
21 |
+
x[i, :, :] = torch.roll(_slice, offset, dims=1) # _slice 2D
|
22 |
+
return x
|
23 |
|
24 |
class AudioGen(torch.nn.Module):
|
25 |
|
|
|
28 |
def __init__(self):
|
29 |
|
30 |
super().__init__()
|
|
|
|
|
|
|
31 |
_file_1 = hf_hub_download(
|
32 |
repo_id='facebook/audiogen-medium',
|
33 |
filename="compression_state_dict.bin",
|
34 |
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
|
35 |
library_name="audiocraft",
|
36 |
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
|
37 |
+
pkg = torch.load(_file_1, map_location='cpu')# kwargs = OmegaConf.create(pkg['xp.cfg'])
|
|
|
|
|
38 |
decoder = SEANetDecoder()
|
39 |
quantizer = ResidualVectorQuantizer()
|
40 |
self.compression_model = EncodecModel(decoder=decoder,
|
|
|
44 |
sample_rate=16000,
|
45 |
channels=1,
|
46 |
causal=False) #.to(cfg.device)
|
47 |
+
self.compression_model.load_state_dict(pkg['best_state'], strict=False)
|
48 |
+
self.compression_model.eval() # ckpt has also unused encoder weights
|
49 |
+
# T5 &
|
|
|
50 |
# LM
|
|
|
51 |
_file_2 = hf_hub_download(
|
52 |
repo_id='facebook/audiogen-medium',
|
53 |
filename="state_dict.bin",
|
|
|
57 |
pkg = torch.load(_file_2, map_location='cpu')
|
58 |
cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
|
59 |
_best = pkg['best_state']
|
|
|
|
|
60 |
_best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float)
|
61 |
_best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float)
|
62 |
+
self.lm = LMModel()
|
63 |
+
self.lm.load_state_dict(pkg['best_state'], strict=True)
|
|
|
|
|
64 |
self.lm.eval()
|
65 |
+
|
66 |
|
67 |
@torch.no_grad()
|
68 |
def generate(self,
|
69 |
prompt='dogs mewo',
|
70 |
+
duration=2.24, # seconds of audio
|
71 |
):
|
72 |
+
torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
|
73 |
+
self.lm.n_draw = int(duration / 12) + 1 # different beam every 7 seconds of audio
|
74 |
+
|
75 |
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
76 |
gen_tokens = self.lm.generate(
|
77 |
+
text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
|
78 |
+
max_tokens=int(duration / (N_REPEAT * self.lm.n_draw) * self.compression_model.frame_rate)
|
79 |
+
) # [bs, 4, 74 * self.lm.n_draw]
|
80 |
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
|
81 |
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
+
for _ in range(7): # perhaps shift is too random as already lm.n_draw has randomness
|
84 |
+
x = _shift(x)
|
85 |
|
86 |
+
return x.reshape(-1) #x / (x.abs().max() + 1e-7)
|
|
|
|
audiocraft/lm.py
CHANGED
@@ -6,14 +6,14 @@ from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
|
6 |
class T5(nn.Module):
|
7 |
|
8 |
def __init__(self):
|
9 |
-
|
10 |
super().__init__()
|
11 |
self.output_proj = nn.Linear(1024, # t5-large
|
12 |
1536) # lm hidden
|
13 |
self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True)
|
14 |
t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False)
|
15 |
|
16 |
-
# this makes sure that the t5
|
17 |
# of the saved checkpoint
|
18 |
self.__dict__['t5'] = t5.to('cuda:0')
|
19 |
|
@@ -28,9 +28,8 @@ class T5(nn.Module):
|
|
28 |
|
29 |
x = self.t5(input_ids=d['input_ids'],
|
30 |
attention_mask=d['attention_mask']).last_hidden_state # no kv
|
31 |
-
|
32 |
-
|
33 |
-
print('BEF PROJ',x[0, :, :].sum(), x[1, :, :].sum(), self.output_proj.weight.sum(), self.output_proj.weight.dtype, self.output_proj.bias.sum(), 'GEN\n\n143')
|
34 |
x = self.output_proj(x) # nn.Linear() - produces different result if there is no duplicate txt condition here
|
35 |
x[bs:, :, :] = 0 # venv/../site-packages/audiocraft/modules/conditioners.py -> tokenize()
|
36 |
return x
|
@@ -41,71 +40,55 @@ class LMModel(nn.Module):
|
|
41 |
def __init__(self,
|
42 |
n_q = 4,
|
43 |
card = 2048,
|
44 |
-
dim = 1536
|
45 |
-
num_heads = 24,
|
46 |
-
hidden_scale = 4, # FFN of Transformer
|
47 |
):
|
48 |
super().__init__()
|
49 |
self.t5 = T5()
|
50 |
-
self.card = card
|
51 |
-
self.n_draw = 1 # draw
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
self.n_q = n_q
|
56 |
-
self.dim = dim
|
57 |
-
self.emb = nn.ModuleList([nn.Embedding(embed_dim, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
|
58 |
-
self.transformer = StreamingTransformer(
|
59 |
-
d_model=dim,
|
60 |
-
num_heads=num_heads,
|
61 |
-
dim_feedforward=int(hidden_scale * dim),
|
62 |
-
num_layers=48,
|
63 |
-
positional_embedding='sin',
|
64 |
-
)
|
65 |
self.out_norm = nn.LayerNorm(dim, eps=1e-5)
|
66 |
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) # LINEAR DOESNT HAVE 2049
|
67 |
|
68 |
def forward(self,
|
69 |
sequence,
|
70 |
condition_tensors=None,
|
71 |
-
|
72 |
|
73 |
bs, n_q, time_frames = sequence.shape # [bs, 4, time]
|
74 |
|
75 |
-
input_ = sum([self.emb[k](sequence[:, k]) for k in range(
|
76 |
|
77 |
out = self.transformer(torch.cat([input_, input_], 0), # duplicate null condition (bs x 2) for ClassifierFreeGuidance
|
78 |
cross_attention_src=condition_tensors,
|
79 |
-
|
80 |
)
|
81 |
|
82 |
-
logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
#
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
p
|
93 |
-
|
94 |
-
|
95 |
-
p = p.reshape(bs * self.n_q, 2048)
|
96 |
-
out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
|
97 |
-
num_samples=self.n_draw,
|
98 |
-
replacement=False) # [bs*4, self.n_draw]
|
99 |
-
# print('DRAW','c', out)
|
100 |
-
return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
|
101 |
|
102 |
@torch.no_grad()
|
103 |
def generate(self,
|
104 |
max_tokens=None,
|
105 |
-
text_condition=None
|
106 |
-
):
|
107 |
x = self.t5(text_condition)
|
108 |
bs = x.shape[0] // 2 # has null conditions - bs*2*N_REPEAT applys in builders.py
|
|
|
|
|
|
|
109 |
out_codes = torch.full((bs,
|
110 |
self.n_draw,
|
111 |
4,
|
@@ -113,14 +96,15 @@ class LMModel(nn.Module):
|
|
113 |
self.card,
|
114 |
dtype=torch.long,
|
115 |
device=x.device) # [bs, n_draw, 4, dur]
|
116 |
-
|
|
|
117 |
for offset in range(0, max_tokens + 4 - 1): # max_tokens + n_q - 1
|
118 |
|
119 |
# extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
|
120 |
next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
|
121 |
#gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
|
122 |
condition_tensors=x, # utilisation of the attention mask of txt condition ?
|
123 |
-
|
124 |
|
125 |
# Fill of next_token should be also placed on antidiagonal [not column]
|
126 |
|
@@ -133,11 +117,11 @@ class LMModel(nn.Module):
|
|
133 |
# [2048, 2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6]]
|
134 |
# NO OVerWriting
|
135 |
if offset == 0:
|
136 |
-
|
137 |
next_token[:, :, 1:4] = 2048 # self.card - bottom 3 entries of the antidiagonal should remain 2048
|
138 |
|
139 |
elif offset == 1:
|
140 |
-
|
141 |
next_token[:, :, 2:4] = 2048 # bottom 2 entries of the antidiagonal should remain 2048
|
142 |
|
143 |
elif offset == 2:
|
@@ -157,16 +141,22 @@ class LMModel(nn.Module):
|
|
157 |
next_token[:, :, 0:3] = 2048
|
158 |
|
159 |
else: # offset 3,4,5,6,7...... max_tokens-1 # FILL Complete n_q = 4 ANTIDIAGONAL ENTRIES
|
160 |
-
|
161 |
pass #print('No delete anti-diag')
|
162 |
|
163 |
out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
-
for
|
169 |
-
|
170 |
-
lay.self_attn.v_history = None
|
171 |
|
172 |
-
return out_codes # SKIP THE 4 fill 2048
|
|
|
6 |
class T5(nn.Module):
|
7 |
|
8 |
def __init__(self):
|
9 |
+
|
10 |
super().__init__()
|
11 |
self.output_proj = nn.Linear(1024, # t5-large
|
12 |
1536) # lm hidden
|
13 |
self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True)
|
14 |
t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False)
|
15 |
|
16 |
+
# this makes sure that the t5 is not part
|
17 |
# of the saved checkpoint
|
18 |
self.__dict__['t5'] = t5.to('cuda:0')
|
19 |
|
|
|
28 |
|
29 |
x = self.t5(input_ids=d['input_ids'],
|
30 |
attention_mask=d['attention_mask']).last_hidden_state # no kv
|
31 |
+
# Float 16
|
32 |
+
# > self.output_proj() is outside of autocast of t5 - however inside the autocast of lm thus computed in torch.float16
|
|
|
33 |
x = self.output_proj(x) # nn.Linear() - produces different result if there is no duplicate txt condition here
|
34 |
x[bs:, :, :] = 0 # venv/../site-packages/audiocraft/modules/conditioners.py -> tokenize()
|
35 |
return x
|
|
|
40 |
def __init__(self,
|
41 |
n_q = 4,
|
42 |
card = 2048,
|
43 |
+
dim = 1536
|
|
|
|
|
44 |
):
|
45 |
super().__init__()
|
46 |
self.t5 = T5()
|
47 |
+
self.card = card # 2048
|
48 |
+
self.n_draw = 1 # draw > 1 tokens of different CFG scale
|
49 |
+
# batch size > 1 is slower from n_draw as calls transformer on larger batch
|
50 |
+
self.emb = nn.ModuleList([nn.Embedding(self.card + 1, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
|
51 |
+
self.transformer = StreamingTransformer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
self.out_norm = nn.LayerNorm(dim, eps=1e-5)
|
53 |
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) # LINEAR DOESNT HAVE 2049
|
54 |
|
55 |
def forward(self,
|
56 |
sequence,
|
57 |
condition_tensors=None,
|
58 |
+
cache_position=None):
|
59 |
|
60 |
bs, n_q, time_frames = sequence.shape # [bs, 4, time]
|
61 |
|
62 |
+
input_ = sum([self.emb[k](sequence[:, k]) for k in range(n_q)])
|
63 |
|
64 |
out = self.transformer(torch.cat([input_, input_], 0), # duplicate null condition (bs x 2) for ClassifierFreeGuidance
|
65 |
cross_attention_src=condition_tensors,
|
66 |
+
cache_position=cache_position
|
67 |
)
|
68 |
|
69 |
+
logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(n_q)], dim=1) # [2*bs, 4, 1, 2048]
|
70 |
+
logits = 3 * logits[:bs, :, :, :] - self._scale * logits[bs:, :, :, :] # [ bs, 4, n_draw, 2048]
|
71 |
+
|
72 |
+
k = 24
|
73 |
+
logits = torch.softmax(logits / 1.0, dim=3) # [bs, 4, 1, 2048]
|
74 |
+
p, ix = torch.topk(logits, k, dim=3) # p = [bs, 4, 1, 24], ix = [bs, 4, 1, 2048]
|
75 |
+
# Exponential Distribution
|
76 |
+
deflation = torch.empty_like(p).exponential_(lambd=1)
|
77 |
+
p = p / deflation
|
78 |
+
# divide large probs with exp(prob) If prob=.001 then 1/exp(1*.001) -> almost by 0 --> exp doesnt really produce (0, Inf)
|
79 |
+
p = p.argmax(dim=3, keepdim=True) # [bs, 4, n_draw, 24]
|
80 |
+
tok = ix.gather(dim=3, index=p).to(torch.int64) # [bs, 4, n_draw, 1]
|
81 |
+
return tok[:, :, :, 0].transpose(1, 2) # [bs, n_draw, 4]
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
@torch.no_grad()
|
84 |
def generate(self,
|
85 |
max_tokens=None,
|
86 |
+
text_condition=None):
|
|
|
87 |
x = self.t5(text_condition)
|
88 |
bs = x.shape[0] // 2 # has null conditions - bs*2*N_REPEAT applys in builders.py
|
89 |
+
self._scale = .3 * torch.rand(1, 1, self.n_draw, 1, device=x.device) + 1.94
|
90 |
+
cache_position = 0
|
91 |
+
|
92 |
out_codes = torch.full((bs,
|
93 |
self.n_draw,
|
94 |
4,
|
|
|
96 |
self.card,
|
97 |
dtype=torch.long,
|
98 |
device=x.device) # [bs, n_draw, 4, dur]
|
99 |
+
|
100 |
+
# A/R
|
101 |
for offset in range(0, max_tokens + 4 - 1): # max_tokens + n_q - 1
|
102 |
|
103 |
# extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
|
104 |
next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
|
105 |
#gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
|
106 |
condition_tensors=x, # utilisation of the attention mask of txt condition ?
|
107 |
+
cache_position=cache_position) # [bs, n_draw, 4]
|
108 |
|
109 |
# Fill of next_token should be also placed on antidiagonal [not column]
|
110 |
|
|
|
117 |
# [2048, 2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6]]
|
118 |
# NO OVerWriting
|
119 |
if offset == 0:
|
120 |
+
|
121 |
next_token[:, :, 1:4] = 2048 # self.card - bottom 3 entries of the antidiagonal should remain 2048
|
122 |
|
123 |
elif offset == 1:
|
124 |
+
|
125 |
next_token[:, :, 2:4] = 2048 # bottom 2 entries of the antidiagonal should remain 2048
|
126 |
|
127 |
elif offset == 2:
|
|
|
141 |
next_token[:, :, 0:3] = 2048
|
142 |
|
143 |
else: # offset 3,4,5,6,7...... max_tokens-1 # FILL Complete n_q = 4 ANTIDIAGONAL ENTRIES
|
144 |
+
|
145 |
pass #print('No delete anti-diag')
|
146 |
|
147 |
out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
|
148 |
+
# Sink Attn
|
149 |
+
if (offset > 0) and (offset % 71) == 0:
|
150 |
+
n_preserve = 4
|
151 |
+
self.transformer._flush(n_preserve=n_preserve)
|
152 |
+
cache_position = n_preserve
|
153 |
+
else:
|
154 |
+
cache_position += 1
|
155 |
+
|
156 |
+
# [bs, n_draw, 4, time+xtra] -> [bs, 4, n_draw, time] -> [bs, 4, time * n_draw]
|
157 |
+
out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens)
|
158 |
|
159 |
+
# flush for next API call
|
160 |
+
self.transformer._flush()
|
|
|
161 |
|
162 |
+
return out_codes # SKIP THE 4 fill 2048
|
audiocraft/transformer.py
CHANGED
@@ -5,17 +5,23 @@ from einops import rearrange
|
|
5 |
|
6 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
7 |
|
8 |
-
|
|
|
9 |
dim,
|
10 |
max_period=10000
|
11 |
):
|
12 |
-
assert dim % 2 == 0
|
13 |
half_dim = dim // 2
|
14 |
positions = positions.to(torch.float)
|
15 |
-
adim = torch.arange(half_dim, device=positions.device,
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
|
18 |
-
|
|
|
19 |
|
20 |
|
21 |
class StreamingMultiheadAttention(nn.Module):
|
@@ -23,19 +29,20 @@ class StreamingMultiheadAttention(nn.Module):
|
|
23 |
def __init__(self,
|
24 |
embed_dim,
|
25 |
num_heads,
|
26 |
-
cross_attention
|
27 |
):
|
28 |
|
29 |
super().__init__()
|
30 |
|
31 |
self.cross_attention = cross_attention
|
32 |
-
self.
|
33 |
-
self.k_history = None
|
34 |
-
|
|
|
35 |
self.num_heads = num_heads
|
36 |
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
37 |
self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim),
|
38 |
-
|
39 |
|
40 |
def forward(self,
|
41 |
query,
|
@@ -44,15 +51,16 @@ class StreamingMultiheadAttention(nn.Module):
|
|
44 |
layout = "b h t d"
|
45 |
if self.cross_attention:
|
46 |
|
47 |
-
# Different queries, keys, values
|
48 |
-
|
49 |
dim = self.in_proj_weight.shape[0] // 3
|
50 |
|
51 |
q = nn.functional.linear(query, self.in_proj_weight[:dim])
|
52 |
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim])
|
53 |
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
|
54 |
|
55 |
-
q, k, v = [
|
|
|
56 |
|
57 |
else:
|
58 |
# 1st projected makes k,v (instantaneous)
|
@@ -60,59 +68,45 @@ class StreamingMultiheadAttention(nn.Module):
|
|
60 |
|
61 |
# HISTORY - DIFFERENT FOR EACH TRANSF LAYER
|
62 |
|
63 |
-
|
|
|
64 |
# print(query.sum(), projected.sum() , self.in_proj_weight.sum(), 'Lc') # verified official AudioGen values
|
65 |
bound_layout = "b h p t d"
|
66 |
-
packed = rearrange(
|
|
|
67 |
q, k, v = packed.unbind(dim=2)
|
68 |
if self.k_history is not None:
|
69 |
-
#
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
self.v_history = torch.cat([self.v_history[:, :, :4, :], self.v_history[:, :, -1:, :]], 2)
|
74 |
-
# fill new k/v
|
75 |
-
self.k_history = torch.cat([self.k_history, k], 2) # IF ctrl^c here during live demo it is non-atomic k!=v
|
76 |
-
self.v_history = torch.cat([self.v_history, v], 2) # thus it will try to continue with incompatible k/v dims!
|
77 |
-
|
78 |
else:
|
79 |
-
# init
|
80 |
self.k_history = k
|
81 |
self.v_history = v
|
82 |
-
|
|
|
|
|
83 |
k = self.k_history
|
84 |
v = self.v_history
|
85 |
|
|
|
86 |
|
87 |
-
|
88 |
-
# KV COMPLETION ONLY ON SELF ATTENTION
|
89 |
-
|
90 |
x = torch.nn.functional.scaled_dot_product_attention(
|
91 |
-
q, k, v, is_causal=False, dropout_p=0)
|
92 |
|
93 |
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
|
94 |
x = self.out_proj(x)
|
95 |
return x
|
96 |
|
97 |
|
|
|
98 |
|
99 |
-
class StreamingTransformerLayer(nn.TransformerEncoderLayer):
|
100 |
-
|
101 |
def __init__(self,
|
102 |
d_model,
|
103 |
num_heads,
|
104 |
dim_feedforward):
|
105 |
-
|
106 |
-
super().__init__(
|
107 |
-
num_heads,
|
108 |
-
dim_feedforward=dim_feedforward,
|
109 |
-
dropout=0.0,
|
110 |
-
device='cuda',
|
111 |
-
dtype=torch.float32,
|
112 |
-
batch_first=True,
|
113 |
-
norm_first=True,
|
114 |
-
activation='gelu')
|
115 |
-
# super().__init__()
|
116 |
|
117 |
self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
|
118 |
num_heads=num_heads)
|
@@ -125,15 +119,14 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
|
|
125 |
self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
|
126 |
self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
|
127 |
|
128 |
-
|
129 |
def forward(self,
|
130 |
x,
|
131 |
-
cross_attention_src=None):
|
132 |
x = x + self.self_attn(self.norm1(x))
|
133 |
-
x = x + self.cross_attention(query
|
134 |
-
key
|
135 |
-
value
|
136 |
-
x = x + self.linear2(F.gelu(self.linear1(
|
137 |
return x
|
138 |
|
139 |
|
@@ -143,39 +136,38 @@ class StreamingTransformer(nn.Module):
|
|
143 |
d_model=1536,
|
144 |
num_heads=24,
|
145 |
num_layers=48,
|
146 |
-
dim_feedforward=6144
|
147 |
-
cross_attention = True,
|
148 |
-
positional_embedding: str = 'sin',
|
149 |
-
max_period: float = 10_000
|
150 |
-
):
|
151 |
super().__init__()
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
d_model=d_model,
|
161 |
-
num_heads=num_heads,
|
162 |
-
dim_feedforward=dim_feedforward
|
163 |
-
)
|
164 |
-
)
|
165 |
|
166 |
def forward(self,
|
167 |
x,
|
168 |
-
|
169 |
cross_attention_src=None):
|
170 |
|
171 |
-
|
172 |
-
|
173 |
-
1536,
|
174 |
-
max_period=self.max_period)
|
175 |
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
# self attn = audio x audio
|
180 |
-
# Every layer (mha) keeps itsw own kv cachE
|
181 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
7 |
|
8 |
+
|
9 |
+
def create_sin_embedding(positions,
|
10 |
dim,
|
11 |
max_period=10000
|
12 |
):
|
13 |
+
# assert dim % 2 == 0
|
14 |
half_dim = dim // 2
|
15 |
positions = positions.to(torch.float)
|
16 |
+
adim = torch.arange(half_dim, device=positions.device,
|
17 |
+
dtype=torch.float).view(1, 1, -1)
|
18 |
+
max_period_tensor = torch.full([],
|
19 |
+
max_period,
|
20 |
+
device=positions.device,
|
21 |
+
dtype=torch.float) # avoid sync point
|
22 |
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
|
23 |
+
# OFFICIAL is torch.float32 HOWEVER self_attn.in_prod_weight = torch.float16
|
24 |
+
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
25 |
|
26 |
|
27 |
class StreamingMultiheadAttention(nn.Module):
|
|
|
29 |
def __init__(self,
|
30 |
embed_dim,
|
31 |
num_heads,
|
32 |
+
cross_attention=False,
|
33 |
):
|
34 |
|
35 |
super().__init__()
|
36 |
|
37 |
self.cross_attention = cross_attention
|
38 |
+
# if not self.cross_attention then it has kvcachingn
|
39 |
+
self.k_history = None
|
40 |
+
# cleanup history through LM inside GENERATION - Each 0,..,47 mha has different kv history
|
41 |
+
self.v_history = None
|
42 |
self.num_heads = num_heads
|
43 |
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
44 |
self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim),
|
45 |
+
dtype=torch.float))
|
46 |
|
47 |
def forward(self,
|
48 |
query,
|
|
|
51 |
layout = "b h t d"
|
52 |
if self.cross_attention:
|
53 |
|
54 |
+
# Different queries, keys, values > split in_proj_weight
|
55 |
+
|
56 |
dim = self.in_proj_weight.shape[0] // 3
|
57 |
|
58 |
q = nn.functional.linear(query, self.in_proj_weight[:dim])
|
59 |
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim])
|
60 |
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
|
61 |
|
62 |
+
q, k, v = [
|
63 |
+
rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
|
64 |
|
65 |
else:
|
66 |
# 1st projected makes k,v (instantaneous)
|
|
|
68 |
|
69 |
# HISTORY - DIFFERENT FOR EACH TRANSF LAYER
|
70 |
|
71 |
+
# here we have different floating values from official
|
72 |
+
projected = nn.functional.linear(query, self.in_proj_weight, None)
|
73 |
# print(query.sum(), projected.sum() , self.in_proj_weight.sum(), 'Lc') # verified official AudioGen values
|
74 |
bound_layout = "b h p t d"
|
75 |
+
packed = rearrange(
|
76 |
+
projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
77 |
q, k, v = packed.unbind(dim=2)
|
78 |
if self.k_history is not None:
|
79 |
+
# IF ctrl^c during live_demo the assigning of each of kv is non-atomic k!=v
|
80 |
+
# thus it will try to continue with incompatible k/v dims!
|
81 |
+
self.k_history = torch.cat([self.k_history, k], 2)
|
82 |
+
self.v_history = torch.cat([self.v_history, v], 2)
|
|
|
|
|
|
|
|
|
|
|
83 |
else:
|
|
|
84 |
self.k_history = k
|
85 |
self.v_history = v
|
86 |
+
|
87 |
+
# Assign Completed k / v to k / v
|
88 |
+
|
89 |
k = self.k_history
|
90 |
v = self.v_history
|
91 |
|
92 |
+
# -> kv CACHE ONLY APPLIES if not self.cross_attention
|
93 |
|
|
|
|
|
|
|
94 |
x = torch.nn.functional.scaled_dot_product_attention(
|
95 |
+
q, k, v, attn_mask=None, is_causal=False, dropout_p=0.0)
|
96 |
|
97 |
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
|
98 |
x = self.out_proj(x)
|
99 |
return x
|
100 |
|
101 |
|
102 |
+
class StreamingTransformerLayer(nn.Module):
|
103 |
|
|
|
|
|
104 |
def __init__(self,
|
105 |
d_model,
|
106 |
num_heads,
|
107 |
dim_feedforward):
|
108 |
+
|
109 |
+
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
|
112 |
num_heads=num_heads)
|
|
|
119 |
self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
|
120 |
self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
|
121 |
|
|
|
122 |
def forward(self,
|
123 |
x,
|
124 |
+
cross_attention_src=None):
|
125 |
x = x + self.self_attn(self.norm1(x))
|
126 |
+
x = x + self.cross_attention(query=self.norm_cross(x),
|
127 |
+
key=cross_attention_src,
|
128 |
+
value=cross_attention_src) # txtcondition
|
129 |
+
x = x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
|
130 |
return x
|
131 |
|
132 |
|
|
|
136 |
d_model=1536,
|
137 |
num_heads=24,
|
138 |
num_layers=48,
|
139 |
+
dim_feedforward=6144):
|
|
|
|
|
|
|
|
|
140 |
super().__init__()
|
141 |
+
|
142 |
+
self.layers = nn.ModuleList(
|
143 |
+
[
|
144 |
+
StreamingTransformerLayer(d_model=d_model,
|
145 |
+
num_heads=num_heads,
|
146 |
+
dim_feedforward=dim_feedforward) for _ in range(num_layers)
|
147 |
+
]
|
148 |
+
)
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
def forward(self,
|
151 |
x,
|
152 |
+
cache_position=None,
|
153 |
cross_attention_src=None):
|
154 |
|
155 |
+
x = x + create_sin_embedding(
|
156 |
+
torch.zeros(x.shape[0], 1, 1, device=x.device) + cache_position, 1536)
|
|
|
|
|
157 |
|
158 |
+
for lay in self.layers:
|
159 |
+
x = lay(x,
|
160 |
+
cross_attention_src=cross_attention_src)
|
|
|
|
|
161 |
return x
|
162 |
+
|
163 |
+
def _flush(self,
|
164 |
+
n_preserve=None):
|
165 |
+
|
166 |
+
for lay in self.layers:
|
167 |
+
if n_preserve is not None:
|
168 |
+
# cache position is difficult to choose to also preserve kv from end
|
169 |
+
lay.self_attn.k_history = lay.self_attn.k_history[:, :, :n_preserve, :]
|
170 |
+
lay.self_attn.v_history = lay.self_attn.v_history[:, :, :n_preserve, :]
|
171 |
+
else:
|
172 |
+
lay.self_attn.k_history = None
|
173 |
+
lay.self_attn.v_history = None
|