File size: 9,254 Bytes
09794cb
 
dd97a96
09794cb
 
 
 
 
 
 
fec53ba
09794cb
 
 
 
 
 
 
 
 
 
 
 
 
 
629a90b
09794cb
 
dd97a96
0ff9928
09794cb
 
0ff9928
09794cb
 
 
0ff9928
09794cb
 
079604c
09794cb
629a90b
 
 
 
 
f47e09f
 
 
 
2b7760c
 
 
 
 
 
 
 
 
 
f47e09f
 
 
 
0fe93da
f47e09f
 
d712cde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0685a2c
d712cde
09794cb
629a90b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09794cb
f47e09f
09794cb
 
 
 
 
 
 
085b825
09794cb
a9bb66f
0ff9928
 
 
 
09794cb
d712cde
629a90b
0ff9928
 
 
 
 
09794cb
98c6962
0ff9928
 
 
 
09794cb
 
 
 
 
0ff9928
09794cb
 
 
 
 
 
 
0ff9928
 
085b825
0ff9928
09794cb
fec53ba
97dea78
09794cb
 
 
 
 
 
2b7760c
51fb3d2
 
 
 
 
 
2b7760c
 
085b825
09794cb
22e35a0
 
 
 
 
 
079604c
f47e09f
 
09794cb
079604c
09794cb
6a37b4f
 
 
98c6962
6a37b4f
2b7760c
 
6a37b4f
085b825
bbd22e4
796bea7
bbd22e4
7ba28c1
 
 
bbd22e4
6a37b4f
 
085b825
629a90b
079604c
629a90b
 
e2dd3f3
085b825
19ec831
6a37b4f
085b825
f4337c3
085b825
 
 
 
 
 
 
 
6a37b4f
19ec831
 
 
6a37b4f
d712cde
0950fa7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import warnings
import spaces
warnings.filterwarnings("ignore")
import logging
from argparse import ArgumentParser
from pathlib import Path
import torch
import torchaudio
import gradio as gr
from transformers import AutoModel
import laion_clap
from meanaudio.eval_utils import (
    ModelConfig,
    all_model_cfg,
    generate_mf,
    generate_fm,
    setup_eval_logging,
)
from meanaudio.model.flow_matching import FlowMatching
from meanaudio.model.mean_flow import MeanFlow
from meanaudio.model.networks import MeanAudio, get_mean_audio
from meanaudio.model.utils.features_utils import FeaturesUtils
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import gc
import json
from datetime import datetime
from huggingface_hub import snapshot_download
import numpy as np

log = logging.getLogger()
device = "cpu"

if torch.cuda.is_available():
    device = "cuda"
setup_eval_logging()

OUTPUT_DIR = Path("./output/gradio")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
NUM_SAMPLE = 1

# 创建RLHF反馈数据目录
FEEDBACK_DIR = Path("./rlhf")
FEEDBACK_DIR.mkdir(exist_ok=True)
FEEDBACK_FILE = FEEDBACK_DIR / "user_preferences.jsonl"

# Global model cache to avoid reloading
MODEL_CACHE = {}
FEATURE_UTILS_CACHE = {}


def fade_out(x, sr, fade_ms=50):
    n = len(x)
    k = int(sr * fade_ms / 1000)
    if k <= 0 or k >= n: 
        return x
    w = np.linspace(1.0, 0.0, k)
    x[-k:] = x[-k:] * w
    return x

def ensure_models_downloaded():
    for variant, model_cfg in all_model_cfg.items():
        if not model_cfg.model_path.exists():
            log.info(f'Model {variant} not found, downloading...')
            snapshot_download(repo_id="AndreasXi/MeanAudio", local_dir="./weights")
            break

def load_model_cache(): 
    for variant in all_model_cfg.keys():
        if variant in MODEL_CACHE:
            return MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default']
        else:
            log.info(f"Loading model {variant} for the first time...")
            model_cfg = all_model_cfg[variant]
            net = get_mean_audio(model_cfg.model_name, use_rope=True, text_c_dim=512)
            net = net.to(device, torch.bfloat16).eval()
            net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True))
            MODEL_CACHE[variant] = net
            feature_utils = FeaturesUtils(
                tod_vae_ckpt=model_cfg.vae_path,
                enable_conditions=True,
                encoder_name="t5_clap", 
                mode=model_cfg.mode,
                bigvgan_vocoder_ckpt=model_cfg.bigvgan_16k_path,
                need_vae_encoder=False
            ).to(device, torch.bfloat16).eval()
            FEATURE_UTILS_CACHE['default'] = feature_utils

def save_preference_feedback(prompt, audio1_path, audio2_path, preference, additional_comment=""):
    feedback_data = {
        "timestamp": datetime.now().isoformat(),
        "prompt": prompt,
        "audio1_path": audio1_path,
        "audio2_path": audio2_path,
        "preference": preference,  # "audio1", "audio2", "equal", "both_bad"
        "additional_comment": additional_comment
    }
    
    with open(FEEDBACK_FILE, "a", encoding="utf-8") as f:
        f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n")
    
    log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'")
    return f"✅ Thanks for your feedback, preference recorded: {preference}"


@spaces.GPU(duration=60)
@torch.inference_mode()
def generate_audio_gradio(
    prompt,
    duration,
    cfg_strength,
    num_steps,
    variant,
    seed
):
    # update
    if duration <= 0 or num_steps <= 0:
        raise ValueError("Duration and number of steps must be positive.")
    if variant not in all_model_cfg:
        raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}")

    net, feature_utils = MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default']

    model = all_model_cfg[variant]
    seq_cfg = model.seq_cfg
    seq_cfg.duration = duration
    
    net.update_seq_lengths(seq_cfg.latent_seq_len)

    if variant == 'meanaudio_s_ac' or variant == 'meanaudio_s_full' or variant == 'meanaudio_l_full':
        use_meanflow=True
    elif variant == 'fluxaudio_s_full':
        use_meanflow=False

    if use_meanflow:
        sampler = MeanFlow(steps=num_steps)
        log.info("Using MeanFlow for generation.")
        generation_func = generate_mf
        sampler_arg_name = "mf"
        cfg_strength = 0
    else:
        sampler = FlowMatching(
            min_sigma=0, inference_mode="euler", num_steps=num_steps
        )
        log.info("Using FlowMatching for generation.")
        generation_func = generate_fm
        sampler_arg_name = "fm"

    rng = torch.Generator(device=device)
    rng.manual_seed(seed)   

    audios = generation_func(
        [prompt]*NUM_SAMPLE,
        negative_text=None,
        feature_utils=feature_utils,
        net=net,
        rng=rng,
        cfg_strength=cfg_strength,
        **{sampler_arg_name: sampler},
    )
    save_paths = []
    safe_prompt = (
        "".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
        .rstrip()
        .replace(" ", "_")[:50]
    )
    
    for i, audio in enumerate(audios): 
        audio = audio.float().cpu()
        audio = fade_out(audio, seq_cfg.sampling_rate, fade_ms=100)

        current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
        filename = f"{safe_prompt}_{current_time_string}_{i}.flac"
        save_path = OUTPUT_DIR / filename
        torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
        log.info(f"Audio saved to {save_path}")
        save_paths.append(str(save_path))

    if device == "cuda":
        torch.cuda.empty_cache()

    return save_paths[0], prompt


# Gradio input and output components
input_text = gr.Textbox(lines=2, label="Prompt")
variant = gr.Dropdown(label="Model Variant", choices=list(all_model_cfg.keys()), value='meanaudio_s_full', interactive=True)
output_audio = gr.Audio(label="Generated Audio", type="filepath")
denoising_steps = gr.Slider(minimum=1, maximum=25, value=1, step=1, label="Sampling Steps", interactive=True)
cfg_strength = gr.Slider(minimum=1, maximum=10, value=4.5, step=0.5, label="Guidance Scale", interactive=True)
duration = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Duration", interactive=True)
seed = gr.Slider(minimum=1, maximum=100, value=42, step=1, label="Seed", interactive=True)


description_text = """
### **MeanAudio** is a novel text-to-audio generator that uses **MeanFlow** to synthesize realistic and faithful audio in few sampling steps. It achieves state-of-the-art performance in single-step audio generation and delivers strong performance in multi-step audio generation.

### [📖 **Arxiv**](https://arxiv.org/abs/2508.06098) | [💻 **GitHub**](https://github.com/xiquan-li/MeanAudio) | [🤗 **Model**](https://huggingface.co/AndreasXi/MeanAudio) | [🚀 **Space**](https://huggingface.co/spaces/chenxie95/MeanAudio) | [🌐 **Project Page**](https://meanaudio.github.io/)
"""
gr_interface = gr.Interface(
    fn=generate_audio_gradio,
    inputs=[input_text, duration, cfg_strength, denoising_steps, variant, seed],
    outputs=[
        gr.Audio(label="🎵 Audio Sample", type="filepath"),
        gr.Textbox(label="Prompt Used", interactive=False)
    ],
    title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
    description=description_text,
    flagging_mode="never",
    examples=[
        ["Guitar and piano playing a warm music, with a soft and gentle melody, perfect for a romantic evening.", 10, 3, 1, "meanaudio_s_full", 42],
        ["Melodic human whistling harmonizing with natural birdsong", 10, 3, 1, "meanaudio_s_full", 42],
        ["A parade marches through a town square, with drumbeats pounding, children clapping, and a horse neighing amidst the commotion", 10, 3, 1, "meanaudio_s_full", 42],
        ["Quiet speech and then and airplane flying away", 10, 3, 1, "meanaudio_s_full", 42],
        ["A basketball bounces rhythmically on a court, shoes squeak against the floor, and a referee’s whistle cuts through the air", 10, 3, 1, "meanaudio_s_full", 42],
        ["Chopping meat on a wooden table.", 10, 3, 1, "meanaudio_s_full", 42],
        ["A vehicle engine revving then accelerating at a high rate as a metal surface is whipped followed by tires skidding.", 10, 3, 1, "meanaudio_s_full", 42],
        ["Battlefield scene, continuous roar of artillery and gunfire, high fidelity, the sharp crack of bullets, the thundering explosions of bombs, and the screams of wounded soldiers.", 10, 3, 1, "meanaudio_s_full", 42],
        ["Pop music that upbeat, catchy, and easy to listen, high fidelity, with simple melodies, electronic instruments and polished production.", 10, 3, 1, "meanaudio_s_full", 42],
        ["A fast-paced instrumental piece with a classical vibe featuring stringed instruments, evoking an energetic and uplifting mood.", 10, 3, 1, "meanaudio_s_full", 42]
    ],
    cache_examples="lazy",
)


if __name__ == "__main__":
    ensure_models_downloaded()
    load_model_cache()
    gr_interface.queue(15).launch()