Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Wed Apr 16 17:23:28 2025 | |
CFGZero* implementation for Blissful Tuner extension based on https://github.com/WeichenFan/CFG-Zero-star/blob/main/models/wan/wan_pipeline.py | |
License: Apache 2.0 | |
@author: blyss | |
""" | |
import torch | |
def apply_zerostar(cond: torch.Tensor, uncond: torch.Tensor, current_step: int, guidance_scale: float, use_scaling: bool = True, zero_init_steps: int = -1) -> torch.Tensor: | |
if (current_step <= zero_init_steps): | |
return cond * 0 | |
if not use_scaling: | |
# CFG formula | |
noise_pred = uncond + guidance_scale * (cond - uncond) | |
else: | |
batch_size = cond.shape[0] | |
positive_flat = cond.view(batch_size, -1) | |
negative_flat = uncond.view(batch_size, -1) | |
alpha = optimized_scale(positive_flat, negative_flat) | |
alpha = alpha.view(batch_size, *([1] * (len(cond.shape) - 1))) | |
alpha = alpha.to(cond.dtype) | |
# CFG formula modified with alpha | |
noise_pred = uncond * alpha + guidance_scale * (cond - uncond * alpha) | |
return noise_pred | |
def optimized_scale(positive_flat, negative_flat): | |
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) | |
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 | |
# st_star = v_cond^T * v_uncond / ||v_uncond||^2 | |
st_star = dot_product / squared_norm | |
return st_star | |