import base64 import io import spaces import gradio as gr from PIL import Image import requests import numpy as np import PIL from concept_attention import ConceptAttentionFluxPipeline concept_attention_default_args = { "model_name": "flux-schnell", "device": "cuda", "layer_indices": list(range(10, 19)), "timesteps": list(range(4)), "num_samples": 4, "num_inference_steps": 4 } IMG_SIZE = 250 def download_image(url): return Image.open(io.BytesIO(requests.get(url).content)) EXAMPLES = [ [ "A dog by a tree", # prompt download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/dog_by_tree.png?raw=true"), "tree, dog, grass, background", # words 42, # seed ], [ "A dragon", # prompt download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/dragon_image.png?raw=true"), "dragon, sky, rock, cloud", # words 42, # seed ], [ "A hot air balloon", # prompt download_image("https://github.com/helblazer811/ConceptAttention/blob/master/images/hot_air_balloon.png?raw=true"), "balloon, sky, water, tree", # words 42, # seed ] ] pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda") @spaces.GPU(duration=60) def process_inputs(prompt, input_image, word_list, seed): print("Processing inputs") prompt = prompt.strip() if not word_list.strip(): return None, "Please enter comma-separated words" concepts = [w.strip() for w in word_list.split(",")] if input_image is not None: if isinstance(input_image, np.ndarray): input_image = Image.fromarray(input_image) input_image = input_image.convert("RGB") input_image = input_image.resize((1024, 1024)) elif isinstance(input_image, PIL.Image.Image): input_image = input_image.convert("RGB") input_image = input_image.resize((1024, 1024)) print(input_image.size) pipeline_output = pipeline.encode_image( image=input_image, concepts=concepts, prompt=prompt, width=1024, height=1024, seed=seed, num_samples=concept_attention_default_args["num_samples"] ) else: pipeline_output = pipeline.generate_image( prompt=prompt, concepts=concepts, width=1024, height=1024, seed=seed, timesteps=concept_attention_default_args["timesteps"], num_inference_steps=concept_attention_default_args["num_inference_steps"], ) output_image = pipeline_output.image concept_heatmaps = pipeline_output.concept_heatmaps html_elements = [] for concept, heatmap in zip(concepts, concept_heatmaps): img = heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) buffered = io.BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() html = f"""