add rlhf
Browse files- app.py +41 -16
- feedback_collector.py +127 -0
app.py
CHANGED
@@ -23,6 +23,7 @@ from meanaudio.model.utils.features_utils import FeaturesUtils
|
|
23 |
torch.backends.cuda.matmul.allow_tf32 = True
|
24 |
torch.backends.cudnn.allow_tf32 = True
|
25 |
import gc
|
|
|
26 |
from datetime import datetime
|
27 |
from huggingface_hub import snapshot_download
|
28 |
import numpy as np
|
@@ -38,6 +39,11 @@ OUTPUT_DIR = Path("./output/gradio")
|
|
38 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
39 |
NUM_SAMPLE = 2
|
40 |
|
|
|
|
|
|
|
|
|
|
|
41 |
# Global model cache to avoid reloading
|
42 |
MODEL_CACHE = {}
|
43 |
FEATURE_UTILS_CACHE = {}
|
@@ -80,6 +86,22 @@ def load_model_cache():
|
|
80 |
).to(device, torch.bfloat16).eval()
|
81 |
FEATURE_UTILS_CACHE['default'] = feature_utils
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
@spaces.GPU(duration=60)
|
85 |
@torch.inference_mode()
|
@@ -97,7 +119,7 @@ def generate_audio_gradio(
|
|
97 |
raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}")
|
98 |
|
99 |
net, feature_utils = MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default']
|
100 |
-
|
101 |
model = all_model_cfg[variant]
|
102 |
seq_cfg = model.seq_cfg
|
103 |
seq_cfg.duration = duration
|
@@ -142,21 +164,21 @@ def generate_audio_gradio(
|
|
142 |
|
143 |
audio = fade_out(audio, seq_cfg.sampling_rate)
|
144 |
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
if device == "cuda":
|
157 |
torch.cuda.empty_cache()
|
158 |
|
159 |
-
return save_paths
|
160 |
|
161 |
|
162 |
# Gradio input and output components
|
@@ -171,9 +193,13 @@ variant = gr.Dropdown(label="Model Variant", choices=list(all_model_cfg.keys()),
|
|
171 |
gr_interface = gr.Interface(
|
172 |
fn=generate_audio_gradio,
|
173 |
inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
|
174 |
-
outputs=[
|
|
|
|
|
|
|
|
|
175 |
title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
|
176 |
-
description="",
|
177 |
flagging_mode="never",
|
178 |
examples=[
|
179 |
["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
|
@@ -193,7 +219,6 @@ gr_interface = gr.Interface(
|
|
193 |
)
|
194 |
|
195 |
if __name__ == "__main__":
|
196 |
-
|
197 |
ensure_models_downloaded()
|
198 |
load_model_cache()
|
199 |
gr_interface.queue(15).launch()
|
|
|
23 |
torch.backends.cuda.matmul.allow_tf32 = True
|
24 |
torch.backends.cudnn.allow_tf32 = True
|
25 |
import gc
|
26 |
+
import json
|
27 |
from datetime import datetime
|
28 |
from huggingface_hub import snapshot_download
|
29 |
import numpy as np
|
|
|
39 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
40 |
NUM_SAMPLE = 2
|
41 |
|
42 |
+
# 创建RLHF反馈数据目录
|
43 |
+
FEEDBACK_DIR = Path("./rlhf")
|
44 |
+
FEEDBACK_DIR.mkdir(exist_ok=True)
|
45 |
+
FEEDBACK_FILE = FEEDBACK_DIR / "user_preferences.jsonl"
|
46 |
+
|
47 |
# Global model cache to avoid reloading
|
48 |
MODEL_CACHE = {}
|
49 |
FEATURE_UTILS_CACHE = {}
|
|
|
86 |
).to(device, torch.bfloat16).eval()
|
87 |
FEATURE_UTILS_CACHE['default'] = feature_utils
|
88 |
|
89 |
+
def save_preference_feedback(prompt, audio1_path, audio2_path, preference, additional_comment=""):
|
90 |
+
feedback_data = {
|
91 |
+
"timestamp": datetime.now().isoformat(),
|
92 |
+
"prompt": prompt,
|
93 |
+
"audio1_path": audio1_path,
|
94 |
+
"audio2_path": audio2_path,
|
95 |
+
"preference": preference, # "audio1", "audio2", "equal", "both_bad"
|
96 |
+
"additional_comment": additional_comment
|
97 |
+
}
|
98 |
+
|
99 |
+
with open(FEEDBACK_FILE, "a", encoding="utf-8") as f:
|
100 |
+
f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n")
|
101 |
+
|
102 |
+
log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'")
|
103 |
+
return f"✅ Thanks for your feedback, preference recorded: {preference}"
|
104 |
+
|
105 |
|
106 |
@spaces.GPU(duration=60)
|
107 |
@torch.inference_mode()
|
|
|
119 |
raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}")
|
120 |
|
121 |
net, feature_utils = MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default']
|
122 |
+
|
123 |
model = all_model_cfg[variant]
|
124 |
seq_cfg = model.seq_cfg
|
125 |
seq_cfg.duration = duration
|
|
|
164 |
|
165 |
audio = fade_out(audio, seq_cfg.sampling_rate)
|
166 |
|
167 |
+
safe_prompt = (
|
168 |
+
"".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
|
169 |
+
.rstrip()
|
170 |
+
.replace(" ", "_")[:50]
|
171 |
+
)
|
172 |
+
current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
173 |
+
filename = f"{safe_prompt}_{current_time_string}_{i}.flac"
|
174 |
+
save_path = OUTPUT_DIR / filename
|
175 |
+
torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
|
176 |
+
log.info(f"Audio saved to {save_path}")
|
177 |
+
save_paths.append(str(save_path))
|
178 |
if device == "cuda":
|
179 |
torch.cuda.empty_cache()
|
180 |
|
181 |
+
return save_paths[0], save_paths[1], prompt
|
182 |
|
183 |
|
184 |
# Gradio input and output components
|
|
|
193 |
gr_interface = gr.Interface(
|
194 |
fn=generate_audio_gradio,
|
195 |
inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
|
196 |
+
outputs=[
|
197 |
+
gr.Audio(label="🎵 Audio Sample 1"),
|
198 |
+
gr.Audio(label="🎵 Audio Sample 2"),
|
199 |
+
gr.Textbox(label="Prompt Used", interactive=False)
|
200 |
+
],
|
201 |
title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
|
202 |
+
description="🎯 **RLHF数据收集**: 现在生成2个音频样本!收集反馈数据用于改进模型。使用分析工具: `python analyze_feedback.py`",
|
203 |
flagging_mode="never",
|
204 |
examples=[
|
205 |
["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
|
|
|
219 |
)
|
220 |
|
221 |
if __name__ == "__main__":
|
|
|
222 |
ensure_models_downloaded()
|
223 |
load_model_cache()
|
224 |
gr_interface.queue(15).launch()
|
feedback_collector.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
简单的反馈收集工具
|
4 |
+
在MeanAudio生成音频后,运行此脚本收集用户偏好
|
5 |
+
"""
|
6 |
+
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
from datetime import datetime
|
11 |
+
from pathlib import Path
|
12 |
+
import gradio as gr
|
13 |
+
|
14 |
+
# 设置反馈目录
|
15 |
+
FEEDBACK_DIR = Path("./rlhf_feedback")
|
16 |
+
FEEDBACK_DIR.mkdir(exist_ok=True)
|
17 |
+
FEEDBACK_FILE = FEEDBACK_DIR / "user_preferences.jsonl"
|
18 |
+
|
19 |
+
def save_feedback(audio1_path, audio2_path, prompt, preference, comment=""):
|
20 |
+
"""保存反馈数据"""
|
21 |
+
feedback_data = {
|
22 |
+
"timestamp": datetime.now().isoformat(),
|
23 |
+
"prompt": prompt,
|
24 |
+
"audio1_path": audio1_path,
|
25 |
+
"audio2_path": audio2_path,
|
26 |
+
"preference": preference,
|
27 |
+
"additional_comment": comment
|
28 |
+
}
|
29 |
+
|
30 |
+
with open(FEEDBACK_FILE, "a", encoding="utf-8") as f:
|
31 |
+
f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n")
|
32 |
+
|
33 |
+
return f"✅ 反馈已保存!偏好: {preference}"
|
34 |
+
|
35 |
+
def create_feedback_interface():
|
36 |
+
"""创建反馈收集界面"""
|
37 |
+
|
38 |
+
with gr.Blocks(title="MeanAudio 反馈收集器") as demo:
|
39 |
+
gr.Markdown("# MeanAudio 反馈收集器")
|
40 |
+
gr.Markdown("*请输入生成的音频文件路径和提示词,然后选择您的偏好*")
|
41 |
+
|
42 |
+
with gr.Row():
|
43 |
+
with gr.Column():
|
44 |
+
prompt_input = gr.Textbox(
|
45 |
+
label="提示词",
|
46 |
+
placeholder="输入用于生成音频的提示词..."
|
47 |
+
)
|
48 |
+
|
49 |
+
audio1_path = gr.Textbox(
|
50 |
+
label="音频文件1路径",
|
51 |
+
placeholder="./output/gradio/prompt_timestamp_0.flac"
|
52 |
+
)
|
53 |
+
|
54 |
+
audio2_path = gr.Textbox(
|
55 |
+
label="音频文件2路径",
|
56 |
+
placeholder="./output/gradio/prompt_timestamp_1.flac"
|
57 |
+
)
|
58 |
+
|
59 |
+
with gr.Column():
|
60 |
+
# 显示音频
|
61 |
+
audio1_player = gr.Audio(label="音频1")
|
62 |
+
audio2_player = gr.Audio(label="音频2")
|
63 |
+
|
64 |
+
load_btn = gr.Button("🔄 加载音频文件")
|
65 |
+
|
66 |
+
# 反馈区域
|
67 |
+
gr.Markdown("---")
|
68 |
+
gr.Markdown("### 请选择您的偏好")
|
69 |
+
|
70 |
+
preference = gr.Radio(
|
71 |
+
choices=[
|
72 |
+
("音频1更好", "audio1"),
|
73 |
+
("音频2更好", "audio2"),
|
74 |
+
("两者质量相等", "equal"),
|
75 |
+
("两者都不好", "both_bad")
|
76 |
+
],
|
77 |
+
label="哪个音频更好?"
|
78 |
+
)
|
79 |
+
|
80 |
+
comment = gr.Textbox(
|
81 |
+
label="额外评论 (可选)",
|
82 |
+
placeholder="关于音频质量的具体反馈...",
|
83 |
+
lines=3
|
84 |
+
)
|
85 |
+
|
86 |
+
submit_btn = gr.Button("📝 提交反馈", variant="primary")
|
87 |
+
|
88 |
+
result = gr.Textbox(label="结果", interactive=False)
|
89 |
+
|
90 |
+
# 事件处理
|
91 |
+
def load_audio_files(path1, path2):
|
92 |
+
"""加载音频文件用于播放"""
|
93 |
+
audio1 = path1 if os.path.exists(path1) else None
|
94 |
+
audio2 = path2 if os.path.exists(path2) else None
|
95 |
+
return audio1, audio2
|
96 |
+
|
97 |
+
load_btn.click(
|
98 |
+
fn=load_audio_files,
|
99 |
+
inputs=[audio1_path, audio2_path],
|
100 |
+
outputs=[audio1_player, audio2_player]
|
101 |
+
)
|
102 |
+
|
103 |
+
submit_btn.click(
|
104 |
+
fn=save_feedback,
|
105 |
+
inputs=[audio1_path, audio2_path, prompt_input, preference, comment],
|
106 |
+
outputs=[result]
|
107 |
+
)
|
108 |
+
|
109 |
+
# 使用说明
|
110 |
+
gr.Markdown("---")
|
111 |
+
gr.Markdown("""
|
112 |
+
### 使用说明
|
113 |
+
1. 先运行 MeanAudio 生成两个音频文件
|
114 |
+
2. 将生成的音频文件路径复制到上面的输入框中
|
115 |
+
3. 点击"加载音频文件"来播放音频
|
116 |
+
4. 选择您的偏好并提交反馈
|
117 |
+
5. 反馈数据将保存到 `./rlhf_feedback/user_preferences.jsonl`
|
118 |
+
6. 使用 `python analyze_feedback.py` 分析收集的反馈数据
|
119 |
+
""")
|
120 |
+
|
121 |
+
return demo
|
122 |
+
|
123 |
+
if __name__ == "__main__":
|
124 |
+
demo = create_feedback_interface()
|
125 |
+
print("启动反馈收集界面...")
|
126 |
+
print(f"反馈数据将保存到: {FEEDBACK_FILE}")
|
127 |
+
demo.launch(server_name="127.0.0.1", server_port=7861)
|