import gradio as gr
import torch
import os
from utils import call
from diffusers import (
DDPMScheduler,
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from diffusers.pipelines import StableDiffusionXLPipeline
StableDiffusionXLPipeline.__call__ = call
import os
from trainscripts.textsliders.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
from trainscripts.textsliders.demotrain import train_xl
os.environ['CURL_CA_BUNDLE'] = ''
model_map = {
'年龄调整': 'models/age.pt',
'体型丰满': 'models/chubby.pt',
'肌肉感': 'models/muscular.pt',
'惊讶表情': 'models/suprised_look.pt',
'微笑': 'models/smiling.pt',
'职业感': 'models/professional.pt',
'长发': 'models/long_hair.pt',
'卷发': 'models/curlyhair.pt',
'Pixar风格': 'models/pixar_style.pt',
'雕塑风格': 'models/sculpture_style.pt',
'陶土风格': 'models/clay_style.pt',
'修复图像': 'models/repair_slider.pt',
'修复手部': 'models/fix_hands.pt',
'杂乱房间': 'models/cluttered_room.pt',
'阴暗天气': 'models/dark_weather.pt',
'节日氛围': 'models/festive.pt',
'热带天气': 'models/tropical_weather.pt',
'冬季天气': 'models/winter_weather.pt',
'弯眉': 'models/eyebrow.pt',
'眼睛大小 (使用刻度 -3, -1, 1, 3)': 'models/eyesize.pt',
}
ORIGINAL_SPACE_ID = 'baulab/ConceptSliders'
SPACE_ID = os.getenv('SPACE_ID')
SHARED_UI_WARNING = f'''## 注意 - 在此共享UI中训练可能会很慢。您可以选择复制并使用至少40GB GPU的设备,或克隆此存储库以在自己的机器上运行。
'''
def merge_lora_networks(networks):
if not networks:
return None
base_network = networks[0]
for network in networks[1:]:
for name, param in network.named_parameters():
if name in base_network.state_dict():
base_network.state_dict()[name].add_(param)
else:
base_network.state_dict()[name] = param.clone()
return base_network
# 修改 call 方法以支持传递 networks 参数
# def rw_sd_call(self, *args, networks=None, scales=None, **kwargs):
# if networks is not None and scales is not None:
# for network, scale in zip(networks, scales):
# for name, param in network.named_parameters():
# if name in self.unet.state_dict():
# self.unet.state_dict()[name].add_(param * scale)
# else:
# self.unet.state_dict()[name] = param.clone() * scale
# return self.__original_call__(*args, **kwargs)
# StableDiffusionXLPipeline.__original_call__ = StableDiffusionXLPipeline.__call__
# StableDiffusionXLPipeline.__call__ = rw_sd_call
class Demo:
def __init__(self) -> None:
self.training = False
self.generating = False
self.weight_dtype = torch.bfloat16
self.model_sections = []
model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
if torch.cuda.is_available():
self.device = 'cuda'
else:
self.device = 'cpu'
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=self.weight_dtype).to(self.device)
pipe = None
del pipe
torch.cuda.empty_cache()
model_id = "stabilityai/sdxl-turbo"
self.current_model = 'SDXL Turbo'
euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=euler_anc, torch_dtype=self.weight_dtype).to(self.device)
if torch.cuda.is_available():
self.pipe.enable_xformers_memory_efficient_attention()
self.guidance_scale = 1
self.num_inference_steps = 3
self.seed = 42753 # 默认种子值
with gr.Blocks() as demo:
self.layout()
demo.queue(max_size=5).launch(share=True, max_threads=2)
def layout(self):
# with gr.Row():
# if SPACE_ID == ORIGINAL_SPACE_ID:
# self.warning = gr.Markdown(SHARED_UI_WARNING)
# with gr.Row():
with gr.Tab("测试") as inference_column:
with gr.Row():
self.explain_infr = gr.Markdown(value='这是[概念滑块:用于扩散模型的LoRA适配器](https://sliders.baulab.info/)的演示。要尝试可以控制特定概念的模型,请选择一个模型并输入任何提示词,选择一个种子值,最后选择SDEdit时间步以保持结构。较高的SDEdit时间步会导致更多的结构变化。例如,如果选择“惊讶表情”模型,可以生成提示词“A picture of a person, realistic, 8k”的图像,并将滑块效果与原始模型生成的图像进行比较。我们还提供了几个其他预先微调的模型,如“修复”滑块,用于修复SDXL生成图像中的缺陷(请查看“预训练滑块”下拉菜单)。您还可以训练和运行自己的自定义滑块。请查看“训练”部分以进行自定义概念滑块训练。当前推理正在运行SDXL Turbo!')
with gr.Row():
self.prompt_input_infr = gr.Text(
placeholder="photo of a person, with bokeh street background, realistic, 8k",
label="提示词",
info="生成图像的提示词",
value="photo of a person, with bokeh street background, realistic, 8k"
)
for model_name in model_map.keys():
with gr.Row():
model_checkbox = gr.Checkbox(label=model_name, value=False)
slider_scale_infr = gr.Slider(-4, 4, label="滑块刻度", value=3, info="较大的滑块刻度会导致更强的编辑效果")
self.model_sections.append(((model_checkbox.label, model_checkbox.value), slider_scale_infr.value))
# 添加复选框的change事件处理程序
model_checkbox.change(
fn=self.update_model_sections,
inputs=[gr.Text(value=f"{model_checkbox.label}"), model_checkbox, slider_scale_infr],
outputs=[]
)
with gr.Row():
self.seed_infr = gr.Number(label="种子值", value=self.seed)
self.start_noise_infr = gr.Slider(
600, 900,
value=750,
label="SDEdit时间步",
info="选择较小的值以保持更多结构"
)
self.model_type = gr.Dropdown(
label="模型",
choices=['SDXL Turbo', 'SDXL'],
value='SDXL Turbo',
interactive=True
)
with gr.Row():
self.infr_button = gr.Button(
value="生成",
interactive=True
)
with gr.Row():
self.image_orig = gr.Image(
label="原始SD",
interactive=False,
type='pil',
)
self.image_new = gr.Image(
label=f"概念滑块",
interactive=False,
type='pil',
)
with gr.Tab("训练") as training_column:
with gr.Row():
self.explain_train= gr.Markdown(value='在这一部分,您可以为Stable Diffusion XL训练文本概念滑块。输入您希望进行编辑的目标概念(例如:人)。接下来,输入您希望编辑的属性的增强提示词(例如:控制人的年龄,输入“person, old”)。然后,输入属性的抑制提示词(例如:输入“person, young”)。然后按“训练”按钮。使用默认设置,训练一个滑块大约需要25分钟;然后您可以在上面的“测试”选项卡中尝试推理或下载权重。为了更快的训练,请复制此存储库并使用A100或更大的GPU进行训练。代码和详细信息在[github链接](https://github.com/rohitgandikota/sliders)。')
with gr.Row():
with gr.Column(scale=3):
self.target_concept = gr.Text(
placeholder="输入要进行编辑的目标概念...",
label="编辑概念的提示词",
info="对应于要编辑的概念的提示词(例如:“person”)",
value = ''
)
self.positive_prompt = gr.Text(
placeholder="输入编辑的增强提示词...",
label="增强提示词",
info="对应于要增强的概念的提示词(例如:“person, old”)",
value = ''
)
self.negative_prompt = gr.Text(
placeholder="输入编辑的抑制提示词...",
label="抑制提示词",
info="对应于要抑制的概念的提示词(例如:“person, young”)",
value = ''
)
self.attributes_input = gr.Text(
placeholder="输入要保留的概念(用逗号分隔)。如果不需要,请留空...",
label="要保留的概念",
info="要保留/解缠的概念(例如:“male, female”)",
value = ''
)
self.is_person = gr.Checkbox(
label="人",
info="您是否在为人训练滑块?")
self.rank = gr.Number(
value=4,
label="滑块等级",
info='要训练的滑块等级'
)
choices = ['xattn', 'noxattn']
self.train_method_input = gr.Dropdown(
choices=choices,
value='xattn',
label='训练方法',
info='训练方法。如果[* xattn *] - loras将仅在交叉注意层上。如果[* noxattn *](官方实现) - 除交叉注意层外的所有层',
interactive=True
)
self.iterations_input = gr.Number(
value=500,
precision=0,
label="迭代次数",
info='用于训练的迭代次数 - 最大为1000'
)
self.lr_input = gr.Number(
value=2e-4,
label="学习率",
info='用于训练的学习率'
)
with gr.Column(scale=1):
self.train_status = gr.Button(value='', variant='primary', interactive=False)
self.train_button = gr.Button(
value="训练",
)
self.download = gr.Files()
self.model_dropdown = gr.Dropdown(choices=list(model_map.keys()))
self.infr_button.click(self.inference, inputs=[
self.prompt_input_infr,
self.start_noise_infr,
self.model_type,
self.seed_infr
],
outputs=[
self.image_new,
self.image_orig
]
)
self.train_button.click(self.train, inputs = [
self.target_concept,
self.positive_prompt,
self.negative_prompt,
self.rank,
self.iterations_input,
self.lr_input,
self.attributes_input,
self.is_person,
self.train_method_input
],
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
)
def update_model_sections(self, label, checkbox, scale):
for i, section in enumerate(self.model_sections):
if section[0][0] == label:
self.model_sections[i] = ((label, checkbox), scale)
break
def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, train_method_input, pbar = gr.Progress(track_tqdm=True)):
iterations_input = min(int(iterations_input),1000)
if attributes_input == '':
attributes_input = None
print(target_concept, positive_prompt, negative_prompt, attributes_input, is_person)
randn = torch.randint(1, 10000000, (1,)).item()
save_name = f"{randn}_{positive_prompt.replace(',','').replace(' ','').replace('.','')[:20]}"
save_name += f'_alpha-{1}'
save_name += f'_{train_method_input}'
save_name += f'_rank_{int(rank)}.pt'
# if torch.cuda.get_device_properties(0).total_memory * 1e-9 < 40:
# return [gr.update(interactive=True, value='Train'), gr.update(value='GPU Memory is not enough for training... Please upgrade to GPU atleast 40GB or clone the repo to your local machine.'), None, gr.update()]
if self.training:
return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
attributes = attributes_input
if is_person:
attributes = 'white, black, asian, hispanic, indian, male, female'
self.training = True
train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=int(rank), train_method=train_method_input, device=self.device, attributes=attributes, save_name=save_name)
self.training = False
torch.cuda.empty_cache()
model_map[save_name.replace('.pt','')] = f'models/{save_name}'
return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), f'models/{save_name}', gr.update(choices=list(model_map.keys()), value=save_name.replace('.pt',''))]
def get_selected_models(self):
# 过滤出被选中的模型数据
selected = [
(section[0][0], section[1]) # (label, scale)
for section in self.model_sections
if section[0][1] # 检查 checkbox value 是否为 True
]
if selected:
# 解包成两个数组
labels, scales = zip(*selected)
return list(labels), list(scales)
else:
return [], []
def inference(self, prompt, start_noise, model, seed, pbar=gr.Progress(track_tqdm=True)):
self.seed = seed # 更新种子值
result = self.get_selected_models()
print(111, self.model_sections)
model_names, scale_list = result
print(222, model_names, scale_list)
if self.current_model != model:
if model=='SDXL Turbo':
model_id = "stabilityai/sdxl-turbo"
euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=euler_anc, torch_dtype=self.weight_dtype).to(self.device)
if torch.cuda.is_available():
self.pipe.enable_xformers_memory_efficient_attention()
self.guidance_scale = 1
self.num_inference_steps = 3
self.current_model = 'SDXL Turbo'
else:
model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=self.weight_dtype).to(self.device)
if torch.cuda.is_available():
self.pipe.enable_xformers_memory_efficient_attention()
self.guidance_scale = 7.5
self.num_inference_steps = 20
self.current_model = 'SDXL'
networks = []
for model_name in model_names:
model_path = model_map[model_name]
unet = self.pipe.unet
network_type = "c3lier"
if 'full' in model_path:
train_method = 'full'
elif 'noxattn' in model_path:
train_method = 'noxattn'
elif 'xattn' in model_path:
train_method = 'xattn'
network_type = 'lierla'
else:
train_method = 'noxattn'
modules = DEFAULT_TARGET_REPLACE
if network_type == "c3lier":
modules += UNET_TARGET_REPLACE_MODULE_CONV
name = os.path.basename(model_path)
rank = 4
alpha = 1
if 'rank' in model_path:
rank = int(float(model_path.split('_')[-1].replace('.pt','')))
if 'alpha1' in model_path:
alpha = 1.0
network = LoRANetwork(
unet,
rank=rank,
multiplier=1.0,
alpha=alpha,
train_method=train_method,
).to(self.device, dtype=self.weight_dtype)
network.load_state_dict(torch.load(model_path, weights_only=True))
networks.append(network)
# 设置种子
generator = torch.manual_seed(self.seed)
# 生成编辑后的图像(应用多权重)
edited_image = self.pipe(
prompt,
num_images_per_prompt=1,
num_inference_steps=self.num_inference_steps,
generator=generator,
networks=networks, # 加载多个 LoRA 模型
start_noise=int(start_noise),
scales=scale_list, # 设置每个 LoRA 的权重
unet=unet,
guidance_scale=self.guidance_scale
).images[0]
# 生成原始图像(不应用权重)
generator = torch.manual_seed(self.seed)
original_image = self.pipe(
prompt,
num_images_per_prompt=1,
num_inference_steps=self.num_inference_steps,
generator=generator,
networks=networks,
start_noise=int(start_noise),
scales=[0] * len(networks), # 不设置任何权重
unet=unet,
guidance_scale=self.guidance_scale
).images[0]
del unet, networks
unet = None
networks = None
torch.cuda.empty_cache()
return edited_image, original_image
demo = Demo()