Update audiocraft/models/loaders.py
Browse files
audiocraft/models/loaders.py
CHANGED
@@ -100,14 +100,19 @@ def _delete_param(cfg: DictConfig, full_name: str):
|
|
100 |
OmegaConf.set_struct(cfg, True)
|
101 |
|
102 |
|
103 |
-
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
104 |
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
105 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
106 |
cfg.device = str(device)
|
107 |
if cfg.device == 'cpu':
|
108 |
cfg.dtype = 'float32'
|
109 |
else:
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
111 |
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
|
112 |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
113 |
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
|
|
100 |
OmegaConf.set_struct(cfg, True)
|
101 |
|
102 |
|
103 |
+
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', depth='float32', cache_dir: tp.Optional[str] = None):
|
104 |
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
105 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
106 |
cfg.device = str(device)
|
107 |
if cfg.device == 'cpu':
|
108 |
cfg.dtype = 'float32'
|
109 |
else:
|
110 |
+
if depth=='float32':
|
111 |
+
cfg.dtype = 'float32'
|
112 |
+
if depth=='bfloat16':
|
113 |
+
cfg.dtype = 'bfloat16'
|
114 |
+
if depth=='float16':
|
115 |
+
cfg.dtype = 'float16'
|
116 |
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
|
117 |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
118 |
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|