text2tag-llm / genimage.py
John6666's picture
Upload 4 files
4923b8f verified
import spaces
import gradio as gr
import torch
import gc, os, uuid, json
from PIL import PngImagePlugin
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if os.getenv("SPACES_ZERO_GPU", None):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high") # https://pytorch.org/blog/accelerating-generative-ai-3/
def load_pipeline():
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(
"John6666/rae-diffusion-xl-v2-sdxl-spo-pcm",
custom_pipeline="lpw_stable_diffusion_xl",
#custom_pipeline="nyanko7/sdxl_smoothed_energy_guidance",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
pipe.to("cpu")
return pipe
def token_auto_concat_embeds(pipe, positive, negative):
max_length = pipe.tokenizer.model_max_length
positive_length = pipe.tokenizer(positive, return_tensors="pt").input_ids.shape[-1]
negative_length = pipe.tokenizer(negative, return_tensors="pt").input_ids.shape[-1]
print(f'Token length is model maximum: {max_length}, positive length: {positive_length}, negative length: {negative_length}.')
if max_length < positive_length or max_length < negative_length:
print('Concatenated embedding.')
if positive_length > negative_length:
positive_ids = pipe.tokenizer(positive, return_tensors="pt").input_ids.to("cuda")
negative_ids = pipe.tokenizer(negative, truncation=False, padding="max_length", max_length=positive_ids.shape[-1], return_tensors="pt").input_ids.to("cuda")
else:
negative_ids = pipe.tokenizer(negative, return_tensors="pt").input_ids.to("cuda")
positive_ids = pipe.tokenizer(positive, truncation=False, padding="max_length", max_length=negative_ids.shape[-1], return_tensors="pt").input_ids.to("cuda")
else:
positive_ids = pipe.tokenizer(positive, truncation=False, padding="max_length", max_length=max_length, return_tensors="pt").input_ids.to("cuda")
negative_ids = pipe.tokenizer(negative, truncation=False, padding="max_length", max_length=max_length, return_tensors="pt").input_ids.to("cuda")
positive_concat_embeds = []
negative_concat_embeds = []
for i in range(0, positive_ids.shape[-1], max_length):
positive_concat_embeds.append(pipe.text_encoder(positive_ids[:, i: i + max_length])[0])
negative_concat_embeds.append(pipe.text_encoder(negative_ids[:, i: i + max_length])[0])
positive_prompt_embeds = torch.cat(positive_concat_embeds, dim=1)
negative_prompt_embeds = torch.cat(negative_concat_embeds, dim=1)
return positive_prompt_embeds, negative_prompt_embeds
def save_image(image, metadata, output_dir):
filename = str(uuid.uuid4()) + ".png"
os.makedirs(output_dir, exist_ok=True)
filepath = os.path.join(output_dir, filename)
metadata_str = json.dumps(metadata)
info = PngImagePlugin.PngInfo()
info.add_text("metadata", metadata_str)
image.save(filepath, "PNG", pnginfo=info)
return filepath
pipe = load_pipeline()
@torch.inference_mode()
@spaces.GPU(duration=10)
def generate_image(prompt, neg_prompt, progress=gr.Progress(track_tqdm=True)):
pipe.to(device)
prompt += ", anime, masterpiece, best quality, very aesthetic, absurdres"
neg_prompt += ", bad hands, bad feet, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], photo, deformed, disfigured, low contrast, photo, deformed, disfigured, low contrast"
metadata = {
"prompt": prompt,
"negative_prompt": neg_prompt,
"resolution": f"{1024} x {1024}",
"guidance_scale": 7.0,
"num_inference_steps": 28,
"sampler": "Euler",
}
try:
#positive_embeds, negative_embeds = token_auto_concat_embeds(pipe, prompt, neg_prompt)
images = pipe(
prompt=prompt,
negative_prompt=neg_prompt,
width=1024,
height=1024,
guidance_scale=7.0,# seg_scale=3.0, seg_applied_layers=["mid"],
num_inference_steps=28,
output_type="pil",
clip_skip=2,
).images
if images:
image_paths = [
save_image(image, metadata, "./outputs")
for image in images
]
return image_paths
except Exception as e:
print(e)
return []
finally:
pipe.to("cpu")
torch.cuda.empty_cache()
gc.collect()