AndreasXi commited on
Commit
629a90b
·
1 Parent(s): 2b7760c
Files changed (2) hide show
  1. app.py +41 -16
  2. feedback_collector.py +127 -0
app.py CHANGED
@@ -23,6 +23,7 @@ from meanaudio.model.utils.features_utils import FeaturesUtils
23
  torch.backends.cuda.matmul.allow_tf32 = True
24
  torch.backends.cudnn.allow_tf32 = True
25
  import gc
 
26
  from datetime import datetime
27
  from huggingface_hub import snapshot_download
28
  import numpy as np
@@ -38,6 +39,11 @@ OUTPUT_DIR = Path("./output/gradio")
38
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
39
  NUM_SAMPLE = 2
40
 
 
 
 
 
 
41
  # Global model cache to avoid reloading
42
  MODEL_CACHE = {}
43
  FEATURE_UTILS_CACHE = {}
@@ -80,6 +86,22 @@ def load_model_cache():
80
  ).to(device, torch.bfloat16).eval()
81
  FEATURE_UTILS_CACHE['default'] = feature_utils
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  @spaces.GPU(duration=60)
85
  @torch.inference_mode()
@@ -97,7 +119,7 @@ def generate_audio_gradio(
97
  raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}")
98
 
99
  net, feature_utils = MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default']
100
-
101
  model = all_model_cfg[variant]
102
  seq_cfg = model.seq_cfg
103
  seq_cfg.duration = duration
@@ -142,21 +164,21 @@ def generate_audio_gradio(
142
 
143
  audio = fade_out(audio, seq_cfg.sampling_rate)
144
 
145
- safe_prompt = (
146
- "".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
147
- .rstrip()
148
- .replace(" ", "_")[:50]
149
- )
150
- current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
151
- filename = f"{safe_prompt}_{current_time_string}_{i}.flac"
152
- save_path = OUTPUT_DIR / filename
153
- torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
154
- log.info(f"Audio saved to {save_path}")
155
- save_paths.append(str(save_path))
156
  if device == "cuda":
157
  torch.cuda.empty_cache()
158
 
159
- return save_paths
160
 
161
 
162
  # Gradio input and output components
@@ -171,9 +193,13 @@ variant = gr.Dropdown(label="Model Variant", choices=list(all_model_cfg.keys()),
171
  gr_interface = gr.Interface(
172
  fn=generate_audio_gradio,
173
  inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
174
- outputs=["audio", "audio"],
 
 
 
 
175
  title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
176
- description="",
177
  flagging_mode="never",
178
  examples=[
179
  ["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
@@ -193,7 +219,6 @@ gr_interface = gr.Interface(
193
  )
194
 
195
  if __name__ == "__main__":
196
-
197
  ensure_models_downloaded()
198
  load_model_cache()
199
  gr_interface.queue(15).launch()
 
23
  torch.backends.cuda.matmul.allow_tf32 = True
24
  torch.backends.cudnn.allow_tf32 = True
25
  import gc
26
+ import json
27
  from datetime import datetime
28
  from huggingface_hub import snapshot_download
29
  import numpy as np
 
39
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
40
  NUM_SAMPLE = 2
41
 
42
+ # 创建RLHF反馈数据目录
43
+ FEEDBACK_DIR = Path("./rlhf")
44
+ FEEDBACK_DIR.mkdir(exist_ok=True)
45
+ FEEDBACK_FILE = FEEDBACK_DIR / "user_preferences.jsonl"
46
+
47
  # Global model cache to avoid reloading
48
  MODEL_CACHE = {}
49
  FEATURE_UTILS_CACHE = {}
 
86
  ).to(device, torch.bfloat16).eval()
87
  FEATURE_UTILS_CACHE['default'] = feature_utils
88
 
89
+ def save_preference_feedback(prompt, audio1_path, audio2_path, preference, additional_comment=""):
90
+ feedback_data = {
91
+ "timestamp": datetime.now().isoformat(),
92
+ "prompt": prompt,
93
+ "audio1_path": audio1_path,
94
+ "audio2_path": audio2_path,
95
+ "preference": preference, # "audio1", "audio2", "equal", "both_bad"
96
+ "additional_comment": additional_comment
97
+ }
98
+
99
+ with open(FEEDBACK_FILE, "a", encoding="utf-8") as f:
100
+ f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n")
101
+
102
+ log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'")
103
+ return f"✅ Thanks for your feedback, preference recorded: {preference}"
104
+
105
 
106
  @spaces.GPU(duration=60)
107
  @torch.inference_mode()
 
119
  raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}")
120
 
121
  net, feature_utils = MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default']
122
+
123
  model = all_model_cfg[variant]
124
  seq_cfg = model.seq_cfg
125
  seq_cfg.duration = duration
 
164
 
165
  audio = fade_out(audio, seq_cfg.sampling_rate)
166
 
167
+ safe_prompt = (
168
+ "".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
169
+ .rstrip()
170
+ .replace(" ", "_")[:50]
171
+ )
172
+ current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
173
+ filename = f"{safe_prompt}_{current_time_string}_{i}.flac"
174
+ save_path = OUTPUT_DIR / filename
175
+ torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
176
+ log.info(f"Audio saved to {save_path}")
177
+ save_paths.append(str(save_path))
178
  if device == "cuda":
179
  torch.cuda.empty_cache()
180
 
181
+ return save_paths[0], save_paths[1], prompt
182
 
183
 
184
  # Gradio input and output components
 
193
  gr_interface = gr.Interface(
194
  fn=generate_audio_gradio,
195
  inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
196
+ outputs=[
197
+ gr.Audio(label="🎵 Audio Sample 1"),
198
+ gr.Audio(label="🎵 Audio Sample 2"),
199
+ gr.Textbox(label="Prompt Used", interactive=False)
200
+ ],
201
  title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
202
+ description="🎯 **RLHF数据收集**: 现在生成2个音频样本!收集反馈数据用于改进模型。使用分析工具: `python analyze_feedback.py`",
203
  flagging_mode="never",
204
  examples=[
205
  ["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
 
219
  )
220
 
221
  if __name__ == "__main__":
 
222
  ensure_models_downloaded()
223
  load_model_cache()
224
  gr_interface.queue(15).launch()
feedback_collector.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 简单的反馈收集工具
4
+ 在MeanAudio生成音频后,运行此脚本收集用户偏好
5
+ """
6
+
7
+ import json
8
+ import os
9
+ import sys
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ import gradio as gr
13
+
14
+ # 设置反馈目录
15
+ FEEDBACK_DIR = Path("./rlhf_feedback")
16
+ FEEDBACK_DIR.mkdir(exist_ok=True)
17
+ FEEDBACK_FILE = FEEDBACK_DIR / "user_preferences.jsonl"
18
+
19
+ def save_feedback(audio1_path, audio2_path, prompt, preference, comment=""):
20
+ """保存反馈数据"""
21
+ feedback_data = {
22
+ "timestamp": datetime.now().isoformat(),
23
+ "prompt": prompt,
24
+ "audio1_path": audio1_path,
25
+ "audio2_path": audio2_path,
26
+ "preference": preference,
27
+ "additional_comment": comment
28
+ }
29
+
30
+ with open(FEEDBACK_FILE, "a", encoding="utf-8") as f:
31
+ f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n")
32
+
33
+ return f"✅ 反馈已保存!偏好: {preference}"
34
+
35
+ def create_feedback_interface():
36
+ """创建反馈收集界面"""
37
+
38
+ with gr.Blocks(title="MeanAudio 反馈收集器") as demo:
39
+ gr.Markdown("# MeanAudio 反馈收集器")
40
+ gr.Markdown("*请输入生成的音频文件路径和提示词,然后选择您的偏好*")
41
+
42
+ with gr.Row():
43
+ with gr.Column():
44
+ prompt_input = gr.Textbox(
45
+ label="提示词",
46
+ placeholder="输入用于生成音频的提示词..."
47
+ )
48
+
49
+ audio1_path = gr.Textbox(
50
+ label="音频文件1路径",
51
+ placeholder="./output/gradio/prompt_timestamp_0.flac"
52
+ )
53
+
54
+ audio2_path = gr.Textbox(
55
+ label="音频文件2路径",
56
+ placeholder="./output/gradio/prompt_timestamp_1.flac"
57
+ )
58
+
59
+ with gr.Column():
60
+ # 显示音频
61
+ audio1_player = gr.Audio(label="音频1")
62
+ audio2_player = gr.Audio(label="音频2")
63
+
64
+ load_btn = gr.Button("🔄 加载音频文件")
65
+
66
+ # 反馈区域
67
+ gr.Markdown("---")
68
+ gr.Markdown("### 请选择您的偏好")
69
+
70
+ preference = gr.Radio(
71
+ choices=[
72
+ ("音频1更好", "audio1"),
73
+ ("音频2更好", "audio2"),
74
+ ("两者质量相等", "equal"),
75
+ ("两者都不好", "both_bad")
76
+ ],
77
+ label="哪个音频更好?"
78
+ )
79
+
80
+ comment = gr.Textbox(
81
+ label="额外评论 (可选)",
82
+ placeholder="关于音频质量的具体反馈...",
83
+ lines=3
84
+ )
85
+
86
+ submit_btn = gr.Button("📝 提交反馈", variant="primary")
87
+
88
+ result = gr.Textbox(label="结果", interactive=False)
89
+
90
+ # 事件处理
91
+ def load_audio_files(path1, path2):
92
+ """加载音频文件用于播放"""
93
+ audio1 = path1 if os.path.exists(path1) else None
94
+ audio2 = path2 if os.path.exists(path2) else None
95
+ return audio1, audio2
96
+
97
+ load_btn.click(
98
+ fn=load_audio_files,
99
+ inputs=[audio1_path, audio2_path],
100
+ outputs=[audio1_player, audio2_player]
101
+ )
102
+
103
+ submit_btn.click(
104
+ fn=save_feedback,
105
+ inputs=[audio1_path, audio2_path, prompt_input, preference, comment],
106
+ outputs=[result]
107
+ )
108
+
109
+ # 使用说明
110
+ gr.Markdown("---")
111
+ gr.Markdown("""
112
+ ### 使用说明
113
+ 1. 先运行 MeanAudio 生成两个音频文件
114
+ 2. 将生成的音频文件路径复制到上面的输入框中
115
+ 3. 点击"加载音频文件"来播放音频
116
+ 4. 选择您的偏好并提交反馈
117
+ 5. 反馈数据将保存到 `./rlhf_feedback/user_preferences.jsonl`
118
+ 6. 使用 `python analyze_feedback.py` 分析收集的反馈数据
119
+ """)
120
+
121
+ return demo
122
+
123
+ if __name__ == "__main__":
124
+ demo = create_feedback_interface()
125
+ print("启动反馈收集界面...")
126
+ print(f"反馈数据将保存到: {FEEDBACK_FILE}")
127
+ demo.launch(server_name="127.0.0.1", server_port=7861)