Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,501 Bytes
07f1f64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
"""Projector that maps hidden states from the LLM component to multimodal logits."""
import torch
from torch import nn
from dataclasses import dataclass
from typing import Optional, Tuple
from .common import HiggsAudioPreTrainedModel
from .configuration_higgs_audio import HiggsAudioConfig
@dataclass
class HiggsAudioDecoderLayerOutput:
logits: torch.FloatTensor
audio_logits: torch.FloatTensor
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
class HiggsAudioDecoderProjector(HiggsAudioPreTrainedModel):
"""Projection layers that map hidden states from the LLM component to audio / text logits.
We support two type of audio head:
- Basic Audio Head:
Directly map the hidden states to audio logits for all the codebooks.
"""
def __init__(self, config: HiggsAudioConfig, layer_idx: Optional[int] = None):
super().__init__(config)
self.text_lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.audio_lm_head = nn.Linear(
config.text_config.hidden_size,
config.audio_num_codebooks * (config.audio_codebook_size + 2),
bias=False,
)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
hidden_states,
audio_out_mask,
label_audio_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
output_audio_hidden_states=False,
cache_position=None,
):
"""
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
Hidden states from the LLM component
audio_out_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
Mask for identifying the audio out tokens.
label_audio_ids (`torch.Tensor` of shape `(num_codebooks, num_audio_out_tokens)`):
Label tokens for the audio-out part. This is used for calculating the logits if RQ-Transformer is used.
attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
Mask to avoid performing attention on padding token indices
position_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
Position ids for the input tokens
Returns:
logits (`torch.Tensor` of shape `(batch_size, seq_len, vocab_size)`):
Logits for text tokens
audio_logits (`torch.Tensor` of shape `(num_audio_out_tokens, audio_num_codebooks * audio_codebook_size)`):
Logits for audio tokens. We ensure `num_text_tokens + num_audio_tokens == batch_size * seq_len`
"""
logits = self.text_lm_head(hidden_states)
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
# TODO(sxjscience) Need to check if DeepSpeed Zero3 supports zero-shape input.
if self.config.audio_decoder_proj_num_layers > 0:
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.transformer_layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
next_cache = next_decoder_cache if use_cache else None
audio_logits = self.audio_lm_head(hidden_states[audio_out_mask])
if output_audio_hidden_states:
audio_hidden_states = hidden_states[audio_out_mask]
else:
audio_hidden_states = None
return (
logits,
audio_logits,
all_self_attns,
all_hidden_states,
audio_hidden_states,
next_cache,
)
|