Vanisper commited on
Commit
a819b10
·
1 Parent(s): af05561

fix: 去除多种子的业务逻辑

Browse files
Files changed (1) hide show
  1. app.py +44 -26
app.py CHANGED
@@ -91,6 +91,7 @@ class Demo:
91
 
92
  self.guidance_scale = 1
93
  self.num_inference_steps = 3
 
94
 
95
  with gr.Blocks() as demo:
96
  self.layout()
@@ -124,18 +125,18 @@ class Demo:
124
  for model_name in model_map.keys():
125
  with gr.Row():
126
  model_checkbox = gr.Checkbox(label=model_name, value=False)
127
- seed_infr = gr.Number(label="种子值", value=42753)
128
  slider_scale_infr = gr.Slider(-4, 4, label="滑块刻度", value=3, info="较大的滑块刻度会导致更强的编辑效果")
129
- self.model_sections.append(((model_checkbox.label, model_checkbox.value), seed_infr.value, slider_scale_infr.value))
130
 
131
  # 添加复选框的change事件处理程序
132
  model_checkbox.change(
133
  fn=self.update_model_sections,
134
- inputs=[gr.Text(value=f"{model_checkbox.label}"), model_checkbox, seed_infr, slider_scale_infr],
135
  outputs=[]
136
  )
137
 
138
  with gr.Row():
 
139
  self.start_noise_infr = gr.Slider(
140
  600, 900,
141
  value=750,
@@ -249,7 +250,8 @@ class Demo:
249
  self.infr_button.click(self.inference, inputs=[
250
  self.prompt_input_infr,
251
  self.start_noise_infr,
252
- self.model_type
 
253
  ],
254
  outputs=[
255
  self.image_new,
@@ -270,10 +272,10 @@ class Demo:
270
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
271
  )
272
 
273
- def update_model_sections(self, label, checkbox, seed, scale):
274
  for i, section in enumerate(self.model_sections):
275
  if section[0][0] == label:
276
- self.model_sections[i] = ((label, checkbox), seed, scale)
277
  break
278
 
279
  def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, train_method_input, pbar = gr.Progress(track_tqdm=True)):
@@ -309,23 +311,24 @@ class Demo:
309
  def get_selected_models(self):
310
  # 过滤出被选中的模型数据
311
  selected = [
312
- (section[0][0], section[1], section[2]) # (label, seed, scale)
313
  for section in self.model_sections
314
  if section[0][1] # 检查 checkbox value 是否为 True
315
  ]
316
 
317
  if selected:
318
- # 解包成三个数组
319
- labels, seeds, scales = zip(*selected)
320
- return list(labels), list(seeds), list(scales)
321
  else:
322
- return [], [], []
323
 
324
- def inference(self, prompt, start_noise, model, pbar=gr.Progress(track_tqdm=True)):
 
325
  result = self.get_selected_models()
326
  print(111, self.model_sections)
327
- model_names, seed_list, scale_list = result
328
- print(222, model_names, seed_list, scale_list)
329
 
330
  if self.current_model != model:
331
  if model=='SDXL Turbo':
@@ -347,7 +350,7 @@ class Demo:
347
  self.current_model = 'SDXL'
348
 
349
  networks = []
350
- for i, model_name in enumerate(model_names):
351
  model_path = model_map[model_name]
352
  unet = self.pipe.unet
353
  network_type = "c3lier"
@@ -380,17 +383,32 @@ class Demo:
380
  train_method=train_method,
381
  ).to(self.device, dtype=self.weight_dtype)
382
  network.load_state_dict(torch.load(model_path))
383
- networks.append((network, seed_list[i], scale_list[i]))
384
-
385
- generator = torch.manual_seed(seed_list[0])
386
- edited_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, network=networks[0][0], start_noise=int(start_noise), scale=networks[0][2], unet=unet, guidance_scale=self.guidance_scale).images[0]
387
-
388
- generator = torch.manual_seed(seed_list[0])
389
- original_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, network=networks[0][0], start_noise=start_noise, scale=0, unet=unet, guidance_scale=self.guidance_scale).images[0]
390
-
391
- for network, seed, scale in networks[1:]:
392
- generator = torch.manual_seed(seed)
393
- edited_image = self.pipe(prompt, num_images_per_prompt=1, num_inference_steps=self.num_inference_steps, generator=generator, network=network, start_noise=int(start_noise), scale=scale, unet=unet, guidance_scale=self.guidance_scale).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
  del unet, networks
396
  unet = None
 
91
 
92
  self.guidance_scale = 1
93
  self.num_inference_steps = 3
94
+ self.seed = 42753 # 默认种子值
95
 
96
  with gr.Blocks() as demo:
97
  self.layout()
 
125
  for model_name in model_map.keys():
126
  with gr.Row():
127
  model_checkbox = gr.Checkbox(label=model_name, value=False)
 
128
  slider_scale_infr = gr.Slider(-4, 4, label="滑块刻度", value=3, info="较大的滑块刻度会导致更强的编辑效果")
129
+ self.model_sections.append(((model_checkbox.label, model_checkbox.value), slider_scale_infr.value))
130
 
131
  # 添加复选框的change事件处理程序
132
  model_checkbox.change(
133
  fn=self.update_model_sections,
134
+ inputs=[gr.Text(value=f"{model_checkbox.label}"), model_checkbox, slider_scale_infr],
135
  outputs=[]
136
  )
137
 
138
  with gr.Row():
139
+ self.seed_infr = gr.Number(label="种子值", value=self.seed)
140
  self.start_noise_infr = gr.Slider(
141
  600, 900,
142
  value=750,
 
250
  self.infr_button.click(self.inference, inputs=[
251
  self.prompt_input_infr,
252
  self.start_noise_infr,
253
+ self.model_type,
254
+ self.seed_infr
255
  ],
256
  outputs=[
257
  self.image_new,
 
272
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
273
  )
274
 
275
+ def update_model_sections(self, label, checkbox, scale):
276
  for i, section in enumerate(self.model_sections):
277
  if section[0][0] == label:
278
+ self.model_sections[i] = ((label, checkbox), scale)
279
  break
280
 
281
  def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, train_method_input, pbar = gr.Progress(track_tqdm=True)):
 
311
  def get_selected_models(self):
312
  # 过滤出被选中的模型数据
313
  selected = [
314
+ (section[0][0], section[1]) # (label, scale)
315
  for section in self.model_sections
316
  if section[0][1] # 检查 checkbox value 是否为 True
317
  ]
318
 
319
  if selected:
320
+ # 解包成两个数组
321
+ labels, scales = zip(*selected)
322
+ return list(labels), list(scales)
323
  else:
324
+ return [], []
325
 
326
+ def inference(self, prompt, start_noise, model, seed, pbar=gr.Progress(track_tqdm=True)):
327
+ self.seed = seed # 更新种子值
328
  result = self.get_selected_models()
329
  print(111, self.model_sections)
330
+ model_names, scale_list = result
331
+ print(222, model_names, scale_list)
332
 
333
  if self.current_model != model:
334
  if model=='SDXL Turbo':
 
350
  self.current_model = 'SDXL'
351
 
352
  networks = []
353
+ for model_name in model_names:
354
  model_path = model_map[model_name]
355
  unet = self.pipe.unet
356
  network_type = "c3lier"
 
383
  train_method=train_method,
384
  ).to(self.device, dtype=self.weight_dtype)
385
  network.load_state_dict(torch.load(model_path))
386
+ networks.append(network)
387
+
388
+ # 设置种子
389
+ generator = torch.manual_seed(self.seed)
390
+ # 生成编辑后的图像(应用多权重)
391
+ edited_image = self.pipe(
392
+ prompt,
393
+ num_images_per_prompt=1,
394
+ num_inference_steps=self.num_inference_steps,
395
+ generator=generator,
396
+ networks=networks, # 加载多个 LoRA 模型
397
+ scales=scale_list, # 设置每个 LoRA 的权重
398
+ guidance_scale=self.guidance_scale
399
+ ).images[0]
400
+
401
+ # 生成原始图像(不应用权重)
402
+ generator = torch.manual_seed(self.seed)
403
+ original_image = self.pipe(
404
+ prompt,
405
+ num_images_per_prompt=1,
406
+ num_inference_steps=self.num_inference_steps,
407
+ generator=generator,
408
+ networks=[], # 不加载任何 LoRA 模型
409
+ scales=[], # 不设置任何权重
410
+ guidance_scale=self.guidance_scale
411
+ ).images[0]
412
 
413
  del unet, networks
414
  unet = None