from torch import nn from transformers.modeling_utils import PreTrainedModel from .configuration_higgs_audio import HiggsAudioConfig class HiggsAudioPreTrainedModel(PreTrainedModel): config_class = HiggsAudioConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = [] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True def _init_weights(self, module): std = self.config.init_std if hasattr(self.config, "init_std") else self.config.audio_encoder_config.init_std if isinstance(module, (nn.Linear, nn.Conv1d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_()