Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,016 Bytes
cc6558b 154718e cc6558b 83f3d87 cc6558b 474dcae cc6558b 154718e cc6558b 154718e cc6558b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import spaces
import time
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
from src_inference.pipeline import FluxPipeline
from src_inference.lora_helper import set_single_lora
import random
base_path = "black-forest-labs/FLUX.1-dev"
# Download OmniConsistency LoRA using hf_hub_download
omni_consistency_path = hf_hub_download(repo_id="showlab/OmniConsistency",
filename="OmniConsistency.safetensors",
local_dir="./Model")
# Initialize the pipeline with the model
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16).to("cuda")
# Set LoRA weights
set_single_lora(pipe.transformer, omni_consistency_path, lora_weights=[1], cond_size=512)
# Function to clear cache
def clear_cache(transformer):
for name, attn_processor in transformer.attn_processors.items():
attn_processor.bank_kv.clear()
# Function to download all LoRAs in advance
def download_all_loras():
lora_names = [
"3D_Chibi", "American_Cartoon", "Chinese_Ink",
"Clay_Toy", "Fabric", "Ghibli", "Irasutoya",
"Jojo", "LEGO", "Line", "Macaron",
"Oil_Painting", "Origami", "Paper_Cutting",
"Picasso", "Pixel", "Poly", "Pop_Art",
"Rick_Morty", "Snoopy", "Van_Gogh", "Vector"
]
for lora_name in lora_names:
hf_hub_download(repo_id="showlab/OmniConsistency",
filename=f"LoRAs/{lora_name}_rank128_bf16.safetensors",
local_dir="./LoRAs")
# Download all LoRAs in advance before the interface is launched
download_all_loras()
# Main function to generate the image
@spaces.GPU()
def generate_image(lora_name, prompt, uploaded_image, width, height, guidance_scale, num_inference_steps, seed):
# Download specific LoRA based on selection (use local directory as LoRAs are already downloaded)
lora_path = f"./LoRAs/LoRAs/{lora_name}_rank128_bf16.safetensors"
# Load the specific LoRA weights
pipe.unload_lora_weights()
pipe.load_lora_weights("./LoRAs/LoRAs", weight_name=f"{lora_name}_rank128_bf16.safetensors")
# Prepare input image
spatial_image = [uploaded_image.convert("RGB")]
subject_images = []
start_time = time.time()
# Generate the image
image = pipe(
prompt,
height=(int(height) // 8) * 8,
width=(int(width) // 8) * 8,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(seed),
spatial_images=spatial_image,
subject_images=subject_images,
cond_size=512,
).images[0]
end_time = time.time()
elapsed_time = end_time - start_time
print(f"code running time: {elapsed_time} s")
# Clear cache after generation
clear_cache(pipe.transformer)
return (uploaded_image, image)
# Example data
examples = [
["3D_Chibi", "3D Chibi style, Two smiling colleagues enthusiastically high-five in front of a whiteboard filled with technical notes about multimodal learning, reflecting a moment of success and collaboration at OpenAI.",
Image.open("./test_imgs/00.png"), 680, 1024, 3.5, 24, 42],
["Clay_Toy", "Clay Toy style, Three team members from OpenAI are gathered around a laptop in a cozy, festive setting, with holiday decorations in the background; one waves cheerfully while the others engage in light conversation, reflecting a relaxed and collaborative atmosphere.",
Image.open("./test_imgs/01.png"), 560, 1024, 3.5, 24, 42],
["American_Cartoon", "American Cartoon style, In a dramatic and comedic moment from a classic Chinese film, an intense elder with a white beard and red hat grips a younger man, declaring something with fervor, while the subtitle at the bottom reads 'I want them all' — capturing both tension and humor.",
Image.open("./test_imgs/02.png"), 568, 1024, 3.5, 24, 42],
["Origami", "Origami style, A thrilled fan wearing a Portugal football kit poses energetically with a smiling Cristiano Ronaldo, who gives a thumbs-up, as they stand side by side in a casual, cheerful moment—capturing the excitement of meeting a football legend.",
Image.open("./test_imgs/03.png"), 768, 672, 3.5, 24, 42],
["Macaron", "Macaron style, A man glances admiringly at a passing woman, while his girlfriend looks at him in disbelief, perfectly capturing the theme of shifting attention and misplaced priorities in a humorous, relatable way.",
Image.open("./test_imgs/04.png"), 696, 1024, 3.5, 24, 42]
]
# Gradio interface setup
def create_gradio_interface():
lora_names = [
"3D_Chibi", "American_Cartoon", "Chinese_Ink",
"Clay_Toy", "Fabric", "Ghibli", "Irasutoya",
"Jojo", "LEGO", "Line", "Macaron",
"Oil_Painting", "Origami", "Paper_Cutting",
"Picasso", "Pixel", "Poly", "Pop_Art",
"Rick_Morty", "Snoopy", "Van_Gogh", "Vector"
]
with gr.Blocks() as demo:
gr.Markdown("# OmniConsistency LoRA Image Generation")
gr.Markdown("Select a LoRA, enter a prompt, and upload an image to generate a new image with OmniConsistency. [View on GitHub](https://github.com/showlab/OmniConsistency)")
with gr.Row():
with gr.Column(scale=1):
lora_dropdown = gr.Dropdown(lora_names, label="Select LoRA")
prompt_box = gr.Textbox(label="Prompt", placeholder="Enter a prompt...")
image_input = gr.Image(type="pil", label="Upload Image")
with gr.Column(scale=1):
output_image = gr.ImageSlider(label="Generated Image")
width_box = gr.Textbox(label="Width", value="1024")
height_box = gr.Textbox(label="Height", value="1024")
guidance_slider = gr.Slider(minimum=0.1, maximum=20, value=3.5, step=0.1, label="Guidance Scale")
steps_slider = gr.Slider(minimum=1, maximum=50, value=25, step=1, label="Inference Steps")
seed_slider = gr.Slider(minimum=1, maximum=10000000000, value=42, step=1, label="Seed")
generate_button = gr.Button("Generate")
# Add examples for Generation
gr.Examples(
examples=examples,
inputs=[lora_dropdown, prompt_box, image_input, height_box, width_box, guidance_slider, steps_slider, seed_slider],
outputs=output_image,
fn=generate_image,
cache_examples=False,
label="Examples"
)
generate_button.click(
fn=generate_image,
inputs=[
lora_dropdown, prompt_box, image_input,
width_box, height_box, guidance_slider,
steps_slider, seed_slider
],
outputs=output_image
)
return demo
# Launch the Gradio interface
interface = create_gradio_interface()
interface.launch()
|