import spaces import gradio as gr from PIL import Image import math import io import base64 import subprocess import os from concept_attention import ConceptAttentionFluxPipeline IMG_SIZE = 210 COLUMNS = 5 def update_default_concepts(prompt): default_concepts = { "A dog by a tree": ["dog", "grass", "tree", "background"], "A dragon": ["dragon", "sky", "rock", "cloud"], "A hot air balloon": ["balloon", "sky", "water", "tree"] } return gr.update(value=default_concepts.get(prompt, [])) pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda") # , offload_model=True) def convert_pil_to_bytes(img): img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) buffered = io.BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() return img_str @spaces.GPU(duration=60) def process_inputs(prompt, concepts, seed, layer_start_index, timestep_start_index): if not prompt: raise gr.exceptions.InputError("prompt", "Please enter a prompt") if not prompt.strip(): raise gr.exceptions.InputError("prompt", "Please enter a prompt") prompt = prompt.strip() if len(concepts) == 0: raise gr.exceptions.InputError("words", "Please enter at least 1 concept") if len(concepts) > 9: raise gr.exceptions.InputError("words", "Please enter at most 9 concepts") pipeline_output = pipeline.generate_image( prompt=prompt, concepts=concepts, width=1024, height=1024, seed=seed, timesteps=list(range(timestep_start_index, 4)), num_inference_steps=4, layer_indices=list(range(layer_start_index, 19)), softmax=True if len(concepts) > 1 else False ) output_image = pipeline_output.image output_space_heatmaps = pipeline_output.concept_heatmaps output_space_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in output_space_heatmaps] output_space_maps_and_labels = [(output_space_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))] cross_attention_heatmaps = pipeline_output.cross_attention_maps cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps] cross_attention_maps_and_labels = [(cross_attention_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))] return output_image, \ gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \ gr.update(value=cross_attention_maps_and_labels, columns=len(cross_attention_maps_and_labels)) with gr.Blocks( css=""" .container { max-width: 1400px; margin: 0 auto; padding: 20px; } .authors { text-align: center; margin-bottom: 10px; } .affiliations { text-align: center; color: #666; margin-bottom: 10px; } .abstract { text-align: center; margin-bottom: 40px; } .generated-image { display: flex; align-items: center; justify-content: center; height: 100%; /* Ensures full height */ } .header { display: flex; flex-direction: column; } .input { height: 47px; } .input-column { flex-direction: column; gap: 0px; } .input-column-label {} .gallery {} .run-button-column { width: 100px !important; } #title { font-size: 2.4em; text-align: center; margin-bottom: 10px; } #subtitle { font-size: 2.0em; text-align: center; } #concept-attention-callout-svg { width: 250px; } /* Show only on screens wider than 768px (adjust as needed) */ @media (min-width: 1024px) { .svg-container { min-width: 150px; width: 200px; padding-top: 540px; } } @media (min-width: 1280px) { .svg-container { min-width: 200px; width: 300px; padding-top: 420px; } } @media (min-width: 1530px) { .svg-container { min-width: 200px; width: 300px; padding-top: 400px; } } @media (max-width: 1024px) { .svg-container { display: none; } } """ # , # elem_classes="container" ) as demo: with gr.Row(elem_classes="container"): with gr.Column(elem_classes="application", scale=15): with gr.Row(scale=3, elem_classes="header"): gr.HTML("