Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from diffusers import DiffusionPipeline, QwenImageEditPipeline, FlowMatchEulerDiscreteScheduler | |
import random | |
import uuid | |
import numpy as np | |
import time | |
import zipfile | |
import os | |
import requests | |
from urllib.parse import urlparse | |
import tempfile | |
import shutil | |
import math | |
# --- App Description --- | |
DESCRIPTION = """## Qwen Image Hpc/.""" | |
# --- Helper Functions for Both Tabs --- | |
MAX_SEED = np.iinfo(np.int32).max | |
def save_image(img): | |
"""Saves a PIL image to a temporary file with a unique name.""" | |
unique_name = str(uuid.uuid4()) + ".png" | |
img.save(unique_name) | |
return unique_name | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
"""Returns a random seed if randomize_seed is True, otherwise returns the original seed.""" | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
# --- Model Loading --- | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# --- Qwen-Image-Gen Model --- | |
pipe_qwen_gen = DiffusionPipeline.from_pretrained( | |
"Qwen/Qwen-Image", | |
torch_dtype=dtype | |
).to(device) | |
# --- Qwen-Image-Edit Model with Lightning LoRA --- | |
scheduler_config = { | |
"base_image_seq_len": 256, | |
"base_shift": math.log(3), | |
"invert_sigmas": False, | |
"max_image_seq_len": 8192, | |
"max_shift": math.log(3), | |
"num_train_timesteps": 1000, | |
"shift": 1.0, | |
"shift_terminal": None, | |
"stochastic_sampling": False, | |
"time_shift_type": "exponential", | |
"use_beta_sigmas": False, | |
"use_dynamic_shifting": True, | |
"use_exponential_sigmas": False, | |
"use_karras_sigmas": False, | |
} | |
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config) | |
pipe_qwen_edit = QwenImageEditPipeline.from_pretrained( | |
"Qwen/Qwen-Image-Edit", | |
scheduler=scheduler, | |
torch_dtype=dtype | |
).to(device) | |
try: | |
pipe_qwen_edit.load_lora_weights( | |
"lightx2v/Qwen-Image-Lightning", | |
weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors" | |
) | |
pipe_qwen_edit.fuse_lora() | |
print("Successfully loaded Lightning LoRA weights for Qwen-Image-Edit") | |
except Exception as e: | |
print(f"Warning: Could not load Lightning LoRA weights for Qwen-Image-Edit: {e}") | |
print("Continuing with the base Qwen-Image-Edit model...") | |
# --- Qwen-Image-Gen Functions --- | |
aspect_ratios = { | |
"1:1": (1328, 1328), | |
"16:9": (1664, 928), | |
"9:16": (928, 1664), | |
"4:3": (1472, 1140), | |
"3:4": (1140, 1472) | |
} | |
def load_lora_opt(pipe, lora_input): | |
"""Loads a LoRA from a local path, Hugging Face repo, or URL.""" | |
lora_input = lora_input.strip() | |
if not lora_input: | |
return | |
if "/" in lora_input and not lora_input.startswith("http"): | |
pipe.load_lora_weights(lora_input, adapter_name="default") | |
return | |
if lora_input.startswith("http"): | |
url = lora_input | |
if "huggingface.co" in url and "/blob/" not in url and "/resolve/" not in url: | |
repo_id = urlparse(url).path.strip("/") | |
pipe.load_lora_weights(repo_id, adapter_name="default") | |
return | |
if "/blob/" in url: | |
url = url.replace("/blob/", "/resolve/") | |
tmp_dir = tempfile.mkdtemp() | |
local_path = os.path.join(tmp_dir, os.path.basename(urlparse(url).path)) | |
try: | |
print(f"Downloading LoRA from {url}...") | |
resp = requests.get(url, stream=True) | |
resp.raise_for_status() | |
with open(local_path, "wb") as f: | |
for chunk in resp.iter_content(chunk_size=8192): | |
f.write(chunk) | |
print(f"Saved LoRA to {local_path}") | |
pipe.load_lora_weights(local_path, adapter_name="default") | |
finally: | |
shutil.rmtree(tmp_dir, ignore_errors=True) | |
def generate_qwen( | |
prompt: str, | |
negative_prompt: str = "", | |
seed: int = 0, | |
width: int = 1024, | |
height: int = 1024, | |
guidance_scale: float = 4.0, | |
randomize_seed: bool = False, | |
num_inference_steps: int = 50, | |
num_images: int = 1, | |
zip_images: bool = False, | |
lora_input: str = "", | |
lora_scale: float = 1.0, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
"""Main generation function for Qwen-Image-Gen.""" | |
seed = randomize_seed_fn(seed, randomize_seed) | |
generator = torch.Generator(device).manual_seed(seed) | |
start_time = time.time() | |
current_adapters = pipe_qwen_gen.get_list_adapters() | |
for adapter in current_adapters: | |
pipe_qwen_gen.delete_adapters(adapter) | |
pipe_qwen_gen.disable_lora() | |
if lora_input and lora_input.strip() != "": | |
load_lora_opt(pipe_qwen_gen, lora_input) | |
pipe_qwen_gen.set_adapters(["default"], adapter_weights=[lora_scale]) | |
images = pipe_qwen_gen( | |
prompt=prompt, | |
negative_prompt=negative_prompt if negative_prompt else " ", | |
height=height, | |
width=width, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
num_images_per_prompt=num_images, | |
generator=generator, | |
).images | |
end_time = time.time() | |
duration = end_time - start_time | |
image_paths = [save_image(img) for img in images] | |
zip_path = None | |
if zip_images and len(image_paths) > 0: | |
zip_name = str(uuid.uuid4()) + ".zip" | |
with zipfile.ZipFile(zip_name, 'w') as zipf: | |
for i, img_path in enumerate(image_paths): | |
zipf.write(img_path, arcname=f"Img_{i}.png") | |
zip_path = zip_name | |
current_adapters = pipe_qwen_gen.get_list_adapters() | |
for adapter in current_adapters: | |
pipe_qwen_gen.delete_adapters(adapter) | |
pipe_qwen_gen.disable_lora() | |
return image_paths, seed, f"{duration:.2f}", zip_path | |
def generate( | |
prompt: str, | |
negative_prompt: str, | |
use_negative_prompt: bool, | |
seed: int, | |
width: int, | |
height: int, | |
guidance_scale: float, | |
randomize_seed: bool, | |
num_inference_steps: int, | |
num_images: int, | |
zip_images: bool, | |
lora_input: str, | |
lora_scale: float, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
"""UI wrapper for the Qwen-Image-Gen generation function.""" | |
final_negative_prompt = negative_prompt if use_negative_prompt else "" | |
return generate_qwen( | |
prompt=prompt, | |
negative_prompt=final_negative_prompt, | |
seed=seed, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
randomize_seed=randomize_seed, | |
num_inference_steps=num_inference_steps, | |
num_images=num_images, | |
zip_images=zip_images, | |
lora_input=lora_input, | |
lora_scale=lora_scale, | |
progress=progress, | |
) | |
# --- Qwen-Image-Edit Functions --- | |
def infer_edit( | |
image, | |
prompt, | |
seed=42, | |
randomize_seed=False, | |
true_guidance_scale=1.0, | |
num_inference_steps=8, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
"""Main inference function for Qwen-Image-Edit.""" | |
if image is None: | |
raise gr.Error("Please upload an image to edit.") | |
negative_prompt = " " | |
seed = randomize_seed_fn(seed, randomize_seed) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
print(f"Original prompt: '{prompt}'") | |
print(f"Negative Prompt: '{negative_prompt}'") | |
print(f"Seed: {seed}, Steps: {num_inference_steps}, Guidance: {true_guidance_scale}") | |
try: | |
images = pipe_qwen_edit( | |
image, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
true_cfg_scale=true_guidance_scale, | |
num_images_per_prompt=1 | |
).images | |
return images[0], seed | |
except Exception as e: | |
print(f"Error during inference: {e}") | |
raise gr.Error(f"An error occurred during image editing: {e}") | |
# --- Gradio UI --- | |
css = ''' | |
.gradio-container { | |
max-width: 800px !important; | |
margin: 0 auto !important; | |
} | |
h1 { | |
text-align: center; | |
} | |
footer { | |
visibility: hidden; | |
} | |
''' | |
with gr.Blocks(css=css, theme="bethecloud/storj_theme", delete_cache=(240, 240)) as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Tabs(): | |
with gr.TabItem("Qwen-Image-Gen"): | |
with gr.Column(): | |
with gr.Row(): | |
prompt_gen = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="✦︎ Enter your prompt for generation", | |
container=False, | |
) | |
run_button_gen = gr.Button("Generate", scale=0, variant="primary") | |
result_gen = gr.Gallery(label="Result", columns=2, show_label=False, preview=True, height="auto") | |
with gr.Row(): | |
aspect_ratio_gen = gr.Dropdown( | |
label="Aspect Ratio", | |
choices=list(aspect_ratios.keys()), | |
value="1:1", | |
) | |
lora_gen = gr.Textbox(label="Optional LoRA", placeholder="Enter Hugging Face repo ID or URL...") | |
with gr.Accordion("Additional Options", open=False): | |
use_negative_prompt_gen = gr.Checkbox(label="Use negative prompt", value=True) | |
negative_prompt_gen = gr.Text( | |
label="Negative prompt", | |
max_lines=1, | |
placeholder="Enter a negative prompt", | |
value="text, watermark, copyright, blurry, low resolution", | |
) | |
seed_gen = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
randomize_seed_gen = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
width_gen = gr.Slider(label="Width", minimum=512, maximum=2048, step=64, value=1328) | |
height_gen = gr.Slider(label="Height", minimum=512, maximum=2048, step=64, value=1328) | |
guidance_scale_gen = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, step=0.1, value=4.0) | |
num_inference_steps_gen = gr.Slider("Number of inference steps", 1, 100, 50, step=1) | |
num_images_gen = gr.Slider("Number of images", 1, 5, 1, step=1) | |
zip_images_gen = gr.Checkbox(label="Zip generated images", value=False) | |
with gr.Row(): | |
lora_scale_gen = gr.Slider(label="LoRA Scale", minimum=0, maximum=2, step=0.01, value=1) | |
gr.Markdown("### Output Information") | |
seed_display_gen = gr.Textbox(label="Seed used", interactive=False) | |
generation_time_gen = gr.Textbox(label="Generation time (seconds)", interactive=False) | |
zip_file_gen = gr.File(label="Download ZIP") | |
# --- Gen Tab Logic --- | |
def set_dimensions(ar): | |
w, h = aspect_ratios[ar] | |
return gr.update(value=w), gr.update(value=h) | |
aspect_ratio_gen.change(fn=set_dimensions, inputs=aspect_ratio_gen, outputs=[width_gen, height_gen]) | |
use_negative_prompt_gen.change(fn=lambda x: gr.update(visible=x), inputs=use_negative_prompt_gen, outputs=negative_prompt_gen) | |
gen_inputs = [ | |
prompt_gen, negative_prompt_gen, use_negative_prompt_gen, seed_gen, width_gen, height_gen, | |
guidance_scale_gen, randomize_seed_gen, num_inference_steps_gen, num_images_gen, | |
zip_images_gen, lora_gen, lora_scale_gen | |
] | |
gen_outputs = [result_gen, seed_display_gen, generation_time_gen, zip_file_gen] | |
gr.on(triggers=[prompt_gen.submit, run_button_gen.click], fn=generate, inputs=gen_inputs, outputs=gen_outputs) | |
gen_examples = [ | |
"A decadent slice of layered chocolate cake on a ceramic plate with a drizzle of chocolate syrup and powdered sugar dusted on top.", | |
"A young girl wearing school uniform stands in a classroom, writing on a chalkboard. The text 'Introducing Qwen-Image' appears in neat white chalk.", | |
"一幅精致细腻的工笔画,画面中心是一株蓬勃生长的红色牡丹,花朵繁茂。", | |
"Realistic still life photography style: A single, fresh apple, resting on a clean, soft-textured surface.", | |
] | |
gr.Examples(examples=gen_examples, inputs=prompt_gen, outputs=gen_outputs, fn=generate, cache_examples=False) | |
with gr.TabItem("Qwen-Image-Edit"): | |
with gr.Column(): | |
with gr.Row(): | |
input_image_edit = gr.Image(label="Input Image", type="pil", height=400) | |
result_edit = gr.Image(label="Result", type="pil", height=400) | |
with gr.Row(): | |
prompt_edit = gr.Text( | |
label="Edit Instruction", | |
show_label=False, | |
placeholder="Describe the edit you want to make", | |
container=False, | |
) | |
run_button_edit = gr.Button("Edit", variant="primary") | |
with gr.Accordion("Advanced Settings", open=False): | |
seed_edit = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) | |
randomize_seed_edit = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
true_guidance_scale_edit = gr.Slider( | |
label="True guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0 | |
) | |
num_inference_steps_edit = gr.Slider( | |
label="Inference steps (Lightning LoRA)", minimum=4, maximum=28, step=1, value=8 | |
) | |
# --- Edit Tab Logic --- | |
edit_inputs = [ | |
input_image_edit, prompt_edit, seed_edit, randomize_seed_edit, | |
true_guidance_scale_edit, num_inference_steps_edit | |
] | |
edit_outputs = [result_edit, seed_edit] | |
gr.on(triggers=[prompt_edit.submit, run_button_edit.click], fn=infer_edit, inputs=edit_inputs, outputs=edit_outputs) | |
edit_examples = [ | |
["image-edit/cat.png", "make the cat wear sunglasses"], | |
["image-edit/girl.png", "change her hair to blonde"], | |
] | |
gr.Examples(examples=edit_examples, inputs=[input_image_edit, prompt_edit], outputs=edit_outputs, fn=infer_edit, cache_examples=True) | |
if __name__ == "__main__": | |
demo.queue(max_size=50).launch(share=False, debug=True) |