add rlhf
Browse files
app.py
CHANGED
@@ -102,39 +102,6 @@ def save_preference_feedback(prompt, audio1_path, audio2_path, preference, addit
|
|
102 |
log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'")
|
103 |
return f"✅ Thanks for your feedback, preference recorded: {preference}"
|
104 |
|
105 |
-
def save_preference_feedback_from_flag(input_text, duration, cfg_strength, num_steps, variant,
|
106 |
-
audio1_path, audio2_path, prompt_used, preference, comment):
|
107 |
-
"""处理Gradio flagging回调的反馈保存"""
|
108 |
-
try:
|
109 |
-
if not preference:
|
110 |
-
print("⚠️ 用户没有选择偏好")
|
111 |
-
return
|
112 |
-
|
113 |
-
feedback_data = {
|
114 |
-
"timestamp": datetime.now().isoformat(),
|
115 |
-
"prompt": prompt_used or input_text,
|
116 |
-
"audio1_path": audio1_path,
|
117 |
-
"audio2_path": audio2_path,
|
118 |
-
"preference": preference,
|
119 |
-
"additional_comment": comment or "",
|
120 |
-
"generation_params": {
|
121 |
-
"duration": duration,
|
122 |
-
"cfg_strength": cfg_strength,
|
123 |
-
"num_steps": num_steps,
|
124 |
-
"variant": variant
|
125 |
-
}
|
126 |
-
}
|
127 |
-
|
128 |
-
with open(FEEDBACK_FILE, "a", encoding="utf-8") as f:
|
129 |
-
f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n")
|
130 |
-
|
131 |
-
log.info(f"✅ 反馈已保存: {preference} - {prompt_used[:50]}...")
|
132 |
-
print(f"✅ 用户反馈已保存到: {FEEDBACK_FILE}")
|
133 |
-
|
134 |
-
except Exception as e:
|
135 |
-
log.error(f"保存反馈时出错: {e}")
|
136 |
-
print(f"❌ 保存反馈时出错: {e}")
|
137 |
-
|
138 |
|
139 |
@spaces.GPU(duration=60)
|
140 |
@torch.inference_mode()
|
@@ -227,31 +194,13 @@ gr_interface = gr.Interface(
|
|
227 |
fn=generate_audio_gradio,
|
228 |
inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
|
229 |
outputs=[
|
230 |
-
gr.Audio(label="🎵 Audio Sample 1"),
|
231 |
-
gr.Audio(label="🎵 Audio Sample 2"),
|
232 |
gr.Textbox(label="Prompt Used", interactive=False)
|
233 |
],
|
234 |
-
additional_inputs=[
|
235 |
-
gr.Radio(
|
236 |
-
choices=[
|
237 |
-
("🎵 Audio 1 更好", "audio1"),
|
238 |
-
("🎵 Audio 2 更好", "audio2"),
|
239 |
-
("😊 两者都很好", "equal"),
|
240 |
-
("😔 两者都不好", "both_bad")
|
241 |
-
],
|
242 |
-
label="🤔 请选择您更喜欢的音频:",
|
243 |
-
value=None
|
244 |
-
),
|
245 |
-
gr.Textbox(
|
246 |
-
label="💭 评论 (可选)",
|
247 |
-
placeholder="您对音频质量的具体反馈...",
|
248 |
-
lines=2
|
249 |
-
)
|
250 |
-
],
|
251 |
title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
|
252 |
-
description="🎯 **RLHF数据收集**: 现在生成2
|
253 |
-
flagging_mode="
|
254 |
-
flagging_callback=lambda *args: save_preference_feedback_from_flag(*args),
|
255 |
examples=[
|
256 |
["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
|
257 |
["Melodic human whistling harmonizing with natural birdsong", 10, 3, 1, "meanaudio_s_full"],
|
@@ -266,13 +215,82 @@ gr_interface = gr.Interface(
|
|
266 |
['doorbell ding once followed by footsteps gradually getting louder and a door is opened ', 10, 3, 1, "meanaudio_s_full"],
|
267 |
["A fork scrapes a plate, water drips slowly into a sink, and the faint hum of a refrigerator lingers in the background", 10, 3, 1, "meanaudio_s_full"]
|
268 |
],
|
269 |
-
cache_examples="lazy",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
|
272 |
if __name__ == "__main__":
|
273 |
ensure_models_downloaded()
|
274 |
load_model_cache()
|
275 |
-
gr_interface.queue(15).launch()
|
276 |
|
277 |
# theme = gr.themes.Soft(
|
278 |
# primary_hue="blue",
|
|
|
102 |
log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'")
|
103 |
return f"✅ Thanks for your feedback, preference recorded: {preference}"
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
@spaces.GPU(duration=60)
|
107 |
@torch.inference_mode()
|
|
|
194 |
fn=generate_audio_gradio,
|
195 |
inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
|
196 |
outputs=[
|
197 |
+
gr.Audio(label="🎵 Audio Sample 1", type="filepath"),
|
198 |
+
gr.Audio(label="🎵 Audio Sample 2", type="filepath"),
|
199 |
gr.Textbox(label="Prompt Used", interactive=False)
|
200 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
|
202 |
+
description="🎯 **RLHF数据收集**: 现在生成2个音频样本!生成后请在下方选择偏好并提交。",
|
203 |
+
flagging_mode="never",
|
|
|
204 |
examples=[
|
205 |
["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
|
206 |
["Melodic human whistling harmonizing with natural birdsong", 10, 3, 1, "meanaudio_s_full"],
|
|
|
215 |
['doorbell ding once followed by footsteps gradually getting louder and a door is opened ', 10, 3, 1, "meanaudio_s_full"],
|
216 |
["A fork scrapes a plate, water drips slowly into a sink, and the faint hum of a refrigerator lingers in the background", 10, 3, 1, "meanaudio_s_full"]
|
217 |
],
|
218 |
+
cache_examples="lazy",
|
219 |
+
)
|
220 |
+
|
221 |
+
# ==== Preference collection UI (RLHF) ====
|
222 |
+
|
223 |
+
# 允许用户在两段音频之间选择偏好,并补充备注
|
224 |
+
with gr.Blocks() as pref_block:
|
225 |
+
gr.Markdown("## 🧠 RLHF 偏好标注")
|
226 |
+
gr.Markdown("生成完成后,请在下方选择您更喜欢的音频(或都不好/差不多),并可附加简短备注。点“提交偏好”即可写入 `./rlhf/user_preferences.jsonl`。")
|
227 |
+
|
228 |
+
# 这里复用上面 Interface 的输出:我们需要拿到两段音频的文件路径与使用的 prompt
|
229 |
+
# 为了连接这两个“界面”,再放一组可粘连的输入组件:
|
230 |
+
with gr.Row():
|
231 |
+
gen_audio1_path = gr.Textbox(label="Audio 1 路径(自动填充)", interactive=False)
|
232 |
+
gen_audio2_path = gr.Textbox(label="Audio 2 路径(自动填充)", interactive=False)
|
233 |
+
prompt_used = gr.Textbox(label="Prompt(自动填充)", interactive=False)
|
234 |
+
|
235 |
+
# 偏好选项与备注
|
236 |
+
pref_choice = gr.Radio(
|
237 |
+
["audio1", "audio2", "equal", "both_bad"],
|
238 |
+
value="audio1",
|
239 |
+
label="你更偏好哪个?",
|
240 |
+
info="equal=差不多; both_bad=都不好"
|
241 |
+
)
|
242 |
+
pref_comment = gr.Textbox(label="可选备注(例如:哪一段更贴合描述、是否有噪声/破音等)", lines=2)
|
243 |
+
|
244 |
+
submit_btn = gr.Button("✅ 提交偏好")
|
245 |
+
submit_status = gr.Markdown()
|
246 |
+
|
247 |
+
# 小工具:读取当前标注条目数
|
248 |
+
def _count_feedback():
|
249 |
+
try:
|
250 |
+
with open(FEEDBACK_FILE, "r", encoding="utf-8") as f:
|
251 |
+
return sum(1 for _ in f)
|
252 |
+
except FileNotFoundError:
|
253 |
+
return 0
|
254 |
+
|
255 |
+
refresh_btn = gr.Button("📈 刷新统计")
|
256 |
+
count_box = gr.Markdown()
|
257 |
+
|
258 |
+
def submit_preference_ui(a1, a2, p, pref, cmt):
|
259 |
+
if not a1 or not a2:
|
260 |
+
return "❗请先在上面的生成器里生成两段音频。"
|
261 |
+
# 写入 jsonl
|
262 |
+
msg = save_preference_feedback(p, a1, a2, pref, cmt)
|
263 |
+
return msg
|
264 |
+
|
265 |
+
def refresh_count_ui():
|
266 |
+
n = _count_feedback()
|
267 |
+
return f"当前已收集 **{n}** 条偏好样本。"
|
268 |
+
|
269 |
+
submit_btn.click(
|
270 |
+
fn=submit_preference_ui,
|
271 |
+
inputs=[gen_audio1_path, gen_audio2_path, prompt_used, pref_choice, pref_comment],
|
272 |
+
outputs=submit_status
|
273 |
)
|
274 |
+
refresh_btn.click(fn=refresh_count_ui, outputs=count_box)
|
275 |
+
|
276 |
+
# —— 把 Interface 的输出“联动”到偏好区:当用户生成完成后,自动把路径和 prompt 填入偏好区输入框 ——
|
277 |
+
def _passthrough(a1, a2, p):
|
278 |
+
# 直接把接口输出透传给下方偏好区
|
279 |
+
return a1, a2, p
|
280 |
+
|
281 |
+
# 用 Events 把 Interface 的输出连到 pref_block 的三个文本框
|
282 |
+
gr_interface.submit(
|
283 |
+
fn=_passthrough,
|
284 |
+
inputs=gr_interface.outputs, # [Audio1(filepath), Audio2(filepath), PromptUsed]
|
285 |
+
outputs=[gen_audio1_path, gen_audio2_path, prompt_used],
|
286 |
+
)
|
287 |
+
|
288 |
+
|
289 |
|
290 |
if __name__ == "__main__":
|
291 |
ensure_models_downloaded()
|
292 |
load_model_cache()
|
293 |
+
gr_interface.queue(15).launch(share=False, show_api=False)
|
294 |
|
295 |
# theme = gr.themes.Soft(
|
296 |
# primary_hue="blue",
|