numerals
Browse files- api.py +11 -55
- audiocraft/activations.py +0 -96
- audiocraft/builders.py +44 -91
- audiocraft/conditioners.py +3 -3
- audiocraft/lm.py +28 -222
- audiocraft/transformer.py +94 -287
- msinference.py +57 -23
- requirements.txt +1 -1
api.py
CHANGED
@@ -2,7 +2,6 @@
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
import numpy as np
|
4 |
import soundfile
|
5 |
-
import audresample
|
6 |
from Utils.text_utils import split_into_sentences
|
7 |
import msinference
|
8 |
import re
|
@@ -15,10 +14,12 @@ from flask import Flask, request, send_from_directory
|
|
15 |
from moviepy.video.io.VideoFileClip import VideoFileClip
|
16 |
from moviepy.video.VideoClip import ImageClip
|
17 |
from audiocraft.builders import AudioGen
|
18 |
-
CACHE_DIR = 'flask_cache/'
|
19 |
-
NUM_SOUND_GENERATIONS = 3 # batch size to generate same text (same soundscape for long video)
|
20 |
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
|
23 |
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
24 |
|
@@ -57,62 +58,17 @@ def _resize(image, width=None, height=None, inter=cv2.INTER_AREA):
|
|
57 |
# return the resized image
|
58 |
return resized
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
def _shift(x):
|
63 |
-
n = x.shape[0]
|
64 |
-
i = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0
|
65 |
-
x = np.roll(x, i)
|
66 |
-
# we can add the one or fade it and then amplify
|
67 |
-
# the audio is so short 6s that is difficult to not hear the shift somewhere
|
68 |
-
# Just concatenate - raw - and then shift - the longconcat audio - many times may fix it
|
69 |
-
# fade_in = 1 - .5 * np.tanh(-4*(np.linspace(-10, 10, n) - 9.4)) + .5 * np.tanh(4*(np.linspace(-10, 10, n) + 9.4))
|
70 |
-
return x #* fade_in # silence this
|
71 |
-
|
72 |
def overlay(x, soundscape=None):
|
73 |
-
|
74 |
if soundscape is not None:
|
75 |
-
|
76 |
-
# SOUNDS
|
77 |
-
|
78 |
-
background = sound_generator.generate(
|
79 |
-
[soundscape] * NUM_SOUND_GENERATIONS
|
80 |
-
).reshape(-1).detach().cpu().numpy() # bs, 11400 @.74s
|
81 |
-
|
82 |
-
# upsample 16 kHz AudioGen to 24kHZ of VITS/StyleTTS2
|
83 |
-
|
84 |
-
print('Resampling') # soundscape each generation in batch differs from the other generations thus clone/shift each element in batch, finally concat w/o shift
|
85 |
-
|
86 |
-
|
87 |
-
background = audresample.resample(
|
88 |
-
background,
|
89 |
-
original_rate=16000, # sound_generator.sample_rate,
|
90 |
-
target_rate=24000)[0, :-250] # last samples have splash sounds DISCARD 25000 last samples
|
91 |
-
|
92 |
|
|
|
|
|
|
|
93 |
|
|
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
n_repeat = len(x) // background.shape[0] + 1
|
104 |
-
|
105 |
-
total = np.tile(background, n_repeat)
|
106 |
-
|
107 |
-
# less periodic
|
108 |
-
|
109 |
-
for _ in range(4):
|
110 |
-
total = _shift(total)
|
111 |
-
|
112 |
-
# amplify sounds full [-1,1]
|
113 |
-
|
114 |
-
total /= np.abs(total).max() + 1e-7
|
115 |
-
x = .5 * x + .5 * total[:len(x)]
|
116 |
|
117 |
else:
|
118 |
|
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
import numpy as np
|
4 |
import soundfile
|
|
|
5 |
from Utils.text_utils import split_into_sentences
|
6 |
import msinference
|
7 |
import re
|
|
|
14 |
from moviepy.video.io.VideoFileClip import VideoFileClip
|
15 |
from moviepy.video.VideoClip import ImageClip
|
16 |
from audiocraft.builders import AudioGen
|
|
|
|
|
17 |
|
18 |
+
CACHE_DIR = 'flask_cache/'
|
19 |
+
PIECE_OF_SOUND_DURATION = 4.74 # seconds
|
20 |
+
sound_generator = AudioGen(
|
21 |
+
duration=PIECE_OF_SOUND_DURATION
|
22 |
+
).to('cuda:0').eval()
|
23 |
|
24 |
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
|
25 |
|
|
|
58 |
# return the resized image
|
59 |
return resized
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
def overlay(x, soundscape=None):
|
62 |
+
# pre-calculate the n_repeat here then apply torchaudio.resample and repeat insd sound_gen forward()
|
63 |
if soundscape is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
background = sound_generator.generate(soundscape,
|
66 |
+
n_repeat=int(len(x) / (PIECE_OF_SOUND_DURATION * 16000)) + 1
|
67 |
+
).detach().cpu().numpy() # bs, 11400 @.74s
|
68 |
|
69 |
+
# blend TTS
|
70 |
|
71 |
+
x = .5 * x + .5 * background[:len(x)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
else:
|
74 |
|
audiocraft/activations.py
DELETED
@@ -1,96 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import torch
|
8 |
-
import torch.nn as nn
|
9 |
-
from torch import Tensor
|
10 |
-
from typing import Union, Callable
|
11 |
-
|
12 |
-
|
13 |
-
class CustomGLU(nn.Module):
|
14 |
-
"""Custom Gated Linear Unit activation.
|
15 |
-
Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
|
16 |
-
of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
|
17 |
-
function (i.e. sigmoid, swish, etc.).
|
18 |
-
|
19 |
-
Args:
|
20 |
-
activation (nn.Module): The custom activation to apply in the Gated Linear Unit
|
21 |
-
dim (int): the dimension on which to split the input. Default: -1
|
22 |
-
|
23 |
-
Shape:
|
24 |
-
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
|
25 |
-
dimensions
|
26 |
-
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
|
27 |
-
|
28 |
-
Examples::
|
29 |
-
>>> m = CustomGLU(nn.Sigmoid())
|
30 |
-
>>> input = torch.randn(4, 2)
|
31 |
-
>>> output = m(input)
|
32 |
-
"""
|
33 |
-
def __init__(self, activation: nn.Module, dim: int = -1):
|
34 |
-
super(CustomGLU, self).__init__()
|
35 |
-
self.dim = dim
|
36 |
-
self.activation = activation
|
37 |
-
|
38 |
-
def forward(self, x: Tensor):
|
39 |
-
assert x.shape[self.dim] % 2 == 0 # M = N / 2
|
40 |
-
a, b = torch.chunk(x, 2, dim=self.dim)
|
41 |
-
return a * self.activation(b)
|
42 |
-
|
43 |
-
|
44 |
-
class SwiGLU(CustomGLU):
|
45 |
-
"""SiLU Gated Linear Unit activation.
|
46 |
-
Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
|
47 |
-
the first half of the input matrices, :math:`b` is the second half.
|
48 |
-
|
49 |
-
Args:
|
50 |
-
dim (int): the dimension on which to split the input. Default: -1
|
51 |
-
"""
|
52 |
-
def __init__(self, dim: int = -1):
|
53 |
-
super(SwiGLU, self).__init__(nn.SiLU(), dim)
|
54 |
-
|
55 |
-
|
56 |
-
class GeGLU(CustomGLU):
|
57 |
-
"""GeLU Gated Linear Unit activation.
|
58 |
-
Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
|
59 |
-
the first half of the input matrices, :math:`b` is the second half.
|
60 |
-
|
61 |
-
Args:
|
62 |
-
dim (int): the dimension on which to split the input. Default: -1
|
63 |
-
"""
|
64 |
-
def __init__(self, dim: int = -1):
|
65 |
-
super(GeGLU, self).__init__(nn.GELU(), dim)
|
66 |
-
|
67 |
-
|
68 |
-
class ReGLU(CustomGLU):
|
69 |
-
"""ReLU Gated Linear Unit activation.
|
70 |
-
Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
|
71 |
-
the first half of the input matrices, :math:`b` is the second half.
|
72 |
-
|
73 |
-
Args:
|
74 |
-
dim (int): the dimension on which to split the input. Default: -1
|
75 |
-
"""
|
76 |
-
def __init__(self, dim: int = -1):
|
77 |
-
super(ReGLU, self).__init__(nn.ReLU(), dim)
|
78 |
-
|
79 |
-
|
80 |
-
def get_activation_fn(
|
81 |
-
activation: Union[str, Callable[[Tensor], Tensor]]
|
82 |
-
) -> Union[str, Callable[[Tensor], Tensor]]:
|
83 |
-
"""Helper function to map an activation string to the activation class.
|
84 |
-
If the supplied activation is not a string that is recognized, the activation is passed back.
|
85 |
-
|
86 |
-
Args:
|
87 |
-
activation (str, or Callable[[Tensor], Tensor]): Activation to check
|
88 |
-
"""
|
89 |
-
if isinstance(activation, str):
|
90 |
-
if activation == "reglu":
|
91 |
-
return ReGLU()
|
92 |
-
elif activation == "geglu":
|
93 |
-
return GeGLU()
|
94 |
-
elif activation == "swiglu":
|
95 |
-
return SwiGLU()
|
96 |
-
return activation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/builders.py
CHANGED
@@ -1,22 +1,25 @@
|
|
1 |
-
import typing as tp
|
2 |
import omegaconf
|
|
|
3 |
from torch import nn
|
4 |
import torch
|
|
|
5 |
from huggingface_hub import hf_hub_download
|
6 |
import os
|
7 |
-
from omegaconf import OmegaConf
|
8 |
-
|
9 |
from .encodec import EncodecModel
|
10 |
from .lm import LMModel
|
11 |
from .seanet import SEANetDecoder
|
12 |
-
from .codebooks_patterns import DelayedPatternProvider
|
13 |
-
from .conditioners import T5Conditioner
|
14 |
from .vq import ResidualVectorQuantizer
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
def _delete_param(cfg: DictConfig, full_name: str):
|
20 |
parts = full_name.split('.')
|
21 |
for part in parts[:-1]:
|
22 |
if part in cfg:
|
@@ -35,48 +38,53 @@ def dict_from_config(cfg):
|
|
35 |
return dct
|
36 |
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
# ============================================== DEFINE AUDIOGEN
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
class AudioGen(nn.Module):
|
52 |
|
53 |
# https://huggingface.co/facebook/audiogen-medium
|
54 |
|
55 |
def __init__(self,
|
56 |
-
duration=
|
57 |
-
|
58 |
|
59 |
super().__init__()
|
60 |
-
self.device = device # needed for loading & select float16 LM
|
61 |
self.load_compression_model()
|
62 |
self.load_lm_model()
|
63 |
self.duration = duration
|
|
|
|
|
64 |
|
65 |
@property
|
66 |
def frame_rate(self):
|
67 |
return self.compression_model.frame_rate
|
68 |
|
69 |
def generate(self,
|
70 |
-
descriptions
|
|
|
|
|
71 |
with torch.no_grad():
|
72 |
gen_tokens = self.lm.generate(
|
73 |
-
descriptions=descriptions,
|
74 |
max_gen_len=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
|
75 |
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
|
76 |
-
|
|
|
|
|
|
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
# == BUILD Fn
|
82 |
def get_quantizer(self, quantizer, cfg, dimension):
|
@@ -126,58 +134,7 @@ class AudioGen(nn.Module):
|
|
126 |
).to(cfg.device)
|
127 |
else:
|
128 |
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
129 |
-
|
130 |
-
|
131 |
-
def get_lm_model(self, cfg):
|
132 |
-
"""Instantiate a transformer LM."""
|
133 |
-
if cfg.lm_model in ['transformer_lm',
|
134 |
-
'transformer_lm_magnet']:
|
135 |
-
kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
|
136 |
-
n_q = kwargs['n_q']
|
137 |
-
q_modeling = kwargs.pop('q_modeling', None)
|
138 |
-
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
|
139 |
-
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
|
140 |
-
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
|
141 |
-
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
# if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
|
146 |
-
kwargs['cross_attention'] = True
|
147 |
-
if codebooks_pattern_cfg.modeling is None:
|
148 |
-
print('Q MODELING\n=\n=><')
|
149 |
-
assert q_modeling is not None, \
|
150 |
-
"LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
|
151 |
-
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
|
152 |
-
{'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
|
153 |
-
)
|
154 |
-
|
155 |
-
pattern_provider = self.get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
|
156 |
-
return LMModel(
|
157 |
-
pattern_provider=pattern_provider,
|
158 |
-
condition_provider=T5Conditioner(name='t5-large', output_dim=kwargs["dim"], device=self.device),
|
159 |
-
cfg_dropout=cfg_prob,
|
160 |
-
cfg_coef=cfg_coef,
|
161 |
-
attribute_dropout=attribute_dropout,
|
162 |
-
dtype=getattr(torch, cfg.dtype),
|
163 |
-
device=self.device,
|
164 |
-
**kwargs
|
165 |
-
).to(cfg.device)
|
166 |
-
else:
|
167 |
-
raise KeyError(f"Unexpected LM model {cfg.lm_model}")
|
168 |
-
|
169 |
-
|
170 |
-
def get_codebooks_pattern_provider(self, n_q, cfg):
|
171 |
-
pattern_providers = {
|
172 |
-
'delay': DelayedPatternProvider, # THIS
|
173 |
-
}
|
174 |
-
name = cfg.modeling
|
175 |
-
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
|
176 |
-
|
177 |
-
klass = pattern_providers[name]
|
178 |
-
return klass(n_q, **kwargs)
|
179 |
-
|
180 |
-
# ======================
|
181 |
def load_compression_model(self):
|
182 |
file = hf_hub_download(
|
183 |
repo_id='facebook/audiogen-medium',
|
@@ -204,24 +161,20 @@ class AudioGen(nn.Module):
|
|
204 |
library_name="audiocraft",
|
205 |
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
|
206 |
pkg = torch.load(file,
|
207 |
-
map_location=
|
208 |
-
cfg = OmegaConf.create(pkg['xp.cfg'])
|
209 |
-
# cfg.device = 'cpu'
|
210 |
-
if self.device == 'cpu':
|
211 |
-
cfg.dtype = 'float32'
|
212 |
-
else:
|
213 |
-
cfg.dtype = 'float16'
|
214 |
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
|
215 |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
216 |
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
217 |
-
|
218 |
-
|
|
|
|
|
219 |
_best = pkg['best_state']
|
220 |
_best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
|
221 |
_best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
|
222 |
model.load_state_dict(pkg['best_state'])
|
223 |
-
model.cfg = cfg
|
224 |
-
# return model
|
225 |
self.lm = model.to(torch.float)
|
226 |
|
227 |
# def _flush(self):
|
|
|
|
|
1 |
import omegaconf
|
2 |
+
import torchaudio
|
3 |
from torch import nn
|
4 |
import torch
|
5 |
+
import numpy as np
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
import os
|
8 |
+
from omegaconf import OmegaConf
|
|
|
9 |
from .encodec import EncodecModel
|
10 |
from .lm import LMModel
|
11 |
from .seanet import SEANetDecoder
|
|
|
|
|
12 |
from .vq import ResidualVectorQuantizer
|
13 |
|
14 |
+
def _shift(x):
|
15 |
+
# [bs, samples] shift circular each batch elem of sound
|
16 |
+
n = x.shape[1]
|
17 |
+
for i, batch_elem in enumerate(x):
|
18 |
+
offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
|
19 |
+
x[i, :] = torch.roll(batch_elem, offset, dims=0) # batch_elem = [400000, ]
|
20 |
+
return x
|
21 |
|
22 |
+
def _delete_param(cfg, full_name):
|
|
|
|
|
23 |
parts = full_name.split('.')
|
24 |
for part in parts[:-1]:
|
25 |
if part in cfg:
|
|
|
38 |
return dct
|
39 |
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
class AudioGen(nn.Module):
|
42 |
|
43 |
# https://huggingface.co/facebook/audiogen-medium
|
44 |
|
45 |
def __init__(self,
|
46 |
+
duration=2.24, # s
|
47 |
+
):
|
48 |
|
49 |
super().__init__()
|
|
|
50 |
self.load_compression_model()
|
51 |
self.load_lm_model()
|
52 |
self.duration = duration
|
53 |
+
# AudioGen = 16KHZ StyleTTS2 = 24 KHz / MMSTTS = 24 KHz
|
54 |
+
self.resample_fn = torchaudio.transforms.Resample(16000, 24000)
|
55 |
|
56 |
@property
|
57 |
def frame_rate(self):
|
58 |
return self.compression_model.frame_rate
|
59 |
|
60 |
def generate(self,
|
61 |
+
descriptions,
|
62 |
+
n_repeat=3):
|
63 |
+
|
64 |
with torch.no_grad():
|
65 |
gen_tokens = self.lm.generate(
|
66 |
+
descriptions=[descriptions]*3,
|
67 |
max_gen_len=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
|
68 |
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
|
69 |
+
|
70 |
+
x = x[:, 0, :-250] # last samples have splash sounds DISCARD 25000 last samples
|
71 |
+
|
72 |
+
# AudioGen 16KHZ / StyleTTS2 24 KHz / MMSTTS 24 KHz
|
73 |
|
74 |
+
# x = self.resample_fn(x)
|
75 |
+
|
76 |
+
# batch size = different sounds for same txt
|
77 |
+
|
78 |
+
x = x.repeat(1, n_repeat)
|
79 |
+
|
80 |
+
# less periodic - shift every batch elem
|
81 |
|
82 |
+
for _ in range(7):
|
83 |
+
x = _shift(x)
|
84 |
+
|
85 |
+
x = x.reshape(-1)
|
86 |
+
print(x.abs().max(), 'MAX')
|
87 |
+
return x / (x.abs().max() + 1e-7)
|
88 |
|
89 |
# == BUILD Fn
|
90 |
def get_quantizer(self, quantizer, cfg, dimension):
|
|
|
134 |
).to(cfg.device)
|
135 |
else:
|
136 |
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
137 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
def load_compression_model(self):
|
139 |
file = hf_hub_download(
|
140 |
repo_id='facebook/audiogen-medium',
|
|
|
161 |
library_name="audiocraft",
|
162 |
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
|
163 |
pkg = torch.load(file,
|
164 |
+
map_location='cpu')
|
165 |
+
cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
|
|
|
|
|
|
|
|
|
|
|
166 |
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
|
167 |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
168 |
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
169 |
+
print('___________________________CFG___________________',cfg,'\n=======================')
|
170 |
+
kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
|
171 |
+
print('___________________________Kwarg___________________',kwargs,'\n=======================')
|
172 |
+
model = LMModel().to(getattr(torch, cfg.dtype)) #.to(cfg.device)
|
173 |
_best = pkg['best_state']
|
174 |
_best['condition_provider.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')
|
175 |
_best['condition_provider.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')
|
176 |
model.load_state_dict(pkg['best_state'])
|
177 |
+
# model.cfg = cfg
|
|
|
178 |
self.lm = model.to(torch.float)
|
179 |
|
180 |
# def _flush(self):
|
audiocraft/conditioners.py
CHANGED
@@ -25,7 +25,7 @@ class T5Conditioner(nn.Module):
|
|
25 |
def __init__(self,
|
26 |
name,
|
27 |
output_dim,
|
28 |
-
device,
|
29 |
finetune=False):
|
30 |
print(f'{finetune=}')
|
31 |
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
|
@@ -36,7 +36,7 @@ class T5Conditioner(nn.Module):
|
|
36 |
self.device = device
|
37 |
self.name = name
|
38 |
|
39 |
-
self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
|
40 |
t5 = T5EncoderModel.from_pretrained(name).eval() #.train(mode=finetune)
|
41 |
if finetune:
|
42 |
self.t5 = t5
|
@@ -65,7 +65,7 @@ class T5Conditioner(nn.Module):
|
|
65 |
embeds = self.t5(input_ids=d['input_ids'],
|
66 |
attention_mask=d['attention_mask']
|
67 |
).last_hidden_state # no kvcache for txt conditioning
|
68 |
-
|
69 |
embeds = (embeds * d['attention_mask'].unsqueeze(-1))
|
70 |
|
71 |
return embeds # , d['attention_mask']
|
|
|
25 |
def __init__(self,
|
26 |
name,
|
27 |
output_dim,
|
28 |
+
device='cuda:0',
|
29 |
finetune=False):
|
30 |
print(f'{finetune=}')
|
31 |
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
|
|
|
36 |
self.device = device
|
37 |
self.name = name
|
38 |
|
39 |
+
self.t5_tokenizer = T5Tokenizer.from_pretrained(name, legacy=True)
|
40 |
t5 = T5EncoderModel.from_pretrained(name).eval() #.train(mode=finetune)
|
41 |
if finetune:
|
42 |
self.t5 = t5
|
|
|
65 |
embeds = self.t5(input_ids=d['input_ids'],
|
66 |
attention_mask=d['attention_mask']
|
67 |
).last_hidden_state # no kvcache for txt conditioning
|
68 |
+
embeds = self.output_proj(embeds.to(self.output_proj.weight))
|
69 |
embeds = (embeds * d['attention_mask'].unsqueeze(-1))
|
70 |
|
71 |
return embeds # , d['attention_mask']
|
audiocraft/lm.py
CHANGED
@@ -1,237 +1,45 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
import logging
|
3 |
-
import math
|
4 |
-
import typing as tp
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
from audiocraft.transformer import StreamingTransformer
|
8 |
-
from dataclasses import dataclass
|
9 |
-
from functools import partial
|
10 |
from torch import nn
|
11 |
-
from audiocraft.
|
|
|
12 |
import numpy as np
|
13 |
|
14 |
-
def _shift(x):
|
15 |
-
# cyclic shift of [1, 4, seq_len] slices from [bs, 4, seq_len]
|
16 |
-
print(x.shape, 'SHIFT\n= = = = = ')
|
17 |
-
for i, _slice in enumerate(x):
|
18 |
-
n = x.shape[2]
|
19 |
-
offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
|
20 |
-
print(offset)
|
21 |
-
x[i, :, :] = torch.roll(_slice, offset, dims=1)
|
22 |
-
return x
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
|
27 |
-
"""LM layer initialization.
|
28 |
-
Inspired from xlformers: https://github.com/fairinternal/xlformers
|
29 |
-
|
30 |
-
Args:
|
31 |
-
method (str): Method name for init function. Valid options are:
|
32 |
-
'gaussian', 'uniform'.
|
33 |
-
input_dim (int): Input dimension of the initialized module.
|
34 |
-
init_depth (int, optional): Optional init depth value used to rescale
|
35 |
-
the standard deviation if defined.
|
36 |
-
"""
|
37 |
-
# Compute std
|
38 |
-
std = 1 / math.sqrt(input_dim)
|
39 |
-
# Rescale with depth
|
40 |
-
if init_depth is not None:
|
41 |
-
std = std / math.sqrt(2 * init_depth)
|
42 |
-
|
43 |
-
if method == 'gaussian':
|
44 |
-
return partial(
|
45 |
-
torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
|
46 |
-
)
|
47 |
-
elif method == 'uniform':
|
48 |
-
bound = math.sqrt(3) * std # ensure the standard deviation is `std`
|
49 |
-
return partial(torch.nn.init.uniform_, a=-bound, b=bound)
|
50 |
-
else:
|
51 |
-
raise ValueError("Unsupported layer initialization method")
|
52 |
-
|
53 |
-
|
54 |
-
def init_layer(m: nn.Module,
|
55 |
-
method: str,
|
56 |
-
init_depth: tp.Optional[int] = None,
|
57 |
-
zero_bias_init: bool = False):
|
58 |
-
"""Wrapper around ``get_init_fn`` for proper initialization of LM modules.
|
59 |
-
|
60 |
-
Args:
|
61 |
-
m (nn.Module): Module to initialize.
|
62 |
-
method (str): Method name for the init function.
|
63 |
-
init_depth (int, optional): Optional init depth value used to rescale
|
64 |
-
the standard deviation if defined.
|
65 |
-
zero_bias_init (bool): Whether to initialize the bias to 0 or not.
|
66 |
-
"""
|
67 |
-
if isinstance(m, nn.Linear):
|
68 |
-
init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
|
69 |
-
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
|
70 |
-
weight = m.weight.float()
|
71 |
-
init_fn(weight)
|
72 |
-
m.weight.data[:] = weight.half()
|
73 |
-
else:
|
74 |
-
init_fn(m.weight)
|
75 |
-
if zero_bias_init and m.bias is not None:
|
76 |
-
nn.init.constant_(m.bias, 0)
|
77 |
-
elif isinstance(m, nn.Embedding):
|
78 |
-
init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
|
79 |
-
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
|
80 |
-
weight = m.weight.float()
|
81 |
-
init_fn(weight)
|
82 |
-
m.weight.data[:] = weight.half()
|
83 |
-
else:
|
84 |
-
init_fn(m.weight)
|
85 |
-
|
86 |
-
|
87 |
-
class ScaledEmbedding(nn.Embedding):
|
88 |
-
"""Boost learning rate for embeddings (with `scale`).
|
89 |
-
"""
|
90 |
-
def __init__(self, *args, lr=None, **kwargs):
|
91 |
-
super().__init__(*args, **kwargs)
|
92 |
-
self.lr = lr
|
93 |
-
|
94 |
-
def make_optim_group(self):
|
95 |
-
group = {"params": list(self.parameters())}
|
96 |
-
if self.lr is not None:
|
97 |
-
group["lr"] = self.lr
|
98 |
-
return group
|
99 |
-
|
100 |
-
|
101 |
-
@dataclass
|
102 |
-
class LMOutput:
|
103 |
-
# The logits are already re-aligned with the input codes
|
104 |
-
# hence no extra shift is required, e.g. when computing CE
|
105 |
-
logits: torch.Tensor # [B, K, T, card]
|
106 |
-
mask: torch.Tensor # [B, K, T]
|
107 |
-
|
108 |
|
109 |
class LMModel(nn.Module):
|
110 |
-
|
111 |
-
|
112 |
-
Args:
|
113 |
-
pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
|
114 |
-
condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
|
115 |
-
fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
|
116 |
-
n_q (int): Number of parallel streams to model.
|
117 |
-
card (int): Cardinality, vocabulary size.
|
118 |
-
dim (int): Dimension of the transformer encoder.
|
119 |
-
num_heads (int): Number of heads for the transformer encoder.
|
120 |
-
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
|
121 |
-
norm (str): Normalization method.
|
122 |
-
norm_first (bool): Use pre-norm instead of post-norm.
|
123 |
-
emb_lr (float, optional): Embedding-specific learning rate.
|
124 |
-
bias_proj (bool): Use bias for output projections.
|
125 |
-
weight_init (str, optional): Method for weight initialization.
|
126 |
-
depthwise_init (str, optional): Method for depthwise weight initialization.
|
127 |
-
zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
|
128 |
-
cfg_dropout (float): Classifier-free guidance dropout.
|
129 |
-
cfg_coef (float): Classifier-free guidance coefficient.
|
130 |
-
attribute_dropout (dict): Attribute dropout probabilities.
|
131 |
-
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
|
132 |
-
**kwargs: Additional parameters for the transformer encoder.
|
133 |
-
"""
|
134 |
def __init__(self,
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
hidden_scale: int = 4,
|
142 |
-
norm: str = 'layer_norm',
|
143 |
-
norm_first: bool = False,
|
144 |
-
emb_lr: tp.Optional[float] = None,
|
145 |
-
bias_proj: bool = True,
|
146 |
-
weight_init: tp.Optional[str] = None,
|
147 |
-
depthwise_init: tp.Optional[str] = None,
|
148 |
-
zero_bias_init: bool = False, cfg_dropout: float = 0,
|
149 |
-
cfg_coef: float = 1.0,
|
150 |
-
two_step_cfg: bool = False,
|
151 |
-
**kwargs):
|
152 |
super().__init__()
|
153 |
-
self.
|
154 |
-
|
155 |
self.card = card # 2048 ?
|
156 |
self.n_draw = 1 # replicate so many times the generation of each text in batch
|
|
|
|
|
157 |
embed_dim = self.card + 1
|
158 |
self.n_q = n_q
|
159 |
self.dim = dim
|
160 |
-
self.pattern_provider =
|
161 |
-
self.
|
162 |
-
self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
|
163 |
-
if 'activation' in kwargs:
|
164 |
-
kwargs['activation'] = get_activation_fn(kwargs['activation'])
|
165 |
-
# ========================================================================
|
166 |
-
# {
|
167 |
-
# 'dtype': torch.float16, 'device': 'cuda',
|
168 |
-
# 'num_layers': 48, 'dropout': 0.0, 'activation': 'gelu',
|
169 |
-
# 'bias_ff': False, 'bias_attn': False,
|
170 |
-
# 'past_context': None, 'causal': True,
|
171 |
-
# 'custom': False, 'memory_efficient': True,
|
172 |
-
# 'attention_as_float32': False, 'positional_embedding': 'sin', 'xpos': False,
|
173 |
-
# 'checkpointing': 'none', 'cross_attention': True, 'qk_layer_norm': False,
|
174 |
-
# 'qk_layer_norm_cross': False, 'attention_dropout': None, 'kv_repeat': 1
|
175 |
-
# }
|
176 |
-
# ==========================================================================
|
177 |
-
kwargs.pop('layer_scale') # nn.Indentity()
|
178 |
-
|
179 |
self.transformer = StreamingTransformer(
|
180 |
d_model=dim,
|
181 |
num_heads=num_heads,
|
182 |
dim_feedforward=int(hidden_scale * dim),
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
self.
|
189 |
-
self.
|
190 |
-
self._fsdp: tp.Optional[nn.Module]
|
191 |
-
self.__dict__['_fsdp'] = None
|
192 |
-
|
193 |
-
def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
|
194 |
-
"""Initialization of the transformer module weights.
|
195 |
-
|
196 |
-
Args:
|
197 |
-
weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
|
198 |
-
depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
|
199 |
-
'current' where the depth corresponds to the current layer index or 'global' where the total number
|
200 |
-
of layer is used as depth. If not set, no depthwise initialization strategy is used.
|
201 |
-
zero_bias_init (bool): Whether to initialize bias to zero or not.
|
202 |
-
"""
|
203 |
-
assert depthwise_init is None or depthwise_init in ['current', 'global']
|
204 |
-
assert depthwise_init is None or weight_init is not None, \
|
205 |
-
"If 'depthwise_init' is defined, a 'weight_init' method should be provided."
|
206 |
-
assert not zero_bias_init or weight_init is not None, \
|
207 |
-
"If 'zero_bias_init', a 'weight_init' method should be provided"
|
208 |
-
|
209 |
-
if weight_init is None:
|
210 |
-
return
|
211 |
-
|
212 |
-
for emb_layer in self.emb:
|
213 |
-
init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
214 |
-
|
215 |
-
for layer_idx, tr_layer in enumerate(self.transformer.layers):
|
216 |
-
depth = None
|
217 |
-
if depthwise_init == 'current':
|
218 |
-
depth = layer_idx + 1
|
219 |
-
elif depthwise_init == 'global':
|
220 |
-
depth = len(self.transformer.layers)
|
221 |
-
init_fn = partial(init_layer,
|
222 |
-
method=weight_init,
|
223 |
-
init_depth=depth,
|
224 |
-
zero_bias_init=zero_bias_init)
|
225 |
-
tr_layer.apply(init_fn)
|
226 |
-
|
227 |
-
for linear in self.linears:
|
228 |
-
init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
229 |
|
230 |
-
@property
|
231 |
-
def special_token_id(self) -> int:
|
232 |
-
return self.card
|
233 |
-
|
234 |
-
|
235 |
|
236 |
def forward(self,
|
237 |
sequence,
|
@@ -293,7 +101,7 @@ class LMModel(nn.Module):
|
|
293 |
max_gen_len), -1, dtype=torch.long,
|
294 |
device=text_condition.device)
|
295 |
|
296 |
-
gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.
|
297 |
_, _, audiodur = gen_sequence.shape # bs, 4, 7=audiodur
|
298 |
|
299 |
# print(gen_sequence.shape, mask.shape, 'F') # mask has no batch = [4,audio_duration]
|
@@ -313,7 +121,7 @@ class LMModel(nn.Module):
|
|
313 |
for offset in range(1, audiodur):
|
314 |
|
315 |
# forward duplicates the query to nullcond - then cfg & returns deduplicate token
|
316 |
-
next_token = self.forward(gen_sequence[:, 0, :, offset-1:offset],
|
317 |
condition_tensors=text_condition, # utilisation of the attention mask of txt condition ?
|
318 |
token_count=offset-1) # [bs, 4, 1, 2048]
|
319 |
|
@@ -322,7 +130,7 @@ class LMModel(nn.Module):
|
|
322 |
|
323 |
# MASK is not full 1---- HAS 4 x audioduration PATTERN
|
324 |
m = mask[:, :, :, offset]
|
325 |
-
next_token[~m] = self.
|
326 |
gen_sequence[:, :, :, offset] = torch.where(
|
327 |
gen_sequence[:, :, :, offset] == -1, #unknown_token,
|
328 |
next_token,
|
@@ -333,7 +141,7 @@ class LMModel(nn.Module):
|
|
333 |
# 1. reshape n_draw as bs * n_draw
|
334 |
# 2. invert all short-sequences
|
335 |
# 3. reshape bs * n_draw -> bs, n_draw * audiodur ELONGATION
|
336 |
-
out_codes
|
337 |
gen_sequence.reshape(bs * self.n_draw, 4, audiodur), # [3,8,4,7]
|
338 |
special_token=-1)
|
339 |
# print(f'{gen_sequence.shape=} {out_codes.shape=} Ha') # REVERT PATTERN REDUCES DURATION?
|
@@ -341,12 +149,10 @@ class LMModel(nn.Module):
|
|
341 |
out_codes = out_codes.reshape(bs, self.n_draw, 4, new_len)
|
342 |
out_codes = out_codes.transpose(1, 2).reshape(bs, 4, self.n_draw * new_len)
|
343 |
print(out_codes.shape, 'o')
|
344 |
-
for _ in range(7):
|
345 |
-
out_codes = _shift(out_codes)
|
346 |
|
347 |
-
# Clear
|
348 |
for lay in self.transformer.layers:
|
349 |
lay.self_attn.k_history = None
|
350 |
lay.self_attn.v_history = None
|
351 |
|
352 |
-
return out_codes
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
from audiocraft.transformer import StreamingTransformer
|
|
|
|
|
4 |
from torch import nn
|
5 |
+
from audiocraft.codebooks_patterns import DelayedPatternProvider
|
6 |
+
from audiocraft.conditioners import T5Conditioner
|
7 |
import numpy as np
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
class LMModel(nn.Module):
|
11 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
def __init__(self,
|
13 |
+
n_q = 4,
|
14 |
+
card = 2048,
|
15 |
+
dim = 1536,
|
16 |
+
num_heads = 24,
|
17 |
+
hidden_scale = 4, # FFN of Transformer
|
18 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
super().__init__()
|
20 |
+
self.condition_provider = T5Conditioner(name='t5-large',
|
21 |
+
output_dim=dim)
|
22 |
self.card = card # 2048 ?
|
23 |
self.n_draw = 1 # replicate so many times the generation of each text in batch
|
24 |
+
# the batch is more expensive than n_draw as it re-runs the model bs times
|
25 |
+
# n_draw just draws more phonemes from the multinomial - after running the lm
|
26 |
embed_dim = self.card + 1
|
27 |
self.n_q = n_q
|
28 |
self.dim = dim
|
29 |
+
self.pattern_provider = DelayedPatternProvider()
|
30 |
+
self.emb = nn.ModuleList([nn.Embedding(embed_dim, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
self.transformer = StreamingTransformer(
|
32 |
d_model=dim,
|
33 |
num_heads=num_heads,
|
34 |
dim_feedforward=int(hidden_scale * dim),
|
35 |
+
num_layers=48,
|
36 |
+
positional_embedding='sin',
|
37 |
+
)
|
38 |
+
self.out_norm = nn.LayerNorm(dim, eps=1e-5)
|
39 |
+
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) # LINEAR DOESNT HAVE 2049
|
40 |
+
# self._init_weights(weight_init, depthwise_init, zero_bias_init)
|
41 |
+
# self.__dict__['_fsdp'] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
def forward(self,
|
45 |
sequence,
|
|
|
101 |
max_gen_len), -1, dtype=torch.long,
|
102 |
device=text_condition.device)
|
103 |
|
104 |
+
gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.card)
|
105 |
_, _, audiodur = gen_sequence.shape # bs, 4, 7=audiodur
|
106 |
|
107 |
# print(gen_sequence.shape, mask.shape, 'F') # mask has no batch = [4,audio_duration]
|
|
|
121 |
for offset in range(1, audiodur):
|
122 |
|
123 |
# forward duplicates the query to nullcond - then cfg & returns deduplicate token
|
124 |
+
next_token = self.forward(gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
|
125 |
condition_tensors=text_condition, # utilisation of the attention mask of txt condition ?
|
126 |
token_count=offset-1) # [bs, 4, 1, 2048]
|
127 |
|
|
|
130 |
|
131 |
# MASK is not full 1---- HAS 4 x audioduration PATTERN
|
132 |
m = mask[:, :, :, offset]
|
133 |
+
next_token[~m] = self.card
|
134 |
gen_sequence[:, :, :, offset] = torch.where(
|
135 |
gen_sequence[:, :, :, offset] == -1, #unknown_token,
|
136 |
next_token,
|
|
|
141 |
# 1. reshape n_draw as bs * n_draw
|
142 |
# 2. invert all short-sequences
|
143 |
# 3. reshape bs * n_draw -> bs, n_draw * audiodur ELONGATION
|
144 |
+
out_codes = pattern.revert_pattern_sequence(
|
145 |
gen_sequence.reshape(bs * self.n_draw, 4, audiodur), # [3,8,4,7]
|
146 |
special_token=-1)
|
147 |
# print(f'{gen_sequence.shape=} {out_codes.shape=} Ha') # REVERT PATTERN REDUCES DURATION?
|
|
|
149 |
out_codes = out_codes.reshape(bs, self.n_draw, 4, new_len)
|
150 |
out_codes = out_codes.transpose(1, 2).reshape(bs, 4, self.n_draw * new_len)
|
151 |
print(out_codes.shape, 'o')
|
|
|
|
|
152 |
|
153 |
+
# Clear k/v cache (Different kv is saved by every 48x selfattn)
|
154 |
for lay in self.transformer.layers:
|
155 |
lay.self_attn.k_history = None
|
156 |
lay.self_attn.v_history = None
|
157 |
|
158 |
+
return out_codes # bs*n_draw, duration -> repeat/shift in api.py
|
audiocraft/transformer.py
CHANGED
@@ -1,26 +1,12 @@
|
|
1 |
-
import typing as tp
|
2 |
-
from einops import rearrange
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
from torch.nn import functional as F
|
6 |
-
from
|
7 |
-
|
8 |
-
|
9 |
-
_efficient_attention_backend: str = 'torch'
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
def _get_attention_time_dimension(memory_efficient: bool) -> int:
|
15 |
-
if _efficient_attention_backend == 'torch' and memory_efficient:
|
16 |
-
return 2
|
17 |
-
else:
|
18 |
-
return 1
|
19 |
-
|
20 |
-
|
21 |
|
22 |
-
def create_sin_embedding(positions
|
23 |
-
|
|
|
|
|
24 |
"""Create sinusoidal positional embedding, with shape `[B, T, C]`.
|
25 |
|
26 |
Args:
|
@@ -41,256 +27,102 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float =
|
|
41 |
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
42 |
|
43 |
|
44 |
-
def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
|
45 |
-
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
|
46 |
-
if n_rep == 1:
|
47 |
-
return x
|
48 |
-
if _efficient_attention_backend == 'torch' and memory_efficient:
|
49 |
-
bs, n_kv_heads, slen, head_dim = x.shape
|
50 |
-
return (
|
51 |
-
x[:, :, None, :, :]
|
52 |
-
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
|
53 |
-
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
|
54 |
-
)
|
55 |
-
else:
|
56 |
-
bs, slen, n_kv_heads, head_dim = x.shape
|
57 |
-
return (
|
58 |
-
x[:, :, :, None, :]
|
59 |
-
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
60 |
-
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
61 |
-
)
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
class StreamingMultiheadAttention(nn.Module):
|
68 |
|
69 |
def __init__(self,
|
70 |
embed_dim,
|
71 |
-
num_heads,
|
72 |
-
|
73 |
-
memory_efficient: bool = False, attention_as_float32: bool = False,
|
74 |
-
cross_attention: bool = False,
|
75 |
-
kv_repeat: int = 1,
|
76 |
-
device=None, dtype=None):
|
77 |
super().__init__()
|
78 |
-
|
79 |
-
if past_context is not None:
|
80 |
-
assert causal
|
81 |
-
|
82 |
self.embed_dim = embed_dim
|
83 |
-
|
84 |
self.k_history = None # previous k from the previous tokens seen in the current generation - only for selt.attn
|
85 |
-
self.v_history = None # clean up IN LM after finishing GENERATION - Each 1...47 mha has different kv history
|
86 |
-
|
87 |
-
self.memory_efficient = memory_efficient
|
88 |
-
|
89 |
-
|
90 |
-
self.cross_attention = cross_attention
|
91 |
-
|
92 |
self.num_heads = num_heads
|
93 |
-
self.
|
94 |
-
self.
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
self.custom = True #_is_custom(custom, memory_efficient)
|
100 |
-
if not self.custom:
|
101 |
-
print(f'{self.custom}')
|
102 |
-
if self.custom:
|
103 |
-
out_dim = embed_dim
|
104 |
-
assert num_heads % kv_repeat == 0
|
105 |
-
assert not cross_attention or kv_repeat == 1
|
106 |
-
num_kv = num_heads // kv_repeat
|
107 |
-
kv_dim = (embed_dim // num_heads) * num_kv
|
108 |
-
out_dim += 2 * kv_dim
|
109 |
-
in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
|
110 |
-
# We try to follow the default PyTorch MHA convention, to easily compare results.
|
111 |
-
self.in_proj_weight = in_proj.weight
|
112 |
-
self.in_proj_bias = in_proj.bias
|
113 |
-
if bias:
|
114 |
-
self.in_proj_bias.data.zero_() # Following Pytorch convention
|
115 |
-
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
116 |
-
if bias:
|
117 |
-
self.out_proj.bias.data.zero_()
|
118 |
-
else:
|
119 |
-
assert kv_repeat == 1
|
120 |
-
self.mha = nn.MultiheadAttention(
|
121 |
-
embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
|
122 |
-
**factory_kwargs)
|
123 |
-
|
124 |
-
|
125 |
-
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
126 |
-
if not self.custom:
|
127 |
-
# Support compat with regular MHA
|
128 |
-
keys = [n for n, _ in self.mha.named_parameters()]
|
129 |
-
for key in keys:
|
130 |
-
if prefix + key in state_dict:
|
131 |
-
state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
|
132 |
-
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
133 |
-
|
134 |
-
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
def forward(self,
|
141 |
-
query,
|
142 |
-
key=None, # ignores those 2 args if not self.cross_attn
|
143 |
value=None):
|
144 |
-
|
145 |
-
|
146 |
-
# time_dim = _get_attention_time_dimension(self.memory_efficient)
|
147 |
-
# if time_dim == 2:
|
148 |
layout = "b h t d"
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
-
|
|
|
|
|
159 |
|
160 |
-
if self.cross_attention:
|
161 |
-
# Different queries, keys, values, we have to spit manually the weights
|
162 |
-
# before applying the linear.
|
163 |
-
dim = self.in_proj_weight.shape[0] // 3
|
164 |
-
if self.in_proj_bias is None:
|
165 |
-
bias_q, bias_k, bias_v = None, None, None
|
166 |
-
else:
|
167 |
-
bias_q = self.in_proj_bias[:dim]
|
168 |
-
bias_k = self.in_proj_bias[dim: 2 * dim]
|
169 |
-
bias_v = self.in_proj_bias[2 * dim:]
|
170 |
-
q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
|
171 |
-
# todo: when streaming, we could actually save k, v and check the shape actually match.
|
172 |
-
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
|
173 |
-
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
|
174 |
-
|
175 |
-
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
|
176 |
-
# print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(),'CROSS A5')
|
177 |
-
else:
|
178 |
-
# 1st projected makes k,v (instantaneous)
|
179 |
-
# 2nd cat
|
180 |
|
181 |
-
|
182 |
-
#
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
# else:
|
189 |
-
# bound_layout = "b t p h d"
|
190 |
-
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
191 |
-
q, k, v = packed.unbind(dim=2)
|
192 |
-
|
193 |
-
|
194 |
-
if self.k_history is not None:
|
195 |
-
#
|
196 |
-
# pk.shape=torch.Size([2, 24, 3, 64]) k.shape=torch.Size([2, 24, 1, 64]) CONCAT
|
197 |
-
# has to be 4D with batch 1 due to single condition 3=seqlen
|
198 |
-
# 24 heads 64 dimofh
|
199 |
-
self.k_history = torch.cat([self.k_history, k], 2)
|
200 |
-
self.v_history = torch.cat([self.v_history, v], 2)
|
201 |
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
|
211 |
|
212 |
-
|
213 |
-
# KV COMPLETION ONLY ON SELF ATTENTION
|
214 |
-
# print('KV5', self.k_history.sum(), self.v_history.sum(), self.k_history.shape, self.v_history.shape)
|
215 |
-
|
216 |
|
217 |
-
|
218 |
-
# print('EVER IN MEMORY EFFICIENT A')
|
219 |
-
|
220 |
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
|
229 |
-
x = self.out_proj(x)
|
230 |
return x
|
231 |
|
232 |
|
233 |
-
class StreamingTransformerLayer(nn.Module):
|
234 |
-
# INHERITS MHA !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
235 |
|
236 |
def __init__(self,
|
237 |
-
d_model
|
238 |
-
num_heads
|
239 |
-
dim_feedforward:
|
240 |
-
dropout: float = 0.1,
|
241 |
-
bias_ff: bool = True,
|
242 |
-
bias_attn: bool = True,
|
243 |
-
custom: bool = False,
|
244 |
-
memory_efficient: bool = False,
|
245 |
-
attention_as_float32: bool = False,
|
246 |
-
cross_attention: bool = False,
|
247 |
-
attention_dropout: tp.Optional[float] = None,
|
248 |
-
kv_repeat: int = 1,
|
249 |
-
norm: str = 'layer_norm',
|
250 |
-
device=None,
|
251 |
-
dtype=None,
|
252 |
-
**kwargs):
|
253 |
-
|
254 |
|
255 |
-
super().__init__() #d_model, num_heads, dim_feedforward, dropout,
|
256 |
-
#device=device, dtype=dtype, batch_first=True, **kwargs)
|
257 |
-
# print(kwargs['activation'], 'ACTIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII\n\n\n\n')
|
258 |
-
# -- EN Layer
|
259 |
-
# DOES NOT INHERIT NO VARIABLE FROM nn.TransformerEncoderLayer only the _sa_block function
|
260 |
|
261 |
-
|
262 |
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
'memory_efficient': memory_efficient,
|
272 |
-
'attention_as_float32': attention_as_float32,
|
273 |
-
}
|
274 |
-
self.self_attn = StreamingMultiheadAttention(
|
275 |
-
kv_repeat=kv_repeat,
|
276 |
-
**attn_kwargs,
|
277 |
-
**factory_kwargs) # type: ignore
|
278 |
-
# Redefine feedforward layers to expose bias parameter
|
279 |
-
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
|
280 |
-
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
|
281 |
-
# print('LAYER scale', layer_scale, '\n\n\n\n\n\n\n\n\n') # always
|
282 |
-
|
283 |
-
|
284 |
-
self.cross_attention= None
|
285 |
-
if cross_attention:
|
286 |
-
self.cross_attention = StreamingMultiheadAttention(
|
287 |
-
cross_attention=True,
|
288 |
-
**attn_kwargs,
|
289 |
-
**factory_kwargs)
|
290 |
-
|
291 |
-
self.dropout_cross = nn.Dropout(dropout)
|
292 |
-
|
293 |
-
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
|
294 |
self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
|
295 |
self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
|
296 |
|
@@ -316,59 +148,34 @@ class StreamingTransformerLayer(nn.Module): #nn.TransformerEncoderLayer):
|
|
316 |
|
317 |
class StreamingTransformer(nn.Module):
|
318 |
|
319 |
-
def __init__(self,
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
bias_attn: bool = True,
|
326 |
-
custom: bool = False,
|
327 |
-
memory_efficient: bool = False,
|
328 |
-
attention_as_float32: bool = False,
|
329 |
-
cross_attention: bool = False,
|
330 |
positional_embedding: str = 'sin',
|
331 |
-
max_period: float = 10_000
|
332 |
-
|
333 |
-
checkpointing: str = 'none',
|
334 |
-
device=None,
|
335 |
-
dtype=None,
|
336 |
-
**kwargs):
|
337 |
super().__init__()
|
338 |
assert d_model % num_heads == 0
|
339 |
|
340 |
self.positional_embedding = positional_embedding
|
341 |
self.max_period = max_period
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
# self._stream_off = 0 # the llm should reinitialize this at ery generate()
|
346 |
-
|
347 |
-
self.checkpointing = checkpointing
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
self.layers = nn.ModuleList()
|
353 |
for idx in range(num_layers):
|
354 |
self.layers.append(
|
355 |
-
|
356 |
-
d_model=d_model,
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
device=device, dtype=dtype, **kwargs))
|
362 |
-
|
363 |
-
if self.checkpointing != 'none':
|
364 |
-
for layer in self.layers:
|
365 |
-
# see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
|
366 |
-
# backward hook inside of FSDP...
|
367 |
-
layer._magma_checkpointed = True # type: ignore
|
368 |
-
|
369 |
-
|
370 |
|
371 |
-
def forward(self,
|
|
|
|
|
|
|
372 |
|
373 |
B, T, C = x.shape
|
374 |
|
@@ -376,7 +183,7 @@ class StreamingTransformer(nn.Module):
|
|
376 |
if self.positional_embedding in ['sin', 'sin_rope']:
|
377 |
|
378 |
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
379 |
-
positions = positions +
|
380 |
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
|
381 |
x = x + pos_emb
|
382 |
|
@@ -384,6 +191,6 @@ class StreamingTransformer(nn.Module):
|
|
384 |
|
385 |
for j, lay in enumerate(self.layers):
|
386 |
# print(f'Transf Layer{j} {pos_emb.sum()=} {pos_emb.shape=}{x.shape=}___________________')
|
387 |
-
x = lay(x, cross_attention_src=
|
388 |
# each layer (mha) keeps history of its own k,v for all tokens
|
389 |
return x
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
from torch.nn import functional as F
|
4 |
+
from einops import rearrange
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
def create_sin_embedding(positions,
|
7 |
+
dim,
|
8 |
+
max_period = 10000,
|
9 |
+
dtype = torch.float32):
|
10 |
"""Create sinusoidal positional embedding, with shape `[B, T, C]`.
|
11 |
|
12 |
Args:
|
|
|
27 |
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
class StreamingMultiheadAttention(nn.Module):
|
31 |
|
32 |
def __init__(self,
|
33 |
embed_dim,
|
34 |
+
num_heads,
|
35 |
+
cross_attention = False):
|
|
|
|
|
|
|
|
|
36 |
super().__init__()
|
37 |
+
self.cross_attention = cross_attention
|
|
|
|
|
|
|
38 |
self.embed_dim = embed_dim
|
|
|
39 |
self.k_history = None # previous k from the previous tokens seen in the current generation - only for selt.attn
|
40 |
+
self.v_history = None # clean up IN LM after finishing GENERATION - Each 1...47 mha has different kv history
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
self.num_heads = num_heads
|
42 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
43 |
+
self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim),
|
44 |
+
dtype=torch.float))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
def forward(self,
|
47 |
+
query,
|
48 |
+
key=None,
|
|
|
|
|
|
|
|
|
49 |
value=None):
|
|
|
|
|
|
|
|
|
50 |
layout = "b h t d"
|
51 |
+
if self.cross_attention:
|
52 |
+
|
53 |
+
# Different queries, keys, values, we have to spit manually the in_proj_weight
|
54 |
+
|
55 |
+
dim = self.in_proj_weight.shape[0] // 3
|
56 |
+
|
57 |
+
q = nn.functional.linear(query, self.in_proj_weight[:dim])
|
58 |
+
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim])
|
59 |
+
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
|
60 |
+
|
61 |
+
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
|
62 |
+
# print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(),'CROSS A5')
|
63 |
+
else:
|
64 |
+
# 1st projected makes k,v (instantaneous)
|
65 |
+
# 2nd cat
|
66 |
+
|
67 |
+
|
68 |
+
# HISTORY - DIFFERENT FOR EACH TRANSF LAYER
|
69 |
+
|
70 |
+
projected = nn.functional.linear(query, self.in_proj_weight)
|
71 |
|
72 |
+
bound_layout = "b h p t d"
|
73 |
+
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
74 |
+
q, k, v = packed.unbind(dim=2)
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
+
if self.k_history is not None:
|
78 |
+
#
|
79 |
+
# pk.shape=torch.Size([2, 24, 3, 64]) k.shape=torch.Size([2, 24, 1, 64]) CONCAT
|
80 |
+
# has to be 4D with batch 1 due to single condition 3=seqlen
|
81 |
+
# 24 heads 64 dimofh
|
82 |
+
self.k_history = torch.cat([self.k_history, k], 2)
|
83 |
+
self.v_history = torch.cat([self.v_history, v], 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
else:
|
86 |
+
# init on 1st token (for all 47 transf layers)
|
87 |
+
print(f'AudioGen kv cache Flush')
|
88 |
+
self.k_history = k
|
89 |
+
self.v_history = v
|
90 |
+
|
91 |
+
k = self.k_history
|
92 |
+
v = self.v_history
|
93 |
|
94 |
|
|
|
|
|
|
|
|
|
95 |
|
96 |
+
# KV COMPLETION ONLY ON SELF ATTENTION
|
|
|
|
|
97 |
|
98 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
99 |
+
q, k, v, is_causal=False, dropout_p=0
|
100 |
+
)
|
101 |
+
|
102 |
+
x = x.to(q.dtype)
|
103 |
+
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
|
104 |
+
x = self.out_proj(x)
|
|
|
|
|
105 |
return x
|
106 |
|
107 |
|
108 |
+
class StreamingTransformerLayer(nn.Module):
|
|
|
109 |
|
110 |
def __init__(self,
|
111 |
+
d_model,
|
112 |
+
num_heads,
|
113 |
+
dim_feedforward):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
+
super().__init__()
|
117 |
|
118 |
+
self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
|
119 |
+
num_heads=num_heads)
|
120 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
|
121 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
|
122 |
+
self.cross_attention = StreamingMultiheadAttention(embed_dim=d_model,
|
123 |
+
num_heads=num_heads,
|
124 |
+
cross_attention=True)
|
125 |
+
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
|
127 |
self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
|
128 |
|
|
|
148 |
|
149 |
class StreamingTransformer(nn.Module):
|
150 |
|
151 |
+
def __init__(self,
|
152 |
+
d_model=1536,
|
153 |
+
num_heads=24,
|
154 |
+
num_layers=48,
|
155 |
+
dim_feedforward=6144,
|
156 |
+
cross_attention = True,
|
|
|
|
|
|
|
|
|
|
|
157 |
positional_embedding: str = 'sin',
|
158 |
+
max_period: float = 10_000
|
159 |
+
):
|
|
|
|
|
|
|
|
|
160 |
super().__init__()
|
161 |
assert d_model % num_heads == 0
|
162 |
|
163 |
self.positional_embedding = positional_embedding
|
164 |
self.max_period = max_period
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
self.layers = nn.ModuleList()
|
166 |
for idx in range(num_layers):
|
167 |
self.layers.append(
|
168 |
+
StreamingTransformerLayer(
|
169 |
+
d_model=d_model,
|
170 |
+
num_heads=num_heads,
|
171 |
+
dim_feedforward=dim_feedforward
|
172 |
+
)
|
173 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
+
def forward(self,
|
176 |
+
x,
|
177 |
+
token_count=None,
|
178 |
+
cross_attention_src=None):
|
179 |
|
180 |
B, T, C = x.shape
|
181 |
|
|
|
183 |
if self.positional_embedding in ['sin', 'sin_rope']:
|
184 |
|
185 |
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
186 |
+
positions = positions + token_count #offsets.view(-1, 1, 1)
|
187 |
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
|
188 |
x = x + pos_emb
|
189 |
|
|
|
191 |
|
192 |
for j, lay in enumerate(self.layers):
|
193 |
# print(f'Transf Layer{j} {pos_emb.sum()=} {pos_emb.shape=}{x.shape=}___________________')
|
194 |
+
x = lay(x, cross_attention_src=cross_attention_src) # cross_attention_src = txt-cond
|
195 |
# each layer (mha) keeps history of its own k,v for all tokens
|
196 |
return x
|
msinference.py
CHANGED
@@ -293,10 +293,41 @@ with open(f"Utils/all_langs.csv") as f:
|
|
293 |
|
294 |
|
295 |
|
296 |
-
# LOAD hun / ron / serbian - rmc-script_latin / cyrillic-Carpathian (not Vlax)
|
297 |
-
|
298 |
-
|
299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
|
301 |
def has_cyrillic(text):
|
302 |
# https://stackoverflow.com/questions/48255244/python-check-if-a-string-contains-cyrillic-characters
|
@@ -358,7 +389,7 @@ class TextForeign(object):
|
|
358 |
def foreign(text=None, # list of text
|
359 |
lang='romanian',
|
360 |
speed=None):
|
361 |
-
|
362 |
lang = lang.lower() # https://huggingface.co/dkounadis/artificial-styletts2/blob/main/Utils/all_langs.csv
|
363 |
|
364 |
# https://huggingface.co/spaces/mms-meta/MMS
|
@@ -367,11 +398,11 @@ def foreign(text=None, # list of text
|
|
367 |
|
368 |
lang_code = 'hun'
|
369 |
|
370 |
-
elif
|
371 |
|
372 |
if has_cyrillic(text[0]): # check 0-th sentence if is cyrillic
|
373 |
|
374 |
-
lang_code = 'rmc-script_cyrillic' # romani carpathian (also has
|
375 |
|
376 |
else:
|
377 |
|
@@ -387,6 +418,11 @@ def foreign(text=None, # list of text
|
|
387 |
lang_code = 'deu'
|
388 |
speed = 1.14 if speed is None else speed
|
389 |
|
|
|
|
|
|
|
|
|
|
|
390 |
else:
|
391 |
|
392 |
lang_code = lang.split()[0].strip()
|
@@ -431,20 +467,29 @@ def foreign(text=None, # list of text
|
|
431 |
x = []
|
432 |
|
433 |
for _t in text:
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
if is_uroman:
|
438 |
uroman_dir = "Utils/uroman"
|
439 |
assert os.path.exists(uroman_dir)
|
440 |
uroman_pl = os.path.join(uroman_dir, "bin", "uroman.pl")
|
441 |
_t = text_mapper.uromanize(_t, uroman_pl)
|
442 |
|
443 |
-
_t = _t.lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
|
445 |
_t = text_mapper.filter_oov(_t, lang=lang)
|
446 |
|
447 |
-
|
448 |
stn_tst = text_mapper.get_text(_t, hps)
|
449 |
with torch.no_grad():
|
450 |
x_tst = stn_tst.unsqueeze(0).to(device)
|
@@ -468,14 +513,3 @@ def foreign(text=None, # list of text
|
|
468 |
original_rate=16000,
|
469 |
target_rate=24000)[0, :] # reshapes (64,) -> (1,64)
|
470 |
return x
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
# LANG = 'eng'
|
476 |
-
# _t = 'Converts a string of text to a sequence of IDs corresponding to the symbols in the text. Args: text: string to convert to a sequence'
|
477 |
-
|
478 |
-
# x = synthesize(text=_t, lang=LANG, speed=1.14)
|
479 |
-
# audiofile.write('_r.wav', x, 16000) # mms-tts = 16,000
|
480 |
-
|
481 |
-
|
|
|
293 |
|
294 |
|
295 |
|
296 |
+
# LOAD hun / ron / serbian - rmc-script_latin / cyrillic-Carpathian (not Vlax)
|
297 |
+
# ==============================================================================================
|
298 |
+
import re
|
299 |
+
from num2words import num2words
|
300 |
+
|
301 |
+
PHONEME_MAP = {
|
302 |
+
'q': 'ku',
|
303 |
+
'w': 'aou',
|
304 |
+
'z': 's',
|
305 |
+
"š": "s",
|
306 |
+
'th': 'ta',
|
307 |
+
'v': 'vv',
|
308 |
+
# "ć": "č",
|
309 |
+
# "đ": "ď",
|
310 |
+
# "lj": "ľ",
|
311 |
+
# "nj": "ň",
|
312 |
+
"ž": "z",
|
313 |
+
# "c": "č"
|
314 |
+
}
|
315 |
+
|
316 |
+
# ALLOWED_PHONEMES = set("šč_bďph`-3žt 'ľzj5yuoóx1vfnaiedt́sṁkň2rčlg")
|
317 |
+
|
318 |
+
def number_to_phonemes(match):
|
319 |
+
number = int(match.group())
|
320 |
+
words = num2words(number, lang='sr')
|
321 |
+
return fix_phones(words.lower())
|
322 |
+
# return words
|
323 |
+
|
324 |
+
def fix_phones(text):
|
325 |
+
for src, target in PHONEME_MAP.items():
|
326 |
+
text = text.replace(src, target)
|
327 |
+
# text = re.sub(r'\s+', '` `', text) #.strip() #.lower()
|
328 |
+
# text = re.sub(r'\s+', '_ _', text) # almost proper pausing
|
329 |
+
|
330 |
+
return text.replace(',', '_ _').replace('.', '_ _')
|
331 |
|
332 |
def has_cyrillic(text):
|
333 |
# https://stackoverflow.com/questions/48255244/python-check-if-a-string-contains-cyrillic-characters
|
|
|
389 |
def foreign(text=None, # list of text
|
390 |
lang='romanian',
|
391 |
speed=None):
|
392 |
+
|
393 |
lang = lang.lower() # https://huggingface.co/dkounadis/artificial-styletts2/blob/main/Utils/all_langs.csv
|
394 |
|
395 |
# https://huggingface.co/spaces/mms-meta/MMS
|
|
|
398 |
|
399 |
lang_code = 'hun'
|
400 |
|
401 |
+
elif any([i in lang for i in ['ser', 'bosn', 'herzegov', 'montenegr', 'macedon']]):
|
402 |
|
403 |
if has_cyrillic(text[0]): # check 0-th sentence if is cyrillic
|
404 |
|
405 |
+
lang_code = 'rmc-script_cyrillic' # romani carpathian (also has latin / cyrillic Vlax)
|
406 |
|
407 |
else:
|
408 |
|
|
|
418 |
lang_code = 'deu'
|
419 |
speed = 1.14 if speed is None else speed
|
420 |
|
421 |
+
elif 'alban' in lang:
|
422 |
+
|
423 |
+
lang_code = 'sqi'
|
424 |
+
speed = 1.04 if speed is None else speed
|
425 |
+
|
426 |
else:
|
427 |
|
428 |
lang_code = lang.split()[0].strip()
|
|
|
467 |
x = []
|
468 |
|
469 |
for _t in text:
|
|
|
|
|
|
|
470 |
if is_uroman:
|
471 |
uroman_dir = "Utils/uroman"
|
472 |
assert os.path.exists(uroman_dir)
|
473 |
uroman_pl = os.path.join(uroman_dir, "bin", "uroman.pl")
|
474 |
_t = text_mapper.uromanize(_t, uroman_pl)
|
475 |
|
476 |
+
_t = _t.lower()
|
477 |
+
|
478 |
+
if lang_code == 'rmc-script_latin':
|
479 |
+
|
480 |
+
_t = re.sub(r'\d+', number_to_phonemes, _t)
|
481 |
+
_t = fix_phones(_t)
|
482 |
+
|
483 |
+
elif lang_code == 'ron':
|
484 |
+
|
485 |
+
_t = _t.replace("ţ", "ț"
|
486 |
+
).replace('ț','ts').replace('î', 'u')
|
487 |
+
|
488 |
+
# /data/dkounadis/.hf7/hub/models--facebook--mms-tts/snapshots/44cc7fb408064ef9ea6e7c59130d88cac1274671/models/rmc-script_latin/vocab.txt
|
489 |
|
490 |
_t = text_mapper.filter_oov(_t, lang=lang)
|
491 |
|
492 |
+
print(f'{speed=}\n\n\n\n_______________________________ {_t}')
|
493 |
stn_tst = text_mapper.get_text(_t, hps)
|
494 |
with torch.no_grad():
|
495 |
x_tst = stn_tst.unsqueeze(0).to(device)
|
|
|
513 |
original_rate=16000,
|
514 |
target_rate=24000)[0, :] # reshapes (64,) -> (1,64)
|
515 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -2,7 +2,7 @@ torch
|
|
2 |
torchaudio
|
3 |
numpy
|
4 |
audiofile
|
5 |
-
|
6 |
cached_path
|
7 |
einops
|
8 |
flask
|
|
|
2 |
torchaudio
|
3 |
numpy
|
4 |
audiofile
|
5 |
+
num2words
|
6 |
cached_path
|
7 |
einops
|
8 |
flask
|