zachzzc's picture
Upload tts playground and serving engine
07f1f64
"""Projector that maps hidden states from the LLM component to multimodal logits."""
import torch
from torch import nn
from dataclasses import dataclass
from typing import Optional, Tuple
from .common import HiggsAudioPreTrainedModel
from .configuration_higgs_audio import HiggsAudioConfig
@dataclass
class HiggsAudioDecoderLayerOutput:
logits: torch.FloatTensor
audio_logits: torch.FloatTensor
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
class HiggsAudioDecoderProjector(HiggsAudioPreTrainedModel):
"""Projection layers that map hidden states from the LLM component to audio / text logits.
We support two type of audio head:
- Basic Audio Head:
Directly map the hidden states to audio logits for all the codebooks.
"""
def __init__(self, config: HiggsAudioConfig, layer_idx: Optional[int] = None):
super().__init__(config)
self.text_lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.audio_lm_head = nn.Linear(
config.text_config.hidden_size,
config.audio_num_codebooks * (config.audio_codebook_size + 2),
bias=False,
)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
hidden_states,
audio_out_mask,
label_audio_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
output_audio_hidden_states=False,
cache_position=None,
):
"""
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
Hidden states from the LLM component
audio_out_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
Mask for identifying the audio out tokens.
label_audio_ids (`torch.Tensor` of shape `(num_codebooks, num_audio_out_tokens)`):
Label tokens for the audio-out part. This is used for calculating the logits if RQ-Transformer is used.
attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
Mask to avoid performing attention on padding token indices
position_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
Position ids for the input tokens
Returns:
logits (`torch.Tensor` of shape `(batch_size, seq_len, vocab_size)`):
Logits for text tokens
audio_logits (`torch.Tensor` of shape `(num_audio_out_tokens, audio_num_codebooks * audio_codebook_size)`):
Logits for audio tokens. We ensure `num_text_tokens + num_audio_tokens == batch_size * seq_len`
"""
logits = self.text_lm_head(hidden_states)
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
# TODO(sxjscience) Need to check if DeepSpeed Zero3 supports zero-shape input.
if self.config.audio_decoder_proj_num_layers > 0:
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.transformer_layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
next_cache = next_decoder_cache if use_cache else None
audio_logits = self.audio_lm_head(hidden_states[audio_out_mask])
if output_audio_hidden_states:
audio_hidden_states = hidden_states[audio_out_mask]
else:
audio_hidden_states = None
return (
logits,
audio_logits,
all_self_attns,
all_hidden_states,
audio_hidden_states,
next_cache,
)