import gradio as gr import torch import matplotlib.pyplot as plt import os from PIL import Image import numpy as np # Import your custom modules from SDLens import HookedStableDiffusionXLPipeline from training.k_sparse_autoencoder import SparseAutoencoder from utils.hooks import add_feature_on_text_prompt # Function to modulate hooks on prompt def modulate_hook_prompt(sae, steering_feature, block): def hook_function(*args, **kwargs): return add_feature_on_text_prompt( sae, steering_feature, *args, **kwargs ) return hook_function # Function to load models def load_models(): try: # Load the Pipeline pipe = HookedStableDiffusionXLPipeline.from_pretrained('stabilityai/sdxl-turbo') pipe.set_progress_bar_config(disable=True) # Define blocks to save blocks_to_save = ['text_encoder.text_model.encoder.layers.10', 'text_encoder_2.text_model.encoder.layers.28'] # Load the sparse autoencoder sae_path = "Checkpoints/dahyecheckpoint" sae = SparseAutoencoder.load_from_disk(os.path.join(sae_path, 'final')) return pipe, blocks_to_save, sae except Exception as e: print(f"Error loading models: {e}") return None, None, None # Function to generate images with activation modulation def activation_modulation_across_prompt(pipe, sae, blocks_to_save, steer_prompt, strength, prompt, guidance_scale, num_inference_steps, seed): # Generate steering feature output, cache = pipe.run_with_cache( steer_prompt, positions_to_cache=blocks_to_save, save_input=True, save_output=True, num_inference_steps=1, guidance_scale=guidance_scale, generator=torch.Generator(device="cpu").manual_seed(seed) ) diff = torch.cat([cache['output'][blocks_to_save[0]], cache['output'][blocks_to_save[1]]], dim=-1) diff = diff.squeeze(0).squeeze(0) with torch.no_grad(): activated = sae.encode_without_topk(diff) # [77, 81920] mask = activated * strength to_add = mask @ sae.decoder.weight.T steering_feature = to_add # Generate image with modulation output = pipe.run_with_hooks( prompt, position_hook_dict = { block: modulate_hook_prompt(sae, steering_feature, block) for block in blocks_to_save }, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=torch.Generator(device="cpu").manual_seed(seed) ) return output.images[0] # Function to generate images for the Gradio app def generate_comparison(prompt, steer_prompt, strength, seed, guidance_scale, steps): if pipe is None or sae is None or blocks_to_save is None: return Image.new('RGB', (512, 512), color='red'), Image.new('RGB', (512, 512), color='red'), "Error: Models failed to load" try: # Generate image with standard model (strength = 0) standard_image = pipe( prompt, num_inference_steps=steps, guidance_scale=guidance_scale, generator=torch.Generator(device="cpu").manual_seed(seed) ).images[0] # Generate image with activation modulation if strength > 0: modified_image = activation_modulation_across_prompt( pipe, sae, blocks_to_save, steer_prompt, strength, prompt, guidance_scale, steps, seed ) else: # If strength is 0, just return the standard image again to avoid redundant computation modified_image = standard_image comparison_message = f"Generated images with modulation strength: {strength}" return standard_image, modified_image, comparison_message except Exception as e: error_image = Image.new('RGB', (512, 512), color='red') return error_image, error_image, f"Error during generation: {str(e)}" # Load the models at startup print("Loading models...") pipe, blocks_to_save, sae = load_models() if pipe is not None: print("Models loaded successfully!") else: print("Failed to load models") # Define the Gradio interface with gr.Blocks(title="SDXL Activation Modulation") as app: gr.Markdown("# SDXL Activation Modulation Comparison") gr.Markdown(""" This app demonstrates activation modulation in Stable Diffusion XL using sparse autoencoders. It compares standard SDXL-Turbo outputs with modulated outputs that can steer the generation based on a separate concept. """) with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", placeholder="Enter your main image prompt here...", value="A photo of a tree") steer_prompt = gr.Textbox(label="Steering Prompt", placeholder="Enter concept to steer with...", value="tree with autumn leaves") strength = gr.Slider(minimum=-2.5, maximum=2.5, value=0.8, step=0.05, label="Modulation Strength (λ)") with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider(minimum=0, maximum=2147483647, step=1, value=61730, label="Seed") guidance_scale = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.5, label="Guidance Scale") steps = gr.Slider(minimum=1, maximum=50, value=3, step=1, label="Inference Steps") generate_btn = gr.Button("Generate Comparison", variant="primary") status = gr.Textbox(label="Status", interactive=False) with gr.Row(): standard_output = gr.Image(label="Standard SDXL-Turbo") modified_output = gr.Image(label="Modulated Output") gr.Markdown(""" ## Examples from the notebook: - Main prompt: "A photo of a tree" with steering prompt: "tree with autumn leaves" - Main prompt: "A dog" with steering prompt: "full shot" - Main prompt: "A car" with steering prompt: "A blue car" """) with gr.Row(): example1 = gr.Button("Example 1: Tree with autumn leaves") example2 = gr.Button("Example 2: Dog with full shot") example3 = gr.Button("Example 3: Blue car") # Set up button actions generate_btn.click( fn=generate_comparison, inputs=[prompt, steer_prompt, strength, seed, guidance_scale, steps], outputs=[standard_output, modified_output, status] ) # Set up example button click events example1.click( fn=lambda: ["A photo of a tree", "tree with autumn leaves", 0.5, 61730, 0.0, 3], inputs=None, outputs=[prompt, steer_prompt, strength, seed, guidance_scale, steps] ) example2.click( fn=lambda: ["A dog", "full shot", 0.4, 61730, 0.0, 3], inputs=None, outputs=[prompt, steer_prompt, strength, seed, guidance_scale, steps] ) example3.click( fn=lambda: ["A car", "A blue car", 0.3, 61730, 0.0, 3], inputs=None, outputs=[prompt, steer_prompt, strength, seed, guidance_scale, steps] ) gr.Markdown(""" ## How to Use 1. Enter your main prompt (what you want to generate) 2. Enter a steering prompt (concept to influence the generation) 3. Adjust the modulation strength slider (λ) - higher values mean stronger influence 4. Click "Generate Comparison" to see the results side by side 5. Use advanced settings if needed to adjust seed, guidance scale, or steps ## About This app demonstrates activation modulation using a sparse autoencoder trained on SDXL text encoder layers. The modulation allows steering the generation toward specific concepts without changing the main prompt. """) # Launch the app if __name__ == "__main__": app.launch()