ford442 commited on
Commit
2de17e7
·
verified ·
1 Parent(s): dd54b0a

Update audiocraft/models/loaders.py

Browse files
Files changed (1) hide show
  1. audiocraft/models/loaders.py +7 -2
audiocraft/models/loaders.py CHANGED
@@ -100,14 +100,19 @@ def _delete_param(cfg: DictConfig, full_name: str):
100
  OmegaConf.set_struct(cfg, True)
101
 
102
 
103
- def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
104
  pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
105
  cfg = OmegaConf.create(pkg['xp.cfg'])
106
  cfg.device = str(device)
107
  if cfg.device == 'cpu':
108
  cfg.dtype = 'float32'
109
  else:
110
- cfg.dtype = 'float16'
 
 
 
 
 
111
  _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
112
  _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
113
  _delete_param(cfg, 'conditioners.args.drop_desc_p')
 
100
  OmegaConf.set_struct(cfg, True)
101
 
102
 
103
+ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', depth='float32', cache_dir: tp.Optional[str] = None):
104
  pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
105
  cfg = OmegaConf.create(pkg['xp.cfg'])
106
  cfg.device = str(device)
107
  if cfg.device == 'cpu':
108
  cfg.dtype = 'float32'
109
  else:
110
+ if depth=='float32':
111
+ cfg.dtype = 'float32'
112
+ if depth=='bfloat16':
113
+ cfg.dtype = 'bfloat16'
114
+ if depth=='float16':
115
+ cfg.dtype = 'float16'
116
  _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
117
  _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
118
  _delete_param(cfg, 'conditioners.args.drop_desc_p')