junxiliu commited on
Commit
0ad0454
·
1 Parent(s): 3a1da90

add files used by space

Browse files
Files changed (4) hide show
  1. MeanAudio.py +147 -0
  2. app.py +421 -0
  3. easyinfer.py +3 -0
  4. requirements.txt +27 -0
MeanAudio.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore", category=FutureWarning)
3
+ import logging
4
+ from pathlib import Path
5
+ import torch
6
+ import torchaudio
7
+ from meanaudio.eval_utils import (ModelConfig, all_model_cfg, generate_mf, generate_fm, setup_eval_logging)
8
+ from meanaudio.model.flow_matching import FlowMatching
9
+ from meanaudio.model.mean_flow import MeanFlow
10
+ from meanaudio.model.networks import MeanAudio, get_mean_audio
11
+ from meanaudio.model.utils.features_utils import FeaturesUtils
12
+ from huggingface_hub import snapshot_download
13
+
14
+ torch.backends.cuda.matmul.allow_tf32 = True
15
+ torch.backends.cudnn.allow_tf32 = True
16
+ log = logging.getLogger()
17
+
18
+ @torch.inference_mode()
19
+ def MeanAudioInference(
20
+ prompt='',
21
+ negative_prompt='',
22
+ model_path='',
23
+ encoder_name='t5_clap',
24
+ variant='meanaudio_mf',
25
+ duration=10,
26
+ cfg_strength=4.5,
27
+ num_steps=1,
28
+ output='./output',
29
+ seed=42,
30
+ full_precision=False,
31
+ use_rope=True,
32
+ text_c_dim=512,
33
+ use_meanflow=False
34
+ ):
35
+ '''
36
+ prompt (str):
37
+ The text description guiding the audio generation (e.g., "a dog is barking").
38
+ negative_prompt (str):
39
+ A text description for sounds that should be avoided in the generated audio.
40
+ model_path (str):
41
+ Path to the model weights file. If empty, it defaults to ./weights/{variant}.pth.
42
+ encoder_name (str):
43
+ Specifies the text encoder to use (default: 't5_clap').
44
+ variant (str):
45
+ Specifies the model variant to load (default: 'meanaudio_mf'). Must be a key in all_model_cfg.
46
+ duration (int):
47
+ The desired duration of the generated audio in seconds (default: 10).
48
+ cfg_strength (float):
49
+ Classifier-Free Guidance strength. Ignored if use_meanflow is True or variant is 'meanaudio_mf' (default: 4.5).
50
+ num_steps (int):
51
+ Number of steps for the generation process (default: 1).
52
+ output (str):
53
+ Directory path where the generated audio file will be saved (default: './output').
54
+ seed (int):
55
+ Random seed for generation reproducibility (default: 42).
56
+ full_precision (bool):
57
+ If True, uses torch.float32 precision; otherwise, uses torch.bfloat16 (default: False).
58
+ use_rope (bool):
59
+ Whether to use Rotary Position Embedding in the model (default: True).
60
+ text_c_dim (int):
61
+ Dimension of the text context vector (default: 512).
62
+ use_meanflow (bool):
63
+ If True, uses the MeanFlow generation method; otherwise, uses FlowMatching. If variant is 'meanaudio_mf', this is automatically set to True (default: False).
64
+ '''
65
+ setup_eval_logging()
66
+ output_dir = Path(output).expanduser()
67
+ output_dir.mkdir(parents=True, exist_ok=True)
68
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
69
+ dtype = torch.float32 if full_precision else torch.bfloat16
70
+ if duration <= 0 or num_steps <= 0:
71
+ raise ValueError("Duration and number of steps must be positive.")
72
+ if variant not in all_model_cfg:
73
+ raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}")
74
+ if not model_path or model_path == '':
75
+ model_path = Path(f'./weights/{variant}.pth')
76
+ else:
77
+ model_path = Path(model_path)
78
+ if not model_path.exists():
79
+ if str(model_path) == f'./weights/{variant}.pth':
80
+ log.info(f'Model not found at {model_path}')
81
+ log.info('Downloading models to "./weights/"...')
82
+ try:
83
+ weights_dir = Path('./weights')
84
+ weights_dir.mkdir(exist_ok=True)
85
+ snapshot_download(repo_id="junxiliu/Meanaudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
86
+ raise NotImplementedError("Model download functionality needs to be implemented")
87
+ except Exception as e:
88
+ log.error(f"Failed to download model: {e}")
89
+ raise FileNotFoundError(f"Model file not found and download failed: {model_path}")
90
+ else:
91
+ raise FileNotFoundError(f"Model file not found: {model_path}")
92
+
93
+ model = all_model_cfg[variant]
94
+ seq_cfg = model.seq_cfg
95
+ seq_cfg.duration = duration
96
+
97
+ net = get_mean_audio(model.model_name, use_rope=use_rope, text_c_dim=text_c_dim)
98
+ net = net.to(device, dtype).eval()
99
+ net.load_weights(torch.load(model_path, map_location=device, weights_only=True))
100
+ net.update_seq_lengths(seq_cfg.latent_seq_len)
101
+
102
+ if variant=='meanaudio_mf':
103
+ use_meanflow=True
104
+ if use_meanflow:
105
+ generation_func = MeanFlow(steps=num_steps)
106
+ cfg_strength=0
107
+ else:
108
+ generation_func = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
109
+
110
+ feature_utils = FeaturesUtils(
111
+ tod_vae_ckpt=model.vae_path,
112
+ enable_conditions=True,
113
+ encoder_name=encoder_name,
114
+ mode=model.mode,
115
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
116
+ need_vae_encoder=False
117
+ )
118
+ feature_utils = feature_utils.to(device, dtype).eval()
119
+
120
+ rng = torch.Generator(device=device)
121
+ rng.manual_seed(seed)
122
+
123
+ generate_fn = generate_mf if use_meanflow else generate_fm
124
+ kwargs = {
125
+ 'negative_text': [negative_prompt],
126
+ 'feature_utils': feature_utils,
127
+ 'net': net,
128
+ 'rng': rng,
129
+ 'cfg_strength': cfg_strength
130
+ }
131
+
132
+ if use_meanflow:
133
+ kwargs['mf'] = generation_func
134
+ else:
135
+ kwargs['fm'] = generation_func
136
+
137
+ audios = generate_fn([prompt], **kwargs)
138
+ audio = audios.float().cpu()[0]
139
+ safe_filename = prompt.replace(' ', '_').replace('/', '_').replace('.', '')
140
+ save_path = output_dir / f'{safe_filename}--numsteps{num_steps}--seed{seed}.wav'
141
+ torchaudio.save(save_path, audio, seq_cfg.sampling_rate)
142
+ log.info(f'Audio saved to {save_path}')
143
+ log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
144
+ return save_path
145
+
146
+ if __name__ == '__main__':
147
+ MeanAudioInference('a dog is barking')
app.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import warnings
3
+
4
+ warnings.filterwarnings("ignore", category=FutureWarning)
5
+ import logging
6
+ from argparse import ArgumentParser
7
+ from pathlib import Path
8
+ import torch
9
+ import torchaudio
10
+ import gradio as gr
11
+ from meanaudio.eval_utils import (
12
+ ModelConfig,
13
+ all_model_cfg,
14
+ generate_mf,
15
+ generate_fm,
16
+ setup_eval_logging,
17
+ )
18
+ from meanaudio.model.flow_matching import FlowMatching
19
+ from meanaudio.model.mean_flow import MeanFlow
20
+ from meanaudio.model.networks import MeanAudio, get_mean_audio
21
+ from meanaudio.model.utils.features_utils import FeaturesUtils
22
+
23
+ torch.backends.cuda.matmul.allow_tf32 = True
24
+ torch.backends.cudnn.allow_tf32 = True
25
+ import gc
26
+ from datetime import datetime
27
+
28
+ log = logging.getLogger()
29
+
30
+ device = "cpu"
31
+ if torch.cuda.is_available():
32
+ device = "cuda"
33
+ elif torch.backends.mps.is_available():
34
+ device = "mps"
35
+ else:
36
+ log.warning("CUDA/MPS are not available, running on CPU")
37
+ setup_eval_logging()
38
+
39
+ OUTPUT_DIR = Path("./output/gradio")
40
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
41
+
42
+ current_model_state = {
43
+ "net": None,
44
+ "feature_utils": None,
45
+ "seq_cfg": None,
46
+ "args": None,
47
+ }
48
+
49
+
50
+ def load_model_if_needed(
51
+ variant, model_path, encoder_name, use_rope, text_c_dim, full_precision
52
+ ):
53
+ global current_model_state
54
+ dtype = torch.float32 if full_precision else torch.bfloat16
55
+ needs_reload = (
56
+ current_model_state["args"] is None
57
+ or current_model_state["args"].variant != variant
58
+ or current_model_state["args"].model_path != model_path
59
+ or current_model_state["args"].encoder_name != encoder_name
60
+ or current_model_state["args"].use_rope != use_rope
61
+ or current_model_state["args"].text_c_dim != text_c_dim
62
+ or current_model_state["args"].full_precision != full_precision
63
+ )
64
+ if needs_reload:
65
+ try:
66
+ if variant not in all_model_cfg:
67
+ raise ValueError(f"Unknown model variant: {variant}")
68
+ model: ModelConfig = all_model_cfg[variant]
69
+ seq_cfg = model.seq_cfg
70
+
71
+ class MockArgs:
72
+ pass
73
+
74
+ mock_args = MockArgs()
75
+ mock_args.variant = variant
76
+ mock_args.model_path = model_path
77
+ mock_args.encoder_name = encoder_name
78
+ mock_args.use_rope = use_rope
79
+ mock_args.text_c_dim = text_c_dim
80
+ mock_args.full_precision = full_precision
81
+
82
+ net: MeanAudio = (
83
+ get_mean_audio(
84
+ model.model_name,
85
+ use_rope=mock_args.use_rope,
86
+ text_c_dim=mock_args.text_c_dim,
87
+ )
88
+ .to(device, dtype)
89
+ .eval()
90
+ )
91
+ net.load_weights(
92
+ torch.load(
93
+ mock_args.model_path, map_location=device, weights_only=True
94
+ )
95
+ )
96
+ log.info(f"Loaded weights from {mock_args.model_path}")
97
+
98
+ feature_utils = FeaturesUtils(
99
+ tod_vae_ckpt=model.vae_path,
100
+ enable_conditions=True,
101
+ encoder_name=mock_args.encoder_name,
102
+ mode=model.mode,
103
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
104
+ need_vae_encoder=False,
105
+ )
106
+ feature_utils = feature_utils.to(device, dtype).eval()
107
+
108
+ current_model_state["net"] = net
109
+ current_model_state["feature_utils"] = feature_utils
110
+ current_model_state["seq_cfg"] = seq_cfg
111
+ current_model_state["args"] = mock_args
112
+ log.info(f"Model '{variant}' loaded successfully.")
113
+ return True
114
+ except Exception as e:
115
+ log.error(f"Error loading model: {e}")
116
+
117
+ current_model_state = {
118
+ "net": None,
119
+ "feature_utils": None,
120
+ "seq_cfg": None,
121
+ "args": None,
122
+ }
123
+ raise e
124
+ else:
125
+ log.info(f"Model '{variant}' already loaded with current settings.")
126
+ return False
127
+
128
+
129
+ @torch.inference_mode()
130
+ def generate_audio_gradio(
131
+ prompt,
132
+ negative_prompt,
133
+ duration,
134
+ cfg_strength,
135
+ num_steps,
136
+ seed,
137
+ variant,
138
+ full_precision,
139
+ ):
140
+ global current_model_state
141
+ use_meanflow = variant == "meanaudio_mf"
142
+
143
+ model_path = (
144
+ "./weights/meanaudio_mf.pth"
145
+ if use_meanflow
146
+ else "./weights/fluxaudio_fm.pth"
147
+ )
148
+ encoder_name = "t5_clap"
149
+ use_rope = True
150
+ text_c_dim = 512
151
+
152
+ try:
153
+ load_model_if_needed(
154
+ variant, model_path, encoder_name, use_rope, text_c_dim, full_precision
155
+ )
156
+ except Exception as e:
157
+ return f"Error loading model: {str(e)}", None
158
+
159
+ if current_model_state["net"] is None:
160
+ return "Error: Model could not be loaded.", None
161
+ net = current_model_state["net"]
162
+ feature_utils = current_model_state["feature_utils"]
163
+ seq_cfg = current_model_state["seq_cfg"]
164
+
165
+ args = current_model_state["args"]
166
+ dtype = torch.float32 if args.full_precision else torch.bfloat16
167
+
168
+ try:
169
+ seq_cfg.duration = duration
170
+ net.update_seq_lengths(seq_cfg.latent_seq_len)
171
+
172
+ rng = torch.Generator(device=device)
173
+ if seed >= 0:
174
+ rng.manual_seed(seed)
175
+ else:
176
+ rng.seed()
177
+
178
+ if use_meanflow:
179
+ sampler = MeanFlow(steps=num_steps)
180
+ log.info("Using MeanFlow for generation.")
181
+ generation_func = generate_mf
182
+ sampler_arg_name = "mf"
183
+ cfg_strength = 3
184
+ else:
185
+ sampler = FlowMatching(
186
+ min_sigma=0, inference_mode="euler", num_steps=num_steps
187
+ )
188
+ log.info("Using FlowMatching for generation.")
189
+ generation_func = generate_fm
190
+ sampler_arg_name = "fm"
191
+
192
+ prompts = [prompt]
193
+
194
+ audios = generation_func(
195
+ prompts,
196
+ negative_text=[negative_prompt],
197
+ feature_utils=feature_utils,
198
+ net=net,
199
+ rng=rng,
200
+ cfg_strength=cfg_strength,
201
+ **{sampler_arg_name: sampler},
202
+ )
203
+ audio = audios.float().cpu()[0]
204
+
205
+ safe_prompt = (
206
+ "".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
207
+ .rstrip()
208
+ .replace(" ", "_")[:50]
209
+ )
210
+ current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
211
+ filename = f"{safe_prompt}_{current_time_string}.flac"
212
+ save_path = OUTPUT_DIR / filename
213
+ torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
214
+ log.info(f"Audio saved to {save_path}")
215
+
216
+ gc.collect()
217
+
218
+ return (
219
+ f"Generated audio for prompt: '{prompt}' using {'MeanFlow' if use_meanflow else 'FlowMatching'}",
220
+ str(save_path),
221
+ )
222
+ except Exception as e:
223
+ gc.collect()
224
+ log.error(f"Generation error: {e}")
225
+ return f"Error during generation: {str(e)}", None
226
+
227
+
228
+ theme = gr.themes.Soft(
229
+ primary_hue="blue",
230
+ secondary_hue="slate",
231
+ neutral_hue="slate",
232
+ text_size="sm",
233
+ spacing_size="sm",
234
+ ).set(
235
+ background_fill_primary="*neutral_50",
236
+ background_fill_secondary="*background_fill_primary",
237
+ block_background_fill="*background_fill_primary",
238
+ block_border_width="0px",
239
+ panel_background_fill="*neutral_50",
240
+ panel_border_width="0px",
241
+ input_background_fill="*neutral_100",
242
+ input_border_color="*neutral_200",
243
+ button_primary_background_fill="*primary_300",
244
+ button_primary_background_fill_hover="*primary_400",
245
+ button_secondary_background_fill="*neutral_200",
246
+ button_secondary_background_fill_hover="*neutral_300",
247
+ )
248
+
249
+ custom_css = """
250
+ #main-header {
251
+ text-align: center;
252
+ margin-top: 5px;
253
+ margin-bottom: 10px;
254
+ color: var(--neutral-600);
255
+ font-weight: 600;
256
+ }
257
+ #model-settings-header, #generation-settings-header {
258
+ color: var(--neutral-600);
259
+ margin-top: 8px;
260
+ margin-bottom: 8px;
261
+ font-weight: 500;
262
+ font-size: 1.1em;
263
+ }
264
+ .setting-section {
265
+ padding: 10px 12px;
266
+ border-radius: 6px;
267
+ background-color: var(--neutral-50);
268
+ margin-bottom: 10px;
269
+ border: 1px solid var(--neutral-100);
270
+ }
271
+ hr {
272
+ border: none;
273
+ height: 1px;
274
+ background-color: var(--neutral-200);
275
+ margin: 8px 0;
276
+ }
277
+ #generate-btn {
278
+ width: 100%;
279
+ max-width: 250px;
280
+ margin: 10px auto;
281
+ display: block;
282
+ padding: 10px 15px;
283
+ font-size: 16px;
284
+ border-radius: 5px;
285
+ }
286
+ #status-box {
287
+ min-height: 50px;
288
+ display: flex;
289
+ align-items: center;
290
+ justify-content: center;
291
+ padding: 8px;
292
+ border-radius: 5px;
293
+ border: 1px solid var(--neutral-200);
294
+ color: var(--neutral-700);
295
+ }
296
+ #audio-output {
297
+ height: 100px;
298
+ border-radius: 5px;
299
+ border: 1px solid var(--neutral-200);
300
+ }
301
+ .gradio-dropdown label, .gradio-checkbox label, .gradio-number label, .gradio-textbox label {
302
+ font-weight: 500;
303
+ color: var(--neutral-700);
304
+ font-size: 0.9em;
305
+ }
306
+ .gradio-row {
307
+ gap: 8px;
308
+ }
309
+ .gradio-block {
310
+ margin-bottom: 8px;
311
+ }
312
+ .setting-section .gradio-block {
313
+ margin-bottom: 6px;
314
+ }
315
+ ::-webkit-scrollbar {
316
+ width: 8px;
317
+ height: 8px;
318
+ }
319
+ ::-webkit-scrollbar-track {
320
+ background: var(--neutral-100);
321
+ border-radius: 4px;
322
+ }
323
+ ::-webkit-scrollbar-thumb {
324
+ background: var(--neutral-300);
325
+ border-radius: 4px;
326
+ }
327
+ ::-webkit-scrollbar-thumb:hover {
328
+ background: var(--neutral-400);
329
+ }
330
+ * {
331
+ scrollbar-width: thin;
332
+ scrollbar-color: var(--neutral-300) var(--neutral-100);
333
+ }
334
+ """
335
+
336
+ with gr.Blocks(title="MeanAudio Generator", theme=theme, css=custom_css) as demo:
337
+ gr.Markdown("# MeanAudio Text-to-Audio Generator", elem_id="main-header")
338
+
339
+ gr.Markdown("### Model and Generation Settings", elem_id="model-settings-header")
340
+ with gr.Column(elem_classes="setting-section"):
341
+ with gr.Row():
342
+ available_variants = (
343
+ list(all_model_cfg.keys()) if all_model_cfg else []
344
+ )
345
+ default_variant = (
346
+ "small_16k_mf"
347
+ if "small_16k_mf" in available_variants
348
+ else available_variants[0] if available_variants else ""
349
+ )
350
+ variant = gr.Dropdown(
351
+ label="Model Variant",
352
+ choices=available_variants,
353
+ value=default_variant,
354
+ interactive=True,
355
+ scale=3,
356
+ )
357
+ full_precision = gr.Checkbox(
358
+ label="Full Precision (float32)", value=True, scale=1
359
+ )
360
+
361
+ gr.Markdown("### Audio Generation", elem_id="generation-settings-header")
362
+ with gr.Column(elem_classes="setting-section"):
363
+ with gr.Row():
364
+ prompt = gr.Textbox(
365
+ label="Prompt",
366
+ placeholder="Describe the sound you want to generate...",
367
+ scale=1,
368
+ )
369
+ negative_prompt = gr.Textbox(
370
+ label="Negative Prompt",
371
+ placeholder="Describe sounds you want to avoid...",
372
+ value="",
373
+ scale=1,
374
+ )
375
+ with gr.Row():
376
+ duration = gr.Number(
377
+ label="Duration (sec)", value=10.0, minimum=0.1, scale=1
378
+ )
379
+ cfg_strength = gr.Number(
380
+ label="CFG (Meanflow forced to 3)", value=3, minimum=0.0, scale=1
381
+ )
382
+ with gr.Row():
383
+ seed = gr.Number(
384
+ label="Seed (-1 for random)", value=42, precision=0, scale=1
385
+ )
386
+ num_steps = gr.Number(
387
+ label="Number of Steps",
388
+ value=1,
389
+ precision=0,
390
+ minimum=1,
391
+ scale=1,
392
+ )
393
+
394
+ generate_button = gr.Button("Generate", variant="primary", elem_id="generate-btn")
395
+ generate_output_text = gr.Textbox(
396
+ label="Result Status", interactive=False, elem_id="status-box"
397
+ )
398
+ audio_output = gr.Audio(
399
+ label="Generated Audio", type="filepath", elem_id="audio-output"
400
+ )
401
+ generate_button.click(
402
+ fn=generate_audio_gradio,
403
+ inputs=[
404
+ prompt,
405
+ negative_prompt,
406
+ duration,
407
+ cfg_strength,
408
+ num_steps,
409
+ seed,
410
+ variant,
411
+ full_precision,
412
+ ],
413
+ outputs=[generate_output_text, audio_output],
414
+ )
415
+
416
+ if __name__ == "__main__":
417
+ parser = ArgumentParser()
418
+ parser.add_argument("--port", type=int, default=7861)
419
+ args = parser.parse_args()
420
+ demo.launch(server_port=args.port, allowed_paths=[OUTPUT_DIR.resolve()])
421
+
easyinfer.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from MeanAudio import MeanAudioInference
2
+ audio_path=MeanAudioInference('a dog is barking')
3
+ print(audio_path)
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.5.1
2
+ huggingface_hub>=0.26
3
+ cython
4
+ gitpython>=3.1
5
+ tensorboard>=2.11
6
+ numpy>=1.21,<2.1
7
+ Pillow>=9.5
8
+ opencv-python>=4.8
9
+ scipy>=1.7
10
+ tqdm>=4.66.1
11
+ gradio>=3.34
12
+ einops>=0.6
13
+ hydra-core>=1.3.2
14
+ requests
15
+ torchdiffeq>=0.2.5
16
+ librosa>=0.8.1
17
+ nitrous-ema
18
+ hydra_colorlog
19
+ tensordict>=0.6.1
20
+ colorlog
21
+ open_clip_torch>=2.29.0
22
+ av>=14.0.1
23
+ timm>=1.0.12
24
+ python-dotenv
25
+ transformers
26
+ debugpy
27
+ laion-clap