AndreasXi commited on
Commit
51fb3d2
·
1 Parent(s): 22e35a0
Files changed (1) hide show
  1. app.py +59 -8
app.py CHANGED
@@ -102,6 +102,39 @@ def save_preference_feedback(prompt, audio1_path, audio2_path, preference, addit
102
  log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'")
103
  return f"✅ Thanks for your feedback, preference recorded: {preference}"
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  @spaces.GPU(duration=60)
107
  @torch.inference_mode()
@@ -159,16 +192,16 @@ def generate_audio_gradio(
159
  **{sampler_arg_name: sampler},
160
  )
161
  save_paths = []
 
 
 
 
 
 
162
  for i, audio in enumerate(audios):
163
  audio = audio.float().cpu()
164
-
165
  audio = fade_out(audio, seq_cfg.sampling_rate)
166
 
167
- safe_prompt = (
168
- "".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
169
- .rstrip()
170
- .replace(" ", "_")[:50]
171
- )
172
  current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
173
  filename = f"{safe_prompt}_{current_time_string}_{i}.flac"
174
  save_path = OUTPUT_DIR / filename
@@ -198,9 +231,27 @@ gr_interface = gr.Interface(
198
  gr.Audio(label="🎵 Audio Sample 2"),
199
  gr.Textbox(label="Prompt Used", interactive=False)
200
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
202
  description="🎯 **RLHF数据收集**: 现在生成2个音频样本!收集反馈数据用于改进模型。使用分析工具: `python analyze_feedback.py`",
203
- flagging_mode="never",
 
204
  examples=[
205
  ["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
206
  ["Melodic human whistling harmonizing with natural birdsong", 10, 3, 1, "meanaudio_s_full"],
@@ -216,7 +267,7 @@ gr_interface = gr.Interface(
216
  ["A fork scrapes a plate, water drips slowly into a sink, and the faint hum of a refrigerator lingers in the background", 10, 3, 1, "meanaudio_s_full"]
217
  ],
218
  cache_examples="lazy", # Turn on to cache.
219
- )
220
 
221
  if __name__ == "__main__":
222
  ensure_models_downloaded()
 
102
  log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'")
103
  return f"✅ Thanks for your feedback, preference recorded: {preference}"
104
 
105
+ def save_preference_feedback_from_flag(input_text, duration, cfg_strength, num_steps, variant,
106
+ audio1_path, audio2_path, prompt_used, preference, comment):
107
+ """处理Gradio flagging回调的反馈保存"""
108
+ try:
109
+ if not preference:
110
+ print("⚠️ 用户没有选择偏好")
111
+ return
112
+
113
+ feedback_data = {
114
+ "timestamp": datetime.now().isoformat(),
115
+ "prompt": prompt_used or input_text,
116
+ "audio1_path": audio1_path,
117
+ "audio2_path": audio2_path,
118
+ "preference": preference,
119
+ "additional_comment": comment or "",
120
+ "generation_params": {
121
+ "duration": duration,
122
+ "cfg_strength": cfg_strength,
123
+ "num_steps": num_steps,
124
+ "variant": variant
125
+ }
126
+ }
127
+
128
+ with open(FEEDBACK_FILE, "a", encoding="utf-8") as f:
129
+ f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n")
130
+
131
+ log.info(f"✅ 反馈已保存: {preference} - {prompt_used[:50]}...")
132
+ print(f"✅ 用户反馈已保存到: {FEEDBACK_FILE}")
133
+
134
+ except Exception as e:
135
+ log.error(f"保存反馈时出错: {e}")
136
+ print(f"❌ 保存反馈时出错: {e}")
137
+
138
 
139
  @spaces.GPU(duration=60)
140
  @torch.inference_mode()
 
192
  **{sampler_arg_name: sampler},
193
  )
194
  save_paths = []
195
+ safe_prompt = (
196
+ "".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
197
+ .rstrip()
198
+ .replace(" ", "_")[:50]
199
+ )
200
+
201
  for i, audio in enumerate(audios):
202
  audio = audio.float().cpu()
 
203
  audio = fade_out(audio, seq_cfg.sampling_rate)
204
 
 
 
 
 
 
205
  current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
206
  filename = f"{safe_prompt}_{current_time_string}_{i}.flac"
207
  save_path = OUTPUT_DIR / filename
 
231
  gr.Audio(label="🎵 Audio Sample 2"),
232
  gr.Textbox(label="Prompt Used", interactive=False)
233
  ],
234
+ additional_inputs=[
235
+ gr.Radio(
236
+ choices=[
237
+ ("🎵 Audio 1 更好", "audio1"),
238
+ ("🎵 Audio 2 更好", "audio2"),
239
+ ("😊 两者都很好", "equal"),
240
+ ("😔 两者都不好", "both_bad")
241
+ ],
242
+ label="🤔 请选择您更喜欢的音频:",
243
+ value=None
244
+ ),
245
+ gr.Textbox(
246
+ label="💭 评论 (可选)",
247
+ placeholder="您对音频质量的具体反馈...",
248
+ lines=2
249
+ )
250
+ ],
251
  title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
252
  description="🎯 **RLHF数据收集**: 现在生成2个音频样本!收集反馈数据用于改进模型。使用分析工具: `python analyze_feedback.py`",
253
+ flagging_mode="manual",
254
+ flagging_callback=lambda *args: save_preference_feedback_from_flag(*args),
255
  examples=[
256
  ["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"],
257
  ["Melodic human whistling harmonizing with natural birdsong", 10, 3, 1, "meanaudio_s_full"],
 
267
  ["A fork scrapes a plate, water drips slowly into a sink, and the faint hum of a refrigerator lingers in the background", 10, 3, 1, "meanaudio_s_full"]
268
  ],
269
  cache_examples="lazy", # Turn on to cache.
270
+ )
271
 
272
  if __name__ == "__main__":
273
  ensure_models_downloaded()