helblazer811's picture
Changed styling
a8468a7
raw
history blame
9.92 kB
import spaces
import gradio as gr
from PIL import Image
import math
import io
import base64
import subprocess
import os
from concept_attention import ConceptAttentionFluxPipeline
IMG_SIZE = 210
COLUMNS = 5
def update_default_concepts(prompt):
default_concepts = {
"A dog by a tree": ["dog", "grass", "tree", "background"],
"A dragon": ["dragon", "sky", "rock", "cloud"],
"A hot air balloon": ["balloon", "sky", "water", "tree"]
}
return gr.update(value=default_concepts.get(prompt, []))
pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda") # , offload_model=True)
def convert_pil_to_bytes(img):
img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
@spaces.GPU(duration=60)
def process_inputs(prompt, concepts, seed, layer_start_index, timestep_start_index):
if not prompt:
raise gr.exceptions.InputError("prompt", "Please enter a prompt")
if not prompt.strip():
raise gr.exceptions.InputError("prompt", "Please enter a prompt")
prompt = prompt.strip()
if len(concepts) == 0:
raise gr.exceptions.InputError("words", "Please enter at least 1 concept")
if len(concepts) > 9:
raise gr.exceptions.InputError("words", "Please enter at most 9 concepts")
pipeline_output = pipeline.generate_image(
prompt=prompt,
concepts=concepts,
width=1024,
height=1024,
seed=seed,
timesteps=list(range(timestep_start_index, 4)),
num_inference_steps=4,
layer_indices=list(range(layer_start_index, 19)),
softmax=True if len(concepts) > 1 else False
)
output_image = pipeline_output.image
output_space_heatmaps = pipeline_output.concept_heatmaps
output_space_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in output_space_heatmaps]
output_space_maps_and_labels = [(output_space_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
cross_attention_heatmaps = pipeline_output.cross_attention_maps
cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps]
cross_attention_maps_and_labels = [(cross_attention_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))]
return output_image, \
gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \
gr.update(value=cross_attention_maps_and_labels, columns=len(cross_attention_maps_and_labels))
with gr.Blocks(
css="""
.container {
max-width: 1400px;
margin: 0 auto;
padding: 20px;
}
.authors { text-align: center; margin-bottom: 10px; }
.affiliations { text-align: center; color: #666; margin-bottom: 10px; }
.abstract { text-align: center; margin-bottom: 40px; }
.generated-image {
display: flex;
align-items: center;
justify-content: center;
height: 100%; /* Ensures full height */
}
.header {
display: flex;
flex-direction: column;
}
.input {
height: 47px;
}
.input-column {
flex-direction: column;
gap: 0px;
}
.input-column-label {}
.gallery {}
.run-button-column {
width: 100px !important;
}
#title {
font-size: 2.4em;
text-align: center;
margin-bottom: 10px;
}
#subtitle {
font-size: 2.0em;
text-align: center;
}
#concept-attention-callout-svg {
width: 250px;
}
/* Show only on screens wider than 768px (adjust as needed) */
@media (min-width: 1024px) {
.svg-container {
min-width: 150px;
width: 200px;
padding-top: 540px;
}
}
@media (min-width: 1280px) {
.svg-container {
min-width: 200px;
width: 300px;
padding-top: 420px;
}
}
@media (min-width: 1530px) {
.svg-container {
min-width: 200px;
width: 300px;
padding-top: 400px;
}
}
@media (max-width: 1024px) {
.svg-container {
display: none;
}
}
"""
# ,
# elem_classes="container"
) as demo:
with gr.Row(elem_classes="container"):
with gr.Column(elem_classes="application", scale=15):
with gr.Row(scale=3, elem_classes="header"):
gr.HTML("<h1 id='title'> ConceptAttention: Visualize Any Concepts in Your Generated Images</h1>")
gr.HTML("<h2 id='subtitle'> Interpret generative models with precise, high-quality heatmaps. <br/> Check out our paper <a href='https://arxiv.org/abs/2502.04320'> here </a>. </h2>")
with gr.Row(scale=1, equal_height=True):
with gr.Column(scale=4, elem_classes="input-column", min_width=250):
gr.HTML(
"Write a Prompt",
elem_classes="input-column-label"
)
prompt = gr.Dropdown(
["A dog by a tree", "A dragon", "A hot air balloon"],
container=False,
allow_custom_value=True,
elem_classes="input"
)
with gr.Column(scale=7, elem_classes="input-column"):
gr.HTML(
"Select or Write Concepts",
elem_classes="input-column-label"
)
concepts = gr.Dropdown(
["dog", "grass", "tree", "dragon", "sky", "rock", "cloud", "balloon", "water", "background"],
value=["dog", "grass", "tree", "background"],
multiselect=True,
label="Concepts",
container=False,
allow_custom_value=True,
# scale=4,
elem_classes="input",
max_choices=5
)
with gr.Column(scale=1, min_width=100, elem_classes="input-column run-button-column"):
gr.HTML(
"&#8203;",
elem_classes="input-column-label"
)
submit_btn = gr.Button(
"Run",
elem_classes="input"
)
with gr.Row(elem_classes="gallery", scale=8):
with gr.Column(scale=1, min_width=250):
generated_image = gr.Image(
elem_classes="generated-image",
show_label=False
)
with gr.Column(scale=4):
concept_attention_gallery = gr.Gallery(
label="Concept Attention (Ours)",
show_label=True,
# columns=3,
rows=1,
object_fit="contain",
height="200px",
elem_classes="gallery",
elem_id="concept-attention-gallery"
)
cross_attention_gallery = gr.Gallery(
label="Cross Attention",
show_label=True,
# columns=3,
rows=1,
object_fit="contain",
height="200px",
elem_classes="gallery"
)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10)
timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2)
submit_btn.click(
fn=process_inputs,
inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
)
prompt.change(update_default_concepts, inputs=[prompt], outputs=[concepts])
# Automatically process the first example on launch
demo.load(
process_inputs,
inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index],
outputs=[generated_image, concept_attention_gallery, cross_attention_gallery]
)
with gr.Column(scale=4, min_width=250, elem_classes="svg-container"):
concept_attention_callout_svg = gr.HTML(
"<img src='/gradio_api/file=ConceptAttentionCallout.svg' id='concept-attention-callout-svg'/>",
# container=False,
)
if __name__ == "__main__":
if os.path.exists("/data-nvme/zerogpu-offload"):
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
demo.launch(
allowed_paths=["."]
)
# share=True,
# server_name="0.0.0.0",
# inbrowser=True,
# # share=False,
# server_port=6754,
# quiet=True,
# max_threads=1
# )