update
Browse files
app.py
CHANGED
@@ -1,528 +0,0 @@
|
|
1 |
-
import warnings
|
2 |
-
import spaces
|
3 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
4 |
-
import logging
|
5 |
-
from argparse import ArgumentParser
|
6 |
-
from pathlib import Path
|
7 |
-
import torch
|
8 |
-
import torchaudio
|
9 |
-
import gradio as gr
|
10 |
-
from transformers import AutoModel
|
11 |
-
from meanaudio.eval_utils import (
|
12 |
-
ModelConfig,
|
13 |
-
all_model_cfg,
|
14 |
-
generate_mf,
|
15 |
-
generate_fm,
|
16 |
-
setup_eval_logging,
|
17 |
-
)
|
18 |
-
from meanaudio.model.flow_matching import FlowMatching
|
19 |
-
from meanaudio.model.mean_flow import MeanFlow
|
20 |
-
from meanaudio.model.networks import MeanAudio, get_mean_audio
|
21 |
-
from meanaudio.model.utils.features_utils import FeaturesUtils
|
22 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
23 |
-
torch.backends.cudnn.allow_tf32 = True
|
24 |
-
import gc
|
25 |
-
from datetime import datetime
|
26 |
-
from huggingface_hub import snapshot_download
|
27 |
-
|
28 |
-
log = logging.getLogger()
|
29 |
-
device = "cpu"
|
30 |
-
if torch.cuda.is_available():
|
31 |
-
device = "cuda"
|
32 |
-
setup_eval_logging()
|
33 |
-
OUTPUT_DIR = Path("./output/gradio")
|
34 |
-
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
35 |
-
|
36 |
-
# --- 预下载依赖 ---
|
37 |
-
snapshot_download(repo_id="google/flan-t5-large")
|
38 |
-
#snapshot_download(repo_id="google-bert/bert-base-uncased")
|
39 |
-
a=AutoModel.from_pretrained('bert-base-uncased')
|
40 |
-
b=AutoModel.from_pretrained('roberta-base')
|
41 |
-
#snapshot_download(repo_id="FacebookAI/roberta-base")
|
42 |
-
snapshot_download(repo_id="junxiliu/Meanaudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
|
43 |
-
|
44 |
-
# --- 模型状态存储 (修改为字典存储多个模型) ---
|
45 |
-
# Stores states for multiple models, keyed by variant name
|
46 |
-
current_model_states = {
|
47 |
-
# Example structure:
|
48 |
-
# "meanaudio_mf": { "net": ..., "feature_utils": ..., "seq_cfg": ..., "args": ... },
|
49 |
-
# "fluxaudio_fm": { "net": ..., "feature_utils": ..., "seq_cfg": ..., "args": ... },
|
50 |
-
}
|
51 |
-
|
52 |
-
# --- 模型加载逻辑 (修改以支持多模型) ---
|
53 |
-
def load_model_if_needed(
|
54 |
-
variant, model_path, encoder_name, use_rope, text_c_dim
|
55 |
-
):
|
56 |
-
global current_model_states
|
57 |
-
dtype = torch.float32
|
58 |
-
|
59 |
-
# Check if this specific variant with these args is already loaded
|
60 |
-
existing_state = current_model_states.get(variant)
|
61 |
-
needs_reload = (
|
62 |
-
existing_state is None
|
63 |
-
or existing_state["args"].variant != variant
|
64 |
-
or existing_state["args"].model_path != model_path
|
65 |
-
or existing_state["args"].encoder_name != encoder_name
|
66 |
-
or existing_state["args"].use_rope != use_rope
|
67 |
-
or existing_state["args"].text_c_dim != text_c_dim
|
68 |
-
|
69 |
-
)
|
70 |
-
|
71 |
-
if needs_reload:
|
72 |
-
log.info(f"Loading/reloading model '{variant}'.")
|
73 |
-
if variant not in all_model_cfg:
|
74 |
-
raise ValueError(f"Unknown model variant: {variant}")
|
75 |
-
model: ModelConfig = all_model_cfg[variant]
|
76 |
-
seq_cfg = model.seq_cfg
|
77 |
-
|
78 |
-
# Create mock args object
|
79 |
-
class MockArgs:
|
80 |
-
pass
|
81 |
-
mock_args = MockArgs()
|
82 |
-
mock_args.variant = variant
|
83 |
-
mock_args.model_path = model_path
|
84 |
-
mock_args.encoder_name = encoder_name
|
85 |
-
mock_args.use_rope = use_rope
|
86 |
-
mock_args.text_c_dim = text_c_dim
|
87 |
-
|
88 |
-
|
89 |
-
# Load network
|
90 |
-
net: MeanAudio = (
|
91 |
-
get_mean_audio(
|
92 |
-
model.model_name,
|
93 |
-
use_rope=mock_args.use_rope,
|
94 |
-
text_c_dim=mock_args.text_c_dim,
|
95 |
-
)
|
96 |
-
.to(device, dtype)
|
97 |
-
.eval()
|
98 |
-
)
|
99 |
-
net.load_weights(
|
100 |
-
torch.load(
|
101 |
-
mock_args.model_path, map_location=device, weights_only=True
|
102 |
-
)
|
103 |
-
)
|
104 |
-
log.info(f"Loaded weights from {mock_args.model_path}")
|
105 |
-
|
106 |
-
# Load feature utils
|
107 |
-
feature_utils = FeaturesUtils(
|
108 |
-
tod_vae_ckpt=model.vae_path,
|
109 |
-
enable_conditions=True,
|
110 |
-
encoder_name=mock_args.encoder_name,
|
111 |
-
mode=model.mode,
|
112 |
-
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
113 |
-
need_vae_encoder=False,
|
114 |
-
)
|
115 |
-
feature_utils = feature_utils.to(device, dtype).eval()
|
116 |
-
|
117 |
-
# Store the loaded model state
|
118 |
-
current_model_states[variant] = {
|
119 |
-
"net": net,
|
120 |
-
"feature_utils": feature_utils,
|
121 |
-
"seq_cfg": seq_cfg,
|
122 |
-
"args": mock_args,
|
123 |
-
}
|
124 |
-
log.info(f"Model '{variant}' loaded successfully.")
|
125 |
-
# Return the loaded components for immediate use if needed
|
126 |
-
return net, feature_utils, seq_cfg, mock_args
|
127 |
-
|
128 |
-
else:
|
129 |
-
log.info(f"Model '{variant}' already loaded with current settings. Skipping reload.")
|
130 |
-
# Return the existing components
|
131 |
-
return existing_state["net"], existing_state["feature_utils"], existing_state["seq_cfg"], existing_state["args"]
|
132 |
-
|
133 |
-
# --- 初始化函数,用于预加载所有默认模型 ---
|
134 |
-
def initialize_all_default_models():
|
135 |
-
"""Function to load all default model configurations at startup."""
|
136 |
-
log.info("Initializing default models...")
|
137 |
-
default_models = ['meanaudio_mf', 'fluxaudio_fm'] # List of default variants to load
|
138 |
-
common_params = {
|
139 |
-
"encoder_name": "t5_clap",
|
140 |
-
"use_rope": True,
|
141 |
-
"text_c_dim": 512,
|
142 |
-
# Match the default value in the UI or desired startup precision
|
143 |
-
}
|
144 |
-
|
145 |
-
for variant in default_models:
|
146 |
-
model_path = f"./weights/{variant}.pth"
|
147 |
-
# This will load the model if it's not already loaded or if params differ
|
148 |
-
try:
|
149 |
-
load_model_if_needed(
|
150 |
-
variant, model_path, **common_params
|
151 |
-
)
|
152 |
-
log.info(f"Default model '{variant}' initialized successfully.")
|
153 |
-
except Exception as e:
|
154 |
-
log.error(f"Failed to initialize default model '{variant}': {e}")
|
155 |
-
# Depending on requirements, decide if failure to load one model should stop the app
|
156 |
-
# For now, we'll just log and continue trying others.
|
157 |
-
initialize_all_default_models()
|
158 |
-
# --- 带有 GPU 装饰器的生成函数 (修改为使用预加载模型) ---
|
159 |
-
@spaces.GPU(duration=8) # This decorator now primarily handles moving the *already loaded* model/functionality to the GPU instance for execution
|
160 |
-
@torch.inference_mode()
|
161 |
-
def generate_audio_gradio(
|
162 |
-
prompt,
|
163 |
-
negative_prompt,
|
164 |
-
duration,
|
165 |
-
cfg_strength,
|
166 |
-
num_steps,
|
167 |
-
seed,
|
168 |
-
variant, # This determines which pre-loaded model to use
|
169 |
-
|
170 |
-
):
|
171 |
-
global current_model_states
|
172 |
-
# Determine model parameters based on input (mainly for path)
|
173 |
-
model_path = f"./weights/{variant}.pth" # Determine path based on variant
|
174 |
-
encoder_name = "t5_clap"
|
175 |
-
use_rope = True
|
176 |
-
text_c_dim = 512
|
177 |
-
|
178 |
-
# --- 关键修改点:获取已加载的模型组件 ---
|
179 |
-
# Fetch the pre-loaded model components based on the selected variant
|
180 |
-
model_state = current_model_states.get(variant)
|
181 |
-
if model_state is None:
|
182 |
-
# This case should ideally not happen if initialization was successful,
|
183 |
-
# but handle it gracefully in case of unexpected state.
|
184 |
-
error_msg = f"Error: Model '{variant}' is not available. It may not have been loaded correctly during startup."
|
185 |
-
log.error(error_msg)
|
186 |
-
return error_msg, None
|
187 |
-
|
188 |
-
# Use the pre-loaded components
|
189 |
-
net = model_state["net"]
|
190 |
-
feature_utils = model_state["feature_utils"]
|
191 |
-
seq_cfg = model_state["seq_cfg"]
|
192 |
-
# Use the args stored with the model for consistency (e.g., dtype)
|
193 |
-
args = model_state["args"]
|
194 |
-
# Ensure dtype consistency based on the loaded model's args or UI input
|
195 |
-
dtype = torch.float32 # Use UI input for dtype consistency
|
196 |
-
|
197 |
-
# --- Rest of your generation logic ---
|
198 |
-
# Update sequence length based on duration
|
199 |
-
temp_seq_cfg = type(seq_cfg)(**seq_cfg.__dict__) # Create a temporary copy to modify duration
|
200 |
-
temp_seq_cfg.duration = duration
|
201 |
-
# Update network sequence lengths
|
202 |
-
net.update_seq_lengths(temp_seq_cfg.latent_seq_len)
|
203 |
-
|
204 |
-
# Setup random number generator
|
205 |
-
rng = torch.Generator(device=device)
|
206 |
-
if seed >= 0:
|
207 |
-
rng.manual_seed(seed)
|
208 |
-
else:
|
209 |
-
rng.seed()
|
210 |
-
|
211 |
-
# Select sampler and generation function based on variant
|
212 |
-
use_meanflow = variant == "meanaudio_mf"
|
213 |
-
if use_meanflow:
|
214 |
-
sampler = MeanFlow(steps=num_steps)
|
215 |
-
log.info("Using MeanFlow for generation.")
|
216 |
-
generation_func = generate_mf
|
217 |
-
sampler_arg_name = "mf"
|
218 |
-
# Note: cfg_strength is forced to 3 for MeanFlow inside the function
|
219 |
-
cfg_strength = 3
|
220 |
-
else:
|
221 |
-
sampler = FlowMatching(
|
222 |
-
min_sigma=0, inference_mode="euler", num_steps=num_steps
|
223 |
-
)
|
224 |
-
log.info("Using FlowMatching for generation.")
|
225 |
-
generation_func = generate_fm
|
226 |
-
sampler_arg_name = "fm"
|
227 |
-
|
228 |
-
# Perform generation
|
229 |
-
prompts = [prompt]
|
230 |
-
audios = generation_func(
|
231 |
-
prompts,
|
232 |
-
negative_text=[negative_prompt],
|
233 |
-
feature_utils=feature_utils,
|
234 |
-
net=net,
|
235 |
-
rng=rng,
|
236 |
-
cfg_strength=cfg_strength,
|
237 |
-
**{sampler_arg_name: sampler},
|
238 |
-
)
|
239 |
-
audio = audios.float().cpu()[0] # Get the first generated audio and move to CPU
|
240 |
-
|
241 |
-
# Save the generated audio
|
242 |
-
safe_prompt = (
|
243 |
-
"".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
|
244 |
-
.rstrip()
|
245 |
-
.replace(" ", "_")[:50]
|
246 |
-
)
|
247 |
-
current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
248 |
-
filename = f"{safe_prompt}_{current_time_string}.flac"
|
249 |
-
save_path = OUTPUT_DIR / filename
|
250 |
-
torchaudio.save(str(save_path), audio, temp_seq_cfg.sampling_rate) # Use temp_seq_cfg for correct SR
|
251 |
-
log.info(f"Audio saved to {save_path}")
|
252 |
-
|
253 |
-
# Cleanup
|
254 |
-
gc.collect()
|
255 |
-
# torch.cuda.empty_cache() # Optional: if using CUDA and want to be more aggressive
|
256 |
-
|
257 |
-
return (
|
258 |
-
f"Generated audio for prompt: '{prompt}' using {'MeanFlow' if use_meanflow else 'FlowMatching'}",
|
259 |
-
str(save_path),
|
260 |
-
)
|
261 |
-
|
262 |
-
|
263 |
-
# --- Gradio UI 和事件 ---
|
264 |
-
theme = gr.themes.Soft(
|
265 |
-
primary_hue="blue",
|
266 |
-
secondary_hue="slate",
|
267 |
-
neutral_hue="slate",
|
268 |
-
text_size="sm",
|
269 |
-
spacing_size="sm",
|
270 |
-
).set(
|
271 |
-
background_fill_primary="*neutral_50",
|
272 |
-
background_fill_secondary="*background_fill_primary",
|
273 |
-
block_background_fill="*background_fill_primary",
|
274 |
-
block_border_width="0px",
|
275 |
-
panel_background_fill="*neutral_50",
|
276 |
-
panel_border_width="0px",
|
277 |
-
input_background_fill="*neutral_100",
|
278 |
-
input_border_color="*neutral_200",
|
279 |
-
button_primary_background_fill="*primary_300",
|
280 |
-
button_primary_background_fill_hover="*primary_400",
|
281 |
-
button_secondary_background_fill="*neutral_200",
|
282 |
-
button_secondary_background_fill_hover="*neutral_300",
|
283 |
-
)
|
284 |
-
custom_css = """
|
285 |
-
#main-headertitle {
|
286 |
-
text-align: center;
|
287 |
-
margin-top: 15px;
|
288 |
-
margin-bottom: 10px;
|
289 |
-
color: var(--neutral-600);
|
290 |
-
font-weight: 600;
|
291 |
-
}
|
292 |
-
#main-header {
|
293 |
-
text-align: center;
|
294 |
-
margin-top: 5px;
|
295 |
-
margin-bottom: 10px;
|
296 |
-
color: var(--neutral-600);
|
297 |
-
font-weight: 600;
|
298 |
-
}
|
299 |
-
#model-settings-header, #generation-settings-header {
|
300 |
-
color: var(--neutral-600);
|
301 |
-
margin-top: 8px;
|
302 |
-
margin-bottom: 8px;
|
303 |
-
font-weight: 500;
|
304 |
-
font-size: 1.1em;
|
305 |
-
}
|
306 |
-
.setting-section {
|
307 |
-
padding: 10px 12px;
|
308 |
-
border-radius: 6px;
|
309 |
-
background-color: var(--neutral-50);
|
310 |
-
margin-bottom: 10px;
|
311 |
-
border: 1px solid var(--neutral-100);
|
312 |
-
}
|
313 |
-
hr {
|
314 |
-
border: none;
|
315 |
-
height: 1px;
|
316 |
-
background-color: var(--neutral-200);
|
317 |
-
margin: 8px 0;
|
318 |
-
}
|
319 |
-
#generate-btn {
|
320 |
-
width: 100%;
|
321 |
-
max-width: 250px;
|
322 |
-
margin: 10px auto;
|
323 |
-
display: block;
|
324 |
-
padding: 10px 15px;
|
325 |
-
font-size: 16px;
|
326 |
-
border-radius: 5px;
|
327 |
-
}
|
328 |
-
#status-box {
|
329 |
-
min-height: 50px;
|
330 |
-
display: flex;
|
331 |
-
align-items: center;
|
332 |
-
justify-content: center;
|
333 |
-
padding: 8px;
|
334 |
-
border-radius: 5px;
|
335 |
-
border: 1px solid var(--neutral-200);
|
336 |
-
color: var(--neutral-700);
|
337 |
-
}
|
338 |
-
#project-badges {
|
339 |
-
text-align: center; /* 内容居中 */
|
340 |
-
margin-top: 30px;
|
341 |
-
margin-bottom: 20px; /* 与下方内容的间距 */
|
342 |
-
}
|
343 |
-
|
344 |
-
/* 针对徽章容器 div */
|
345 |
-
#project-badges #badge-container {
|
346 |
-
display: flex;
|
347 |
-
gap: 10px;
|
348 |
-
align-items: center;
|
349 |
-
justify-content: center;
|
350 |
-
flex-wrap: wrap;
|
351 |
-
}
|
352 |
-
|
353 |
-
/* 针对徽章图片本身 */
|
354 |
-
#project-badges img { /* 使用 #project-badges img 可以更精确地只影响这个区域的图片 */
|
355 |
-
border-radius: 5px; /* 圆角 */
|
356 |
-
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); /* 阴影 */
|
357 |
-
height: 20px; /* 统一高度 */
|
358 |
-
transition: transform 0.1s ease, box-shadow 0.1s ease; /* 添加悬停效果的过渡 */
|
359 |
-
}
|
360 |
-
|
361 |
-
/* 可选:添加悬停效果 */
|
362 |
-
#project-badges a:hover img {
|
363 |
-
transform: translateY(-2px); /* 向上轻微移动 */
|
364 |
-
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); /* 加强阴影 */
|
365 |
-
}
|
366 |
-
#audio-output {
|
367 |
-
height: 200px;
|
368 |
-
border-radius: 5px;
|
369 |
-
border: 1px solid var(--neutral-200);
|
370 |
-
}
|
371 |
-
.gradio-dropdown label, .gradio-checkbox label, .gradio-number label, .gradio-textbox label {
|
372 |
-
font-weight: 500;
|
373 |
-
color: var(--neutral-700);
|
374 |
-
font-size: 0.9em;
|
375 |
-
}
|
376 |
-
.gradio-row {
|
377 |
-
gap: 8px;
|
378 |
-
}
|
379 |
-
.gradio-block {
|
380 |
-
margin-bottom: 8px;
|
381 |
-
}
|
382 |
-
.setting-section .gradio-block {
|
383 |
-
margin-bottom: 6px;
|
384 |
-
}
|
385 |
-
::-webkit-scrollbar {
|
386 |
-
width: 8px;
|
387 |
-
height: 8px;
|
388 |
-
}
|
389 |
-
::-webkit-scrollbar-track {
|
390 |
-
background: var(--neutral-100);
|
391 |
-
border-radius: 4px;
|
392 |
-
}
|
393 |
-
::-webkit-scrollbar-thumb {
|
394 |
-
background: var(--neutral-300);
|
395 |
-
border-radius: 4px;
|
396 |
-
}
|
397 |
-
::-webkit-scrollbar-thumb:hover {
|
398 |
-
background: var(--neutral-400);
|
399 |
-
}
|
400 |
-
* {
|
401 |
-
scrollbar-width: thin;
|
402 |
-
scrollbar-color: var(--neutral-300) var(--neutral-100);
|
403 |
-
}
|
404 |
-
"""
|
405 |
-
with gr.Blocks(title="MeanAudio Generator", theme=theme, css=custom_css) as demo:
|
406 |
-
gr.Markdown("# MeanAudio:Fast and Faithful Text-to-Audio Generation with Mean Flows", elem_id="main-header")
|
407 |
-
#gr.Markdown("### Model and Generation Settings", elem_id="model-settings-header")
|
408 |
-
#[](https://arxiv.org/abs/2412.21037)
|
409 |
-
project_badges_markdown = '''
|
410 |
-
<div style="display: flex; gap: 10px; align-items: center; justify-content: center; flex-wrap: wrap; margin-bottom: 20px;">
|
411 |
-
<a href="https://huggingface.co/junxiliu/MeanAudio">
|
412 |
-
<img src="https://img.shields.io/badge/Model-HuggingFace-violet?logo=huggingface" alt="Hugging Face Model">
|
413 |
-
</a>
|
414 |
-
<a href="https://huggingface.co/spaces/chenxie95/MeanAudio">
|
415 |
-
<img src="https://img.shields.io/badge/Space-HuggingFace-8A2BE2?logo=huggingface" alt="Hugging Face Space">
|
416 |
-
</a>
|
417 |
-
<a href="https://meanaudio.github.io/">
|
418 |
-
<img src="https://img.shields.io/badge/Project-Page-brightred?style=flat" alt="Project Page">
|
419 |
-
</a>
|
420 |
-
<a href="https://github.com/xiquan-li/MeanAudio">
|
421 |
-
<img src="https://img.shields.io/badge/Code-GitHub-black?logo=github" alt="GitHub">
|
422 |
-
</a>
|
423 |
-
</div>
|
424 |
-
'''
|
425 |
-
# 使用 gr.Markdown 渲染徽章行
|
426 |
-
gr.Markdown(project_badges_markdown, elem_id="project-badges")
|
427 |
-
with gr.Column(elem_classes="setting-section"):
|
428 |
-
with gr.Row():
|
429 |
-
available_variants = (
|
430 |
-
list(all_model_cfg.keys()) if all_model_cfg else []
|
431 |
-
)
|
432 |
-
default_variant = (
|
433 |
-
'meanaudio_mf'
|
434 |
-
)
|
435 |
-
variant = gr.Dropdown(
|
436 |
-
label="Model Variant",
|
437 |
-
choices=available_variants,
|
438 |
-
value=default_variant,
|
439 |
-
interactive=True,
|
440 |
-
scale=3,
|
441 |
-
)
|
442 |
-
#gr.Markdown("### Audio Generation", elem_id="generation-settings-header")
|
443 |
-
with gr.Column(elem_classes="setting-section"):
|
444 |
-
with gr.Row():
|
445 |
-
prompt = gr.Textbox(
|
446 |
-
label="Prompt",
|
447 |
-
placeholder="Describe the sound you want to generate...",
|
448 |
-
scale=1,
|
449 |
-
)
|
450 |
-
negative_prompt = gr.Textbox(
|
451 |
-
label="Negative Prompt",
|
452 |
-
placeholder="Describe sounds you want to avoid...",
|
453 |
-
value="",
|
454 |
-
scale=1,
|
455 |
-
)
|
456 |
-
with gr.Row():
|
457 |
-
duration = gr.Number(
|
458 |
-
label="Duration (sec)", value=10.0, minimum=0.1, scale=1
|
459 |
-
)
|
460 |
-
cfg_strength = gr.Number(
|
461 |
-
label="CFG (Meanflow forced to 3)", value=3, minimum=0.0, scale=1
|
462 |
-
)
|
463 |
-
with gr.Row():
|
464 |
-
seed = gr.Number(
|
465 |
-
label="Seed (-1 for random)", value=42, precision=0, scale=1
|
466 |
-
)
|
467 |
-
num_steps = gr.Number(
|
468 |
-
label="Number of Steps",
|
469 |
-
value=1,
|
470 |
-
precision=0,
|
471 |
-
minimum=1,
|
472 |
-
scale=1,
|
473 |
-
)
|
474 |
-
generate_button = gr.Button("Generate", variant="primary", elem_id="generate-btn")
|
475 |
-
generate_output_text = gr.Textbox(
|
476 |
-
label="Result Status", interactive=False, elem_id="status-box"
|
477 |
-
)
|
478 |
-
audio_output = gr.Audio(
|
479 |
-
label="Generated Audio", type="filepath", elem_id="audio-output"
|
480 |
-
)
|
481 |
-
generate_button.click(
|
482 |
-
fn=generate_audio_gradio,
|
483 |
-
inputs=[
|
484 |
-
prompt,
|
485 |
-
negative_prompt,
|
486 |
-
duration,
|
487 |
-
cfg_strength,
|
488 |
-
num_steps,
|
489 |
-
seed,
|
490 |
-
variant,
|
491 |
-
|
492 |
-
],
|
493 |
-
outputs=[generate_output_text, audio_output],
|
494 |
-
)
|
495 |
-
audio_examples = [
|
496 |
-
# [prompt, negative_prompt, duration, cfg_strength, num_steps, seed, variant]
|
497 |
-
["A speech and gunfire followed by a gun being loaded", "", 10.0, 3, 1, 42, "meanaudio_mf"],
|
498 |
-
["Typing on a keyboard", "", 10.0, 3, 1, 42, "meanaudio_mf"],
|
499 |
-
["A man speaks followed by a popping noise and laughter", "", 10.0, 3, 2, 42, "meanaudio_mf"],
|
500 |
-
["Some humming followed by a toilet flushing", "", 10.0, 3, 2, 42, "meanaudio_mf"],
|
501 |
-
["Rain falling on a hard surface as thunder roars in the distance", "", 10.0, 3, 5, 42, "meanaudio_mf"],
|
502 |
-
["Food sizzling and oil popping", "", 10.0, 3, 25, 42, "meanaudio_mf"],
|
503 |
-
["Pots and dishes clanking as a man talks followed by liquid pouring into a container", "", 8.0, 3, 2, 42, "meanaudio_mf"],
|
504 |
-
["A few seconds of silence then a rasping sound against wood", "", 12.0, 3, 2, 42, "meanaudio_mf"],
|
505 |
-
["A man speaks as he gives a speech and then the crowd cheers", "", 10.0, 3, 25, 42, "fluxaudio_fm"],
|
506 |
-
["A goat bleating repeatedly", "", 10.0, 3, 50, 123, "fluxaudio_fm"],
|
507 |
-
["Tires squealing followed by an engine revving", "", 12.0, 4, 25, 456, "fluxaudio_fm"],
|
508 |
-
["Hammer slowly hitting the wooden table", "", 10.0, 3.5, 25, 42, "fluxaudio_fm"],
|
509 |
-
["Dog barking excitedly and man shouting as race car engine roars past", "", 10.0, 3, 1, 42, "meanaudio_mf"],
|
510 |
-
["A dog barking and a cat mewing and a racing car passes by", "", 12.0, 3, 5, -1, "meanaudio_mf"],
|
511 |
-
["Whistling with birds chirping", "", 10.0, 4, 50, 42, "fluxaudio_fm"],
|
512 |
-
|
513 |
-
]
|
514 |
-
gr.Examples(
|
515 |
-
examples=audio_examples,
|
516 |
-
inputs=[prompt, negative_prompt, duration, cfg_strength, num_steps, seed, variant], # 必须与示例数据顺序一致
|
517 |
-
outputs=[generate_output_text, audio_output], # 指定输出组件
|
518 |
-
fn=generate_audio_gradio, # 指定处理函数
|
519 |
-
# cache_examples=True, # 可选:缓存示例结果(需要 Hugging Face Space 支持或特定设置)
|
520 |
-
examples_per_page=5, # 可选:每页显示的示例数量
|
521 |
-
label="Example Prompts", # 可选:示例组的标签
|
522 |
-
)
|
523 |
-
# --- 关键修改点:使用 load 事件进行初始化,加载所有默认模型 ---
|
524 |
-
# This ensures all default models are loaded when the app starts
|
525 |
-
#demo.load(fn=initialize_all_default_models, inputs=None, outputs=None)
|
526 |
-
|
527 |
-
if __name__ == "__main__":
|
528 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|