Dionyssos commited on
Commit
bc7f42e
·
1 Parent(s): 53c0776

determinis

Browse files
Files changed (3) hide show
  1. audiocraft/builders.py +30 -45
  2. audiocraft/lm.py +49 -59
  3. audiocraft/transformer.py +69 -77
audiocraft/builders.py CHANGED
@@ -1,26 +1,25 @@
1
- import omegaconf
2
- import torchaudio
3
  import torch
4
  from torch import nn
 
5
  import numpy as np
6
  from huggingface_hub import hf_hub_download
7
  import os
8
- from omegaconf import OmegaConf
9
- from .encodec import EncodecModel
10
- from .lm import LMModel
11
- from .seanet import SEANetDecoder
12
- from .vq import ResidualVectorQuantizer
13
 
14
- # torch.backends.cudnn.deterministic = True
15
  N_REPEAT = 2 # num (virtual batch_size) clones of audio sounds
16
 
17
  def _shift(x):
18
- n = len(x)
19
- offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
20
- if isinstance(x, torch.Tensor):
21
- return torch.roll(x, offset, dims=0)
22
- elif isinstance(x, str):
23
- return x[offset:] + x[:offset] #np.roll(x, offset)
 
24
 
25
  class AudioGen(torch.nn.Module):
26
 
@@ -29,18 +28,13 @@ class AudioGen(torch.nn.Module):
29
  def __init__(self):
30
 
31
  super().__init__()
32
- # self.autocast = TorchAutocast(
33
- # enabled=True, device_type='cuda', dtype=torch.float16)
34
- # Vocoder
35
  _file_1 = hf_hub_download(
36
  repo_id='facebook/audiogen-medium',
37
  filename="compression_state_dict.bin",
38
  cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
39
  library_name="audiocraft",
40
  library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
41
- pkg = torch.load(_file_1, map_location='cpu')
42
- # kwargs = OmegaConf.create(pkg['xp.cfg'])
43
- # kwargs.device = 'cpu'
44
  decoder = SEANetDecoder()
45
  quantizer = ResidualVectorQuantizer()
46
  self.compression_model = EncodecModel(decoder=decoder,
@@ -50,12 +44,10 @@ class AudioGen(torch.nn.Module):
50
  sample_rate=16000,
51
  channels=1,
52
  causal=False) #.to(cfg.device)
53
- # self.compression_model = self.get_compression_model(cfg)
54
- self.compression_model.load_state_dict(pkg['best_state'], strict=False) # ckpt has also unused encoder weights
55
- self.resample_fn = torchaudio.transforms.Resample(16000, 24000) # AudioGen = 16KHZ StyleTTS2 = 24 KHz / MMSTTS = 24 KHz
56
- # # T5 &
57
  # LM
58
-
59
  _file_2 = hf_hub_download(
60
  repo_id='facebook/audiogen-medium',
61
  filename="state_dict.bin",
@@ -65,37 +57,30 @@ class AudioGen(torch.nn.Module):
65
  pkg = torch.load(_file_2, map_location='cpu')
66
  cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
67
  _best = pkg['best_state']
68
- # _best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
69
- # _best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
70
  _best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float)
71
  _best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float)
72
- self.lm = LMModel() #to(torch.float16)
73
- self.lm.load_state_dict(pkg['best_state'],
74
- strict=True)
75
- #
76
  self.lm.eval()
77
- self.compression_model.eval()
78
 
79
  @torch.no_grad()
80
  def generate(self,
81
  prompt='dogs mewo',
82
- duration=2.24, ## seconds of audio
83
  ):
 
 
 
84
  with torch.autocast(device_type='cuda', dtype=torch.float16):
85
  gen_tokens = self.lm.generate(
86
- text_condition=[prompt] + [prompt[:10] + _shift(prompt) for _ in range(N_REPEAT-1)] + [''] * N_REPEAT, # '' for null condition, # ['trance', 'dogs meow', '', '']
87
- max_tokens=int(duration / (N_REPEAT * self.lm.n_draw) * self.compression_model.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
88
-
89
  x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
90
 
91
- x = x[:, 0, :] # last samples have splash sounds DISCARD 25000 last samples
92
-
93
- # AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
94
-
95
- x = self.resample_fn(x) # [N_REPEAT, duration]
96
 
97
- x = x.reshape(-1)
 
98
 
99
- # for _ in range(7):
100
- # x = _shift(x)
101
- return x #x / (x.abs().max() + 1e-7)
 
 
 
1
  import torch
2
  from torch import nn
3
+ from omegaconf import OmegaConf
4
  import numpy as np
5
  from huggingface_hub import hf_hub_download
6
  import os
7
+ from audiocraft.encodec import EncodecModel
8
+ from audiocraft.lm import LMModel
9
+ from audiocraft.seanet import SEANetDecoder
10
+ from audiocraft.vq import ResidualVectorQuantizer
11
+
12
 
 
13
  N_REPEAT = 2 # num (virtual batch_size) clones of audio sounds
14
 
15
  def _shift(x):
16
+ #print(x.shape, 'BATCH Independent SHIFT\n AudioGen')
17
+ for i, _slice in enumerate(x):
18
+ n = x.shape[2]
19
+ offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
20
+ print(offset)
21
+ x[i, :, :] = torch.roll(_slice, offset, dims=1) # _slice 2D
22
+ return x
23
 
24
  class AudioGen(torch.nn.Module):
25
 
 
28
  def __init__(self):
29
 
30
  super().__init__()
 
 
 
31
  _file_1 = hf_hub_download(
32
  repo_id='facebook/audiogen-medium',
33
  filename="compression_state_dict.bin",
34
  cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
35
  library_name="audiocraft",
36
  library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
37
+ pkg = torch.load(_file_1, map_location='cpu')# kwargs = OmegaConf.create(pkg['xp.cfg'])
 
 
38
  decoder = SEANetDecoder()
39
  quantizer = ResidualVectorQuantizer()
40
  self.compression_model = EncodecModel(decoder=decoder,
 
44
  sample_rate=16000,
45
  channels=1,
46
  causal=False) #.to(cfg.device)
47
+ self.compression_model.load_state_dict(pkg['best_state'], strict=False)
48
+ self.compression_model.eval() # ckpt has also unused encoder weights
49
+ # T5 &
 
50
  # LM
 
51
  _file_2 = hf_hub_download(
52
  repo_id='facebook/audiogen-medium',
53
  filename="state_dict.bin",
 
57
  pkg = torch.load(_file_2, map_location='cpu')
58
  cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
59
  _best = pkg['best_state']
 
 
60
  _best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float)
61
  _best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float)
62
+ self.lm = LMModel()
63
+ self.lm.load_state_dict(pkg['best_state'], strict=True)
 
 
64
  self.lm.eval()
65
+
66
 
67
  @torch.no_grad()
68
  def generate(self,
69
  prompt='dogs mewo',
70
+ duration=2.24, # seconds of audio
71
  ):
72
+ torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
73
+ self.lm.n_draw = int(duration / 12) + 1 # different beam every 7 seconds of audio
74
+
75
  with torch.autocast(device_type='cuda', dtype=torch.float16):
76
  gen_tokens = self.lm.generate(
77
+ text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
78
+ max_tokens=int(duration / (N_REPEAT * self.lm.n_draw) * self.compression_model.frame_rate)
79
+ ) # [bs, 4, 74 * self.lm.n_draw]
80
  x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
81
 
 
 
 
 
 
82
 
83
+ for _ in range(7): # perhaps shift is too random as already lm.n_draw has randomness
84
+ x = _shift(x)
85
 
86
+ return x.reshape(-1) #x / (x.abs().max() + 1e-7)
 
 
audiocraft/lm.py CHANGED
@@ -6,14 +6,14 @@ from transformers import T5EncoderModel, T5Tokenizer # type: ignore
6
  class T5(nn.Module):
7
 
8
  def __init__(self):
9
- # run this from within lm so it autocasts thus match exact values of t5 in official audiogen
10
  super().__init__()
11
  self.output_proj = nn.Linear(1024, # t5-large
12
  1536) # lm hidden
13
  self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True)
14
  t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False)
15
 
16
- # this makes sure that the t5 models is not part
17
  # of the saved checkpoint
18
  self.__dict__['t5'] = t5.to('cuda:0')
19
 
@@ -28,9 +28,8 @@ class T5(nn.Module):
28
 
29
  x = self.t5(input_ids=d['input_ids'],
30
  attention_mask=d['attention_mask']).last_hidden_state # no kv
31
-
32
- # output_proj as float32
33
- print('BEF PROJ',x[0, :, :].sum(), x[1, :, :].sum(), self.output_proj.weight.sum(), self.output_proj.weight.dtype, self.output_proj.bias.sum(), 'GEN\n\n143')
34
  x = self.output_proj(x) # nn.Linear() - produces different result if there is no duplicate txt condition here
35
  x[bs:, :, :] = 0 # venv/../site-packages/audiocraft/modules/conditioners.py -> tokenize()
36
  return x
@@ -41,71 +40,55 @@ class LMModel(nn.Module):
41
  def __init__(self,
42
  n_q = 4,
43
  card = 2048,
44
- dim = 1536,
45
- num_heads = 24,
46
- hidden_scale = 4, # FFN of Transformer
47
  ):
48
  super().__init__()
49
  self.t5 = T5()
50
- self.card = card # 2048 ?
51
- self.n_draw = 1 # draw additional tokens at each call:
52
- # Batch size is slower than n_draw as it calls the transformer on larger batch
53
- # n_draw instead draws more tokens/phonemes from torch.multinomial - after execution of lm
54
- embed_dim = self.card + 1
55
- self.n_q = n_q
56
- self.dim = dim
57
- self.emb = nn.ModuleList([nn.Embedding(embed_dim, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
58
- self.transformer = StreamingTransformer(
59
- d_model=dim,
60
- num_heads=num_heads,
61
- dim_feedforward=int(hidden_scale * dim),
62
- num_layers=48,
63
- positional_embedding='sin',
64
- )
65
  self.out_norm = nn.LayerNorm(dim, eps=1e-5)
66
  self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) # LINEAR DOESNT HAVE 2049
67
 
68
  def forward(self,
69
  sequence,
70
  condition_tensors=None,
71
- token_count=None):
72
 
73
  bs, n_q, time_frames = sequence.shape # [bs, 4, time]
74
 
75
- input_ = sum([self.emb[k](sequence[:, k]) for k in range(self.n_q)])
76
 
77
  out = self.transformer(torch.cat([input_, input_], 0), # duplicate null condition (bs x 2) for ClassifierFreeGuidance
78
  cross_attention_src=condition_tensors,
79
- token_count=token_count
80
  )
81
 
82
- logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(self.n_q)], dim=1)#[2*bs,4,1,2048]
83
-
84
- logits = 3 * logits[:bs, :, :, :] - 2 * logits[bs:, :, :, :] # [3, 4, 1, 2048]
85
-
86
- # SAMPLE TOP K
87
- k = 400 # 450 is nice sound still train honk is clear!
88
- p = torch.softmax(logits, dim=3)
89
- top_k_value, _ = torch.topk(p, k, dim=3) # [3, 4, 1, k]
90
- min_value_top_k = top_k_value[:, :, :, -1:]
91
- p *= (p >= min_value_top_k).float() # zero low probs
92
- p.div_(p.sum(dim=-1, keepdim=True)) # renormalise on non-zero probs
93
-
94
- # BRING THE nq = 4 IN BATCH
95
- p = p.reshape(bs * self.n_q, 2048)
96
- out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
97
- num_samples=self.n_draw,
98
- replacement=False) # [bs*4, self.n_draw]
99
- # print('DRAW','c', out)
100
- return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
101
 
102
  @torch.no_grad()
103
  def generate(self,
104
  max_tokens=None,
105
- text_condition=None
106
- ):
107
  x = self.t5(text_condition)
108
  bs = x.shape[0] // 2 # has null conditions - bs*2*N_REPEAT applys in builders.py
 
 
 
109
  out_codes = torch.full((bs,
110
  self.n_draw,
111
  4,
@@ -113,14 +96,15 @@ class LMModel(nn.Module):
113
  self.card,
114
  dtype=torch.long,
115
  device=x.device) # [bs, n_draw, 4, dur]
116
- # =========================================
 
117
  for offset in range(0, max_tokens + 4 - 1): # max_tokens + n_q - 1
118
 
119
  # extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
120
  next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
121
  #gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
122
  condition_tensors=x, # utilisation of the attention mask of txt condition ?
123
- token_count=offset) # [bs, n_draw, 4]
124
 
125
  # Fill of next_token should be also placed on antidiagonal [not column]
126
 
@@ -133,11 +117,11 @@ class LMModel(nn.Module):
133
  # [2048, 2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6]]
134
  # NO OVerWriting
135
  if offset == 0:
136
-
137
  next_token[:, :, 1:4] = 2048 # self.card - bottom 3 entries of the antidiagonal should remain 2048
138
 
139
  elif offset == 1:
140
-
141
  next_token[:, :, 2:4] = 2048 # bottom 2 entries of the antidiagonal should remain 2048
142
 
143
  elif offset == 2:
@@ -157,16 +141,22 @@ class LMModel(nn.Module):
157
  next_token[:, :, 0:3] = 2048
158
 
159
  else: # offset 3,4,5,6,7...... max_tokens-1 # FILL Complete n_q = 4 ANTIDIAGONAL ENTRIES
160
-
161
  pass #print('No delete anti-diag')
162
 
163
  out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
164
- print('\nFULL FINAL TOKENS UNFILT\n', out_codes[:, 0, :, 4:max_tokens+4], out_codes[0, 0, :, 4:max_tokens+4].shape)
165
- # EXTRACT COLUMNS AS ALIGN IS ALREADY DONE by FILLING DIAGONALLY
166
- out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens) # [bs, 4, duration*n_draw] DISCARD FILL 2048
 
 
 
 
 
 
 
167
 
168
- for lay in self.transformer.layers:
169
- lay.self_attn.k_history = None
170
- lay.self_attn.v_history = None
171
 
172
- return out_codes # SKIP THE 4 fill 2048 bs*n_draw, duration -> repeat/shift in api.py
 
6
  class T5(nn.Module):
7
 
8
  def __init__(self):
9
+
10
  super().__init__()
11
  self.output_proj = nn.Linear(1024, # t5-large
12
  1536) # lm hidden
13
  self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True)
14
  t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False)
15
 
16
+ # this makes sure that the t5 is not part
17
  # of the saved checkpoint
18
  self.__dict__['t5'] = t5.to('cuda:0')
19
 
 
28
 
29
  x = self.t5(input_ids=d['input_ids'],
30
  attention_mask=d['attention_mask']).last_hidden_state # no kv
31
+ # Float 16
32
+ # > self.output_proj() is outside of autocast of t5 - however inside the autocast of lm thus computed in torch.float16
 
33
  x = self.output_proj(x) # nn.Linear() - produces different result if there is no duplicate txt condition here
34
  x[bs:, :, :] = 0 # venv/../site-packages/audiocraft/modules/conditioners.py -> tokenize()
35
  return x
 
40
  def __init__(self,
41
  n_q = 4,
42
  card = 2048,
43
+ dim = 1536
 
 
44
  ):
45
  super().__init__()
46
  self.t5 = T5()
47
+ self.card = card # 2048
48
+ self.n_draw = 1 # draw > 1 tokens of different CFG scale
49
+ # batch size > 1 is slower from n_draw as calls transformer on larger batch
50
+ self.emb = nn.ModuleList([nn.Embedding(self.card + 1, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
51
+ self.transformer = StreamingTransformer()
 
 
 
 
 
 
 
 
 
 
52
  self.out_norm = nn.LayerNorm(dim, eps=1e-5)
53
  self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) # LINEAR DOESNT HAVE 2049
54
 
55
  def forward(self,
56
  sequence,
57
  condition_tensors=None,
58
+ cache_position=None):
59
 
60
  bs, n_q, time_frames = sequence.shape # [bs, 4, time]
61
 
62
+ input_ = sum([self.emb[k](sequence[:, k]) for k in range(n_q)])
63
 
64
  out = self.transformer(torch.cat([input_, input_], 0), # duplicate null condition (bs x 2) for ClassifierFreeGuidance
65
  cross_attention_src=condition_tensors,
66
+ cache_position=cache_position
67
  )
68
 
69
+ logits = torch.stack([self.linears[k](self.out_norm(out)) for k in range(n_q)], dim=1) # [2*bs, 4, 1, 2048]
70
+ logits = 3 * logits[:bs, :, :, :] - self._scale * logits[bs:, :, :, :] # [ bs, 4, n_draw, 2048]
71
+
72
+ k = 24
73
+ logits = torch.softmax(logits / 1.0, dim=3) # [bs, 4, 1, 2048]
74
+ p, ix = torch.topk(logits, k, dim=3) # p = [bs, 4, 1, 24], ix = [bs, 4, 1, 2048]
75
+ # Exponential Distribution
76
+ deflation = torch.empty_like(p).exponential_(lambd=1)
77
+ p = p / deflation
78
+ # divide large probs with exp(prob) If prob=.001 then 1/exp(1*.001) -> almost by 0 --> exp doesnt really produce (0, Inf)
79
+ p = p.argmax(dim=3, keepdim=True) # [bs, 4, n_draw, 24]
80
+ tok = ix.gather(dim=3, index=p).to(torch.int64) # [bs, 4, n_draw, 1]
81
+ return tok[:, :, :, 0].transpose(1, 2) # [bs, n_draw, 4]
 
 
 
 
 
 
82
 
83
  @torch.no_grad()
84
  def generate(self,
85
  max_tokens=None,
86
+ text_condition=None):
 
87
  x = self.t5(text_condition)
88
  bs = x.shape[0] // 2 # has null conditions - bs*2*N_REPEAT applys in builders.py
89
+ self._scale = .3 * torch.rand(1, 1, self.n_draw, 1, device=x.device) + 1.94
90
+ cache_position = 0
91
+
92
  out_codes = torch.full((bs,
93
  self.n_draw,
94
  4,
 
96
  self.card,
97
  dtype=torch.long,
98
  device=x.device) # [bs, n_draw, 4, dur]
99
+
100
+ # A/R
101
  for offset in range(0, max_tokens + 4 - 1): # max_tokens + n_q - 1
102
 
103
  # extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
104
  next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
105
  #gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
106
  condition_tensors=x, # utilisation of the attention mask of txt condition ?
107
+ cache_position=cache_position) # [bs, n_draw, 4]
108
 
109
  # Fill of next_token should be also placed on antidiagonal [not column]
110
 
 
117
  # [2048, 2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6]]
118
  # NO OVerWriting
119
  if offset == 0:
120
+
121
  next_token[:, :, 1:4] = 2048 # self.card - bottom 3 entries of the antidiagonal should remain 2048
122
 
123
  elif offset == 1:
124
+
125
  next_token[:, :, 2:4] = 2048 # bottom 2 entries of the antidiagonal should remain 2048
126
 
127
  elif offset == 2:
 
141
  next_token[:, :, 0:3] = 2048
142
 
143
  else: # offset 3,4,5,6,7...... max_tokens-1 # FILL Complete n_q = 4 ANTIDIAGONAL ENTRIES
144
+
145
  pass #print('No delete anti-diag')
146
 
147
  out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
148
+ # Sink Attn
149
+ if (offset > 0) and (offset % 71) == 0:
150
+ n_preserve = 4
151
+ self.transformer._flush(n_preserve=n_preserve)
152
+ cache_position = n_preserve
153
+ else:
154
+ cache_position += 1
155
+
156
+ # [bs, n_draw, 4, time+xtra] -> [bs, 4, n_draw, time] -> [bs, 4, time * n_draw]
157
+ out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens)
158
 
159
+ # flush for next API call
160
+ self.transformer._flush()
 
161
 
162
+ return out_codes # SKIP THE 4 fill 2048
audiocraft/transformer.py CHANGED
@@ -5,17 +5,23 @@ from einops import rearrange
5
 
6
  torch.backends.cuda.enable_mem_efficient_sdp(True)
7
 
8
- def create_sin_embedding(positions,
 
9
  dim,
10
  max_period=10000
11
  ):
12
- assert dim % 2 == 0
13
  half_dim = dim // 2
14
  positions = positions.to(torch.float)
15
- adim = torch.arange(half_dim, device=positions.device, dtype=torch.float).view(1, 1, -1)
16
- max_period_tensor = torch.full([], max_period, device=positions.device, dtype=torch.float) # avoid sync point
 
 
 
 
17
  phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
18
- return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) # OFFICIAL is torch.float32 HOWEVER self_attn.in_prod_weight = torch.float16
 
19
 
20
 
21
  class StreamingMultiheadAttention(nn.Module):
@@ -23,19 +29,20 @@ class StreamingMultiheadAttention(nn.Module):
23
  def __init__(self,
24
  embed_dim,
25
  num_heads,
26
- cross_attention = False,
27
  ):
28
 
29
  super().__init__()
30
 
31
  self.cross_attention = cross_attention
32
- self.embed_dim = embed_dim
33
- self.k_history = None # previous k from the previous tokens seen in the current generation - only for selt.attn
34
- self.v_history = None # clean up IN LM after finishing GENERATION - Each 1...47 mha has different kv history
 
35
  self.num_heads = num_heads
36
  self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
37
  self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim),
38
- dtype=torch.float))
39
 
40
  def forward(self,
41
  query,
@@ -44,15 +51,16 @@ class StreamingMultiheadAttention(nn.Module):
44
  layout = "b h t d"
45
  if self.cross_attention:
46
 
47
- # Different queries, keys, values, we have to spit manually the in_proj_weight
48
-
49
  dim = self.in_proj_weight.shape[0] // 3
50
 
51
  q = nn.functional.linear(query, self.in_proj_weight[:dim])
52
  k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim])
53
  v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
54
 
55
- q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
 
56
 
57
  else:
58
  # 1st projected makes k,v (instantaneous)
@@ -60,59 +68,45 @@ class StreamingMultiheadAttention(nn.Module):
60
 
61
  # HISTORY - DIFFERENT FOR EACH TRANSF LAYER
62
 
63
- projected = nn.functional.linear(query, self.in_proj_weight, None) # here we have different floating values from official
 
64
  # print(query.sum(), projected.sum() , self.in_proj_weight.sum(), 'Lc') # verified official AudioGen values
65
  bound_layout = "b h p t d"
66
- packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
 
67
  q, k, v = packed.unbind(dim=2)
68
  if self.k_history is not None:
69
- # flush
70
- if self.k_history.shape[2] > 71:
71
-
72
- self.k_history = torch.cat([self.k_history[:, :, :4, :], self.k_history[:, :, -1:, :]], 2)
73
- self.v_history = torch.cat([self.v_history[:, :, :4, :], self.v_history[:, :, -1:, :]], 2)
74
- # fill new k/v
75
- self.k_history = torch.cat([self.k_history, k], 2) # IF ctrl^c here during live demo it is non-atomic k!=v
76
- self.v_history = torch.cat([self.v_history, v], 2) # thus it will try to continue with incompatible k/v dims!
77
-
78
  else:
79
- # init
80
  self.k_history = k
81
  self.v_history = v
82
- # For self attn prepare
 
 
83
  k = self.k_history
84
  v = self.v_history
85
 
 
86
 
87
-
88
- # KV COMPLETION ONLY ON SELF ATTENTION
89
-
90
  x = torch.nn.functional.scaled_dot_product_attention(
91
- q, k, v, is_causal=False, dropout_p=0)
92
 
93
  x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
94
  x = self.out_proj(x)
95
  return x
96
 
97
 
 
98
 
99
- class StreamingTransformerLayer(nn.TransformerEncoderLayer):
100
-
101
  def __init__(self,
102
  d_model,
103
  num_heads,
104
  dim_feedforward):
105
-
106
- super().__init__(d_model,
107
- num_heads,
108
- dim_feedforward=dim_feedforward,
109
- dropout=0.0,
110
- device='cuda',
111
- dtype=torch.float32,
112
- batch_first=True,
113
- norm_first=True,
114
- activation='gelu')
115
- # super().__init__()
116
 
117
  self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
118
  num_heads=num_heads)
@@ -125,15 +119,14 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
125
  self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
126
  self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
127
 
128
-
129
  def forward(self,
130
  x,
131
- cross_attention_src=None): # txtcond
132
  x = x + self.self_attn(self.norm1(x))
133
- x = x + self.cross_attention(query = self.norm_cross(x),
134
- key = cross_attention_src,
135
- value = cross_attention_src) # txtcondition
136
- x = x + self.linear2(F.gelu(self.linear1( self.norm2(x) )))
137
  return x
138
 
139
 
@@ -143,39 +136,38 @@ class StreamingTransformer(nn.Module):
143
  d_model=1536,
144
  num_heads=24,
145
  num_layers=48,
146
- dim_feedforward=6144,
147
- cross_attention = True,
148
- positional_embedding: str = 'sin',
149
- max_period: float = 10_000
150
- ):
151
  super().__init__()
152
- assert d_model % num_heads == 0
153
-
154
- self.positional_embedding = positional_embedding
155
- self.max_period = max_period
156
- self.layers = nn.ModuleList()
157
- for idx in range(num_layers):
158
- self.layers.append(
159
- StreamingTransformerLayer(
160
- d_model=d_model,
161
- num_heads=num_heads,
162
- dim_feedforward=dim_feedforward
163
- )
164
- )
165
 
166
  def forward(self,
167
  x,
168
- token_count=None,
169
  cross_attention_src=None):
170
 
171
- if self.positional_embedding in ['sin', 'sin_rope']:
172
- pos_emb = create_sin_embedding(torch.zeros(x.shape[0], 1, 1, device=x.device) + token_count,
173
- 1536,
174
- max_period=self.max_period)
175
 
176
- x = x + pos_emb
177
- for j, lay in enumerate(self.layers):
178
- x = lay(x, cross_attention_src=cross_attention_src) # cross_attention_src = txt-cond x audio
179
- # self attn = audio x audio
180
- # Every layer (mha) keeps itsw own kv cachE
181
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  torch.backends.cuda.enable_mem_efficient_sdp(True)
7
 
8
+
9
+ def create_sin_embedding(positions,
10
  dim,
11
  max_period=10000
12
  ):
13
+ # assert dim % 2 == 0
14
  half_dim = dim // 2
15
  positions = positions.to(torch.float)
16
+ adim = torch.arange(half_dim, device=positions.device,
17
+ dtype=torch.float).view(1, 1, -1)
18
+ max_period_tensor = torch.full([],
19
+ max_period,
20
+ device=positions.device,
21
+ dtype=torch.float) # avoid sync point
22
  phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
23
+ # OFFICIAL is torch.float32 HOWEVER self_attn.in_prod_weight = torch.float16
24
+ return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
25
 
26
 
27
  class StreamingMultiheadAttention(nn.Module):
 
29
  def __init__(self,
30
  embed_dim,
31
  num_heads,
32
+ cross_attention=False,
33
  ):
34
 
35
  super().__init__()
36
 
37
  self.cross_attention = cross_attention
38
+ # if not self.cross_attention then it has kvcachingn
39
+ self.k_history = None
40
+ # cleanup history through LM inside GENERATION - Each 0,..,47 mha has different kv history
41
+ self.v_history = None
42
  self.num_heads = num_heads
43
  self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
44
  self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim),
45
+ dtype=torch.float))
46
 
47
  def forward(self,
48
  query,
 
51
  layout = "b h t d"
52
  if self.cross_attention:
53
 
54
+ # Different queries, keys, values > split in_proj_weight
55
+
56
  dim = self.in_proj_weight.shape[0] // 3
57
 
58
  q = nn.functional.linear(query, self.in_proj_weight[:dim])
59
  k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim])
60
  v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
61
 
62
+ q, k, v = [
63
+ rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
64
 
65
  else:
66
  # 1st projected makes k,v (instantaneous)
 
68
 
69
  # HISTORY - DIFFERENT FOR EACH TRANSF LAYER
70
 
71
+ # here we have different floating values from official
72
+ projected = nn.functional.linear(query, self.in_proj_weight, None)
73
  # print(query.sum(), projected.sum() , self.in_proj_weight.sum(), 'Lc') # verified official AudioGen values
74
  bound_layout = "b h p t d"
75
+ packed = rearrange(
76
+ projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
77
  q, k, v = packed.unbind(dim=2)
78
  if self.k_history is not None:
79
+ # IF ctrl^c during live_demo the assigning of each of kv is non-atomic k!=v
80
+ # thus it will try to continue with incompatible k/v dims!
81
+ self.k_history = torch.cat([self.k_history, k], 2)
82
+ self.v_history = torch.cat([self.v_history, v], 2)
 
 
 
 
 
83
  else:
 
84
  self.k_history = k
85
  self.v_history = v
86
+
87
+ # Assign Completed k / v to k / v
88
+
89
  k = self.k_history
90
  v = self.v_history
91
 
92
+ # -> kv CACHE ONLY APPLIES if not self.cross_attention
93
 
 
 
 
94
  x = torch.nn.functional.scaled_dot_product_attention(
95
+ q, k, v, attn_mask=None, is_causal=False, dropout_p=0.0)
96
 
97
  x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
98
  x = self.out_proj(x)
99
  return x
100
 
101
 
102
+ class StreamingTransformerLayer(nn.Module):
103
 
 
 
104
  def __init__(self,
105
  d_model,
106
  num_heads,
107
  dim_feedforward):
108
+
109
+ super().__init__()
 
 
 
 
 
 
 
 
 
110
 
111
  self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
112
  num_heads=num_heads)
 
119
  self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
120
  self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
121
 
 
122
  def forward(self,
123
  x,
124
+ cross_attention_src=None):
125
  x = x + self.self_attn(self.norm1(x))
126
+ x = x + self.cross_attention(query=self.norm_cross(x),
127
+ key=cross_attention_src,
128
+ value=cross_attention_src) # txtcondition
129
+ x = x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
130
  return x
131
 
132
 
 
136
  d_model=1536,
137
  num_heads=24,
138
  num_layers=48,
139
+ dim_feedforward=6144):
 
 
 
 
140
  super().__init__()
141
+
142
+ self.layers = nn.ModuleList(
143
+ [
144
+ StreamingTransformerLayer(d_model=d_model,
145
+ num_heads=num_heads,
146
+ dim_feedforward=dim_feedforward) for _ in range(num_layers)
147
+ ]
148
+ )
 
 
 
 
 
149
 
150
  def forward(self,
151
  x,
152
+ cache_position=None,
153
  cross_attention_src=None):
154
 
155
+ x = x + create_sin_embedding(
156
+ torch.zeros(x.shape[0], 1, 1, device=x.device) + cache_position, 1536)
 
 
157
 
158
+ for lay in self.layers:
159
+ x = lay(x,
160
+ cross_attention_src=cross_attention_src)
 
 
161
  return x
162
+
163
+ def _flush(self,
164
+ n_preserve=None):
165
+
166
+ for lay in self.layers:
167
+ if n_preserve is not None:
168
+ # cache position is difficult to choose to also preserve kv from end
169
+ lay.self_attn.k_history = lay.self_attn.k_history[:, :, :n_preserve, :]
170
+ lay.self_attn.v_history = lay.self_attn.v_history[:, :, :n_preserve, :]
171
+ else:
172
+ lay.self_attn.k_history = None
173
+ lay.self_attn.v_history = None