AndreasXi commited on
Commit
19ec831
·
1 Parent(s): 51fb3d2
Files changed (1) hide show
  1. app.py +75 -57
app.py CHANGED
@@ -102,39 +102,6 @@ def save_preference_feedback(prompt, audio1_path, audio2_path, preference, addit
102
  log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'")
103
  return f"✅ Thanks for your feedback, preference recorded: {preference}"
104
 
105
- def save_preference_feedback_from_flag(input_text, duration, cfg_strength, num_steps, variant,
106
- audio1_path, audio2_path, prompt_used, preference, comment):
107
- """处理Gradio flagging回调的反馈保存"""
108
- try:
109
- if not preference:
110
- print("⚠️ 用户没有选择偏好")
111
- return
112
-
113
- feedback_data = {
114
- "timestamp": datetime.now().isoformat(),
115
- "prompt": prompt_used or input_text,
116
- "audio1_path": audio1_path,
117
- "audio2_path": audio2_path,
118
- "preference": preference,
119
- "additional_comment": comment or "",
120
- "generation_params": {
121
- "duration": duration,
122
- "cfg_strength": cfg_strength,
123
- "num_steps": num_steps,
124
- "variant": variant
125
- }
126
- }
127
-
128
- with open(FEEDBACK_FILE, "a", encoding="utf-8") as f:
129
- f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n")
130
-
131
- log.info(f"✅ 反馈已保存: {preference} - {prompt_used[:50]}...")
132
- print(f"✅ 用户反馈已保存到: {FEEDBACK_FILE}")
133
-
134
- except Exception as e:
135
- log.error(f"保存反馈时出错: {e}")
136
- print(f"❌ 保存反馈时出错: {e}")
137
-
138
 
139
  @spaces.GPU(duration=60)
140
  @torch.inference_mode()
@@ -227,31 +194,13 @@ gr_interface = gr.Interface(
227
  fn=generate_audio_gradio,
228
  inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
229
  outputs=[
230
- gr.Audio(label="🎵 Audio Sample 1"),
231
- gr.Audio(label="🎵 Audio Sample 2"),
232
  gr.Textbox(label="Prompt Used", interactive=False)
233
  ],
234
- additional_inputs=[
235
- gr.Radio(
236
- choices=[
237
- ("🎵 Audio 1 更好", "audio1"),
238
- ("🎵 Audio 2 更好", "audio2"),
239
- ("😊 两者都很好", "equal"),
240
- ("😔 两者都不好", "both_bad")
241
- ],
242
- label="🤔 请选择您更喜欢的音频:",
243
- value=None
244
- ),
245
- gr.Textbox(
246
- label="💭 评论 (可选)",
247
- placeholder="您对音频质量的具体反馈...",
248
- lines=2
249
- )
250
- ],
251
  title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
252
- description="🎯 **RLHF数据收集**: 现在生成2个音频样本!收集反馈数据用于改进模型。使用分析工具: `python analyze_feedback.py`",
253
- flagging_mode="manual",
254
- flagging_callback=lambda *args: save_preference_feedback_from_flag(*args),
255
  examples=[
256
  ["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
257
  ["Melodic human whistling harmonizing with natural birdsong", 10, 3, 1, "meanaudio_s_full"],
@@ -266,13 +215,82 @@ gr_interface = gr.Interface(
266
  ['doorbell ding once followed by footsteps gradually getting louder and a door is opened ', 10, 3, 1, "meanaudio_s_full"],
267
  ["A fork scrapes a plate, water drips slowly into a sink, and the faint hum of a refrigerator lingers in the background", 10, 3, 1, "meanaudio_s_full"]
268
  ],
269
- cache_examples="lazy", # Turn on to cache.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  if __name__ == "__main__":
273
  ensure_models_downloaded()
274
  load_model_cache()
275
- gr_interface.queue(15).launch()
276
 
277
  # theme = gr.themes.Soft(
278
  # primary_hue="blue",
 
102
  log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'")
103
  return f"✅ Thanks for your feedback, preference recorded: {preference}"
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  @spaces.GPU(duration=60)
107
  @torch.inference_mode()
 
194
  fn=generate_audio_gradio,
195
  inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
196
  outputs=[
197
+ gr.Audio(label="🎵 Audio Sample 1", type="filepath"),
198
+ gr.Audio(label="🎵 Audio Sample 2", type="filepath"),
199
  gr.Textbox(label="Prompt Used", interactive=False)
200
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
202
+ description="🎯 **RLHF数据收集**: 现在生成2个音频样本!生成后请在下方选择偏好并提交。",
203
+ flagging_mode="never",
 
204
  examples=[
205
  ["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
206
  ["Melodic human whistling harmonizing with natural birdsong", 10, 3, 1, "meanaudio_s_full"],
 
215
  ['doorbell ding once followed by footsteps gradually getting louder and a door is opened ', 10, 3, 1, "meanaudio_s_full"],
216
  ["A fork scrapes a plate, water drips slowly into a sink, and the faint hum of a refrigerator lingers in the background", 10, 3, 1, "meanaudio_s_full"]
217
  ],
218
+ cache_examples="lazy",
219
+ )
220
+
221
+ # ==== Preference collection UI (RLHF) ====
222
+
223
+ # 允许用户在两段音频之间选择偏好,并补充备注
224
+ with gr.Blocks() as pref_block:
225
+ gr.Markdown("## 🧠 RLHF 偏好标注")
226
+ gr.Markdown("生成完成后,请在下方选择您更喜欢的音频(或都不好/差不多),并可附加简短备注。点“提交偏好”即可写入 `./rlhf/user_preferences.jsonl`。")
227
+
228
+ # 这里复用上面 Interface 的输出:我们需要拿到两段音频的文件路径与使用的 prompt
229
+ # 为了连接这两个“界面”,再放一组可粘连的输入组件:
230
+ with gr.Row():
231
+ gen_audio1_path = gr.Textbox(label="Audio 1 路径(自动填充)", interactive=False)
232
+ gen_audio2_path = gr.Textbox(label="Audio 2 路径(自动填充)", interactive=False)
233
+ prompt_used = gr.Textbox(label="Prompt(自动填充)", interactive=False)
234
+
235
+ # 偏好选项与备注
236
+ pref_choice = gr.Radio(
237
+ ["audio1", "audio2", "equal", "both_bad"],
238
+ value="audio1",
239
+ label="你更偏好哪个?",
240
+ info="equal=差不多; both_bad=都不好"
241
+ )
242
+ pref_comment = gr.Textbox(label="可选备注(例如:哪一段更贴合描述、是否有噪声/破音等)", lines=2)
243
+
244
+ submit_btn = gr.Button("✅ 提交偏好")
245
+ submit_status = gr.Markdown()
246
+
247
+ # 小工具:读取当前标注条目数
248
+ def _count_feedback():
249
+ try:
250
+ with open(FEEDBACK_FILE, "r", encoding="utf-8") as f:
251
+ return sum(1 for _ in f)
252
+ except FileNotFoundError:
253
+ return 0
254
+
255
+ refresh_btn = gr.Button("📈 刷新统计")
256
+ count_box = gr.Markdown()
257
+
258
+ def submit_preference_ui(a1, a2, p, pref, cmt):
259
+ if not a1 or not a2:
260
+ return "❗请先在上面的生成器里生成两段音频。"
261
+ # 写入 jsonl
262
+ msg = save_preference_feedback(p, a1, a2, pref, cmt)
263
+ return msg
264
+
265
+ def refresh_count_ui():
266
+ n = _count_feedback()
267
+ return f"当前已收集 **{n}** 条偏好样本。"
268
+
269
+ submit_btn.click(
270
+ fn=submit_preference_ui,
271
+ inputs=[gen_audio1_path, gen_audio2_path, prompt_used, pref_choice, pref_comment],
272
+ outputs=submit_status
273
  )
274
+ refresh_btn.click(fn=refresh_count_ui, outputs=count_box)
275
+
276
+ # —— 把 Interface 的输出“联动”到偏好区:当用户生成完成后,自动把路径和 prompt 填入偏好区输入框 ——
277
+ def _passthrough(a1, a2, p):
278
+ # 直接把接口输出透传给下方偏好区
279
+ return a1, a2, p
280
+
281
+ # 用 Events 把 Interface 的输出连到 pref_block 的三个文本框
282
+ gr_interface.submit(
283
+ fn=_passthrough,
284
+ inputs=gr_interface.outputs, # [Audio1(filepath), Audio2(filepath), PromptUsed]
285
+ outputs=[gen_audio1_path, gen_audio2_path, prompt_used],
286
+ )
287
+
288
+
289
 
290
  if __name__ == "__main__":
291
  ensure_models_downloaded()
292
  load_model_cache()
293
+ gr_interface.queue(15).launch(share=False, show_api=False)
294
 
295
  # theme = gr.themes.Soft(
296
  # primary_hue="blue",