Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Model inspection and conversion utility for Blissful Tuner Extension | |
License: Apache 2.0 | |
Created on Wed Apr 23 10:19:19 2025 | |
@author: blyss | |
""" | |
import os | |
import argparse | |
import torch | |
import safetensors | |
from safetensors.torch import save_file | |
from tqdm import tqdm | |
parser = argparse.ArgumentParser( | |
description="Convert any model checkpoint (single file or shard directory) to safetensors with dtype cast." | |
) | |
parser.add_argument( | |
"--input", | |
required=True, | |
help="Checkpoint file or directory of shards to convert/inspect" | |
) | |
parser.add_argument("--convert", type=str, default=None) | |
parser.add_argument("--inspect", action="store_true") | |
parser.add_argument("--key_target", type=str) | |
parser.add_argument("--weights_only", type=str, default="true") | |
parser.add_argument("--dtype", type=str) | |
args = parser.parse_args() | |
def load_torch_file(ckpt, weights_only=True, device=None, return_metadata=False): | |
""" | |
Load a single checkpoint file or all shards in a directory. | |
- If `ckpt` is a dir, iterates over supported files, loads each, and merges. | |
- Returns state_dict (and metadata if return_metadata=True and single file). | |
""" | |
if device is None: | |
device = torch.device("cpu") | |
# --- shard support --- | |
if os.path.isdir(ckpt): | |
all_sd = {} | |
for fname in sorted(os.listdir(ckpt)): | |
path = os.path.join(ckpt, fname) | |
# only load supported extensions | |
if not os.path.isfile(path): | |
continue | |
if not path.lower().endswith((".safetensors", ".sft", ".pt", ".pth")): | |
continue | |
# load each shard (we ignore metadata for shards) | |
shard_sd = load_torch_file(path, weights_only, device, return_metadata=False) | |
all_sd.update(shard_sd) | |
return (all_sd, None) if return_metadata else all_sd | |
# --- single file --- | |
metadata = None | |
if ckpt.lower().endswith((".safetensors", ".sft")): | |
try: | |
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: | |
sd = {k: f.get_tensor(k) for k in f.keys()} | |
metadata = f.metadata() if return_metadata else None | |
except Exception as e: | |
raise ValueError(f"Safetensors load failed: {e}\nFile: {ckpt}") | |
else: | |
pl_sd = torch.load(ckpt, map_location=device, weights_only=weights_only) | |
sd = pl_sd.get("state_dict", pl_sd) | |
return (sd, metadata) if return_metadata else sd | |
print("Loading checkpoint...") | |
weights_only = args.weights_only.lower() == "true" | |
checkpoint = load_torch_file(args.input, weights_only) | |
dtype_mapping = { | |
"fp16": torch.float16, | |
"float16": torch.float16, | |
"bf16": torch.bfloat16, | |
"bfloat16": torch.bfloat16, | |
"fp32": torch.float32, | |
"float32": torch.float32, | |
} | |
if args.convert is not None and os.path.exists(args.convert): | |
confirm = input(f"{args.convert} exists. Overwrite? [y/N]: ").strip().lower() | |
if confirm != "y": | |
print("Aborting.") | |
exit() | |
converted_state_dict = {} | |
keys_to_process = ( | |
[k for k in checkpoint if args.key_target in k] if args.key_target else checkpoint.keys() | |
) | |
dtypes_in_model = {} | |
for key in tqdm(keys_to_process, desc="Processing tensors"): | |
value = checkpoint[key] | |
if args.inspect: | |
print(f"{key}: {value.shape} ({value.dtype})") | |
dtype_to_use = ( | |
dtype_mapping.get(args.dtype.lower(), value.dtype) | |
if args.dtype | |
else value.dtype | |
) | |
if dtype_to_use not in dtypes_in_model: | |
dtypes_in_model[dtype_to_use] = 1 | |
else: | |
dtypes_in_model[dtype_to_use] += 1 | |
if args.convert: | |
converted_state_dict[key] = value.to(dtype_to_use) | |
print(f"Dtypes in model: {dtypes_in_model}") | |
if args.convert: | |
output_file = ( | |
args.convert.replace(".pth", ".safetensors") | |
.replace(".pt", ".safetensors") | |
) | |
print(f"Saving converted tensors to '{output_file}'...") | |
save_file(converted_state_dict, output_file) | |