Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Utility functions for Blissful Tuner extension | |
License: Apache 2.0 | |
Created on Sat Apr 12 14:09:37 2025 | |
@author: blyss | |
""" | |
import argparse | |
import hashlib | |
import torch | |
import safetensors | |
from typing import List, Union, Dict, Tuple, Optional | |
import logging | |
from rich.logging import RichHandler | |
# Adapted from ComfyUI | |
def load_torch_file( | |
ckpt: str, | |
safe_load: Optional[bool] = True, | |
device: Optional[Union[str, torch.device]] = None, | |
return_metadata: Optional[bool] = False | |
) -> Union[ | |
Dict[str, torch.Tensor], | |
Tuple[Dict[str, torch.Tensor], Optional[Dict[str, str]]] | |
]: | |
if device is None: | |
device = torch.device("cpu") | |
metadata = None | |
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): | |
try: | |
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: | |
sd = {} | |
for k in f.keys(): | |
sd[k] = f.get_tensor(k) | |
if return_metadata: | |
metadata = f.metadata() | |
except Exception as e: | |
if len(e.args) > 0: | |
message = e.args[0] | |
if "HeaderTooLarge" in message: | |
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.".format(message, ckpt)) | |
if "MetadataIncompleteBuffer" in message: | |
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt)) | |
raise e | |
else: | |
pl_sd = torch.load(ckpt, map_location=device, weights_only=safe_load) | |
if "state_dict" in pl_sd: | |
sd = pl_sd["state_dict"] | |
else: | |
if len(pl_sd) == 1: | |
key = list(pl_sd.keys())[0] | |
sd = pl_sd[key] | |
if not isinstance(sd, dict): | |
sd = pl_sd | |
else: | |
sd = pl_sd | |
return (sd, metadata) if return_metadata else sd | |
def add_noise_to_reference_video( | |
image: torch.Tensor, | |
ratio: Optional[float] = None | |
) -> torch.Tensor: | |
""" | |
Add Gaussian noise (scaled by `ratio`) to an image or batch of images. | |
Supports: | |
• Single image: (C, H, W) | |
• Batch of images: (B, C, H, W) | |
Any pixel exactly == –1 will have zero noise (mask value). | |
""" | |
if ratio is None or ratio == 0.0: | |
return image | |
dims = image.ndim | |
if dims == 3: | |
# Single image -> make it a batch of 1 | |
image = image.unsqueeze(0) # -> (1, C, H, W) | |
squeeze_back = True | |
elif dims == 4: | |
squeeze_back = False | |
else: | |
raise ValueError( | |
f"add_noise_to_reference_video() expected 3D or 4D tensor, got {dims}D" | |
) | |
# image is now (B, C, H, W) | |
B, C, H, W = image.shape | |
# make a (B,) sigma array, all = ratio | |
sigma = image.new_ones((B,)) * ratio | |
# sample noise and scale by sigma | |
noise = torch.randn_like(image) * sigma.view(B, 1, 1, 1) | |
# zero out noise wherever the original was -1 | |
noise = torch.where(image == -1, torch.zeros_like(image), noise) | |
out = image + noise | |
return out.squeeze(0) if squeeze_back else out | |
# Below here, Blyss wrote it! | |
class BlissfulLogger: | |
def __init__(self, logging_source: str, log_color: str, do_announce: Optional[bool] = False): | |
logging_source = f"{logging_source}" | |
self.logging_source = "{:<8}".format(logging_source) | |
self.log_color = log_color | |
self.logger = logging.getLogger(self.logging_source) | |
self.logger.setLevel(logging.DEBUG) | |
self.handler = RichHandler( | |
show_time=False, | |
show_level=True, | |
show_path=True, | |
rich_tracebacks=True, | |
markup=True | |
) | |
formatter = logging.Formatter( | |
f"[{self.log_color} bold]%(name)s[/] | %(message)s [dim](%(funcName)s)[/]" | |
) | |
self.handler.setFormatter(formatter) | |
self.logger.addHandler(self.handler) | |
if do_announce: | |
self.logger.info("Set up logging!") | |
def set_color(self, new_color): | |
self.log_color = new_color | |
formatter = logging.Formatter( | |
f"[{self.log_color} bold]%(name)s[/] | %(message)s [dim](%(funcName)s)[/]" | |
) | |
self.handler.setFormatter(formatter) | |
def set_name(self, new_name): | |
self.logging_source = "{:<8}".format(new_name) | |
self.logger = logging.getLogger(self.logging_source) | |
self.logger.setLevel(logging.DEBUG) | |
# Remove any existing handlers (just in case) | |
if not self.logger.hasHandlers(): | |
self.logger.addHandler(self.handler) | |
else: | |
self.logger.handlers.clear() | |
self.logger.addHandler(self.handler) | |
def info(self, msg): | |
self.logger.info(msg, stacklevel=2) | |
def debug(self, msg): | |
self.logger.debug(msg, stacklevel=2) | |
def warning(self, msg, levelmod=0): | |
self.logger.warning(msg, stacklevel=2 + levelmod) | |
def warn(self, msg): | |
self.logger.warning(msg, stacklevel=2) | |
def error(self, msg): | |
self.logger.error(msg, stacklevel=2) | |
def critical(self, msg): | |
self.logger.critical(msg, stacklevel=2) | |
def setLevel(self, level): | |
self.logger.set_level(level) | |
def parse_scheduled_cfg(schedule: str, infer_steps: int, guidance_scale: int) -> List[int]: | |
""" | |
Parse a schedule string like "1-10,20,!5,e~3" into a sorted list of steps. | |
- "start-end" includes all steps in [start, end] | |
- "e~n" includes every nth step (n, 2n, ...) up to infer_steps | |
- "x" includes the single step x | |
- Prefix "!" on any token to exclude those steps instead of including them. | |
- Postfix ":float" e.g. ":6.0" to any step or range to specify a guidance_scale override for that step | |
Raises argparse.ArgumentTypeError on malformed tokens or out-of-range steps. | |
""" | |
excluded = set() | |
guidance_scale_dict = {} | |
for raw in schedule.split(","): | |
token = raw.strip() | |
if not token: | |
continue # skip empty tokens | |
# exclusion if it starts with "!" | |
if token.startswith("!"): | |
target = "exclude" | |
token = token[1:] | |
else: | |
target = "include" | |
weight = guidance_scale | |
if ":" in token: | |
token, float_part = token.rsplit(":", 1) | |
weight = float(float_part) | |
# modulus syntax: e.g. "e~3" | |
if token.startswith("e~"): | |
num_str = token[2:] | |
try: | |
n = int(num_str) | |
except ValueError: | |
raise argparse.ArgumentTypeError(f"Invalid modulus in '{raw}'") | |
if n < 1: | |
raise argparse.ArgumentTypeError(f"Modulus must be ≥ 1 in '{raw}'") | |
steps = range(n, infer_steps + 1, n) | |
# range syntax: e.g. "5-10" | |
elif "-" in token: | |
parts = token.split("-") | |
if len(parts) != 2: | |
raise argparse.ArgumentTypeError(f"Malformed range '{raw}'") | |
start_str, end_str = parts | |
try: | |
start = int(start_str) | |
end = int(end_str) | |
except ValueError: | |
raise argparse.ArgumentTypeError(f"Non‑integer in range '{raw}'") | |
if start < 1 or end < 1: | |
raise argparse.ArgumentTypeError(f"Steps must be ≥ 1 in '{raw}'") | |
if start > end: | |
raise argparse.ArgumentTypeError(f"Start > end in '{raw}'") | |
if end > infer_steps: | |
raise argparse.ArgumentTypeError(f"End > infer_steps ({infer_steps}) in '{raw}'") | |
steps = range(start, end + 1) | |
# single‑step syntax: e.g. "7" | |
else: | |
try: | |
step = int(token) | |
except ValueError: | |
raise argparse.ArgumentTypeError(f"Invalid token '{raw}'") | |
if step < 1 or step > infer_steps: | |
raise argparse.ArgumentTypeError(f"Step {step} out of range 1–{infer_steps} in '{raw}'") | |
steps = [step] | |
# apply include/exclude | |
if target == "include": | |
for step in steps: | |
guidance_scale_dict[step] = weight | |
else: | |
excluded.update(steps) | |
for step in excluded: | |
guidance_scale_dict.pop(step, None) | |
return guidance_scale_dict | |
def setup_compute_context(device: Optional[Union[torch.device, str]] = None, dtype: Optional[Union[torch.dtype, str]] = None) -> Tuple[torch.device, torch.dtype]: | |
dtype_mapping = { | |
"fp16": torch.float16, | |
"float16": torch.float16, | |
"bf16": torch.bfloat16, | |
"bfloat16": torch.bfloat16, | |
"fp32": torch.float32, | |
"float32": torch.float32, | |
"fp8": torch.float8_e4m3fn, | |
"float8": torch.float8_e4m3fn | |
} | |
if device is None: | |
device = torch.device("cpu") | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
elif torch.mps.is_available(): | |
device = torch.device("mps") | |
elif isinstance(device, str): | |
device = torch.device(device) | |
if dtype is None: | |
dtype = torch.float32 | |
elif isinstance(dtype, str): | |
if dtype not in dtype_mapping: | |
raise ValueError(f"Unknown dtype string '{dtype}'") | |
dtype = dtype_mapping[dtype] | |
torch.set_float32_matmul_precision('high') | |
if dtype == torch.float16 or dtype == torch.bfloat16: | |
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): | |
torch.backends.cuda.matmul.allow_fp16_accumulation = True | |
print("FP16 accumulation enabled.") | |
return device, dtype | |
def string_to_seed(s: str, bits: int = 63) -> int: | |
""" | |
Turn any string into a reproducible integer in [0, 2**bits) with a hash and some other logic. | |
Args: | |
s: Input string | |
bits: Number of bits for the final seed (PyTorch accepts up to 63 safely, numpy likes 32) | |
Returns: | |
A non-negative int < 2**bits | |
""" | |
digest = hashlib.sha256(s.encode("utf-8")).digest() | |
crypto = int.from_bytes(digest, byteorder="big") | |
mask = (1 << bits) - 1 | |
algo = 0 | |
for i, char in enumerate(s): | |
char_val = ord(char) | |
if i % 2 == 0: | |
algo *= char_val | |
elif i % 3 == 0: | |
algo -= char_val | |
elif i % 5 == 0: | |
algo /= char_val | |
else: | |
algo += char_val | |
seed = (abs(crypto - int(algo))) & mask | |
return seed | |
def error_out(error, message): | |
logger = BlissfulLogger(__name__, "#8e00ed") | |
logger.warning(message, levelmod=1) | |
raise error(message) | |