preserve only few last kv
Browse files- README.md +3 -1
- audiocraft/builders.py +8 -14
- audiocraft/transformer.py +13 -21
- msinference.py +1 -1
README.md
CHANGED
@@ -67,7 +67,9 @@ CUDA_DEVICE_ORDER=PCI_BUS_ID HF_HOME=/data/dkounadis/.hf7/ CUDA_VISIBLE_DEVICES=
|
|
67 |
|
68 |
Following examples need `api.py` to be running. [Set this IP](https://huggingface.co/dkounadis/artificial-styletts2/blob/main/tts.py#L85) to the IP shown when starting `api.py`.
|
69 |
|
70 |
-
|
|
|
|
|
71 |
|
72 |
</details>
|
73 |
|
|
|
67 |
|
68 |
Following examples need `api.py` to be running. [Set this IP](https://huggingface.co/dkounadis/artificial-styletts2/blob/main/tts.py#L85) to the IP shown when starting `api.py`.
|
69 |
|
70 |
+
```
|
71 |
+
python tts.py --text assets/ocr.txt --image assets/ocr.jpg --soundscape "battle hero" --voice romanian
|
72 |
+
```
|
73 |
|
74 |
</details>
|
75 |
|
audiocraft/builders.py
CHANGED
@@ -11,15 +11,13 @@ from .lm import LMModel
|
|
11 |
from .seanet import SEANetDecoder
|
12 |
from .vq import ResidualVectorQuantizer
|
13 |
|
14 |
-
N_REPEAT =
|
15 |
|
16 |
def _shift(x):
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
x[i, :] = torch.roll(batch_elem, offset, dims=0) # batch_elem = [400000, ]
|
22 |
-
return x
|
23 |
|
24 |
def _delete_param(cfg, full_name):
|
25 |
parts = full_name.split('.')
|
@@ -70,18 +68,14 @@ class AudioGen(nn.Module):
|
|
70 |
|
71 |
# AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
|
72 |
|
73 |
-
x = self.resample_fn(x)
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
x = x.repeat(1, N_REPEAT)
|
78 |
-
|
79 |
-
# less periodic - shift every batch elem
|
80 |
|
81 |
for _ in range(7):
|
82 |
x = _shift(x)
|
83 |
|
84 |
-
|
85 |
print(x.abs().max(), 'MAX')
|
86 |
return x / (x.abs().max() + 1e-7)
|
87 |
|
|
|
11 |
from .seanet import SEANetDecoder
|
12 |
from .vq import ResidualVectorQuantizer
|
13 |
|
14 |
+
N_REPEAT = 3 # num (virtual batch_size) clones of audio sounds
|
15 |
|
16 |
def _shift(x):
|
17 |
+
n = x.shape[0]
|
18 |
+
offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
|
19 |
+
return torch.roll(x, offset, dims=0)
|
20 |
+
|
|
|
|
|
21 |
|
22 |
def _delete_param(cfg, full_name):
|
23 |
parts = full_name.split('.')
|
|
|
68 |
|
69 |
# AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
|
70 |
|
71 |
+
x = self.resample_fn(x) # [N_REPEAT, duration]
|
72 |
|
73 |
+
x = x.repeat(1, N_REPEAT).reshape(-1)
|
|
|
|
|
|
|
|
|
74 |
|
75 |
for _ in range(7):
|
76 |
x = _shift(x)
|
77 |
|
78 |
+
|
79 |
print(x.abs().max(), 'MAX')
|
80 |
return x / (x.abs().max() + 1e-7)
|
81 |
|
audiocraft/transformer.py
CHANGED
@@ -3,8 +3,8 @@ import torch.nn as nn
|
|
3 |
from torch.nn import functional as F
|
4 |
from einops import rearrange
|
5 |
|
6 |
-
def create_sin_embedding(positions,
|
7 |
-
dim,
|
8 |
max_period = 10000,
|
9 |
dtype = torch.float32):
|
10 |
"""Create sinusoidal positional embedding, with shape `[B, T, C]`.
|
@@ -78,28 +78,20 @@ class StreamingMultiheadAttention(nn.Module):
|
|
78 |
|
79 |
|
80 |
if self.k_history is not None:
|
81 |
-
#
|
82 |
-
|
|
|
|
|
|
|
|
|
83 |
self.k_history = torch.cat([self.k_history, k], 2) # IF ctrl^c here during live demo it is non-atomic k!=v
|
84 |
self.v_history = torch.cat([self.v_history, v], 2) # thus it will try to continue with incompatible k/v dims!
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
# find LOWEST l2 norm of keys > https://arxiv.org/pdf/2406.11430v4
|
89 |
-
|
90 |
-
low_norm = (self.k_history * self.k_history).mean(3, keepdims=True).sum(1, keepdims=True) # [bs, 24, T, 64] -> [bs, T]
|
91 |
-
_, _ix = torch.topk(low_norm, k=10, dim=2, largest=False) # shows background music due to cfg - looses the txt conditioning if flushed!
|
92 |
-
_ix = _ix.repeat(1, 24, 1, 64)
|
93 |
-
# print(_ix.shape)
|
94 |
-
self.k_history = torch.gather(self.k_history, 2, _ix)
|
95 |
-
self.v_history = torch.gather(self.v_history, 2, _ix)
|
96 |
-
|
97 |
-
else:
|
98 |
-
# init on 1st token (for all 47 transf layers)
|
99 |
-
print(f'AudioGen kv cache Flush')
|
100 |
self.k_history = k
|
101 |
-
self.v_history = v
|
102 |
-
|
103 |
k = self.k_history
|
104 |
v = self.v_history
|
105 |
|
|
|
3 |
from torch.nn import functional as F
|
4 |
from einops import rearrange
|
5 |
|
6 |
+
def create_sin_embedding(positions,
|
7 |
+
dim,
|
8 |
max_period = 10000,
|
9 |
dtype = torch.float32):
|
10 |
"""Create sinusoidal positional embedding, with shape `[B, T, C]`.
|
|
|
78 |
|
79 |
|
80 |
if self.k_history is not None:
|
81 |
+
# flush
|
82 |
+
if self.k_history.shape[2] > 71:
|
83 |
+
|
84 |
+
self.k_history = torch.cat([self.k_history[:, :, :4, :], self.k_history[:, :, -1:, :]], 2)
|
85 |
+
self.v_history = torch.cat([self.v_history[:, :, :4, :], self.v_history[:, :, -1:, :]], 2)
|
86 |
+
# fill new k/v
|
87 |
self.k_history = torch.cat([self.k_history, k], 2) # IF ctrl^c here during live demo it is non-atomic k!=v
|
88 |
self.v_history = torch.cat([self.v_history, v], 2) # thus it will try to continue with incompatible k/v dims!
|
89 |
+
|
90 |
+
else:
|
91 |
+
# init
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
self.k_history = k
|
93 |
+
self.v_history = v
|
94 |
+
# For self attn prepare
|
95 |
k = self.k_history
|
96 |
v = self.v_history
|
97 |
|
msinference.py
CHANGED
@@ -390,7 +390,7 @@ def foreign(text=None, # split sentences here so we can prepend a txt for germ
|
|
390 |
|
391 |
x = net_g(input_ids=inputs.input_ids.to(device),
|
392 |
attention_mask=inputs.attention_mask.to(device),
|
393 |
-
speed = .94 + .
|
394 |
)[0, :]
|
395 |
|
396 |
# crop the 1st audio - is PREFIX text 156000 samples to chose deu voice / VitsAttention()
|
|
|
390 |
|
391 |
x = net_g(input_ids=inputs.input_ids.to(device),
|
392 |
attention_mask=inputs.attention_mask.to(device),
|
393 |
+
speed = .94 + .84 * np.random.rand() # variable speed / sentence
|
394 |
)[0, :]
|
395 |
|
396 |
# crop the 1st audio - is PREFIX text 156000 samples to chose deu voice / VitsAttention()
|