Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
import os | |
import json | |
import argparse | |
from diffusers import FluxTransformer2DModel, AutoencoderKL | |
from diffusers.hooks import apply_group_offloading | |
from transformers import T5EncoderModel, CLIPTextModel | |
from src.pipeline_tryon import FluxTryonPipeline | |
from optimum.quanto import freeze, qfloat8, quantize | |
device = torch.device("cuda") | |
torch_dtype = torch.bfloat16 # torch.float16 | |
def load_models(device=device, torch_dtype=torch_dtype,group_offloading=False): | |
bfl_repo = "Fynd/flux-dev-1-clone" | |
# Enable memory efficient attention | |
text_encoder = CLIPTextModel.from_pretrained(bfl_repo, subfolder="text_encoder", torch_dtype=torch_dtype,) | |
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=torch_dtype,) | |
transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder="transformer", torch_dtype=torch_dtype,) | |
vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder="vae", torch_dtype=torch_dtype) | |
# transformer = FluxTransformer2DModel.from_single_file("Kijai/flux-fp8/flux1-dev-fp8.safetensors", torch_dtype=torch_dtype) | |
pipe = FluxTryonPipeline.from_pretrained( | |
bfl_repo, | |
transformer=transformer, | |
text_encoder=text_encoder, | |
text_encoder_2=text_encoder_2, | |
vae=vae, | |
torch_dtype=torch_dtype, | |
)#.to(device="cpu", dtype=torch_dtype) | |
# pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True) # Do not use this if resolution can change | |
# # quantize transformer cause severe degration | |
# quantize(pipe.transformer, weights=qfloat8) | |
# freeze(pipe.transformer) | |
quantize(pipe.text_encoder_2, weights=qfloat8) | |
freeze(pipe.text_encoder_2) | |
# pipe.to(device=device) | |
# Enable memory efficient attention and VAE optimization | |
pipe.enable_attention_slicing() | |
pipe.vae.enable_slicing() | |
pipe.vae.enable_tiling() | |
pipe.enable_model_cpu_offload() | |
# pipe.enable_sequential_cpu_offload() | |
pipe.load_lora_weights( | |
"loooooong/Any2anyTryon", | |
weight_name="dev_lora_any2any_alltasks.safetensors", | |
adapter_name="tryon", | |
) | |
pipe.remove_all_hooks() | |
if group_offloading: | |
# https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#group-offloading | |
apply_group_offloading( | |
pipe.transformer, | |
offload_type="leaf_level", | |
offload_device=torch.device("cpu"), | |
onload_device=torch.device(device), | |
use_stream=True, | |
) | |
apply_group_offloading( | |
pipe.text_encoder, | |
offload_device=torch.device("cpu"), | |
onload_device=torch.device(device), | |
offload_type="leaf_level", | |
use_stream=True, | |
) | |
# apply_group_offloading( | |
# pipe.text_encoder_2, | |
# offload_device=torch.device("cpu"), | |
# onload_device=torch.device(device), | |
# offload_type="leaf_level", | |
# use_stream=True, | |
# ) | |
apply_group_offloading( | |
pipe.vae, | |
offload_device=torch.device("cpu"), | |
onload_device=torch.device(device), | |
offload_type="leaf_level", | |
use_stream=True, | |
) | |
pipe.to(device=device) | |
return pipe | |
def crop_to_multiple_of_16(img): | |
width, height = img.size | |
# Calculate new dimensions that are multiples of 8 | |
new_width = width - (width % 16) | |
new_height = height - (height % 16) | |
# Calculate crop box coordinates | |
left = (width - new_width) // 2 | |
top = (height - new_height) // 2 | |
right = left + new_width | |
bottom = top + new_height | |
# Crop the image | |
cropped_img = img.crop((left, top, right, bottom)) | |
return cropped_img | |
def resize_and_pad_to_size(image, target_width, target_height): | |
# Convert numpy array to PIL Image if needed | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
# Get original dimensions | |
orig_width, orig_height = image.size | |
# Calculate aspect ratios | |
target_ratio = target_width / target_height | |
orig_ratio = orig_width / orig_height | |
# Calculate new dimensions while maintaining aspect ratio | |
if orig_ratio > target_ratio: | |
# Image is wider than target ratio - scale by width | |
new_width = target_width | |
new_height = int(new_width / orig_ratio) | |
else: | |
# Image is taller than target ratio - scale by height | |
new_height = target_height | |
new_width = int(new_height * orig_ratio) | |
# Resize image | |
resized_image = image.resize((new_width, new_height)) | |
# Create white background image of target size | |
padded_image = Image.new('RGB', (target_width, target_height), 'white') | |
# Calculate padding to center the image | |
left_padding = (target_width - new_width) // 2 | |
top_padding = (target_height - new_height) // 2 | |
# Paste resized image onto padded background | |
padded_image.paste(resized_image, (left_padding, top_padding)) | |
return padded_image, left_padding, top_padding, target_width - new_width - left_padding, target_height - new_height - top_padding | |
def resize_by_height(image, height): | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
# image is a PIL image | |
image = image.resize((int(image.width * height / image.height), height)) | |
return crop_to_multiple_of_16(image) | |
# @spaces.GPU() | |
def generate_image(prompt, model_image, garment_image, height=512, width=384, seed=0, guidance_scale=3.5, show_type="follow model image", num_inference_steps=30): | |
height, width = int(height), int(width) | |
width = width - (width % 16) | |
height = height - (height % 16) | |
concat_image_list = [np.zeros((height, width, 3), dtype=np.uint8)] | |
has_model_image = model_image is not None | |
has_garment_image = garment_image is not None | |
if has_model_image: | |
if has_garment_image: | |
# if both model and garment image are provided, ensure model image and target image have the same size | |
input_height, input_width = model_image.shape[:2] | |
model_image, lp, tp, rp, bp = resize_and_pad_to_size(Image.fromarray(model_image), width, height) | |
else: | |
model_image = resize_by_height(model_image, height) | |
# model_image = resize_and_pad_to_size(Image.fromarray(model_image), width, height) | |
concat_image_list.append(model_image) | |
if has_garment_image: | |
# if has_model_image: | |
# garment_image = resize_and_pad_to_size(Image.fromarray(garment_image), width, height) | |
# else: | |
garment_image = resize_by_height(garment_image, height) | |
concat_image_list.append(garment_image) | |
image = np.concatenate([np.array(img) for img in concat_image_list], axis=1) | |
image = Image.fromarray(image) | |
mask = np.zeros_like(image) | |
mask[:,:width] = 255 | |
mask_image = Image.fromarray(mask) | |
assert height==image.height, "ensure same height" | |
# with torch.cuda.amp.autocast(): # this cause black image | |
# with torch.no_grad(): | |
output = pipe( | |
prompt, | |
image=image, | |
mask_image=mask_image, | |
strength=1., | |
height=height, | |
width=image.width, | |
target_width=width, | |
tryon=has_model_image and has_garment_image, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
max_sequence_length=512, | |
generator=torch.Generator().manual_seed(seed), | |
output_type="latent", | |
).images | |
latents = pipe._unpack_latents(output, image.height, image.width, pipe.vae_scale_factor) | |
if show_type!="all outputs": | |
latents = latents[:,:,:,:width//pipe.vae_scale_factor] | |
latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor | |
image = pipe.vae.decode(latents, return_dict=False)[0] | |
image = pipe.image_processor.postprocess(image, output_type="pil")[0] | |
output = image | |
if show_type=="follow model image" and has_model_image and has_garment_image: | |
output = output.crop((lp, tp, output.width-rp, output.height-bp)).resize((input_width, input_height)) | |
return output | |
def update_dimensions(model_image, garment_image, height, width, auto_ar): | |
if not auto_ar: | |
return height, width | |
if model_image is not None: | |
height = model_image.shape[0] | |
width = model_image.shape[1] | |
elif garment_image is not None: | |
height = garment_image.shape[0] | |
width = garment_image.shape[1] | |
else: | |
height = 512 | |
width = 384 | |
# Set max dimensions and minimum size | |
max_height = 1024 | |
max_width = 1024 | |
min_size = 384 | |
# Scale down if exceeds max dimensions while maintaining aspect ratio | |
if height > max_height or width > max_width: | |
aspect_ratio = width / height | |
if height > max_height: | |
height = max_height | |
width = int(height * aspect_ratio) | |
if width > max_width: | |
width = max_width | |
height = int(width / aspect_ratio) | |
# Scale up if below minimum size while maintaining aspect ratio | |
if height < min_size and width < min_size: | |
aspect_ratio = width / height | |
if height < width: | |
height = min_size | |
width = int(height * aspect_ratio) | |
else: | |
width = min_size | |
height = int(width / aspect_ratio) | |
return height, width | |
model1 = Image.open("asset/images/model/model1.png") | |
model2 = Image.open("asset/images/model/model2.jpg") | |
model3 = Image.open("asset/images/model/model3.png") | |
model4 = Image.open("asset/images/model/model4.png") | |
garment1 = Image.open("asset/images/garment/garment1.jpg") | |
garment2 = Image.open("asset/images/garment/garment2.jpg") | |
garment3 = Image.open("asset/images/garment/garment3.jpg") | |
garment4 = Image.open("asset/images/garment/garment4.jpg") | |
def launch_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Any2AnyTryon") | |
gr.Markdown("Demo(experimental) for [Any2AnyTryon: Leveraging Adaptive Position Embeddings for Versatile Virtual Clothing Tasks](https://arxiv.org/abs/2501.15891) ([Code](https://github.com/logn-2024/Any2anyTryon)).") | |
with gr.Row(): | |
with gr.Column(): | |
model_image = gr.Image(label="Model Image", type="numpy", interactive=True,) | |
with gr.Row(): | |
garment_image = gr.Image(label="Garment Image", type="numpy", interactive=True,) | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
info="Try example prompts from right side", | |
placeholder="Enter your prompt here...", | |
value="", | |
# visible=False, | |
) | |
with gr.Row(): | |
height = gr.Number(label="Height", value=576, precision=0) | |
width = gr.Number(label="Width", value=576, precision=0) | |
seed = gr.Number(label="Seed", value=0, precision=0) | |
with gr.Accordion("Advanced Settings", open=False): | |
guidance_scale = gr.Number(label="Guidance Scale", value=3.5) | |
num_inference_steps = gr.Number(label="Inference Steps", value=15) | |
show_type = gr.Radio(label="Show Type",choices=["follow model image", "follow height & width", "all outputs"],value="follow model image") | |
auto_ar = gr.Checkbox(label="Detect Image Size(From Uploaded Images)", value=False, visible=True,) | |
btn = gr.Button("Generate") | |
with gr.Column(): | |
output = gr.Image(label="Generated Image") | |
example_prompts = gr.Examples( | |
[ | |
"<MODEL> a person with fashion garment. <GARMENT> a garment. <TARGET> model with fashion garment", | |
"<MODEL> a person with fashion garment. <TARGET> the same garment laid flat.", | |
"<GARMENT> The image shows a fashion garment. <TARGET> a smiling person with the garment in white background", | |
], | |
inputs=prompt, | |
label="Example Prompts", | |
# visible=False | |
) | |
example_model = gr.Examples( | |
examples=[ | |
model1, model2, model3, model4 | |
], | |
inputs=model_image, | |
label="Example Model Images" | |
) | |
example_garment = gr.Examples( | |
examples=[ | |
garment1, garment2, garment3, garment4 | |
], | |
inputs=garment_image, | |
label="Example Garment Images" | |
) | |
# Update dimensions when images change | |
model_image.change(fn=update_dimensions, | |
inputs=[model_image, garment_image, height, width, auto_ar], | |
outputs=[height, width]) | |
garment_image.change(fn=update_dimensions, | |
inputs=[model_image, garment_image, height, width, auto_ar], | |
outputs=[height, width]) | |
btn.click(fn=generate_image, | |
inputs=[prompt, model_image, garment_image, height, width, seed, guidance_scale, show_type, num_inference_steps], | |
outputs=output) | |
demo.title = "FLUX Image Generation Demo" | |
demo.description = "Generate images using FLUX model with LoRA" | |
examples = [ | |
# tryon | |
[ | |
'''<MODEL> a man <GARMENT> a medium-sized, short-sleeved, blue t-shirt with a round neckline and a pocket on the front. <TARGET> model with fashion garment''', | |
model1, | |
garment1, | |
576, 576 | |
], | |
[ | |
'''<MODEL> a man with gray hair and a beard wearing a black jacket and sunglasses, standing in front of a body of water with mountains in the background and a cloudy sky above <GARMENT> a black and white striped t-shirt with a red heart embroidered on the chest <TARGET> ''', | |
model2, | |
garment2, | |
576, 576 | |
], | |
[ | |
'''<MODEL> a person with fashion garment. <GARMENT> a garment. <TARGET> model with fashion garment''', | |
model3, | |
garment3, | |
576, 576 | |
], | |
[ | |
'''<MODEL> a woman lift up her right leg. <GARMENT> a pair of black and white patterned pajama pants. <TARGET> model with fashion garment''', | |
model4, | |
garment4, | |
576, 576 | |
], | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[prompt, model_image, garment_image], | |
outputs=output, | |
fn=generate_image, | |
cache_examples=False, | |
examples_per_page=20 | |
) | |
demo.queue().launch(share=False, show_error=False, | |
server_name="0.0.0.0" | |
) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--group_offloading', action="store_true") | |
args=parser.parse_args() | |
pipe = load_models(group_offloading=args.group_offloading) | |
launch_demo() |