import spaces import os import time import torch import gradio as gr from PIL import Image from huggingface_hub import hf_hub_download, list_repo_files, login from src_inference.pipeline import FluxPipeline from src_inference.lora_helper import set_single_lora HF_TOKEN = os.environ.get("HF_TOKEN") if HF_TOKEN: login(token=HF_TOKEN) BASE_PATH = "black-forest-labs/FLUX.1-dev" LOCAL_LORA_DIR = "./LoRAs" CUSTOM_LORA_DIR = "./Custom_LoRAs" os.makedirs(LOCAL_LORA_DIR, exist_ok=True) os.makedirs(CUSTOM_LORA_DIR, exist_ok=True) print("downloading OmniConsistency base LoRA …") omni_consistency_path = hf_hub_download( repo_id="showlab/OmniConsistency", filename="OmniConsistency.safetensors", local_dir="./Model" ) print("loading base pipeline …") pipe = FluxPipeline.from_pretrained( BASE_PATH, torch_dtype=torch.bfloat16 ).to("cuda") set_single_lora(pipe.transformer, omni_consistency_path, lora_weights=[1], cond_size=512) lora_names = [ "3D_Chibi", "American_Cartoon", "Macaron", "Pixel", "Poly", "Van_Gogh" ] def download_all_loras(): for name in lora_names: hf_hub_download( repo_id="showlab/OmniConsistency", filename=f"LoRAs/{name}_rank128_bf16.safetensors", local_dir=LOCAL_LORA_DIR, ) download_all_loras() def reload_all_loras(): pipe.unload_lora_weights() for name in lora_names: pipe.load_lora_weights( f"{LOCAL_LORA_DIR}/LoRAs", weight_name=f"{name}_rank128_bf16.safetensors", adapter_name=name, ) reload_all_loras() def clear_cache(transformer): for _, attn_processor in transformer.attn_processors.items(): attn_processor.bank_kv.clear() @spaces.GPU(duration=30) def generate_image( lora_name, prompt, uploaded_image, guidance_scale, num_inference_steps, seed ): width, height = uploaded_image.size maxSize = 1024 factor = maxSize / max(width, height) width = int(width * factor) height = int(height * factor) generator = torch.Generator("cpu").manual_seed(seed) pipe.set_adapters(lora_name) spatial_image = [uploaded_image.convert("RGB")] subject_images = [] start = time.time() out_img = pipe( prompt, height=(height // 8) * 8, width=(width // 8) * 8, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, max_sequence_length=512, generator=generator, spatial_images=spatial_image, subject_images=subject_images, cond_size=512, ).images[0] print(f"inference time: {time.time()-start:.2f}s") clear_cache(pipe.transformer) return uploaded_image, out_img # =============== Gradio UI =============== def create_interface(): def update_trigger_word(lora_name, prompt): for name in lora_names: trigger = " ".join(name.split("_")) + " style," prompt = prompt.replace(trigger, "") new_trigger = " ".join(lora_name.split("_"))+ " style," return new_trigger + prompt header = """
arXiv HuggingFace GitHub
""" with gr.Blocks() as demo: gr.Markdown("# OmniConsistency LoRA Image Generation") gr.Markdown("Select a LoRA, enter a prompt, and upload an image to generate a new image with OmniConsistency.") gr.HTML(header) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload Image") prompt_box = gr.Textbox(label="Prompt", value="3D Chibi style,", info="Remember to include the necessary trigger words if you're using a custom LoRA." ) lora_dropdown = gr.Dropdown( lora_names, label="Select built-in LoRA") gen_btn = gr.Button("Generate") with gr.Column(scale=1): output_image = gr.ImageSlider(label="Generated Image") with gr.Accordion("Advanced Options", open=False): height_box = gr.Textbox(value="1024", label="Height") width_box = gr.Textbox(value="1024", label="Width") guidance_slider = gr.Slider( 0.1, 20, value=3.5, step=0.1, label="Guidance Scale") steps_slider = gr.Slider( 1, 50, value=25, step=1, label="Inference Steps") seed_slider = gr.Slider( 1, 2_147_483_647, value=42, step=1, label="Seed") lora_dropdown.select(fn=update_trigger_word, inputs=[lora_dropdown,prompt_box], outputs=prompt_box) gen_btn.click( fn=generate_image, inputs=[lora_dropdown, prompt_box, image_input, guidance_slider, steps_slider, seed_slider], outputs=output_image ) return demo if __name__ == "__main__": demo = create_interface() demo.launch(ssr_mode=False)