yiren98 commited on
Commit
d45d6cb
·
verified ·
1 Parent(s): fdf7ba1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -104
app.py CHANGED
@@ -1,103 +1,143 @@
1
  import spaces
 
2
  import time
3
  import torch
4
  import gradio as gr
5
  from PIL import Image
6
- from huggingface_hub import hf_hub_download
7
  from src_inference.pipeline import FluxPipeline
8
  from src_inference.lora_helper import set_single_lora
9
- import random
10
 
11
- base_path = "black-forest-labs/FLUX.1-dev"
12
-
13
- # Download OmniConsistency LoRA using hf_hub_download
14
- omni_consistency_path = hf_hub_download(repo_id="showlab/OmniConsistency",
15
- filename="OmniConsistency.safetensors",
16
- local_dir="./Model")
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # Initialize the pipeline with the model
19
- pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16).to("cuda")
20
-
21
- # Set LoRA weights
22
- set_single_lora(pipe.transformer, omni_consistency_path, lora_weights=[1], cond_size=512)
23
-
24
- # Function to clear cache
25
- def clear_cache(transformer):
26
- for name, attn_processor in transformer.attn_processors.items():
27
- attn_processor.bank_kv.clear()
28
-
29
- # Function to download all LoRAs in advance
30
  def download_all_loras():
31
  lora_names = [
32
- "3D_Chibi", "American_Cartoon", "Chinese_Ink",
33
- "Clay_Toy", "Fabric", "Ghibli", "Irasutoya",
34
- "Jojo", "LEGO", "Line", "Macaron",
35
- "Oil_Painting", "Origami", "Paper_Cutting",
36
- "Picasso", "Pixel", "Poly", "Pop_Art",
37
- "Rick_Morty", "Snoopy", "Van_Gogh", "Vector"
38
  ]
39
- for lora_name in lora_names:
40
- hf_hub_download(repo_id="showlab/OmniConsistency",
41
- filename=f"LoRAs/{lora_name}_rank128_bf16.safetensors",
42
- local_dir="./LoRAs")
43
-
44
- # Download all LoRAs in advance before the interface is launched
45
  download_all_loras()
46
 
47
- # Main function to generate the image
 
 
 
48
  @spaces.GPU()
49
- def generate_image(lora_name, prompt, uploaded_image, width, height, guidance_scale, num_inference_steps, seed):
50
- # Download specific LoRA based on selection (use local directory as LoRAs are already downloaded)
51
- lora_path = f"./LoRAs/LoRAs/{lora_name}_rank128_bf16.safetensors"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # Load the specific LoRA weights
54
  pipe.unload_lora_weights()
55
- pipe.load_lora_weights("./LoRAs/LoRAs", weight_name=f"{lora_name}_rank128_bf16.safetensors")
 
 
 
 
 
 
56
 
57
- # Prepare input image
58
- spatial_image = [uploaded_image.convert("RGB")]
59
  subject_images = []
60
-
61
- start_time = time.time()
62
-
63
- # Generate the image
64
- image = pipe(
65
  prompt,
66
- height=(int(height) // 8) * 8,
67
- width=(int(width) // 8) * 8,
68
  guidance_scale=guidance_scale,
69
  num_inference_steps=num_inference_steps,
70
  max_sequence_length=512,
71
- generator=torch.Generator("cpu").manual_seed(seed),
72
  spatial_images=spatial_image,
73
  subject_images=subject_images,
74
  cond_size=512,
75
  ).images[0]
 
76
 
77
- end_time = time.time()
78
- elapsed_time = end_time - start_time
79
- print(f"code running time: {elapsed_time} s")
80
-
81
- # Clear cache after generation
82
  clear_cache(pipe.transformer)
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- return (uploaded_image, image)
85
-
86
- # Example data
87
- examples = [
88
- ["3D_Chibi", "3D Chibi style, Two smiling colleagues enthusiastically high-five in front of a whiteboard filled with technical notes about multimodal learning, reflecting a moment of success and collaboration at OpenAI.",
89
- Image.open("./test_imgs/00.png"), 680, 1024, 3.5, 24, 42],
90
- ["Clay_Toy", "Clay Toy style, Three team members from OpenAI are gathered around a laptop in a cozy, festive setting, with holiday decorations in the background; one waves cheerfully while the others engage in light conversation, reflecting a relaxed and collaborative atmosphere.",
91
- Image.open("./test_imgs/01.png"), 560, 1024, 3.5, 24, 42],
92
- ["American_Cartoon", "American Cartoon style, In a dramatic and comedic moment from a classic Chinese film, an intense elder with a white beard and red hat grips a younger man, declaring something with fervor, while the subtitle at the bottom reads 'I want them all' capturing both tension and humor.",
93
- Image.open("./test_imgs/02.png"), 568, 1024, 3.5, 24, 42],
94
- ["Origami", "Origami style, A thrilled fan wearing a Portugal football kit poses energetically with a smiling Cristiano Ronaldo, who gives a thumbs-up, as they stand side by side in a casual, cheerful moment—capturing the excitement of meeting a football legend.",
95
- Image.open("./test_imgs/03.png"), 768, 672, 3.5, 24, 42],
96
- ["Macaron", "Macaron style, A man glances admiringly at a passing woman, while his girlfriend looks at him in disbelief, perfectly capturing the theme of shifting attention and misplaced priorities in a humorous, relatable way.",
97
- Image.open("./test_imgs/04.png"), 696, 1024, 3.5, 24, 42]
98
- ]
99
-
100
- header = """
101
  <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
102
  <a href="https://arxiv.org/abs/2505.18445"><img src="https://img.shields.io/badge/ariXv-2505.18445-A42C25.svg" alt="arXiv"></a>
103
  <a href="https://huggingface.co/showlab/OmniConsistency"><img src="https://img.shields.io/badge/🤗_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a>
@@ -105,58 +145,55 @@ header = """
105
  </div>
106
  """
107
 
108
- # Gradio interface setup
109
- def create_gradio_interface():
110
- lora_names = [
111
- "3D_Chibi", "American_Cartoon", "Chinese_Ink",
112
- "Clay_Toy", "Fabric", "Ghibli", "Irasutoya",
113
- "Jojo", "LEGO", "Line", "Macaron",
114
- "Oil_Painting", "Origami", "Paper_Cutting",
115
- "Picasso", "Pixel", "Poly", "Pop_Art",
116
- "Rick_Morty", "Snoopy", "Van_Gogh", "Vector"
117
- ]
118
-
119
  with gr.Blocks() as demo:
120
  gr.Markdown("# OmniConsistency LoRA Image Generation")
121
- gr.Markdown("Select a LoRA, enter a prompt, and upload an image to generate a new image with OmniConsistency. [View on GitHub](https://github.com/showlab/OmniConsistency)")
122
  gr.HTML(header)
 
123
  with gr.Row():
124
  with gr.Column(scale=1):
125
- lora_dropdown = gr.Dropdown(lora_names, label="Select LoRA")
126
- prompt_box = gr.Textbox(label="Prompt", placeholder="Enter a prompt...")
 
 
 
 
 
 
 
 
 
127
  image_input = gr.Image(type="pil", label="Upload Image")
128
  with gr.Column(scale=1):
129
  output_image = gr.ImageSlider(label="Generated Image")
130
- width_box = gr.Textbox(label="Width", value="1024")
131
- height_box = gr.Textbox(label="Height", value="1024")
132
- guidance_slider = gr.Slider(minimum=0.1, maximum=20, value=3.5, step=0.1, label="Guidance Scale")
133
- steps_slider = gr.Slider(minimum=1, maximum=50, value=25, step=1, label="Inference Steps")
134
- seed_slider = gr.Slider(minimum=1, maximum=10000000000, value=42, step=1, label="Seed")
135
- generate_button = gr.Button("Generate")
136
-
137
- # Add examples for Generation
 
 
138
  gr.Examples(
139
  examples=examples,
140
- inputs=[lora_dropdown, prompt_box, image_input, height_box, width_box, guidance_slider, steps_slider, seed_slider],
 
141
  outputs=output_image,
142
  fn=generate_image,
143
  cache_examples=False,
144
  label="Examples"
145
  )
146
 
147
- generate_button.click(
148
  fn=generate_image,
149
- inputs=[
150
- lora_dropdown, prompt_box, image_input,
151
- width_box, height_box, guidance_slider,
152
- steps_slider, seed_slider
153
- ],
154
  outputs=output_image
155
  )
156
-
157
  return demo
158
 
159
-
160
- # Launch the Gradio interface
161
- interface = create_gradio_interface()
162
- interface.launch()
 
1
  import spaces
2
+ import os
3
  import time
4
  import torch
5
  import gradio as gr
6
  from PIL import Image
7
+ from huggingface_hub import hf_hub_download, list_repo_files
8
  from src_inference.pipeline import FluxPipeline
9
  from src_inference.lora_helper import set_single_lora
 
10
 
11
+ BASE_PATH = "black-forest-labs/FLUX.1-dev"
12
+ LOCAL_LORA_DIR = "./LoRAs"
13
+ CUSTOM_LORA_DIR = "./Custom_LoRAs"
14
+ os.makedirs(LOCAL_LORA_DIR, exist_ok=True)
15
+ os.makedirs(CUSTOM_LORA_DIR, exist_ok=True)
16
+
17
+ print("downloading OmniConsistency base LoRA …")
18
+ omni_consistency_path = hf_hub_download(
19
+ repo_id="showlab/OmniConsistency",
20
+ filename="OmniConsistency.safetensors",
21
+ local_dir="./Model"
22
+ )
23
+
24
+ print("loading base pipeline …")
25
+ pipe = FluxPipeline.from_pretrained(
26
+ BASE_PATH, torch_dtype=torch.bfloat16
27
+ ).to("cuda")
28
+ set_single_lora(pipe.transformer, omni_consistency_path,
29
+ lora_weights=[1], cond_size=512)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def download_all_loras():
32
  lora_names = [
33
+ "3D_Chibi", "American_Cartoon", "Chinese_Ink", "Clay_Toy",
34
+ "Fabric", "Ghibli", "Irasutoya", "Jojo", "LEGO", "Line",
35
+ "Macaron", "Oil_Painting", "Origami", "Paper_Cutting",
36
+ "Picasso", "Pixel", "Poly", "Pop_Art", "Rick_Morty",
37
+ "Snoopy", "Van_Gogh", "Vector"
 
38
  ]
39
+ for name in lora_names:
40
+ hf_hub_download(
41
+ repo_id="showlab/OmniConsistency",
42
+ filename=f"LoRAs/{name}_rank128_bf16.safetensors",
43
+ local_dir=LOCAL_LORA_DIR,
44
+ )
45
  download_all_loras()
46
 
47
+ def clear_cache(transformer):
48
+ for _, attn_processor in transformer.attn_processors.items():
49
+ attn_processor.bank_kv.clear()
50
+
51
  @spaces.GPU()
52
+ def generate_image(
53
+ lora_name,
54
+ custom_repo_id,
55
+ prompt,
56
+ uploaded_image,
57
+ width, height,
58
+ guidance_scale,
59
+ num_inference_steps,
60
+ seed
61
+ ):
62
+ width, height = int(width), int(height)
63
+ generator = torch.Generator("cpu").manual_seed(seed)
64
+
65
+ if custom_repo_id and custom_repo_id.strip():
66
+ repo_id = custom_repo_id.strip()
67
+ try:
68
+ files = list_repo_files(repo_id)
69
+ print("using custom LoRA from:", repo_id)
70
+ safetensors_files = [f for f in files if f.endswith(".safetensors")]
71
+ print("found safetensors files:", safetensors_files)
72
+ if not safetensors_files:
73
+ raise ValueError("No .safetensors files were found in this repo")
74
+ fname = safetensors_files[0]
75
+ lora_path = hf_hub_download(
76
+ repo_id=repo_id,
77
+ filename=fname,
78
+ local_dir=CUSTOM_LORA_DIR,
79
+ )
80
+ except Exception as e:
81
+ raise gr.Error(f"Load custom LoRA failed: {e}")
82
+ else:
83
+ lora_path = os.path.join(
84
+ f"{LOCAL_LORA_DIR}/LoRAs", f"{lora_name}_rank128_bf16.safetensors"
85
+ )
86
 
 
87
  pipe.unload_lora_weights()
88
+ try:
89
+ pipe.load_lora_weights(
90
+ os.path.dirname(lora_path),
91
+ weight_name=os.path.basename(lora_path)
92
+ )
93
+ except Exception as e:
94
+ raise gr.Error(f"Load LoRA failed: {e}")
95
 
96
+ spatial_image = [uploaded_image.convert("RGB")]
 
97
  subject_images = []
98
+ start = time.time()
99
+ out_img = pipe(
 
 
 
100
  prompt,
101
+ height=(height // 8) * 8,
102
+ width=(width // 8) * 8,
103
  guidance_scale=guidance_scale,
104
  num_inference_steps=num_inference_steps,
105
  max_sequence_length=512,
106
+ generator=generator,
107
  spatial_images=spatial_image,
108
  subject_images=subject_images,
109
  cond_size=512,
110
  ).images[0]
111
+ print(f"inference time: {time.time()-start:.2f}s")
112
 
 
 
 
 
 
113
  clear_cache(pipe.transformer)
114
+ return uploaded_image, out_img
115
+
116
+ # =============== Gradio UI ===============
117
+ def create_interface():
118
+ demo_lora_names = [
119
+ "3D_Chibi", "American_Cartoon", "Chinese_Ink", "Clay_Toy",
120
+ "Fabric", "Ghibli", "Irasutoya", "Jojo", "LEGO", "Line",
121
+ "Macaron", "Oil_Painting", "Origami", "Paper_Cutting",
122
+ "Picasso", "Pixel", "Poly", "Pop_Art", "Rick_Morty",
123
+ "Snoopy", "Van_Gogh", "Vector"
124
+ ]
125
 
126
+ # Example data
127
+ examples = [
128
+ ["3D_Chibi", "", "3D Chibi style, Two smiling colleagues enthusiastically high-five in front of a whiteboard filled with technical notes about multimodal learning, reflecting a moment of success and collaboration at OpenAI.",
129
+ Image.open("./test_imgs/00.png"), 680, 1024, 3.5, 24, 42],
130
+ ["Clay_Toy", "", "Clay Toy style, Three team members from OpenAI are gathered around a laptop in a cozy, festive setting, with holiday decorations in the background; one waves cheerfully while the others engage in light conversation, reflecting a relaxed and collaborative atmosphere.",
131
+ Image.open("./test_imgs/01.png"), 560, 1024, 3.5, 24, 42],
132
+ ["American_Cartoon", "", "American Cartoon style, In a dramatic and comedic moment from a classic Chinese film, an intense elder with a white beard and red hat grips a younger man, declaring something with fervor, while the subtitle at the bottom reads 'I want them all' capturing both tension and humor.",
133
+ Image.open("./test_imgs/02.png"), 568, 1024, 3.5, 24, 42],
134
+ ["Origami", "", "Origami style, A thrilled fan wearing a Portugal football kit poses energetically with a smiling Cristiano Ronaldo, who gives a thumbs-up, as they stand side by side in a casual, cheerful moment—capturing the excitement of meeting a football legend.",
135
+ Image.open("./test_imgs/03.png"), 768, 672, 3.5, 24, 42],
136
+ ["Vector", "", "Vector style, A man glances admiringly at a passing woman, while his girlfriend looks at him in disbelief, perfectly capturing the theme of shifting attention and misplaced priorities in a humorous, relatable way.",
137
+ Image.open("./test_imgs/04.png"), 512, 1024, 3.5, 24, 42]
138
+ ]
139
+
140
+ header = """
 
 
141
  <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
142
  <a href="https://arxiv.org/abs/2505.18445"><img src="https://img.shields.io/badge/ariXv-2505.18445-A42C25.svg" alt="arXiv"></a>
143
  <a href="https://huggingface.co/showlab/OmniConsistency"><img src="https://img.shields.io/badge/🤗_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a>
 
145
  </div>
146
  """
147
 
 
 
 
 
 
 
 
 
 
 
 
148
  with gr.Blocks() as demo:
149
  gr.Markdown("# OmniConsistency LoRA Image Generation")
150
+ gr.Markdown("Select a LoRA, enter a prompt, and upload an image to generate a new image with OmniConsistency.")
151
  gr.HTML(header)
152
+
153
  with gr.Row():
154
  with gr.Column(scale=1):
155
+ lora_dropdown = gr.Dropdown(
156
+ demo_lora_names, label="Select built-in LoRA")
157
+ custom_repo_box = gr.Textbox(
158
+ label="Enter Custom LoRA",
159
+ placeholder="LoRA Hugging Face path (e.g., 'username/repo_name')",
160
+ info="If you want to use a custom LoRA, enter its Hugging Face repo ID here and built-in LoRA will be Overridden. Leave empty to use built-in LoRAs. [Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)"
161
+ )
162
+ prompt_box = gr.Textbox(label="Prompt",
163
+ placeholder="Enter your prompt here",
164
+ info="Remember to include the necessary trigger words if you're using a custom LoRA."
165
+ )
166
  image_input = gr.Image(type="pil", label="Upload Image")
167
  with gr.Column(scale=1):
168
  output_image = gr.ImageSlider(label="Generated Image")
169
+ height_box = gr.Textbox(value="1024", label="Height")
170
+ width_box = gr.Textbox(value="1024", label="Width")
171
+ guidance_slider = gr.Slider(
172
+ 0.1, 20, value=3.5, step=0.1, label="Guidance Scale")
173
+ steps_slider = gr.Slider(
174
+ 1, 50, value=25, step=1, label="Inference Steps")
175
+ seed_slider = gr.Slider(
176
+ 1, 2_147_483_647, value=42, step=1, label="Seed")
177
+ gen_btn = gr.Button("Generate")
178
+
179
  gr.Examples(
180
  examples=examples,
181
+ inputs=[lora_dropdown, custom_repo_box, prompt_box, image_input,
182
+ height_box, width_box, guidance_slider, steps_slider, seed_slider],
183
  outputs=output_image,
184
  fn=generate_image,
185
  cache_examples=False,
186
  label="Examples"
187
  )
188
 
189
+ gen_btn.click(
190
  fn=generate_image,
191
+ inputs=[lora_dropdown, custom_repo_box, prompt_box, image_input,
192
+ width_box, height_box, guidance_slider, steps_slider, seed_slider],
 
 
 
193
  outputs=output_image
194
  )
 
195
  return demo
196
 
197
+ if __name__ == "__main__":
198
+ demo = create_interface()
199
+ demo.launch()