add rlhf
Browse files
app.py
CHANGED
@@ -102,6 +102,39 @@ def save_preference_feedback(prompt, audio1_path, audio2_path, preference, addit
|
|
102 |
log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'")
|
103 |
return f"✅ Thanks for your feedback, preference recorded: {preference}"
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
@spaces.GPU(duration=60)
|
107 |
@torch.inference_mode()
|
@@ -159,16 +192,16 @@ def generate_audio_gradio(
|
|
159 |
**{sampler_arg_name: sampler},
|
160 |
)
|
161 |
save_paths = []
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
for i, audio in enumerate(audios):
|
163 |
audio = audio.float().cpu()
|
164 |
-
|
165 |
audio = fade_out(audio, seq_cfg.sampling_rate)
|
166 |
|
167 |
-
safe_prompt = (
|
168 |
-
"".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
|
169 |
-
.rstrip()
|
170 |
-
.replace(" ", "_")[:50]
|
171 |
-
)
|
172 |
current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
173 |
filename = f"{safe_prompt}_{current_time_string}_{i}.flac"
|
174 |
save_path = OUTPUT_DIR / filename
|
@@ -198,9 +231,27 @@ gr_interface = gr.Interface(
|
|
198 |
gr.Audio(label="🎵 Audio Sample 2"),
|
199 |
gr.Textbox(label="Prompt Used", interactive=False)
|
200 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
|
202 |
description="🎯 **RLHF数据收集**: 现在生成2个音频样本!收集反馈数据用于改进模型。使用分析工具: `python analyze_feedback.py`",
|
203 |
-
flagging_mode="
|
|
|
204 |
examples=[
|
205 |
["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
|
206 |
["Melodic human whistling harmonizing with natural birdsong", 10, 3, 1, "meanaudio_s_full"],
|
@@ -216,7 +267,7 @@ gr_interface = gr.Interface(
|
|
216 |
["A fork scrapes a plate, water drips slowly into a sink, and the faint hum of a refrigerator lingers in the background", 10, 3, 1, "meanaudio_s_full"]
|
217 |
],
|
218 |
cache_examples="lazy", # Turn on to cache.
|
219 |
-
)
|
220 |
|
221 |
if __name__ == "__main__":
|
222 |
ensure_models_downloaded()
|
|
|
102 |
log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'")
|
103 |
return f"✅ Thanks for your feedback, preference recorded: {preference}"
|
104 |
|
105 |
+
def save_preference_feedback_from_flag(input_text, duration, cfg_strength, num_steps, variant,
|
106 |
+
audio1_path, audio2_path, prompt_used, preference, comment):
|
107 |
+
"""处理Gradio flagging回调的反馈保存"""
|
108 |
+
try:
|
109 |
+
if not preference:
|
110 |
+
print("⚠️ 用户没有选择偏好")
|
111 |
+
return
|
112 |
+
|
113 |
+
feedback_data = {
|
114 |
+
"timestamp": datetime.now().isoformat(),
|
115 |
+
"prompt": prompt_used or input_text,
|
116 |
+
"audio1_path": audio1_path,
|
117 |
+
"audio2_path": audio2_path,
|
118 |
+
"preference": preference,
|
119 |
+
"additional_comment": comment or "",
|
120 |
+
"generation_params": {
|
121 |
+
"duration": duration,
|
122 |
+
"cfg_strength": cfg_strength,
|
123 |
+
"num_steps": num_steps,
|
124 |
+
"variant": variant
|
125 |
+
}
|
126 |
+
}
|
127 |
+
|
128 |
+
with open(FEEDBACK_FILE, "a", encoding="utf-8") as f:
|
129 |
+
f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n")
|
130 |
+
|
131 |
+
log.info(f"✅ 反馈已保存: {preference} - {prompt_used[:50]}...")
|
132 |
+
print(f"✅ 用户反馈已保存到: {FEEDBACK_FILE}")
|
133 |
+
|
134 |
+
except Exception as e:
|
135 |
+
log.error(f"保存反馈时出错: {e}")
|
136 |
+
print(f"❌ 保存反馈时出错: {e}")
|
137 |
+
|
138 |
|
139 |
@spaces.GPU(duration=60)
|
140 |
@torch.inference_mode()
|
|
|
192 |
**{sampler_arg_name: sampler},
|
193 |
)
|
194 |
save_paths = []
|
195 |
+
safe_prompt = (
|
196 |
+
"".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
|
197 |
+
.rstrip()
|
198 |
+
.replace(" ", "_")[:50]
|
199 |
+
)
|
200 |
+
|
201 |
for i, audio in enumerate(audios):
|
202 |
audio = audio.float().cpu()
|
|
|
203 |
audio = fade_out(audio, seq_cfg.sampling_rate)
|
204 |
|
|
|
|
|
|
|
|
|
|
|
205 |
current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
206 |
filename = f"{safe_prompt}_{current_time_string}_{i}.flac"
|
207 |
save_path = OUTPUT_DIR / filename
|
|
|
231 |
gr.Audio(label="🎵 Audio Sample 2"),
|
232 |
gr.Textbox(label="Prompt Used", interactive=False)
|
233 |
],
|
234 |
+
additional_inputs=[
|
235 |
+
gr.Radio(
|
236 |
+
choices=[
|
237 |
+
("🎵 Audio 1 更好", "audio1"),
|
238 |
+
("🎵 Audio 2 更好", "audio2"),
|
239 |
+
("😊 两者都很好", "equal"),
|
240 |
+
("😔 两者都不好", "both_bad")
|
241 |
+
],
|
242 |
+
label="🤔 请选择您更喜欢的音频:",
|
243 |
+
value=None
|
244 |
+
),
|
245 |
+
gr.Textbox(
|
246 |
+
label="💭 评论 (可选)",
|
247 |
+
placeholder="您对音频质量的具体反馈...",
|
248 |
+
lines=2
|
249 |
+
)
|
250 |
+
],
|
251 |
title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
|
252 |
description="🎯 **RLHF数据收集**: 现在生成2个音频样本!收集反馈数据用于改进模型。使用分析工具: `python analyze_feedback.py`",
|
253 |
+
flagging_mode="manual",
|
254 |
+
flagging_callback=lambda *args: save_preference_feedback_from_flag(*args),
|
255 |
examples=[
|
256 |
["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
|
257 |
["Melodic human whistling harmonizing with natural birdsong", 10, 3, 1, "meanaudio_s_full"],
|
|
|
267 |
["A fork scrapes a plate, water drips slowly into a sink, and the faint hum of a refrigerator lingers in the background", 10, 3, 1, "meanaudio_s_full"]
|
268 |
],
|
269 |
cache_examples="lazy", # Turn on to cache.
|
270 |
+
)
|
271 |
|
272 |
if __name__ == "__main__":
|
273 |
ensure_models_downloaded()
|