Spaces:
Sleeping
Sleeping
fix: 去除多种子的业务逻辑
Browse files
app.py
CHANGED
@@ -91,6 +91,7 @@ class Demo:
|
|
91 |
|
92 |
self.guidance_scale = 1
|
93 |
self.num_inference_steps = 3
|
|
|
94 |
|
95 |
with gr.Blocks() as demo:
|
96 |
self.layout()
|
@@ -124,18 +125,18 @@ class Demo:
|
|
124 |
for model_name in model_map.keys():
|
125 |
with gr.Row():
|
126 |
model_checkbox = gr.Checkbox(label=model_name, value=False)
|
127 |
-
seed_infr = gr.Number(label="种子值", value=42753)
|
128 |
slider_scale_infr = gr.Slider(-4, 4, label="滑块刻度", value=3, info="较大的滑块刻度会导致更强的编辑效果")
|
129 |
-
self.model_sections.append(((model_checkbox.label, model_checkbox.value),
|
130 |
|
131 |
# 添加复选框的change事件处理程序
|
132 |
model_checkbox.change(
|
133 |
fn=self.update_model_sections,
|
134 |
-
inputs=[gr.Text(value=f"{model_checkbox.label}"), model_checkbox,
|
135 |
outputs=[]
|
136 |
)
|
137 |
|
138 |
with gr.Row():
|
|
|
139 |
self.start_noise_infr = gr.Slider(
|
140 |
600, 900,
|
141 |
value=750,
|
@@ -249,7 +250,8 @@ class Demo:
|
|
249 |
self.infr_button.click(self.inference, inputs=[
|
250 |
self.prompt_input_infr,
|
251 |
self.start_noise_infr,
|
252 |
-
self.model_type
|
|
|
253 |
],
|
254 |
outputs=[
|
255 |
self.image_new,
|
@@ -270,10 +272,10 @@ class Demo:
|
|
270 |
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
271 |
)
|
272 |
|
273 |
-
def update_model_sections(self, label, checkbox,
|
274 |
for i, section in enumerate(self.model_sections):
|
275 |
if section[0][0] == label:
|
276 |
-
self.model_sections[i] = ((label, checkbox),
|
277 |
break
|
278 |
|
279 |
def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, train_method_input, pbar = gr.Progress(track_tqdm=True)):
|
@@ -309,23 +311,24 @@ class Demo:
|
|
309 |
def get_selected_models(self):
|
310 |
# 过滤出被选中的模型数据
|
311 |
selected = [
|
312 |
-
(section[0][0], section[1]
|
313 |
for section in self.model_sections
|
314 |
if section[0][1] # 检查 checkbox value 是否为 True
|
315 |
]
|
316 |
|
317 |
if selected:
|
318 |
-
#
|
319 |
-
labels,
|
320 |
-
return list(labels), list(
|
321 |
else:
|
322 |
-
return [], []
|
323 |
|
324 |
-
def inference(self, prompt, start_noise, model, pbar=gr.Progress(track_tqdm=True)):
|
|
|
325 |
result = self.get_selected_models()
|
326 |
print(111, self.model_sections)
|
327 |
-
model_names,
|
328 |
-
print(222, model_names,
|
329 |
|
330 |
if self.current_model != model:
|
331 |
if model=='SDXL Turbo':
|
@@ -347,7 +350,7 @@ class Demo:
|
|
347 |
self.current_model = 'SDXL'
|
348 |
|
349 |
networks = []
|
350 |
-
for
|
351 |
model_path = model_map[model_name]
|
352 |
unet = self.pipe.unet
|
353 |
network_type = "c3lier"
|
@@ -380,17 +383,32 @@ class Demo:
|
|
380 |
train_method=train_method,
|
381 |
).to(self.device, dtype=self.weight_dtype)
|
382 |
network.load_state_dict(torch.load(model_path))
|
383 |
-
networks.append(
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
generator
|
393 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
|
395 |
del unet, networks
|
396 |
unet = None
|
|
|
91 |
|
92 |
self.guidance_scale = 1
|
93 |
self.num_inference_steps = 3
|
94 |
+
self.seed = 42753 # 默认种子值
|
95 |
|
96 |
with gr.Blocks() as demo:
|
97 |
self.layout()
|
|
|
125 |
for model_name in model_map.keys():
|
126 |
with gr.Row():
|
127 |
model_checkbox = gr.Checkbox(label=model_name, value=False)
|
|
|
128 |
slider_scale_infr = gr.Slider(-4, 4, label="滑块刻度", value=3, info="较大的滑块刻度会导致更强的编辑效果")
|
129 |
+
self.model_sections.append(((model_checkbox.label, model_checkbox.value), slider_scale_infr.value))
|
130 |
|
131 |
# 添加复选框的change事件处理程序
|
132 |
model_checkbox.change(
|
133 |
fn=self.update_model_sections,
|
134 |
+
inputs=[gr.Text(value=f"{model_checkbox.label}"), model_checkbox, slider_scale_infr],
|
135 |
outputs=[]
|
136 |
)
|
137 |
|
138 |
with gr.Row():
|
139 |
+
self.seed_infr = gr.Number(label="种子值", value=self.seed)
|
140 |
self.start_noise_infr = gr.Slider(
|
141 |
600, 900,
|
142 |
value=750,
|
|
|
250 |
self.infr_button.click(self.inference, inputs=[
|
251 |
self.prompt_input_infr,
|
252 |
self.start_noise_infr,
|
253 |
+
self.model_type,
|
254 |
+
self.seed_infr
|
255 |
],
|
256 |
outputs=[
|
257 |
self.image_new,
|
|
|
272 |
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
273 |
)
|
274 |
|
275 |
+
def update_model_sections(self, label, checkbox, scale):
|
276 |
for i, section in enumerate(self.model_sections):
|
277 |
if section[0][0] == label:
|
278 |
+
self.model_sections[i] = ((label, checkbox), scale)
|
279 |
break
|
280 |
|
281 |
def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, train_method_input, pbar = gr.Progress(track_tqdm=True)):
|
|
|
311 |
def get_selected_models(self):
|
312 |
# 过滤出被选中的模型数据
|
313 |
selected = [
|
314 |
+
(section[0][0], section[1]) # (label, scale)
|
315 |
for section in self.model_sections
|
316 |
if section[0][1] # 检查 checkbox value 是否为 True
|
317 |
]
|
318 |
|
319 |
if selected:
|
320 |
+
# 解包成两个数组
|
321 |
+
labels, scales = zip(*selected)
|
322 |
+
return list(labels), list(scales)
|
323 |
else:
|
324 |
+
return [], []
|
325 |
|
326 |
+
def inference(self, prompt, start_noise, model, seed, pbar=gr.Progress(track_tqdm=True)):
|
327 |
+
self.seed = seed # 更新种子值
|
328 |
result = self.get_selected_models()
|
329 |
print(111, self.model_sections)
|
330 |
+
model_names, scale_list = result
|
331 |
+
print(222, model_names, scale_list)
|
332 |
|
333 |
if self.current_model != model:
|
334 |
if model=='SDXL Turbo':
|
|
|
350 |
self.current_model = 'SDXL'
|
351 |
|
352 |
networks = []
|
353 |
+
for model_name in model_names:
|
354 |
model_path = model_map[model_name]
|
355 |
unet = self.pipe.unet
|
356 |
network_type = "c3lier"
|
|
|
383 |
train_method=train_method,
|
384 |
).to(self.device, dtype=self.weight_dtype)
|
385 |
network.load_state_dict(torch.load(model_path))
|
386 |
+
networks.append(network)
|
387 |
+
|
388 |
+
# 设置种子
|
389 |
+
generator = torch.manual_seed(self.seed)
|
390 |
+
# 生成编辑后的图像(应用多权重)
|
391 |
+
edited_image = self.pipe(
|
392 |
+
prompt,
|
393 |
+
num_images_per_prompt=1,
|
394 |
+
num_inference_steps=self.num_inference_steps,
|
395 |
+
generator=generator,
|
396 |
+
networks=networks, # 加载多个 LoRA 模型
|
397 |
+
scales=scale_list, # 设置每个 LoRA 的权重
|
398 |
+
guidance_scale=self.guidance_scale
|
399 |
+
).images[0]
|
400 |
+
|
401 |
+
# 生成原始图像(不应用权重)
|
402 |
+
generator = torch.manual_seed(self.seed)
|
403 |
+
original_image = self.pipe(
|
404 |
+
prompt,
|
405 |
+
num_images_per_prompt=1,
|
406 |
+
num_inference_steps=self.num_inference_steps,
|
407 |
+
generator=generator,
|
408 |
+
networks=[], # 不加载任何 LoRA 模型
|
409 |
+
scales=[], # 不设置任何权重
|
410 |
+
guidance_scale=self.guidance_scale
|
411 |
+
).images[0]
|
412 |
|
413 |
del unet, networks
|
414 |
unet = None
|