Framepack-H111 / blissful_tuner /model_utility.py
rahul7star's picture
Upload 303 files
e0336bc verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Model inspection and conversion utility for Blissful Tuner Extension
License: Apache 2.0
Created on Wed Apr 23 10:19:19 2025
@author: blyss
"""
import os
import argparse
import torch
import safetensors
from safetensors.torch import save_file
from tqdm import tqdm
parser = argparse.ArgumentParser(
description="Convert any model checkpoint (single file or shard directory) to safetensors with dtype cast."
)
parser.add_argument(
"--input",
required=True,
help="Checkpoint file or directory of shards to convert/inspect"
)
parser.add_argument("--convert", type=str, default=None)
parser.add_argument("--inspect", action="store_true")
parser.add_argument("--key_target", type=str)
parser.add_argument("--weights_only", type=str, default="true")
parser.add_argument("--dtype", type=str)
args = parser.parse_args()
def load_torch_file(ckpt, weights_only=True, device=None, return_metadata=False):
"""
Load a single checkpoint file or all shards in a directory.
- If `ckpt` is a dir, iterates over supported files, loads each, and merges.
- Returns state_dict (and metadata if return_metadata=True and single file).
"""
if device is None:
device = torch.device("cpu")
# --- shard support ---
if os.path.isdir(ckpt):
all_sd = {}
for fname in sorted(os.listdir(ckpt)):
path = os.path.join(ckpt, fname)
# only load supported extensions
if not os.path.isfile(path):
continue
if not path.lower().endswith((".safetensors", ".sft", ".pt", ".pth")):
continue
# load each shard (we ignore metadata for shards)
shard_sd = load_torch_file(path, weights_only, device, return_metadata=False)
all_sd.update(shard_sd)
return (all_sd, None) if return_metadata else all_sd
# --- single file ---
metadata = None
if ckpt.lower().endswith((".safetensors", ".sft")):
try:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {k: f.get_tensor(k) for k in f.keys()}
metadata = f.metadata() if return_metadata else None
except Exception as e:
raise ValueError(f"Safetensors load failed: {e}\nFile: {ckpt}")
else:
pl_sd = torch.load(ckpt, map_location=device, weights_only=weights_only)
sd = pl_sd.get("state_dict", pl_sd)
return (sd, metadata) if return_metadata else sd
print("Loading checkpoint...")
weights_only = args.weights_only.lower() == "true"
checkpoint = load_torch_file(args.input, weights_only)
dtype_mapping = {
"fp16": torch.float16,
"float16": torch.float16,
"bf16": torch.bfloat16,
"bfloat16": torch.bfloat16,
"fp32": torch.float32,
"float32": torch.float32,
}
if args.convert is not None and os.path.exists(args.convert):
confirm = input(f"{args.convert} exists. Overwrite? [y/N]: ").strip().lower()
if confirm != "y":
print("Aborting.")
exit()
converted_state_dict = {}
keys_to_process = (
[k for k in checkpoint if args.key_target in k] if args.key_target else checkpoint.keys()
)
dtypes_in_model = {}
for key in tqdm(keys_to_process, desc="Processing tensors"):
value = checkpoint[key]
if args.inspect:
print(f"{key}: {value.shape} ({value.dtype})")
dtype_to_use = (
dtype_mapping.get(args.dtype.lower(), value.dtype)
if args.dtype
else value.dtype
)
if dtype_to_use not in dtypes_in_model:
dtypes_in_model[dtype_to_use] = 1
else:
dtypes_in_model[dtype_to_use] += 1
if args.convert:
converted_state_dict[key] = value.to(dtype_to_use)
print(f"Dtypes in model: {dtypes_in_model}")
if args.convert:
output_file = (
args.convert.replace(".pth", ".safetensors")
.replace(".pt", ".safetensors")
)
print(f"Saving converted tensors to '{output_file}'...")
save_file(converted_state_dict, output_file)