Spaces:
Running
on
Zero
Running
on
Zero
from torch import nn | |
from transformers.modeling_utils import PreTrainedModel | |
from .configuration_higgs_audio import HiggsAudioConfig | |
class HiggsAudioPreTrainedModel(PreTrainedModel): | |
config_class = HiggsAudioConfig | |
base_model_prefix = "model" | |
supports_gradient_checkpointing = True | |
_no_split_modules = [] | |
_skip_keys_device_placement = "past_key_values" | |
_supports_flash_attn_2 = True | |
_supports_sdpa = True | |
def _init_weights(self, module): | |
std = self.config.init_std if hasattr(self.config, "init_std") else self.config.audio_encoder_config.init_std | |
if isinstance(module, (nn.Linear, nn.Conv1d)): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |