higgs_audio_v2 / higgs_audio /model /custom_modules.py
zachzzc's picture
Upload tts playground and serving engine
07f1f64
import torch
import torch.nn as nn
class PartiallyFrozenEmbedding(nn.Module):
"""Split an existing `nn.Embedding` module that splits the embedding into:
- A frozen embedding for indices [0..freeze_until_idx].
- A trainable embedding for indices [freeze_until_idx+1..vocab_size-1].
This should work with both Zero-2 and Zero-3 seamlessly
"""
def __init__(self, original_embedding: nn.Embedding, freeze_until_idx: int):
"""
:param original_embedding: An instance of nn.Embedding (the original embedding layer).
:param freeze_until_idx: The index up to which the embedding is frozen (excluding). The freeze_until_idx is not frozen.
"""
super().__init__()
self.freeze_until_idx = freeze_until_idx
self.original_vocab_size = original_embedding.num_embeddings
self.embedding_dim = original_embedding.embedding_dim
# Split the original embedding into frozen and trainable parts
self.embedding_frozen = nn.Embedding(
freeze_until_idx,
self.embedding_dim,
dtype=original_embedding.weight.dtype,
device=original_embedding.weight.device,
)
self.embedding_trainable = nn.Embedding(
self.original_vocab_size - freeze_until_idx,
self.embedding_dim,
dtype=original_embedding.weight.dtype,
device=original_embedding.weight.device,
)
# Copy weights from the original embedding into the frozen and trainable parts
with torch.no_grad():
self.embedding_frozen.weight.copy_(original_embedding.weight[:freeze_until_idx])
self.embedding_trainable.weight.copy_(original_embedding.weight[freeze_until_idx:])
# Freeze the frozen embedding
self.embedding_frozen.weight.requires_grad = False
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the split embedding wrapper.
:param input_ids: Tensor of shape [batch_size, seq_len] with indices in [0..original_vocab_size-1].
"""
# Masks to separate frozen and trainable indices
# (bsz, seq_len)
mask_frozen = input_ids < self.freeze_until_idx
mask_trainable = ~mask_frozen
# Output tensor for embedding results
batch_size, seq_len = input_ids.shape
embeddings = torch.zeros(
batch_size,
seq_len,
self.embedding_dim,
device=input_ids.device,
dtype=self.embedding_frozen.weight.dtype,
)
# Handle frozen embedding
if mask_frozen.any():
frozen_ids = input_ids[mask_frozen]
frozen_emb = self.embedding_frozen(frozen_ids)
embeddings[mask_frozen] = frozen_emb
# Handle trainable embedding
if mask_trainable.any():
# Adjust trainable IDs to the local index space of the trainable embedding
trainable_ids = input_ids[mask_trainable] - (self.freeze_until_idx)
trainable_emb = self.embedding_trainable(trainable_ids)
embeddings[mask_trainable] = trainable_emb
return embeddings
def to_unsplit(self) -> nn.Embedding:
unsplit_embedding = nn.Embedding(
self.original_vocab_size,
self.embedding_dim,
dtype=self.embedding_frozen.weight.dtype,
device=self.embedding_frozen.weight.device,
)
with torch.no_grad():
unsplit_embedding.weight[: self.freeze_until_idx].copy_(self.embedding_frozen.weight)
unsplit_embedding.weight[self.freeze_until_idx :].copy_(self.embedding_trainable.weight)
return unsplit_embedding
class PartiallyFrozenLinear(nn.Module):
"""A wrapper around nn.Linear to partially freeze part of the weight matrix."""
def __init__(self, original_linear: nn.Linear, freeze_until_idx: int):
"""
:param original_linear: The original nn.Linear layer.
:param freeze_until_idx: The index up to which the rows of the weight matrix are frozen.
"""
super().__init__()
assert original_linear.bias is None, "Currently only support linear module without bias"
self.freeze_until_idx = freeze_until_idx
self.input_dim = original_linear.in_features
self.output_dim = original_linear.out_features
# Create frozen and trainable linear layers
self.linear_frozen = nn.Linear(
self.input_dim,
freeze_until_idx,
bias=False,
dtype=original_linear.weight.dtype,
device=original_linear.weight.device,
)
self.linear_trainable = nn.Linear(
self.input_dim,
self.output_dim - freeze_until_idx,
bias=False,
dtype=original_linear.weight.dtype,
device=original_linear.weight.device,
)
# Copy weights from the original linear layer
with torch.no_grad():
self.linear_frozen.weight.copy_(original_linear.weight[:freeze_until_idx])
self.linear_trainable.weight.copy_(original_linear.weight[freeze_until_idx:])
# Freeze the frozen linear layer
self.linear_frozen.weight.requires_grad = False
def forward(self, input_tensor):
# input_tensor: (bsz, seq_len, hidden_state_dim)
frozen_output = self.linear_frozen(input_tensor)
trainable_output = self.linear_trainable(input_tensor)
return torch.cat((frozen_output, trainable_output), dim=-1)
def to_unsplit(self) -> nn.Linear:
unsplit_linear = nn.Linear(
self.input_dim,
self.output_dim,
bias=False,
dtype=self.linear_frozen.weight.dtype,
device=self.linear_frozen.weight.device,
)
# Copy weights from the frozen and trainable layers into the unsplit linear layer
with torch.no_grad():
unsplit_linear.weight[: self.freeze_until_idx].copy_(self.linear_frozen.weight)
unsplit_linear.weight[self.freeze_until_idx :].copy_(self.linear_trainable.weight)
return unsplit_linear