import spaces import gradio as gr from sd3_pipeline import StableDiffusion3Pipeline import torch import random import numpy as np import os import gc import tempfile import imageio from diffusers import AutoencoderKLWan from wan_pipeline import WanPipeline from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler from PIL import Image from diffusers.utils import export_to_video from huggingface_hub import login login(token=os.getenv('HF_TOKEN')) def set_seed(seed): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) # Model paths model_paths = { "sd3.5": "stabilityai/stable-diffusion-3.5-large", "sd3": "stabilityai/stable-diffusion-3-medium-diffusers", "wan-t2v": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" } # Global variable for current model current_model = None # Folder to save video outputs OUTPUT_DIR = "generated_videos" os.makedirs(OUTPUT_DIR, exist_ok=True) def load_model(model_name): global current_model if current_model is not None: del current_model # Delete the old model torch.cuda.empty_cache() # Free GPU memory gc.collect() # Force garbage collection if "wan-t2v" in model_name: vae = AutoencoderKLWan.from_pretrained(model_paths[model_name], subfolder="vae", torch_dtype=torch.bfloat16) scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=8.0) current_model = WanPipeline.from_pretrained(model_paths[model_name], vae=vae, torch_dtype=torch.float16).to("cuda") current_model.scheduler = scheduler else: current_model = StableDiffusion3Pipeline.from_pretrained(model_paths[model_name], torch_dtype=torch.bfloat16).to("cuda") return current_model.to('cuda') @spaces.GPU(duration=120) def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps=50, use_cfg_zero_star=True, use_zero_init=True, zero_steps=0, seed=None, compare_mode=False): model = load_model(model_name) if seed is None: seed = random.randint(0, 2**32 - 1) set_seed(seed) is_video_model = "wan-t2v" in model_name if is_video_model: if compare_mode: set_seed(seed) video1_frames = model( prompt=prompt, guidance_scale=guidance_scale, num_frames=81, use_cfg_zero_star=True, use_zero_init=use_zero_init, zero_steps=zero_steps ).frames[0] video1_path = os.path.join(OUTPUT_DIR, f"{seed}_CFG-Zero-Star.mp4") export_to_video(video1_frames, video1_path, fps=16) set_seed(seed) video2_frames = model( prompt=prompt, guidance_scale=guidance_scale, num_frames=81, use_cfg_zero_star=False, use_zero_init=use_zero_init, zero_steps=zero_steps ).frames[0] video2_path = os.path.join(OUTPUT_DIR, f"{seed}_CFG.mp4") export_to_video(video2_frames, video2_path, fps=16) return None, None, video1_path, video2_path, seed else: video_frames = model( prompt=prompt, guidance_scale=guidance_scale, num_frames=81, use_cfg_zero_star=use_cfg_zero_star, use_zero_init=use_zero_init, zero_steps=zero_steps ).frames[0] video_path = save_video(video_frames, f"{seed}.mp4") return None, None, video_path, None, seed if compare_mode: set_seed(seed) image1 = model( prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, use_cfg_zero_star=True, use_zero_init=use_zero_init, zero_steps=zero_steps ).images[0] set_seed(seed) image2 = model( prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, use_cfg_zero_star=False, use_zero_init=use_zero_init, zero_steps=zero_steps ).images[0] return image1, image2, None, None, seed else: image = model( prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, use_cfg_zero_star=use_cfg_zero_star, use_zero_init=use_zero_init, zero_steps=zero_steps ).images[0] if use_cfg_zero_star: return image, None, None, None, seed else: return None, image, None, None, seed # Gradio UI demo = gr.Interface( fn=generate_content, inputs=[ gr.Textbox(value="A cosmic whale swimming throught a glaxy with stars and swirling cosmic dusts.", label="Enter your prompt"), gr.Dropdown(choices=list(model_paths.keys()), label="Choose Model"), gr.Slider(1, 20, value=4.0, step=0.5, label="Guidance Scale"), gr.Slider(10, 100, value=28, step=5, label="Inference Steps"), gr.Checkbox(value=True, label="Use CFG Zero Star"), gr.Checkbox(value=True, label="Use Zero Init"), gr.Slider(0, 20, value=0, step=1, label="Zero out steps"), gr.Number(value=42, label="Seed (Leave blank for random)"), gr.Checkbox(value=True, label="Compare Mode") ], outputs=[ gr.Image(type="pil", label="CFG-Zero* Image"), gr.Image(type="pil", label="CFG Image"), gr.Video(label="CFG-Zero* Video"), gr.Video(label="CFG Video"), gr.Textbox(label="Used Seed") ], title="CFG-Zero*: Improved Classifier-Free Guidance for Flow Matching Models", ) demo.launch(ssr_mode=False)