AndreasXi commited on
Commit
1bb9319
·
1 Parent(s): 0ff9928

fix model variants

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. meanaudio/model/networks.py +2 -2
app.py CHANGED
@@ -132,7 +132,7 @@ def generate_audio_gradio(
132
  current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
133
  filename = f"{safe_prompt}_{current_time_string}.flac"
134
  save_path = OUTPUT_DIR / filename
135
- torchaudio.save(str(save_path), audio, temp_seq_cfg.sampling_rate)
136
  log.info(f"Audio saved to {save_path}")
137
 
138
  gc.collect()
 
132
  current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
133
  filename = f"{safe_prompt}_{current_time_string}.flac"
134
  save_path = OUTPUT_DIR / filename
135
+ torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
136
  log.info(f"Audio saved to {save_path}")
137
 
138
  gc.collect()
meanaudio/model/networks.py CHANGED
@@ -600,9 +600,9 @@ def meanaudio_s(**kwargs) -> MeanAudio:
600
 
601
 
602
  def get_mean_audio(name: str, **kwargs) -> MeanAudio:
603
- if name == 'meanaudio_s':
604
  return meanaudio_s(**kwargs)
605
- if name == 'fluxaudio_s':
606
  return fluxaudio_s(**kwargs)
607
 
608
  raise ValueError(f'Unknown model name: {name}')
 
600
 
601
 
602
  def get_mean_audio(name: str, **kwargs) -> MeanAudio:
603
+ if name == 'meanaudio_s_ac' or 'meanaudio_s_full':
604
  return meanaudio_s(**kwargs)
605
+ if name == 'fluxaudio_s_full':
606
  return fluxaudio_s(**kwargs)
607
 
608
  raise ValueError(f'Unknown model name: {name}')