Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
from typing import TypedDict | |
from PIL import Image, ImageDraw, ImageFont | |
from diffusers.pipelines import FluxPipeline | |
from diffusers import FluxTransformer2DModel | |
import numpy as np | |
import examples_db | |
from flux.condition import Condition | |
from flux.generate import seed_everything, generate | |
from flux.lora_controller import set_lora_scale | |
pipe = None | |
current_adapter = None | |
use_int8 = False | |
model_config = { "union_cond_attn": True, "add_cond_attn": False, "latent_lora": False, "independent_condition": True} | |
def get_gpu_memory(): | |
return torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
def init_pipeline(): | |
global pipe | |
if use_int8 or get_gpu_memory() < 33: | |
transformer_model = FluxTransformer2DModel.from_pretrained( | |
"sayakpaul/flux.1-schell-int8wo-improved", | |
torch_dtype=torch.bfloat16, | |
use_safetensors=False, | |
) | |
pipe = FluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-schnell", | |
transformer=transformer_model, | |
torch_dtype=torch.bfloat16, | |
) | |
else: | |
pipe = FluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 | |
) | |
pipe = pipe.to("cuda") | |
# Optional: Load additional LoRA weights | |
pipe.load_lora_weights( | |
"fotographerai/zenctrl_tools", | |
weight_name="weights/zen2con_1024_10000/" | |
"pytorch_lora_weights.safetensors", | |
adapter_name="subject" | |
) | |
# Optional: Load additional LoRA weights | |
#pipe.load_lora_weights("XLabs-AI/flux-RealismLora", adapter_name="realism") | |
def paste_on_white_background(image: Image.Image) -> Image.Image: | |
""" | |
Pastes a transparent image onto a white background of the same size. | |
""" | |
if image.mode != "RGBA": | |
image = image.convert("RGBA") | |
# Create white background | |
white_bg = Image.new("RGBA", image.size, (255, 255, 255, 255)) | |
white_bg.paste(image, (0, 0), mask=image) | |
return white_bg.convert("RGB") # Convert back to RGB if you don't need alpha | |
#@spaces.GPU | |
def process_image_and_text(image, text, steps=8, strength_sub=1.0, strength_spat=1.0, size=1024): | |
# center crop image | |
w, h, min_size = image.size[0], image.size[1], min(image.size) | |
image = image.crop( | |
( | |
(w - min_size) // 2, | |
(h - min_size) // 2, | |
(w + min_size) // 2, | |
(h + min_size) // 2, | |
) | |
) | |
image = image.resize((size, size)) | |
image = paste_on_white_background(image) | |
condition0 = Condition("subject", image, position_delta=(0, size // 16)) | |
condition1 = Condition("subject", image, position_delta=(0, -size // 16)) | |
pipe = get_pipeline() | |
with set_lora_scale(["subject"], scale=3.0): | |
result_img = generate( | |
pipe, | |
prompt=text.strip(), | |
conditions=[condition0, condition1], | |
num_inference_steps=steps, | |
height=1024, | |
width=1024, | |
condition_scale = [strength_sub,strength_spat], | |
model_config=model_config, | |
).images[0] | |
return result_img | |
# ================== MODE CONFIG ===================== | |
Mode = TypedDict( | |
"Mode", | |
{ | |
"model": str, | |
"prompt": str, | |
"default_strength": float, | |
"default_height": int, | |
"default_width": int, | |
"models": list[str], | |
"remove_bg": bool, | |
}, | |
) | |
MODEL_TO_LORA: dict[str, str] = { | |
# dropdown-value # relative path inside the HF repo | |
"zen2con_1024_10000": "weights/zen2con_1024_10000/pytorch_lora_weights.safetensors", | |
"zen2con_1440_17000": "weights/zen2con_1440_17000/pytorch_lora_weights.safetensors", | |
"zen_sub_sub_1024_10000": "weights/zen_sub_sub_1024_10000/pytorch_lora_weights.safetensors", | |
"zen_toys_1024_4000": "weights/zen_toys_1024_4000/12000/pytorch_lora_weights.safetensors", | |
"zen_toys_1024_15000": "weights/zen_toys_1024_4000/zen_toys_1024_15000/pytorch_lora_weights.safetensors", | |
# add more as you upload them | |
} | |
MODE_DEFAULTS: dict[str, Mode] = { | |
"Subject Generation": { | |
"model": "zen2con_1024_10000", | |
"prompt": "A vibrant background with dynamic lighting and textures", | |
"default_strength": 1.2, | |
"default_height": 1024, | |
"default_width": 1024, | |
"models": list(MODEL_TO_LORA.keys()), | |
"remove_bg": True, | |
}, | |
#"Image fix": { | |
# "model": "zen_toys_1024_4000", | |
# "prompt": "A detailed portrait with soft lighting", | |
# "default_strength": 1.2, | |
# "default_height": 1024, | |
# "default_width": 1024, | |
# "models": ["weights/zen_toys_1024_4000/12000/", "weights/zen_toys_1024_4000/12000/"], | |
# "remove_bg": True, | |
#} | |
} | |
def get_pipeline(): | |
"""Lazy-build the pipeline inside the GPU worker.""" | |
global pipe | |
if pipe is None: | |
init_pipeline() # safe here β this fn is @spaces.GPU wrapped | |
return pipe | |
def get_samples(): | |
sample_list = [ | |
{ | |
"image": "samples/1.png", | |
"text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'", | |
}, | |
{ | |
"image": "samples/2.png", | |
"text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'", | |
}, | |
{ | |
"image": "samples/3.png", | |
"text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.", | |
}, | |
{ | |
"image": "samples/4.png", | |
"text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.", | |
}, | |
{ | |
"image": "samples/5.png", | |
"text": "On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her.", | |
}, | |
] | |
return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list] | |
# =============== UI =============== | |
header = """ | |
<h1>π ZenCtrl medium</h1> | |
<div align="center" style="line-height: 1;"> | |
<a href="https://github.com/FotographerAI/ZenCtrl/tree/main" target="_blank" style="margin: 2px;" name="github_repo_link"><img src="https://img.shields.io/badge/GitHub-Repo-181717.svg" alt="GitHub Repo" style="display: inline-block; vertical-align: middle;"></a> | |
<a href="https://huggingface.co/fotographerai/zenctrl_tools" target="_blank" name="huggingface_space_link"><img src="https://img.shields.io/badge/π€_HuggingFace-Model-ffbd45.svg" alt="HuggingFace Model" style="display: inline-block; vertical-align: middle;"></a> | |
<a href="https://discord.com/invite/b9RuYQ3F8k" target="_blank" style="margin: 2px;" name="discord_link"><img src="https://img.shields.io/badge/Discord-Join-7289da.svg?logo=discord" alt="Discord" style="display: inline-block; vertical-align: middle;"></a> | |
<a href="https://fotographer.ai/zen-control" target="_blank" style="margin: 2px;" name="lp_link"><img src="https://img.shields.io/badge/Website-Landing_Page-blue" alt="LP" style="display: inline-block; vertical-align: middle;"></a> | |
<a href="https://x.com/FotographerAI" target="_blank" style="margin: 2px;" name="twitter_link"><img src="https://img.shields.io/twitter/follow/FotographerAI?style=social" alt="X" style="display: inline-block; vertical-align: middle;"></a> | |
</div> | |
""" | |
with gr.Blocks(title="π ZenCtrl-medium") as demo: | |
# ---------- banner ---------- | |
gr.HTML(header) | |
gr.Markdown( | |
""" | |
# ZenCtrl Demo | |
One framework to Generate multi-view, diverse-scene, and task-specific high-resolution images from a single subject imageβwithout fine-tuning. | |
We are first releasing some of the task specific weights and will release the codes soon. | |
The goal is to unify all of the visual content generation tasks with a single LLM... | |
**Mode:** | |
- **Subject-driven Image Generation:** Generate in-context images of your subject with high fidelity and in different perspectives. | |
For more details, shoot us a message on discord. | |
""" | |
) | |
# ---------- tab bar ---------- | |
with gr.Tabs(): | |
for mode_name, defaults in MODE_DEFAULTS.items(): | |
with gr.Tab(mode_name): | |
gr.Markdown(f"### {mode_name}") | |
# -------- left (input) column -------- | |
with gr.Row(): | |
with gr.Column(scale=2): | |
input_image = gr.Image(label="Input Image", type="pil") | |
model_dropdown = gr.Dropdown( | |
label="Model (LoRA adapter)", | |
choices=defaults["models"], | |
value=defaults["model"], | |
interactive=True, | |
) | |
prompt_box = gr.Textbox(label="Prompt", | |
value=defaults["prompt"], lines=2) | |
generate_btn = gr.Button("Generate") | |
with gr.Accordion("Generation Parameters", open=False): | |
step_slider = gr.Slider(2, 28, value=12, step=2, label="Steps") | |
strength_sub_slider = gr.Slider(0.0, 2.0, | |
value=defaults["default_strength"], | |
step=0.1, label="Strength (subject)") | |
strength_spat_slider = gr.Slider(0.0, 2.0, | |
value=defaults["default_strength"], | |
step=0.1, label="Strength (spatial)") | |
size_slider = gr.Slider(512, 2048, | |
value=defaults["default_height"], | |
step=64, label="Size (px)") | |
# -------- right (output) column -------- | |
with gr.Column(scale=2): | |
output_image = gr.Image(label="Output Image", type="pil") | |
# ---------- click handler ---------- | |
def _run(image, model_name, prompt, steps, s_sub, s_spat, size): | |
global current_adapter | |
pipe = get_pipeline() | |
# ββ switch adapter if needed ββββββββββββββββββββββββββ | |
if model_name != current_adapter: | |
lora_path = MODEL_TO_LORA[model_name] | |
# load & activate the chosen adapter | |
pipe.load_lora_weights( | |
"fotographerai/zenctrl_tools", | |
weight_name=lora_path, | |
adapter_name=model_name, | |
) | |
pipe.set_adapters([model_name]) | |
current_adapter = model_name | |
# ββ run generation βββββββββββββββββββββββββββββββββββ | |
delta = size // 16 | |
return process_image_and_text( | |
image, prompt, steps=steps, | |
strength_sub=s_sub, strength_spat=s_spat, size=size | |
) | |
generate_btn.click( | |
fn=_run, | |
inputs=[input_image, model_dropdown, prompt_box, | |
step_slider, strength_sub_slider, | |
strength_spat_slider, size_slider], | |
outputs=[output_image], | |
) | |
# ---------------- Templates -------------------- | |
if examples_db.MODE_EXAMPLES.get(mode_name): | |
gr.Examples( | |
examples=examples_db.MODE_EXAMPLES[mode_name], | |
inputs=[ input_image, # Image widget | |
model_dropdown, # Dropdown for adapter | |
prompt_box, # Textbox for prompt | |
output_image, # Gallery for output | |
], | |
label="Presets (Image / Model / Prompt)", | |
examples_per_page=15, | |
) | |
# =============== launch =============== | |
if __name__ == "__main__": | |
#init_pipeline() | |
demo.launch( | |
debug=True, | |
share=True | |
) |