Vanisper commited on
Commit
8f08c38
·
1 Parent(s): 6ddf196

feat: 添加获取选中模型的功能并优化模型数据处理

Browse files
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -126,7 +126,7 @@ class Demo:
126
  model_checkbox = gr.Checkbox(label=model_name, value=False)
127
  seed_infr = gr.Number(label="种子值", value=42753)
128
  slider_scale_infr = gr.Slider(-4, 4, label="滑块刻度", value=3, info="较大的滑块刻度会导致更强的编辑效果")
129
- self.model_sections.append((model_checkbox.value, seed_infr, slider_scale_infr.value))
130
 
131
  with gr.Row():
132
  self.start_noise_infr = gr.Slider(
@@ -293,14 +293,24 @@ class Demo:
293
 
294
  return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), f'models/{save_name}', gr.update(choices=list(model_map.keys()), value=save_name.replace('.pt',''))]
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  def inference(self, prompt, start_noise, model, pbar=gr.Progress(track_tqdm=True)):
298
- model_sections = self.model_sections
299
- print(111, model_sections)
300
- model_names = [section[0] for section in model_sections if section[0]]
301
- seed_list = [section[1] for section in model_sections if section[0]]
302
- scale_list = [section[2] for section in model_sections if section[0]]
303
- print(222, model_names, seed_list, scale_list)
304
 
305
  if self.current_model != model:
306
  if model=='SDXL Turbo':
 
126
  model_checkbox = gr.Checkbox(label=model_name, value=False)
127
  seed_infr = gr.Number(label="种子值", value=42753)
128
  slider_scale_infr = gr.Slider(-4, 4, label="滑块刻度", value=3, info="较大的滑块刻度会导致更强的编辑效果")
129
+ self.model_sections.append(((model_checkbox.label, model_checkbox.value), seed_infr.value, slider_scale_infr.value))
130
 
131
  with gr.Row():
132
  self.start_noise_infr = gr.Slider(
 
293
 
294
  return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), f'models/{save_name}', gr.update(choices=list(model_map.keys()), value=save_name.replace('.pt',''))]
295
 
296
+ def get_selected_models(self):
297
+ # 过滤出被选中的模型数据
298
+ selected = [
299
+ (section[0][0], section[1], section[2]) # (label, seed, scale)
300
+ for section in self.model_sections
301
+ if section[0][1] # 检查 checkbox value 是否为 True
302
+ ]
303
+
304
+ if selected:
305
+ # 解包成三个数组
306
+ labels, seeds, scales = zip(*selected)
307
+ return list(labels), list(seeds), list(scales)
308
+ else:
309
+ return [], [], []
310
 
311
  def inference(self, prompt, start_noise, model, pbar=gr.Progress(track_tqdm=True)):
312
+ result = self.get_selected_models()
313
+ model_names, seed_list, scale_list = result
 
 
 
 
314
 
315
  if self.current_model != model:
316
  if model=='SDXL Turbo':