AndreasXi commited on
Commit
079604c
·
1 Parent(s): 19ec831
Files changed (1) hide show
  1. app.py +6 -74
app.py CHANGED
@@ -37,7 +37,7 @@ setup_eval_logging()
37
 
38
  OUTPUT_DIR = Path("./output/gradio")
39
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
40
- NUM_SAMPLE = 2
41
 
42
  # 创建RLHF反馈数据目录
43
  FEEDBACK_DIR = Path("./rlhf")
@@ -175,10 +175,11 @@ def generate_audio_gradio(
175
  torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
176
  log.info(f"Audio saved to {save_path}")
177
  save_paths.append(str(save_path))
 
178
  if device == "cuda":
179
  torch.cuda.empty_cache()
180
 
181
- return save_paths[0], save_paths[1], prompt
182
 
183
 
184
  # Gradio input and output components
@@ -194,12 +195,11 @@ gr_interface = gr.Interface(
194
  fn=generate_audio_gradio,
195
  inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
196
  outputs=[
197
- gr.Audio(label="🎵 Audio Sample 1", type="filepath"),
198
- gr.Audio(label="🎵 Audio Sample 2", type="filepath"),
199
  gr.Textbox(label="Prompt Used", interactive=False)
200
  ],
201
  title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
202
- description="🎯 **RLHF数据收集**: 现在生成2个音频样本!生成后请在下方选择偏好并提交。",
203
  flagging_mode="never",
204
  examples=[
205
  ["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
@@ -218,79 +218,11 @@ gr_interface = gr.Interface(
218
  cache_examples="lazy",
219
  )
220
 
221
- # ==== Preference collection UI (RLHF) ====
222
-
223
- # 允许用户在两段音频之间选择偏好,并补充备注
224
- with gr.Blocks() as pref_block:
225
- gr.Markdown("## 🧠 RLHF 偏好标注")
226
- gr.Markdown("生成完成后,请在下方选择您更喜欢的音频(或都不好/差不多),并可附加简短备注。点“提交偏好”即可写入 `./rlhf/user_preferences.jsonl`。")
227
-
228
- # 这里复用上面 Interface 的输出:我们需要拿到两段音频的文件路径与使用的 prompt
229
- # 为了连接这两个“界面”,再放一组可粘连的输入组件:
230
- with gr.Row():
231
- gen_audio1_path = gr.Textbox(label="Audio 1 路径(自动填充)", interactive=False)
232
- gen_audio2_path = gr.Textbox(label="Audio 2 路径(自动填充)", interactive=False)
233
- prompt_used = gr.Textbox(label="Prompt(自动填充)", interactive=False)
234
-
235
- # 偏好选项与备注
236
- pref_choice = gr.Radio(
237
- ["audio1", "audio2", "equal", "both_bad"],
238
- value="audio1",
239
- label="你更偏好哪个?",
240
- info="equal=差不多; both_bad=都不好"
241
- )
242
- pref_comment = gr.Textbox(label="可选备注(例如:哪一段更贴合描述、是否有噪声/破音等)", lines=2)
243
-
244
- submit_btn = gr.Button("✅ 提交偏好")
245
- submit_status = gr.Markdown()
246
-
247
- # 小工具:读取当前标注条目数
248
- def _count_feedback():
249
- try:
250
- with open(FEEDBACK_FILE, "r", encoding="utf-8") as f:
251
- return sum(1 for _ in f)
252
- except FileNotFoundError:
253
- return 0
254
-
255
- refresh_btn = gr.Button("📈 刷新统计")
256
- count_box = gr.Markdown()
257
-
258
- def submit_preference_ui(a1, a2, p, pref, cmt):
259
- if not a1 or not a2:
260
- return "❗请先在上面的生成器里生成两段音频。"
261
- # 写入 jsonl
262
- msg = save_preference_feedback(p, a1, a2, pref, cmt)
263
- return msg
264
-
265
- def refresh_count_ui():
266
- n = _count_feedback()
267
- return f"当前已收集 **{n}** 条偏好样本。"
268
-
269
- submit_btn.click(
270
- fn=submit_preference_ui,
271
- inputs=[gen_audio1_path, gen_audio2_path, prompt_used, pref_choice, pref_comment],
272
- outputs=submit_status
273
- )
274
- refresh_btn.click(fn=refresh_count_ui, outputs=count_box)
275
-
276
- # —— 把 Interface 的输出“联动”到偏好区:当用户生成完成后,自动把路径和 prompt 填入偏好区输入框 ——
277
- def _passthrough(a1, a2, p):
278
- # 直接把接口输出透传给下方偏好区
279
- return a1, a2, p
280
-
281
- # 用 Events 把 Interface 的输出连到 pref_block 的三个文本框
282
- gr_interface.submit(
283
- fn=_passthrough,
284
- inputs=gr_interface.outputs, # [Audio1(filepath), Audio2(filepath), PromptUsed]
285
- outputs=[gen_audio1_path, gen_audio2_path, prompt_used],
286
- )
287
-
288
-
289
 
290
  if __name__ == "__main__":
291
  ensure_models_downloaded()
292
  load_model_cache()
293
- gr_interface.queue(15).launch(share=False, show_api=False)
294
 
295
  # theme = gr.themes.Soft(
296
  # primary_hue="blue",
 
37
 
38
  OUTPUT_DIR = Path("./output/gradio")
39
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
40
+ NUM_SAMPLE = 1
41
 
42
  # 创建RLHF反馈数据目录
43
  FEEDBACK_DIR = Path("./rlhf")
 
175
  torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
176
  log.info(f"Audio saved to {save_path}")
177
  save_paths.append(str(save_path))
178
+
179
  if device == "cuda":
180
  torch.cuda.empty_cache()
181
 
182
+ return save_paths[0], prompt
183
 
184
 
185
  # Gradio input and output components
 
195
  fn=generate_audio_gradio,
196
  inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
197
  outputs=[
198
+ gr.Audio(label="🎵 Audio Sample", type="filepath"),
 
199
  gr.Textbox(label="Prompt Used", interactive=False)
200
  ],
201
  title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
202
+ description="",
203
  flagging_mode="never",
204
  examples=[
205
  ["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
 
218
  cache_examples="lazy",
219
  )
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  if __name__ == "__main__":
223
  ensure_models_downloaded()
224
  load_model_cache()
225
+ gr_interface.queue(15).launch()
226
 
227
  # theme = gr.themes.Soft(
228
  # primary_hue="blue",