add rlhf
Browse files
app.py
CHANGED
@@ -37,7 +37,7 @@ setup_eval_logging()
|
|
37 |
|
38 |
OUTPUT_DIR = Path("./output/gradio")
|
39 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
40 |
-
NUM_SAMPLE =
|
41 |
|
42 |
# 创建RLHF反馈数据目录
|
43 |
FEEDBACK_DIR = Path("./rlhf")
|
@@ -175,10 +175,11 @@ def generate_audio_gradio(
|
|
175 |
torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
|
176 |
log.info(f"Audio saved to {save_path}")
|
177 |
save_paths.append(str(save_path))
|
|
|
178 |
if device == "cuda":
|
179 |
torch.cuda.empty_cache()
|
180 |
|
181 |
-
return save_paths[0],
|
182 |
|
183 |
|
184 |
# Gradio input and output components
|
@@ -194,12 +195,11 @@ gr_interface = gr.Interface(
|
|
194 |
fn=generate_audio_gradio,
|
195 |
inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
|
196 |
outputs=[
|
197 |
-
gr.Audio(label="🎵 Audio Sample
|
198 |
-
gr.Audio(label="🎵 Audio Sample 2", type="filepath"),
|
199 |
gr.Textbox(label="Prompt Used", interactive=False)
|
200 |
],
|
201 |
title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
|
202 |
-
description="
|
203 |
flagging_mode="never",
|
204 |
examples=[
|
205 |
["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
|
@@ -218,79 +218,11 @@ gr_interface = gr.Interface(
|
|
218 |
cache_examples="lazy",
|
219 |
)
|
220 |
|
221 |
-
# ==== Preference collection UI (RLHF) ====
|
222 |
-
|
223 |
-
# 允许用户在两段音频之间选择偏好,并补充备注
|
224 |
-
with gr.Blocks() as pref_block:
|
225 |
-
gr.Markdown("## 🧠 RLHF 偏好标注")
|
226 |
-
gr.Markdown("生成完成后,请在下方选择您更喜欢的音频(或都不好/差不多),并可附加简短备注。点“提交偏好”即可写入 `./rlhf/user_preferences.jsonl`。")
|
227 |
-
|
228 |
-
# 这里复用上面 Interface 的输出:我们需要拿到两段音频的文件路径与使用的 prompt
|
229 |
-
# 为了连接这两个“界面”,再放一组可粘连的输入组件:
|
230 |
-
with gr.Row():
|
231 |
-
gen_audio1_path = gr.Textbox(label="Audio 1 路径(自动填充)", interactive=False)
|
232 |
-
gen_audio2_path = gr.Textbox(label="Audio 2 路径(自动填充)", interactive=False)
|
233 |
-
prompt_used = gr.Textbox(label="Prompt(自动填充)", interactive=False)
|
234 |
-
|
235 |
-
# 偏好选项与备注
|
236 |
-
pref_choice = gr.Radio(
|
237 |
-
["audio1", "audio2", "equal", "both_bad"],
|
238 |
-
value="audio1",
|
239 |
-
label="你更偏好哪个?",
|
240 |
-
info="equal=差不多; both_bad=都不好"
|
241 |
-
)
|
242 |
-
pref_comment = gr.Textbox(label="可选备注(例如:哪一段更贴合描述、是否有噪声/破音等)", lines=2)
|
243 |
-
|
244 |
-
submit_btn = gr.Button("✅ 提交偏好")
|
245 |
-
submit_status = gr.Markdown()
|
246 |
-
|
247 |
-
# 小工具:读取当前标注条目数
|
248 |
-
def _count_feedback():
|
249 |
-
try:
|
250 |
-
with open(FEEDBACK_FILE, "r", encoding="utf-8") as f:
|
251 |
-
return sum(1 for _ in f)
|
252 |
-
except FileNotFoundError:
|
253 |
-
return 0
|
254 |
-
|
255 |
-
refresh_btn = gr.Button("📈 刷新统计")
|
256 |
-
count_box = gr.Markdown()
|
257 |
-
|
258 |
-
def submit_preference_ui(a1, a2, p, pref, cmt):
|
259 |
-
if not a1 or not a2:
|
260 |
-
return "❗请先在上面的生成器里生成两段音频。"
|
261 |
-
# 写入 jsonl
|
262 |
-
msg = save_preference_feedback(p, a1, a2, pref, cmt)
|
263 |
-
return msg
|
264 |
-
|
265 |
-
def refresh_count_ui():
|
266 |
-
n = _count_feedback()
|
267 |
-
return f"当前已收集 **{n}** 条偏好样本。"
|
268 |
-
|
269 |
-
submit_btn.click(
|
270 |
-
fn=submit_preference_ui,
|
271 |
-
inputs=[gen_audio1_path, gen_audio2_path, prompt_used, pref_choice, pref_comment],
|
272 |
-
outputs=submit_status
|
273 |
-
)
|
274 |
-
refresh_btn.click(fn=refresh_count_ui, outputs=count_box)
|
275 |
-
|
276 |
-
# —— 把 Interface 的输出“联动”到偏好区:当用户生成完成后,自动把路径和 prompt 填入偏好区输入框 ——
|
277 |
-
def _passthrough(a1, a2, p):
|
278 |
-
# 直接把接口输出透传给下方偏好区
|
279 |
-
return a1, a2, p
|
280 |
-
|
281 |
-
# 用 Events 把 Interface 的输出连到 pref_block 的三个文本框
|
282 |
-
gr_interface.submit(
|
283 |
-
fn=_passthrough,
|
284 |
-
inputs=gr_interface.outputs, # [Audio1(filepath), Audio2(filepath), PromptUsed]
|
285 |
-
outputs=[gen_audio1_path, gen_audio2_path, prompt_used],
|
286 |
-
)
|
287 |
-
|
288 |
-
|
289 |
|
290 |
if __name__ == "__main__":
|
291 |
ensure_models_downloaded()
|
292 |
load_model_cache()
|
293 |
-
gr_interface.queue(15).launch(
|
294 |
|
295 |
# theme = gr.themes.Soft(
|
296 |
# primary_hue="blue",
|
|
|
37 |
|
38 |
OUTPUT_DIR = Path("./output/gradio")
|
39 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
40 |
+
NUM_SAMPLE = 1
|
41 |
|
42 |
# 创建RLHF反馈数据目录
|
43 |
FEEDBACK_DIR = Path("./rlhf")
|
|
|
175 |
torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
|
176 |
log.info(f"Audio saved to {save_path}")
|
177 |
save_paths.append(str(save_path))
|
178 |
+
|
179 |
if device == "cuda":
|
180 |
torch.cuda.empty_cache()
|
181 |
|
182 |
+
return save_paths[0], prompt
|
183 |
|
184 |
|
185 |
# Gradio input and output components
|
|
|
195 |
fn=generate_audio_gradio,
|
196 |
inputs=[input_text, duration, cfg_strength, denoising_steps, variant],
|
197 |
outputs=[
|
198 |
+
gr.Audio(label="🎵 Audio Sample", type="filepath"),
|
|
|
199 |
gr.Textbox(label="Prompt Used", interactive=False)
|
200 |
],
|
201 |
title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
|
202 |
+
description="",
|
203 |
flagging_mode="never",
|
204 |
examples=[
|
205 |
["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
|
|
|
218 |
cache_examples="lazy",
|
219 |
)
|
220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
if __name__ == "__main__":
|
223 |
ensure_models_downloaded()
|
224 |
load_model_cache()
|
225 |
+
gr_interface.queue(15).launch()
|
226 |
|
227 |
# theme = gr.themes.Soft(
|
228 |
# primary_hue="blue",
|