junxiliu commited on
Commit
ec164a8
·
1 Parent(s): ce52318
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -1,6 +1,6 @@
1
 
2
  import warnings
3
- import spaces#space
4
  warnings.filterwarnings("ignore", category=FutureWarning)
5
  import logging
6
  from argparse import ArgumentParser
@@ -27,6 +27,10 @@ from datetime import datetime
27
 
28
  log = logging.getLogger()
29
 
 
 
 
 
30
  setup_eval_logging()
31
 
32
  OUTPUT_DIR = Path("./output/gradio")
@@ -335,9 +339,7 @@ with gr.Blocks(title="MeanAudio Generator", theme=theme, css=custom_css) as demo
335
  list(all_model_cfg.keys()) if all_model_cfg else []
336
  )
337
  default_variant = (
338
- "small_16k_mf"
339
- if "small_16k_mf" in available_variants
340
- else available_variants[0] if available_variants else ""
341
  )
342
  variant = gr.Dropdown(
343
  label="Model Variant",
@@ -406,5 +408,8 @@ with gr.Blocks(title="MeanAudio Generator", theme=theme, css=custom_css) as demo
406
  )
407
 
408
  if __name__ == "__main__":
409
- demo.launch()
 
 
 
410
 
 
1
 
2
  import warnings
3
+ import spaces
4
  warnings.filterwarnings("ignore", category=FutureWarning)
5
  import logging
6
  from argparse import ArgumentParser
 
27
 
28
  log = logging.getLogger()
29
 
30
+ device = "cpu"
31
+ if torch.cuda.is_available():
32
+ device = "cuda"
33
+
34
  setup_eval_logging()
35
 
36
  OUTPUT_DIR = Path("./output/gradio")
 
339
  list(all_model_cfg.keys()) if all_model_cfg else []
340
  )
341
  default_variant = (
342
+ 'meanaudio_mf'
 
 
343
  )
344
  variant = gr.Dropdown(
345
  label="Model Variant",
 
408
  )
409
 
410
  if __name__ == "__main__":
411
+ parser = ArgumentParser()
412
+ parser.add_argument("--port", type=int, default=7861)
413
+ args = parser.parse_args()
414
+ demo.launch(server_port=args.port, allowed_paths=[OUTPUT_DIR.resolve()])
415