File size: 4,049 Bytes
e0336bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Model inspection and conversion utility for Blissful Tuner Extension

License: Apache 2.0
Created on Wed Apr 23 10:19:19 2025
@author: blyss
"""
import os
import argparse
import torch
import safetensors
from safetensors.torch import save_file
from tqdm import tqdm

parser = argparse.ArgumentParser(
    description="Convert any model checkpoint (single file or shard directory) to safetensors with dtype cast."
)
parser.add_argument(
    "--input",
    required=True,
    help="Checkpoint file or directory of shards to convert/inspect"
)
parser.add_argument("--convert", type=str, default=None)
parser.add_argument("--inspect", action="store_true")
parser.add_argument("--key_target", type=str)
parser.add_argument("--weights_only", type=str, default="true")
parser.add_argument("--dtype", type=str)
args = parser.parse_args()


def load_torch_file(ckpt, weights_only=True, device=None, return_metadata=False):
    """
    Load a single checkpoint file or all shards in a directory.
    - If `ckpt` is a dir, iterates over supported files, loads each, and merges.
    - Returns state_dict (and metadata if return_metadata=True and single file).
    """
    if device is None:
        device = torch.device("cpu")

    # --- shard support ---
    if os.path.isdir(ckpt):
        all_sd = {}
        for fname in sorted(os.listdir(ckpt)):
            path = os.path.join(ckpt, fname)
            # only load supported extensions
            if not os.path.isfile(path):
                continue
            if not path.lower().endswith((".safetensors", ".sft", ".pt", ".pth")):
                continue
            # load each shard (we ignore metadata for shards)
            shard_sd = load_torch_file(path, weights_only, device, return_metadata=False)
            all_sd.update(shard_sd)
        return (all_sd, None) if return_metadata else all_sd

    # --- single file ---
    metadata = None
    if ckpt.lower().endswith((".safetensors", ".sft")):
        try:
            with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
                sd = {k: f.get_tensor(k) for k in f.keys()}
                metadata = f.metadata() if return_metadata else None
        except Exception as e:
            raise ValueError(f"Safetensors load failed: {e}\nFile: {ckpt}")
    else:
        pl_sd = torch.load(ckpt, map_location=device, weights_only=weights_only)
        sd = pl_sd.get("state_dict", pl_sd)

    return (sd, metadata) if return_metadata else sd


print("Loading checkpoint...")
weights_only = args.weights_only.lower() == "true"
checkpoint = load_torch_file(args.input, weights_only)

dtype_mapping = {
    "fp16": torch.float16,
    "float16": torch.float16,
    "bf16": torch.bfloat16,
    "bfloat16": torch.bfloat16,
    "fp32": torch.float32,
    "float32": torch.float32,
}

if args.convert is not None and os.path.exists(args.convert):
    confirm = input(f"{args.convert} exists. Overwrite? [y/N]: ").strip().lower()
    if confirm != "y":
        print("Aborting.")
        exit()

converted_state_dict = {}
keys_to_process = (
    [k for k in checkpoint if args.key_target in k] if args.key_target else checkpoint.keys()
)
dtypes_in_model = {}
for key in tqdm(keys_to_process, desc="Processing tensors"):
    value = checkpoint[key]
    if args.inspect:
        print(f"{key}: {value.shape} ({value.dtype})")
    dtype_to_use = (
        dtype_mapping.get(args.dtype.lower(), value.dtype)
        if args.dtype
        else value.dtype
    )
    if dtype_to_use not in dtypes_in_model:
        dtypes_in_model[dtype_to_use] = 1
    else:
        dtypes_in_model[dtype_to_use] += 1
    if args.convert:
        converted_state_dict[key] = value.to(dtype_to_use)


print(f"Dtypes in model: {dtypes_in_model}")
if args.convert:
    output_file = (
        args.convert.replace(".pth", ".safetensors")
        .replace(".pt", ".safetensors")
    )
    print(f"Saving converted tensors to '{output_file}'...")
    save_file(converted_state_dict, output_file)