junxiliu commited on
Commit
2496580
·
1 Parent(s): 508272b

change demo.py

Browse files
Files changed (1) hide show
  1. app.py +3 -14
app.py CHANGED
@@ -1,6 +1,6 @@
1
 
2
  import warnings
3
-
4
  warnings.filterwarnings("ignore", category=FutureWarning)
5
  import logging
6
  from argparse import ArgumentParser
@@ -27,13 +27,6 @@ from datetime import datetime
27
 
28
  log = logging.getLogger()
29
 
30
- device = "cpu"
31
- if torch.cuda.is_available():
32
- device = "cuda"
33
- elif torch.backends.mps.is_available():
34
- device = "mps"
35
- else:
36
- log.warning("CUDA/MPS are not available, running on CPU")
37
  setup_eval_logging()
38
 
39
  OUTPUT_DIR = Path("./output/gradio")
@@ -125,8 +118,7 @@ def load_model_if_needed(
125
  log.info(f"Model '{variant}' already loaded with current settings.")
126
  return False
127
 
128
-
129
- @torch.inference_mode()
130
  def generate_audio_gradio(
131
  prompt,
132
  negative_prompt,
@@ -414,8 +406,5 @@ with gr.Blocks(title="MeanAudio Generator", theme=theme, css=custom_css) as demo
414
  )
415
 
416
  if __name__ == "__main__":
417
- parser = ArgumentParser()
418
- parser.add_argument("--port", type=int, default=7861)
419
- args = parser.parse_args()
420
- demo.launch(server_port=args.port, allowed_paths=[OUTPUT_DIR.resolve()])
421
 
 
1
 
2
  import warnings
3
+ import space
4
  warnings.filterwarnings("ignore", category=FutureWarning)
5
  import logging
6
  from argparse import ArgumentParser
 
27
 
28
  log = logging.getLogger()
29
 
 
 
 
 
 
 
 
30
  setup_eval_logging()
31
 
32
  OUTPUT_DIR = Path("./output/gradio")
 
118
  log.info(f"Model '{variant}' already loaded with current settings.")
119
  return False
120
 
121
+ @spaces.GPU
 
122
  def generate_audio_gradio(
123
  prompt,
124
  negative_prompt,
 
406
  )
407
 
408
  if __name__ == "__main__":
409
+ demo.launch()
 
 
 
410