farzadab commited on
Commit
e91d05d
·
verified ·
1 Parent(s): 2e7bce3

Update ultravox_config.py

Browse files
Files changed (1) hide show
  1. ultravox_config.py +1 -6
ultravox_config.py CHANGED
@@ -1,5 +1,4 @@
1
  import dataclasses
2
- import torch
3
  from enum import Enum
4
  from typing import Any, Dict, List, Optional
5
 
@@ -116,12 +115,8 @@ class UltravoxConfig(transformers.PretrainedConfig):
116
  text_model_lora_config: LoraConfigSimplified | None = None,
117
  audio_model_lora_config: LoraConfigSimplified | None = None,
118
  audio_latency_block_size: int | None = None,
119
- torch_type: str | None = None,
120
  **kwargs,
121
  ):
122
- if isinstance(torch_type, str):
123
- torch_type = getattr(torch, torch_type)
124
-
125
  self.ignore_index = ignore_index
126
 
127
  self.audio_model_id = audio_model_id
@@ -169,7 +164,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
169
 
170
  self.initializer_range = text_config.initializer_range
171
 
172
- super().__init__(**kwargs, torch_type=torch_type)
173
 
174
  def to_diff_dict(self) -> Dict[str, Any]:
175
  diff_dict = super().to_diff_dict()
 
1
  import dataclasses
 
2
  from enum import Enum
3
  from typing import Any, Dict, List, Optional
4
 
 
115
  text_model_lora_config: LoraConfigSimplified | None = None,
116
  audio_model_lora_config: LoraConfigSimplified | None = None,
117
  audio_latency_block_size: int | None = None,
 
118
  **kwargs,
119
  ):
 
 
 
120
  self.ignore_index = ignore_index
121
 
122
  self.audio_model_id = audio_model_id
 
164
 
165
  self.initializer_range = text_config.initializer_range
166
 
167
+ super().__init__(**kwargs)
168
 
169
  def to_diff_dict(self) -> Dict[str, Any]:
170
  diff_dict = super().to_diff_dict()