Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
from gradio_client import Client, handle_file | |
import torch | |
import spaces | |
from diffusers import Lumina2Pipeline | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
if torch.cuda.is_available(): | |
torch_dtype = torch.bfloat16 | |
else: | |
torch_dtype = torch.float32 | |
def set_client_for_session(request: gr.Request): | |
x_ip_token = request.headers['x-ip-token'] | |
# The "gradio/text-to-image" space is a ZeroGPU space | |
# return Client("stzhao/LeX-Enhancer", headers={"X-IP-Token": x_ip_token}) | |
return Client("stzhao/LeX-Enhancer") | |
# Load models | |
def load_models(): | |
pipe = Lumina2Pipeline.from_pretrained( | |
"X-ART/LeX-Lumina", | |
torch_dtype=torch.bfloat16 | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe.to("cuda") | |
return pipe | |
def prompt_enhance(client, image_caption, text_caption): | |
combined_caption, enhanced_caption = client.predict(image_caption, text_caption, api_name="/generate_enhanced_caption") | |
return combined_caption, enhanced_caption | |
pipe = load_models() | |
# def truncate_caption_by_tokens(caption, max_tokens=256): | |
# """Truncate the caption to fit within the max token limit""" | |
# tokens = tokenizer.encode(caption) | |
# if len(tokens) > max_tokens: | |
# truncated_tokens = tokens[:max_tokens] | |
# caption = tokenizer.decode(truncated_tokens, skip_special_tokens=True) | |
# print(f"Caption was truncated from {len(tokens)} tokens to {max_tokens} tokens") | |
# return caption | |
def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale): | |
# pipe.to("cuda") | |
pipe.enable_model_cpu_offload() | |
"""Generate image using LeX-Lumina""" | |
# Truncate the caption if it's too long | |
# enhanced_caption = truncate_caption_by_tokens(enhanced_caption, max_tokens=256) | |
generator = torch.Generator("cpu").manual_seed(seed) if seed != 0 else None | |
image = pipe( | |
enhanced_caption, | |
height=1024, | |
width=1024, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
cfg_trunc_ratio=1, | |
cfg_normalization=True, | |
max_sequence_length=256, | |
generator=generator, | |
system_prompt="You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts.", | |
).images[0] | |
print(image) | |
pipe.to("cpu") | |
torch.cuda.empty_cache() | |
return image | |
# @spaces.GPU(duration=130) | |
def run_pipeline(image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer, client): | |
"""Run the complete pipeline from captions to final image""" | |
combined_caption = f"{image_caption}, with the text on it: {text_caption}." | |
if enable_enhancer: | |
# combined_caption, enhanced_caption = generate_enhanced_caption(image_caption, text_caption) | |
combined_caption, enhanced_caption = prompt_enhance(client, image_caption, text_caption) | |
print(f"enhanced caption:\n{enhanced_caption}") | |
else: | |
enhanced_caption = combined_caption | |
image = generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale) | |
return image, combined_caption, enhanced_caption | |
# Gradio interface | |
with gr.Blocks() as demo: | |
client = gr.State() | |
gr.Markdown("# LeX-Enhancer & LeX-Lumina Demo") | |
gr.Markdown("## Project Page: https://zhaoshitian.github.io/lexart/") | |
gr.Markdown("Generate enhanced captions from simple image and text descriptions, then create images with LeX-Lumina") | |
with gr.Row(): | |
with gr.Column(): | |
image_caption = gr.Textbox( | |
lines=2, | |
label="Image Caption", | |
placeholder="Describe the visual content of the image", | |
value="A picture of a group of people gathered in front of a world map" | |
) | |
text_caption = gr.Textbox( | |
lines=2, | |
label="Text Caption", | |
placeholder="Describe any text that should appear in the image", | |
value="\"Communicate\" in purple, \"Execute\" in yellow" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
enable_enhancer = gr.Checkbox( | |
label="Enable LeX-Enhancer", | |
value=True, | |
info="When enabled, the caption will be enhanced before image generation" | |
) | |
seed = gr.Slider( | |
minimum=0, | |
maximum=100000, | |
value=0, | |
step=1, | |
label="Seed (0 for random)" | |
) | |
num_inference_steps = gr.Slider( | |
minimum=20, | |
maximum=100, | |
value=40, | |
step=1, | |
label="Number of Inference Steps" | |
) | |
guidance_scale = gr.Slider( | |
minimum=1.0, | |
maximum=10.0, | |
value=7.5, | |
step=0.1, | |
label="Guidance Scale" | |
) | |
submit_btn = gr.Button("Generate", variant="primary") | |
with gr.Column(): | |
output_image = gr.Image(label="Generated Image") | |
combined_caption_box = gr.Textbox( | |
label="Combined Caption", | |
interactive=False | |
) | |
enhanced_caption_box = gr.Textbox( | |
label="Enhanced Caption" if enable_enhancer.value else "Final Caption", | |
interactive=False, | |
lines=5 | |
) | |
# Example prompts | |
examples = [ | |
["A modern office workspace", "\"Innovation\" in bold blue letters at the center"], | |
["A beach sunset scene", "\"Relax\" in cursive white text in the corner"], | |
["A futuristic city skyline", "\"The Future is Now\" in neon pink glowing letters"] | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[image_caption, text_caption], | |
label="Example Inputs" | |
) | |
# Update the label of enhanced_caption_box based on checkbox state | |
def update_caption_label(enable_enhancer): | |
return gr.Textbox(label="Enhanced Caption" if enable_enhancer else "Final Caption") | |
enable_enhancer.change( | |
fn=update_caption_label, | |
inputs=enable_enhancer, | |
outputs=enhanced_caption_box | |
) | |
submit_btn.click( | |
fn=run_pipeline, | |
inputs=[image_caption, text_caption, seed, num_inference_steps, guidance_scale, enable_enhancer, client], | |
outputs=[output_image, combined_caption_box, enhanced_caption_box] | |
) | |
demo.load(set_client_for_session, None, client) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |