File size: 7,016 Bytes
cc6558b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154718e
cc6558b
 
 
83f3d87
 
 
 
 
 
 
 
 
 
cc6558b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474dcae
cc6558b
 
 
 
 
 
154718e
cc6558b
 
 
 
 
 
154718e
cc6558b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import spaces
import time
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
from src_inference.pipeline import FluxPipeline
from src_inference.lora_helper import set_single_lora
import random

base_path = "black-forest-labs/FLUX.1-dev"
    
# Download OmniConsistency LoRA using hf_hub_download
omni_consistency_path = hf_hub_download(repo_id="showlab/OmniConsistency", 
                                        filename="OmniConsistency.safetensors", 
                                        local_dir="./Model")

# Initialize the pipeline with the model
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16).to("cuda")

# Set LoRA weights
set_single_lora(pipe.transformer, omni_consistency_path, lora_weights=[1], cond_size=512)

# Function to clear cache
def clear_cache(transformer):
    for name, attn_processor in transformer.attn_processors.items():
        attn_processor.bank_kv.clear()

# Function to download all LoRAs in advance
def download_all_loras():
    lora_names = [
        "3D_Chibi", "American_Cartoon", "Chinese_Ink", 
        "Clay_Toy", "Fabric", "Ghibli", "Irasutoya",
        "Jojo", "LEGO", "Line", "Macaron",
        "Oil_Painting", "Origami", "Paper_Cutting", 
        "Picasso", "Pixel", "Poly", "Pop_Art", 
        "Rick_Morty", "Snoopy", "Van_Gogh", "Vector"
    ]
    for lora_name in lora_names:
        hf_hub_download(repo_id="showlab/OmniConsistency", 
                        filename=f"LoRAs/{lora_name}_rank128_bf16.safetensors", 
                        local_dir="./LoRAs")

# Download all LoRAs in advance before the interface is launched
download_all_loras()

# Main function to generate the image
@spaces.GPU()
def generate_image(lora_name, prompt, uploaded_image, width, height, guidance_scale, num_inference_steps, seed):
    # Download specific LoRA based on selection (use local directory as LoRAs are already downloaded)
    lora_path = f"./LoRAs/LoRAs/{lora_name}_rank128_bf16.safetensors"

    # Load the specific LoRA weights
    pipe.unload_lora_weights()
    pipe.load_lora_weights("./LoRAs/LoRAs", weight_name=f"{lora_name}_rank128_bf16.safetensors")

    # Prepare input image
    spatial_image = [uploaded_image.convert("RGB")]
    subject_images = []

    start_time = time.time()

    # Generate the image
    image = pipe(
        prompt,
        height=(int(height) // 8) * 8,
        width=(int(width) // 8) * 8,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        max_sequence_length=512,
        generator=torch.Generator("cpu").manual_seed(seed),
        spatial_images=spatial_image,
        subject_images=subject_images,
        cond_size=512,
    ).images[0]

    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"code running time: {elapsed_time} s")

    # Clear cache after generation
    clear_cache(pipe.transformer)

    return (uploaded_image, image)

# Example data
examples = [
    ["3D_Chibi", "3D Chibi style, Two smiling colleagues enthusiastically high-five in front of a whiteboard filled with technical notes about multimodal learning, reflecting a moment of success and collaboration at OpenAI.", 
     Image.open("./test_imgs/00.png"), 680, 1024, 3.5, 24, 42],
    ["Clay_Toy", "Clay Toy style, Three team members from OpenAI are gathered around a laptop in a cozy, festive setting, with holiday decorations in the background; one waves cheerfully while the others engage in light conversation, reflecting a relaxed and collaborative atmosphere.", 
     Image.open("./test_imgs/01.png"), 560, 1024, 3.5, 24, 42],
    ["American_Cartoon", "American Cartoon style, In a dramatic and comedic moment from a classic Chinese film, an intense elder with a white beard and red hat grips a younger man, declaring something with fervor, while the subtitle at the bottom reads 'I want them all' — capturing both tension and humor.",  
     Image.open("./test_imgs/02.png"), 568, 1024, 3.5, 24, 42],
    ["Origami", "Origami style, A thrilled fan wearing a Portugal football kit poses energetically with a smiling Cristiano Ronaldo, who gives a thumbs-up, as they stand side by side in a casual, cheerful moment—capturing the excitement of meeting a football legend.", 
     Image.open("./test_imgs/03.png"), 768, 672, 3.5, 24, 42],
    ["Macaron", "Macaron style, A man glances admiringly at a passing woman, while his girlfriend looks at him in disbelief, perfectly capturing the theme of shifting attention and misplaced priorities in a humorous, relatable way.", 
     Image.open("./test_imgs/04.png"), 696, 1024, 3.5, 24, 42]
]

# Gradio interface setup
def create_gradio_interface():
    lora_names = [
        "3D_Chibi", "American_Cartoon", "Chinese_Ink", 
        "Clay_Toy", "Fabric", "Ghibli", "Irasutoya",
        "Jojo", "LEGO", "Line", "Macaron",
        "Oil_Painting", "Origami", "Paper_Cutting", 
        "Picasso", "Pixel", "Poly", "Pop_Art", 
        "Rick_Morty", "Snoopy", "Van_Gogh", "Vector"
    ]

    with gr.Blocks() as demo:
        gr.Markdown("# OmniConsistency LoRA Image Generation")
        gr.Markdown("Select a LoRA, enter a prompt, and upload an image to generate a new image with OmniConsistency. [View on GitHub](https://github.com/showlab/OmniConsistency)")
        with gr.Row():
            with gr.Column(scale=1):
                lora_dropdown = gr.Dropdown(lora_names, label="Select LoRA")
                prompt_box = gr.Textbox(label="Prompt", placeholder="Enter a prompt...")
                image_input = gr.Image(type="pil", label="Upload Image")
            with gr.Column(scale=1):
                output_image = gr.ImageSlider(label="Generated Image")
                width_box = gr.Textbox(label="Width", value="1024")
                height_box = gr.Textbox(label="Height", value="1024")
                guidance_slider = gr.Slider(minimum=0.1, maximum=20, value=3.5, step=0.1, label="Guidance Scale")
                steps_slider = gr.Slider(minimum=1, maximum=50, value=25, step=1, label="Inference Steps")
                seed_slider = gr.Slider(minimum=1, maximum=10000000000, value=42, step=1, label="Seed")
                generate_button = gr.Button("Generate")
                
        # Add examples for Generation
        gr.Examples(
            examples=examples,
            inputs=[lora_dropdown, prompt_box, image_input, height_box, width_box, guidance_slider, steps_slider, seed_slider],
            outputs=output_image,
            fn=generate_image,
            cache_examples=False,
            label="Examples"
        )

        generate_button.click(
            fn=generate_image,
            inputs=[
                lora_dropdown, prompt_box, image_input,
                width_box, height_box, guidance_slider,
                steps_slider, seed_slider
            ],
            outputs=output_image
        )

    return demo


# Launch the Gradio interface
interface = create_gradio_interface()
interface.launch()