Update ultravox_config.py
Browse files- ultravox_config.py +1 -6
ultravox_config.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import dataclasses
|
2 |
-
import torch
|
3 |
from enum import Enum
|
4 |
from typing import Any, Dict, List, Optional
|
5 |
|
@@ -116,12 +115,8 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
|
116 |
text_model_lora_config: LoraConfigSimplified | None = None,
|
117 |
audio_model_lora_config: LoraConfigSimplified | None = None,
|
118 |
audio_latency_block_size: int | None = None,
|
119 |
-
torch_type: str | None = None,
|
120 |
**kwargs,
|
121 |
):
|
122 |
-
if isinstance(torch_type, str):
|
123 |
-
torch_type = getattr(torch, torch_type)
|
124 |
-
|
125 |
self.ignore_index = ignore_index
|
126 |
|
127 |
self.audio_model_id = audio_model_id
|
@@ -169,7 +164,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
|
169 |
|
170 |
self.initializer_range = text_config.initializer_range
|
171 |
|
172 |
-
super().__init__(**kwargs
|
173 |
|
174 |
def to_diff_dict(self) -> Dict[str, Any]:
|
175 |
diff_dict = super().to_diff_dict()
|
|
|
1 |
import dataclasses
|
|
|
2 |
from enum import Enum
|
3 |
from typing import Any, Dict, List, Optional
|
4 |
|
|
|
115 |
text_model_lora_config: LoraConfigSimplified | None = None,
|
116 |
audio_model_lora_config: LoraConfigSimplified | None = None,
|
117 |
audio_latency_block_size: int | None = None,
|
|
|
118 |
**kwargs,
|
119 |
):
|
|
|
|
|
|
|
120 |
self.ignore_index = ignore_index
|
121 |
|
122 |
self.audio_model_id = audio_model_id
|
|
|
164 |
|
165 |
self.initializer_range = text_config.initializer_range
|
166 |
|
167 |
+
super().__init__(**kwargs)
|
168 |
|
169 |
def to_diff_dict(self) -> Dict[str, Any]:
|
170 |
diff_dict = super().to_diff_dict()
|