Dionyssos commited on
Commit
17a68db
·
1 Parent(s): 5f5d0ea

preserve only few last kv

Browse files
Files changed (4) hide show
  1. README.md +3 -1
  2. audiocraft/builders.py +8 -14
  3. audiocraft/transformer.py +13 -21
  4. msinference.py +1 -1
README.md CHANGED
@@ -67,7 +67,9 @@ CUDA_DEVICE_ORDER=PCI_BUS_ID HF_HOME=/data/dkounadis/.hf7/ CUDA_VISIBLE_DEVICES=
67
 
68
  Following examples need `api.py` to be running. [Set this IP](https://huggingface.co/dkounadis/artificial-styletts2/blob/main/tts.py#L85) to the IP shown when starting `api.py`.
69
 
70
- <div><iframe width="560" height="315" src="https://www.youtube.com/embed/2YjxAPkdXIc?si=eVpClu_7whMAdWi0" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" referrerpolicy="strict-origin-when-cross-origin" allowfullscreen></iframe></div>
 
 
71
 
72
  </details>
73
 
 
67
 
68
  Following examples need `api.py` to be running. [Set this IP](https://huggingface.co/dkounadis/artificial-styletts2/blob/main/tts.py#L85) to the IP shown when starting `api.py`.
69
 
70
+ ```
71
+ python tts.py --text assets/ocr.txt --image assets/ocr.jpg --soundscape "battle hero" --voice romanian
72
+ ```
73
 
74
  </details>
75
 
audiocraft/builders.py CHANGED
@@ -11,15 +11,13 @@ from .lm import LMModel
11
  from .seanet import SEANetDecoder
12
  from .vq import ResidualVectorQuantizer
13
 
14
- N_REPEAT = 7 # num (virtual batch_size) clones of audio sounds
15
 
16
  def _shift(x):
17
- # [bs, samples] shift circular each batch elem of sound
18
- n = x.shape[1]
19
- for i, batch_elem in enumerate(x):
20
- offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
21
- x[i, :] = torch.roll(batch_elem, offset, dims=0) # batch_elem = [400000, ]
22
- return x
23
 
24
  def _delete_param(cfg, full_name):
25
  parts = full_name.split('.')
@@ -70,18 +68,14 @@ class AudioGen(nn.Module):
70
 
71
  # AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
72
 
73
- x = self.resample_fn(x)
74
 
75
- # batch size = different sounds for same txt
76
-
77
- x = x.repeat(1, N_REPEAT)
78
-
79
- # less periodic - shift every batch elem
80
 
81
  for _ in range(7):
82
  x = _shift(x)
83
 
84
- x = x.reshape(-1)
85
  print(x.abs().max(), 'MAX')
86
  return x / (x.abs().max() + 1e-7)
87
 
 
11
  from .seanet import SEANetDecoder
12
  from .vq import ResidualVectorQuantizer
13
 
14
+ N_REPEAT = 3 # num (virtual batch_size) clones of audio sounds
15
 
16
  def _shift(x):
17
+ n = x.shape[0]
18
+ offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
19
+ return torch.roll(x, offset, dims=0)
20
+
 
 
21
 
22
  def _delete_param(cfg, full_name):
23
  parts = full_name.split('.')
 
68
 
69
  # AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
70
 
71
+ x = self.resample_fn(x) # [N_REPEAT, duration]
72
 
73
+ x = x.repeat(1, N_REPEAT).reshape(-1)
 
 
 
 
74
 
75
  for _ in range(7):
76
  x = _shift(x)
77
 
78
+
79
  print(x.abs().max(), 'MAX')
80
  return x / (x.abs().max() + 1e-7)
81
 
audiocraft/transformer.py CHANGED
@@ -3,8 +3,8 @@ import torch.nn as nn
3
  from torch.nn import functional as F
4
  from einops import rearrange
5
 
6
- def create_sin_embedding(positions,
7
- dim,
8
  max_period = 10000,
9
  dtype = torch.float32):
10
  """Create sinusoidal positional embedding, with shape `[B, T, C]`.
@@ -78,28 +78,20 @@ class StreamingMultiheadAttention(nn.Module):
78
 
79
 
80
  if self.k_history is not None:
81
- # k_history.shape = torch.Size([2*N_REPEAT, 24, 3, 64]) FOR cfg > k.shape=torch.Size([2, 24, 1, 64])
82
- # 24 heads 64 dim
 
 
 
 
83
  self.k_history = torch.cat([self.k_history, k], 2) # IF ctrl^c here during live demo it is non-atomic k!=v
84
  self.v_history = torch.cat([self.v_history, v], 2) # thus it will try to continue with incompatible k/v dims!
85
- # Preserve first 4-10 tokens & flush kv
86
- if self.k_history.shape[2] > 24:
87
-
88
- # find LOWEST l2 norm of keys > https://arxiv.org/pdf/2406.11430v4
89
-
90
- low_norm = (self.k_history * self.k_history).mean(3, keepdims=True).sum(1, keepdims=True) # [bs, 24, T, 64] -> [bs, T]
91
- _, _ix = torch.topk(low_norm, k=10, dim=2, largest=False) # shows background music due to cfg - looses the txt conditioning if flushed!
92
- _ix = _ix.repeat(1, 24, 1, 64)
93
- # print(_ix.shape)
94
- self.k_history = torch.gather(self.k_history, 2, _ix)
95
- self.v_history = torch.gather(self.v_history, 2, _ix)
96
-
97
- else:
98
- # init on 1st token (for all 47 transf layers)
99
- print(f'AudioGen kv cache Flush')
100
  self.k_history = k
101
- self.v_history = v
102
-
103
  k = self.k_history
104
  v = self.v_history
105
 
 
3
  from torch.nn import functional as F
4
  from einops import rearrange
5
 
6
+ def create_sin_embedding(positions,
7
+ dim,
8
  max_period = 10000,
9
  dtype = torch.float32):
10
  """Create sinusoidal positional embedding, with shape `[B, T, C]`.
 
78
 
79
 
80
  if self.k_history is not None:
81
+ # flush
82
+ if self.k_history.shape[2] > 71:
83
+
84
+ self.k_history = torch.cat([self.k_history[:, :, :4, :], self.k_history[:, :, -1:, :]], 2)
85
+ self.v_history = torch.cat([self.v_history[:, :, :4, :], self.v_history[:, :, -1:, :]], 2)
86
+ # fill new k/v
87
  self.k_history = torch.cat([self.k_history, k], 2) # IF ctrl^c here during live demo it is non-atomic k!=v
88
  self.v_history = torch.cat([self.v_history, v], 2) # thus it will try to continue with incompatible k/v dims!
89
+
90
+ else:
91
+ # init
 
 
 
 
 
 
 
 
 
 
 
 
92
  self.k_history = k
93
+ self.v_history = v
94
+ # For self attn prepare
95
  k = self.k_history
96
  v = self.v_history
97
 
msinference.py CHANGED
@@ -390,7 +390,7 @@ def foreign(text=None, # split sentences here so we can prepend a txt for germ
390
 
391
  x = net_g(input_ids=inputs.input_ids.to(device),
392
  attention_mask=inputs.attention_mask.to(device),
393
- speed = .94 + .4 * np.random.rand() # variable speed / sentence
394
  )[0, :]
395
 
396
  # crop the 1st audio - is PREFIX text 156000 samples to chose deu voice / VitsAttention()
 
390
 
391
  x = net_g(input_ids=inputs.input_ids.to(device),
392
  attention_mask=inputs.attention_mask.to(device),
393
+ speed = .94 + .84 * np.random.rand() # variable speed / sentence
394
  )[0, :]
395
 
396
  # crop the 1st audio - is PREFIX text 156000 samples to chose deu voice / VitsAttention()