Spaces:
Running
on
Zero
Running
on
Zero
import contextlib | |
from contextlib import contextmanager | |
from functools import wraps | |
import torch | |
from transformers.integrations import is_deepspeed_available | |
if is_deepspeed_available(): | |
from deepspeed.utils import groups as deepspeed_groups | |
from deepspeed.sequence.layer import _SeqAllToAll | |
else: | |
deepspeed_groups = None | |
_SeqAllToAll = None | |
def _ceil_to_nearest(n, round_to): | |
return (n + round_to - 1) // round_to * round_to | |
def count_parameters(model, trainable_only=True): | |
if trainable_only: | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
else: | |
return sum(p.numel() for p in model.parameters()) | |
# TODO(sxjscience) Consider to move the function to audio_processing/utils.py | |
def build_delay_pattern_mask( | |
input_ids: torch.LongTensor, | |
bos_token_id: int, | |
pad_token_id: int, | |
): | |
"""Implement the delay pattern proposed in "Simple and Controllable Music Generation", https://arxiv.org/pdf/2306.05284 | |
In the delay pattern, each codebook is offset by the previous codebook by | |
one. We insert a special delay token at the start of the sequence if its delayed, and append pad token once the sequence finishes. | |
Take the example where there are 4 codebooks and audio sequence length=5. After shifting, the output should have length seq_len + num_codebooks - 1 | |
- [ *, *, *, *, *, P, P, P] | |
- [ B, *, *, *, *, *, P, P] | |
- [ B, B, *, *, *, *, *, P] | |
- [ B, B, B, *, *, *, *, *] | |
where B indicates the delay token id, P is the special padding token id and `*` indicates that the original audio token. | |
Now let's consider the case where we have a sequence of audio tokens to condition on. | |
The audio tokens were originally in the following non-delayed form: | |
- [a, b] | |
- [c, d] | |
- [e, f] | |
- [g, h] | |
After conversion, we get the following delayed form: | |
- [a, b, -1, -1, -1] | |
- [B, c, d, -1, -1] | |
- [B, B, e, f, -1] | |
- [B, B, B, g, h] | |
Note that we have a special token `-1` that indicates it should be replaced by a new token we see in the generation phase. | |
In that case, we should override the `-1` tokens in auto-regressive generation. | |
Args: | |
input_ids (:obj:`torch.LongTensor`): | |
The input ids of the prompt. It will have shape (bsz, num_codebooks, seq_len). | |
bos_token_id (:obj:`int`): | |
The id of the special delay token | |
pad_token_id (:obj:`int`): | |
The id of the padding token. Should be the same as eos_token_id. | |
Returns: | |
input_ids (:obj:`torch.LongTensor`): | |
The transformed input ids with delay pattern applied. It will have shape (bsz, num_codebooks, seq_len + num_codebooks - 1). | |
input_ids_with_gen_mask (:obj:`torch.LongTensor`): | |
The transformed input ids with delay pattern applied. The -1 in the output indicates new tokens that should be generated. | |
""" | |
bsz, num_codebooks, seq_len = input_ids.shape | |
new_seq_len = seq_len + num_codebooks - 1 | |
input_ids_with_gen_mask = torch.ones((bsz, num_codebooks, new_seq_len), dtype=torch.long, device=input_ids.device) | |
bos_mask = torch.tril(input_ids_with_gen_mask, -1) > 0 | |
eos_mask = torch.triu(input_ids_with_gen_mask, seq_len) > 0 | |
input_ids_with_gen_mask[bos_mask] = bos_token_id | |
input_ids_with_gen_mask[(~bos_mask) & (~eos_mask)] = input_ids.reshape(-1) | |
input_ids = input_ids_with_gen_mask.clone() | |
input_ids[eos_mask] = pad_token_id | |
input_ids_with_gen_mask[eos_mask] = -1 | |
return input_ids, input_ids_with_gen_mask | |
def revert_delay_pattern(data): | |
"""Convert samples encoded with delay pattern back to the original form. | |
Args: | |
data (:obj:`torch.Tensor`): | |
The data with delay pattern applied. It will have shape (num_codebooks, seq_len + num_codebooks - 1). | |
Returns: | |
ret (:obj:`torch.Tensor`): | |
Recovered data with delay pattern removed. It will have shape (num_codebooks, seq_len). | |
""" | |
assert len(data.shape) == 2 | |
out_l = [] | |
num_codebooks = data.shape[0] | |
for i in range(num_codebooks): | |
out_l.append(data[i : (i + 1), i : (data.shape[1] - num_codebooks + 1 + i)]) | |
return torch.cat(out_l, dim=0) | |
def merge_input_ids_with_audio_features( | |
audio_features_embed, | |
audio_features_length, | |
audio_in_embed, | |
audio_in_ids_start, | |
audio_out_embed, | |
audio_out_ids_start, | |
audio_in_token_idx, | |
audio_out_token_idx, | |
inputs_embeds, | |
input_ids, | |
attention_mask, | |
label_ids, | |
pad_token_id, | |
ignore_index=-100, | |
round_to=8, | |
left_padding=True, | |
): | |
""" | |
Merge input_ids with audio features into final embeddings. | |
Args: | |
audio_features_embed (`torch.Tensor` of shape `(num_audios, max_audio_tokens, embed_dim)`): | |
Encoded vectors of all audios in the batch (obtained from the semantic encoder) | |
audio_features_length (`torch.LongTensor` of shape `(num_audios,)`): | |
The length of audio embeddings of each audio as stacked in `audio_features_embed` | |
audio_in_embed (`torch.Tensor` of shape `(total_num_audio_in_tokens, embed_dim)`): | |
The embeddings of audio-in tokens | |
audio_in_ids_start (`torch.LongTensor` of shape `(num_audios,)`): | |
The start index of the audio-in tokens for each audio | |
audio_out_embed (`torch.Tensor` of shape `(total_num_audio_out_tokens, embed_dim)`): | |
The embeddings of audio-out tokens | |
audio_out_ids_start (`torch.LongTensor` of shape `(num_audios,)`): | |
The start index of the audio-out tokens for each audio | |
audio_in_token_idx | |
The index of the audio-in token in the vocabulary | |
audio_out_token_idx | |
The index of the audio-out token in the vocabulary | |
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`): | |
Token embeddings before merging with audio embeddings | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Input_ids of tokens, possibly filled with audio token | |
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Mask to avoid performing attention on padding token indices. | |
label_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) | |
labels need to be recalculated to support training (if provided) | |
pad_token_id (`int`): | |
The index of the pad token in the vocabulary | |
ignore_index | |
The index to ignore in the loss calculation | |
round_to | |
The number to round to for padding | |
left_padding | |
Whether to apply left padding | |
Returns: | |
final_embedding | |
The final embeddings after merging audio embeddings with text embeddings. | |
final_attention_mask | |
The final attention mask after merging audio embeddings with text embeddings. | |
final_labels | |
The labels for the text stream | |
position_ids | |
Positional ids for the merged data | |
final_input_ids | |
The final input_ids after merging audio embeddings with text embeddings. | |
final_audio_in_mask | |
Mask for audio-in embeddings | |
final_audio_in_discrete_codes_mask | |
Mask for audio-in discrete tokens | |
final_audio_out_mask | |
Mask for audio-out embeddings | |
Explanation: | |
each audio has variable length embeddings, with length specified by | |
- audio_features_length | |
- audio_in_ids_start | |
- audio_out_ids_start | |
Task: | |
- fill each <|AUDIO|> with audio embeddings (it can be the combination of embeddings extracted by WhisperEncoder and embeddings from audio codebooks) | |
- fill each <|AUDIO_OUT|> with the audio-out embeddings | |
Example: | |
<|AUDIO_OUT|>: X (5 tokens), Y (3 tokens) | |
<|AUDIO|>: Z (8 tokens) | |
X, Y are in the same sequence (in-context voice-clone). Z is in a different sequence (audio understanding). | |
if right padding | |
input_ids: [ | |
a b c d e f X g h i j k Y l m | |
o p q r Z s t u v _ _ _ _ _ _ | |
] | |
input_ids should be: [ | |
a b c d e f X X X X X g h i j k Y Y Y l m | |
o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _ | |
] | |
labels should be: [ | |
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m | |
o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _ | |
] | |
elif left padding | |
input_ids: [ | |
a b c d e f X g h i j k Y l m | |
_ _ _ _ _ _ o p q r Z s t u v | |
] | |
input_ids should be: [ | |
a b c d e f X X X X X g h i j k Y Y Y l m | |
_ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v | |
] | |
labels should be: [ | |
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m | |
_ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v | |
] | |
""" | |
if label_ids is None: | |
skip_labels = True | |
else: | |
skip_labels = False | |
if audio_features_embed is not None and audio_features_embed.shape[0] == 0: | |
audio_features_embed = None | |
if audio_in_embed is not None and audio_in_embed.shape[0] == 0: | |
audio_in_embed = None | |
if audio_out_embed is not None and audio_out_embed.shape[0] == 0: | |
audio_out_embed = None | |
batch_size, sequence_length, embed_dim = inputs_embeds.shape | |
target_device = inputs_embeds.device | |
if left_padding is None: | |
left_padding = torch.any(attention_mask[:, 0] == 0) | |
audio_in_token_mask = input_ids == audio_in_token_idx | |
audio_out_token_mask = input_ids == audio_out_token_idx | |
text_token_mask = (input_ids != audio_in_token_idx) & (input_ids != audio_out_token_idx) | |
# 1. Calculate the number of tokens for each placeholder (like [<|AUDIO|>, <|AUDIO_OUT|>]). | |
token_placeholder_num = torch.ones_like(input_ids) | |
if audio_features_embed is not None: | |
num_audios, max_audio_tokens, _ = audio_features_embed.shape | |
audio_in_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to( | |
audio_features_length.device | |
) < audio_features_length.unsqueeze(1) | |
masked_audio_in_features = audio_features_embed[audio_in_features_mask].view(-1, embed_dim) | |
token_placeholder_num[audio_in_token_mask] = audio_features_length.long() | |
if audio_in_embed is not None: | |
audio_in_codes_length = torch.concat( | |
[ | |
audio_in_ids_start[1:] - audio_in_ids_start[:-1], | |
torch.tensor( | |
[audio_in_embed.shape[0] - audio_in_ids_start[-1]], | |
device=audio_in_ids_start.device, | |
dtype=torch.long, | |
), | |
], | |
dim=0, | |
) | |
if audio_features_embed is not None: | |
token_placeholder_num[audio_in_token_mask] += audio_in_codes_length.long() | |
else: | |
token_placeholder_num[audio_in_token_mask] = audio_in_codes_length.long() | |
if audio_out_embed is not None: | |
audio_out_codes_length = torch.concat( | |
[ | |
audio_out_ids_start[1:] - audio_out_ids_start[:-1], | |
torch.tensor( | |
[audio_out_embed.shape[0] - audio_out_ids_start[-1]], | |
device=audio_out_ids_start.device, | |
dtype=torch.long, | |
), | |
], | |
dim=0, | |
) | |
token_placeholder_num[audio_out_token_mask] = audio_out_codes_length.long() | |
new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1 | |
max_token_num = _ceil_to_nearest(token_placeholder_num.sum(-1).max(), round_to) | |
nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1] | |
if left_padding: | |
new_token_positions += nb_audio_pad[:, None] # offset for left padding | |
# 2. Create the full embedding, already padded to the maximum position | |
final_embedding = torch.zeros( | |
(batch_size, max_token_num, embed_dim), | |
dtype=inputs_embeds.dtype, | |
device=inputs_embeds.device, | |
) | |
final_attention_mask = torch.zeros( | |
(batch_size, max_token_num), | |
dtype=attention_mask.dtype, | |
device=inputs_embeds.device, | |
) | |
final_input_ids = torch.full( | |
(batch_size, max_token_num), | |
pad_token_id, | |
dtype=input_ids.dtype, | |
device=inputs_embeds.device, | |
) | |
if skip_labels: | |
final_labels = None | |
else: | |
final_labels = torch.full( | |
(batch_size, max_token_num), | |
ignore_index, | |
dtype=label_ids.dtype, | |
device=inputs_embeds.device, | |
) | |
final_audio_in_mask = torch.full( | |
(batch_size, max_token_num), | |
False, | |
dtype=torch.bool, | |
device=inputs_embeds.device, | |
) | |
final_audio_in_discrete_codes_mask = torch.full( | |
(batch_size, max_token_num), | |
False, | |
dtype=torch.bool, | |
device=inputs_embeds.device, | |
) | |
final_audio_out_mask = torch.full( | |
(batch_size, max_token_num), | |
False, | |
dtype=torch.bool, | |
device=inputs_embeds.device, | |
) | |
# 3. Get the audio-in token positions and audio-out token positions | |
batch_id = torch.arange(batch_size, device=target_device).unsqueeze(1).expand(batch_size, sequence_length) | |
audio_in_batch_id = batch_id[audio_in_token_mask] # Shape (num_audio_in,) | |
audio_out_batch_id = batch_id[audio_out_token_mask] # Shape (num_audio_out,) | |
audio_features_token_ends = new_token_positions[audio_in_token_mask] # Shape (num_audio_in,) | |
audio_out_embed_ends = new_token_positions[audio_out_token_mask] # Shape (num_audio_out,) | |
if audio_in_embed is not None: | |
# Fill in the audio-in embeddings | |
seq_indices = ( | |
torch.arange(max_token_num, device=target_device) | |
.unsqueeze(0) | |
.expand(audio_in_ids_start.shape[0], max_token_num) | |
) | |
audio_in_embed_token_starts = audio_features_token_ends - audio_in_codes_length + 1 | |
batch_indices, col_indices = torch.where( | |
(seq_indices >= audio_in_embed_token_starts.unsqueeze(1)) | |
& (seq_indices <= audio_features_token_ends.unsqueeze(1)) | |
) | |
batch_indices = audio_in_batch_id[batch_indices] | |
final_embedding[batch_indices, col_indices] = audio_in_embed | |
final_input_ids[batch_indices, col_indices] = audio_in_token_idx | |
if not skip_labels: | |
final_labels[batch_indices, col_indices] = ignore_index | |
final_audio_in_mask[batch_indices, col_indices] = True | |
final_audio_in_discrete_codes_mask[batch_indices, col_indices] = True | |
audio_features_token_ends = audio_features_token_ends - audio_in_codes_length | |
if audio_features_embed is not None: | |
# Fill in the audio features | |
seq_indices = ( | |
torch.arange(max_token_num, device=target_device) | |
.unsqueeze(0) | |
.expand(audio_features_embed.shape[0], max_token_num) | |
) | |
audio_features_token_starts = audio_features_token_ends - audio_features_length + 1 | |
batch_indices, col_indices = torch.where( | |
(seq_indices >= audio_features_token_starts.unsqueeze(1)) | |
& (seq_indices <= audio_features_token_ends.unsqueeze(1)) | |
) | |
batch_indices = audio_in_batch_id[batch_indices] | |
final_embedding[batch_indices, col_indices] = masked_audio_in_features | |
final_input_ids[batch_indices, col_indices] = audio_in_token_idx | |
if not skip_labels: | |
final_labels[batch_indices, col_indices] = ignore_index | |
final_audio_in_mask[batch_indices, col_indices] = True | |
if audio_out_embed is not None: | |
# Fill in the audio-out embeddings | |
seq_indices = ( | |
torch.arange(max_token_num, device=target_device) | |
.unsqueeze(0) | |
.expand(audio_out_ids_start.shape[0], max_token_num) | |
) | |
audio_out_embed_token_starts = audio_out_embed_ends - audio_out_codes_length + 1 | |
batch_indices, col_indices = torch.where( | |
(seq_indices >= audio_out_embed_token_starts.unsqueeze(1)) | |
& (seq_indices <= audio_out_embed_ends.unsqueeze(1)) | |
) | |
batch_indices = audio_out_batch_id[batch_indices] | |
final_embedding[batch_indices, col_indices] = audio_out_embed | |
final_input_ids[batch_indices, col_indices] = audio_out_token_idx | |
if not skip_labels: | |
final_labels[batch_indices, col_indices] = ignore_index | |
final_audio_out_mask[batch_indices, col_indices] = True | |
# Fill in the original text embeddings and labels | |
batch_indices, non_audio_indices = torch.where(text_token_mask) | |
text_to_overwrite = new_token_positions[batch_indices, non_audio_indices] | |
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_audio_indices] | |
if not skip_labels: | |
final_labels[batch_indices, text_to_overwrite] = label_ids[batch_indices, non_audio_indices] | |
final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_audio_indices] | |
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_audio_indices] | |
final_attention_mask = final_attention_mask | final_audio_in_mask | final_audio_out_mask | |
# Trim the tensor if there are redundant padding tokens | |
if left_padding: | |
first_non_zero_loc = final_attention_mask.sum(0).nonzero()[0] | |
first_non_zero_loc = (first_non_zero_loc // round_to) * round_to | |
if first_non_zero_loc > 0: | |
final_attention_mask = final_attention_mask[:, first_non_zero_loc:] | |
final_embedding = final_embedding[:, first_non_zero_loc:] | |
if not skip_labels: | |
final_labels = final_labels[:, first_non_zero_loc:] | |
final_input_ids = final_input_ids[:, first_non_zero_loc:] | |
final_audio_in_mask = final_audio_in_mask[:, first_non_zero_loc:] | |
final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, first_non_zero_loc:] | |
final_audio_out_mask = final_audio_out_mask[:, first_non_zero_loc:] | |
else: | |
# We have done right padding, so we need to trim the mask | |
last_non_zero_loc = final_attention_mask.sum(0).nonzero()[-1] + 1 | |
last_non_zero_loc = ((last_non_zero_loc + round_to - 1) // round_to) * round_to | |
if last_non_zero_loc < max_token_num: | |
final_attention_mask = final_attention_mask[:, :last_non_zero_loc] | |
final_embedding = final_embedding[:, :last_non_zero_loc] | |
if not skip_labels: | |
final_labels = final_labels[:, :last_non_zero_loc] | |
final_input_ids = final_input_ids[:, :last_non_zero_loc] | |
final_audio_in_mask = final_audio_in_mask[:, :last_non_zero_loc] | |
final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, :last_non_zero_loc] | |
final_audio_out_mask = final_audio_out_mask[:, :last_non_zero_loc] | |
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) | |
return ( | |
final_embedding, | |
final_attention_mask, | |
final_labels, | |
position_ids, | |
final_input_ids, | |
final_audio_in_mask, | |
final_audio_in_discrete_codes_mask, | |
final_audio_out_mask, | |
) | |
def is_deepspeed_ulysses_enabled(): | |
if deepspeed_groups is None: | |
return False | |
"""Check if sequence parallelism is enabled.""" | |
return deepspeed_groups._get_sequence_parallel_world_size() > 1 | |
def support_deepspeed_ulysses(module): | |
"""A decorator around Pytorch module. It is needed for the module that needs access to sequence parallel info.""" | |
module._sp_size = None | |
module._sp_rank = None | |
module._sp_group = None | |
def sp_size(self): | |
if self._sp_size is None: | |
self._sp_size = 1 | |
if is_deepspeed_ulysses_enabled(): | |
self._sp_size = deepspeed_groups._get_sequence_parallel_group().size() | |
return self._sp_size | |
def sp_rank(self): | |
if self._sp_rank is None: | |
self._sp_rank = 0 | |
if is_deepspeed_ulysses_enabled(): | |
self._sp_rank = deepspeed_groups._get_sequence_parallel_rank() | |
return self._sp_rank | |
def sp_group(self): | |
if self._sp_group is None and is_deepspeed_ulysses_enabled(): | |
self._sp_group = deepspeed_groups._get_sequence_parallel_group() | |
return self._sp_group | |
module.sp_size = sp_size | |
module.sp_rank = sp_rank | |
module.sp_group = sp_group | |
return module | |
def deepspeed_ulysses_attention(seq_dim=1, head_dim=2): | |
"""Perform all-to-all before and after the attention function.""" | |
def attention_decorator(attn_func=None): | |
def wrapped(*args, **kwargs): | |
if is_deepspeed_ulysses_enabled(): | |
sp_group = deepspeed_groups._get_sequence_parallel_group() | |
scatter_idx = head_dim # Scatter on num_heads dimension | |
gather_idx = seq_dim # Gather on seq_len dimension | |
batch_dim_idx = 0 | |
args = list(args) | |
args[0] = _SeqAllToAll.apply(sp_group, args[0], scatter_idx, gather_idx, batch_dim_idx) | |
args[1] = _SeqAllToAll.apply(sp_group, args[1], scatter_idx, gather_idx, batch_dim_idx) | |
args[2] = _SeqAllToAll.apply(sp_group, args[2], scatter_idx, gather_idx, batch_dim_idx) | |
args = tuple(args) | |
attn_output = attn_func(*args, **kwargs) | |
if is_deepspeed_ulysses_enabled(): | |
scatter_idx = seq_dim # Scatter back on seq_len dimension | |
gather_idx = head_dim # Gather on num_heads dimension | |
batch_dim_idx = 0 | |
attn_output = _SeqAllToAll.apply(sp_group, attn_output, scatter_idx, gather_idx, batch_dim_idx) | |
return attn_output | |
return wrapped | |
return attention_decorator | |
def deepspeed_ulysses_rope(state_seq_dim=2, trig_seq_dim=1): | |
"""Slice the corresponding cos and sin chunks for rope.""" | |
def rope_decorator(rope_func=None): | |
def wrapped(*args, **kwargs): | |
if is_deepspeed_ulysses_enabled(): | |
sp_rank = deepspeed_groups._get_sequence_parallel_rank() | |
args = list(args) | |
seq_chunk_size = args[0].size(state_seq_dim) | |
args[2] = torch.narrow(args[2], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size) | |
args[3] = torch.narrow(args[3], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size) | |
args = tuple(args) | |
return rope_func(*args, **kwargs) | |
return wrapped | |
return rope_decorator | |
def _gather_tensors(input_, group=None): | |
"""Gather tensors and concatenate them along a dimension.""" | |
input_ = input_.contiguous() | |
world_size = torch.distributed.get_world_size(group) | |
if world_size == 1: | |
return input_ | |
tensor_shapes = [ | |
torch.empty(len(input_.size()), dtype=torch.int64, device=input_.device) for _ in range(world_size) | |
] | |
input_size = torch.tensor(input_.size(), dtype=torch.int64, device=input_.device) | |
torch.distributed.all_gather(tensor_shapes, input_size, group=group) | |
gathered_buffers = [ | |
torch.empty(tensor_shapes[i].tolist(), dtype=input_.dtype, device=input_.device) for i in range(world_size) | |
] | |
torch.distributed.all_gather(gathered_buffers, input_, group=group) | |
return gathered_buffers | |
def _scatter_tensors(input_, group=None): | |
"""Scatter tensors.""" | |
world_size = torch.distributed.get_world_size(group) | |
if world_size == 1: | |
return input_ | |
rank = torch.distributed.get_rank(group) | |
return input_[rank] | |
class _GatherTensors(torch.autograd.Function): | |
"""All gather tensors among the ranks.""" | |
def symbolic(graph, input_, group): | |
return _gather_tensors(input_, group) | |
def forward(ctx, input_, group): | |
ctx.group = group | |
return torch.nested.as_nested_tensor(_gather_tensors(input_, group), layout=torch.jagged) | |
def backward(ctx, grad_output): | |
return _scatter_tensors(grad_output, ctx.group), None | |
def all_gather_tensors(input_, size=None, dim=0, group=None): | |
if torch.distributed.get_world_size(group) == 1: | |
# no sequence parallelism | |
return input_ | |
gathered_tensors = _GatherTensors.apply(input_, group) | |
if size: | |
split_gathered_tensors = [] | |
for s, gathered_tensor in zip(size, gathered_tensors): | |
split_gathered_tensor = torch.split(gathered_tensor, s.tolist()) | |
split_gathered_tensors.append(split_gathered_tensor) | |
gathered_tensors = [y for x in zip(*split_gathered_tensors) for y in x] | |
return torch.cat(gathered_tensors, dim).contiguous() | |
def get_sequence_data_parallel_world_size(): | |
return torch.distributed.get_world_size() | |
def get_sequence_data_parallel_rank(): | |
return torch.distributed.get_rank() | |
def get_sequence_data_parallel_group(): | |
return torch.distributed.group.WORLD | |
if is_deepspeed_available(): | |
deepspeed_groups._get_sequence_data_parallel_world_size = get_sequence_data_parallel_world_size | |
deepspeed_groups._get_sequence_data_parallel_rank = get_sequence_data_parallel_rank | |
deepspeed_groups._get_sequence_data_parallel_group = get_sequence_data_parallel_group | |
def _gather_tokens(input_, dim=0, group=None): | |
"""Gather tensors and concatenate them along a dimension""" | |
input_ = input_.contiguous() | |
world_size = torch.distributed.get_world_size(group) | |
if world_size == 1: | |
return input_ | |
gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device) | |
torch.distributed.all_gather_into_tensor(gather_buffer, input_, group=group) | |
if dim == 0: | |
shape = list(input_.size()) | |
shape[0] = shape[0] * world_size | |
output = gather_buffer.view(shape) | |
else: | |
tensor_list = [ | |
gather_buffer.narrow(0, input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size) | |
] | |
# Note: torch.cat already creates a contiguous tensor. | |
output = torch.cat(tensor_list, dim=dim).contiguous() | |
return output | |
def _drop_tokens(input_, dim=0, group=None): | |
"""Divide a tensor among the sequence parallel ranks""" | |
world_size = torch.distributed.get_world_size(group) | |
if world_size == 1: | |
return input_ | |
this_rank = torch.distributed.get_rank(group) | |
assert input_.shape[dim] % world_size == 0, ( | |
f"input dimension {dim} ({input_.shape[dim]}) is not divisible by sequence parallel world size ({world_size})" | |
) | |
chunk_size = input_.shape[dim] // world_size | |
return torch.narrow(input_, dim, this_rank * chunk_size, chunk_size) | |
class _DropTokens(torch.autograd.Function): | |
"Divide tokens equally among the sequence parallel ranks" | |
def symbolic(graph, input_, dim, group, grad_scale): | |
return _drop_tokens(input_, dim, group) | |
def forward(ctx, input_, dim, group, grad_scale): | |
ctx.dim = dim | |
ctx.group = group | |
ctx.grad_scale = grad_scale | |
return _drop_tokens(input_, dim, group) | |
def backward(ctx, grad_output): | |
grad_input = _gather_tokens(grad_output, ctx.dim, ctx.group) | |
if ctx.grad_scale != 1: | |
grad_input /= ctx.grad_scale | |
return grad_input, None, None, None | |
class _GatherTokens(torch.autograd.Function): | |
"Gather tokens among the sequence parallel ranks" | |
def symbolic(graph, input_, dim, group, grad_scale): | |
return _gather_tokens(input_, dim, group) | |
def forward(ctx, input_, dim, group, grad_scale): | |
ctx.dim = dim | |
ctx.group = group | |
ctx.grad_scale = grad_scale | |
return _gather_tokens(input_, dim, group) | |
def backward(ctx, grad_output): | |
grad_input = _drop_tokens(grad_output, ctx.dim, ctx.group) | |
if ctx.grad_scale != 1: | |
grad_input *= ctx.grad_scale | |
return grad_input, None, None, None | |
def drop_tokens(input_, dim=0, group=None, grad_scale=1): | |
if torch.distributed.get_world_size(group) == 1: | |
# no sequence parallelism | |
return input_ | |
return _DropTokens.apply(input_, dim, group, grad_scale) | |
def gather_tokens(input_, dim=0, group=None, grad_scale=1): | |
if torch.distributed.get_world_size(group) == 1: | |
# no sequence parallelism | |
return input_ | |
return _GatherTokens.apply(input_, dim, group, grad_scale) | |
def sequence_chunking_per_rank(sp_size, sp_rank, *args, dim=1): | |
""" | |
Slice the inputs to create chuncks per the sequence parallel rank. This is used for the context parallel training. | |
Args: | |
sp_size (`int`): | |
Sequence parallel size. | |
sp_rank (`int`): | |
Sequence parallel rank for the current process. | |
dim (`int`): | |
The dimension to slice | |
""" | |
if sp_size == 1: | |
return args[0] if len(args) == 1 else args | |
seq_length = args[0].size(dim) | |
for arg in args[1:]: | |
assert arg.size(dim) == seq_length, ( | |
f"arg={arg} ({arg.shape[dim]}) does not have the same size as args[0] ({seq_length}) in dimension {dim}" | |
) | |
assert seq_length % sp_size == 0, ( | |
f"dimension {dim} ({args[0].shape[dim]}) is not divisible by sequence parallel world size ({sp_size})" | |
) | |
sub_seq_length = seq_length // sp_size | |
sub_seq_start = sp_rank * sub_seq_length | |
output = [] | |
for ind in args: | |
ind = torch.narrow(ind, dim, sub_seq_start, sub_seq_length) | |
output.append(ind) | |
return tuple(output) if len(output) > 1 else output[0] | |
def disable_deepspeed_ulysses(): | |
"""Disable deepspeed ulysses (sequence parallelism) if it is enabled""" | |
if is_deepspeed_ulysses_enabled(): | |
_old_get_sequence_parallel_world_size = deepspeed_groups._get_sequence_parallel_world_size | |
def _get_sequence_parallel_world_size(): | |
return 1 | |
deepspeed_groups._get_sequence_parallel_world_size = _get_sequence_parallel_world_size | |
try: | |
yield | |
finally: | |
deepspeed_groups._get_sequence_parallel_world_size = _old_get_sequence_parallel_world_size | |
else: | |
context = contextlib.nullcontext | |
with context(): | |
yield | |